diff --git a/src/ldlite/_database/__init__.py b/src/ldlite/_database/__init__.py index 5229176..b287c79 100644 --- a/src/ldlite/_database/__init__.py +++ b/src/ldlite/_database/__init__.py @@ -15,7 +15,7 @@ class Prefix: def __init__(self, prefix: str): - self._schema: str | None = None + self.schema: str | None = None sandt = prefix.split(".") if len(sandt) > 2: msg = f"Expected one or two identifiers but got {prefix}" @@ -24,32 +24,43 @@ def __init__(self, prefix: str): if len(sandt) == 1: (self._prefix,) = sandt else: - (self._schema, self._prefix) = sandt - - @property - def schema_name(self) -> sql.Identifier | None: - return None if self._schema is None else sql.Identifier(self._schema) + (self.schema, self._prefix) = sandt def _identifier(self, table: str) -> sql.Identifier: - if self._schema is None: + if self.schema is None: return sql.Identifier(table) - return sql.Identifier(self._schema, table) + return sql.Identifier(self.schema, table) @property - def load_history_key(self) -> str: - return (self._schema or "public") + "." + self._prefix + def schema_identifier(self) -> sql.Identifier | None: + return None if self.schema is None else sql.Identifier(self.schema) @property - def raw_table_name(self) -> sql.Identifier: + def raw_table_identifier(self) -> sql.Identifier: return self._identifier(self._prefix) @property - def catalog_table_name(self) -> sql.Identifier: - return self._identifier(f"{self._prefix}__tcatalog") + def catalog_table_name(self) -> str: + return f"{self._prefix}__tcatalog" + + @property + def catalog_table_identifier(self) -> sql.Identifier: + return self._identifier(self.catalog_table_name) + + @property + def legacy_jtable_name(self) -> str: + return f"{self._prefix}_jtable" @property - def legacy_jtable(self) -> sql.Identifier: - return self._identifier(f"{self._prefix}_jtable") + def legacy_jtable_identifier(self) -> sql.Identifier: + return self._identifier(self.legacy_jtable_name) + + @property + def load_history_key(self) -> str: + if self.schema is None: + return self._prefix + + return self.schema + "." + self._prefix @dataclass(frozen=True) @@ -102,8 +113,9 @@ def __init__(self, conn_factory: Callable[[], DB]): );""") conn.commit() + @property @abstractmethod - def _rollback(self, conn: DB) -> None: ... + def _default_schema(self) -> str: ... def drop_prefix( self, @@ -134,7 +146,7 @@ def _drop_raw_table( with closing(conn.cursor()) as cur: cur.execute( sql.SQL("DROP TABLE IF EXISTS {table};") - .format(table=prefix.raw_table_name) + .format(table=prefix.raw_table_identifier) .as_string(), ) @@ -146,9 +158,6 @@ def drop_extracted_tables( self._drop_extracted_tables(conn, prefix) conn.commit() - @property - @abstractmethod - def _missing_table_error(self) -> type[Exception]: ... def _drop_extracted_tables( self, conn: DB, @@ -156,28 +165,32 @@ def _drop_extracted_tables( ) -> None: tables: list[Sequence[Sequence[Any]]] = [] with closing(conn.cursor()) as cur: - try: - cur.execute( - sql.SQL("SELECT table_name FROM {catalog};") - .format(catalog=prefix.catalog_table_name) - .as_string(), - ) - except self._missing_table_error: - self._rollback(conn) - else: - tables.extend(cur.fetchall()) - - with closing(conn.cursor()) as cur: - try: - cur.execute( - sql.SQL("SELECT table_name FROM {catalog};") - .format(catalog=prefix.legacy_jtable) - .as_string(), - ) - except self._missing_table_error: - self._rollback(conn) - else: - tables.extend(cur.fetchall()) + cur.execute( + """ +SELECT table_name FROM information_schema.tables +WHERE table_schema = $1 and table_name IN ($2, $3);""", + ( + prefix.schema or self._default_schema, + prefix.catalog_table_name, + prefix.legacy_jtable_name, + ), + ) + for (tname,) in cur.fetchall(): + if tname == prefix.catalog_table_name: + cur.execute( + sql.SQL("SELECT table_name FROM {catalog};") + .format(catalog=prefix.catalog_table_identifier) + .as_string(), + ) + tables.extend(cur.fetchall()) + + if tname == prefix.legacy_jtable_name: + cur.execute( + sql.SQL("SELECT table_name FROM {catalog};") + .format(catalog=prefix.legacy_jtable_identifier) + .as_string(), + ) + tables.extend(cur.fetchall()) with closing(conn.cursor()) as cur: for (et,) in tables: @@ -188,12 +201,12 @@ def _drop_extracted_tables( ) cur.execute( sql.SQL("DROP TABLE IF EXISTS {catalog};") - .format(catalog=prefix.catalog_table_name) + .format(catalog=prefix.catalog_table_identifier) .as_string(), ) cur.execute( sql.SQL("DROP TABLE IF EXISTS {catalog};") - .format(catalog=prefix.legacy_jtable) + .format(catalog=prefix.legacy_jtable_identifier) .as_string(), ) @@ -206,17 +219,17 @@ def _prepare_raw_table( prefix: Prefix, ) -> None: with closing(conn.cursor()) as cur: - if prefix.schema_name is not None: + if prefix.schema_identifier is not None: cur.execute( sql.SQL("CREATE SCHEMA IF NOT EXISTS {schema};") - .format(schema=prefix.schema_name) + .format(schema=prefix.schema_identifier) .as_string(), ) self._drop_raw_table(conn, prefix) with closing(conn.cursor()) as cur: cur.execute( self._create_raw_table_sql.format( - table=prefix.raw_table_name, + table=prefix.raw_table_identifier, ).as_string(), ) diff --git a/src/ldlite/_database/duckdb.py b/src/ldlite/_database/duckdb.py index ebe5806..0676875 100644 --- a/src/ldlite/_database/duckdb.py +++ b/src/ldlite/_database/duckdb.py @@ -8,12 +8,9 @@ class DuckDbDatabase(TypedDatabase[duckdb.DuckDBPyConnection]): - def _rollback(self, conn: duckdb.DuckDBPyConnection) -> None: - pass - @property - def _missing_table_error(self) -> type[Exception]: - return duckdb.CatalogException + def _default_schema(self) -> str: + return "main" @property def _create_raw_table_sql(self) -> sql.SQL: @@ -31,7 +28,7 @@ def ingest_records( insert_sql = ( sql.SQL("INSERT INTO {table} VALUES(?, ?);") .format( - table=prefix.raw_table_name, + table=prefix.raw_table_identifier, ) .as_string() ) diff --git a/src/ldlite/_database/postgres.py b/src/ldlite/_database/postgres.py index adccf37..479018d 100644 --- a/src/ldlite/_database/postgres.py +++ b/src/ldlite/_database/postgres.py @@ -13,12 +13,9 @@ def __init__(self, dsn: str): # same sql between duckdb and postgres super().__init__(lambda: psycopg.connect(dsn, cursor_factory=psycopg.RawCursor)) - def _rollback(self, conn: psycopg.Connection) -> None: - conn.rollback() - @property - def _missing_table_error(self) -> type[Exception]: - return psycopg.errors.UndefinedTable + def _default_schema(self) -> str: + return "public" @property def _create_raw_table_sql(self) -> sql.SQL: @@ -40,7 +37,7 @@ def ingest_records( cur.copy( sql.SQL( "COPY {table} (__id, jsonb) FROM STDIN (FORMAT BINARY)", - ).format(table=prefix.raw_table_name), + ).format(table=prefix.raw_table_identifier), ) as copy, ): # postgres jsonb is always version 1 diff --git a/tests/test_cases/load_history_cases.py b/tests/test_cases/load_history_cases.py index 44fa6df..ed1727b 100644 --- a/tests/test_cases/load_history_cases.py +++ b/tests/test_cases/load_history_cases.py @@ -33,7 +33,31 @@ def case_one_load(self, query: str | None) -> LoadHistoryCase: }, queries={"prefix": [query]}, expected_loads={ - "public.prefix": (query, 2), + "prefix": (query, 2), + }, + ) + + def case_schema_load(self) -> LoadHistoryCase: + return LoadHistoryCase( + values={ + "schema.prefix": [ + { + "purchaseOrders": [ + { + "id": "b096504a-3d54-4664-9bf5-1b872466fd66", + "value": "value", + }, + { + "id": "b096504a-9999-4664-9bf5-1b872466fd66", + "value": "value-2", + }, + ], + }, + ], + }, + queries={"schema.prefix": [None]}, + expected_loads={ + "schema.prefix": (None, 2), }, ) @@ -69,6 +93,6 @@ def case_two_loads(self) -> LoadHistoryCase: }, queries={"prefix": [None, "a query"]}, expected_loads={ - "public.prefix": ("a query", 2), + "prefix": ("a query", 2), }, ) diff --git a/tests/test_duckdb.py b/tests/test_duckdb.py index bbab397..1c3f8ea 100644 --- a/tests/test_duckdb.py +++ b/tests/test_duckdb.py @@ -29,6 +29,7 @@ def test_drop_tables( dsn = f":memory:{tc.db}" ld.connect_folio("https://doesnt.matter", "", "", "") ld.connect_db(dsn) + ld.drop_tables(tc.drop) for prefix in tc.values: ld.query(table=prefix, path="/patched", keep_raw=tc.keep_raw) diff --git a/tests/test_postgres.py b/tests/test_postgres.py index d58d8b0..1489869 100644 --- a/tests/test_postgres.py +++ b/tests/test_postgres.py @@ -56,6 +56,7 @@ def test_drop_tables( dsn = pg_dsn(tc.db) ld.connect_folio("https://doesnt.matter", "", "", "") ld.connect_db_postgresql(dsn) + ld.drop_tables(tc.drop) for prefix in tc.values: ld.query(table=prefix, path="/patched", keep_raw=tc.keep_raw)