Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
358 changes: 178 additions & 180 deletions Pipfile.lock

Large diffs are not rendered by default.

25 changes: 12 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
generate_sample_embeddings_for_run,
generate_sample_records,
)
from timdex_dataset_api import TIMDEXDataset, TIMDEXDatasetMetadata
from timdex_dataset_api import TIMDEXDataset
from timdex_dataset_api.dataset import TIMDEXDatasetConfig
from timdex_dataset_api.embeddings import (
DatasetEmbedding,
TIMDEXEmbeddings,
)
from timdex_dataset_api.metadata import TIMDEXDatasetMetadata
from timdex_dataset_api.record import DatasetRecord


Expand Down Expand Up @@ -230,10 +231,8 @@ def timdex_dataset_same_day_runs(tmp_path) -> TIMDEXDataset:
@pytest.fixture(scope="module")
def timdex_metadata(timdex_dataset_with_runs) -> TIMDEXDatasetMetadata:
"""TIMDEXDatasetMetadata with static database file created."""
metadata = TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)
metadata.rebuild_dataset_metadata()
metadata.refresh()
return metadata
timdex_dataset_with_runs.metadata.rebuild_dataset_metadata()
return timdex_dataset_with_runs.metadata


@pytest.fixture(scope="module")
Expand All @@ -247,9 +246,9 @@ def timdex_dataset_with_runs_with_metadata(


@pytest.fixture
def timdex_metadata_empty(timdex_dataset_with_runs) -> TIMDEXDatasetMetadata:
def timdex_metadata_empty(timdex_dataset_empty) -> TIMDEXDatasetMetadata:
"""TIMDEXDatasetMetadata without static database file."""
return TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)
return timdex_dataset_empty.metadata


@pytest.fixture
Expand All @@ -271,7 +270,8 @@ def timdex_metadata_with_deltas(
)
td.write(records)

return TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)
# return fresh TIMDEXDataset's metadata
return TIMDEXDataset(timdex_dataset_with_runs.location).metadata


