From e400fc1f2fa4be54102dce4c476445170e5bffcd Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Mon, 22 Sep 2025 17:02:55 +0000 Subject: [PATCH 1/5] Make _database module a directory --- src/ldlite/{_database.py => _database/__init__.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/ldlite/{_database.py => _database/__init__.py} (100%) diff --git a/src/ldlite/_database.py b/src/ldlite/_database/__init__.py similarity index 100% rename from src/ldlite/_database.py rename to src/ldlite/_database/__init__.py From 5077545355ea07c567b5c25cf87248885a374299 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Mon, 22 Sep 2025 17:48:11 +0000 Subject: [PATCH 2/5] WIP: Add duckdb implementation of database abstraction --- src/ldlite/_database/__init__.py | 38 +++++++---------------------- src/ldlite/_database/duckdb.py | 41 ++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 29 deletions(-) create mode 100644 src/ldlite/_database/duckdb.py diff --git a/src/ldlite/_database/__init__.py b/src/ldlite/_database/__init__.py index 258dedc..429bd49 100644 --- a/src/ldlite/_database/__init__.py +++ b/src/ldlite/_database/__init__.py @@ -6,9 +6,8 @@ from psycopg import sql if TYPE_CHECKING: - from _typeshed import dbapi - -DB = TypeVar("DB", bound="dbapi.DBAPIConnection") + import duckdb + import psycopg class Prefix: @@ -42,13 +41,13 @@ def legacy_jtable(self) -> sql.Identifier: return self.identifier(f"{self._prefix}_jtable") +DB = TypeVar("DB", bound="duckdb.DuckDBPyConnection | psycopg.Connection") + + class Database(ABC, Generic[DB]): def __init__(self, conn_factory: Callable[[], DB]): self._conn_factory = conn_factory - @abstractmethod - def _rollback(self, conn: DB) -> None: ... - def drop_prefix( self, prefix: Prefix, @@ -86,9 +85,8 @@ def drop_extracted_tables( self._drop_extracted_tables(conn, prefix) conn.commit() - @property @abstractmethod - def _missing_table_error(self) -> tuple[type[Exception], ...]: ... + def _missing_table_error(self) -> type[Exception]: ... def _drop_extracted_tables( self, conn: DB, @@ -103,7 +101,7 @@ def _drop_extracted_tables( .as_string(), ) except self._missing_table_error: - self._rollback(conn) + conn.rollback() else: tables.extend(cur.fetchall()) @@ -115,7 +113,7 @@ def _drop_extracted_tables( .as_string(), ) except self._missing_table_error: - self._rollback(conn) + conn.rollback() else: tables.extend(cur.fetchall()) @@ -137,9 +135,6 @@ def _drop_extracted_tables( .as_string(), ) - @property - @abstractmethod - def _truncate_raw_table_sql(self) -> sql.SQL: ... @property @abstractmethod def _create_raw_table_sql(self) -> sql.SQL: ... @@ -163,25 +158,10 @@ def _prepare_raw_table( ).as_string(), ) - @property @abstractmethod - def _insert_record_sql(self) -> sql.SQL: ... def ingest_records( self, prefix: Prefix, on_processed: Callable[[], bool], records: Iterator[tuple[int, bytes]], - ) -> None: - with closing(self._conn_factory()) as conn: - self._prepare_raw_table(conn, prefix) - - insert_sql = self._insert_record_sql.format( - table=prefix.raw_table_name, - ).as_string() - with closing(conn.cursor()) as cur: - for pkey, r in records: - cur.execute(insert_sql, (pkey, r.decode())) - if not on_processed(): - break - - conn.commit() + ) -> None: ... diff --git a/src/ldlite/_database/duckdb.py b/src/ldlite/_database/duckdb.py new file mode 100644 index 0000000..ed470bd --- /dev/null +++ b/src/ldlite/_database/duckdb.py @@ -0,0 +1,41 @@ +from collections.abc import Callable, Iterator + +import duckdb +from psycopg import sql + +from . import Database, Prefix + + +class DuckDbDatabase( + Database[duckdb.DuckDBPyConnection], +): + def _missing_table_error(self) -> type[Exception]: + return duckdb.CatalogException + + @property + def _create_raw_table_sql(self) -> sql.SQL: + return sql.SQL("CREATE TABLE IF NOT EXISTS {table} (__id integer, jsonb text);") + + def ingest_records( + self, + prefix: Prefix, + on_processed: Callable[[], bool], + records: Iterator[tuple[int, bytes]], + ) -> None: + with self._conn_factory() as conn, conn.begin() as tx: + self._prepare_raw_table(tx, prefix) + + insert_sql = ( + sql.SQL("INSERT INTO {table} VALUES(?, ?);") + .format( + table=prefix.raw_table_name, + ) + .as_string() + ) + with tx.cursor() as cur: + for pkey, r in records: + cur.execute(insert_sql, (pkey, r.decode())) + if not on_processed(): + break + + tx.commit() From f2ec7c3326671ec250d79d49d7cc7e5079a83da0 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Mon, 22 Sep 2025 18:56:45 +0000 Subject: [PATCH 3/5] Refactor existing dbtype implementation to postgres/duckdb --- src/ldlite/__init__.py | 23 ++++----- src/ldlite/_database/__init__.py | 38 +++++++++++++-- src/ldlite/_database/duckdb.py | 18 +++---- src/ldlite/_database/postgres.py | 51 ++++++++++++++++++++ src/ldlite/_sqlx.py | 82 -------------------------------- 5 files changed, 105 insertions(+), 107 deletions(-) create mode 100644 src/ldlite/_database/postgres.py diff --git a/src/ldlite/__init__.py b/src/ldlite/__init__.py index 4d46edf..fae102b 100644 --- a/src/ldlite/__init__.py +++ b/src/ldlite/__init__.py @@ -43,14 +43,13 @@ from httpx_folio.auth import FolioParams from tqdm import tqdm +from . import _database as _db from ._csv import to_csv -from ._database import Prefix from ._folio import FolioClient from ._jsonx import Attr, transform_json from ._select import select from ._sqlx import ( DBType, - DBTypeDatabase, as_postgres, autocommit, sqlid, @@ -77,7 +76,7 @@ def __init__(self) -> None: self._quiet = False self.dbtype: DBType = DBType.UNDEFINED self.db: dbapi.DBAPIConnection | None = None - self._db: DBTypeDatabase | None = None + self._db: _db.Database | None = None self._folio: FolioClient | None = None self.page_size = 1000 self._okapi_timeout = 60 @@ -124,14 +123,13 @@ def _connect_db_duckdb( db = ld.connect_db_duckdb(filename='ldlite.db') """ + from ._database.duckdb import DuckDbDatabase # noqa: PLC0415 + self.dbtype = DBType.DUCKDB fn = filename if filename is not None else ":memory:" db = duckdb.connect(database=fn) self.db = cast("dbapi.DBAPIConnection", db.cursor()) - self._db = DBTypeDatabase( - DBType.DUCKDB, - lambda: cast("dbapi.DBAPIConnection", db.cursor()), - ) + self._db = DuckDbDatabase(lambda: db.cursor()) return db.cursor() @@ -146,13 +144,12 @@ def connect_db_postgresql(self, dsn: str) -> psycopg.Connection: db = ld.connect_db_postgresql(dsn='dbname=ld host=localhost user=ldlite') """ + from ._database.postgres import PostgresDatabase # noqa: PLC0415 + self.dbtype = DBType.POSTGRES db = psycopg.connect(dsn) self.db = cast("dbapi.DBAPIConnection", db) - self._db = DBTypeDatabase( - DBType.POSTGRES, - lambda: cast("dbapi.DBAPIConnection", psycopg.connect(dsn)), - ) + self._db = PostgresDatabase(lambda: psycopg.connect(dsn)) ret_db = psycopg.connect(dsn) ret_db.rollback() @@ -203,7 +200,7 @@ def drop_tables(self, table: str) -> None: schema_table = table.strip().split(".") if len(schema_table) != 1 and len(schema_table) != 2: raise ValueError("invalid table name: " + table) - prefix = Prefix(table) + prefix = _db.Prefix(table) self._db.drop_prefix(prefix) def set_folio_max_retries(self, max_retries: int) -> None: @@ -304,7 +301,7 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915 if self.db is None or self._db is None: self._check_db() return [] - prefix = Prefix(table) + prefix = _db.Prefix(table) if not self._quiet: print("ldlite: querying: " + path, file=sys.stderr) try: diff --git a/src/ldlite/_database/__init__.py b/src/ldlite/_database/__init__.py index 429bd49..01ec251 100644 --- a/src/ldlite/_database/__init__.py +++ b/src/ldlite/_database/__init__.py @@ -41,13 +41,44 @@ def legacy_jtable(self) -> sql.Identifier: return self.identifier(f"{self._prefix}_jtable") +class Database(ABC): + @abstractmethod + def drop_prefix( + self, + prefix: Prefix, + ) -> None: ... + + @abstractmethod + def drop_raw_table( + self, + prefix: Prefix, + ) -> None: ... + + @abstractmethod + def drop_extracted_tables( + self, + prefix: Prefix, + ) -> None: ... + + @abstractmethod + def ingest_records( + self, + prefix: Prefix, + on_processed: Callable[[], bool], + records: Iterator[tuple[int, bytes]], + ) -> None: ... + + DB = TypeVar("DB", bound="duckdb.DuckDBPyConnection | psycopg.Connection") -class Database(ABC, Generic[DB]): +class TypedDatabase(Database, Generic[DB]): def __init__(self, conn_factory: Callable[[], DB]): self._conn_factory = conn_factory + @abstractmethod + def _rollback(self, conn: DB) -> None: ... + def drop_prefix( self, prefix: Prefix, @@ -85,6 +116,7 @@ def drop_extracted_tables( self._drop_extracted_tables(conn, prefix) conn.commit() + @property @abstractmethod def _missing_table_error(self) -> type[Exception]: ... def _drop_extracted_tables( @@ -101,7 +133,7 @@ def _drop_extracted_tables( .as_string(), ) except self._missing_table_error: - conn.rollback() + self._rollback(conn) else: tables.extend(cur.fetchall()) @@ -113,7 +145,7 @@ def _drop_extracted_tables( .as_string(), ) except self._missing_table_error: - conn.rollback() + self._rollback(conn) else: tables.extend(cur.fetchall()) diff --git a/src/ldlite/_database/duckdb.py b/src/ldlite/_database/duckdb.py index ed470bd..c62fc19 100644 --- a/src/ldlite/_database/duckdb.py +++ b/src/ldlite/_database/duckdb.py @@ -3,12 +3,14 @@ import duckdb from psycopg import sql -from . import Database, Prefix +from . import Prefix, TypedDatabase -class DuckDbDatabase( - Database[duckdb.DuckDBPyConnection], -): +class DuckDbDatabase(TypedDatabase[duckdb.DuckDBPyConnection]): + def _rollback(self, conn: duckdb.DuckDBPyConnection) -> None: + pass + + @property def _missing_table_error(self) -> type[Exception]: return duckdb.CatalogException @@ -22,8 +24,8 @@ def ingest_records( on_processed: Callable[[], bool], records: Iterator[tuple[int, bytes]], ) -> None: - with self._conn_factory() as conn, conn.begin() as tx: - self._prepare_raw_table(tx, prefix) + with self._conn_factory() as conn: + self._prepare_raw_table(conn, prefix) insert_sql = ( sql.SQL("INSERT INTO {table} VALUES(?, ?);") @@ -32,10 +34,8 @@ def ingest_records( ) .as_string() ) - with tx.cursor() as cur: + with conn.cursor() as cur: for pkey, r in records: cur.execute(insert_sql, (pkey, r.decode())) if not on_processed(): break - - tx.commit() diff --git a/src/ldlite/_database/postgres.py b/src/ldlite/_database/postgres.py new file mode 100644 index 0000000..99ae59b --- /dev/null +++ b/src/ldlite/_database/postgres.py @@ -0,0 +1,51 @@ +from collections.abc import Callable, Iterator + +import psycopg +from psycopg import sql + +from . import Prefix, TypedDatabase + + +class PostgresDatabase(TypedDatabase[psycopg.Connection]): + def _rollback(self, conn: psycopg.Connection) -> None: + conn.rollback() + + @property + def _missing_table_error(self) -> type[Exception]: + return psycopg.errors.UndefinedTable + + @property + def _create_raw_table_sql(self) -> sql.SQL: + return sql.SQL( + "CREATE TABLE IF NOT EXISTS {table} (__id integer, jsonb jsonb);", + ) + + def ingest_records( + self, + prefix: Prefix, + on_processed: Callable[[], bool], + records: Iterator[tuple[int, bytes]], + ) -> None: + with self._conn_factory() as conn: + self._prepare_raw_table(conn, prefix) + + with ( + conn.cursor() as cur, + cur.copy( + sql.SQL( + "COPY {table} (__id, jsonb) FROM STDIN (FORMAT BINARY)", + ).format(table=prefix.raw_table_name), + ) as copy, + ): + # postgres jsonb is always version 1 + # and it always goes in front + jver = (1).to_bytes(1, "big") + for pkey, r in records: + rb = bytearray() + rb.extend(jver) + rb.extend(r) + copy.write_row((pkey.to_bytes(4, "big"), rb)) + if not on_processed(): + break + + conn.commit() diff --git a/src/ldlite/_sqlx.py b/src/ldlite/_sqlx.py index b52fce8..cf2e749 100644 --- a/src/ldlite/_sqlx.py +++ b/src/ldlite/_sqlx.py @@ -1,19 +1,13 @@ import secrets -from collections.abc import Callable, Iterator -from contextlib import closing from enum import Enum from typing import TYPE_CHECKING, cast import duckdb import psycopg -from psycopg import sql - -from ._database import Database if TYPE_CHECKING: from _typeshed import dbapi - from ._database import Prefix from ._jsonx import JsonValue @@ -23,82 +17,6 @@ class DBType(Enum): POSTGRES = 2 -class DBTypeDatabase(Database["dbapi.DBAPIConnection"]): - def __init__(self, dbtype: DBType, factory: Callable[[], "dbapi.DBAPIConnection"]): - self._dbtype = dbtype - super().__init__(factory) - - @property - def _missing_table_error(self) -> tuple[type[Exception], ...]: - return ( - psycopg.errors.UndefinedTable, - duckdb.CatalogException, - ) - - def _rollback(self, conn: "dbapi.DBAPIConnection") -> None: - if pgdb := as_postgres(conn, self._dbtype): - pgdb.rollback() - - @property - def _create_raw_table_sql(self) -> sql.SQL: - create_sql = "CREATE TABLE IF NOT EXISTS {table} (__id integer, jsonb text);" - if self._dbtype == DBType.POSTGRES: - create_sql = ( - "CREATE TABLE IF NOT EXISTS {table} (__id integer, jsonb jsonb);" - ) - - return sql.SQL(create_sql) - - @property - def _truncate_raw_table_sql(self) -> sql.SQL: - truncate_sql = "TRUNCATE TABLE {table};" - - return sql.SQL(truncate_sql) - - @property - def _insert_record_sql(self) -> sql.SQL: - insert_sql = "INSERT INTO {table} VALUES(?, ?);" - if self._dbtype == DBType.POSTGRES: - insert_sql = "INSERT INTO {table} VALUES(%s, %s);" - - return sql.SQL(insert_sql) - - def ingest_records( - self, - prefix: "Prefix", - on_processed: Callable[[], bool], - records: Iterator[tuple[int, bytes]], - ) -> None: - if self._dbtype != DBType.POSTGRES: - super().ingest_records(prefix, on_processed, records) - return - - with closing(self._conn_factory()) as conn: - self._prepare_raw_table(conn, prefix) - - if pgconn := as_postgres(conn, self._dbtype): - with ( - pgconn.cursor() as cur, - cur.copy( - sql.SQL( - "COPY {table} (__id, jsonb) FROM STDIN (FORMAT BINARY)", - ).format(table=prefix.raw_table_name), - ) as copy, - ): - # postgres jsonb is always version 1 - # and it always goes in front - jver = (1).to_bytes(1, "big") - for pkey, r in records: - rb = bytearray() - rb.extend(jver) - rb.extend(r) - copy.write_row((pkey.to_bytes(4, "big"), rb)) - if not on_processed(): - break - - pgconn.commit() - - def as_duckdb( db: "dbapi.DBAPIConnection", dbtype: DBType, From ed48d8237cf52e22a5af4bbbc395cf96b8ee8497 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Mon, 22 Sep 2025 20:40:44 +0000 Subject: [PATCH 4/5] Refactor iterate/ingest workflow with bonus tqdm cleanup --- src/ldlite/__init__.py | 72 +++++++-------- src/ldlite/_database/__init__.py | 13 +-- src/ldlite/_database/duckdb.py | 17 ++-- src/ldlite/_database/postgres.py | 16 ++-- src/ldlite/_folio.py | 147 ++++++++++++++++++++----------- 5 files changed, 147 insertions(+), 118 deletions(-) diff --git a/src/ldlite/__init__.py b/src/ldlite/__init__.py index fae102b..16dad11 100644 --- a/src/ldlite/__init__.py +++ b/src/ldlite/__init__.py @@ -35,7 +35,6 @@ """ import sys -from itertools import count from typing import TYPE_CHECKING, NoReturn, cast import duckdb @@ -56,6 +55,8 @@ ) if TYPE_CHECKING: + from collections.abc import Iterator + from _typeshed import dbapi from httpx_folio.query import QueryType @@ -305,57 +306,39 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915 if not self._quiet: print("ldlite: querying: " + path, file=sys.stderr) try: - # First get total number of records - records = self._folio.iterate_records( + (total_records, records) = self._folio.iterate_records( path, self._okapi_timeout, self._okapi_max_retries, self.page_size, query=cast("QueryType", query), ) - (total_records, _) = next(records) - total = min(total_records, limit or total_records) + if limit is not None: + total_records = min(total_records, limit) + records = (x for _, x in zip(range(limit), records, strict=False)) if self._verbose: - print("ldlite: estimated row count: " + str(total), file=sys.stderr) - - class PbarNoop: - def update(self, _: int) -> None: ... - def close(self) -> None: ... - - p_count = count(1) - processed = 0 - pbar: tqdm | PbarNoop # type:ignore[type-arg] - if not self._quiet: - pbar = tqdm( - desc="reading", - total=total, - leave=False, - mininterval=3, - smoothing=0, - colour="#A9A9A9", - bar_format="{desc} {bar}{postfix}", + print( + "ldlite: estimated row count: " + str(total_records), + file=sys.stderr, ) - else: - pbar = PbarNoop() - - def on_processed() -> bool: - pbar.update(1) - nonlocal processed - processed = next(p_count) - return True - def on_processed_limit() -> bool: - pbar.update(1) - nonlocal processed, limit - processed = next(p_count) - return limit is None or processed < limit - - self._db.ingest_records( + processed = self._db.ingest_records( prefix, - on_processed_limit if limit is not None else on_processed, - records, + cast( + "Iterator[bytes]", + tqdm( + records, + desc="downloading", + total=total_records, + leave=False, + mininterval=5, + disable=self._quiet, + unit=table.split(".")[-1], + unit_scale=True, + delay=5, + ), + ), ) - pbar.close() self._db.drop_extracted_tables(prefix) newtables = [table] @@ -383,6 +366,13 @@ def on_processed_limit() -> bool: autocommit(self.db, self.dbtype, True) # Create indexes on id columns (for postgres) if self.dbtype == DBType.POSTGRES: + + class PbarNoop: + def update(self, _: int) -> None: ... + def close(self) -> None: ... + + pbar: tqdm | PbarNoop = PbarNoop() # type:ignore[type-arg] + indexable_attrs = [ (t, a) for t, attrs in newattrs.items() diff --git a/src/ldlite/_database/__init__.py b/src/ldlite/_database/__init__.py index 01ec251..ebae4b9 100644 --- a/src/ldlite/_database/__init__.py +++ b/src/ldlite/_database/__init__.py @@ -64,9 +64,8 @@ def drop_extracted_tables( def ingest_records( self, prefix: Prefix, - on_processed: Callable[[], bool], - records: Iterator[tuple[int, bytes]], - ) -> None: ... + records: Iterator[bytes], + ) -> int: ... DB = TypeVar("DB", bound="duckdb.DuckDBPyConnection | psycopg.Connection") @@ -189,11 +188,3 @@ def _prepare_raw_table( table=prefix.raw_table_name, ).as_string(), ) - - @abstractmethod - def ingest_records( - self, - prefix: Prefix, - on_processed: Callable[[], bool], - records: Iterator[tuple[int, bytes]], - ) -> None: ... diff --git a/src/ldlite/_database/duckdb.py b/src/ldlite/_database/duckdb.py index c62fc19..f4c6909 100644 --- a/src/ldlite/_database/duckdb.py +++ b/src/ldlite/_database/duckdb.py @@ -1,4 +1,5 @@ -from collections.abc import Callable, Iterator +from collections.abc import Iterator +from itertools import count import duckdb from psycopg import sql @@ -21,9 +22,9 @@ def _create_raw_table_sql(self) -> sql.SQL: def ingest_records( self, prefix: Prefix, - on_processed: Callable[[], bool], - records: Iterator[tuple[int, bytes]], - ) -> None: + records: Iterator[bytes], + ) -> int: + pkey = count(1) with self._conn_factory() as conn: self._prepare_raw_table(conn, prefix) @@ -35,7 +36,7 @@ def ingest_records( .as_string() ) with conn.cursor() as cur: - for pkey, r in records: - cur.execute(insert_sql, (pkey, r.decode())) - if not on_processed(): - break + for r in records: + cur.execute(insert_sql, (next(pkey), r.decode())) + + return next(pkey) - 1 diff --git a/src/ldlite/_database/postgres.py b/src/ldlite/_database/postgres.py index 99ae59b..8008d22 100644 --- a/src/ldlite/_database/postgres.py +++ b/src/ldlite/_database/postgres.py @@ -1,4 +1,5 @@ -from collections.abc import Callable, Iterator +from collections.abc import Iterator +from itertools import count import psycopg from psycopg import sql @@ -23,9 +24,9 @@ def _create_raw_table_sql(self) -> sql.SQL: def ingest_records( self, prefix: Prefix, - on_processed: Callable[[], bool], - records: Iterator[tuple[int, bytes]], - ) -> None: + records: Iterator[bytes], + ) -> int: + pkey = count(1) with self._conn_factory() as conn: self._prepare_raw_table(conn, prefix) @@ -40,12 +41,11 @@ def ingest_records( # postgres jsonb is always version 1 # and it always goes in front jver = (1).to_bytes(1, "big") - for pkey, r in records: + for r in records: rb = bytearray() rb.extend(jver) rb.extend(r) - copy.write_row((pkey.to_bytes(4, "big"), rb)) - if not on_processed(): - break + copy.write_row((next(pkey).to_bytes(4, "big"), rb)) conn.commit() + return next(pkey) - 1 diff --git a/src/ldlite/_folio.py b/src/ldlite/_folio.py index e8d5aed..7530725 100644 --- a/src/ldlite/_folio.py +++ b/src/ldlite/_folio.py @@ -1,5 +1,6 @@ from collections.abc import Iterator from itertools import count +from typing import cast import orjson from httpx_folio.factories import ( @@ -36,7 +37,7 @@ def iterate_records( retries: int, page_size: int, query: QueryType | None = None, - ) -> Iterator[tuple[int, bytes]]: + ) -> tuple[int, Iterator[bytes]]: """Iterates all records for a given path. Returns: @@ -47,12 +48,8 @@ def iterate_records( # this is Java's max size of int because we want all the source records params = QueryParams(query, 2_147_483_647 - 1 if is_srs else page_size) - with self._client_factory( - BasicClientOptions( - retries=retries, - timeout=timeout, - ), - ) as client: + client_opts = BasicClientOptions(retries=retries, timeout=timeout) + with self._client_factory(client_opts) as client: res = client.get( path if not is_srs else _SOURCESTATS[path.lower()], params=params.stats(), @@ -60,60 +57,110 @@ def iterate_records( res.raise_for_status() j = orjson.loads(res.text) r = int(j["totalRecords"]) - yield (r, b"") - - if r == 0: - return - - pkey = count(start=1) - if is_srs: - # streaming is a more stable endpoint for source records - with client.stream( - "GET", - _SOURCESTREAM[path.lower()], - params=params.normalized(), - ) as res: - res.raise_for_status() - record = "" - for f in res.iter_lines(): - # HTTPX can return partial json fragments during iteration - # if they contain "newline-ish" characters like U+2028 - record += f - if len(f) == 0 or f[-1] != "}": - continue - yield (next(pkey), orjson.dumps(orjson.Fragment(record))) - record = "" - return - key = next(iter(j.keys())) - nonid_key = ( - # Grab the first key if there isn't an id column - # because we need it to offset page properly - next(iter(j[key][0].keys())) if "id" not in j[key][0] else None + if r == 0: + return (0, iter([])) + + if is_srs: + return (r, self._iterate_records_srs(client_opts, path, params)) + + key = cast("str", next(iter(j.keys()))) + r1 = j[key][0] + if ( + nonid_key := cast("str", next(iter(r1.keys()))) if "id" not in r1 else None + ) or not params.can_page_by_id(): + return ( + r, + self._iterate_records_offset( + client_opts, + path, + params, + key, + nonid_key, + ), ) - last_id: str | None = None + return ( + r, + self._iterate_records_id( + client_opts, + path, + params, + key, + ), + ) + + def _iterate_records_srs( + self, + client_opts: BasicClientOptions, + path: str, + params: QueryParams, + ) -> Iterator[bytes]: + with ( + self._client_factory(client_opts) as client, + client.stream( + "GET", + _SOURCESTREAM[path.lower()], + params=params.normalized(), + ) as res, + ): + res.raise_for_status() + record = "" + for f in res.iter_lines(): + # HTTPX can return partial json fragments during iteration + # if they contain "newline-ish" characters like U+2028 + record += f + if len(f) == 0 or f[-1] != "}": + continue + yield orjson.dumps(orjson.Fragment(record)) + record = "" + + def _iterate_records_offset( + self, + client_opts: BasicClientOptions, + path: str, + params: QueryParams, + key: str, + nonid_key: str | None, + ) -> Iterator[bytes]: + with self._client_factory(client_opts) as client: page = count(start=1) while True: - if nonid_key is not None: - p = params.offset_paging(key=nonid_key, page=next(page)) - elif params.can_page_by_id(): - p = params.id_paging(last_id=last_id) - else: - p = params.offset_paging(page=next(page)) - - res = client.get(path, params=p) + res = client.get( + path, + params=params.offset_paging(page=next(page)) + if nonid_key is None + else params.offset_paging(key=nonid_key, page=next(page)), + ) + res.raise_for_status() + + last = None + for r in (o for o in orjson.loads(res.text)[key] if o is not None): + last = r + yield orjson.dumps(r) + + if last is None: + return + + def _iterate_records_id( + self, + client_opts: BasicClientOptions, + path: str, + params: QueryParams, + key: str, + ) -> Iterator[bytes]: + with self._client_factory(client_opts) as client: + last_id: str | None = None + while True: + res = client.get(path, params=params.id_paging(last_id=last_id)) res.raise_for_status() last = None for r in (o for o in orjson.loads(res.text)[key] if o is not None): last = r - yield (next(pkey), orjson.dumps(r)) + yield orjson.dumps(r) if last is None: return - last_id = last.get( - "id", - "this value is unused because we're offset paging", - ) + last_id = last["id"] From 98f304b3fe999965f6bd131404bd4929218dffb0 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Mon, 22 Sep 2025 21:28:02 +0000 Subject: [PATCH 5/5] Misc code review changes --- src/ldlite/__init__.py | 14 ++++---------- src/ldlite/_database/__init__.py | 29 ++++++++++------------------- src/ldlite/_database/duckdb.py | 4 +++- src/ldlite/_folio.py | 6 ------ 4 files changed, 17 insertions(+), 36 deletions(-) diff --git a/src/ldlite/__init__.py b/src/ldlite/__init__.py index 16dad11..9b5bbbd 100644 --- a/src/ldlite/__init__.py +++ b/src/ldlite/__init__.py @@ -42,8 +42,8 @@ from httpx_folio.auth import FolioParams from tqdm import tqdm -from . import _database as _db from ._csv import to_csv +from ._database import Database, Prefix from ._folio import FolioClient from ._jsonx import Attr, transform_json from ._select import select @@ -77,7 +77,7 @@ def __init__(self) -> None: self._quiet = False self.dbtype: DBType = DBType.UNDEFINED self.db: dbapi.DBAPIConnection | None = None - self._db: _db.Database | None = None + self._db: Database | None = None self._folio: FolioClient | None = None self.page_size = 1000 self._okapi_timeout = 60 @@ -198,10 +198,7 @@ def drop_tables(self, table: str) -> None: if self.db is None or self._db is None: self._check_db() return - schema_table = table.strip().split(".") - if len(schema_table) != 1 and len(schema_table) != 2: - raise ValueError("invalid table name: " + table) - prefix = _db.Prefix(table) + prefix = Prefix(table) self._db.drop_prefix(prefix) def set_folio_max_retries(self, max_retries: int) -> None: @@ -291,9 +288,6 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915 "use json_depth=0 to disable JSON transformation" ) raise ValueError(msg) - schema_table = table.split(".") - if len(schema_table) != 1 and len(schema_table) != 2: - raise ValueError("invalid table name: " + table) if json_depth is None or json_depth < 0 or json_depth > 4: raise ValueError("invalid value for json_depth: " + str(json_depth)) if self._folio is None: @@ -302,7 +296,7 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915 if self.db is None or self._db is None: self._check_db() return [] - prefix = _db.Prefix(table) + prefix = Prefix(table) if not self._quiet: print("ldlite: querying: " + path, file=sys.stderr) try: diff --git a/src/ldlite/_database/__init__.py b/src/ldlite/_database/__init__.py index ebae4b9..c0f3ebe 100644 --- a/src/ldlite/_database/__init__.py +++ b/src/ldlite/_database/__init__.py @@ -11,9 +11,13 @@ class Prefix: - def __init__(self, table: str): + def __init__(self, prefix: str): self._schema: str | None = None - sandt = table.split(".") + sandt = prefix.split(".") + if len(sandt) > 2: + msg = f"Expected one or two identifiers but got {prefix}" + raise ValueError(msg) + if len(sandt) == 1: (self._prefix,) = sandt else: @@ -43,29 +47,16 @@ def legacy_jtable(self) -> sql.Identifier: class Database(ABC): @abstractmethod - def drop_prefix( - self, - prefix: Prefix, - ) -> None: ... + def drop_prefix(self, prefix: Prefix) -> None: ... @abstractmethod - def drop_raw_table( - self, - prefix: Prefix, - ) -> None: ... + def drop_raw_table(self, prefix: Prefix) -> None: ... @abstractmethod - def drop_extracted_tables( - self, - prefix: Prefix, - ) -> None: ... + def drop_extracted_tables(self, prefix: Prefix) -> None: ... @abstractmethod - def ingest_records( - self, - prefix: Prefix, - records: Iterator[bytes], - ) -> int: ... + def ingest_records(self, prefix: Prefix, records: Iterator[bytes]) -> int: ... DB = TypeVar("DB", bound="duckdb.DuckDBPyConnection | psycopg.Connection") diff --git a/src/ldlite/_database/duckdb.py b/src/ldlite/_database/duckdb.py index f4c6909..ebe5806 100644 --- a/src/ldlite/_database/duckdb.py +++ b/src/ldlite/_database/duckdb.py @@ -35,8 +35,10 @@ def ingest_records( ) .as_string() ) - with conn.cursor() as cur: + # duckdb has better performance bulk inserting in a transaction + with conn.begin() as tx, tx.cursor() as cur: for r in records: cur.execute(insert_sql, (next(pkey), r.decode())) + tx.commit() return next(pkey) - 1 diff --git a/src/ldlite/_folio.py b/src/ldlite/_folio.py index 7530725..1fa87fa 100644 --- a/src/ldlite/_folio.py +++ b/src/ldlite/_folio.py @@ -38,12 +38,6 @@ def iterate_records( page_size: int, query: QueryType | None = None, ) -> tuple[int, Iterator[bytes]]: - """Iterates all records for a given path. - - Returns: - A tuple of the autoincrementing key + the json for each record. - The first result will be the total record count. - """ is_srs = path.lower() in _SOURCESTATS # this is Java's max size of int because we want all the source records params = QueryParams(query, 2_147_483_647 - 1 if is_srs else page_size)