-
Notifications
You must be signed in to change notification settings - Fork 46
Expand file tree
/
Copy pathhelpers.py
More file actions
385 lines (327 loc) · 13.8 KB
/
helpers.py
File metadata and controls
385 lines (327 loc) · 13.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
This module provides helper functions for the mssql_python package.
"""
import re
import threading
import locale
from typing import Any, Union, Tuple, Optional
from mssql_python import ddbc_bindings
from mssql_python.exceptions import raise_exception
from mssql_python.logging import logger
from mssql_python.constants import ConstantsDDBC
# normalize_architecture import removed as it's unused
def check_error(handle_type: int, handle: Any, ret: int) -> None:
"""
Check for errors and raise an exception if an error is found.
Args:
handle_type: The type of the handle (e.g., SQL_HANDLE_ENV, SQL_HANDLE_DBC).
handle: The SqlHandle object associated with the operation.
ret: The return code from the DDBC function call.
Raises:
RuntimeError: If an error is found.
"""
if ret < 0:
logger.debug(
"check_error: Error detected - handle_type=%d, return_code=%d", handle_type, ret
)
error_info = ddbc_bindings.DDBCSQLCheckError(handle_type, handle, ret)
logger.error("Error: %s", error_info.ddbcErrorMsg)
logger.debug("check_error: SQL state=%s", error_info.sqlState)
raise_exception(error_info.sqlState, error_info.ddbcErrorMsg)
def sanitize_connection_string(conn_str: str) -> str:
"""
Sanitize the connection string by removing sensitive information.
Delegates to the parser-based implementation in connection_string_parser
which correctly handles ODBC braced values (e.g. PWD={Top;Secret}).
Args:
conn_str (str): The connection string to sanitize.
Returns:
str: The sanitized connection string.
"""
from mssql_python.connection_string_parser import (
sanitize_connection_string as _sanitize,
)
return _sanitize(conn_str)
def sanitize_user_input(user_input: str, max_length: int = 50) -> str:
"""
Sanitize user input for safe logging by removing control characters,
limiting length, and ensuring safe characters only.
Args:
user_input (str): The user input to sanitize.
max_length (int): Maximum length of the sanitized output.
Returns:
str: The sanitized string safe for logging.
"""
logger.debug(
"sanitize_user_input: Sanitizing input (type=%s, length=%d)",
type(user_input).__name__,
len(user_input) if isinstance(user_input, str) else 0,
)
if not isinstance(user_input, str):
logger.debug("sanitize_user_input: Non-string input detected")
return "<non-string>"
# Remove control characters and non-printable characters
# Allow alphanumeric, dash, underscore, and dot (common in encoding names)
sanitized = re.sub(r"[^\w\-\.]", "", user_input)
# Limit length to prevent log flooding
was_truncated = False
if len(sanitized) > max_length:
sanitized = sanitized[:max_length] + "..."
was_truncated = True
# Return placeholder if nothing remains after sanitization
result = sanitized if sanitized else "<invalid>"
logger.debug(
"sanitize_user_input: Result length=%d, truncated=%s", len(result), str(was_truncated)
)
return result
def validate_attribute_value(
attribute: Union[int, str],
value: Union[int, str, bytes, bytearray],
is_connected: bool = True,
sanitize_logs: bool = True,
max_log_length: int = 50,
) -> Tuple[bool, Optional[str], str, str]:
"""
Validates attribute and value pairs for connection attributes.
Performs basic type checking and validation of ODBC connection attributes.
Args:
attribute (int): The connection attribute to validate (SQL_ATTR_*)
value: The value to set for the attribute (int, str, bytes, or bytearray)
is_connected (bool): Whether the connection is already established
sanitize_logs (bool): Whether to include sanitized versions for logging
max_log_length (int): Maximum length of sanitized output for logging
Returns:
tuple: (is_valid, error_message, sanitized_attribute, sanitized_value)
"""
logger.debug(
"validate_attribute_value: Validating attribute=%s, value_type=%s, is_connected=%s",
str(attribute),
type(value).__name__,
str(is_connected),
)
# Sanitize a value for logging
def _sanitize_for_logging(input_val: Any, max_length: int = max_log_length) -> str:
if not isinstance(input_val, str):
try:
input_val = str(input_val)
except (TypeError, ValueError):
return "<non-string>"
# Allow alphanumeric, dash, underscore, and dot
sanitized = re.sub(r"[^\w\-\.]", "", input_val)
# Limit length
if len(sanitized) > max_length:
sanitized = sanitized[:max_length] + "..."
return sanitized if sanitized else "<invalid>"
# Create sanitized versions for logging
sanitized_attr = _sanitize_for_logging(attribute) if sanitize_logs else str(attribute)
sanitized_val = _sanitize_for_logging(value) if sanitize_logs else str(value)
# Basic attribute validation - must be an integer
if not isinstance(attribute, int):
logger.debug(
"validate_attribute_value: Attribute not an integer - type=%s", type(attribute).__name__
)
return (
False,
f"Attribute must be an integer, got {type(attribute).__name__}",
sanitized_attr,
sanitized_val,
)
# Define driver-level attributes that are supported
supported_attributes = [
ConstantsDDBC.SQL_ATTR_ACCESS_MODE.value,
ConstantsDDBC.SQL_ATTR_CONNECTION_TIMEOUT.value,
ConstantsDDBC.SQL_ATTR_CURRENT_CATALOG.value,
ConstantsDDBC.SQL_ATTR_LOGIN_TIMEOUT.value,
ConstantsDDBC.SQL_ATTR_PACKET_SIZE.value,
ConstantsDDBC.SQL_ATTR_TXN_ISOLATION.value,
]
# Check if attribute is supported
if attribute not in supported_attributes:
logger.debug("validate_attribute_value: Unsupported attribute - attr=%d", attribute)
return (
False,
f"Unsupported attribute: {attribute}",
sanitized_attr,
sanitized_val,
)
# Check timing constraints for these specific attributes
before_only_attributes = [
ConstantsDDBC.SQL_ATTR_LOGIN_TIMEOUT.value,
ConstantsDDBC.SQL_ATTR_PACKET_SIZE.value,
]
# Check if attribute can be set at the current connection state
if is_connected and attribute in before_only_attributes:
logger.debug(
"validate_attribute_value: Timing violation - attr=%d cannot be set after connection",
attribute,
)
return (
False,
(
f"Attribute {attribute} must be set before connection establishment. "
"Use the attrs_before parameter when creating the connection."
),
sanitized_attr,
sanitized_val,
)
# Basic value type validation
if isinstance(value, int):
# For integer values, check if negative (login timeout can be -1 for default)
if value < 0 and attribute != ConstantsDDBC.SQL_ATTR_LOGIN_TIMEOUT.value:
return (
False,
f"Integer value cannot be negative: {value}",
sanitized_attr,
sanitized_val,
)
elif isinstance(value, str):
# Basic string length check
max_string_size = 8192 # 8KB maximum
if len(value) > max_string_size:
return (
False,
f"String value too large: {len(value)} bytes (max {max_string_size})",
sanitized_attr,
sanitized_val,
)
elif isinstance(value, (bytes, bytearray)):
# Basic binary length check
max_binary_size = 32768 # 32KB maximum
if len(value) > max_binary_size:
return (
False,
f"Binary value too large: {len(value)} bytes (max {max_binary_size})",
sanitized_attr,
sanitized_val,
)
else:
# Reject unsupported value types
return (
False,
f"Unsupported attribute value type: {type(value).__name__}",
sanitized_attr,
sanitized_val,
)
# All basic validations passed
logger.debug(
"validate_attribute_value: Validation passed - attr=%d, value_type=%s",
attribute,
type(value).__name__,
)
return True, None, sanitized_attr, sanitized_val
def connstr_to_pycore_params(params: dict) -> dict:
"""Translate parsed ODBC connection-string params for py-core's bulk copy path.
When ``cursor.bulkcopy()`` is called, mssql-python opens a *separate*
connection through mssql-py-core.
py-core's ``connection.rs`` expects a Python dict with snake_case keys —
different from the ODBC-style keys that ``_ConnectionStringParser._parse``
returns.
This function bridges that gap: it maps lowercase ODBC keys (e.g.
``"trustservercertificate"``) to py-core keys (``"trust_server_certificate"``)
and converts numeric strings to ``int`` for timeout/size params.
Boolean params (TrustServerCertificate, MultiSubnetFailover) are passed as
strings — ``connection.rs`` validates Yes/No and rejects invalid values.
Unrecognised keys are silently dropped.
"""
# Only keys listed below are forwarded to py-core.
# Unknown/reserved keys (app, workstationid, language, connect_timeout,
# mars_connection) are silently dropped here. In the normal connect()
# path the parser validates keywords first (validate_keywords=True),
# but bulkcopy parses with validation off, so this mapping is the
# authoritative filter in that path.
key_map = {
# auth / credentials
"uid": "user_name",
"pwd": "password",
"trusted_connection": "trusted_connection",
"authentication": "authentication",
# server (accept parser synonyms)
"server": "server",
"addr": "server",
"address": "server",
# database
"database": "database",
"applicationintent": "application_intent",
# encryption / TLS (include snake_case alias the parser may emit)
"encrypt": "encryption",
"trustservercertificate": "trust_server_certificate",
"trust_server_certificate": "trust_server_certificate",
"hostnameincertificate": "host_name_in_certificate",
"servercertificate": "server_certificate",
# Kerberos
"serverspn": "server_spn",
# network
"multisubnetfailover": "multi_subnet_failover",
"ipaddresspreference": "ip_address_preference",
"keepalive": "keep_alive",
"keepaliveinterval": "keep_alive_interval",
# sizing / limits ("packet size" with space is a common pyodbc-ism)
"packetsize": "packet_size",
"packet size": "packet_size",
"connectretrycount": "connect_retry_count",
"connectretryinterval": "connect_retry_interval",
}
int_keys = {
"packet_size",
"connect_retry_count",
"connect_retry_interval",
"keep_alive",
"keep_alive_interval",
}
pycore_params: dict = {}
for connstr_key, pycore_key in key_map.items():
raw_value = params.get(connstr_key)
if raw_value is None:
continue
# First-wins: match ODBC behaviour — first synonym in the
# connection string takes precedence (e.g. Addr before Server).
if pycore_key in pycore_params:
continue
# ODBC values are always strings; py-core expects native types for int keys.
# Boolean params (trust_server_certificate, multi_subnet_failover) are passed
# as strings — all Yes/No validation is in connection.rs for single-location
# consistency with Encrypt, ApplicationIntent, IPAddressPreference, etc.
if pycore_key in int_keys:
# Numeric params (timeouts, packet size, etc.) — skip on bad input
try:
pycore_params[pycore_key] = int(raw_value)
except (ValueError, TypeError):
pass # let py-core fall back to its compiled-in default
else:
# String params (server, database, encryption, etc.) — pass through
pycore_params[pycore_key] = raw_value
return pycore_params
# Settings functionality moved here to avoid circular imports
# Initialize the locale setting only once at module import time
# This avoids thread-safety issues with locale
_default_decimal_separator: str = "."
try:
# Get the locale setting once during module initialization
locale_separator = locale.localeconv()["decimal_point"]
if locale_separator and len(locale_separator) == 1:
_default_decimal_separator = locale_separator
except (AttributeError, KeyError, TypeError, ValueError):
pass # Keep the default "." if locale access fails
class Settings:
"""
Settings class for mssql_python package configuration.
This class holds global settings that affect the behavior of the package,
including lowercase column names, decimal separator, and UUID handling.
"""
def __init__(self) -> None:
self.lowercase: bool = False
# Use the pre-determined separator - no locale access here
self.decimal_separator: str = _default_decimal_separator
# Controls whether UNIQUEIDENTIFIER columns return uuid.UUID (True)
# or str (False). Default True returns native uuid.UUID objects.
# Set to False to return str for pyodbc-compatible migration.
self.native_uuid: bool = True
# Global settings instance
_settings: Settings = Settings()
_settings_lock: threading.Lock = threading.Lock()
def get_settings() -> Settings:
"""Return the global settings object"""
with _settings_lock:
return _settings