diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 448b3a95..036f1b83 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -261,10 +261,14 @@ def __init__( } # Initialize decoding settings with Python 3 defaults + # SQL_CHAR default uses SQL_WCHAR ctype so the ODBC driver returns + # UTF-16 data for VARCHAR columns. This avoids encoding mismatches on + # Windows where the driver returns raw bytes in the server's native + # code page (e.g. CP-1252) that may fail to decode as UTF-8. self._decoding_settings = { ConstantsDDBC.SQL_CHAR.value: { - "encoding": "utf-8", - "ctype": ConstantsDDBC.SQL_CHAR.value, + "encoding": "utf-16le", + "ctype": ConstantsDDBC.SQL_WCHAR.value, }, ConstantsDDBC.SQL_WCHAR.value: { "encoding": "utf-16le", diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 7c8f9ad4..95b911c9 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2368,8 +2368,9 @@ def fetchone(self) -> Union[None, Row]: ret = ddbc_bindings.DDBCSQLFetchOne( self.hstmt, row_data, - char_decoding.get("encoding", "utf-8"), + char_decoding.get("encoding", "utf-16le"), wchar_decoding.get("encoding", "utf-16le"), + char_decoding.get("ctype", ddbc_sql_const.SQL_WCHAR.value), ) if self.hstmt: @@ -2434,8 +2435,9 @@ def fetchmany(self, size: Optional[int] = None) -> List[Row]: self.hstmt, rows_data, size, - char_decoding.get("encoding", "utf-8"), + char_decoding.get("encoding", "utf-16le"), wchar_decoding.get("encoding", "utf-16le"), + char_decoding.get("ctype", ddbc_sql_const.SQL_WCHAR.value), ) if self.hstmt: @@ -2492,8 +2494,9 @@ def fetchall(self) -> List[Row]: ret = ddbc_bindings.DDBCSQLFetchAll( self.hstmt, rows_data, - char_decoding.get("encoding", "utf-8"), + char_decoding.get("encoding", "utf-16le"), wchar_decoding.get("encoding", "utf-16le"), + char_decoding.get("ctype", ddbc_sql_const.SQL_WCHAR.value), ) # Check for errors diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index ee548319..4c69da3f 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -269,38 +269,38 @@ struct ArrowSchemaPrivateData { #define ARROW_FLAG_MAP_KEYS_SORTED 4 struct ArrowSchema { - // Array type description - const char* format; - const char* name; - const char* metadata; - int64_t flags; - int64_t n_children; - struct ArrowSchema** children; - struct ArrowSchema* dictionary; - - // Release callback - void (*release)(struct ArrowSchema*); - // Opaque producer-specific data - // Only our child-arrays will set this, so we can give it the correct type - ArrowSchemaPrivateData* private_data; + // Array type description + const char* format; + const char* name; + const char* metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema** children; + struct ArrowSchema* dictionary; + + // Release callback + void (*release)(struct ArrowSchema*); + // Opaque producer-specific data + // Only our child-arrays will set this, so we can give it the correct type + ArrowSchemaPrivateData* private_data; }; struct ArrowArray { - // Array data description - int64_t length; - int64_t null_count; - int64_t offset; - int64_t n_buffers; - int64_t n_children; - const void** buffers; - struct ArrowArray** children; - struct ArrowArray* dictionary; - - // Release callback - void (*release)(struct ArrowArray*); - // Opaque producer-specific data - // Only our child-arrays will set this, so we can give it the correct type - ArrowArrayPrivateData* private_data; + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void** buffers; + struct ArrowArray** children; + struct ArrowArray* dictionary; + + // Release callback + void (*release)(struct ArrowArray*); + // Opaque producer-specific data + // Only our child-arrays will set this, so we can give it the correct type + ArrowArrayPrivateData* private_data; }; #endif // ARROW_C_DATA_INTERFACE @@ -3112,13 +3112,14 @@ static inline bool IsLobOrVariantColumn(SQLSMALLINT dataType, SQLULEN columnSize // Helper function to retrieve column data SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, py::list& row, const std::string& charEncoding = "utf-8", - const std::string& wcharEncoding = "utf-16le") { + const std::string& wcharEncoding = "utf-16le", + int charCtype = SQL_C_WCHAR) { // Note: wcharEncoding parameter is reserved for future use // Currently WCHAR data always uses UTF-16LE for Windows compatibility (void)wcharEncoding; // Suppress unused parameter warning - LOG("SQLGetData: Getting data from %d columns for statement_handle=%p", colCount, - (void*)StatementHandle->get()); + LOG("SQLGetData: Getting data from %d columns for statement_handle=%p (charCtype=%d)", colCount, + (void*)StatementHandle->get(), charCtype); if (!SQLGetData_ptr) { LOG("SQLGetData: Function pointer not initialized, loading driver"); DriverLoader::getInstance().loadDriver(); // Load the driver @@ -3186,13 +3187,99 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: { + // When charCtype == SQL_C_WCHAR, ask ODBC to convert VARCHAR + // data to UTF-16. This avoids encoding mismatches on Windows + // where the driver returns raw bytes in the server's native + // code page (e.g. CP-1252) that may fail to decode as UTF-8. + // When charCtype == SQL_C_CHAR, use the existing narrow-char + // path with Python codec decoding. + // + // Exception: sql_variant columns always use SQL_C_CHAR. + // The variant probe call (SQLGetData with SQL_C_BINARY) has + // already consumed the column header, and requesting + // SQL_C_WCHAR after the probe fails on some ODBC drivers + // (notably unixODBC on Linux). SQL_C_CHAR works reliably + // because the Linux ODBC driver pre-converts to UTF-8. + const bool isSqlVariant = (dataType == SQL_SS_VARIANT); + const bool useWideChar = (charCtype == SQL_C_WCHAR) && !isSqlVariant; + + // For sql_variant, the SQL_C_CHAR path returns raw bytes in + // the server's native encoding (Windows) or UTF-8 + // (Linux/macOS, driver converts). Force "utf-8" so + // GetEffectiveCharDecoding picks the right codec on each + // platform, avoiding mismatch with the default "utf-16le" + // encoding which is only valid for the SQL_C_WCHAR path. + const std::string& effectiveCharEnc = + isSqlVariant ? std::string("utf-8") : charEncoding; + if (columnSize == SQL_NO_TOTAL || columnSize == 0 || columnSize > SQL_MAX_LOB_SIZE) { - LOG("SQLGetData: Streaming LOB for column %d (SQL_C_CHAR) " + LOG("SQLGetData: Streaming LOB for column %d (%s) " "- columnSize=%lu", - i, (unsigned long)columnSize); - row.append( - FetchLobColumnData(hStmt, i, SQL_C_CHAR, false, false, charEncoding)); + i, useWideChar ? "SQL_C_WCHAR" : "SQL_C_CHAR", (unsigned long)columnSize); + if (useWideChar) { + row.append( + FetchLobColumnData(hStmt, i, SQL_C_WCHAR, true, false, "utf-16le")); + } else { + row.append(FetchLobColumnData(hStmt, i, SQL_C_CHAR, false, false, + effectiveCharEnc)); + } + } else if (useWideChar) { + // Wide-char path: fetch VARCHAR data as SQL_C_WCHAR + uint64_t fetchBufferSize = + (columnSize + 1) * sizeof(SQLWCHAR); // +1 for null terminator + std::vector dataBuffer(columnSize + 1); + SQLLEN dataLen; + ret = SQLGetData_ptr(hStmt, i, SQL_C_WCHAR, dataBuffer.data(), fetchBufferSize, + &dataLen); + if (SQL_SUCCEEDED(ret)) { + if (dataLen > 0) { + uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR); + if (numCharsInData < dataBuffer.size()) { +#if defined(__APPLE__) || defined(__linux__) + std::wstring wstr = + SQLWCHARToWString(dataBuffer.data(), numCharsInData); + std::string utf8str = WideToUTF8(wstr); + row.append(py::str(utf8str)); +#else + std::wstring wstr(reinterpret_cast(dataBuffer.data())); + row.append(py::cast(wstr)); +#endif + LOG("SQLGetData: CHAR column %d fetched as WCHAR, " + "length=%lu", + i, (unsigned long)numCharsInData); + } else { + // Buffer too small, fallback to streaming + LOG("SQLGetData: CHAR column %d (WCHAR path) data " + "truncated, using streaming LOB", + i); + row.append(FetchLobColumnData(hStmt, i, SQL_C_WCHAR, true, false, + "utf-16le")); + } + } else if (dataLen == SQL_NULL_DATA) { + LOG("SQLGetData: Column %d is NULL (CHAR via WCHAR)", i); + row.append(py::none()); + } else if (dataLen == 0) { + row.append(py::str("")); + } else if (dataLen == SQL_NO_TOTAL) { + LOG("SQLGetData: Cannot determine data length " + "(SQL_NO_TOTAL) for column %d (CHAR via WCHAR), " + "returning NULL", + i); + row.append(py::none()); + } else if (dataLen < 0) { + LOG("SQLGetData: Unexpected negative data length " + "for column %d - dataType=%d, dataLen=%ld", + i, dataType, (long)dataLen); + ThrowStdException("SQLGetData returned an unexpected negative " + "data length"); + } + } else { + LOG("SQLGetData: Error retrieving data for column %d " + "(CHAR via WCHAR) - SQLRETURN=%d, returning NULL", + i, ret); + row.append(py::none()); + } } else { // Allocate columnSize * 4 + 1 on ALL platforms (no #if guard). // @@ -3223,7 +3310,7 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p // SQLGetData will null-terminate the data // Use Python's codec system to decode bytes. const std::string decodeEncoding = - GetEffectiveCharDecoding(charEncoding); + GetEffectiveCharDecoding(effectiveCharEnc); py::bytes raw_bytes(reinterpret_cast(dataBuffer.data()), static_cast(dataLen)); try { @@ -3247,7 +3334,7 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p "(buffer_size=%zu), using streaming LOB", i, dataBuffer.size()); row.append(FetchLobColumnData(hStmt, i, SQL_C_CHAR, false, false, - charEncoding)); + effectiveCharEnc)); } } else if (dataLen == SQL_NULL_DATA) { LOG("SQLGetData: Column %d is NULL (CHAR)", i); @@ -3709,8 +3796,9 @@ SQLRETURN SQLFetchScroll_wrap(SqlHandlePtr StatementHandle, SQLSMALLINT FetchOri // For column in the result set, binds a buffer to retrieve column data // TODO: Move to anonymous namespace, since it is not used outside this file SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& columnNames, - SQLUSMALLINT numCols, int fetchSize) { + SQLUSMALLINT numCols, int fetchSize, int charCtype = SQL_C_WCHAR) { SQLRETURN ret = SQL_SUCCESS; + const bool useWideChar = (charCtype == SQL_C_WCHAR); // Bind columns based on their data types for (SQLUSMALLINT col = 1; col <= numCols; col++) { auto columnMeta = columnNames[col - 1].cast(); @@ -3721,32 +3809,27 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: { - // TODO: handle variable length data correctly. This logic wont - // suffice HandleZeroColumnSizeAtFetch(columnSize); - // Use columnSize * 4 + 1 on Linux/macOS to accommodate UTF-8 - // expansion. The ODBC driver returns UTF-8 for SQL_C_CHAR where - // each character can be up to 4 bytes. + if (useWideChar) { + // Bind VARCHAR columns as SQL_C_WCHAR so the ODBC driver + // returns UTF-16 data, avoiding code-page decode issues. + uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; + buffers.wcharBuffers[col - 1].resize(fetchSize * fetchBufferSize); + ret = SQLBindCol_ptr( + hStmt, col, SQL_C_WCHAR, buffers.wcharBuffers[col - 1].data(), + fetchBufferSize * sizeof(SQLWCHAR), buffers.indicators[col - 1].data()); + } else { + // Original narrow-char path #if defined(__APPLE__) || defined(__linux__) - uint64_t fetchBufferSize = columnSize * 4 + 1 /*null-terminator*/; + uint64_t fetchBufferSize = columnSize * 4 + 1 /*null-terminator*/; #else - uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; + uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; #endif - // TODO: For LONGVARCHAR/BINARY types, columnSize is returned as - // 2GB-1 by SQLDescribeCol. So fetchBufferSize = 2GB. - // fetchSize=1 if columnSize>1GB. So we'll allocate a vector of - // size 2GB. If a query fetches multiple (say N) LONG... - // columns, we will have allocated multiple (N) 2GB sized - // vectors. This will make driver very slow. And if the N is - // high enough, we could hit the OS limit for heap memory that - // we can allocate, & hence get a std::bad_alloc. The process - // could also be killed by OS for consuming too much memory. - // Hence this will be revisited in beta to not allocate 2GB+ - // memory, & use streaming instead - buffers.charBuffers[col - 1].resize(fetchSize * fetchBufferSize); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_CHAR, buffers.charBuffers[col - 1].data(), - fetchBufferSize * sizeof(SQLCHAR), - buffers.indicators[col - 1].data()); + buffers.charBuffers[col - 1].resize(fetchSize * fetchBufferSize); + ret = SQLBindCol_ptr( + hStmt, col, SQL_C_CHAR, buffers.charBuffers[col - 1].data(), + fetchBufferSize * sizeof(SQLCHAR), buffers.indicators[col - 1].data()); + } break; } case SQL_WCHAR: @@ -3880,7 +3963,7 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& columnNames, py::list& rows, SQLUSMALLINT numCols, SQLULEN& numRowsFetched, const std::vector& lobColumns, - const std::string& charEncoding = "utf-8") { + const std::string& charEncoding = "utf-8", int charCtype = SQL_C_WCHAR) { LOG("FetchBatchData: Fetching data in batches"); SQLRETURN ret = SQLFetchScroll_ptr(hStmt, SQL_FETCH_NEXT, 0); if (ret == SQL_NO_DATA) { @@ -3901,6 +3984,7 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum uint64_t fetchBufferSize; bool isLob; }; + const bool useWideChar = (charCtype == SQL_C_WCHAR); std::vector columnInfos(numCols); for (SQLUSMALLINT col = 0; col < numCols; col++) { const auto& columnMeta = columnNames[col].cast(); @@ -3910,22 +3994,30 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum std::find(lobColumns.begin(), lobColumns.end(), col + 1) != lobColumns.end(); columnInfos[col].processedColumnSize = columnInfos[col].columnSize; HandleZeroColumnSizeAtFetch(columnInfos[col].processedColumnSize); - // On Linux/macOS, the ODBC driver returns UTF-8 for SQL_C_CHAR where - // each character can be up to 4 bytes. Must match SQLBindColums buffer. -#if defined(__APPLE__) || defined(__linux__) + SQLSMALLINT dt = columnInfos[col].dataType; bool isCharType = (dt == SQL_CHAR || dt == SQL_VARCHAR || dt == SQL_LONGVARCHAR); - if (isCharType) { - columnInfos[col].fetchBufferSize = columnInfos[col].processedColumnSize * 4 + - 1; // *4 for UTF-8, +1 for null terminator + + if (isCharType && useWideChar) { + // When VARCHAR is bound as SQL_C_WCHAR, buffer size is in SQLWCHAR + // units (same as NVARCHAR). +1 for null terminator. + columnInfos[col].fetchBufferSize = columnInfos[col].processedColumnSize + 1; } else { + // On Linux/macOS, the ODBC driver returns UTF-8 for SQL_C_CHAR where + // each character can be up to 4 bytes. Must match SQLBindColums buffer. +#if defined(__APPLE__) || defined(__linux__) + if (isCharType) { + columnInfos[col].fetchBufferSize = columnInfos[col].processedColumnSize * 4 + + 1; // *4 for UTF-8, +1 for null terminator + } else { + columnInfos[col].fetchBufferSize = + columnInfos[col].processedColumnSize + 1; // +1 for null terminator + } +#else columnInfos[col].fetchBufferSize = columnInfos[col].processedColumnSize + 1; // +1 for null terminator - } -#else - columnInfos[col].fetchBufferSize = - columnInfos[col].processedColumnSize + 1; // +1 for null terminator #endif + } } // Performance: Build function pointer dispatch table (once per batch) @@ -3947,6 +4039,10 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum columnInfosExt[col].isLob = columnInfos[col].isLob; columnInfosExt[col].charEncoding = effectiveCharEnc; columnInfosExt[col].isUtf8 = (effectiveCharEnc == "utf-8"); + // Set useWideChar for SQL_CHAR/VARCHAR columns when charCtype is SQL_C_WCHAR + SQLSMALLINT dt = columnInfos[col].dataType; + bool isCharType = (dt == SQL_CHAR || dt == SQL_VARCHAR || dt == SQL_LONGVARCHAR); + columnInfosExt[col].useWideChar = (isCharType && useWideChar); // Map data type to processor function (switch executed once per column, // not per cell) @@ -4328,7 +4424,8 @@ size_t calculateRowSize(py::list& columnNames, SQLUSMALLINT numCols) { // during fetching, it throws a runtime error. SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetchSize, const std::string& charEncoding = "utf-8", - const std::string& wcharEncoding = "utf-16le") { + const std::string& wcharEncoding = "utf-16le", + int charCtype = SQL_C_WCHAR) { SQLRETURN ret; SQLHSTMT hStmt = StatementHandle->get(); // Retrieve column count @@ -4368,8 +4465,8 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch return ret; py::list row; - SQLGetData_wrap(StatementHandle, numCols, row, charEncoding, - wcharEncoding); // <-- streams LOBs correctly + SQLGetData_wrap(StatementHandle, numCols, row, charEncoding, wcharEncoding, + charCtype); // <-- streams LOBs correctly rows.append(row); numRowsFetched++; } @@ -4380,7 +4477,7 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch ColumnBuffers buffers(numCols, fetchSize); // Bind columns - ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize); + ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize, charCtype); if (!SQL_SUCCEEDED(ret)) { LOG("FetchMany_wrap: Error when binding columns - SQLRETURN=%d", ret); return ret; @@ -4390,7 +4487,7 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0); ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched, lobColumns, - charEncoding); + charEncoding, charCtype); if (!SQL_SUCCEEDED(ret) && ret != SQL_NO_DATA) { LOG("FetchMany_wrap: Error when fetching data - SQLRETURN=%d", ret); return ret; @@ -4418,15 +4515,12 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch // @param indicator: Pointer to indicator value (SQL_NULL_DATA for NULL, or data length) // // @return SQLRETURN: SQL_SUCCESS on success, or error code on failure -template -SQLRETURN GetDataVar(SQLHSTMT hStmt, - SQLUSMALLINT colNumber, - SQLSMALLINT cType, - std::vector& dataVec, - SQLLEN* indicator) { +template +SQLRETURN GetDataVar(SQLHSTMT hStmt, SQLUSMALLINT colNumber, SQLSMALLINT cType, + std::vector& dataVec, SQLLEN* indicator) { size_t start = 0; size_t end = 0; - + // Determine null terminator size based on data type size_t sizeNullTerminator = 0; switch (cType) { @@ -4440,7 +4534,7 @@ SQLRETURN GetDataVar(SQLHSTMT hStmt, default: ThrowStdException("GetDataVar only supports SQL_C_CHAR, SQL_C_WCHAR, and SQL_C_BINARY"); } - + // Ensure initial buffer has space for at least the null terminator if (dataVec.size() < sizeNullTerminator) { dataVec.resize(sizeNullTerminator); @@ -4449,13 +4543,9 @@ SQLRETURN GetDataVar(SQLHSTMT hStmt, while (true) { SQLLEN localInd = 0; SQLRETURN ret = SQLGetData_ptr( - hStmt, - colNumber, - cType, - reinterpret_cast(dataVec.data() + start), + hStmt, colNumber, cType, reinterpret_cast(dataVec.data() + start), sizeof(T) * (dataVec.size() - start), // Available buffer size from start position - &localInd - ); + &localInd); // Handle NULL data if (localInd == SQL_NULL_DATA) { @@ -4489,10 +4579,10 @@ SQLRETURN GetDataVar(SQLHSTMT hStmt, assert(localInd % sizeof(T) == 0); end = start + static_cast(localInd) / sizeof(T) + sizeNullTerminator; } - + // The next read starts where the null terminator would have been placed start = dataVec.size() - sizeNullTerminator; - + // Resize buffer for next iteration dataVec.resize(end); } else { @@ -4529,17 +4619,14 @@ int32_t days_from_civil(int y, int m, int d) { // Returns number of days since Unix epoch (1970-01-01) y -= m <= 2; const int era = (y >= 0 ? y : y - 399) / 400; - const unsigned yoe = static_cast(y - era * 400); // [0, 399] - const unsigned doy = (153 * (m + (m > 2 ? -3 : 9)) + 2) / 5 + d - 1; // [0, 365] - const unsigned doe = yoe * 365 + yoe / 4 - yoe / 100 + doy; // [0, 146096] + const unsigned yoe = static_cast(y - era * 400); // [0, 399] + const unsigned doy = (153 * (m + (m > 2 ? -3 : 9)) + 2) / 5 + d - 1; // [0, 365] + const unsigned doe = yoe * 365 + yoe / 4 - yoe / 100 + doy; // [0, 146096] return era * 146097 + static_cast(doe) - 719468; } -SQLRETURN FetchArrowBatch_wrap( - SqlHandlePtr StatementHandle, - py::list& capsules, - int arrowBatchSize -) { +SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules, + int arrowBatchSize) { // An overly large fetch size doesn't seem to help performance int fetchSize = 64; @@ -4583,15 +4670,14 @@ SQLRETURN FetchArrowBatch_wrap( columnSizes[i] = columnSize; columnNullable[i] = (nullable != SQL_NO_NULLS); - if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || - dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR || - dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY || - dataType == SQL_SS_XML || dataType == SQL_SS_UDT) && + if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || dataType == SQL_VARCHAR || + dataType == SQL_LONGVARCHAR || dataType == SQL_VARBINARY || + dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML || dataType == SQL_SS_UDT) && (columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) { - hasLobColumns = true; - if (fetchSize > 1) { - fetchSize = 1; // LOBs require row-by-row fetch - } + hasLobColumns = true; + if (fetchSize > 1) { + fetchSize = 1; // LOBs require row-by-row fetch + } } std::string columnName = colMeta["ColumnName"].cast(); @@ -4600,7 +4686,7 @@ SQLRETURN FetchArrowBatch_wrap( std::memcpy(arrowSchemaPrivateData[i]->name.get(), columnName.c_str(), nameLen); std::string format = ""; - switch(dataType) { + switch (dataType) { case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: @@ -4663,7 +4749,8 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_DECIMAL: case SQL_NUMERIC: { std::ostringstream formatStream; - formatStream << "d:" << columnSize << "," << colMeta["DecimalDigits"].cast(); + formatStream << "d:" << columnSize << "," + << colMeta["DecimalDigits"].cast(); std::string formatStr = formatStream.str(); size_t formatLen = formatStr.length() + 1; arrowSchemaPrivateData[i]->format = std::make_unique(formatLen); @@ -4705,13 +4792,14 @@ SQLRETURN FetchArrowBatch_wrap( break; default: std::ostringstream errorString; - errorString << "Unsupported data type for Arrow batch fetch for column - " << columnName.c_str() - << ", Type - " << dataType << ", column ID - " << (i + 1); + errorString << "Unsupported data type for Arrow batch fetch for column - " + << columnName.c_str() << ", Type - " << dataType << ", column ID - " + << (i + 1); LOG(errorString.str().c_str()); ThrowStdException(errorString.str()); break; } - + // Store format string if not already stored. // For non-decimal types, format is now a static string. if (!arrowSchemaPrivateData[i]->format) { @@ -4729,14 +4817,15 @@ SQLRETURN FetchArrowBatch_wrap( ColumnBuffers buffers(numCols, fetchSize); if (!hasLobColumns && fetchSize > 0) { - // Bind columns - ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize); + // Bind columns — Arrow always uses SQL_C_CHAR for VARCHAR because + // it processes raw byte buffers directly, not via Python codecs. + ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize, SQL_C_CHAR); if (!SQL_SUCCEEDED(ret)) { LOG("Error when binding columns"); return ret; } } - + SQLULEN numRowsFetched = 0; FetchStateGuard fetchStateGuard(hStmt, &numRowsFetched, fetchSize); @@ -4750,7 +4839,7 @@ SQLRETURN FetchArrowBatch_wrap( } ret = SQLFetch_ptr(hStmt); if (ret == SQL_NO_DATA) { - ret = SQL_SUCCESS; // Normal completion + ret = SQL_SUCCESS; // Normal completion break; } if (!SQL_SUCCEEDED(ret)) { @@ -4769,18 +4858,14 @@ SQLRETURN FetchArrowBatch_wrap( if (hasLobColumns) { assert(idxRowSql == 0 && "GetData only works one row at a time"); - switch(dataType) { + switch (dataType) { case SQL_SS_UDT: case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: { - ret = GetDataVar( - hStmt, - idxCol + 1, - SQL_C_BINARY, - buffers.charBuffers[idxCol], - buffers.indicators[idxCol].data() - ); + ret = GetDataVar(hStmt, idxCol + 1, SQL_C_BINARY, + buffers.charBuffers[idxCol], + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching BINARY LOB for column %d", idxCol + 1); return ret; @@ -4790,13 +4875,9 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: { - ret = GetDataVar( - hStmt, - idxCol + 1, - SQL_C_CHAR, - buffers.charBuffers[idxCol], - buffers.indicators[idxCol].data() - ); + ret = GetDataVar(hStmt, idxCol + 1, SQL_C_CHAR, + buffers.charBuffers[idxCol], + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching CHAR LOB for column %d", idxCol + 1); return ret; @@ -4807,13 +4888,9 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_WCHAR: case SQL_WVARCHAR: case SQL_WLONGVARCHAR: { - ret = GetDataVar( - hStmt, - idxCol + 1, - SQL_C_WCHAR, - buffers.wcharBuffers[idxCol], - buffers.indicators[idxCol].data() - ); + ret = GetDataVar(hStmt, idxCol + 1, SQL_C_WCHAR, + buffers.wcharBuffers[idxCol], + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching WCHAR LOB data for column %d", idxCol + 1); return ret; @@ -4823,11 +4900,8 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_INTEGER: { buffers.intBuffers[idxCol].resize(1); ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_SLONG, - buffers.intBuffers[idxCol].data(), - sizeof(SQLINTEGER), - buffers.indicators[idxCol].data() - ); + hStmt, idxCol + 1, SQL_C_SLONG, buffers.intBuffers[idxCol].data(), + sizeof(SQLINTEGER), buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching SLONG data for column %d", idxCol + 1); return ret; @@ -4836,12 +4910,10 @@ SQLRETURN FetchArrowBatch_wrap( } case SQL_SMALLINT: { buffers.smallIntBuffers[idxCol].resize(1); - ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_SSHORT, - buffers.smallIntBuffers[idxCol].data(), - sizeof(SQLSMALLINT), - buffers.indicators[idxCol].data() - ); + ret = SQLGetData_ptr(hStmt, idxCol + 1, SQL_C_SSHORT, + buffers.smallIntBuffers[idxCol].data(), + sizeof(SQLSMALLINT), + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching SSHORT data for column %d", idxCol + 1); return ret; @@ -4850,12 +4922,10 @@ SQLRETURN FetchArrowBatch_wrap( } case SQL_TINYINT: { buffers.charBuffers[idxCol].resize(1); - ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_TINYINT, - buffers.charBuffers[idxCol].data(), - sizeof(SQLCHAR), - buffers.indicators[idxCol].data() - ); + ret = + SQLGetData_ptr(hStmt, idxCol + 1, SQL_C_TINYINT, + buffers.charBuffers[idxCol].data(), sizeof(SQLCHAR), + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching TINYINT data for column %d", idxCol + 1); return ret; @@ -4865,11 +4935,8 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_BIT: { buffers.charBuffers[idxCol].resize(1); ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_BIT, - buffers.charBuffers[idxCol].data(), - sizeof(SQLCHAR), - buffers.indicators[idxCol].data() - ); + hStmt, idxCol + 1, SQL_C_BIT, buffers.charBuffers[idxCol].data(), + sizeof(SQLCHAR), buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching BIT data for column %d", idxCol + 1); return ret; @@ -4879,11 +4946,8 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_REAL: { buffers.realBuffers[idxCol].resize(1); ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_FLOAT, - buffers.realBuffers[idxCol].data(), - sizeof(SQLREAL), - buffers.indicators[idxCol].data() - ); + hStmt, idxCol + 1, SQL_C_FLOAT, buffers.realBuffers[idxCol].data(), + sizeof(SQLREAL), buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching FLOAT data for column %d", idxCol + 1); return ret; @@ -4893,12 +4957,10 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_DECIMAL: case SQL_NUMERIC: { buffers.charBuffers[idxCol].resize(MAX_DIGITS_IN_NUMERIC); - ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_CHAR, - buffers.charBuffers[idxCol].data(), - MAX_DIGITS_IN_NUMERIC * sizeof(SQLCHAR), - buffers.indicators[idxCol].data() - ); + ret = SQLGetData_ptr(hStmt, idxCol + 1, SQL_C_CHAR, + buffers.charBuffers[idxCol].data(), + MAX_DIGITS_IN_NUMERIC * sizeof(SQLCHAR), + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching CHAR data for column %d", idxCol + 1); return ret; @@ -4908,12 +4970,10 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_DOUBLE: case SQL_FLOAT: { buffers.doubleBuffers[idxCol].resize(1); - ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_DOUBLE, - buffers.doubleBuffers[idxCol].data(), - sizeof(SQLDOUBLE), - buffers.indicators[idxCol].data() - ); + ret = SQLGetData_ptr(hStmt, idxCol + 1, SQL_C_DOUBLE, + buffers.doubleBuffers[idxCol].data(), + sizeof(SQLDOUBLE), + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching DOUBLE data for column %d", idxCol + 1); return ret; @@ -4924,12 +4984,10 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: { buffers.timestampBuffers[idxCol].resize(1); - ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_TYPE_TIMESTAMP, - buffers.timestampBuffers[idxCol].data(), - sizeof(SQL_TIMESTAMP_STRUCT), - buffers.indicators[idxCol].data() - ); + ret = SQLGetData_ptr(hStmt, idxCol + 1, SQL_C_TYPE_TIMESTAMP, + buffers.timestampBuffers[idxCol].data(), + sizeof(SQL_TIMESTAMP_STRUCT), + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching TYPE_TIMESTAMP data for column %d", idxCol + 1); return ret; @@ -4938,12 +4996,10 @@ SQLRETURN FetchArrowBatch_wrap( } case SQL_BIGINT: { buffers.bigIntBuffers[idxCol].resize(1); - ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_SBIGINT, - buffers.bigIntBuffers[idxCol].data(), - sizeof(SQLBIGINT), - buffers.indicators[idxCol].data() - ); + ret = SQLGetData_ptr(hStmt, idxCol + 1, SQL_C_SBIGINT, + buffers.bigIntBuffers[idxCol].data(), + sizeof(SQLBIGINT), + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching SBIGINT data for column %d", idxCol + 1); return ret; @@ -4952,12 +5008,10 @@ SQLRETURN FetchArrowBatch_wrap( } case SQL_TYPE_DATE: { buffers.dateBuffers[idxCol].resize(1); - ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_TYPE_DATE, - buffers.dateBuffers[idxCol].data(), - sizeof(SQL_DATE_STRUCT), - buffers.indicators[idxCol].data() - ); + ret = SQLGetData_ptr(hStmt, idxCol + 1, SQL_C_TYPE_DATE, + buffers.dateBuffers[idxCol].data(), + sizeof(SQL_DATE_STRUCT), + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching TYPE_DATE data for column %d", idxCol + 1); return ret; @@ -4968,12 +5022,10 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_TYPE_TIME: case SQL_SS_TIME2: { buffers.timeBuffers[idxCol].resize(1); - ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_TYPE_TIME, - buffers.timeBuffers[idxCol].data(), - sizeof(SQL_TIME_STRUCT), - buffers.indicators[idxCol].data() - ); + ret = SQLGetData_ptr(hStmt, idxCol + 1, SQL_C_TYPE_TIME, + buffers.timeBuffers[idxCol].data(), + sizeof(SQL_TIME_STRUCT), + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching TYPE_TIME data for column %d", idxCol + 1); return ret; @@ -4983,11 +5035,8 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_GUID: { buffers.guidBuffers[idxCol].resize(1); ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_GUID, - buffers.guidBuffers[idxCol].data(), - sizeof(SQLGUID), - buffers.indicators[idxCol].data() - ); + hStmt, idxCol + 1, SQL_C_GUID, buffers.guidBuffers[idxCol].data(), + sizeof(SQLGUID), buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching GUID data for column %d", idxCol + 1); return ret; @@ -4996,14 +5045,13 @@ SQLRETURN FetchArrowBatch_wrap( } case SQL_SS_TIMESTAMPOFFSET: { buffers.datetimeoffsetBuffers[idxCol].resize(1); - ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_SS_TIMESTAMPOFFSET, - buffers.datetimeoffsetBuffers[idxCol].data(), - sizeof(DateTimeOffset), - buffers.indicators[idxCol].data() - ); + ret = SQLGetData_ptr(hStmt, idxCol + 1, SQL_C_SS_TIMESTAMPOFFSET, + buffers.datetimeoffsetBuffers[idxCol].data(), + sizeof(DateTimeOffset), + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { - LOG("Error fetching SS_TIMESTAMPOFFSET data for column %d", idxCol + 1); + LOG("Error fetching SS_TIMESTAMPOFFSET data for column %d", + idxCol + 1); return ret; } break; @@ -5029,8 +5077,7 @@ SQLRETURN FetchArrowBatch_wrap( // Value buffer for variable length data types needs to be set appropriately // as it will be used by the next non null value - switch (dataType) - { + switch (dataType) { case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: @@ -5043,7 +5090,8 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: - arrowColumnProducer->varVal[idxRowArrow + 1] = arrowColumnProducer->varVal[idxRowArrow]; + arrowColumnProducer->varVal[idxRowArrow + 1] = + arrowColumnProducer->varVal[idxRowArrow]; break; default: break; @@ -5053,7 +5101,9 @@ SQLRETURN FetchArrowBatch_wrap( continue; } else if (indicator < 0) { // Negative value is unexpected, log column index, SQL type & raise exception - LOG("Unexpected negative data length. Column ID - %d, SQL Type - %d, Data Length - %lld", idxCol + 1, dataType, (long long)indicator); + LOG("Unexpected negative data length. Column ID - %d, SQL Type - %d, Data " + "Length - %lld", + idxCol + 1, dataType, (long long)indicator); ThrowStdException("Unexpected negative data length."); } auto dataLen = static_cast(indicator); @@ -5070,7 +5120,9 @@ SQLRETURN FetchArrowBatch_wrap( target_vec->resize(target_vec->size() * 2); } - std::memcpy(&(*target_vec)[start], &buffers.charBuffers[idxCol][idxRowSql * fetchBufferSize], dataLen); + std::memcpy(&(*target_vec)[start], + &buffers.charBuffers[idxCol][idxRowSql * fetchBufferSize], + dataLen); arrowColumnProducer->varVal[idxRowArrow + 1] = start + dataLen; break; } @@ -5088,7 +5140,9 @@ SQLRETURN FetchArrowBatch_wrap( target_vec->resize(target_vec->size() * 2); } - std::memcpy(&(*target_vec)[start], &buffers.charBuffers[idxCol][idxRowSql * fetchBufferSize], dataLen); + std::memcpy(&(*target_vec)[start], + &buffers.charBuffers[idxCol][idxRowSql * fetchBufferSize], + dataLen); arrowColumnProducer->varVal[idxRowArrow + 1] = start + dataLen; break; } @@ -5098,16 +5152,21 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_WLONGVARCHAR: { assert(dataLen % sizeof(SQLWCHAR) == 0); auto dataLenW = dataLen / sizeof(SQLWCHAR); - auto wcharSource = &buffers.wcharBuffers[idxCol][idxRowSql * (columnSize + 1)]; + auto wcharSource = + &buffers.wcharBuffers[idxCol][idxRowSql * (columnSize + 1)]; auto start = arrowColumnProducer->varVal[idxRowArrow]; auto target_vec = &arrowColumnProducer->varData; #if defined(_WIN32) // Convert wide string - int dataLenConverted = WideCharToMultiByte(CP_UTF8, 0, wcharSource, static_cast(dataLenW), NULL, 0, NULL, NULL); + int dataLenConverted = + WideCharToMultiByte(CP_UTF8, 0, wcharSource, static_cast(dataLenW), + NULL, 0, NULL, NULL); while (target_vec->size() < start + dataLenConverted) { target_vec->resize(target_vec->size() * 2); } - WideCharToMultiByte(CP_UTF8, 0, wcharSource, static_cast(dataLenW), reinterpret_cast(&(*target_vec)[start]), dataLenConverted, NULL, NULL); + WideCharToMultiByte(CP_UTF8, 0, wcharSource, static_cast(dataLenW), + reinterpret_cast(&(*target_vec)[start]), + dataLenConverted, NULL, NULL); arrowColumnProducer->varVal[idxRowArrow + 1] = start + dataLenConverted; #else // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 @@ -5121,8 +5180,9 @@ SQLRETURN FetchArrowBatch_wrap( break; } case SQL_GUID: { - // GUID is stored as a 36-character string in Arrow (e.g., "550e8400-e29b-41d4-a716-446655440000") - // Each GUID is exactly 36 bytes in UTF-8 + // GUID is stored as a 36-character string in Arrow (e.g., + // "550e8400-e29b-41d4-a716-446655440000") Each GUID is exactly 36 bytes in + // UTF-8 auto target_vec = &arrowColumnProducer->varData; auto start = arrowColumnProducer->varVal[idxRowArrow]; @@ -5136,37 +5196,40 @@ SQLRETURN FetchArrowBatch_wrap( // Convert GUID to string format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx snprintf(reinterpret_cast(&target_vec->data()[start]), 37, - "%08X-%04X-%04X-%02X%02X-%02X%02X%02X%02X%02X%02X", - guidValue.Data1, - guidValue.Data2, - guidValue.Data3, - guidValue.Data4[0], guidValue.Data4[1], - guidValue.Data4[2], guidValue.Data4[3], - guidValue.Data4[4], guidValue.Data4[5], - guidValue.Data4[6], guidValue.Data4[7]); + "%08X-%04X-%04X-%02X%02X-%02X%02X%02X%02X%02X%02X", + guidValue.Data1, guidValue.Data2, guidValue.Data3, + guidValue.Data4[0], guidValue.Data4[1], guidValue.Data4[2], + guidValue.Data4[3], guidValue.Data4[4], guidValue.Data4[5], + guidValue.Data4[6], guidValue.Data4[7]); // Update offset for next row, ignoring null terminator arrowColumnProducer->varVal[idxRowArrow + 1] = start + 36; break; } case SQL_TINYINT: - arrowColumnProducer->uint8Val[idxRowArrow] = buffers.charBuffers[idxCol][idxRowSql]; + arrowColumnProducer->uint8Val[idxRowArrow] = + buffers.charBuffers[idxCol][idxRowSql]; break; case SQL_SMALLINT: - arrowColumnProducer->int16Val[idxRowArrow] = buffers.smallIntBuffers[idxCol][idxRowSql]; + arrowColumnProducer->int16Val[idxRowArrow] = + buffers.smallIntBuffers[idxCol][idxRowSql]; break; case SQL_INTEGER: - arrowColumnProducer->int32Val[idxRowArrow] = buffers.intBuffers[idxCol][idxRowSql]; + arrowColumnProducer->int32Val[idxRowArrow] = + buffers.intBuffers[idxCol][idxRowSql]; break; case SQL_BIGINT: - arrowColumnProducer->int64Val[idxRowArrow] = buffers.bigIntBuffers[idxCol][idxRowSql]; + arrowColumnProducer->int64Val[idxRowArrow] = + buffers.bigIntBuffers[idxCol][idxRowSql]; break; case SQL_REAL: - arrowColumnProducer->float32Val[idxRowArrow] = buffers.realBuffers[idxCol][idxRowSql]; + arrowColumnProducer->float32Val[idxRowArrow] = + buffers.realBuffers[idxCol][idxRowSql]; break; case SQL_FLOAT: case SQL_DOUBLE: - arrowColumnProducer->float64Val[idxRowArrow] = buffers.doubleBuffers[idxCol][idxRowSql]; + arrowColumnProducer->float64Val[idxRowArrow] = + buffers.doubleBuffers[idxCol][idxRowSql]; break; case SQL_DECIMAL: case SQL_NUMERIC: { @@ -5180,23 +5243,23 @@ SQLRETURN FetchArrowBatch_wrap( if (digitChar == '-') { sign = -1; } else if (digitChar >= '0' && digitChar <= '9') { - decimalValue = decimalValue.multiply_by_10() + (uint64_t)(digitChar - '0'); + decimalValue = + decimalValue.multiply_by_10() + (uint64_t)(digitChar - '0'); } } - arrowColumnProducer->decimalVal[idxRowArrow] = (sign > 0) ? decimalValue : -decimalValue; + arrowColumnProducer->decimalVal[idxRowArrow] = + (sign > 0) ? decimalValue : -decimalValue; break; } case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: { - SQL_TIMESTAMP_STRUCT sql_value = buffers.timestampBuffers[idxCol][idxRowSql]; - int64_t days = days_from_civil( - sql_value.year, - sql_value.month, - sql_value.day - ); - arrowColumnProducer->tsMicroVal[idxRowArrow] = - days * 86400 * 1000000 + + SQL_TIMESTAMP_STRUCT sql_value = + buffers.timestampBuffers[idxCol][idxRowSql]; + int64_t days = + days_from_civil(sql_value.year, sql_value.month, sql_value.day); + arrowColumnProducer->tsMicroVal[idxRowArrow] = + days * 86400 * 1000000 + static_cast(sql_value.hour) * 3600 * 1000000 + static_cast(sql_value.minute) * 60 * 1000000 + static_cast(sql_value.second) * 1000000 + @@ -5205,33 +5268,34 @@ SQLRETURN FetchArrowBatch_wrap( } case SQL_SS_TIMESTAMPOFFSET: { DateTimeOffset sql_value = buffers.datetimeoffsetBuffers[idxCol][idxRowSql]; - int64_t days = days_from_civil( - sql_value.year, - sql_value.month, - sql_value.day - ); - arrowColumnProducer->tsMicroVal[idxRowArrow] = - days * 86400 * 1000000 + - (static_cast(sql_value.hour) - static_cast(sql_value.timezone_hour)) * 3600 * 1000000 + - (static_cast(sql_value.minute) - static_cast(sql_value.timezone_minute)) * 60 * 1000000 + + int64_t days = + days_from_civil(sql_value.year, sql_value.month, sql_value.day); + arrowColumnProducer->tsMicroVal[idxRowArrow] = + days * 86400 * 1000000 + + (static_cast(sql_value.hour) - + static_cast(sql_value.timezone_hour)) * + 3600 * 1000000 + + (static_cast(sql_value.minute) - + static_cast(sql_value.timezone_minute)) * + 60 * 1000000 + static_cast(sql_value.second) * 1000000 + static_cast(sql_value.fraction) / 1000; break; } case SQL_TYPE_DATE: - arrowColumnProducer->dateVal[idxRowArrow] = days_from_civil( - buffers.dateBuffers[idxCol][idxRowSql].year, - buffers.dateBuffers[idxCol][idxRowSql].month, - buffers.dateBuffers[idxCol][idxRowSql].day - ); + arrowColumnProducer->dateVal[idxRowArrow] = + days_from_civil(buffers.dateBuffers[idxCol][idxRowSql].year, + buffers.dateBuffers[idxCol][idxRowSql].month, + buffers.dateBuffers[idxCol][idxRowSql].day); break; case SQL_TIME: case SQL_TYPE_TIME: case SQL_SS_TIME2: { - // NOTE: SQL_SS_TIME2 supports fractional seconds, but SQL_C_TYPE_TIME does not. - // To fully support SQL_SS_TIME2, the corresponding c-type should be used. + // NOTE: SQL_SS_TIME2 supports fractional seconds, but SQL_C_TYPE_TIME does + // not. To fully support SQL_SS_TIME2, the corresponding c-type should be + // used. const SQL_TIME_STRUCT& timeValue = buffers.timeBuffers[idxCol][idxRowSql]; - arrowColumnProducer->timeSecondVal[idxRowArrow] = + arrowColumnProducer->timeSecondVal[idxRowArrow] = static_cast(timeValue.hour) * 3600 + static_cast(timeValue.minute) * 60 + static_cast(timeValue.second); @@ -5241,11 +5305,11 @@ SQLRETURN FetchArrowBatch_wrap( // SQL_BIT is stored as a single bit in Arrow's bitmap format // Get the boolean value from the buffer bool bitValue = buffers.charBuffers[idxCol][idxRowSql] != 0; - + // Set the bit in the Arrow bitmap size_t byteIndex = idxRowArrow / 8; size_t bitIndex = idxRowArrow % 8; - + if (bitValue) { // Set bit to 1 arrowColumnProducer->bitVal[byteIndex] |= (1 << bitIndex); @@ -5281,7 +5345,7 @@ SQLRETURN FetchArrowBatch_wrap( // Second, transfer ownership to arrowSchemaBatch // No unhandled exceptions until the pycapsule owns the arrowSchemaBatch to avoid memory leaks - + for (SQLSMALLINT i = 0; i < numCols; i++) { *arrowSchemaBatchChildPointers[i] = { arrowSchemaPrivateData[i]->format.get(), @@ -5296,7 +5360,7 @@ SQLRETURN FetchArrowBatch_wrap( assert(schema->release != nullptr); assert(schema->private_data != nullptr); assert(schema->children == nullptr && schema->n_children == 0); - delete schema->private_data; // Frees format and name + delete schema->private_data; // Frees format and name schema->release = nullptr; }, arrowSchemaPrivateData[i].release(), @@ -5339,13 +5403,14 @@ SQLRETURN FetchArrowBatch_wrap( // Finally, transfer ownership of arrowSchemaBatch and its pointer to pycapsule py::capsule arrowSchemaBatchCapsule; try { - arrowSchemaBatchCapsule = py::capsule(arrowSchemaBatch.get(), "arrow_schema", [](void* ptr) { - auto arrowSchema = static_cast(ptr); - if (arrowSchema->release) { - arrowSchema->release(arrowSchema); - } - delete arrowSchema; - }); + arrowSchemaBatchCapsule = + py::capsule(arrowSchemaBatch.get(), "arrow_schema", [](void* ptr) { + auto arrowSchema = static_cast(ptr); + if (arrowSchema->release) { + arrowSchema->release(arrowSchema); + } + delete arrowSchema; + }); } catch (...) { arrowSchemaBatch->release(arrowSchemaBatch.get()); throw; @@ -5389,7 +5454,7 @@ SQLRETURN FetchArrowBatch_wrap( assert(array->release != nullptr); assert(array->children == nullptr); assert(array->n_children == 0); - delete array->private_data; // Frees all buffer entries + delete array->private_data; // Frees all buffer entries assert(array->buffers != nullptr); array->release = nullptr; }, @@ -5472,7 +5537,8 @@ SQLRETURN FetchArrowBatch_wrap( // throws a runtime error. SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows, const std::string& charEncoding = "utf-8", - const std::string& wcharEncoding = "utf-16le") { + const std::string& wcharEncoding = "utf-16le", + int charCtype = SQL_C_WCHAR) { SQLRETURN ret; SQLHSTMT hStmt = StatementHandle->get(); // Retrieve column count @@ -5512,8 +5578,8 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows, return ret; py::list row; - SQLGetData_wrap(StatementHandle, numCols, row, charEncoding, - wcharEncoding); // <-- streams LOBs correctly + SQLGetData_wrap(StatementHandle, numCols, row, charEncoding, wcharEncoding, + charCtype); // <-- streams LOBs correctly rows.append(row); } return SQL_SUCCESS; @@ -5563,7 +5629,7 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows, ColumnBuffers buffers(numCols, fetchSize); // Bind columns - ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize); + ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize, charCtype); if (!SQL_SUCCEEDED(ret)) { LOG("FetchAll_wrap: Error when binding columns - SQLRETURN=%d", ret); return ret; @@ -5575,7 +5641,7 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows, while (ret != SQL_NO_DATA) { ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched, lobColumns, - charEncoding); + charEncoding, charCtype); if (!SQL_SUCCEEDED(ret) && ret != SQL_NO_DATA) { LOG("FetchAll_wrap: Error when fetching data - SQLRETURN=%d", ret); return ret; @@ -5610,7 +5676,8 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows, // fetching, it throws a runtime error. SQLRETURN FetchOne_wrap(SqlHandlePtr StatementHandle, py::list& row, const std::string& charEncoding = "utf-8", - const std::string& wcharEncoding = "utf-16le") { + const std::string& wcharEncoding = "utf-16le", + int charCtype = SQL_C_WCHAR) { SQLRETURN ret; SQLHSTMT hStmt = StatementHandle->get(); @@ -5624,7 +5691,8 @@ SQLRETURN FetchOne_wrap(SqlHandlePtr StatementHandle, py::list& row, if (SQL_SUCCEEDED(ret)) { // Retrieve column count SQLSMALLINT colCount = SQLNumResultCols_wrap(StatementHandle); - ret = SQLGetData_wrap(StatementHandle, colCount, row, charEncoding, wcharEncoding); + ret = + SQLGetData_wrap(StatementHandle, colCount, row, charEncoding, wcharEncoding, charCtype); if (!SQL_SUCCEEDED(ret)) { LOG("FetchOne_wrap: Error retrieving data with SQLGetData - SQLRETURN=%d", ret); return ret; @@ -5778,14 +5846,16 @@ PYBIND11_MODULE(ddbc_bindings, m) { m.def("DDBCSQLMoreResults", &SQLMoreResults_wrap, "Check for more results in the result set"); m.def("DDBCSQLFetchOne", &FetchOne_wrap, "Fetch one row from the result set", py::arg("StatementHandle"), py::arg("row"), py::arg("charEncoding") = "utf-8", - py::arg("wcharEncoding") = "utf-16le"); + py::arg("wcharEncoding") = "utf-16le", py::arg("charCtype") = SQL_C_WCHAR); m.def("DDBCSQLFetchMany", &FetchMany_wrap, py::arg("StatementHandle"), py::arg("rows"), py::arg("fetchSize"), py::arg("charEncoding") = "utf-8", - py::arg("wcharEncoding") = "utf-16le", "Fetch many rows from the result set"); + py::arg("wcharEncoding") = "utf-16le", py::arg("charCtype") = SQL_C_WCHAR, + "Fetch many rows from the result set"); m.def("DDBCSQLFetchAll", &FetchAll_wrap, "Fetch all rows from the result set", py::arg("StatementHandle"), py::arg("rows"), py::arg("charEncoding") = "utf-8", - py::arg("wcharEncoding") = "utf-16le"); - m.def("DDBCSQLFetchArrowBatch", &FetchArrowBatch_wrap, "Fetch an arrow batch of given length from the result set"); + py::arg("wcharEncoding") = "utf-16le", py::arg("charCtype") = SQL_C_WCHAR); + m.def("DDBCSQLFetchArrowBatch", &FetchArrowBatch_wrap, + "Fetch an arrow batch of given length from the result set"); m.def("DDBCSQLFreeHandle", &SQLFreeHandle_wrap, "Free a handle"); m.def("DDBCSQLCheckError", &SQLCheckError_Wrap, "Check for driver errors"); m.def("DDBCSQLGetAllDiagRecords", &SQLGetAllDiagRecords, diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index dfc22252..5b4a0793 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -668,6 +668,7 @@ struct ColumnInfoExt { uint64_t fetchBufferSize; bool isLob; bool isUtf8; // Pre-computed from charEncoding (avoids string compare per cell) + bool useWideChar; // True when charCtype == SQL_C_WCHAR (VARCHAR fetched as UTF-16) std::string charEncoding; // Effective decoding encoding for SQL_C_CHAR data }; @@ -791,6 +792,10 @@ inline void ProcessDouble(PyObject* row, ColumnBuffers& buffers, const void*, SQ // Process SQL CHAR/VARCHAR (single-byte string) column into Python str // Performance: NULL/NO_TOTAL checks removed - handled centrally before // processor is called +// +// When useWideChar is true, the column was bound as SQL_C_WCHAR in +// SQLBindColums and data lives in wcharBuffers (UTF-16). Otherwise, +// charBuffers contain narrow data decoded with the configured codec. inline void ProcessChar(PyObject* row, ColumnBuffers& buffers, const void* colInfoPtr, SQLUSMALLINT col, SQLULEN rowIdx, SQLHSTMT hStmt) { const ColumnInfoExt* colInfo = static_cast(colInfoPtr); @@ -808,6 +813,37 @@ inline void ProcessChar(PyObject* row, ColumnBuffers& buffers, const void* colIn return; } + if (colInfo->useWideChar) { + // Wide-char path: data was bound as SQL_C_WCHAR, lives in wcharBuffers + uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR); + if (!colInfo->isLob && numCharsInData < colInfo->fetchBufferSize) { + SQLWCHAR* wcharData = &buffers.wcharBuffers[col - 1][rowIdx * colInfo->fetchBufferSize]; +#if defined(__APPLE__) || defined(__linux__) + PyObject* pyStr = + PyUnicode_DecodeUTF16(reinterpret_cast(wcharData), + numCharsInData * sizeof(SQLWCHAR), nullptr, nullptr); +#else + PyObject* pyStr = + PyUnicode_FromWideChar(reinterpret_cast(wcharData), numCharsInData); +#endif + if (!pyStr) { + PyErr_Clear(); + Py_INCREF(Py_None); + PyList_SET_ITEM(row, col - 1, Py_None); + } else { + PyList_SET_ITEM(row, col - 1, pyStr); + } + } else { + // LOB / truncated: stream with SQL_C_WCHAR + PyList_SET_ITEM(row, col - 1, + FetchLobColumnData(hStmt, col, SQL_C_WCHAR, true, false, "utf-16le") + .release() + .ptr()); + } + return; + } + + // Original narrow-char path (charCtype == SQL_C_CHAR) uint64_t numCharsInData = dataLen / sizeof(SQLCHAR); // Fast path: Data fits in buffer (not LOB or truncated) // fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence diff --git a/tests/test_013_encoding_decoding.py b/tests/test_013_encoding_decoding.py index 034afae6..d14ac398 100644 --- a/tests/test_013_encoding_decoding.py +++ b/tests/test_013_encoding_decoding.py @@ -604,12 +604,14 @@ def test_setencoding_cp1252(conn_str): def test_setdecoding_default_settings(db_connection): """Test that default decoding settings are correct for all SQL types.""" - # Check SQL_CHAR defaults + # Check SQL_CHAR defaults (now SQL_WCHAR/utf-16le to avoid CP-1252 decode issues) sql_char_settings = db_connection.getdecoding(mssql_python.SQL_CHAR) - assert sql_char_settings["encoding"] == "utf-8", "Default SQL_CHAR encoding should be utf-8" assert ( - sql_char_settings["ctype"] == mssql_python.SQL_CHAR - ), "Default SQL_CHAR ctype should be SQL_CHAR" + sql_char_settings["encoding"] == "utf-16le" + ), "Default SQL_CHAR encoding should be utf-16le" + assert ( + sql_char_settings["ctype"] == mssql_python.SQL_WCHAR + ), "Default SQL_CHAR ctype should be SQL_WCHAR" # Check SQL_WCHAR defaults sql_wchar_settings = db_connection.getdecoding(mssql_python.SQL_WCHAR) @@ -4921,8 +4923,8 @@ def test_pooled_connections_have_independent_encoding_settings(conn_str, reset_p dec3 = conn3.getdecoding(mssql_python.SQL_CHAR) assert dec1["encoding"] == "latin-1" - assert dec2["encoding"] == "utf-8" - assert dec3["encoding"] == "utf-8" + assert dec2["encoding"] == "utf-16le" + assert dec3["encoding"] == "utf-16le" conn1.close() conn2.close() @@ -5648,10 +5650,10 @@ def test_default_encoding_behavior_validation(conn_str): sql_char_settings = conn.getdecoding(SQL_CHAR) sql_wchar_settings = conn.getdecoding(SQL_WCHAR) - # SQL_CHAR should default to UTF-8 + # SQL_CHAR now defaults to UTF-16LE (SQL_C_WCHAR) to avoid CP-1252 decode issues assert ( - sql_char_settings["encoding"] == "utf-8" - ), f"SQL_CHAR should default to UTF-8, got {sql_char_settings['encoding']}" + sql_char_settings["encoding"] == "utf-16le" + ), f"SQL_CHAR should default to utf-16le, got {sql_char_settings['encoding']}" # SQL_WCHAR should default to UTF-16LE (or UTF-16BE) assert sql_wchar_settings["encoding"] in [ @@ -7256,5 +7258,374 @@ def test_dae_encoding_large_string(db_connection): cursor.close() +# ==================================================================================== +# VARCHAR BYTE VALUE DECODING ISSUE TESTS +# ==================================================================================== +# Validates VARCHAR decoding behavior for byte values that are valid in CP-1252 +# (the default Windows code page) but invalid as single-byte UTF-8 sequences. +# +# Fix: The default ctype for SQL_CHAR is now SQL_C_WCHAR, which tells the ODBC driver +# to convert VARCHAR data to UTF-16 internally. This means all byte values (including +# those >= 0x80) are consistently returned as Python str regardless of platform. +# +# Previously (before fix): Windows + default UTF-8 decoding → bytes (fallback) +# Now (after fix): Default SQL_C_WCHAR → str on all platforms +# +# Users can still explicitly set SQL_C_CHAR via setdecoding() for backward compat. +# ==================================================================================== + + +@pytest.mark.skipif( + sys.platform != "win32", reason="This test class targets Windows-specific ODBC driver behavior" +) +class TestVarcharByteDecodingIssue: + """Tests for VARCHAR byte value decoding with the SQL_C_WCHAR default fix.""" + + TABLE_NAME = "test_varchar_byte_decoding" + + @pytest.fixture(autouse=True) + def setup_table(self, db_connection, cursor): + """Create and clean up the test table for each test.""" + # Reset decoding to the new default (SQL_C_WCHAR + utf-16le) before each test + # to avoid leaking settings from previous tests (db_connection is module-scoped). + db_connection.setdecoding(SQL_CHAR, encoding="utf-16le", ctype=SQL_WCHAR) + cursor.execute(f"DROP TABLE IF EXISTS {self.TABLE_NAME}") + cursor.execute(f"CREATE TABLE {self.TABLE_NAME} (id INT PRIMARY KEY, data VARCHAR(256))") + db_connection.commit() + yield + try: + cursor.execute(f"DROP TABLE IF EXISTS {self.TABLE_NAME}") + db_connection.commit() + except Exception: + pass + + def test_byte_173_returns_str_with_default_wchar(self, db_connection, cursor): + """Byte 173 (0xAD, soft hyphen in CP-1252) returns str with default SQL_C_WCHAR. + + The default ctype for SQL_CHAR is now SQL_C_WCHAR, so the ODBC driver + converts the VARCHAR data to UTF-16 internally, avoiding encoding issues. + """ + cursor.execute(f"INSERT INTO {self.TABLE_NAME} (id, data) VALUES (1, CHAR(173))") + db_connection.commit() + + cursor.execute(f"SELECT data FROM {self.TABLE_NAME} WHERE id = 1") + row = cursor.fetchone() + val = row[0] + + # With the new default SQL_C_WCHAR, byte 0xAD is correctly decoded + # by the ODBC driver to U+00AD SOFT HYPHEN. + assert isinstance( + val, str + ), f"Expected str with default SQL_C_WCHAR, got {type(val).__name__}: {repr(val)}" + assert val == "\u00ad", f"Expected U+00AD, got {repr(val)}" + + def test_byte_173_returns_bytes_with_explicit_sql_c_char(self, db_connection, cursor): + """Byte 173 returns bytes when explicitly using SQL_C_CHAR + utf-8 (old behavior). + + Users can opt into the old behavior by calling setdecoding(SQL_CHAR, encoding='utf-8'). + """ + db_connection.setdecoding(SQL_CHAR, encoding="utf-8") + + cursor.execute(f"INSERT INTO {self.TABLE_NAME} (id, data) VALUES (1, CHAR(173))") + db_connection.commit() + + cursor.execute(f"SELECT data FROM {self.TABLE_NAME} WHERE id = 1") + row = cursor.fetchone() + val = row[0] + + # With explicit SQL_C_CHAR + UTF-8, byte 0xAD cannot be decoded + # as a single-byte UTF-8 sequence, so the C++ fallback returns bytes. + assert isinstance( + val, bytes + ), f"Expected bytes from UTF-8 decode failure fallback, got {type(val).__name__}: {repr(val)}" + assert val == b"\xad", f"Expected b'\\xad', got {repr(val)}" + + def test_byte_173_returns_str_with_cp1252_decoding(self, db_connection, cursor): + """Setting SQL_CHAR decoding to cp1252 correctly decodes byte 173 as str. + + Byte 173 in CP-1252 maps to U+00AD SOFT HYPHEN. + """ + db_connection.setdecoding(SQL_CHAR, encoding="cp1252") + + cursor.execute(f"INSERT INTO {self.TABLE_NAME} (id, data) VALUES (1, CHAR(173))") + db_connection.commit() + + cursor.execute(f"SELECT data FROM {self.TABLE_NAME} WHERE id = 1") + row = cursor.fetchone() + val = row[0] + + assert isinstance( + val, str + ), f"Expected str with cp1252 decoding, got {type(val).__name__}: {repr(val)}" + assert val == "\u00ad", f"Expected U+00AD SOFT HYPHEN, got {repr(val)}" + + def test_byte_173_returns_str_with_latin1_decoding(self, db_connection, cursor): + """Setting SQL_CHAR decoding to latin-1 also correctly decodes byte 173. + + Latin-1 (ISO 8859-1) maps every byte 0x00-0xFF to the same Unicode + code point, so byte 173 → U+00AD. + """ + db_connection.setdecoding(SQL_CHAR, encoding="latin-1") + + cursor.execute(f"INSERT INTO {self.TABLE_NAME} (id, data) VALUES (1, CHAR(173))") + db_connection.commit() + + cursor.execute(f"SELECT data FROM {self.TABLE_NAME} WHERE id = 1") + row = cursor.fetchone() + val = row[0] + + assert isinstance( + val, str + ), f"Expected str with latin-1 decoding, got {type(val).__name__}: {repr(val)}" + assert val == "\u00ad" + + def test_all_high_bytes_return_str_with_default_wchar(self, db_connection, cursor): + """With the default SQL_C_WCHAR ctype, all high bytes return str. + + The ODBC driver converts VARCHAR data to UTF-16 internally, so even + bytes >= 128 that are invalid in UTF-8 are correctly handled. + """ + # Verify we are using the new default SQL_C_WCHAR + settings = db_connection.getdecoding(SQL_CHAR) + assert settings["ctype"] == SQL_WCHAR + + # Insert a selection of high-byte values that are valid in CP-1252 + test_bytes = [128, 142, 150, 160, 173, 192, 224, 255] + for b in test_bytes: + cursor.execute(f"INSERT INTO {self.TABLE_NAME} (id, data) VALUES ({b}, CHAR({b}))") + db_connection.commit() + + cursor.execute(f"SELECT id, data FROM {self.TABLE_NAME} ORDER BY id") + rows = cursor.fetchall() + + str_ids = [r[0] for r in rows if isinstance(r[1], str)] + bytes_ids = [r[0] for r in rows if isinstance(r[1], bytes)] + + assert len(str_ids) == len(test_bytes), ( + f"Expected all {len(test_bytes)} high bytes to return str with default SQL_C_WCHAR, " + f"but {len(bytes_ids)} returned bytes: {bytes_ids}" + ) + + def test_all_high_bytes_decode_with_cp1252(self, db_connection, cursor): + """With cp1252 decoding, all high-byte values decode to str. + + CP-1252 defines mappings for all byte values 0-255 (except 5 undefined + positions: 0x81, 0x8D, 0x8F, 0x90, 0x9D which Python's cp1252 codec + will raise on). + """ + db_connection.setdecoding(SQL_CHAR, encoding="cp1252") + + # CP-1252 defined high bytes (excluding 0x81, 0x8D, 0x8F, 0x90, 0x9D) + defined_bytes = [b for b in range(128, 256) if b not in (0x81, 0x8D, 0x8F, 0x90, 0x9D)] + for b in defined_bytes: + cursor.execute(f"INSERT INTO {self.TABLE_NAME} (id, data) VALUES ({b}, CHAR({b}))") + db_connection.commit() + + cursor.execute(f"SELECT id, data FROM {self.TABLE_NAME} ORDER BY id") + rows = cursor.fetchall() + + for row in rows: + byte_val, data = row[0], row[1] + assert isinstance(data, str), ( + f"Byte {byte_val} (0x{byte_val:02X}): expected str with cp1252 decoding, " + f"got {type(data).__name__}: {repr(data)}" + ) + + def test_ascii_bytes_unaffected_by_encoding_choice(self, db_connection, cursor): + """ASCII bytes (0-127) are valid in all encodings and always return str.""" + # Test a selection of printable ASCII characters + test_chars = [32, 65, 90, 97, 122, 48, 57, 33, 126] + for c in test_chars: + cursor.execute(f"INSERT INTO {self.TABLE_NAME} (id, data) VALUES ({c}, CHAR({c}))") + db_connection.commit() + + cursor.execute(f"SELECT id, data FROM {self.TABLE_NAME} ORDER BY id") + rows = cursor.fetchall() + + for row in rows: + char_val, data = row[0], row[1] + assert isinstance( + data, str + ), f"ASCII char {char_val} should always return str, got {type(data).__name__}" + + def test_mixed_ascii_and_high_bytes_returns_str_with_default_wchar(self, db_connection, cursor): + """A VARCHAR value mixing ASCII and high bytes returns str with default SQL_C_WCHAR. + + E.g. 'hello' + CHAR(173) + 'world' — the ODBC driver converts everything + to UTF-16 so the entire value is correctly decoded to str. + """ + cursor.execute( + f"INSERT INTO {self.TABLE_NAME} (id, data) VALUES (1, 'hello' + CHAR(173) + 'world')" + ) + db_connection.commit() + + cursor.execute(f"SELECT data FROM {self.TABLE_NAME} WHERE id = 1") + row = cursor.fetchone() + val = row[0] + + assert isinstance( + val, str + ), f"Expected str with default SQL_C_WCHAR, got {type(val).__name__}: {repr(val)}" + assert val == "hello\u00adworld" + + def test_mixed_ascii_and_high_bytes_with_cp1252(self, db_connection, cursor): + """With cp1252 decoding, mixed ASCII + high bytes returns str correctly.""" + db_connection.setdecoding(SQL_CHAR, encoding="cp1252") + + cursor.execute( + f"INSERT INTO {self.TABLE_NAME} (id, data) VALUES (1, 'hello' + CHAR(173) + 'world')" + ) + db_connection.commit() + + cursor.execute(f"SELECT data FROM {self.TABLE_NAME} WHERE id = 1") + row = cursor.fetchone() + val = row[0] + + assert isinstance( + val, str + ), f"Expected str with cp1252 decoding, got {type(val).__name__}: {repr(val)}" + assert val == "hello\u00adworld" + + def test_fetchmany_returns_str_for_high_byte_values(self, db_connection, cursor): + """fetchmany() returns str for high bytes with default SQL_C_WCHAR.""" + for i in range(5): + cursor.execute(f"INSERT INTO {self.TABLE_NAME} (id, data) VALUES ({i}, CHAR(173))") + db_connection.commit() + + cursor.execute(f"SELECT data FROM {self.TABLE_NAME}") + rows = cursor.fetchmany(5) + + for row in rows: + assert isinstance( + row[0], str + ), f"Expected str from fetchmany with default SQL_C_WCHAR: {repr(row[0])}" + assert row[0] == "\u00ad" + + def test_fetchall_returns_str_for_high_byte_values(self, db_connection, cursor): + """fetchall() returns str for high bytes with default SQL_C_WCHAR.""" + for i in range(5): + cursor.execute(f"INSERT INTO {self.TABLE_NAME} (id, data) VALUES ({i}, CHAR(173))") + db_connection.commit() + + cursor.execute(f"SELECT data FROM {self.TABLE_NAME}") + rows = cursor.fetchall() + + for row in rows: + assert isinstance( + row[0], str + ), f"Expected str from fetchall with default SQL_C_WCHAR: {repr(row[0])}" + assert row[0] == "\u00ad" + + def test_nvarchar_unaffected_by_varchar_decoding_issue(self, db_connection, cursor): + """NVARCHAR columns use SQL_WCHAR (UTF-16LE) and are not affected. + + The issue only affects VARCHAR (SQL_CHAR) columns where the server's + native encoding (CP-1252) doesn't match the default UTF-8 decoding. + """ + cursor.execute(f"DROP TABLE IF EXISTS {self.TABLE_NAME}") + cursor.execute(f"CREATE TABLE {self.TABLE_NAME} (id INT PRIMARY KEY, data NVARCHAR(256))") + # NCHAR(173) = U+00AD SOFT HYPHEN, stored as UTF-16 natively + cursor.execute(f"INSERT INTO {self.TABLE_NAME} (id, data) VALUES (1, NCHAR(173))") + db_connection.commit() + + cursor.execute(f"SELECT data FROM {self.TABLE_NAME} WHERE id = 1") + row = cursor.fetchone() + val = row[0] + + assert isinstance( + val, str + ), f"NVARCHAR should always return str, got {type(val).__name__}: {repr(val)}" + assert val == "\u00ad" + + def test_cp1252_specific_characters_round_trip(self, db_connection, cursor): + """CP-1252 has characters not in Latin-1: smart quotes, euro sign, etc. + + Byte values like 0x80 (€), 0x93 (\u201c), 0x94 (\u201d), 0x96 (\u2013) are + Windows-specific and have no Latin-1 equivalent. + """ + db_connection.setdecoding(SQL_CHAR, encoding="cp1252") + + # CP-1252 specific mappings: byte -> Unicode codepoint + cp1252_specials = { + 0x80: "\u20ac", # € Euro sign + 0x85: "\u2026", # … Horizontal ellipsis + 0x93: "\u201c", # " Left double quotation mark + 0x94: "\u201d", # " Right double quotation mark + 0x96: "\u2013", # – En dash + 0x97: "\u2014", # — Em dash + 0x99: "\u2122", # ™ Trade mark sign + } + + for byte_val, expected_char in cp1252_specials.items(): + cursor.execute( + f"INSERT INTO {self.TABLE_NAME} (id, data) VALUES ({byte_val}, CHAR({byte_val}))" + ) + db_connection.commit() + + cursor.execute(f"SELECT id, data FROM {self.TABLE_NAME} ORDER BY id") + rows = cursor.fetchall() + + for row in rows: + byte_val, data = row[0], row[1] + expected = cp1252_specials[byte_val] + assert isinstance( + data, str + ), f"Byte 0x{byte_val:02X}: expected str, got {type(data).__name__}" + assert ( + data == expected + ), f"Byte 0x{byte_val:02X}: expected {repr(expected)}, got {repr(data)}" + + def test_switching_decoding_mid_session(self, db_connection, cursor): + """Demonstrates switching between SQL_C_WCHAR (default) and SQL_C_CHAR + cp1252. + + First fetch with default SQL_C_WCHAR returns str, then we switch to + explicit SQL_C_CHAR + cp1252 and the same data still returns str. + Finally we switch to SQL_C_CHAR + utf-8 and it returns bytes (old behavior). + """ + cursor.execute(f"INSERT INTO {self.TABLE_NAME} (id, data) VALUES (1, CHAR(173))") + db_connection.commit() + + # First fetch: default SQL_C_WCHAR → str + cursor.execute(f"SELECT data FROM {self.TABLE_NAME} WHERE id = 1") + row = cursor.fetchone() + assert isinstance(row[0], str), "Expected str with default SQL_C_WCHAR" + assert row[0] == "\u00ad" + + # Switch to cp1252 (auto-detects SQL_C_CHAR) + db_connection.setdecoding(SQL_CHAR, encoding="cp1252") + + # Second fetch: cp1252 + SQL_C_CHAR → str (cp1252 can decode byte 0xAD) + cursor.execute(f"SELECT data FROM {self.TABLE_NAME} WHERE id = 1") + row = cursor.fetchone() + assert isinstance(row[0], str), "Expected str with cp1252" + assert row[0] == "\u00ad" + + # Switch to explicit utf-8 (auto-detects SQL_C_CHAR) — old behavior + db_connection.setdecoding(SQL_CHAR, encoding="utf-8") + + # Third fetch: utf-8 + SQL_C_CHAR → bytes (fallback due to decode error) + cursor.execute(f"SELECT data FROM {self.TABLE_NAME} WHERE id = 1") + row = cursor.fetchone() + assert isinstance(row[0], bytes), "Expected bytes with explicit SQL_C_CHAR + utf-8" + + def test_multiple_cp1252_bytes_in_single_row(self, db_connection, cursor): + """A VARCHAR value with multiple CP-1252 high bytes all decode correctly.""" + db_connection.setdecoding(SQL_CHAR, encoding="cp1252") + + # Build a string with euro sign + en dash + smart quotes: €–"" + cursor.execute( + f"INSERT INTO {self.TABLE_NAME} (id, data) VALUES (1, " + f"CHAR(128) + CHAR(150) + CHAR(147) + CHAR(148))" + ) + db_connection.commit() + + cursor.execute(f"SELECT data FROM {self.TABLE_NAME} WHERE id = 1") + row = cursor.fetchone() + val = row[0] + + assert isinstance(val, str) + assert val == "\u20ac\u2013\u201c\u201d", f"Expected €–\u201c\u201d, got {repr(val)}" + + if __name__ == "__main__": pytest.main([__file__, "-v"])