Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 116 additions & 23 deletions mssql_python/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,57 @@
INFO_TYPE_STRING_THRESHOLD: int = 10000

# UTF-16 encoding variants that should use SQL_WCHAR by default
UTF16_ENCODINGS: frozenset[str] = frozenset(["utf-16", "utf-16le", "utf-16be"])
# Note: "utf-16" with BOM is NOT included as it's problematic for SQL_WCHAR
UTF16_ENCODINGS: frozenset[str] = frozenset(["utf-16le", "utf-16be"])


def _validate_utf16_wchar_compatibility(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document that SQL_WCHAR is always using UTF-16LE (ODBC specification requirement)

encoding: str, wchar_type: int, context: str = "SQL_WCHAR"
) -> None:
"""
Validates UTF-16 encoding compatibility with SQL_WCHAR.

Centralizes the validation logic to eliminate duplication across setencoding/setdecoding.

Args:
encoding: The encoding string (already normalized to lowercase)
wchar_type: The SQL_WCHAR constant value to check against
context: Context string for error messages ('SQL_WCHAR', 'SQL_WCHAR ctype', etc.)

Raises:
ProgrammingError: If encoding is incompatible with SQL_WCHAR
"""
if encoding == "utf-16":
# UTF-16 with BOM is rejected due to byte order ambiguity
logger.warning("utf-16 with BOM rejected for %s", context)
raise ProgrammingError(
driver_error="UTF-16 with Byte Order Mark not supported for SQL_WCHAR",
ddbc_error=(
"Cannot use 'utf-16' encoding with SQL_WCHAR due to Byte Order Mark ambiguity. "
"Use 'utf-16le' or 'utf-16be' instead for explicit byte order."
),
)
elif encoding not in UTF16_ENCODINGS:
# Non-UTF-16 encodings are not supported with SQL_WCHAR
logger.warning(
"Non-UTF-16 encoding %s attempted with %s", sanitize_user_input(encoding), context
)

# Generate context-appropriate error messages
if "ctype" in context:
driver_error = f"SQL_WCHAR ctype only supports UTF-16 encodings"
ddbc_context = "SQL_WCHAR ctype"
else:
driver_error = f"SQL_WCHAR only supports UTF-16 encodings"
ddbc_context = "SQL_WCHAR"

raise ProgrammingError(
driver_error=driver_error,
ddbc_error=(
f"Cannot use encoding '{encoding}' with {ddbc_context}. "
f"SQL_WCHAR requires UTF-16 encodings (utf-16le, utf-16be)"
),
)


