1+ import re
12import typing
23from collections import deque
34from itertools import count, islice
@@ -185,7 +186,7 @@ def executemany(self: "Cursor", operation, param_sets) -> "Cursor":
185186 self._row_count = -1 if -1 in rowcounts else sum(rowcounts)
186187 return self
187188
188- def fetchone(self: "Cursor") -> typing.Optional["Cursor" ]:
189+ def fetchone(self: "Cursor") -> typing.Optional[typing.List ]:
189190 """Fetch the next row of a query result set.
190191
191192 This method is part of the `DBAPI 2.0 specification
@@ -196,7 +197,7 @@ def fetchone(self: "Cursor") -> typing.Optional["Cursor"]:
196197 are available.
197198 """
198199 try:
199- return typing.cast("Cursor", next(self) )
200+ return next(self)
200201 except StopIteration:
201202 return None
202203 except TypeError:
@@ -271,7 +272,7 @@ def setoutputsize(self: "Cursor", size, column=None):
271272 """
272273 pass
273274
274- def __next__(self: "Cursor"):
275+ def __next__(self: "Cursor") -> typing.List :
275276 try:
276277 return self._cached_rows.popleft()
277278 except IndexError:
@@ -311,16 +312,48 @@ def fetch_dataframe(self: "Cursor", num: typing.Optional[int] = None) -> typing.
311312 return None
312313 return pandas.DataFrame(result, columns=columns)
313314
315+ def __is_valid_table(self: "Cursor", table: str) -> bool:
316+ split_table_name: typing.List[str] = table.split(".")
317+
318+ if len(split_table_name) > 2:
319+ return False
320+
321+ q: str = "select 1 from information_schema.tables where table_name = ?"
322+
323+ temp = self.paramstyle
324+ self.paramstyle = "qmark"
325+
326+ try:
327+ if len(split_table_name) == 2:
328+ q += " and table_schema = ?"
329+ self.execute(q, (split_table_name[1], split_table_name[0]))
330+ else:
331+ self.execute(q, (split_table_name[0],))
332+ except:
333+ raise
334+ finally:
335+ # reset paramstyle to it's original value
336+ self.paramstyle = temp
337+
338+ result = self.fetchone()
339+
340+ return result[0] == 1 if result is not None else False
341+
314342 def write_dataframe(self: "Cursor", df: "pandas.DataFrame", table: str) -> None:
315343 """write same structure dataframe into Redshift database"""
316344 try:
317345 import pandas
318346 except ModuleNotFoundError:
319347 raise ModuleNotFoundError(MISSING_MODULE_ERROR_MSG.format(module="pandas"))
320348
349+ if not self.__is_valid_table(table):
350+ raise InterfaceError("Invalid table name passed to write_dataframe: {}".format(table))
351+ sanitized_table_name: str = self.__sanitize_str(table)
321352 arrays: "numpy.ndarray" = df.values
322353 placeholder: str = ", ".join(["%s"] * len(arrays[0]))
323- sql: str = "insert into {table} values ({placeholder})".format(table=table, placeholder=placeholder)
354+ sql: str = "insert into {table} values ({placeholder})".format(
355+ table=sanitized_table_name, placeholder=placeholder
356+ )
324357 if len(arrays) == 1:
325358 self.execute(sql, arrays[0])
326359 elif len(arrays) > 1:
@@ -361,16 +394,33 @@ def get_procedures(
361394 " LEFT JOIN pg_catalog.pg_namespace pn ON (c.relnamespace=pn.oid AND pn.nspname='pg_catalog') "
362395 " WHERE p.pronamespace=n.oid "
363396 )
397+ query_args: typing.List[str] = []
364398 if schema_pattern is not None:
365- sql += " AND n.nspname LIKE {schema}".format(schema=self.__escape_quotes(schema_pattern))
399+ sql += " AND n.nspname LIKE ?"
400+ query_args.append(self.__sanitize_str(schema_pattern))
366401 else:
367402 sql += "and pg_function_is_visible(p.prooid)"
368403
369404 if procedure_name_pattern is not None:
370- sql += " AND p.proname LIKE {procedure}".format(procedure=self.__escape_quotes(procedure_name_pattern))
405+ sql += " AND p.proname LIKE ?"
406+ query_args.append(self.__sanitize_str(procedure_name_pattern))
371407 sql += " ORDER BY PROCEDURE_SCHEM, PROCEDURE_NAME, p.prooid::text "
372408
373- self.execute(sql)
409+ if len(query_args) > 0:
410+ # temporarily use qmark paramstyle
411+ temp = self.paramstyle
412+ self.paramstyle = "qmark"
413+
414+ try:
415+ self.execute(sql, tuple(query_args))
416+ except:
417+ raise
418+ finally:
419+ # reset the original value of paramstyle
420+ self.paramstyle = temp
421+ else:
422+ self.execute(sql)
423+
374424 procedures: tuple = self.fetchall()
375425 return procedures
376426
@@ -383,11 +433,25 @@ def get_schemas(
383433 " OR nspname = (pg_catalog.current_schemas(true))[1]) AND (nspname !~ '^pg_toast_temp_' "
384434 " OR nspname = replace((pg_catalog.current_schemas(true))[1], 'pg_temp_', 'pg_toast_temp_')) "
385435 )
436+ query_args: typing.List[str] = []
386437 if schema_pattern is not None:
387- sql += " AND nspname LIKE {schema}".format(schema=self.__escape_quotes(schema_pattern))
438+ sql += " AND nspname LIKE ?"
439+ query_args.append(self.__sanitize_str(schema_pattern))
388440 sql += " ORDER BY TABLE_SCHEM"
389441
390- self.execute(sql)
442+ if len(query_args) == 1:
443+ # temporarily use qmark paramstyle
444+ temp = self.paramstyle
445+ self.paramstyle = "qmark"
446+ try:
447+ self.execute(sql, tuple(query_args))
448+ except:
449+ raise
450+ finally:
451+ self.paramstyle = temp
452+ else:
453+ self.execute(sql)
454+
391455 schemas: tuple = self.fetchall()
392456 return schemas
393457
@@ -418,13 +482,28 @@ def get_primary_keys(
418482 "i.indisprimary AND "
419483 "ct.relnamespace = n.oid "
420484 )
485+ query_args: typing.List[str] = []
421486 if schema is not None:
422- sql += " AND n.nspname = {schema}".format(schema=self.__escape_quotes(schema))
487+ sql += " AND n.nspname = ?"
488+ query_args.append(self.__sanitize_str(schema))
423489 if table is not None:
424- sql += " AND ct.relname = {table}".format(table=self.__escape_quotes(table))
490+ sql += " AND ct.relname = ?"
491+ query_args.append(self.__sanitize_str(table))
425492
426493 sql += " ORDER BY table_name, pk_name, key_seq"
427- self.execute(sql)
494+
495+ if len(query_args) > 0:
496+ # temporarily use qmark paramstyle
497+ temp = self.paramstyle
498+ self.paramstyle = "qmark"
499+ try:
500+ self.execute(sql, tuple(query_args))
501+ except:
502+ raise
503+ finally:
504+ self.paramstyle = temp
505+ else:
506+ self.execute(sql)
428507 keys: tuple = self.fetchall()
429508 return keys
430509
@@ -437,15 +516,30 @@ def get_tables(
437516 ) -> tuple:
438517 """Returns the unique public tables which are user-defined within the system"""
439518 sql: str = ""
519+ sql_args: typing.Tuple[str, ...] = tuple()
440520 schema_pattern_type: str = self.__schema_pattern_match(schema_pattern)
441521 if schema_pattern_type == "LOCAL_SCHEMA_QUERY":
442- sql = self.__build_local_schema_tables_query(catalog, schema_pattern, table_name_pattern, types)
522+ sql, sql_args = self.__build_local_schema_tables_query(catalog, schema_pattern, table_name_pattern, types)
443523 elif schema_pattern_type == "NO_SCHEMA_UNIVERSAL_QUERY":
444- sql = self.__build_universal_schema_tables_query(catalog, schema_pattern, table_name_pattern, types)
524+ sql, sql_args = self.__build_universal_schema_tables_query(
525+ catalog, schema_pattern, table_name_pattern, types
526+ )
445527 elif schema_pattern_type == "EXTERNAL_SCHEMA_QUERY":
446- sql = self.__build_external_schema_tables_query(catalog, schema_pattern, table_name_pattern, types)
528+ sql, sql_args = self.__build_external_schema_tables_query(
529+ catalog, schema_pattern, table_name_pattern, types
530+ )
447531
448- self.execute(sql)
532+ if len(sql_args) > 0:
533+ temp = self.paramstyle
534+ self.paramstyle = "qmark"
535+ try:
536+ self.execute(sql, sql_args)
537+ except:
538+ raise
539+ finally:
540+ self.paramstyle = temp
541+ else:
542+ self.execute(sql)
449543 tables: tuple = self.fetchall()
450544 return tables
451545
@@ -455,7 +549,7 @@ def __build_local_schema_tables_query(
455549 schema_pattern: typing.Optional[str],
456550 table_name_pattern: typing.Optional[str],
457551 types: list,
458- ) -> str:
552+ ) -> typing.Tuple[ str, typing.Tuple[str, ...]] :
459553 sql: str = (
460554 "SELECT CAST(current_database() AS VARCHAR(124)) AS TABLE_CAT, n.nspname AS TABLE_SCHEM, c.relname AS TABLE_NAME, "
461555 " CASE n.nspname ~ '^pg_' OR n.nspname = 'information_schema' "
@@ -502,32 +596,41 @@ def __build_local_schema_tables_query(
502596 " LEFT JOIN pg_catalog.pg_namespace dn ON (dn.oid=dc.relnamespace AND dn.nspname='pg_catalog') "
503597 " WHERE c.relnamespace = n.oid "
504598 )
505- filter_clause: str = self.__get_table_filter_clause(
599+ filter_clause, filter_args = self.__get_table_filter_clause(
506600 catalog, schema_pattern, table_name_pattern, types, "LOCAL_SCHEMA_QUERY"
507601 )
508602 orderby: str = " ORDER BY TABLE_TYPE,TABLE_SCHEM,TABLE_NAME "
509603
510- return sql + filter_clause + orderby
604+ return sql + filter_clause + orderby, filter_args
511605
512606 def __get_table_filter_clause(
513607 self: "Cursor",
514608 catalog: typing.Optional[str],
515609 schema_pattern: typing.Optional[str],
516610 table_name_pattern: typing.Optional[str],
517- types: list ,
611+ types: typing.List[str] ,
518612 schema_pattern_type: str,
519- ) -> str:
613+ ) -> typing.Tuple[ str, typing.Tuple[str, ...]] :
520614 filter_clause: str = ""
521615 use_schemas: str = "SCHEMAS"
616+ query_args: typing.List[str] = []
522617 if schema_pattern is not None:
523- filter_clause += " AND TABLE_SCHEM LIKE {schema}".format(schema=self.__escape_quotes(schema_pattern))
618+ filter_clause += " AND TABLE_SCHEM LIKE ?"
619+ query_args.append(self.__sanitize_str(schema_pattern))
524620 if table_name_pattern is not None:
525- filter_clause += " AND TABLE_NAME LIKE {table}".format(table=self.__escape_quotes(table_name_pattern))
621+ filter_clause += " AND TABLE_NAME LIKE ?"
622+ query_args.append(self.__sanitize_str(table_name_pattern))
526623 if len(types) > 0:
527624 if schema_pattern_type == "LOCAL_SCHEMA_QUERY":
528625 filter_clause += " AND (false "
529626 orclause: str = ""
530627 for type in types:
628+ if type not in table_type_clauses.keys():
629+ raise InterfaceError(
630+ "Invalid type: {} provided. types may only contain: {}".format(
631+ type, table_type_clauses.keys()
632+ )
633+ )
531634 clauses = table_type_clauses[type]
532635 if len(clauses) > 0:
533636 cluase = clauses[use_schemas]
@@ -538,21 +641,28 @@ def __get_table_filter_clause(
538641 filter_clause += " AND TABLE_TYPE IN ( "
539642 length = len(types)
540643 for type in types:
541- filter_clause += self.__escape_quotes(type)
644+ if type not in table_type_clauses.keys():
645+ raise InterfaceError(
646+ "Invalid type: {} provided. types may only contain: {}".format(
647+ type, table_type_clauses.keys()
648+ )
649+ )
650+ filter_clause += "?"
651+ query_args.append(self.__sanitize_str(type))
542652 length -= 1
543653 if length > 0:
544654 filter_clause += ", "
545655 filter_clause += ") "
546656
547- return filter_clause
657+ return filter_clause, tuple(query_args)
548658
549659 def __build_universal_schema_tables_query(
550660 self: "Cursor",
551661 catalog: typing.Optional[str],
552662 schema_pattern: typing.Optional[str],
553663 table_name_pattern: typing.Optional[str],
554664 types: list,
555- ) -> str:
665+ ) -> typing.Tuple[ str, typing.Tuple[str, ...]] :
556666 sql: str = (
557667 "SELECT * FROM (SELECT CAST(current_database() AS VARCHAR(124)) AS TABLE_CAT,"
558668 " table_schema AS TABLE_SCHEM,"
@@ -583,20 +693,20 @@ def __build_universal_schema_tables_query(
583693 " FROM svv_tables)"
584694 " WHERE true "
585695 )
586- filter_clause: str = self.__get_table_filter_clause(
696+ filter_clause, filter_args = self.__get_table_filter_clause(
587697 catalog, schema_pattern, table_name_pattern, types, "NO_SCHEMA_UNIVERSAL_QUERY"
588698 )
589699 orderby: str = " ORDER BY TABLE_TYPE,TABLE_SCHEM,TABLE_NAME "
590700 sql += filter_clause + orderby
591- return sql
701+ return sql, filter_args
592702
593703 def __build_external_schema_tables_query(
594704 self: "Cursor",
595705 catalog: typing.Optional[str],
596706 schema_pattern: typing.Optional[str],
597707 table_name_pattern: typing.Optional[str],
598708 types: list,
599- ) -> str:
709+ ) -> typing.Tuple[ str, typing.Tuple[str, ...]] :
600710 sql: str = (
601711 "SELECT * FROM (SELECT CAST(current_database() AS VARCHAR(124)) AS TABLE_CAT,"
602712 " schemaname AS table_schem,"
@@ -611,12 +721,12 @@ def __build_external_schema_tables_query(
611721 " FROM svv_external_tables)"
612722 " WHERE true "
613723 )
614- filter_clause: str = self.__get_table_filter_clause(
724+ filter_clause, filter_args = self.__get_table_filter_clause(
615725 catalog, schema_pattern, table_name_pattern, types, "EXTERNAL_SCHEMA_QUERY"
616726 )
617727 orderby: str = " ORDER BY TABLE_TYPE,TABLE_SCHEM,TABLE_NAME "
618728 sql += filter_clause + orderby
619- return sql
729+ return sql, filter_args
620730
621731 def get_columns(
622732 self: "Cursor",
@@ -1477,5 +1587,8 @@ def __schema_pattern_match(self: "Cursor", schema_pattern: typing.Optional[str])
14771587 else:
14781588 return "NO_SCHEMA_UNIVERSAL_QUERY"
14791589
1590+ def __sanitize_str(self: "Cursor", s: str) -> str:
1591+ return re.sub(r"[-;/'\"\n\r ]", "", s)
1592+
14801593 def __escape_quotes(self: "Cursor", s: str) -> str:
1481- return "'{s}'".format(s=s )
1594+ return "'{s}'".format(s=self.__sanitize_str(s) )
0 commit comments