diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 33607f002..e2ef6e7e1 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -9,7 +9,7 @@ from typing import Tuple, Dict, Optional, List from mssql_python.logging import logger -from mssql_python.constants import AuthType +from mssql_python.constants import AuthType, ConstantsDDBC class AADAuth: @@ -30,7 +30,25 @@ def get_token_struct(token: str) -> bytes: @staticmethod def get_token(auth_type: str) -> bytes: - """Get token using the specified authentication type""" + """Get DDBC token struct for the specified authentication type.""" + token_struct, _ = AADAuth._acquire_token(auth_type) + return token_struct + + @staticmethod + def get_raw_token(auth_type: str) -> str: + """Acquire a fresh raw JWT for the mssql-py-core connection (bulk copy). + + This deliberately does NOT cache the credential or token — each call + creates a new Azure Identity credential instance and requests a token. + A fresh acquisition avoids expired-token errors when bulkcopy() is + called long after the original DDBC connect(). + """ + _, raw_token = AADAuth._acquire_token(auth_type) + return raw_token + + @staticmethod + def _acquire_token(auth_type: str) -> Tuple[bytes, str]: + """Internal: acquire token and return (ddbc_struct, raw_jwt).""" # Import Azure libraries inside method to support test mocking # pylint: disable=import-outside-toplevel try: @@ -53,7 +71,11 @@ def get_token(auth_type: str) -> bytes: "interactive": InteractiveBrowserCredential, } - credential_class = credential_map[auth_type] + credential_class = credential_map.get(auth_type) + if not credential_class: + raise ValueError( + f"Unsupported auth_type '{auth_type}'. " f"Supported: {', '.join(credential_map)}" + ) logger.info( "get_token: Starting Azure AD authentication - auth_type=%s, credential_class=%s", auth_type, @@ -61,22 +83,15 @@ def get_token(auth_type: str) -> bytes: ) try: - logger.debug( - "get_token: Creating credential instance - credential_class=%s", - credential_class.__name__, - ) credential = credential_class() - logger.debug( - "get_token: Requesting token from Azure AD - scope=https://database.windows.net/.default" - ) - token = credential.get_token("https://database.windows.net/.default").token + raw_token = credential.get_token("https://database.windows.net/.default").token logger.info( "get_token: Azure AD token acquired successfully - token_length=%d chars", - len(token), + len(raw_token), ) - return AADAuth.get_token_struct(token) + token_struct = AADAuth.get_token_struct(raw_token) + return token_struct, raw_token except ClientAuthenticationError as e: - # Re-raise with more specific context about Azure AD authentication failure logger.error( "get_token: Azure AD authentication failed - credential_class=%s, error=%s", credential_class.__name__, @@ -88,7 +103,6 @@ def get_token(auth_type: str) -> bytes: f"user cancellation, network issues, or unsupported configuration." ) from e except Exception as e: - # Catch any other unexpected exceptions logger.error( "get_token: Unexpected error during credential creation - credential_class=%s, error=%s", credential_class.__name__, @@ -180,7 +194,7 @@ def remove_sensitive_params(parameters: List[str]) -> List[str]: def get_auth_token(auth_type: str) -> Optional[bytes]: - """Get authentication token based on auth type""" + """Get DDBC authentication token struct based on auth type.""" logger.debug("get_auth_token: Starting - auth_type=%s", auth_type) if not auth_type: logger.debug("get_auth_token: No auth_type specified, returning None") @@ -202,9 +216,28 @@ def get_auth_token(auth_type: str) -> Optional[bytes]: return None +def extract_auth_type(connection_string: str) -> Optional[str]: + """Extract Entra ID auth type from a connection string. + + Used as a fallback when process_connection_string does not propagate + auth_type (e.g. Windows Interactive where DDBC handles auth natively). + Bulkcopy still needs the auth type to acquire a token via Azure Identity. + """ + auth_map = { + AuthType.INTERACTIVE.value: "interactive", + AuthType.DEVICE_CODE.value: "devicecode", + AuthType.DEFAULT.value: "default", + } + for part in connection_string.split(";"): + key, _, value = part.strip().partition("=") + if key.strip().lower() == "authentication": + return auth_map.get(value.strip().lower()) + return None + + def process_connection_string( connection_string: str, -) -> Tuple[str, Optional[Dict[int, bytes]]]: +) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str]]: """ Process connection string and handle authentication. @@ -212,7 +245,8 @@ def process_connection_string( connection_string: The connection string to process Returns: - Tuple[str, Optional[Dict]]: Processed connection string and attrs_before dict if needed + Tuple[str, Optional[Dict], Optional[str]]: Processed connection string, + attrs_before dict if needed, and auth_type string for bulk copy token acquisition Raises: ValueError: If the connection string is invalid or empty @@ -259,7 +293,11 @@ def process_connection_string( "process_connection_string: Token authentication configured successfully - auth_type=%s", auth_type, ) - return ";".join(modified_parameters) + ";", {1256: token_struct} + return ( + ";".join(modified_parameters) + ";", + {ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value: token_struct}, + auth_type, + ) else: logger.warning( "process_connection_string: Token acquisition failed, proceeding without token" @@ -269,4 +307,4 @@ def process_connection_string( "process_connection_string: Connection string processing complete - has_auth=%s", bool(auth_type), ) - return ";".join(modified_parameters) + ";", None + return ";".join(modified_parameters) + ";", None, auth_type diff --git a/mssql_python/connection.py b/mssql_python/connection.py index ba79e2a3f..c6c4944de 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -39,7 +39,7 @@ ProgrammingError, NotSupportedError, ) -from mssql_python.auth import process_connection_string +from mssql_python.auth import extract_auth_type, process_connection_string from mssql_python.constants import ConstantsDDBC, GetInfoConstants from mssql_python.connection_string_parser import _ConnectionStringParser from mssql_python.connection_string_builder import _ConnectionStringBuilder @@ -263,6 +263,11 @@ def __init__( }, } + # Auth type for acquiring fresh tokens at bulk copy time. + # We intentionally do NOT cache the token — a fresh one is acquired + # each time bulkcopy() is called to avoid expired-token errors. + self._auth_type = None + # Check if the connection string contains authentication parameters # This is important for processing the connection string correctly. # If authentication is specified, it will be processed to handle @@ -272,6 +277,10 @@ def __init__( self.connection_str = connection_result[0] if connection_result[1]: self._attrs_before.update(connection_result[1]) + # Store auth type so bulkcopy() can acquire a fresh token later. + # On Windows Interactive, process_connection_string returns None + # (DDBC handles auth natively), so fall back to the connection string. + self._auth_type = connection_result[2] or extract_auth_type(self.connection_str) self._closed = False self._timeout = timeout diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 03d40c833..1c4332969 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -158,6 +158,9 @@ class ConstantsDDBC(Enum): SQL_ATTR_SERVER_NAME = 13 SQL_ATTR_RESET_CONNECTION = 116 + # SQL Server-specific connection option constants + SQL_COPT_SS_ACCESS_TOKEN = 1256 + # Transaction Isolation Level Constants SQL_TXN_READ_UNCOMMITTED = 1 SQL_TXN_READ_COMMITTED = 2 diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 3dd7aa283..0e28168e8 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2607,15 +2607,36 @@ def _bulkcopy( context = { "server": params.get("server"), "database": params.get("database"), - "user_name": params.get("uid", ""), "trust_server_certificate": trust_cert, "encryption": encryption, } - # Extract password separately to avoid storing it in generic context that may be logged - password = params.get("pwd", "") + # Build pycore_context with appropriate authentication. + # For Azure AD: acquire a FRESH token right now instead of reusing + # the one from connect() time — avoids expired-token errors when + # bulkcopy() is called long after the original connection. pycore_context = dict(context) - pycore_context["password"] = password + + if self.connection._auth_type: + # Fresh token acquisition for mssql-py-core connection + from mssql_python.auth import AADAuth + + try: + raw_token = AADAuth.get_raw_token(self.connection._auth_type) + except (RuntimeError, ValueError) as e: + raise RuntimeError( + f"Bulk copy failed: unable to acquire Azure AD token " + f"for auth_type '{self.connection._auth_type}': {e}" + ) from e + pycore_context["access_token"] = raw_token + logger.debug( + "Bulk copy: acquired fresh Azure AD token for auth_type=%s", + self.connection._auth_type, + ) + else: + # SQL Server authentication — use uid/password from connection string + pycore_context["user_name"] = params.get("uid", "") + pycore_context["password"] = params.get("pwd", "") pycore_connection = None pycore_cursor = None @@ -2653,10 +2674,10 @@ def _bulkcopy( finally: # Clear sensitive data to minimize memory exposure - password = "" if pycore_context: - pycore_context["password"] = "" - pycore_context["user_name"] = "" + pycore_context.pop("password", None) + pycore_context.pop("user_name", None) + pycore_context.pop("access_token", None) # Clean up bulk copy resources for resource in (pycore_cursor, pycore_connection): if resource and hasattr(resource, "close"): diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index 0c0716cb6..f44bf86e2 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -7,14 +7,16 @@ import pytest import platform import sys +from unittest.mock import patch, MagicMock from mssql_python.auth import ( AADAuth, process_auth_parameters, remove_sensitive_params, get_auth_token, process_connection_string, + extract_auth_type, ) -from mssql_python.constants import AuthType +from mssql_python.constants import AuthType, ConstantsDDBC import secrets SAMPLE_TOKEN = secrets.token_hex(44) @@ -82,6 +84,11 @@ def test_get_token_struct(self): assert isinstance(token_struct, bytes) assert len(token_struct) > 4 + def test_get_raw_token_default(self): + raw_token = AADAuth.get_raw_token("default") + assert isinstance(raw_token, str) + assert raw_token == SAMPLE_TOKEN + def test_get_token_default(self): token_struct = AADAuth.get_token("default") assert isinstance(token_struct, bytes) @@ -281,7 +288,7 @@ def test_interactive_auth_windows(self, monkeypatch): params = ["Authentication=ActiveDirectoryInteractive", "Server=test"] modified_params, auth_type = process_auth_parameters(params) assert "Authentication=ActiveDirectoryInteractive" in modified_params - assert auth_type == None + assert auth_type is None def test_interactive_auth_non_windows(self, monkeypatch): monkeypatch.setattr(platform, "system", lambda: "Darwin") @@ -326,34 +333,37 @@ def test_remove_sensitive_parameters(self): class TestProcessConnectionString: def test_process_connection_string_with_default_auth(self): conn_str = "Server=test;Authentication=ActiveDirectoryDefault;Database=testdb" - result_str, attrs = process_connection_string(conn_str) + result_str, attrs, auth_type = process_connection_string(conn_str) assert "Server=test" in result_str assert "Database=testdb" in result_str assert attrs is not None - assert 1256 in attrs - assert isinstance(attrs[1256], bytes) + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in attrs + assert isinstance(attrs[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value], bytes) + assert auth_type == "default" def test_process_connection_string_no_auth(self): conn_str = "Server=test;Database=testdb;UID=user;PWD=password" - result_str, attrs = process_connection_string(conn_str) + result_str, attrs, auth_type = process_connection_string(conn_str) assert "Server=test" in result_str assert "Database=testdb" in result_str assert "UID=user" in result_str assert "PWD=password" in result_str assert attrs is None + assert auth_type is None def test_process_connection_string_interactive_non_windows(self, monkeypatch): monkeypatch.setattr(platform, "system", lambda: "Darwin") conn_str = "Server=test;Authentication=ActiveDirectoryInteractive;Database=testdb" - result_str, attrs = process_connection_string(conn_str) + result_str, attrs, auth_type = process_connection_string(conn_str) assert "Server=test" in result_str assert "Database=testdb" in result_str assert attrs is not None - assert 1256 in attrs - assert isinstance(attrs[1256], bytes) + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in attrs + assert isinstance(attrs[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value], bytes) + assert auth_type == "interactive" def test_error_handling(): @@ -368,3 +378,42 @@ def test_error_handling(): # Test non-string input with pytest.raises(ValueError, match="Connection string must be a string"): process_connection_string(None) + + +class TestExtractAuthType: + def test_interactive(self): + assert ( + extract_auth_type("Server=test;Authentication=ActiveDirectoryInteractive;") + == "interactive" + ) + + def test_default(self): + assert extract_auth_type("Server=test;Authentication=ActiveDirectoryDefault;") == "default" + + def test_devicecode(self): + assert ( + extract_auth_type("Server=test;Authentication=ActiveDirectoryDeviceCode;") + == "devicecode" + ) + + def test_no_auth(self): + assert extract_auth_type("Server=test;Database=db;") is None + + def test_unsupported_auth(self): + assert extract_auth_type("Server=test;Authentication=SqlPassword;") is None + + +def test_acquire_token_unsupported_auth_type(): + with pytest.raises(ValueError, match="Unsupported auth_type 'bogus'"): + AADAuth._acquire_token("bogus") + + +class TestConnectionAuthType: + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_auth_type_stored_on_connection(self, mock_ddbc_conn): + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + conn = connect("Server=test;Database=testdb;Authentication=ActiveDirectoryDefault") + assert conn._auth_type == "default" + conn.close()