@@ -459,3 +459,46 @@ def test_insert_data_raises_too_many_parameters(mocker) -> None:
459459
460460 with pytest .raises (DataError , match = "Prepared statement exceeds bind parameter limit 32767." ):
461461 mock_cursor .execute (prepared_stmt , params )
462+
463+
464+ @pandas_only
465+ def test_write_dataframe_handles_npdtyes (mocker ):
466+ import numpy as np
467+ import pandas as pd
468+
469+ mocker .patch ("redshift_connector.Cursor.execute" , return_value = None )
470+ mocker .patch ("redshift_connector.Cursor.fetchone" , return_value = [1 ])
471+ mock_cursor : Cursor = Cursor .__new__ (Cursor )
472+ mock_connection : Connection = Connection .__new__ (Connection )
473+ mock_cursor ._c = mock_connection
474+
475+ mock_cursor .paramstyle = "mocked_val"
476+ for datatype , data in (
477+ ("int8_col" , np .array ([1 ], dtype = np .int8 )),
478+ ("int16_col" , np .array ([1 ], dtype = np .int16 )),
479+ ("int32_col" , np .array ([1 ], dtype = np .int32 )),
480+ ("int64_col" , np .array ([1 ], dtype = np .int64 )),
481+ ("uint8_col" , np .array ([1 ], dtype = np .uint8 )),
482+ ("uint16_col" , np .array ([1 ], dtype = np .uint16 )),
483+ ("uint32_col" , np .array ([1 ], dtype = np .uint32 )),
484+ ("uint64_col" , np .array ([1 ], dtype = np .uint64 )),
485+ ("float16_col" , np .array ([1.0 ], dtype = np .float16 )),
486+ ("float32_col" , np .array ([1.0 ], dtype = np .float32 )),
487+ ("float64_col" , np .array ([1.0 ], dtype = np .float64 )),
488+ ("complex64_col" , np .array ([1 + 1j ], dtype = np .complex64 )),
489+ ("complex128_col" , np .array ([1 + 1j ], dtype = np .complex128 )),
490+ ("bool_col" , np .array ([True ], dtype = np .bool_ )),
491+ ("string_col" , np .array (["hello" ], dtype = "U" )),
492+ ("object_col" , np .array ([{"key" , "value" }], dtype = object )),
493+ ):
494+ spy = mocker .spy (mock_cursor , "execute" )
495+ dataframe = pd .DataFrame (data )
496+ mock_cursor .write_dataframe (df = dataframe , table = datatype )
497+
498+ assert spy .called
499+ assert spy .call_count == 2 # once for __is_valid_table, once for write_dataframe
500+ assert not isinstance (spy .mock_calls [1 ].args [1 ], np .ndarray )
501+ assert isinstance (spy .mock_calls [1 ].args [1 ], list )
502+ assert len (spy .mock_calls [1 ].args [1 ]) == 1
503+ # bind parameter list should not contain numpy objects
504+ assert not isinstance (spy .mock_calls [1 ].args [1 ][0 ], np .generic )
0 commit comments