@pytest.fixture
Expand All @@ -286,12 +286,11 @@ def timdex_metadata_merged_deltas(
# clone dataset with runs using new dataset location
td = TIMDEXDataset(dataset_location, config=timdex_dataset_with_runs.config)

# clone metadata and merge append deltas
metadata = TIMDEXDatasetMetadata(td.location)
metadata.merge_append_deltas()
metadata.refresh()
# merge append deltas via the TD's metadata
td.metadata.merge_append_deltas()
td.refresh()

return metadata
return td.metadata


# ================================================================================
Expand Down
11 changes: 6 additions & 5 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,13 @@ def test_embeddings_read_batches_yields_pyarrow_record_batches(
timdex_dataset_empty.metadata.rebuild_dataset_metadata()
timdex_dataset_empty.refresh()

# write embeddings
timdex_embeddings = TIMDEXEmbeddings(timdex_dataset_empty)
timdex_embeddings.write(sample_embeddings_generator(100, run_id="test-run"))
timdex_embeddings = TIMDEXEmbeddings(timdex_dataset_empty)
# write embeddings and refresh to pick up new views
timdex_dataset_empty.embeddings.write(
sample_embeddings_generator(100, run_id="test-run")
)
timdex_dataset_empty.refresh()

batches = timdex_embeddings.read_batches_iter()
batches = timdex_dataset_empty.embeddings.read_batches_iter()
batch = next(batches)
assert isinstance(batch, pa.RecordBatch)

Expand Down
64 changes: 35 additions & 29 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from duckdb import DuckDBPyConnection

from timdex_dataset_api import TIMDEXDataset, TIMDEXDatasetMetadata
from timdex_dataset_api import TIMDEXDataset

ORDERED_METADATA_COLUMN_NAMES = [
"timdex_record_id",
Expand All @@ -21,29 +21,33 @@
]


def test_tdm_init_no_metadata_file_warning_success(caplog, timdex_dataset_with_runs):
TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)

def test_tdm_init_no_metadata_file_warning_success(caplog, tmp_path):
# creating a new TIMDEXDataset will log warning if no metadata file
caplog.set_level("WARNING")
TIMDEXDataset(str(tmp_path / "new_empty_dataset"))
assert "Static metadata database not found" in caplog.text


def test_tdm_local_dataset_structure_properties(tmp_path):
local_root = str(Path(tmp_path) / "path/to/nothing")
tdm_local = TIMDEXDatasetMetadata(local_root)
assert tdm_local.location == local_root
assert tdm_local.location_scheme == "file"
td_local = TIMDEXDataset(local_root)
assert td_local.metadata.location == local_root
assert td_local.metadata.location_scheme == "file"


def test_tdm_s3_dataset_structure_properties(s3_bucket_mocked):
s3_root = "s3://timdex/dataset"
tdm_s3 = TIMDEXDatasetMetadata(s3_root)
assert tdm_s3.location == s3_root
assert tdm_s3.location_scheme == "s3"
def test_tdm_s3_dataset_structure_properties(timdex_dataset_empty):
# test that location_scheme property works correctly for local paths
# S3 tests require full mocking and are covered in other tests
assert timdex_dataset_empty.metadata.location_scheme == "file"


def test_tdm_create_metadata_database_file_success(caplog, timdex_metadata_empty):
def test_tdm_create_metadata_database_file_success(
caplog, timdex_dataset_with_runs, timdex_metadata_empty
):
caplog.set_level("DEBUG")
timdex_metadata_empty.rebuild_dataset_metadata()
# use a fresh dataset from timdex_dataset_with_runs location
td = TIMDEXDataset(timdex_dataset_with_runs.location)
td.metadata.rebuild_dataset_metadata()


def test_tdm_init_metadata_file_found_success(timdex_metadata):
Expand Down Expand Up @@ -321,15 +325,15 @@ def test_tdm_merge_append_deltas_deletes_append_deltas(
assert not os.listdir(timdex_metadata_merged_deltas.append_deltas_path)


def test_tdm_prepare_duckdb_secret_and_extensions_home_env_var_set_and_valid(
def test_td_prepare_duckdb_secret_and_extensions_home_env_var_set_and_valid(
monkeypatch, tmp_path_factory, timdex_dataset_with_runs
):
preset_home = tmp_path_factory.mktemp("my-account")
monkeypatch.setenv("HOME", str(preset_home))

tdm = TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)
td = TIMDEXDataset(timdex_dataset_with_runs.location)
df = (
tdm.conn.query(
td.conn.query(
"""
select
current_setting('secret_directory') as secret_directory,
Expand All @@ -344,15 +348,15 @@ def test_tdm_prepare_duckdb_secret_and_extensions_home_env_var_set_and_valid(
assert df.extension_directory == "" # expected and okay when HOME set


def test_tdm_prepare_duckdb_secret_and_extensions_home_env_var_unset(
def test_td_prepare_duckdb_secret_and_extensions_home_env_var_unset(
monkeypatch, timdex_dataset_with_runs
):
monkeypatch.delenv("HOME", raising=False)

tdm = TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)
td = TIMDEXDataset(timdex_dataset_with_runs.location)

df = (
tdm.conn.query(
td.conn.query(
"""
select
current_setting('secret_directory') as secret_directory,
Expand All @@ -367,15 +371,15 @@ def test_tdm_prepare_duckdb_secret_and_extensions_home_env_var_unset(
assert df.extension_directory == "/tmp/.duckdb/extensions"


def test_tdm_prepare_duckdb_secret_and_extensions_home_env_var_set_but_empty(
def test_td_prepare_duckdb_secret_and_extensions_home_env_var_set_but_empty(
monkeypatch, timdex_dataset_with_runs
):
monkeypatch.setenv("HOME", "") # simulate AWS Lambda environment

tdm = TIMDEXDatasetMetadata(timdex_dataset_with_runs.location)
td = TIMDEXDataset(timdex_dataset_with_runs.location)

df = (
tdm.conn.query(
td.conn.query(
"""
select
current_setting('secret_directory') as secret_directory,
Expand All @@ -390,14 +394,16 @@ def test_tdm_prepare_duckdb_secret_and_extensions_home_env_var_set_but_empty(
assert df.extension_directory == "/tmp/.duckdb/extensions"


def test_tdm_preload_current_records_default_false(tmp_path):
tdm = TIMDEXDatasetMetadata(str(tmp_path))
assert tdm.preload_current_records is False
def test_td_preload_current_records_default_false(tmp_path):
td = TIMDEXDataset(str(tmp_path))
assert td.preload_current_records is False
assert td.metadata.preload_current_records is False


def test_tdm_preload_current_records_flag_true(tmp_path):
tdm = TIMDEXDatasetMetadata(str(tmp_path), preload_current_records=True)
assert tdm.preload_current_records is True
def test_td_preload_current_records_flag_true(tmp_path):
td = TIMDEXDataset(str(tmp_path), preload_current_records=True)
assert td.preload_current_records is True
assert td.metadata.preload_current_records is True


def test_tdm_preload_false_no_temp_table(timdex_dataset_with_runs):
Expand Down
2 changes: 0 additions & 2 deletions tests/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ def test_dataset_load_current_records_gets_correct_same_day_full_run(
):
# ensure metadata exists for this dataset
timdex_dataset_same_day_runs.metadata.rebuild_dataset_metadata()
timdex_dataset_same_day_runs.metadata.refresh()
df = timdex_dataset_same_day_runs.read_dataframe(
table="current_records", run_type="full"
)
Expand All @@ -266,7 +265,6 @@ def test_dataset_load_current_records_gets_correct_same_day_daily_runs_ordering(
timdex_dataset_same_day_runs,
):
timdex_dataset_same_day_runs.metadata.rebuild_dataset_metadata()
timdex_dataset_same_day_runs.metadata.refresh()
first_record = next(
timdex_dataset_same_day_runs.read_dicts_iter(
table="current_records", run_type="daily"
Expand Down
2 changes: 1 addition & 1 deletion timdex_dataset_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from timdex_dataset_api.metadata import TIMDEXDatasetMetadata
from timdex_dataset_api.record import DatasetRecord

__version__ = "3.8.0"
__version__ = "3.9.0"

__all__ = [
"DatasetEmbedding",
Expand Down
86 changes: 62 additions & 24 deletions timdex_dataset_api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
import pandas as pd
import pyarrow as pa
import pyarrow.dataset as ds
from duckdb import DuckDBPyConnection
from duckdb_engine import ConnectionWrapper
from pyarrow import fs
from sqlalchemy import MetaData, Table, create_engine
from sqlalchemy.types import ARRAY, FLOAT

from timdex_dataset_api.config import configure_logger
from timdex_dataset_api.embeddings import TIMDEXEmbeddings
from timdex_dataset_api.metadata import TIMDEXDatasetMetadata
from timdex_dataset_api.utils import DuckDBConnectionFactory

if TYPE_CHECKING:
from timdex_dataset_api.record import DatasetRecord # pragma: nocover
Expand Down Expand Up @@ -78,6 +81,10 @@ class TIMDEXDatasetConfig:
from a dataset; pyarrow default is 16
- fragment_read_ahead: number of fragments to optimistically read ahead when batch
reaching from a dataset; pyarrow default is 4
- duckdb_join_batch_size: batch size for keyset pagination when joining metadata

Note: DuckDB connection settings (memory_limit, threads) are handled by
DuckDBConnectionFactory via TDA_DUCKDB_MEMORY_LIMIT and TDA_DUCKDB_THREADS env vars.
"""

read_batch_size: int = field(
Expand Down Expand Up @@ -132,18 +139,21 @@ def __init__(
self.partition_columns = TIMDEX_DATASET_PARTITION_COLUMNS
self.dataset = self.load_pyarrow_dataset()

# dataset metadata
self.metadata = TIMDEXDatasetMetadata(
location,
preload_current_records=preload_current_records,
)
# create DuckDB connection used by all classes
self.conn_factory = DuckDBConnectionFactory(location_scheme=self.location_scheme)
self.conn = self.conn_factory.create_connection()

# DuckDB context
self.conn = self.setup_duckdb_context()
# create schemas
self._create_duckdb_schemas()

# dataset embeddings
# composed components receive self
self.metadata = TIMDEXDatasetMetadata(self)
self.embeddings = TIMDEXEmbeddings(self)
Comment on lines +150 to 151
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now have parity for how these components are attached to TIMDEXDataset:

  • pass an instance of self
  • assume those components will utilize things from self.timdex_dataset as needed

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smart consolidation!


# SQLAlchemy (SA) reflection after components have set up their views
self.sa_tables: dict[str, dict[str, Table]] = {}
self.reflect_sa_tables()

@property
def location_scheme(self) -> Literal["file", "s3"]:
scheme = urlparse(self.location).scheme
Expand All @@ -158,7 +168,7 @@ def data_records_root(self) -> str:
return f"{self.location.removesuffix('/')}/data/records" # type: ignore[union-attr]

def refresh(self) -> None:
"""Fully reload TIMDEXDataset instance."""
"""Refresh dataset by fully reinitializing."""
self.__init__( # type: ignore[misc]
self.location,
config=self.config,
Expand Down Expand Up @@ -245,24 +255,54 @@ def get_s3_filesystem() -> fs.FileSystem:
session_token=credentials.token,
)

def setup_duckdb_context(self) -> DuckDBPyConnection:
"""Create a DuckDB connection that metadata and data query and retrieval.
def _create_duckdb_schemas(self) -> None:
"""Create DuckDB schemas used by all components."""
self.conn.execute("create schema metadata;")
self.conn.execute("create schema data;")

This method extends TIMDEXDatasetMetadata's pre-existing DuckDB connection, adding
a 'data' schema and any other configurations needed.
def reflect_sa_tables(self, schemas: list[str] | None = None) -> None:
"""Reflect SQLAlchemy metadata for DuckDB schemas.

This centralizes SA reflection for all composed components. Reflected tables
are stored in self.sa_tables as {schema: {table_name: Table}}.

Args:
schemas: list of schemas to reflect; defaults to ["metadata", "data"]
"""
start_time = time.perf_counter()
schemas = schemas or ["metadata", "data"]

conn = self.metadata.conn
engine = create_engine(
"duckdb://",
creator=lambda: ConnectionWrapper(self.conn),
)

for schema in schemas:
db_metadata = MetaData()
db_metadata.reflect(bind=engine, schema=schema, views=True)

# store tables in flat dict keyed by table name (without schema prefix)
self.sa_tables[schema] = {
table_name.removeprefix(f"{schema}."): table
for table_name, table in db_metadata.tables.items()
}

# create data schema
conn.execute("""create schema data;""")
# type fixup for embedding_vector column (DuckDB LIST -> SA ARRAY)
if "embeddings" in self.sa_tables.get("data", {}):
self.sa_tables["data"]["embeddings"].c.embedding_vector.type = ARRAY(FLOAT)

logger.debug(
"DuckDB context created for TIMDEXDataset, "
f"{round(time.perf_counter()-start_time,2)}s"
f"SQLAlchemy reflection complete for schemas {schemas}, "
f"{round(time.perf_counter() - start_time, 3)}s"
)
return conn

def get_sa_table(self, schema: str, table: str) -> Table:
"""Get a reflected SQLAlchemy Table by schema and table name."""
if schema not in self.sa_tables:
raise ValueError(f"Schema '{schema}' not found in reflected SA tables.")
if table not in self.sa_tables[schema]:
raise ValueError(f"Table '{table}' not found in schema '{schema}'.")
return self.sa_tables[schema][table]

def write(
self,
Expand Down Expand Up @@ -326,7 +366,7 @@ def write(
if write_append_deltas:
for written_file in written_files:
self.metadata.write_append_delta_duckdb(written_file.path) # type: ignore[attr-defined]
self.metadata.refresh()
self.refresh()

self.log_write_statistics(start_time, written_files)

Expand Down Expand Up @@ -575,9 +615,7 @@ def _iter_data_chunks(self, data_query: str) -> Iterator[pa.RecordBatch]:
)
finally:
if self.location_scheme == "s3":
self.conn.execute(
f"""set threads={self.metadata.config.duckdb_connection_threads};"""
)
self.conn.execute(f"""set threads={self.conn_factory.threads};""")

def read_dataframes_iter(
self,
Expand Down
Loading