def _validate_encoding(encoding: str) -> bool:
Expand All @@ -70,7 +120,21 @@ def _validate_encoding(encoding: str) -> bool:
Note:
Uses LRU cache to avoid repeated expensive codecs.lookup() calls.
Cache size is limited to 128 entries which should cover most use cases.
Also validates that encoding name only contains safe characters.
"""
# Basic security checks - prevent obvious attacks
if not encoding or not isinstance(encoding, str):
return False

# Check length limit (prevent DOS)
if len(encoding) > 100:
return False

# Prevent null bytes and control characters that could cause issues
if "\x00" in encoding or any(ord(c) < 32 and c not in "\t\n\r" for c in encoding):
return False

# Then check if it's a valid Python codec
try:
codecs.lookup(encoding)
return True
Expand Down Expand Up @@ -227,6 +291,11 @@ def __init__(
self._output_converters = {}
self._converters_lock = threading.Lock()

# Initialize encoding/decoding settings lock for thread safety
# This lock protects both _encoding_settings and _decoding_settings dictionaries
# to prevent race conditions when multiple threads are reading/writing encoding settings
self._encoding_lock = threading.RLock() # RLock allows recursive locking
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a use case of multiple readers and single writer. Is Rlock the right synchronization primitive for such a purpose ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rlock would allow a single reader at a time, which can be detrimental to performance, assuming that we expect many concurrent readers. IF we dont expect concurrent opearations, then is locking useful?

Copy link
Contributor Author

@jahnvi480 jahnvi480 Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for sharing your review, you are correct about this.
Solution that we can implement to remove Rlock and still not have race condition will be to have our own custom implementation by replacing with a ReadWriteLock that allows:

  • Multiple concurrent readers getencoding
  • Exclusive writer access setencoding/setdecoding
    Should I make the changes or have it for next PR?
    @sumitmsft @bewithgaurav @saurabh500 Please let me know your views

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jahnvi480 IMO we need to honor concurrent readers from performance perspectives. I think Saurabh has a good point, Rlock will be detrimental to perf. I would rather with a primitive where multiple readers can concurrently access the resource. I'd rather do that change here in this PR.


# Initialize search escape character
self._searchescape = None

Expand Down Expand Up @@ -416,8 +485,7 @@ def setencoding(self, encoding: Optional[str] = None, ctype: Optional[int] = Non
# Validate encoding using cached validation for better performance
if not _validate_encoding(encoding):
# Log the sanitized encoding for security
logger.debug(
"warning",
logger.warning(
"Invalid encoding attempted: %s",
sanitize_user_input(str(encoding)),
)
Expand All @@ -430,6 +498,10 @@ def setencoding(self, encoding: Optional[str] = None, ctype: Optional[int] = Non
encoding = encoding.casefold()
logger.debug("setencoding: Encoding normalized to %s", encoding)

# Early validation if ctype is already specified as SQL_WCHAR
if ctype == ConstantsDDBC.SQL_WCHAR.value:
_validate_utf16_wchar_compatibility(encoding, ctype, "SQL_WCHAR")

# Set default ctype based on encoding if not provided
if ctype is None:
if encoding in UTF16_ENCODINGS:
Expand All @@ -443,8 +515,7 @@ def setencoding(self, encoding: Optional[str] = None, ctype: Optional[int] = Non
valid_ctypes = [ConstantsDDBC.SQL_CHAR.value, ConstantsDDBC.SQL_WCHAR.value]
if ctype not in valid_ctypes:
# Log the sanitized ctype for security
logger.debug(
"warning",
logger.warning(
"Invalid ctype attempted: %s",
sanitize_user_input(str(ctype)),
)
Expand All @@ -456,20 +527,24 @@ def setencoding(self, encoding: Optional[str] = None, ctype: Optional[int] = Non
),
)

# Store the encoding settings
self._encoding_settings = {"encoding": encoding, "ctype": ctype}
# Final validation: SQL_WCHAR ctype only supports UTF-16 encodings (without BOM)
if ctype == ConstantsDDBC.SQL_WCHAR.value:
_validate_utf16_wchar_compatibility(encoding, ctype, "SQL_WCHAR")

# Store the encoding settings (thread-safe with lock)
with self._encoding_lock:
self._encoding_settings = {"encoding": encoding, "ctype": ctype}

# Log with sanitized values for security
logger.debug(
"info",
logger.info(
"Text encoding set to %s with ctype %s",
sanitize_user_input(encoding),
sanitize_user_input(str(ctype)),
)

def getencoding(self) -> Dict[str, Union[str, int]]:
"""
Gets the current text encoding settings.
Gets the current text encoding settings (thread-safe).

Returns:
dict: A dictionary containing 'encoding' and 'ctype' keys.
Expand All @@ -481,14 +556,19 @@ def getencoding(self) -> Dict[str, Union[str, int]]:
settings = cnxn.getencoding()
print(f"Current encoding: {settings['encoding']}")
print(f"Current ctype: {settings['ctype']}")

