Skip to content

Commit 4b28f34

Browse files
authored
FEAT: EntraID Access Token Support for BulkCopy (#426)
### Work Item / Issue Reference <!-- IMPORTANT: Please follow the PR template guidelines below. For mssql-python maintainers: Insert your ADO Work Item ID below For external contributors: Insert Github Issue number below Only one reference is required - either GitHub issue OR ADO Work Item. --> <!-- mssql-python maintainers: ADO Work Item --> > AB#42282 > [AB#42283](https://sqlclientdrivers.visualstudio.com/c6d89619-62de-46a0-8b46-70b92a84d85e/_workitems/edit/42283) ------------------------------------------------------------------- ### Summary <!-- Insert your summary of changes below. Minimum 10 characters required. --> This pull request introduces significant improvements to Azure AD authentication handling for bulk copy operations, ensuring fresh token acquisition to prevent expired-token errors, and refactors related code for clarity and robustness. It also updates the connection and authentication APIs to propagate and utilize authentication type information more reliably. Azure AD authentication enhancements: * Added `AADAuth.get_raw_token` and refactored token acquisition logic to ensure a fresh Azure AD token is acquired each time bulkcopy is called, preventing expired-token errors. The new method avoids credential/token caching and is used specifically for bulk copy operations. (`mssql_python/auth.py`, [mssql_python/auth.pyL33-R51](diffhunk://#diff-19a0c93fc8573a5a7bfcadda0a2fb8f1b340c4502e1308c4f8a1e4508136c6e1L33-R51)) * Updated bulk copy logic to use the new `get_raw_token` method, storing the auth type on the connection and acquiring a fresh token at bulk copy time. Sensitive data is now removed from memory after use for improved security. (`mssql_python/cursor.py`, [[1]](diffhunk://#diff-deceea46ae01082ce8400e14fa02f4b7585afb7b5ed9885338b66494f5f38280L2610-R2633) [[2]](diffhunk://#diff-deceea46ae01082ce8400e14fa02f4b7585afb7b5ed9885338b66494f5f38280L2656-R2674) Connection and authentication API changes: * Refactored `process_connection_string` to return the authentication type as a third value, and added `extract_auth_type` to reliably extract auth type from the connection string when not propagated (e.g., Windows Interactive). Connection objects now store the auth type for later use. (`mssql_python/auth.py`, [[1]](diffhunk://#diff-19a0c93fc8573a5a7bfcadda0a2fb8f1b340c4502e1308c4f8a1e4508136c6e1R219-R249) [[2]](diffhunk://#diff-19a0c93fc8573a5a7bfcadda0a2fb8f1b340c4502e1308c4f8a1e4508136c6e1L262-R296) [[3]](diffhunk://#diff-19a0c93fc8573a5a7bfcadda0a2fb8f1b340c4502e1308c4f8a1e4508136c6e1L272-R306); `mssql_python/connection.py`, [[4]](diffhunk://#diff-29bb94de45aae51c23a6426d40133c28e4161e68769e08d046059c7186264e90L42-R42) [[5]](diffhunk://#diff-29bb94de45aae51c23a6426d40133c28e4161e68769e08d046059c7186264e90R266-R270) [[6]](diffhunk://#diff-29bb94de45aae51c23a6426d40133c28e4161e68769e08d046059c7186264e90R280-R283) Testing improvements: * Expanded test coverage to verify raw token acquisition, connection string processing, and correct storage of authentication type on connection objects. Tests ensure that the new APIs and behaviors work as expected. (`tests/test_008_auth.py`, [[1]](diffhunk://#diff-83e8bc8183c8cc53e88bf74d3cb8ef1751be6854edd9a727602fe618e691ecdbR86-R90) [[2]](diffhunk://#diff-83e8bc8183c8cc53e88bf74d3cb8ef1751be6854edd9a727602fe618e691ecdbL329-R365) [[3]](diffhunk://#diff-83e8bc8183c8cc53e88bf74d3cb8ef1751be6854edd9a727602fe618e691ecdbR380-R390) Error handling and logging: * Improved error handling and logging in token acquisition, providing clearer messages for unsupported authentication types and unexpected errors. (`mssql_python/auth.py`, [[1]](diffhunk://#diff-19a0c93fc8573a5a7bfcadda0a2fb8f1b340c4502e1308c4f8a1e4508136c6e1L56-L79) [[2]](diffhunk://#diff-19a0c93fc8573a5a7bfcadda0a2fb8f1b340c4502e1308c4f8a1e4508136c6e1L91) Documentation and naming consistency: * Updated docstrings and comments throughout the authentication code for clarity and consistency, reflecting the new behaviors and APIs. (`mssql_python/auth.py`, [mssql_python/auth.pyL183-R197](diffhunk://#diff-19a0c93fc8573a5a7bfcadda0a2fb8f1b340c4502e1308c4f8a1e4508136c6e1L183-R197))
1 parent 44f755b commit 4b28f34

5 files changed

Lines changed: 157 additions & 37 deletions

File tree

mssql_python/auth.py

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Tuple, Dict, Optional, List
1010

1111
from mssql_python.logging import logger
12-
from mssql_python.constants import AuthType
12+
from mssql_python.constants import AuthType, ConstantsDDBC
1313

1414

1515
class AADAuth:
@@ -30,7 +30,25 @@ def get_token_struct(token: str) -> bytes:
3030

3131
@staticmethod
3232
def get_token(auth_type: str) -> bytes:
33-
"""Get token using the specified authentication type"""
33+
"""Get DDBC token struct for the specified authentication type."""
34+
token_struct, _ = AADAuth._acquire_token(auth_type)
35+
return token_struct
36+
37+
@staticmethod
38+
def get_raw_token(auth_type: str) -> str:
39+
"""Acquire a fresh raw JWT for the mssql-py-core connection (bulk copy).
40+
41+
This deliberately does NOT cache the credential or token — each call
42+
creates a new Azure Identity credential instance and requests a token.
43+
A fresh acquisition avoids expired-token errors when bulkcopy() is
44+
called long after the original DDBC connect().
45+
"""
46+
_, raw_token = AADAuth._acquire_token(auth_type)
47+
return raw_token
48+
49+
@staticmethod
50+
def _acquire_token(auth_type: str) -> Tuple[bytes, str]:
51+
"""Internal: acquire token and return (ddbc_struct, raw_jwt)."""
3452
# Import Azure libraries inside method to support test mocking
3553
# pylint: disable=import-outside-toplevel
3654
try:
@@ -53,30 +71,27 @@ def get_token(auth_type: str) -> bytes:
5371
"interactive": InteractiveBrowserCredential,
5472
}
5573

