diff --git a/src/ldlite/__init__.py b/src/ldlite/__init__.py index 4d46edf..9b5bbbd 100644 --- a/src/ldlite/__init__.py +++ b/src/ldlite/__init__.py @@ -35,7 +35,6 @@ """ import sys -from itertools import count from typing import TYPE_CHECKING, NoReturn, cast import duckdb @@ -44,19 +43,20 @@ from tqdm import tqdm from ._csv import to_csv -from ._database import Prefix +from ._database import Database, Prefix from ._folio import FolioClient from ._jsonx import Attr, transform_json from ._select import select from ._sqlx import ( DBType, - DBTypeDatabase, as_postgres, autocommit, sqlid, ) if TYPE_CHECKING: + from collections.abc import Iterator + from _typeshed import dbapi from httpx_folio.query import QueryType @@ -77,7 +77,7 @@ def __init__(self) -> None: self._quiet = False self.dbtype: DBType = DBType.UNDEFINED self.db: dbapi.DBAPIConnection | None = None - self._db: DBTypeDatabase | None = None + self._db: Database | None = None self._folio: FolioClient | None = None self.page_size = 1000 self._okapi_timeout = 60 @@ -124,14 +124,13 @@ def _connect_db_duckdb( db = ld.connect_db_duckdb(filename='ldlite.db') """ + from ._database.duckdb import DuckDbDatabase # noqa: PLC0415 + self.dbtype = DBType.DUCKDB fn = filename if filename is not None else ":memory:" db = duckdb.connect(database=fn) self.db = cast("dbapi.DBAPIConnection", db.cursor()) - self._db = DBTypeDatabase( - DBType.DUCKDB, - lambda: cast("dbapi.DBAPIConnection", db.cursor()), - ) + self._db = DuckDbDatabase(lambda: db.cursor()) return db.cursor() @@ -146,13 +145,12 @@ def connect_db_postgresql(self, dsn: str) -> psycopg.Connection: db = ld.connect_db_postgresql(dsn='dbname=ld host=localhost user=ldlite') """ + from ._database.postgres import PostgresDatabase # noqa: PLC0415 + self.dbtype = DBType.POSTGRES db = psycopg.connect(dsn) self.db = cast("dbapi.DBAPIConnection", db) - self._db = DBTypeDatabase( - DBType.POSTGRES, - lambda: cast("dbapi.DBAPIConnection", psycopg.connect(dsn)), - ) + self._db = PostgresDatabase(lambda: psycopg.connect(dsn)) ret_db = psycopg.connect(dsn) ret_db.rollback() @@ -200,9 +198,6 @@ def drop_tables(self, table: str) -> None: if self.db is None or self._db is None: self._check_db() return - schema_table = table.strip().split(".") - if len(schema_table) != 1 and len(schema_table) != 2: - raise ValueError("invalid table name: " + table) prefix = Prefix(table) self._db.drop_prefix(prefix) @@ -293,9 +288,6 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915 "use json_depth=0 to disable JSON transformation" ) raise ValueError(msg) - schema_table = table.split(".") - if len(schema_table) != 1 and len(schema_table) != 2: - raise ValueError("invalid table name: " + table) if json_depth is None or json_depth < 0 or json_depth > 4: raise ValueError("invalid value for json_depth: " + str(json_depth)) if self._folio is None: @@ -308,57 +300,39 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915 if not self._quiet: print("ldlite: querying: " + path, file=sys.stderr) try: - # First get total number of records - records = self._folio.iterate_records( + (total_records, records) = self._folio.iterate_records( path, self._okapi_timeout, self._okapi_max_retries, self.page_size, query=cast("QueryType", query), ) - (total_records, _) = next(records) - total = min(total_records, limit or total_records) + if limit is not None: + total_records = min(total_records, limit) + records = (x for _, x in zip(range(limit), records, strict=False)) if self._verbose: - print("ldlite: estimated row count: " + str(total), file=sys.stderr) - - class PbarNoop: - def update(self, _: int) -> None: ... - def close(self) -> None: ... - - p_count = count(1) - processed = 0 - pbar: tqdm | PbarNoop # type:ignore[type-arg] - if not self._quiet: - pbar = tqdm( - desc="reading", - total=total, - leave=False, - mininterval=3, - smoothing=0, - colour="#A9A9A9", - bar_format="{desc} {bar}{postfix}", + print( + "ldlite: estimated row count: " + str(total_records), + file=sys.stderr, ) - else: - pbar = PbarNoop() - - def on_processed() -> bool: - pbar.update(1) - nonlocal processed - processed = next(p_count) - return True - def on_processed_limit() -> bool: - pbar.update(1) - nonlocal processed, limit - processed = next(p_count) - return limit is None or processed < limit - - self._db.ingest_records( + processed = self._db.ingest_records( prefix, - on_processed_limit if limit is not None else on_processed, - records, + cast( + "Iterator[bytes]", + tqdm( + records, + desc="downloading", + total=total_records, + leave=False, + mininterval=5, + disable=self._quiet, + unit=table.split(".")[-1], + unit_scale=True, + delay=5, + ), + ), ) - pbar.close() self._db.drop_extracted_tables(prefix) newtables = [table] @@ -386,6 +360,13 @@ def on_processed_limit() -> bool: autocommit(self.db, self.dbtype, True) # Create indexes on id columns (for postgres) if self.dbtype == DBType.POSTGRES: + + class PbarNoop: + def update(self, _: int) -> None: ... + def close(self) -> None: ... + + pbar: tqdm | PbarNoop = PbarNoop() # type:ignore[type-arg] + indexable_attrs = [ (t, a) for t, attrs in newattrs.items() diff --git a/src/ldlite/_database.py b/src/ldlite/_database/__init__.py similarity index 82% rename from src/ldlite/_database.py rename to src/ldlite/_database/__init__.py index 258dedc..c0f3ebe 100644 --- a/src/ldlite/_database.py +++ b/src/ldlite/_database/__init__.py @@ -6,15 +6,18 @@ from psycopg import sql if TYPE_CHECKING: - from _typeshed import dbapi - -DB = TypeVar("DB", bound="dbapi.DBAPIConnection") + import duckdb + import psycopg class Prefix: - def __init__(self, table: str): + def __init__(self, prefix: str): self._schema: str | None = None - sandt = table.split(".") + sandt = prefix.split(".") + if len(sandt) > 2: + msg = f"Expected one or two identifiers but got {prefix}" + raise ValueError(msg) + if len(sandt) == 1: (self._prefix,) = sandt else: @@ -42,7 +45,24 @@ def legacy_jtable(self) -> sql.Identifier: return self.identifier(f"{self._prefix}_jtable") -class Database(ABC, Generic[DB]): +class Database(ABC): + @abstractmethod + def drop_prefix(self, prefix: Prefix) -> None: ... + + @abstractmethod + def drop_raw_table(self, prefix: Prefix) -> None: ... + + @abstractmethod + def drop_extracted_tables(self, prefix: Prefix) -> None: ... + + @abstractmethod + def ingest_records(self, prefix: Prefix, records: Iterator[bytes]) -> int: ... + + +DB = TypeVar("DB", bound="duckdb.DuckDBPyConnection | psycopg.Connection") + + +class TypedDatabase(Database, Generic[DB]): def __init__(self, conn_factory: Callable[[], DB]): self._conn_factory = conn_factory @@ -88,7 +108,7 @@ def drop_extracted_tables( @property @abstractmethod - def _missing_table_error(self) -> tuple[type[Exception], ...]: ... + def _missing_table_error(self) -> type[Exception]: ... def _drop_extracted_tables( self, conn: DB, @@ -137,9 +157,6 @@ def _drop_extracted_tables( .as_string(), ) - @property - @abstractmethod - def _truncate_raw_table_sql(self) -> sql.SQL: ... @property @abstractmethod def _create_raw_table_sql(self) -> sql.SQL: ... @@ -162,26 +179,3 @@ def _prepare_raw_table( table=prefix.raw_table_name, ).as_string(), ) - - @property - @abstractmethod - def _insert_record_sql(self) -> sql.SQL: ... - def ingest_records( - self, - prefix: Prefix, - on_processed: Callable[[], bool], - records: Iterator[tuple[int, bytes]], - ) -> None: - with closing(self._conn_factory()) as conn: - self._prepare_raw_table(conn, prefix) - - insert_sql = self._insert_record_sql.format( - table=prefix.raw_table_name, - ).as_string() - with closing(conn.cursor()) as cur: - for pkey, r in records: - cur.execute(insert_sql, (pkey, r.decode())) - if not on_processed(): - break - - conn.commit() diff --git a/src/ldlite/_database/duckdb.py b/src/ldlite/_database/duckdb.py new file mode 100644 index 0000000..ebe5806 --- /dev/null +++ b/src/ldlite/_database/duckdb.py @@ -0,0 +1,44 @@ +from collections.abc import Iterator +from itertools import count + +import duckdb +from psycopg import sql + +from . import Prefix, TypedDatabase + + +class DuckDbDatabase(TypedDatabase[duckdb.DuckDBPyConnection]): + def _rollback(self, conn: duckdb.DuckDBPyConnection) -> None: + pass + + @property + def _missing_table_error(self) -> type[Exception]: + return duckdb.CatalogException + + @property + def _create_raw_table_sql(self) -> sql.SQL: + return sql.SQL("CREATE TABLE IF NOT EXISTS {table} (__id integer, jsonb text);") + + def ingest_records( + self, + prefix: Prefix, + records: Iterator[bytes], + ) -> int: + pkey = count(1) + with self._conn_factory() as conn: + self._prepare_raw_table(conn, prefix) + + insert_sql = ( + sql.SQL("INSERT INTO {table} VALUES(?, ?);") + .format( + table=prefix.raw_table_name, + ) + .as_string() + ) + # duckdb has better performance bulk inserting in a transaction + with conn.begin() as tx, tx.cursor() as cur: + for r in records: + cur.execute(insert_sql, (next(pkey), r.decode())) + tx.commit() + + return next(pkey) - 1 diff --git a/src/ldlite/_database/postgres.py b/src/ldlite/_database/postgres.py new file mode 100644 index 0000000..8008d22 --- /dev/null +++ b/src/ldlite/_database/postgres.py @@ -0,0 +1,51 @@ +from collections.abc import Iterator +from itertools import count + +import psycopg +from psycopg import sql + +from . import Prefix, TypedDatabase + + +class PostgresDatabase(TypedDatabase[psycopg.Connection]): + def _rollback(self, conn: psycopg.Connection) -> None: + conn.rollback() + + @property + def _missing_table_error(self) -> type[Exception]: + return psycopg.errors.UndefinedTable + + @property + def _create_raw_table_sql(self) -> sql.SQL: + return sql.SQL( + "CREATE TABLE IF NOT EXISTS {table} (__id integer, jsonb jsonb);", + ) + + def ingest_records( + self, + prefix: Prefix, + records: Iterator[bytes], + ) -> int: + pkey = count(1) + with self._conn_factory() as conn: + self._prepare_raw_table(conn, prefix) + + with ( + conn.cursor() as cur, + cur.copy( + sql.SQL( + "COPY {table} (__id, jsonb) FROM STDIN (FORMAT BINARY)", + ).format(table=prefix.raw_table_name), + ) as copy, + ): + # postgres jsonb is always version 1 + # and it always goes in front + jver = (1).to_bytes(1, "big") + for r in records: + rb = bytearray() + rb.extend(jver) + rb.extend(r) + copy.write_row((next(pkey).to_bytes(4, "big"), rb)) + + conn.commit() + return next(pkey) - 1 diff --git a/src/ldlite/_folio.py b/src/ldlite/_folio.py index e8d5aed..1fa87fa 100644 --- a/src/ldlite/_folio.py +++ b/src/ldlite/_folio.py @@ -1,5 +1,6 @@ from collections.abc import Iterator from itertools import count +from typing import cast import orjson from httpx_folio.factories import ( @@ -36,23 +37,13 @@ def iterate_records( retries: int, page_size: int, query: QueryType | None = None, - ) -> Iterator[tuple[int, bytes]]: - """Iterates all records for a given path. - - Returns: - A tuple of the autoincrementing key + the json for each record. - The first result will be the total record count. - """ + ) -> tuple[int, Iterator[bytes]]: is_srs = path.lower() in _SOURCESTATS # this is Java's max size of int because we want all the source records params = QueryParams(query, 2_147_483_647 - 1 if is_srs else page_size) - with self._client_factory( - BasicClientOptions( - retries=retries, - timeout=timeout, - ), - ) as client: + client_opts = BasicClientOptions(retries=retries, timeout=timeout) + with self._client_factory(client_opts) as client: res = client.get( path if not is_srs else _SOURCESTATS[path.lower()], params=params.stats(), @@ -60,60 +51,110 @@ def iterate_records( res.raise_for_status() j = orjson.loads(res.text) r = int(j["totalRecords"]) - yield (r, b"") - - if r == 0: - return - - pkey = count(start=1) - if is_srs: - # streaming is a more stable endpoint for source records - with client.stream( - "GET", - _SOURCESTREAM[path.lower()], - params=params.normalized(), - ) as res: - res.raise_for_status() - record = "" - for f in res.iter_lines(): - # HTTPX can return partial json fragments during iteration - # if they contain "newline-ish" characters like U+2028 - record += f - if len(f) == 0 or f[-1] != "}": - continue - yield (next(pkey), orjson.dumps(orjson.Fragment(record))) - record = "" - return - key = next(iter(j.keys())) - nonid_key = ( - # Grab the first key if there isn't an id column - # because we need it to offset page properly - next(iter(j[key][0].keys())) if "id" not in j[key][0] else None + if r == 0: + return (0, iter([])) + + if is_srs: + return (r, self._iterate_records_srs(client_opts, path, params)) + + key = cast("str", next(iter(j.keys()))) + r1 = j[key][0] + if ( + nonid_key := cast("str", next(iter(r1.keys()))) if "id" not in r1 else None + ) or not params.can_page_by_id(): + return ( + r, + self._iterate_records_offset( + client_opts, + path, + params, + key, + nonid_key, + ), ) - last_id: str | None = None + return ( + r, + self._iterate_records_id( + client_opts, + path, + params, + key, + ), + ) + + def _iterate_records_srs( + self, + client_opts: BasicClientOptions, + path: str, + params: QueryParams, + ) -> Iterator[bytes]: + with ( + self._client_factory(client_opts) as client, + client.stream( + "GET", + _SOURCESTREAM[path.lower()], + params=params.normalized(), + ) as res, + ): + res.raise_for_status() + record = "" + for f in res.iter_lines(): + # HTTPX can return partial json fragments during iteration + # if they contain "newline-ish" characters like U+2028 + record += f + if len(f) == 0 or f[-1] != "}": + continue + yield orjson.dumps(orjson.Fragment(record)) + record = "" + + def _iterate_records_offset( + self, + client_opts: BasicClientOptions, + path: str, + params: QueryParams, + key: str, + nonid_key: str | None, + ) -> Iterator[bytes]: + with self._client_factory(client_opts) as client: page = count(start=1) while True: - if nonid_key is not None: - p = params.offset_paging(key=nonid_key, page=next(page)) - elif params.can_page_by_id(): - p = params.id_paging(last_id=last_id) - else: - p = params.offset_paging(page=next(page)) - - res = client.get(path, params=p) + res = client.get( + path, + params=params.offset_paging(page=next(page)) + if nonid_key is None + else params.offset_paging(key=nonid_key, page=next(page)), + ) res.raise_for_status() last = None for r in (o for o in orjson.loads(res.text)[key] if o is not None): last = r - yield (next(pkey), orjson.dumps(r)) + yield orjson.dumps(r) if last is None: return - last_id = last.get( - "id", - "this value is unused because we're offset paging", - ) + def _iterate_records_id( + self, + client_opts: BasicClientOptions, + path: str, + params: QueryParams, + key: str, + ) -> Iterator[bytes]: + with self._client_factory(client_opts) as client: + last_id: str | None = None + while True: + res = client.get(path, params=params.id_paging(last_id=last_id)) + res.raise_for_status() + + last = None + for r in (o for o in orjson.loads(res.text)[key] if o is not None): + last = r + yield orjson.dumps(r) + + if last is None: + return + + last_id = last["id"] diff --git a/src/ldlite/_sqlx.py b/src/ldlite/_sqlx.py index b52fce8..cf2e749 100644 --- a/src/ldlite/_sqlx.py +++ b/src/ldlite/_sqlx.py @@ -1,19 +1,13 @@ import secrets -from collections.abc import Callable, Iterator -from contextlib import closing from enum import Enum from typing import TYPE_CHECKING, cast import duckdb import psycopg -from psycopg import sql - -from ._database import Database if TYPE_CHECKING: from _typeshed import dbapi - from ._database import Prefix from ._jsonx import JsonValue @@ -23,82 +17,6 @@ class DBType(Enum): POSTGRES = 2 -class DBTypeDatabase(Database["dbapi.DBAPIConnection"]): - def __init__(self, dbtype: DBType, factory: Callable[[], "dbapi.DBAPIConnection"]): - self._dbtype = dbtype - super().__init__(factory) - - @property - def _missing_table_error(self) -> tuple[type[Exception], ...]: - return ( - psycopg.errors.UndefinedTable, - duckdb.CatalogException, - ) - - def _rollback(self, conn: "dbapi.DBAPIConnection") -> None: - if pgdb := as_postgres(conn, self._dbtype): - pgdb.rollback() - - @property - def _create_raw_table_sql(self) -> sql.SQL: - create_sql = "CREATE TABLE IF NOT EXISTS {table} (__id integer, jsonb text);" - if self._dbtype == DBType.POSTGRES: - create_sql = ( - "CREATE TABLE IF NOT EXISTS {table} (__id integer, jsonb jsonb);" - ) - - return sql.SQL(create_sql) - - @property - def _truncate_raw_table_sql(self) -> sql.SQL: - truncate_sql = "TRUNCATE TABLE {table};" - - return sql.SQL(truncate_sql) - - @property - def _insert_record_sql(self) -> sql.SQL: - insert_sql = "INSERT INTO {table} VALUES(?, ?);" - if self._dbtype == DBType.POSTGRES: - insert_sql = "INSERT INTO {table} VALUES(%s, %s);" - - return sql.SQL(insert_sql) - - def ingest_records( - self, - prefix: "Prefix", - on_processed: Callable[[], bool], - records: Iterator[tuple[int, bytes]], - ) -> None: - if self._dbtype != DBType.POSTGRES: - super().ingest_records(prefix, on_processed, records) - return - - with closing(self._conn_factory()) as conn: - self._prepare_raw_table(conn, prefix) - - if pgconn := as_postgres(conn, self._dbtype): - with ( - pgconn.cursor() as cur, - cur.copy( - sql.SQL( - "COPY {table} (__id, jsonb) FROM STDIN (FORMAT BINARY)", - ).format(table=prefix.raw_table_name), - ) as copy, - ): - # postgres jsonb is always version 1 - # and it always goes in front - jver = (1).to_bytes(1, "big") - for pkey, r in records: - rb = bytearray() - rb.extend(jver) - rb.extend(r) - copy.write_row((pkey.to_bytes(4, "big"), rb)) - if not on_processed(): - break - - pgconn.commit() - - def as_duckdb( db: "dbapi.DBAPIConnection", dbtype: DBType,