Note:
This method is thread-safe and can be called from multiple threads concurrently.
"""
if self._closed:
raise InterfaceError(
driver_error="Connection is closed",
ddbc_error="Connection is closed",
)

return self._encoding_settings.copy()
# Thread-safe read with lock to prevent race conditions
with self._encoding_lock:
return self._encoding_settings.copy()

def setdecoding(
self, sqltype: int, encoding: Optional[str] = None, ctype: Optional[int] = None
Expand Down Expand Up @@ -539,8 +619,7 @@ def setdecoding(
SQL_WMETADATA,
]
if sqltype not in valid_sqltypes:
logger.debug(
"warning",
logger.warning(
"Invalid sqltype attempted: %s",
sanitize_user_input(str(sqltype)),
)
Expand All @@ -562,8 +641,7 @@ def setdecoding(

# Validate encoding using cached validation for better performance
if not _validate_encoding(encoding):
logger.debug(
"warning",
logger.warning(
"Invalid encoding attempted: %s",
sanitize_user_input(str(encoding)),
)
Expand All @@ -575,6 +653,13 @@ def setdecoding(
# Normalize encoding to lowercase for consistency
encoding = encoding.lower()

# Validate SQL_WCHAR encoding compatibility
if sqltype == ConstantsDDBC.SQL_WCHAR.value:
_validate_utf16_wchar_compatibility(encoding, sqltype, "SQL_WCHAR sqltype")

# SQL_WMETADATA can use any valid encoding (UTF-8, UTF-16, etc.)
# No restriction needed here - let users configure as needed

# Set default ctype based on encoding if not provided
if ctype is None:
if encoding in UTF16_ENCODINGS:
Expand All @@ -585,8 +670,7 @@ def setdecoding(
# Validate ctype
valid_ctypes = [ConstantsDDBC.SQL_CHAR.value, ConstantsDDBC.SQL_WCHAR.value]
if ctype not in valid_ctypes:
logger.debug(
"warning",
logger.warning(
"Invalid ctype attempted: %s",
sanitize_user_input(str(ctype)),
)
Expand All @@ -598,8 +682,13 @@ def setdecoding(
),
)

# Store the decoding settings for the specified sqltype
self._decoding_settings[sqltype] = {"encoding": encoding, "ctype": ctype}
# Validate SQL_WCHAR ctype encoding compatibility
if ctype == ConstantsDDBC.SQL_WCHAR.value:
_validate_utf16_wchar_compatibility(encoding, ctype, "SQL_WCHAR ctype")

# Store the decoding settings for the specified sqltype (thread-safe with lock)
with self._encoding_lock:
self._decoding_settings[sqltype] = {"encoding": encoding, "ctype": ctype}

# Log with sanitized values for security
sqltype_name = {
Expand All @@ -608,8 +697,7 @@ def setdecoding(
SQL_WMETADATA: "SQL_WMETADATA",
}.get(sqltype, str(sqltype))

logger.debug(
"info",
logger.info(
"Text decoding set for %s to %s with ctype %s",
sqltype_name,
sanitize_user_input(encoding),
Expand All @@ -618,7 +706,7 @@ def setdecoding(

def getdecoding(self, sqltype: int) -> Dict[str, Union[str, int]]:
"""
Gets the current text decoding settings for the specified SQL type.
Gets the current text decoding settings for the specified SQL type (thread-safe).

Args:
sqltype (int): The SQL type to get settings for: SQL_CHAR, SQL_WCHAR, or SQL_WMETADATA.
Expand All @@ -634,6 +722,9 @@ def getdecoding(self, sqltype: int) -> Dict[str, Union[str, int]]:
settings = cnxn.getdecoding(mssql_python.SQL_CHAR)
print(f"SQL_CHAR encoding: {settings['encoding']}")
print(f"SQL_CHAR ctype: {settings['ctype']}")

Note:
This method is thread-safe and can be called from multiple threads concurrently.
"""
if self._closed:
raise InterfaceError(
Expand All @@ -657,7 +748,9 @@ def getdecoding(self, sqltype: int) -> Dict[str, Union[str, int]]:
),
)

return self._decoding_settings[sqltype].copy()
# Thread-safe read with lock to prevent race conditions
with self._encoding_lock:
return self._decoding_settings[sqltype].copy()

def set_attr(self, attribute: int, value: Union[int, str, bytes, bytearray]) -> None:
"""
Expand Down
Loading
Loading