56-
credential_class = credential_map[auth_type]
74+
credential_class = credential_map.get(auth_type)
75+
if not credential_class:
76+
raise ValueError(
77+
f"Unsupported auth_type '{auth_type}'. " f"Supported: {', '.join(credential_map)}"
78+
)
5779
logger.info(
5880
"get_token: Starting Azure AD authentication - auth_type=%s, credential_class=%s",
5981
auth_type,
6082
credential_class.__name__,
6183
)
6284

6385
try:
64-
logger.debug(
65-
"get_token: Creating credential instance - credential_class=%s",
66-
credential_class.__name__,
67-
)
6886
credential = credential_class()
69-
logger.debug(
70-
"get_token: Requesting token from Azure AD - scope=https://database.windows.net/.default"
71-
)
72-
token = credential.get_token("https://database.windows.net/.default").token
87+
raw_token = credential.get_token("https://database.windows.net/.default").token
7388
logger.info(
7489
"get_token: Azure AD token acquired successfully - token_length=%d chars",
75-
len(token),
90+
len(raw_token),
7691
)
77-
return AADAuth.get_token_struct(token)
92+
token_struct = AADAuth.get_token_struct(raw_token)
93+
return token_struct, raw_token
7894
except ClientAuthenticationError as e:
79-
# Re-raise with more specific context about Azure AD authentication failure
8095
logger.error(
8196
"get_token: Azure AD authentication failed - credential_class=%s, error=%s",
8297
credential_class.__name__,
@@ -88,7 +103,6 @@ def get_token(auth_type: str) -> bytes:
88103
f"user cancellation, network issues, or unsupported configuration."
89104
) from e
90105
except Exception as e:
91-
# Catch any other unexpected exceptions
92106
logger.error(
93107
"get_token: Unexpected error during credential creation - credential_class=%s, error=%s",
94108
credential_class.__name__,
@@ -180,7 +194,7 @@ def remove_sensitive_params(parameters: List[str]) -> List[str]:
180194

181195

182196
def get_auth_token(auth_type: str) -> Optional[bytes]:
183-
"""Get authentication token based on auth type"""
197+
"""Get DDBC authentication token struct based on auth type."""
184198
logger.debug("get_auth_token: Starting - auth_type=%s", auth_type)
185199
if not auth_type:
186200
logger.debug("get_auth_token: No auth_type specified, returning None")
@@ -202,17 +216,37 @@ def get_auth_token(auth_type: str) -> Optional[bytes]:
202216
return None
203217

204218

219+
def extract_auth_type(connection_string: str) -> Optional[str]:
220+
"""Extract Entra ID auth type from a connection string.
221+
222+
Used as a fallback when process_connection_string does not propagate
223+
auth_type (e.g. Windows Interactive where DDBC handles auth natively).
224+
Bulkcopy still needs the auth type to acquire a token via Azure Identity.
225+
"""
226+
auth_map = {
227+
AuthType.INTERACTIVE.value: "interactive",
228+
AuthType.DEVICE_CODE.value: "devicecode",
229+
AuthType.DEFAULT.value: "default",
230+
}
231+
for part in connection_string.split(";"):
232+
key, _, value = part.strip().partition("=")
233+
if key.strip().lower() == "authentication":
234+
return auth_map.get(value.strip().lower())
235+
return None
236+
237+
205238
def process_connection_string(
206239
connection_string: str,
207-
) -> Tuple[str, Optional[Dict[int, bytes]]]:
240+
) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str]]:
208241
"""
209242
Process connection string and handle authentication.
210243
211244
Args:
212245
connection_string: The connection string to process
213246
214247
Returns:
215-
Tuple[str, Optional[Dict]]: Processed connection string and attrs_before dict if needed
248+
Tuple[str, Optional[Dict], Optional[str]]: Processed connection string,
249+
attrs_before dict if needed, and auth_type string for bulk copy token acquisition
216250
217251
Raises:
218252
ValueError: If the connection string is invalid or empty
@@ -259,7 +293,11 @@ def process_connection_string(
259293
"process_connection_string: Token authentication configured successfully - auth_type=%s",
260294
auth_type,
261295
)
262-
return ";".join(modified_parameters) + ";", {1256: token_struct}
296+
return (
297+
";".join(modified_parameters) + ";",
298+
{ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value: token_struct},
299+
auth_type,
300+
)
263301
else:
264302
logger.warning(
265303
"process_connection_string: Token acquisition failed, proceeding without token"
@@ -269,4 +307,4 @@ def process_connection_string(
269307
"process_connection_string: Connection string processing complete - has_auth=%s",
270308
bool(auth_type),
271309
)
272-
return ";".join(modified_parameters) + ";", None
310+
return ";".join(modified_parameters) + ";", None, auth_type

mssql_python/connection.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
ProgrammingError,
4040
NotSupportedError,
4141
)
42-
from mssql_python.auth import process_connection_string
42+
from mssql_python.auth import extract_auth_type, process_connection_string
4343
from mssql_python.constants import ConstantsDDBC, GetInfoConstants
4444
from mssql_python.connection_string_parser import _ConnectionStringParser
4545
from mssql_python.connection_string_builder import _ConnectionStringBuilder
@@ -263,6 +263,11 @@ def __init__(
263263
},
264264
}
265265

266+
# Auth type for acquiring fresh tokens at bulk copy time.
267+
# We intentionally do NOT cache the token — a fresh one is acquired
268+
# each time bulkcopy() is called to avoid expired-token errors.
269+
self._auth_type = None
270+
266271
# Check if the connection string contains authentication parameters
267272
# This is important for processing the connection string correctly.
268273
# If authentication is specified, it will be processed to handle
@@ -272,6 +277,10 @@ def __init__(
272277
self.connection_str = connection_result[0]
273278
if connection_result[1]:
274279
self._attrs_before.update(connection_result[1])
280+
# Store auth type so bulkcopy() can acquire a fresh token later.
281+
# On Windows Interactive, process_connection_string returns None
282+
# (DDBC handles auth natively), so fall back to the connection string.
283+
self._auth_type = connection_result[2] or extract_auth_type(self.connection_str)
275284

276285
self._closed = False
277286
self._timeout = timeout

mssql_python/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ class ConstantsDDBC(Enum):
158158
SQL_ATTR_SERVER_NAME = 13
159159
SQL_ATTR_RESET_CONNECTION = 116
160160

161+
# SQL Server-specific connection option constants
162+
SQL_COPT_SS_ACCESS_TOKEN = 1256
163+
161164
# Transaction Isolation Level Constants
162165
SQL_TXN_READ_UNCOMMITTED = 1
163166
SQL_TXN_READ_COMMITTED = 2

mssql_python/cursor.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2607,15 +2607,36 @@ def _bulkcopy(
26072607
context = {
26082608
"server": params.get("server"),
26092609
"database": params.get("database"),
2610-
"user_name": params.get("uid", ""),
26112610
"trust_server_certificate": trust_cert,
26122611
"encryption": encryption,
26132612
}
26142613

2615-
# Extract password separately to avoid storing it in generic context that may be logged
2616-
password = params.get("pwd", "")
2614+
# Build pycore_context with appropriate authentication.
2615+
# For Azure AD: acquire a FRESH token right now instead of reusing
2616+
# the one from connect() time — avoids expired-token errors when
2617+
# bulkcopy() is called long after the original connection.
26172618
pycore_context = dict(context)
2618-
pycore_context["password"] = password
2619+
2620+
if self.connection._auth_type:
2621+
# Fresh token acquisition for mssql-py-core connection
2622+
from mssql_python.auth import AADAuth
2623+
2624+
try:
2625+
raw_token = AADAuth.get_raw_token(self.connection._auth_type)
2626+
except (RuntimeError, ValueError) as e:
2627+
raise RuntimeError(
2628+
f"Bulk copy failed: unable to acquire Azure AD token "
2629+
f"for auth_type '{self.connection._auth_type}': {e}"
2630+
) from e
2631+
pycore_context["access_token"] = raw_token
2632+
logger.debug(
2633+
"Bulk copy: acquired fresh Azure AD token for auth_type=%s",
2634+
self.connection._auth_type,
2635+
)
2636+
else:
2637+
# SQL Server authentication — use uid/password from connection string
2638+
pycore_context["user_name"] = params.get("uid", "")
2639+
pycore_context["password"] = params.get("pwd", "")
26192640

26202641
pycore_connection = None
26212642
pycore_cursor = None
@@ -2653,10 +2674,10 @@ def _bulkcopy(
26532674

26542675
finally:
26552676
# Clear sensitive data to minimize memory exposure
2656-
password = ""
26572677
if pycore_context:
2658-
pycore_context["password"] = ""
2659-
pycore_context["user_name"] = ""
2678+
pycore_context.pop("password", None)
2679+
pycore_context.pop("user_name", None)
2680+
pycore_context.pop("access_token", None)
26602681
# Clean up bulk copy resources
26612682
for resource in (pycore_cursor, pycore_connection):
26622683
if resource and hasattr(resource, "close"):

tests/test_008_auth.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77
import pytest
88
import platform
99
import sys
10+
from unittest.mock import patch, MagicMock
1011
from mssql_python.auth import (
1112
AADAuth,
1213
process_auth_parameters,
1314
remove_sensitive_params,
1415
get_auth_token,
1516
process_connection_string,
17+
extract_auth_type,
1618
)
17-
from mssql_python.constants import AuthType
19+
from mssql_python.constants import AuthType, ConstantsDDBC
1820
import secrets
1921

2022
SAMPLE_TOKEN = secrets.token_hex(44)
@@ -82,6 +84,11 @@ def test_get_token_struct(self):
8284
assert isinstance(token_struct, bytes)
8385
assert len(token_struct) > 4
8486

87+
def test_get_raw_token_default(self):
88+
raw_token = AADAuth.get_raw_token("default")
89+
assert isinstance(raw_token, str)
90+
assert raw_token == SAMPLE_TOKEN
91+
8592
def test_get_token_default(self):
8693
token_struct = AADAuth.get_token("default")
8794
assert isinstance(token_struct, bytes)
@@ -281,7 +288,7 @@ def test_interactive_auth_windows(self, monkeypatch):
281288
params = ["Authentication=ActiveDirectoryInteractive", "Server=test"]
282289
modified_params, auth_type = process_auth_parameters(params)
283290
assert "Authentication=ActiveDirectoryInteractive" in modified_params
284-
assert auth_type == None
291+
assert auth_type is None
285292

286293
def test_interactive_auth_non_windows(self, monkeypatch):
287294
monkeypatch.setattr(platform, "system", lambda: "Darwin")
@@ -326,34 +333,37 @@ def test_remove_sensitive_parameters(self):
326333
class TestProcessConnectionString:
327334
def test_process_connection_string_with_default_auth(self):
328335
conn_str = "Server=test;Authentication=ActiveDirectoryDefault;Database=testdb"
329-
result_str, attrs = process_connection_string(conn_str)
336+
result_str, attrs, auth_type = process_connection_string(conn_str)
330337

331338
assert "Server=test" in result_str
332339
assert "Database=testdb" in result_str
333340
assert attrs is not None
334-
assert 1256 in attrs
335-
assert isinstance(attrs[1256], bytes)
341+
assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in attrs
342+
assert isinstance(attrs[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value], bytes)
343+
assert auth_type == "default"
336344

337345
def test_process_connection_string_no_auth(self):
338346
conn_str = "Server=test;Database=testdb;UID=user;PWD=password"
339-
result_str, attrs = process_connection_string(conn_str)
347+
result_str, attrs, auth_type = process_connection_string(conn_str)
340348

341349
assert "Server=test" in result_str
342350
assert "Database=testdb" in result_str
343351
assert "UID=user" in result_str
344352
assert "PWD=password" in result_str
345353
assert attrs is None
354+
assert auth_type is None
346355

347356
def test_process_connection_string_interactive_non_windows(self, monkeypatch):
348357
monkeypatch.setattr(platform, "system", lambda: "Darwin")
349358
conn_str = "Server=test;Authentication=ActiveDirectoryInteractive;Database=testdb"
350-
result_str, attrs = process_connection_string(conn_str)
359+
result_str, attrs, auth_type = process_connection_string(conn_str)
351360

352361
assert "Server=test" in result_str
353362
assert "Database=testdb" in result_str
354363
assert attrs is not None
355-
assert 1256 in attrs
356-
assert isinstance(attrs[1256], bytes)
364+
assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in attrs
365+
assert isinstance(attrs[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value], bytes)
366+
assert auth_type == "interactive"
357367

358368

359369
def test_error_handling():
@@ -368,3 +378,42 @@ def test_error_handling():
368378
# Test non-string input
369379
with pytest.raises(ValueError, match="Connection string must be a string"):
370380
process_connection_string(None)
381+
382+
383+
class TestExtractAuthType:
384+
def test_interactive(self):
385+
assert (
386+
extract_auth_type("Server=test;Authentication=ActiveDirectoryInteractive;")
387+
== "interactive"
388+
)
389+
390+
def test_default(self):
391+
assert extract_auth_type("Server=test;Authentication=ActiveDirectoryDefault;") == "default"
392+
393+
def test_devicecode(self):
394+
assert (
395+
extract_auth_type("Server=test;Authentication=ActiveDirectoryDeviceCode;")
396+
== "devicecode"
397+
)
398+
399+
def test_no_auth(self):
400+
assert extract_auth_type("Server=test;Database=db;") is None
401+
402+
def test_unsupported_auth(self):
403+
assert extract_auth_type("Server=test;Authentication=SqlPassword;") is None
404+
405+
406+
def test_acquire_token_unsupported_auth_type():
407+
with pytest.raises(ValueError, match="Unsupported auth_type 'bogus'"):
408+
AADAuth._acquire_token("bogus")
409+
410+
411+
class TestConnectionAuthType:
412+
@patch("mssql_python.connection.ddbc_bindings.Connection")
413+
def test_auth_type_stored_on_connection(self, mock_ddbc_conn):
414+
mock_ddbc_conn.return_value = MagicMock()
415+
from mssql_python import connect
416+
417+
conn = connect("Server=test;Database=testdb;Authentication=ActiveDirectoryDefault")
418+
assert conn._auth_type == "default"
419+
conn.close()

0 commit comments

Comments
 (0)