Skip to content

Commit 053ad2b

Browse files
committed
fix(Cursor, write_dataframe): Convert pandas dataframe holding bind parameters to Python list before query execution. Ensures Python datatypes are sent to Redshift server rather than NumPy datatypes
1 parent 7bd195f commit 053ad2b

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

redshift_connector/cursor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ def write_dataframe(self: "Cursor", df: "pandas.DataFrame", table: str) -> None:
584584
if not self.__is_valid_table(table):
585585
raise InterfaceError("Invalid table name passed to write_dataframe: {}".format(table))
586586
sanitized_table_name: str = self.__sanitize_str(table)
587-
arrays: "numpy.ndarray" = df.values
587+
arrays: list = df.values.tolist()
588588
placeholder: str = ", ".join(["%s"] * len(arrays[0]))
589589
sql: str = "insert into {table} values ({placeholder})".format(
590590
table=sanitized_table_name, placeholder=placeholder

test/unit/test_cursor.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,3 +459,46 @@ def test_insert_data_raises_too_many_parameters(mocker) -> None:
459459

460460
with pytest.raises(DataError, match="Prepared statement exceeds bind parameter limit 32767."):
461461
mock_cursor.execute(prepared_stmt, params)
462+
463+
464+
@pandas_only
465+
def test_write_dataframe_handles_npdtyes(mocker):
466+
import numpy as np
467+
import pandas as pd
468+
469+
mocker.patch("redshift_connector.Cursor.execute", return_value=None)
470+
mocker.patch("redshift_connector.Cursor.fetchone", return_value=[1])
471+
mock_cursor: Cursor = Cursor.__new__(Cursor)
472+
mock_connection: Connection = Connection.__new__(Connection)
473+
mock_cursor._c = mock_connection
474+
475+
mock_cursor.paramstyle = "mocked_val"
476+
for datatype, data in (
477+
("int8_col", np.array([1], dtype=np.int8)),
478+
("int16_col", np.array([1], dtype=np.int16)),
479+
("int32_col", np.array([1], dtype=np.int32)),
480+
("int64_col", np.array([1], dtype=np.int64)),
481+
("uint8_col", np.array([1], dtype=np.uint8)),
482+
("uint16_col", np.array([1], dtype=np.uint16)),
483+
("uint32_col", np.array([1], dtype=np.uint32)),
484+
("uint64_col", np.array([1], dtype=np.uint64)),
485+
("float16_col", np.array([1.0], dtype=np.float16)),
486+
("float32_col", np.array([1.0], dtype=np.float32)),
487+
("float64_col", np.array([1.0], dtype=np.float64)),
488+
("complex64_col", np.array([1 + 1j], dtype=np.complex64)),
489+
("complex128_col", np.array([1 + 1j], dtype=np.complex128)),
490+
("bool_col", np.array([True], dtype=np.bool_)),
491+
("string_col", np.array(["hello"], dtype="U")),
492+
("object_col", np.array([{"key", "value"}], dtype=object)),
493+
):
494+
spy = mocker.spy(mock_cursor, "execute")
495+
dataframe = pd.DataFrame(data)
496+
mock_cursor.write_dataframe(df=dataframe, table=datatype)
497+
498+
assert spy.called
499+
assert spy.call_count == 2 # once for __is_valid_table, once for write_dataframe
500+
assert not isinstance(spy.mock_calls[1].args[1], np.ndarray)
501+
assert isinstance(spy.mock_calls[1].args[1], list)
502+
assert len(spy.mock_calls[1].args[1]) == 1
503+
# bind parameter list should not contain numpy objects
504+
assert not isinstance(spy.mock_calls[1].args[1][0], np.generic)

0 commit comments

Comments
 (0)