Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mssql_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class ConstantsDDBC(Enum):
SQL_FETCH_RELATIVE = 6
SQL_FETCH_BOOKMARK = 8
SQL_DATETIMEOFFSET = -155
SQL_SS_UDT = -151 # SQL Server User-Defined Types (geometry, geography, hierarchyid)
SQL_C_SS_TIMESTAMPOFFSET = 0x4001
SQL_SCOPE_CURROW = 0
SQL_BEST_ROWID = 1
Expand Down Expand Up @@ -499,3 +500,5 @@ def get_attribute_set_timing(attribute):
# internally.
"packetsize": "PacketSize",
}

# (Function removed; no replacement needed)
25 changes: 18 additions & 7 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import datetime
import warnings
from typing import List, Union, Any, Optional, Tuple, Sequence, TYPE_CHECKING
import xml
from mssql_python.constants import ConstantsDDBC as ddbc_sql_const, SQLTypes
from mssql_python.helpers import check_error
from mssql_python.logging import logger
Expand Down Expand Up @@ -131,6 +132,9 @@ def __init__(self, connection: "Connection", timeout: int = 0) -> None:
)
self.messages = [] # Store diagnostic messages

# Store raw column metadata for converter lookups
self._column_metadata = None

def _is_unicode_string(self, param: str) -> bool:
"""
Check if a string contains non-ASCII characters.
Expand Down Expand Up @@ -836,8 +840,12 @@ def _initialize_description(self, column_metadata: Optional[Any] = None) -> None
"""Initialize the description attribute from column metadata."""
if not column_metadata:
self.description = None
self._column_metadata = None # Clear metadata too
return

# Store raw metadata for converter map building
self._column_metadata = column_metadata

description = []
for _, col in enumerate(column_metadata):
# Get column name - lowercase it if the lowercase flag is set
Expand All @@ -851,7 +859,7 @@ def _initialize_description(self, column_metadata: Optional[Any] = None) -> None
description.append(
(
column_name, # name
self._map_data_type(col["DataType"]), # type_code
col["DataType"], # type_code (SQL type integer) - CHANGED THIS LINE
None, # display_size
col["ColumnSize"], # internal_size
col["ColumnSize"], # precision - should match ColumnSize
Expand All @@ -869,18 +877,17 @@ def _build_converter_map(self):
"""
if (
not self.description
or not self._column_metadata
or not hasattr(self.connection, "_output_converters")
or not self.connection._output_converters
):
return None

converter_map = []

for desc in self.description:
if desc is None:
converter_map.append(None)
continue
sql_type = desc[1]
for col_meta in self._column_metadata:
# Use the raw SQL type code from metadata, not the mapped Python type
sql_type = col_meta["DataType"]
converter = self.connection.get_output_converter(sql_type)
# If no converter found for the SQL type, try the WVARCHAR converter as a fallback
if converter is None:
Expand Down Expand Up @@ -947,6 +954,11 @@ def _map_data_type(self, sql_type):
ddbc_sql_const.SQL_VARBINARY.value: bytes,
ddbc_sql_const.SQL_LONGVARBINARY.value: bytes,
ddbc_sql_const.SQL_GUID.value: uuid.UUID,
ddbc_sql_const.SQL_SS_UDT.value: bytes, # UDTs mapped to bytes
ddbc_sql_const.SQL_XML.value: str, # XML mapped to str
ddbc_sql_const.SQL_DATETIME2.value: datetime.datetime,
ddbc_sql_const.SQL_SMALLDATETIME.value: datetime.datetime,
ddbc_sql_const.SQL_DATETIMEOFFSET.value: datetime.datetime,
# Add more mappings as needed
}
return sql_to_python_type.get(sql_type, str)
Expand Down Expand Up @@ -2370,7 +2382,6 @@ def __del__(self):
Destructor to ensure the cursor is closed when it is no longer needed.
This is a safety net to ensure resources are cleaned up
even if close() was not called explicitly.
If the cursor is already closed, it will not raise an exception during cleanup.
"""
if "closed" not in self.__dict__ or not self.closed:
try:
Expand Down
29 changes: 22 additions & 7 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
#define MAX_DIGITS_IN_NUMERIC 64
#define SQL_MAX_NUMERIC_LEN 16
#define SQL_SS_XML (-152)
#define SQL_SS_UDT (-151) // SQL Server User-Defined Types (geometry, geography, hierarchyid)
#define SQL_DATETIME2 (42)
#define SQL_SMALLDATETIME (58)

#define STRINGIFY_FOR_CASE(x) \
case x: \
Expand Down Expand Up @@ -2827,6 +2830,11 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
}
break;
}
case SQL_SS_UDT: {
LOG("SQLGetData: Streaming UDT (geometry/geography) for column %d", i);
row.append(FetchLobColumnData(hStmt, i, SQL_C_BINARY, false, true));
break;
}
case SQL_SS_XML: {
LOG("SQLGetData: Streaming XML for column %d", i);
row.append(FetchLobColumnData(hStmt, i, SQL_C_WCHAR, true, false));
Expand Down Expand Up @@ -3050,6 +3058,8 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
}
case SQL_TIMESTAMP:
case SQL_TYPE_TIMESTAMP:
case SQL_DATETIME2:
case SQL_SMALLDATETIME:
case SQL_DATETIME: {
SQL_TIMESTAMP_STRUCT timestampValue;
ret = SQLGetData_ptr(hStmt, i, SQL_C_TYPE_TIMESTAMP, &timestampValue,
Expand Down Expand Up @@ -3633,6 +3643,8 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
}
case SQL_TIMESTAMP:
case SQL_TYPE_TIMESTAMP:
case SQL_DATETIME2:
case SQL_SMALLDATETIME:
case SQL_DATETIME: {
const SQL_TIMESTAMP_STRUCT& ts = buffers.timestampBuffers[col - 1][i];
PyObject* datetimeObj = PythonObjectCache::get_datetime_class()(
Expand Down Expand Up @@ -3812,6 +3824,9 @@ size_t calculateRowSize(py::list& columnNames, SQLUSMALLINT numCols) {
case SQL_SS_TIMESTAMPOFFSET:
rowSize += sizeof(DateTimeOffset);
break;
case SQL_SS_UDT:
rowSize += columnSize; // UDT types use column size as-is
break;
default:
std::wstring columnName = columnMeta["ColumnName"].cast<std::wstring>();
std::ostringstream errorString;
Expand Down Expand Up @@ -3864,28 +3879,29 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch

if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || dataType == SQL_VARCHAR ||
dataType == SQL_LONGVARCHAR || dataType == SQL_VARBINARY ||
dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML) &&
dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML || dataType == SQL_SS_UDT) &&
(columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) {
lobColumns.push_back(i + 1); // 1-based
}
}

SQLULEN numRowsFetched = 0;
// If we have LOBs → fall back to row-by-row fetch + SQLGetData_wrap
if (!lobColumns.empty()) {
LOG("FetchMany_wrap: LOB columns detected (%zu columns), using per-row "
"SQLGetData path",
lobColumns.size());
while (true) {
while (numRowsFetched < (SQLULEN)fetchSize) {
ret = SQLFetch_ptr(hStmt);
if (ret == SQL_NO_DATA)
break;
if (!SQL_SUCCEEDED(ret))
return ret;

py::list row;
SQLGetData_wrap(StatementHandle, numCols,
row); // <-- streams LOBs correctly
SQLGetData_wrap(StatementHandle, numCols, row); // <-- streams LOBs correctly
rows.append(row);
numRowsFetched++;
}
return SQL_SUCCESS;
}
Expand All @@ -3899,8 +3915,7 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch
LOG("FetchMany_wrap: Error when binding columns - SQLRETURN=%d", ret);
return ret;
}

SQLULEN numRowsFetched;

SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)(intptr_t)fetchSize, 0);
SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0);

Expand Down Expand Up @@ -3994,7 +4009,7 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows) {

if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || dataType == SQL_VARCHAR ||
dataType == SQL_LONGVARCHAR || dataType == SQL_VARBINARY ||
dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML) &&
dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML || dataType == SQL_SS_UDT) &&
(columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) {
lobColumns.push_back(i + 1); // 1-based
}
Expand Down
80 changes: 20 additions & 60 deletions tests/test_003_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4273,32 +4273,15 @@ def test_converter_integration(db_connection):
cursor = db_connection.cursor()
sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value

# Test with string converter
# Register converter for SQL_WVARCHAR type
db_connection.add_output_converter(sql_wvarchar, custom_string_converter)

# Test a simple string query
cursor.execute("SELECT N'test string' AS test_col")
row = cursor.fetchone()

# Check if the type matches what we expect for SQL_WVARCHAR
# For Cursor.description, the second element is the type code
column_type = cursor.description[0][1]

# If the cursor description has SQL_WVARCHAR as the type code,
# then our converter should be applied
if column_type == sql_wvarchar:
assert row[0].startswith("CONVERTED:"), "Output converter not applied"
else:
# If the type code is different, adjust the test or the converter
print(f"Column type is {column_type}, not {sql_wvarchar}")
# Add converter for the actual type used
db_connection.clear_output_converters()
db_connection.add_output_converter(column_type, custom_string_converter)

# Re-execute the query
cursor.execute("SELECT N'test string' AS test_col")
row = cursor.fetchone()
assert row[0].startswith("CONVERTED:"), "Output converter not applied"
# The converter should be applied based on the SQL type code
assert row[0].startswith("CONVERTED:"), "Output converter not applied"

# Clean up
db_connection.clear_output_converters()
Expand Down Expand Up @@ -4385,26 +4368,23 @@ def test_multiple_output_converters(db_connection):
"""Test that multiple output converters can work together"""
cursor = db_connection.cursor()

# Execute a query to get the actual type codes used
cursor.execute("SELECT CAST(42 AS INT) as int_col, N'test' as str_col")
int_type = cursor.description[0][1] # Type code for integer column
str_type = cursor.description[1][1] # Type code for string column
# Use SQL type constants directly
sql_integer = ConstantsDDBC.SQL_INTEGER.value # SQL type code for INT
sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value # SQL type code for NVARCHAR

# Add converter for string type
db_connection.add_output_converter(str_type, custom_string_converter)
db_connection.add_output_converter(sql_wvarchar, custom_string_converter)

# Add converter for integer type
def int_converter(value):
if value is None:
return None
# Convert from bytes to int and multiply by 2
if isinstance(value, bytes):
return int.from_bytes(value, byteorder="little") * 2
elif isinstance(value, int):
# Integers are already Python ints, so just multiply by 2
if isinstance(value, int):
return value * 2
return value

db_connection.add_output_converter(int_type, int_converter)
db_connection.add_output_converter(sql_integer, int_converter)

# Test query with both types
cursor.execute("SELECT CAST(42 AS INT) as int_col, N'test' as str_col")
Expand Down Expand Up @@ -4811,32 +4791,15 @@ def test_converter_integration(db_connection):
cursor = db_connection.cursor()
sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value

# Test with string converter
# Register converter for SQL_WVARCHAR type
db_connection.add_output_converter(sql_wvarchar, custom_string_converter)

# Test a simple string query
cursor.execute("SELECT N'test string' AS test_col")
row = cursor.fetchone()

# Check if the type matches what we expect for SQL_WVARCHAR
# For Cursor.description, the second element is the type code
column_type = cursor.description[0][1]

# If the cursor description has SQL_WVARCHAR as the type code,
# then our converter should be applied
if column_type == sql_wvarchar:
assert row[0].startswith("CONVERTED:"), "Output converter not applied"
else:
# If the type code is different, adjust the test or the converter
print(f"Column type is {column_type}, not {sql_wvarchar}")
# Add converter for the actual type used
db_connection.clear_output_converters()
db_connection.add_output_converter(column_type, custom_string_converter)

# Re-execute the query
cursor.execute("SELECT N'test string' AS test_col")
row = cursor.fetchone()
assert row[0].startswith("CONVERTED:"), "Output converter not applied"
# The converter should be applied based on the SQL type code
assert row[0].startswith("CONVERTED:"), "Output converter not applied"

# Clean up
db_connection.clear_output_converters()
Expand Down Expand Up @@ -4923,26 +4886,23 @@ def test_multiple_output_converters(db_connection):
"""Test that multiple output converters can work together"""
cursor = db_connection.cursor()

# Execute a query to get the actual type codes used
cursor.execute("SELECT CAST(42 AS INT) as int_col, N'test' as str_col")
int_type = cursor.description[0][1] # Type code for integer column
str_type = cursor.description[1][1] # Type code for string column
# Use SQL type constants directly
sql_integer = ConstantsDDBC.SQL_INTEGER.value # SQL type code for INT
sql_wvarchar = ConstantsDDBC.SQL_WVARCHAR.value # SQL type code for NVARCHAR

# Add converter for string type
db_connection.add_output_converter(str_type, custom_string_converter)
db_connection.add_output_converter(sql_wvarchar, custom_string_converter)

# Add converter for integer type
def int_converter(value):
if value is None:
return None
# Convert from bytes to int and multiply by 2
if isinstance(value, bytes):
return int.from_bytes(value, byteorder="little") * 2
elif isinstance(value, int):
# Integers are already Python ints, so just multiply by 2
if isinstance(value, int):
return value * 2
return value

db_connection.add_output_converter(int_type, int_converter)
db_connection.add_output_converter(sql_integer, int_converter)

# Test query with both types
cursor.execute("SELECT CAST(42 AS INT) as int_col, N'test' as str_col")
Expand Down
Loading
Loading