From 23195a91459b1ad11d94357efc1480f782efa72b Mon Sep 17 00:00:00 2001 From: cka-y Date: Tue, 2 Jun 2026 14:09:29 -0400 Subject: [PATCH 1/8] feat: automatic pydantic model generation from json schema --- .github/workflows/ci.yml | 63 ++++++++ pyproject.toml | 1 + schema.conf | 3 + scripts/generate_models.sh | 85 ++++++++++ src/gtfs_diff/engine.py | 70 ++++---- src/gtfs_diff/models.py | 324 +++++++++++++++++++++++-------------- tests/test_engine.py | 12 +- tests/test_models.py | 31 ++-- 8 files changed, 413 insertions(+), 176 deletions(-) create mode 100644 .github/workflows/ci.yml create mode 100644 schema.conf create mode 100755 scripts/generate_models.sh diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..84cd58b --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,63 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + +jobs: + changes: + runs-on: ubuntu-latest + permissions: + pull-requests: read + outputs: + src: ${{ steps.filter.outputs.src }} + models: ${{ steps.filter.outputs.models }} + steps: + - uses: actions/checkout@v6 + - uses: dorny/paths-filter@v4 + id: filter + with: + filters: | + src: + - 'src/**' + - 'tests/**' + - 'pyproject.toml' + models: + - 'schema.conf' + - 'scripts/generate_models.sh' + - 'src/gtfs_diff/models.py' + - '.github/workflows/ci.yml' + + test: + needs: changes + if: needs.changes.outputs.src == 'true' + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.12", "3.14"] + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + - run: pip install -e '.[dev]' + - run: pytest --tb=short + + models-freshness: + needs: changes + if: needs.changes.outputs.models == 'true' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: "3.12" + - run: pip install 'datamodel-code-generator[ruff]>=0.59' + - run: ./scripts/generate_models.sh + - name: Check models.py is up to date + run: | + if ! git diff --exit-code src/gtfs_diff/models.py; then + echo "::error::src/gtfs_diff/models.py is out of date with the schema. Run ./scripts/generate_models.sh to regenerate." + exit 1 + fi diff --git a/pyproject.toml b/pyproject.toml index e5c7f60..6611b2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ dev = [ "pytest>=7.0", "pytest-cov", + "datamodel-code-generator[ruff]>=0.59", ] [project.scripts] diff --git a/schema.conf b/schema.conf new file mode 100644 index 0000000..f64bca1 --- /dev/null +++ b/schema.conf @@ -0,0 +1,3 @@ +# Schema version used to generate src/gtfs_diff/models.py +# Run ./scripts/generate_models.sh to regenerate after changing this. +SCHEMA_VERSION=v2-rc1 diff --git a/scripts/generate_models.sh b/scripts/generate_models.sh new file mode 100755 index 0000000..950e05f --- /dev/null +++ b/scripts/generate_models.sh @@ -0,0 +1,85 @@ +#!/usr/bin/env bash +# Generate Pydantic v2 models from the GTFS Diff JSON Schema. +# +# Usage: +# ./scripts/generate_models.sh # use version from schema.conf +# ./scripts/generate_models.sh v2-rc1 # fetch a specific version from GitHub +# ./scripts/generate_models.sh /path/to/local.json # use a local file +set -euo pipefail + +REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)" +OUTPUT="$REPO_ROOT/src/gtfs_diff/models.py" +SCHEMA_REPO="MobilityData/gtfs-diff" +SCHEMA_BRANCH="main" + +# --- Resolve input: argument > schema.conf ----------------------------------- +if [ $# -ge 1 ]; then + INPUT="$1" +elif [ -f "$REPO_ROOT/schema.conf" ]; then + # shellcheck source=../schema.conf + source "$REPO_ROOT/schema.conf" + INPUT="${SCHEMA_VERSION:?SCHEMA_VERSION not set in schema.conf}" + echo "Using version from schema.conf: $INPUT" +else + echo "Usage: $0 []" >&2 + echo " e.g. $0 v2-rc1" >&2 + echo " or set SCHEMA_VERSION in schema.conf" >&2 + exit 1 +fi + +# --- Resolve schema source --------------------------------------------------- +if [ -f "$INPUT" ]; then + SCHEMA_FILE="$INPUT" + echo "Using local schema: $SCHEMA_FILE" +else + VERSION="$INPUT" + SCHEMA_URL="https://raw.githubusercontent.com/$SCHEMA_REPO/$SCHEMA_BRANCH/spec/v2/json_schema/${VERSION}.json" + SCHEMA_FILE="$(mktemp "${TMPDIR:-/tmp}/gtfs_diff_schema_XXXXXX.json")" + trap 'rm -f "$SCHEMA_FILE"' EXIT + + echo "Fetching schema from $SCHEMA_URL ..." + curl -fsSL "$SCHEMA_URL" -o "$SCHEMA_FILE" +fi + +# --- Ensure datamodel-code-generator is available ---------------------------- +if ! command -v datamodel-codegen &>/dev/null; then + echo "Installing datamodel-code-generator ..." + pip install 'datamodel-code-generator[ruff]>=0.59' +fi + +# --- Generate ---------------------------------------------------------------- +echo "Generating models → $OUTPUT" +datamodel-codegen \ + --formatters ruff-format ruff-check \ + --input "$SCHEMA_FILE" \ + --input-file-type jsonschema \ + --output-model-type pydantic_v2.BaseModel \ + --target-python-version 3.10 \ + --use-standard-collections \ + --use-union-operator \ + --snake-case-field \ + --collapse-root-models \ + --enum-field-as-literal all \ + --field-constraints \ + --use-schema-description \ + --class-name GtfsDiff \ + --output "$OUTPUT" + +# --- Post-process: clean header, append __all__ ------------------------------ +# Remove the timestamp and temp filename so re-generation doesn't create noisy diffs. +sed -i.bak '/^# timestamp:/d; /^# filename:/d' "$OUTPUT" && rm -f "$OUTPUT.bak" + +# Collect class names and append __all__. +CLASSES=$(grep -oE '^class ([A-Za-z_][A-Za-z0-9_]*)' "$OUTPUT" | awk '{print $2}') +{ + echo "" + echo "" + echo "__all__ = [" + for cls in $CLASSES; do + echo " \"$cls\"," + done + echo "]" +} >> "$OUTPUT" + +COUNT=$(echo "$CLASSES" | wc -w | tr -d ' ') +echo "Done — $COUNT model(s) generated." diff --git a/src/gtfs_diff/engine.py b/src/gtfs_diff/engine.py index b273d08..08fda8f 100644 --- a/src/gtfs_diff/engine.py +++ b/src/gtfs_diff/engine.py @@ -6,7 +6,7 @@ (line_number, raw_csv_string) for every row in each file. For typical transit feeds this is fine. For very large feeds (stop_times.txt can exceed 10 M rows) a disk-backed index (e.g. SQLite) would be more appropriate; that is left as a -future optimisation. +future optimization. """ from __future__ import annotations @@ -27,6 +27,7 @@ FeedSource, FieldChange, FileDiff, + FileStats, FileSummary, GtfsDiff, Metadata, @@ -71,7 +72,7 @@ def _trace(msg: str) -> None: # --------------------------------------------------------------------------- def _row_to_csv(values: list[str]) -> str: - """Serialise a list of string values to a single CSV line (no trailing newline).""" + """Serialize a list of string values to a single CSV line (no trailing newline).""" buf = io.StringIO() writer = csv.writer(buf, lineterminator="") writer.writerow(values) @@ -280,15 +281,9 @@ def _diff_file_added( file_action="added", columns_added=columns_added, columns_deleted=[], - row_changes=None, - truncated=None, - ) - summary = FileSummary( - file_name=file_name, - status="added", - columns_added_count=len(columns_added), - columns_deleted_count=0, + stats=FileStats(columns_added_count=len(columns_added), columns_deleted_count=0), ) + summary = FileSummary(file_name=file_name, status="added") return file_diff, summary @@ -308,15 +303,9 @@ def _diff_file_deleted( file_action="deleted", columns_added=[], columns_deleted=columns_deleted, - row_changes=None, - truncated=None, - ) - summary = FileSummary( - file_name=file_name, - status="deleted", - columns_added_count=0, - columns_deleted_count=len(columns_deleted), + stats=FileStats(columns_added_count=0, columns_deleted_count=len(columns_deleted)), ) + summary = FileSummary(file_name=file_name, status="deleted") return file_diff, summary @@ -522,16 +511,17 @@ def _remaining(used: int) -> int | None: columns_deleted=columns_deleted, row_changes=row_changes, truncated=truncated, + stats=FileStats( + total_rows_base=len(base_index), + total_rows_new=len(new_index), + columns_added_count=len(columns_added), + columns_deleted_count=len(columns_deleted), + rows_added_count=true_added, + rows_deleted_count=true_deleted, + rows_modified_count=true_modified, + ), ) - summary = FileSummary( - file_name=file_name, - status="modified", - columns_added_count=len(columns_added), - columns_deleted_count=len(columns_deleted), - rows_added_count=true_added, - rows_deleted_count=true_deleted, - rows_modified_count=true_modified, - ) + summary = FileSummary(file_name=file_name, status="modified") return file_diff, summary @@ -607,13 +597,15 @@ def diff_feeds( # Per spec: file_diffs[] contains only *changed* files. # Skip files present in both feeds with no actual changes. + stats = file_diff.stats if ( file_summary.status == "modified" - and not file_summary.columns_added_count - and not file_summary.columns_deleted_count - and not file_summary.rows_added_count - and not file_summary.rows_deleted_count - and not file_summary.rows_modified_count + and stats is not None + and not stats.columns_added_count + and not stats.columns_deleted_count + and not stats.rows_added_count + and not stats.rows_deleted_count + and not stats.rows_modified_count ): continue @@ -625,12 +617,15 @@ def diff_feeds( files_deleted = sum(1 for s in file_summaries if s.status == "deleted") files_modified = sum(1 for s in file_summaries if s.status == "modified") + def _stat(attr: str) -> int: + return sum(getattr(fd.stats, attr, 0) or 0 for fd in file_diffs if fd.stats) + total_changes = ( - sum(s.rows_added_count or 0 for s in file_summaries) - + sum(s.rows_deleted_count or 0 for s in file_summaries) - + sum(s.rows_modified_count or 0 for s in file_summaries) - + sum(s.columns_added_count or 0 for s in file_summaries) - + sum(s.columns_deleted_count or 0 for s in file_summaries) + _stat("rows_added_count") + + _stat("rows_deleted_count") + + _stat("rows_modified_count") + + _stat("columns_added_count") + + _stat("columns_deleted_count") + files_added + files_deleted ) @@ -648,6 +643,7 @@ def diff_feeds( files_added_count=files_added, files_deleted_count=files_deleted, files_modified_count=files_modified, + files_not_compared_count=0, files=file_summaries, ) result = GtfsDiff(metadata=metadata, summary=summary, file_diffs=file_diffs) diff --git a/src/gtfs_diff/models.py b/src/gtfs_diff/models.py index 6665e5e..5d7daac 100644 --- a/src/gtfs_diff/models.py +++ b/src/gtfs_diff/models.py @@ -1,159 +1,233 @@ -"""Pydantic v2 models for the GTFS Diff v2 output format.""" +# generated by datamodel-codegen: from __future__ import annotations +from pydantic import AwareDatetime, BaseModel, Field +from typing import Literal -from datetime import datetime -from typing import Literal, Optional -from pydantic import BaseModel, ConfigDict, Field +class UnsupportedFile(BaseModel): + file_name: str = Field(..., description="File name as it appears in the archive.") + present_in: Literal["base", "new", "both"] = Field( + ..., description="Which archive(s) contain this file." + ) class ColumnEntry(BaseModel): - """A column that was added or deleted, with its name and original position.""" - - model_config = ConfigDict(populate_by_name=True) - - name: str - position: int = Field(..., ge=1) + name: str = Field(..., description="Column name.") + position: int = Field( + ..., description="1-based position of this column in the CSV header row.", ge=1 + ) class FeedSource(BaseModel): - """Identifies a GTFS feed by its source URL and the time it was downloaded.""" - - model_config = ConfigDict(populate_by_name=True) - - source: str - downloaded_at: datetime - - -class UnsupportedFile(BaseModel): - """A file present in one or both feeds that the diff engine does not support.""" - - model_config = ConfigDict(populate_by_name=True) - - file_name: str - present_in: Literal["base", "new", "both"] - - -class Metadata(BaseModel): - """Top-level metadata describing the diff run and both feed sources.""" - - model_config = ConfigDict(populate_by_name=True) - - schema_version: str - generated_at: datetime - row_changes_cap_per_file: Optional[int] = Field(None, ge=0) - base_feed: FeedSource - new_feed: FeedSource - unsupported_files: list[UnsupportedFile] + source: str = Field(..., description="URL or local path to the GTFS archive.") + downloaded_at: AwareDatetime = Field( + ..., description="ISO 8601 timestamp of when the feed was downloaded." + ) class FileSummary(BaseModel): - """High-level change counts for a single GTFS file.""" - - model_config = ConfigDict(populate_by_name=True) - - file_name: str - status: Literal["added", "deleted", "modified"] - columns_added_count: Optional[int] = Field(None, ge=0) - columns_deleted_count: Optional[int] = Field(None, ge=0) - rows_added_count: Optional[int] = Field(None, ge=0) - rows_deleted_count: Optional[int] = Field(None, ge=0) - rows_modified_count: Optional[int] = Field(None, ge=0) - - -class Summary(BaseModel): - """Aggregate change counts across all GTFS files in the diff.""" + file_name: str = Field(..., description="Name of the GTFS file.") + status: Literal["added", "deleted", "modified", "not_compared"] = Field( + ..., description="File-level status." + ) - model_config = ConfigDict(populate_by_name=True) - total_changes: int = Field(..., ge=0) - files_added_count: int = Field(..., ge=0) - files_deleted_count: int = Field(..., ge=0) - files_modified_count: int = Field(..., ge=0) - files: list[FileSummary] - - -class FieldChange(BaseModel): - """The before and after values for a single field within a modified row.""" - - model_config = ConfigDict(populate_by_name=True) - - field: str - base_value: str - new_value: str +class Truncated(BaseModel): + is_truncated: Literal[True] = Field(..., description="Always true when present.") + omitted_count: int = Field( + ..., description="Number of row changes omitted due to the cap.", ge=1 + ) + + +class NotComparedReason(BaseModel): + code: str = Field( + ..., + description='Machine-readable reason code (e.g. "id_churn", "missing_primary_key", "file_too_large").', + ) + message: str = Field( + ..., + description="Human-readable explanation of why the file or column was not compared.", + ) + + +class IgnoredColumn(BaseModel): + column: str = Field(..., description="The column name that was ignored.") + reason: NotComparedReason + + +class ColumnStat(BaseModel): + column: str = Field(..., description="The column name.") + modifications_count: int = Field( + ..., + description="Number of modified rows that had a change in this column.", + ge=0, + ) + modifications_percentage: float = Field( + ..., + description="modifications_count as a percentage of total modified rows.", + ge=0.0, + le=100.0, + ) class RowAdded(BaseModel): - """A row that exists only in the new feed.""" - - model_config = ConfigDict(populate_by_name=True) - - identifier: dict[str, str] - raw_value: str - new_line_number: int = Field(..., ge=1) + identifier: dict[str, str] = Field( + ..., description="Primary key values identifying this row." + ) + raw_value: str = Field( + ..., + description="The CSV row from the new file as a comma-separated string. Field order matches the columns array, which preserves the raw CSV column order of the base feed (new-feed-only columns appended).", + ) + new_line_number: int = Field( + ..., description="1-based line number of this row in the new CSV file.", ge=1 + ) class RowDeleted(BaseModel): - """A row that exists only in the base feed.""" + identifier: dict[str, str] = Field( + ..., description="Primary key values identifying this row." + ) + raw_value: str = Field( + ..., + description="The CSV row from the base file as a comma-separated string. Field order matches the columns array, which preserves the raw CSV column order of the base feed (new-feed-only columns appended).", + ) + base_line_number: int = Field( + ..., description="1-based line number of this row in the base CSV file.", ge=1 + ) - model_config = ConfigDict(populate_by_name=True) - identifier: dict[str, str] - raw_value: str - base_line_number: int = Field(..., ge=1) +class FieldChange(BaseModel): + field: str = Field(..., description="The column name that changed.") + base_value: str = Field(..., description="The value in the base feed.") + new_value: str = Field(..., description="The value in the new feed.") class RowModified(BaseModel): - """A row present in both feeds whose field values differ.""" - - model_config = ConfigDict(populate_by_name=True) - - identifier: dict[str, str] - raw_value: str - base_line_number: int = Field(..., ge=1) - new_line_number: int = Field(..., ge=1) - field_changes: list[FieldChange] = Field(..., min_length=1) - - -class RowChanges(BaseModel): - """All row-level changes for a file, keyed by primary key columns.""" + identifier: dict[str, str] = Field( + ..., description="Primary key values identifying this row." + ) + raw_value: str = Field( + ..., + description="The base CSV row as a comma-separated string. Field order matches the columns array, which preserves the raw CSV column order of the base feed (new-feed-only columns appended).", + ) + base_line_number: int = Field( + ..., description="1-based line number of this row in the base CSV file.", ge=1 + ) + new_line_number: int = Field( + ..., description="1-based line number of this row in the new CSV file.", ge=1 + ) + field_changes: list[FieldChange] = Field( + ..., description="List of field-level changes.", min_length=1 + ) - model_config = ConfigDict(populate_by_name=True) - primary_key: list[str] = Field(..., min_length=1) - columns: list[str] - added: list[RowAdded] - deleted: list[RowDeleted] - modified: list[RowModified] +class Metadata(BaseModel): + schema_version: str = Field(..., description="The version of the schema.") + generated_at: AwareDatetime = Field( + ..., description="ISO 8601 timestamp of when the diff was generated." + ) + row_changes_cap_per_file: int | None = Field( + ..., + description="Maximum number of row changes included per file in file_diffs. 0 means no row changes are included; null means no cap is applied.", + ge=0, + ) + base_feed: FeedSource + new_feed: FeedSource + unsupported_files: list[UnsupportedFile] -class Truncated(BaseModel): - """Indicates that row changes were capped and some were omitted.""" +class Summary(BaseModel): + total_changes: int = Field( + ..., description="Total number of changes across all files.", ge=0 + ) + files_added_count: int = Field(..., description="Number of files added.", ge=0) + files_deleted_count: int = Field(..., description="Number of files deleted.", ge=0) + files_modified_count: int = Field( + ..., description="Number of files modified.", ge=0 + ) + files_not_compared_count: int = Field( + ..., + description="Number of files that could not be meaningfully compared.", + ge=0, + ) + files: list[FileSummary] - model_config = ConfigDict(populate_by_name=True) - is_truncated: Literal[True] - omitted_count: int = Field(..., ge=1) +class RowChanges(BaseModel): + primary_key: list[str] = Field( + ..., description="Column(s) that uniquely identify a row.", min_length=1 + ) + columns: list[str] = Field( + ..., + description="Union of all columns across both base and new versions. Order matches the base feed's original column order; columns only in the new feed are appended.", + ) + added: list[RowAdded] = Field(..., description="Added rows (capped).") + deleted: list[RowDeleted] = Field(..., description="Deleted rows (capped).") + modified: list[RowModified] = Field(..., description="Modified rows (capped).") + + +class FileStats(BaseModel): + total_rows_base: int | None = Field( + None, description="Total number of rows in the base version of the file.", ge=0 + ) + total_rows_new: int | None = Field( + None, description="Total number of rows in the new version of the file.", ge=0 + ) + columns_added_count: int | None = Field( + None, description="Number of columns added.", ge=0 + ) + columns_deleted_count: int | None = Field( + None, description="Number of columns deleted.", ge=0 + ) + rows_added_count: int | None = Field( + None, description="True count of rows added.", ge=0 + ) + rows_deleted_count: int | None = Field( + None, description="True count of rows deleted.", ge=0 + ) + rows_modified_count: int | None = Field( + None, description="True count of rows modified.", ge=0 + ) + rows_changed_percentage: float | None = Field( + None, + description="Percentage of rows that were added, deleted, or modified relative to the larger of the two versions.", + ge=0.0, + le=100.0, + ) + column_stats: list[ColumnStat] | None = Field( + None, + description="Per-column modification statistics. Only covers modified rows.", + ) class FileDiff(BaseModel): - """Complete diff for a single GTFS file, including column and row changes.""" - - model_config = ConfigDict(populate_by_name=True) - - file_name: str - file_action: Literal["modified", "added", "deleted"] - columns_added: list[ColumnEntry] - columns_deleted: list[ColumnEntry] - row_changes: Optional[RowChanges] = None - truncated: Optional[Truncated] = None + file_name: str = Field(..., description="Name of the GTFS file.") + file_action: Literal["modified", "added", "deleted", "not_compared"] = Field( + ..., description="Action describing how this file changed." + ) + not_compared_reason: NotComparedReason | None = None + ignored_columns: list[IgnoredColumn] | None = Field( + None, + description="Columns excluded from the diff because their values are unreliable (e.g. they reference a file that was not compared).", + ) + columns_added: list[ColumnEntry] | None = Field( + None, + description="Columns added to this file, in the order they appear in the new file's CSV header. Each entry includes the column name and its 1-based position in that header.", + ) + columns_deleted: list[ColumnEntry] | None = Field( + None, + description="Columns deleted from this file, in the order they appeared in the base file's CSV header. Each entry includes the column name and its 1-based position in that header.", + ) + row_changes: RowChanges | None = None + truncated: Truncated | None = None + stats: FileStats | None = None class GtfsDiff(BaseModel): - """Root model for the GTFS Diff v2 output format.""" - - model_config = ConfigDict(populate_by_name=True) + """ + Schema for GTFS Diff v2 output: a single JSON document describing all differences between two GTFS archives. + """ metadata: Metadata summary: Summary @@ -161,18 +235,22 @@ class GtfsDiff(BaseModel): __all__ = [ + "UnsupportedFile", "ColumnEntry", "FeedSource", - "UnsupportedFile", - "Metadata", "FileSummary", - "Summary", - "FieldChange", + "Truncated", + "NotComparedReason", + "IgnoredColumn", + "ColumnStat", "RowAdded", "RowDeleted", + "FieldChange", "RowModified", + "Metadata", + "Summary", "RowChanges", - "Truncated", + "FileStats", "FileDiff", "GtfsDiff", ] diff --git a/tests/test_engine.py b/tests/test_engine.py index 35378f7..f60e94d 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -123,8 +123,8 @@ def test_rows_added_count(self, tmp_path: Path): "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\nS2,Stop Two,3.0,4.0\n", }) result = diff_feeds(base, new) - fs = _get_file_summary(result, "stops.txt") - assert fs.rows_added_count == 1 + fd = _get_file_diff(result, "stops.txt") + assert fd.stats.rows_added_count == 1 def test_rows_added_identifier(self, tmp_path: Path): base = write_zip(tmp_path / "base.zip", { @@ -148,8 +148,8 @@ def test_rows_deleted_count(self, tmp_path: Path): "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", }) result = diff_feeds(base, new) - fs = _get_file_summary(result, "stops.txt") - assert fs.rows_deleted_count == 1 + fd = _get_file_diff(result, "stops.txt") + assert fd.stats.rows_deleted_count == 1 def test_rows_deleted_identifier(self, tmp_path: Path): base = write_zip(tmp_path / "base.zip", { @@ -173,8 +173,8 @@ def test_rows_modified_count(self, tmp_path: Path): "stops.txt": STOPS_HEADER + "S1,Stop One Renamed,1.0,2.0\n", }) result = diff_feeds(base, new) - fs = _get_file_summary(result, "stops.txt") - assert fs.rows_modified_count == 1 + fd = _get_file_diff(result, "stops.txt") + assert fd.stats.rows_modified_count == 1 def test_rows_modified_field_changes(self, tmp_path: Path): base = write_zip(tmp_path / "base.zip", { diff --git a/tests/test_models.py b/tests/test_models.py index 00bf7af..41c4954 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -11,6 +11,7 @@ ColumnEntry, FieldChange, FileDiff, + FileStats, FileSummary, GtfsDiff, Metadata, @@ -67,6 +68,7 @@ def _summary(**kwargs) -> Summary: files_added_count=0, files_deleted_count=0, files_modified_count=0, + files_not_compared_count=0, files=[], ) defaults.update(kwargs) @@ -195,23 +197,32 @@ def test_empty_field_changes_rejected(self): # --------------------------------------------------------------------------- class TestFileSummary: - def test_all_optional_counts_none(self): + def test_valid(self): fs = FileSummary(file_name="stops.txt", status="modified") - assert fs.rows_added_count is None - assert fs.rows_deleted_count is None - assert fs.rows_modified_count is None - assert fs.columns_added_count is None - assert fs.columns_deleted_count is None + assert fs.file_name == "stops.txt" + assert fs.status == "modified" + + def test_not_compared_status(self): + fs = FileSummary(file_name="stops.txt", status="not_compared") + assert fs.status == "not_compared" + + +class TestFileStats: + def test_all_optional_counts_none(self): + stats = FileStats() + assert stats.rows_added_count is None + assert stats.rows_deleted_count is None + assert stats.rows_modified_count is None + assert stats.columns_added_count is None + assert stats.columns_deleted_count is None def test_with_counts(self): - fs = FileSummary( - file_name="stops.txt", - status="modified", + stats = FileStats( rows_added_count=3, rows_deleted_count=1, rows_modified_count=0, ) - assert fs.rows_added_count == 3 + assert stats.rows_added_count == 3 # --------------------------------------------------------------------------- From be7f73190968f8d61842acc47fee0c9a220e2f2b Mon Sep 17 00:00:00 2001 From: cka-y Date: Tue, 2 Jun 2026 14:41:48 -0400 Subject: [PATCH 2/8] added lint + improve ci --- .github/workflows/ci.yml | 28 +- pyproject.toml | 22 + scripts/generate_models.sh | 14 +- scripts/lint-fix.sh | 12 + scripts/lint.sh | 12 + src/gtfs_diff/cli.py | 53 ++- src/gtfs_diff/engine.py | 107 +++-- src/gtfs_diff/gtfs_definitions.py | 36 +- src/gtfs_diff/models.py | 6 +- tests/conftest.py | 1 - tests/test_cli.py | 59 ++- tests/test_engine.py | 692 ++++++++++++++++++++---------- tests/test_models.py | 12 +- 13 files changed, 746 insertions(+), 308 deletions(-) create mode 100755 scripts/lint-fix.sh create mode 100755 scripts/lint.sh diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 84cd58b..a2056ab 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,7 +6,8 @@ on: pull_request: jobs: - changes: + detect-changes: + name: Detect changed paths runs-on: ubuntu-latest permissions: pull-requests: read @@ -30,8 +31,11 @@ jobs: - '.github/workflows/ci.yml' test: - needs: changes - if: needs.changes.outputs.src == 'true' + needs: + - detect-changes + - lint + name: Run tests on Python ${{ matrix.python-version }} + if: needs.detect-changes.outputs.src == 'true' runs-on: ubuntu-latest strategy: matrix: @@ -44,9 +48,23 @@ jobs: - run: pip install -e '.[dev]' - run: pytest --tb=short + lint: + needs: detect-changes + name: Lint and format check + if: needs.detect-changes.outputs.src == 'true' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: "3.12" + - run: pip install -e '.[dev]' + - run: ./scripts/lint.sh + models-freshness: - needs: changes - if: needs.changes.outputs.models == 'true' + needs: detect-changes + name: Check models.py is up to date + if: needs.detect-changes.outputs.models == 'true' runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 diff --git a/pyproject.toml b/pyproject.toml index 6611b2f..4932a02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dev = [ "pytest>=7.0", "pytest-cov", "datamodel-code-generator[ruff]>=0.59", + "ruff>=0.11", ] [project.scripts] @@ -28,3 +29,24 @@ packages = ["src/gtfs_diff"] [tool.pytest.ini_options] testpaths = ["tests"] + +[tool.ruff] +target-version = "py310" +src = ["src", "tests"] +exclude = ["src/gtfs_diff/models.py"] + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "F", # pyflakes + "I", # isort + "UP", # pyupgrade + "B", # flake8-bugbear + "SIM", # flake8-simplify +] +ignore = [ + "B905", # zip-without-strict — false positives on pre-padded rows +] + +[tool.ruff.format] +exclude = ["src/gtfs_diff/models.py"] diff --git a/scripts/generate_models.sh b/scripts/generate_models.sh index 950e05f..a2dacd2 100755 --- a/scripts/generate_models.sh +++ b/scripts/generate_models.sh @@ -65,9 +65,17 @@ datamodel-codegen \ --class-name GtfsDiff \ --output "$OUTPUT" -# --- Post-process: clean header, append __all__ ------------------------------ -# Remove the timestamp and temp filename so re-generation doesn't create noisy diffs. -sed -i.bak '/^# timestamp:/d; /^# filename:/d' "$OUTPUT" && rm -f "$OUTPUT.bak" +# --- Post-process: replace header, append __all__ ---------------------------- +# Replace the codegen header with a clear auto-generated notice. +{ + echo "# AUTO-GENERATED — DO NOT EDIT" + echo "# This file is generated from the GTFS Diff JSON Schema." + echo "# To regenerate: ./scripts/generate_models.sh" + echo "# Schema source: https://github.com/$SCHEMA_REPO" + echo "" + # Strip the original codegen comment block (everything before the first blank line). + sed -n '/^$/,$p' "$OUTPUT" +} > "$OUTPUT.tmp" && mv "$OUTPUT.tmp" "$OUTPUT" # Collect class names and append __all__. CLASSES=$(grep -oE '^class ([A-Za-z_][A-Za-z0-9_]*)' "$OUTPUT" | awk '{print $2}') diff --git a/scripts/lint-fix.sh b/scripts/lint-fix.sh new file mode 100755 index 0000000..2c450de --- /dev/null +++ b/scripts/lint-fix.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +# Auto-fix lint violations and reformat code. +# Usage: ./scripts/lint-fix.sh +set -euo pipefail + +echo "Fixing lint violations ..." +ruff check --fix src/ tests/ + +echo "Formatting ..." +ruff format src/ tests/ + +echo "Done." diff --git a/scripts/lint.sh b/scripts/lint.sh new file mode 100755 index 0000000..25a957f --- /dev/null +++ b/scripts/lint.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +# Check lint and formatting (exits non-zero on violations). +# Usage: ./scripts/lint.sh +set -euo pipefail + +echo "Checking lint rules ..." +ruff check src/ tests/ + +echo "Checking formatting ..." +ruff format --check src/ tests/ + +echo "All clean." diff --git a/src/gtfs_diff/cli.py b/src/gtfs_diff/cli.py index cb2fa35..5371563 100644 --- a/src/gtfs_diff/cli.py +++ b/src/gtfs_diff/cli.py @@ -13,16 +13,33 @@ @click.version_option(version="0.1.0", prog_name="gtfs-diff-engine") @click.argument("base_feed", type=click.Path(exists=True, path_type=Path)) @click.argument("new_feed", type=click.Path(exists=True, path_type=Path)) -@click.option("--output", "-o", type=click.Path(path_type=Path), default=None, - help="Write JSON output to FILE instead of stdout.") -@click.option("--cap", "-c", type=int, default=None, - help="Max row changes per file (0 = omit row-level detail).") -@click.option("--pretty/--no-pretty", default=True, - help="Pretty-print JSON (default: --pretty).") -@click.option("--base-downloaded-at", default=None, - help="ISO 8601 datetime for when base was downloaded.") -@click.option("--new-downloaded-at", default=None, - help="ISO 8601 datetime for when new was downloaded.") +@click.option( + "--output", + "-o", + type=click.Path(path_type=Path), + default=None, + help="Write JSON output to FILE instead of stdout.", +) +@click.option( + "--cap", + "-c", + type=int, + default=None, + help="Max row changes per file (0 = omit row-level detail).", +) +@click.option( + "--pretty/--no-pretty", default=True, help="Pretty-print JSON (default: --pretty)." +) +@click.option( + "--base-downloaded-at", + default=None, + help="ISO 8601 datetime for when base was downloaded.", +) +@click.option( + "--new-downloaded-at", + default=None, + help="ISO 8601 datetime for when new was downloaded.", +) def main( base_feed: Path, new_feed: Path, @@ -38,8 +55,12 @@ def main( NEW_FEED: path to the new GTFS feed (zip or directory) """ try: - base_dt = datetime.fromisoformat(base_downloaded_at) if base_downloaded_at else None - new_dt = datetime.fromisoformat(new_downloaded_at) if new_downloaded_at else None + base_dt = ( + datetime.fromisoformat(base_downloaded_at) if base_downloaded_at else None + ) + new_dt = ( + datetime.fromisoformat(new_downloaded_at) if new_downloaded_at else None + ) except ValueError as exc: click.echo(f"Error: {exc}", err=True) sys.exit(1) @@ -55,7 +76,8 @@ def main( except MissingPrimaryKeyError as exc: click.echo( f"ERROR: Cannot process '{exc.file_name}' — " - f"required primary key column(s) {exc.missing_columns} are missing from the file headers.\n" + f"required primary key column(s) {exc.missing_columns} " + f"are missing from the file headers.\n" f"Headers found: {exc.headers}", err=True, ) @@ -64,7 +86,10 @@ def main( click.echo(f"Error: {exc}", err=True) sys.exit(1) - json_str = result.model_dump_json(indent=2, exclude_none=True) if pretty else result.model_dump_json(exclude_none=True) + if pretty: + json_str = result.model_dump_json(indent=2, exclude_none=True) + else: + json_str = result.model_dump_json(exclude_none=True) if output is not None: try: diff --git a/src/gtfs_diff/engine.py b/src/gtfs_diff/engine.py index 08fda8f..f858c88 100644 --- a/src/gtfs_diff/engine.py +++ b/src/gtfs_diff/engine.py @@ -11,15 +11,17 @@ from __future__ import annotations +import configparser import csv import io import sys import time import zipfile +from collections.abc import Callable, Generator from contextlib import contextmanager from datetime import datetime, timezone from pathlib import Path -from typing import Callable, Generator, TextIO +from typing import TextIO from .gtfs_definitions import get_primary_key from .models import ( @@ -40,27 +42,40 @@ UnsupportedFile, ) -SCHEMA_VERSION = "2.0" + +def _read_schema_version() -> str: + conf_path = Path(__file__).resolve().parent.parent.parent / "schema.conf" + parser = configparser.ConfigParser() + parser.read_string("[default]\n" + conf_path.read_text()) + return parser.get("default", "SCHEMA_VERSION") class MissingPrimaryKeyError(ValueError): """Raised when a required primary key column is absent from a file's headers.""" - def __init__(self, file_name: str, missing_columns: list[str], headers: list[str]) -> None: + def __init__( + self, file_name: str, missing_columns: list[str], headers: list[str] + ) -> None: self.file_name = file_name self.missing_columns = missing_columns self.headers = headers super().__init__( - f"'{file_name}': required primary key column(s) {missing_columns} " - f"not found in headers {headers}." + f"'{file_name}': required primary key column(s) " + f"{missing_columns} not found in headers {headers}." ) def _trace(msg: str) -> None: """Print a timestamped progress message with current RSS to stderr.""" import psutil + rss_mb = psutil.Process().memory_info().rss / 1024 / 1024 - print(f"[gtfs-diff {datetime.now().strftime('%H:%M:%S')} {rss_mb:.0f}MB] {msg}", file=sys.stderr, flush=True) + print( + f"[gtfs-diff {datetime.now().strftime('%H:%M:%S')} {rss_mb:.0f}MB] {msg}", + file=sys.stderr, + flush=True, + ) + # A "lazy opener" maps a filename (e.g. "stops.txt") to a zero-arg callable # that opens the file and returns a text stream. @@ -71,6 +86,7 @@ def _trace(msg: str) -> None: # Low-level CSV helpers # --------------------------------------------------------------------------- + def _row_to_csv(values: list[str]) -> str: """Serialize a list of string values to a single CSV line (no trailing newline).""" buf = io.StringIO() @@ -140,8 +156,10 @@ def _read_csv_index( if pk_tuple in index: raise ValueError( - f"{file_name}: duplicate primary key {dict(zip(effective_pk, pk_tuple))} " - f"at line {line_num} (first seen at line {index[pk_tuple][0]})." + f"{file_name}: duplicate primary key " + f"{dict(zip(effective_pk, pk_tuple))} " + f"at line {line_num} " + f"(first seen at line {index[pk_tuple][0]})." ) index[pk_tuple] = (line_num, _row_to_csv(row_vals)) @@ -198,6 +216,7 @@ def _compute_raw_value( # Feed opener # --------------------------------------------------------------------------- + @contextmanager def _open_feed(path: str | Path) -> Generator[LazyOpeners, None, None]: """Open a GTFS feed (zip archive or directory) and yield lazy file openers. @@ -215,8 +234,10 @@ def _open_feed(path: str | Path) -> Generator[LazyOpeners, None, None]: if path.is_dir(): openers: LazyOpeners = {} for txt_file in sorted(path.glob("*.txt")): + def _make_opener(p: Path) -> Callable[[], TextIO]: return lambda: p.open(encoding="utf-8-sig") + openers[txt_file.name] = _make_opener(txt_file) yield openers @@ -233,8 +254,10 @@ def _make_opener(p: Path) -> Callable[[], TextIO]: openers = {} for basename, internal_path in name_map.items(): + def _make_opener(ip: str) -> Callable[[], TextIO]: # type: ignore[misc] return lambda: io.TextIOWrapper(zf.open(ip), encoding="utf-8-sig") + openers[basename] = _make_opener(internal_path) yield openers finally: @@ -250,6 +273,7 @@ def _make_opener(ip: str) -> Callable[[], TextIO]: # type: ignore[misc] # Per-file diff # --------------------------------------------------------------------------- + def _diff_file( file_name: str, base_opener: Callable[[], TextIO] | None, @@ -273,15 +297,16 @@ def _diff_file_added( with new_opener() as f: new_headers = _read_headers(f) columns_added = [ - ColumnEntry(name=col, position=i + 1) - for i, col in enumerate(new_headers) + ColumnEntry(name=col, position=i + 1) for i, col in enumerate(new_headers) ] file_diff = FileDiff( file_name=file_name, file_action="added", columns_added=columns_added, columns_deleted=[], - stats=FileStats(columns_added_count=len(columns_added), columns_deleted_count=0), + stats=FileStats( + columns_added_count=len(columns_added), columns_deleted_count=0 + ), ) summary = FileSummary(file_name=file_name, status="added") return file_diff, summary @@ -295,15 +320,16 @@ def _diff_file_deleted( with base_opener() as f: base_headers = _read_headers(f) columns_deleted = [ - ColumnEntry(name=col, position=i + 1) - for i, col in enumerate(base_headers) + ColumnEntry(name=col, position=i + 1) for i, col in enumerate(base_headers) ] file_diff = FileDiff( file_name=file_name, file_action="deleted", columns_added=[], columns_deleted=columns_deleted, - stats=FileStats(columns_added_count=0, columns_deleted_count=len(columns_deleted)), + stats=FileStats( + columns_added_count=0, columns_deleted_count=len(columns_deleted) + ), ) summary = FileSummary(file_name=file_name, status="deleted") return file_diff, summary @@ -362,7 +388,8 @@ def _scan_modifications( shared_cols = [col for col in base_headers if col in set(new_headers)] candidates: list[tuple[tuple, list[FieldChange], int, int]] = [] - _trace(f" [{file_name}] scanning {len(common_keys):,} common rows for modifications...") + n = len(common_keys) + _trace(f" [{file_name}] scanning {n:,} common rows...") t0 = time.monotonic() for pk_tuple in common_keys: b_line, b_raw = base_index[pk_tuple] @@ -377,7 +404,10 @@ def _scan_modifications( if field_changes: candidates.append((pk_tuple, field_changes, b_line, n_line)) - _trace(f" [{file_name}] scan done in {time.monotonic()-t0:.1f}s — {len(candidates):,} modified") + _trace( + f" [{file_name}] scan done in {time.monotonic() - t0:.1f}s — " + f"{len(candidates):,} modified" + ) return candidates @@ -403,16 +433,24 @@ def _diff_file_modified( _trace(f" [{file_name}] indexing base feed...") t0 = time.monotonic() base_headers, base_index = _read_csv_index(f, pk_cols, file_name=file_name) - _trace(f" [{file_name}] base index done: {len(base_index):,} rows in {time.monotonic()-t0:.1f}s") + _trace( + f" [{file_name}] base index done: {len(base_index):,} " + f"rows in {time.monotonic() - t0:.1f}s" + ) with new_opener() as f: _trace(f" [{file_name}] indexing new feed...") t0 = time.monotonic() new_headers, new_index = _read_csv_index(f, pk_cols, file_name=file_name) - _trace(f" [{file_name}] new index done: {len(new_index):,} rows in {time.monotonic()-t0:.1f}s") + _trace( + f" [{file_name}] new index done: {len(new_index):,} " + f"rows in {time.monotonic() - t0:.1f}s" + ) # Column-level diff - columns_added, columns_deleted, union_columns = _diff_columns(base_headers, new_headers) + columns_added, columns_deleted, union_columns = _diff_columns( + base_headers, new_headers + ) # Row-level diff base_keys = set(base_index) @@ -428,7 +466,12 @@ def _diff_file_modified( file_name, common_keys, base_index, new_index, base_headers, new_headers ) true_modified = len(modified_candidates) - _trace(f" [{file_name}] row diff summary — added={true_added:,} deleted={true_deleted:,} modified={true_modified:,}") + _trace( + f" [{file_name}] row diff summary — " + f"added={true_added:,} " + f"deleted={true_deleted:,} " + f"modified={true_modified:,}" + ) # Determine row-changes output based on cap. # cap=0 means "summary counts only" — row_changes is omitted from the output @@ -456,7 +499,9 @@ def _remaining(used: int) -> int | None: identifier = {col: n_dict.get(col, "") for col in pk_cols} raw_value = _compute_raw_value(n_dict, union_columns, new_header_set) added_rows.append( - RowAdded(identifier=identifier, raw_value=raw_value, new_line_number=n_line) + RowAdded( + identifier=identifier, raw_value=raw_value, new_line_number=n_line + ) ) # Fill deleted rows up to remaining cap. @@ -467,12 +512,16 @@ def _remaining(used: int) -> int | None: identifier = {col: b_dict.get(col, "") for col in pk_cols} raw_value = _compute_raw_value(b_dict, union_columns, base_header_set) deleted_rows.append( - RowDeleted(identifier=identifier, raw_value=raw_value, base_line_number=b_line) + RowDeleted( + identifier=identifier, raw_value=raw_value, base_line_number=b_line + ) ) # Fill modified rows up to remaining cap. modified_limit = _remaining(len(added_rows) + len(deleted_rows)) - for pk_tuple, field_changes, b_line, n_line in modified_candidates[:modified_limit]: + for pk_tuple, field_changes, b_line, n_line in modified_candidates[ + :modified_limit + ]: n_raw = new_index[pk_tuple][1] n_dict = _parse_raw_line(n_raw, new_headers) identifier = {col: n_dict.get(col, "") for col in pk_cols} @@ -490,7 +539,9 @@ def _remaining(used: int) -> int | None: total_included = len(added_rows) + len(deleted_rows) + len(modified_rows) total_true = true_added + true_deleted + true_modified if cap is not None and total_true > cap: - truncated = Truncated(is_truncated=True, omitted_count=total_true - total_included) + truncated = Truncated( + is_truncated=True, omitted_count=total_true - total_included + ) row_changes: RowChanges | None = None if include_row_changes: @@ -529,6 +580,7 @@ def _remaining(used: int) -> int | None: # Public API # --------------------------------------------------------------------------- + def diff_feeds( base_path: str | Path, new_path: str | Path, @@ -543,7 +595,8 @@ def diff_feeds( new_path: Path to the new GTFS feed — zip or directory. row_changes_cap_per_file: * ``None`` — include all row changes (default). - * ``0`` — omit all row-level detail (column diffs and counts still computed). + * ``0`` — omit all row-level detail (column diffs and counts + still computed). * ``N > 0`` — include up to *N* row changes per file (added first, then deleted, then modified); a :class:`Truncated` record is attached when the true count exceeds *N*. @@ -631,7 +684,7 @@ def _stat(attr: str) -> int: ) metadata = Metadata( - schema_version=SCHEMA_VERSION, + schema_version=_read_schema_version(), generated_at=now, row_changes_cap_per_file=row_changes_cap_per_file, base_feed=FeedSource(source=str(base_path), downloaded_at=base_downloaded_at), @@ -648,4 +701,4 @@ def _stat(attr: str) -> int: ) result = GtfsDiff(metadata=metadata, summary=summary, file_diffs=file_diffs) _trace("diff_feeds complete") - return result \ No newline at end of file + return result diff --git a/src/gtfs_diff/gtfs_definitions.py b/src/gtfs_diff/gtfs_definitions.py index 1fd0853..0190b51 100644 --- a/src/gtfs_diff/gtfs_definitions.py +++ b/src/gtfs_diff/gtfs_definitions.py @@ -25,17 +25,37 @@ "calendar_dates.txt": ["service_id", "date"], # Fares v1 "fare_attributes.txt": ["fare_id"], - "fare_rules.txt": ["fare_id", "route_id", "origin_id", "destination_id", "contains_id"], + "fare_rules.txt": [ + "fare_id", + "route_id", + "origin_id", + "destination_id", + "contains_id", + ], # Shapes / frequencies / transfers "shapes.txt": ["shape_id", "shape_pt_sequence"], "frequencies.txt": ["trip_id", "start_time"], - "transfers.txt": ["from_stop_id", "to_stop_id", "from_route_id", "to_route_id", "from_trip_id", "to_trip_id"], + "transfers.txt": [ + "from_stop_id", + "to_stop_id", + "from_route_id", + "to_route_id", + "from_trip_id", + "to_trip_id", + ], # Pathways / levels "pathways.txt": ["pathway_id"], "levels.txt": ["level_id"], # Feed metadata "feed_info.txt": [], # single-row file, no primary key - "translations.txt": ["table_name", "field_name", "language", "record_id", "record_sub_id", "field_value"], + "translations.txt": [ + "table_name", + "field_name", + "language", + "record_id", + "record_sub_id", + "field_value", + ], "attributions.txt": ["attribution_id"], # Areas "areas.txt": ["area_id"], @@ -47,7 +67,12 @@ "fare_media.txt": ["fare_media_id"], "fare_products.txt": ["fare_product_id"], "fare_leg_rules.txt": ["leg_group_id"], # partial key, best effort - "fare_transfer_rules.txt": ["from_leg_group_id", "to_leg_group_id", "transfer_count", "duration_limit"], + "fare_transfer_rules.txt": [ + "from_leg_group_id", + "to_leg_group_id", + "transfer_count", + "duration_limit", + ], "timeframes.txt": ["timeframe_group_id", "start_time", "end_time", "service_id"], # Rider categories / booking "rider_categories.txt": ["rider_category_id"], @@ -61,5 +86,6 @@ def get_primary_key(file_name: str) -> list[str] | None: - """Return the primary key columns for a supported GTFS file, or None if unsupported.""" + """Return the primary key columns for a supported GTFS file, + or None if unsupported.""" return GTFS_PRIMARY_KEYS.get(file_name) diff --git a/src/gtfs_diff/models.py b/src/gtfs_diff/models.py index 5d7daac..8c72bbe 100644 --- a/src/gtfs_diff/models.py +++ b/src/gtfs_diff/models.py @@ -1,4 +1,8 @@ -# generated by datamodel-codegen: +# AUTO-GENERATED — DO NOT EDIT +# This file is generated from the GTFS Diff JSON Schema. +# To regenerate: ./scripts/generate_models.sh +# Schema source: https://github.com/MobilityData/gtfs-diff + from __future__ import annotations from pydantic import AwareDatetime, BaseModel, Field diff --git a/tests/conftest.py b/tests/conftest.py index 7accda8..26ab1bc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,6 @@ from tests.helpers import write_zip - # --------------------------------------------------------------------------- # Reusable fixtures # --------------------------------------------------------------------------- diff --git a/tests/test_cli.py b/tests/test_cli.py index 4f8164b..e78eece 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -5,11 +5,9 @@ import json from pathlib import Path -import pytest from click.testing import CliRunner from gtfs_diff.cli import main - from tests.helpers import write_zip STOPS_HEADER = "stop_id,stop_name,stop_lat,stop_lon\n" @@ -89,35 +87,56 @@ def test_cap_stored_in_metadata(self, tmp_path: Path): class TestMissingPrimaryKeyError: def test_exits_nonzero_on_missing_pk_column(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": "stop_name,stop_lat,stop_lon\nStop One,1.0,2.0\n", # stop_id absent - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": "stop_name,stop_lat,stop_lon\n" + "Stop One,1.0,2.0\n", # stop_id absent + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }, + ) runner = CliRunner() result = runner.invoke(main, [str(base), str(new)]) assert result.exit_code == 1 def test_error_message_names_file_and_missing_column(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": "stop_name,stop_lat,stop_lon\nStop One,1.0,2.0\n", # stop_id absent - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": "stop_name,stop_lat,stop_lon\n" + "Stop One,1.0,2.0\n", # stop_id absent + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }, + ) runner = CliRunner() result = runner.invoke(main, [str(base), str(new)]) assert "stops.txt" in result.output assert "stop_id" in result.output def test_error_message_includes_headers_found(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": "stop_name,stop_lat,stop_lon\nStop One,1.0,2.0\n", # stop_id absent - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": "stop_name,stop_lat,stop_lon\n" + "Stop One,1.0,2.0\n", # stop_id absent + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }, + ) runner = CliRunner() result = runner.invoke(main, [str(base), str(new)]) assert "stop_name" in result.output diff --git a/tests/test_engine.py b/tests/test_engine.py index f60e94d..383ae6f 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -8,14 +8,13 @@ from gtfs_diff.engine import MissingPrimaryKeyError, diff_feeds from gtfs_diff.models import GtfsDiff - from tests.helpers import write_zip - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _get_file_diff(result: GtfsDiff, file_name: str): """Return the FileDiff for a given file name, or raise.""" for fd in result.file_diffs: @@ -38,15 +37,22 @@ def _get_file_summary(result: GtfsDiff, file_name: str): # File-level tests # --------------------------------------------------------------------------- + class TestFileAdded: def test_file_added(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - "routes.txt": "route_id,route_short_name\nR1,Route 1\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + "routes.txt": "route_id,route_short_name\nR1,Route 1\n", + }, + ) result = diff_feeds(base, new) fd = _get_file_diff(result, "routes.txt") assert fd.file_action == "added" @@ -56,13 +62,19 @@ def test_file_added(self, tmp_path: Path): assert "route_short_name" in column_names def test_file_added_summary_status(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - "routes.txt": "route_id,route_short_name\nR1,Route 1\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + "routes.txt": "route_id,route_short_name\nR1,Route 1\n", + }, + ) result = diff_feeds(base, new) fs = _get_file_summary(result, "routes.txt") assert fs.status == "added" @@ -71,13 +83,19 @@ def test_file_added_summary_status(self, tmp_path: Path): class TestFileDeleted: def test_file_deleted(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - "routes.txt": "route_id,route_short_name\nR1,Route 1\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + "routes.txt": "route_id,route_short_name\nR1,Route 1\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }, + ) result = diff_feeds(base, new) fd = _get_file_diff(result, "routes.txt") assert fd.file_action == "deleted" @@ -86,13 +104,19 @@ def test_file_deleted(self, tmp_path: Path): assert "route_id" in column_names def test_file_deleted_summary_status(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - "routes.txt": "route_id,route_short_name\nR1,Route 1\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + "routes.txt": "route_id,route_short_name\nR1,Route 1\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }, + ) result = diff_feeds(base, new) fs = _get_file_summary(result, "routes.txt") assert fs.status == "deleted" @@ -114,25 +138,40 @@ def test_identical_feeds(self, tmp_path: Path): # Row-level tests # --------------------------------------------------------------------------- + class TestRowsAdded: def test_rows_added_count(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\nS2,Stop Two,3.0,4.0\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + + "S1,Stop One,1.0,2.0\nS2,Stop Two,3.0,4.0\n", + }, + ) result = diff_feeds(base, new) fd = _get_file_diff(result, "stops.txt") assert fd.stats.rows_added_count == 1 def test_rows_added_identifier(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\nS2,Stop Two,3.0,4.0\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + + "S1,Stop One,1.0,2.0\nS2,Stop Two,3.0,4.0\n", + }, + ) result = diff_feeds(base, new) fd = _get_file_diff(result, "stops.txt") assert len(fd.row_changes.added) == 1 @@ -141,23 +180,37 @@ def test_rows_added_identifier(self, tmp_path: Path): class TestRowsDeleted: def test_rows_deleted_count(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\nS2,Stop Two,3.0,4.0\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": STOPS_HEADER + + "S1,Stop One,1.0,2.0\nS2,Stop Two,3.0,4.0\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }, + ) result = diff_feeds(base, new) fd = _get_file_diff(result, "stops.txt") assert fd.stats.rows_deleted_count == 1 def test_rows_deleted_identifier(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\nS2,Stop Two,3.0,4.0\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": STOPS_HEADER + + "S1,Stop One,1.0,2.0\nS2,Stop Two,3.0,4.0\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }, + ) result = diff_feeds(base, new) fd = _get_file_diff(result, "stops.txt") assert len(fd.row_changes.deleted) == 1 @@ -166,23 +219,35 @@ def test_rows_deleted_identifier(self, tmp_path: Path): class TestRowsModified: def test_rows_modified_count(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One Renamed,1.0,2.0\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One Renamed,1.0,2.0\n", + }, + ) result = diff_feeds(base, new) fd = _get_file_diff(result, "stops.txt") assert fd.stats.rows_modified_count == 1 def test_rows_modified_field_changes(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One Renamed,1.0,2.0\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One Renamed,1.0,2.0\n", + }, + ) result = diff_feeds(base, new) fd = _get_file_diff(result, "stops.txt") assert len(fd.row_changes.modified) == 1 @@ -190,7 +255,9 @@ def test_rows_modified_field_changes(self, tmp_path: Path): assert mod.identifier == {"stop_id": "S1"} field_names = [fc.field for fc in mod.field_changes] assert "stop_name" in field_names - stop_name_change = next(fc for fc in mod.field_changes if fc.field == "stop_name") + stop_name_change = next( + fc for fc in mod.field_changes if fc.field == "stop_name" + ) assert stop_name_change.base_value == "Stop One" assert stop_name_change.new_value == "Stop One Renamed" @@ -224,12 +291,20 @@ def test_swapped_row_order_is_not_a_change(self, tmp_path: Path): # Swapping two rows must not produce any diff. # NOTE: should a row reorder be reported as a structural change even when # no field values differ? This is currently an open design question. - base = write_zip(tmp_path / "base.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\nS2,Stop Two,3.0,4.0\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S2,Stop Two,3.0,4.0\nS1,Stop One,1.0,2.0\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": STOPS_HEADER + + "S1,Stop One,1.0,2.0\nS2,Stop Two,3.0,4.0\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + + "S2,Stop Two,3.0,4.0\nS1,Stop One,1.0,2.0\n", + }, + ) result = diff_feeds(base, new) assert result.file_diffs == [] assert result.summary.total_changes == 0 @@ -237,18 +312,22 @@ def test_swapped_row_order_is_not_a_change(self, tmp_path: Path): def test_trailing_zeros_in_coordinates_are_not_a_change(self, tmp_path: Path): # A producer may write '-73.55625' in one version and '-73.556250' in the # next. These are numerically identical and must not be reported as a diff. - base = write_zip(tmp_path / "base.zip", { - "shapes.txt": ( - "shape_id,shape_pt_lat,shape_pt_lon,shape_pt_sequence\n" - "11071,45.518332,-73.55625,150001\n" - ), - }) - new = write_zip(tmp_path / "new.zip", { - "shapes.txt": ( - "shape_id,shape_pt_lat,shape_pt_lon,shape_pt_sequence\n" - "11071,45.518332,-73.556250,150001\n" - ), - }) + base = write_zip( + tmp_path / "base.zip", + { + "shapes.txt": ( + "shape_id,shape_pt_lat,shape_pt_lon,shape_pt_sequence\n11071,45.518332,-73.55625,150001\n" + ), + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "shapes.txt": ( + "shape_id,shape_pt_lat,shape_pt_lon,shape_pt_sequence\n11071,45.518332,-73.556250,150001\n" + ), + }, + ) result = diff_feeds(base, new) assert result.file_diffs == [] assert result.summary.total_changes == 0 @@ -258,12 +337,18 @@ def test_swapped_column_order_is_not_a_change(self, tmp_path: Path): # Swapping two columns must not produce any diff. # NOTE: should a column reorder be reported as a structural change even when # no field values differ? This is currently an open design question. - base = write_zip(tmp_path / "base.zip", { - "stops.txt": "stop_id,stop_name,stop_lat\nS1,Stop One,1.0\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": "stop_name,stop_id,stop_lat\nStop One,S1,1.0\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": "stop_id,stop_name,stop_lat\nS1,Stop One,1.0\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": "stop_name,stop_id,stop_lat\nStop One,S1,1.0\n", + }, + ) result = diff_feeds(base, new) assert result.file_diffs == [] assert result.summary.total_changes == 0 @@ -273,26 +358,39 @@ def test_swapped_column_order_is_not_a_change(self, tmp_path: Path): # Column-level tests # --------------------------------------------------------------------------- + class TestColumnAdded: def test_column_added(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": "stop_id,stop_name\nS1,Stop One\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": "stop_id,stop_name,stop_desc\nS1,Stop One,A description\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": "stop_id,stop_name\nS1,Stop One\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": "stop_id,stop_name,stop_desc\nS1,Stop One,A description\n", + }, + ) result = diff_feeds(base, new) fd = _get_file_diff(result, "stops.txt") added_names = [c.name for c in fd.columns_added] assert "stop_desc" in added_names def test_column_added_position(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": "stop_id,stop_name\nS1,Stop One\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": "stop_id,stop_name,stop_desc\nS1,Stop One,A description\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": "stop_id,stop_name\nS1,Stop One\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": "stop_id,stop_name,stop_desc\nS1,Stop One,A description\n", + }, + ) result = diff_feeds(base, new) fd = _get_file_diff(result, "stops.txt") stop_desc_col = next(c for c in fd.columns_added if c.name == "stop_desc") @@ -301,24 +399,36 @@ def test_column_added_position(self, tmp_path: Path): class TestColumnDeleted: def test_column_deleted(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": "stop_id,stop_name,stop_desc\nS1,Stop One,A description\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": "stop_id,stop_name\nS1,Stop One\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": "stop_id,stop_name,stop_desc\nS1,Stop One,A description\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": "stop_id,stop_name\nS1,Stop One\n", + }, + ) result = diff_feeds(base, new) fd = _get_file_diff(result, "stops.txt") deleted_names = [c.name for c in fd.columns_deleted] assert "stop_desc" in deleted_names def test_column_deleted_position(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": "stop_id,stop_name,stop_desc\nS1,Stop One,A description\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": "stop_id,stop_name\nS1,Stop One\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": "stop_id,stop_name,stop_desc\nS1,Stop One,A description\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": "stop_id,stop_name\nS1,Stop One\n", + }, + ) result = diff_feeds(base, new) fd = _get_file_diff(result, "stops.txt") stop_desc_col = next(c for c in fd.columns_deleted if c.name == "stop_desc") @@ -329,6 +439,7 @@ def test_column_deleted_position(self, tmp_path: Path): # Cap tests # --------------------------------------------------------------------------- + def _make_stops_csv(n: int) -> str: header = "stop_id,stop_name\n" rows = "".join(f"S{i},Stop {i}\n" for i in range(1, n + 1)) @@ -337,12 +448,18 @@ def _make_stops_csv(n: int) -> str: class TestCapZero: def test_cap_zero_omits_row_changes(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": "stop_id,stop_name\nS1,Stop One\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": "stop_id,stop_name\nS1,Stop One\nS2,Stop Two\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": "stop_id,stop_name\nS1,Stop One\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": "stop_id,stop_name\nS1,Stop One\nS2,Stop Two\n", + }, + ) result = diff_feeds(base, new, row_changes_cap_per_file=0) fd = _get_file_diff(result, "stops.txt") assert fd.row_changes is None @@ -351,12 +468,19 @@ def test_cap_zero_omits_row_changes(self, tmp_path: Path): class TestCapLimits: def test_cap_limits_row_changes(self, tmp_path: Path): # 5 new rows added, cap = 3 → 3 included, omitted_count = 2 - base = write_zip(tmp_path / "base.zip", { - "stops.txt": "stop_id,stop_name\nS0,Stop Zero\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": "stop_id,stop_name\n" + "".join(f"S{i},Stop {i}\n" for i in range(1, 7)), - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": "stop_id,stop_name\nS0,Stop Zero\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": "stop_id,stop_name\n" + + "".join(f"S{i},Stop {i}\n" for i in range(1, 7)), + }, + ) result = diff_feeds(base, new, row_changes_cap_per_file=3) fd = _get_file_diff(result, "stops.txt") # Total included across added+deleted+modified <= 3 @@ -372,16 +496,25 @@ def test_cap_limits_row_changes(self, tmp_path: Path): def test_truncated_omitted_count_correct(self, tmp_path: Path): # 5 added rows, 0 deleted, 0 modified; cap = 3 → omitted = 2 - base = write_zip(tmp_path / "base.zip", { - "stops.txt": _make_stops_csv(0).replace("\n", "", 1), # header only - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": _make_stops_csv(0).replace("\n", "", 1), # header only + }, + ) # Write header-only base - base = write_zip(tmp_path / "base2.zip", { - "stops.txt": "stop_id,stop_name\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": _make_stops_csv(5), - }) + base = write_zip( + tmp_path / "base2.zip", + { + "stops.txt": "stop_id,stop_name\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": _make_stops_csv(5), + }, + ) result = diff_feeds(base, new, row_changes_cap_per_file=3) fd = _get_file_diff(result, "stops.txt") assert len(fd.row_changes.added) == 3 @@ -390,12 +523,18 @@ def test_truncated_omitted_count_correct(self, tmp_path: Path): class TestCapNone: def test_cap_none_includes_all(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": "stop_id,stop_name\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": _make_stops_csv(5), - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": "stop_id,stop_name\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": _make_stops_csv(5), + }, + ) result = diff_feeds(base, new, row_changes_cap_per_file=None) fd = _get_file_diff(result, "stops.txt") assert len(fd.row_changes.added) == 5 @@ -406,58 +545,93 @@ def test_cap_none_includes_all(self, tmp_path: Path): # Missing primary key column # --------------------------------------------------------------------------- + class TestMissingPrimaryKeyError: def test_missing_pk_column_in_base_raises(self, tmp_path: Path): - """diff_feeds raises MissingPrimaryKeyError when the base feed is missing a required PK column.""" - base = write_zip(tmp_path / "base.zip", { - "stops.txt": "stop_name,stop_lat,stop_lon\nStop One,1.0,2.0\n", # stop_id absent - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - }) + """diff_feeds raises MissingPrimaryKeyError when the base feed + is missing a required PK column.""" + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": "stop_name,stop_lat,stop_lon\n" + "Stop One,1.0,2.0\n", # stop_id absent + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }, + ) with pytest.raises(MissingPrimaryKeyError): diff_feeds(base, new) def test_missing_pk_column_in_new_raises(self, tmp_path: Path): - """diff_feeds raises MissingPrimaryKeyError when the new feed is missing a required PK column.""" - base = write_zip(tmp_path / "base.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": "stop_name,stop_lat,stop_lon\nStop One,1.0,2.0\n", # stop_id absent - }) + """diff_feeds raises MissingPrimaryKeyError when the new feed + is missing a required PK column.""" + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": "stop_name,stop_lat,stop_lon\n" + "Stop One,1.0,2.0\n", # stop_id absent + }, + ) with pytest.raises(MissingPrimaryKeyError): diff_feeds(base, new) def test_exception_carries_file_name(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": "stop_name,stop_lat,stop_lon\nStop One,1.0,2.0\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": "stop_name,stop_lat,stop_lon\nStop One,1.0,2.0\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }, + ) with pytest.raises(MissingPrimaryKeyError) as exc_info: diff_feeds(base, new) assert exc_info.value.file_name == "stops.txt" def test_exception_carries_missing_columns(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": "stop_name,stop_lat,stop_lon\nStop One,1.0,2.0\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": "stop_name,stop_lat,stop_lon\nStop One,1.0,2.0\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }, + ) with pytest.raises(MissingPrimaryKeyError) as exc_info: diff_feeds(base, new) assert "stop_id" in exc_info.value.missing_columns def test_exception_carries_headers(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": "stop_name,stop_lat,stop_lon\nStop One,1.0,2.0\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": "stop_name,stop_lat,stop_lon\nStop One,1.0,2.0\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }, + ) with pytest.raises(MissingPrimaryKeyError) as exc_info: diff_feeds(base, new) assert exc_info.value.headers == ["stop_name", "stop_lat", "stop_lon"] @@ -467,31 +641,48 @@ def test_exception_carries_headers(self, tmp_path: Path): # Unsupported files # --------------------------------------------------------------------------- + class TestUnsupportedFile: def test_unsupported_file_in_metadata(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - "custom_data.txt": "foo,bar\n1,2\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - "custom_data.txt": "foo,bar\n1,2\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + "custom_data.txt": "foo,bar\n1,2\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + "custom_data.txt": "foo,bar\n1,2\n", + }, + ) result = diff_feeds(base, new) unsupported_names = [u.file_name for u in result.metadata.unsupported_files] assert "custom_data.txt" in unsupported_names def test_unsupported_file_present_in_both(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - "custom_data.txt": "foo,bar\n1,2\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - "custom_data.txt": "foo,bar\n1,2\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + "custom_data.txt": "foo,bar\n1,2\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + "custom_data.txt": "foo,bar\n1,2\n", + }, + ) result = diff_feeds(base, new) - uf = next(u for u in result.metadata.unsupported_files if u.file_name == "custom_data.txt") + uf = next( + u + for u in result.metadata.unsupported_files + if u.file_name == "custom_data.txt" + ) assert uf.present_in == "both" @@ -499,35 +690,50 @@ def test_unsupported_file_present_in_both(self, tmp_path: Path): # Schema validation (model round-trip) # --------------------------------------------------------------------------- + class TestOutputMatchesSchema: def test_output_matches_schema(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\nS2,Stop Two,3.0,4.0\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": STOPS_HEADER + + "S1,Stop One,1.0,2.0\nS2,Stop Two,3.0,4.0\n", + }, + ) result = diff_feeds(base, new) json_str = result.model_dump_json() restored = GtfsDiff.model_validate_json(json_str) assert restored.summary.total_changes >= 1 - assert restored.metadata.schema_version == "2.0" + assert restored.metadata.schema_version == "v2-rc1" # --------------------------------------------------------------------------- # raw_value column ordering # --------------------------------------------------------------------------- + class TestRawValueColumnOrder: def test_raw_value_has_empty_for_base_only_column(self, tmp_path: Path): # base has stop_lat; new does NOT have stop_lat (deleted column). # Added rows in new should have empty string for stop_lat in raw_value. - base = write_zip(tmp_path / "base.zip", { - "stops.txt": "stop_id,stop_name,stop_lat\nS1,Stop One,1.0\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": "stop_id,stop_name\nS1,Stop One\nS2,Stop Two\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": "stop_id,stop_name,stop_lat\nS1,Stop One,1.0\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": "stop_id,stop_name\nS1,Stop One\nS2,Stop Two\n", + }, + ) result = diff_feeds(base, new) fd = _get_file_diff(result, "stops.txt") # S2 is an added row; union_columns = [stop_id, stop_name, stop_lat] @@ -536,7 +742,9 @@ def test_raw_value_has_empty_for_base_only_column(self, tmp_path: Path): added = fd.row_changes.added[0] assert added.identifier == {"stop_id": "S2"} # Parse raw_value CSV - import csv, io + import csv + import io + row = next(csv.reader(io.StringIO(added.raw_value))) # union_columns = base_headers + new_only_cols = [stop_id, stop_name, stop_lat] # stop_lat not in new → empty @@ -544,12 +752,19 @@ def test_raw_value_has_empty_for_base_only_column(self, tmp_path: Path): def test_union_columns_order_base_first(self, tmp_path: Path): # New feed has an extra column; union_columns must be base_headers + new_only - base = write_zip(tmp_path / "base.zip", { - "stops.txt": "stop_id,stop_name\nS1,Stop One\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": "stop_id,stop_name,stop_desc\nS1,Stop One,Desc One\nS2,Stop Two,Desc Two\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": "stop_id,stop_name\nS1,Stop One\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": "stop_id,stop_name,stop_desc\nS1,Stop One,Desc One\n" + "S2,Stop Two,Desc Two\n", + }, + ) result = diff_feeds(base, new) fd = _get_file_diff(result, "stops.txt") assert fd.row_changes.columns == ["stop_id", "stop_name", "stop_desc"] @@ -559,39 +774,58 @@ def test_union_columns_order_base_first(self, tmp_path: Path): # Line numbers # --------------------------------------------------------------------------- + class TestLineNumbers: def test_added_row_line_number(self, tmp_path: Path): # Header = line 1, first data row = line 2, second data row = line 3 - base = write_zip(tmp_path / "base.zip", { - "stops.txt": "stop_id,stop_name\nS1,Stop One\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": "stop_id,stop_name\nS1,Stop One\nS2,Stop Two\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": "stop_id,stop_name\nS1,Stop One\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": "stop_id,stop_name\nS1,Stop One\nS2,Stop Two\n", + }, + ) result = diff_feeds(base, new) fd = _get_file_diff(result, "stops.txt") assert len(fd.row_changes.added) == 1 assert fd.row_changes.added[0].new_line_number == 3 def test_deleted_row_line_number(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": "stop_id,stop_name\nS1,Stop One\nS2,Stop Two\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": "stop_id,stop_name\nS1,Stop One\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": "stop_id,stop_name\nS1,Stop One\nS2,Stop Two\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": "stop_id,stop_name\nS1,Stop One\n", + }, + ) result = diff_feeds(base, new) fd = _get_file_diff(result, "stops.txt") assert len(fd.row_changes.deleted) == 1 assert fd.row_changes.deleted[0].base_line_number == 3 def test_modified_row_line_numbers(self, tmp_path: Path): - base = write_zip(tmp_path / "base.zip", { - "stops.txt": "stop_id,stop_name\nS1,Stop One\nS2,Stop Two\n", - }) - new = write_zip(tmp_path / "new.zip", { - "stops.txt": "stop_id,stop_name\nS1,Stop One\nS2,Stop Two RENAMED\n", - }) + base = write_zip( + tmp_path / "base.zip", + { + "stops.txt": "stop_id,stop_name\nS1,Stop One\nS2,Stop Two\n", + }, + ) + new = write_zip( + tmp_path / "new.zip", + { + "stops.txt": "stop_id,stop_name\nS1,Stop One\nS2,Stop Two RENAMED\n", + }, + ) result = diff_feeds(base, new) fd = _get_file_diff(result, "stops.txt") assert len(fd.row_changes.modified) == 1 diff --git a/tests/test_models.py b/tests/test_models.py index 41c4954..1a0a7da 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -9,16 +9,14 @@ from gtfs_diff.models import ( ColumnEntry, + FeedSource, FieldChange, FileDiff, FileStats, FileSummary, GtfsDiff, Metadata, - FeedSource, - RowAdded, RowChanges, - RowDeleted, RowModified, Summary, Truncated, @@ -32,6 +30,7 @@ # Helpers to build valid model instances # --------------------------------------------------------------------------- + def _feed_source(url: str = "http://example.com/feed.zip") -> FeedSource: return FeedSource(source=url, downloaded_at=NOW) @@ -89,6 +88,7 @@ def _gtfs_diff(**kwargs) -> GtfsDiff: # GtfsDiff round-trip # --------------------------------------------------------------------------- + class TestGtfsDiffRoundTrip: def test_round_trip_empty(self): obj = _gtfs_diff() @@ -119,6 +119,7 @@ def test_round_trip_json(self): # ColumnEntry # --------------------------------------------------------------------------- + class TestColumnEntry: def test_valid(self): col = ColumnEntry(name="stop_id", position=1) @@ -142,6 +143,7 @@ def test_position_one_accepted(self): # RowChanges # --------------------------------------------------------------------------- + class TestRowChanges: def test_valid(self): rc = RowChanges( @@ -168,6 +170,7 @@ def test_empty_primary_key_rejected(self): # RowModified # --------------------------------------------------------------------------- + class TestRowModified: def test_valid(self): rm = RowModified( @@ -196,6 +199,7 @@ def test_empty_field_changes_rejected(self): # FileSummary # --------------------------------------------------------------------------- + class TestFileSummary: def test_valid(self): fs = FileSummary(file_name="stops.txt", status="modified") @@ -229,6 +233,7 @@ def test_with_counts(self): # Truncated # --------------------------------------------------------------------------- + class TestTruncated: def test_valid(self): t = Truncated(is_truncated=True, omitted_count=5) @@ -248,6 +253,7 @@ def test_omitted_count_must_be_positive(self): # UnsupportedFile # --------------------------------------------------------------------------- + class TestUnsupportedFile: @pytest.mark.parametrize("present_in", ["base", "new", "both"]) def test_valid_present_in(self, present_in: str): From 2c0bfae2fcafcf1ae939cbd337e07729b2cf90ee Mon Sep 17 00:00:00 2001 From: cka-y Date: Tue, 2 Jun 2026 14:50:13 -0400 Subject: [PATCH 3/8] fix: zip without strict --- pyproject.toml | 3 --- src/gtfs_diff/engine.py | 6 +++--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4932a02..86fb16d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,9 +44,6 @@ select = [ "B", # flake8-bugbear "SIM", # flake8-simplify ] -ignore = [ - "B905", # zip-without-strict — false positives on pre-padded rows -] [tool.ruff.format] exclude = ["src/gtfs_diff/models.py"] diff --git a/src/gtfs_diff/engine.py b/src/gtfs_diff/engine.py index f858c88..58ad3d0 100644 --- a/src/gtfs_diff/engine.py +++ b/src/gtfs_diff/engine.py @@ -151,13 +151,13 @@ def _read_csv_index( if len(row) < n: row = row + [""] * (n - len(row)) row_vals = row[:n] - row_dict = dict(zip(headers, row_vals)) + row_dict = dict(zip(headers, row_vals, strict=True)) pk_tuple = tuple(row_dict.get(col, "") for col in effective_pk) if pk_tuple in index: raise ValueError( f"{file_name}: duplicate primary key " - f"{dict(zip(effective_pk, pk_tuple))} " + f"{dict(zip(effective_pk, pk_tuple, strict=True))} " f"at line {line_num} " f"(first seen at line {index[pk_tuple][0]})." ) @@ -176,7 +176,7 @@ def _parse_raw_line(raw_line: str, headers: list[str]) -> dict[str, str]: return {col: "" for col in headers} if len(row) < len(headers): row = row + [""] * (len(headers) - len(row)) - return dict(zip(headers, row)) + return dict(zip(headers, row, strict=True)) def _values_differ(a: str, b: str) -> bool: From 6d5af75caa9fdb4a686eef55710ebd7ffad380e6 Mon Sep 17 00:00:00 2001 From: cka-y Date: Tue, 2 Jun 2026 14:51:23 -0400 Subject: [PATCH 4/8] update models --- src/gtfs_diff/models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gtfs_diff/models.py b/src/gtfs_diff/models.py index 8c72bbe..cc4f9ac 100644 --- a/src/gtfs_diff/models.py +++ b/src/gtfs_diff/models.py @@ -5,9 +5,11 @@ from __future__ import annotations -from pydantic import AwareDatetime, BaseModel, Field + from typing import Literal +from pydantic import AwareDatetime, BaseModel, Field + class UnsupportedFile(BaseModel): file_name: str = Field(..., description="File name as it appears in the archive.") From f96ef39fab54d2d99cb03154ff5b39919c9f23b5 Mon Sep 17 00:00:00 2001 From: cka-y Date: Tue, 2 Jun 2026 15:06:55 -0400 Subject: [PATCH 5/8] fix: moved conf file --- .github/workflows/ci.yml | 2 +- scripts/generate_models.sh | 11 ++++++----- src/gtfs_diff/engine.py | 5 +++-- schema.conf => src/gtfs_diff/schema.conf | 0 4 files changed, 10 insertions(+), 8 deletions(-) rename schema.conf => src/gtfs_diff/schema.conf (100%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a2056ab..f18f0f2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,7 +25,7 @@ jobs: - 'tests/**' - 'pyproject.toml' models: - - 'schema.conf' + - 'src/gtfs_diff/schema.conf' - 'scripts/generate_models.sh' - 'src/gtfs_diff/models.py' - '.github/workflows/ci.yml' diff --git a/scripts/generate_models.sh b/scripts/generate_models.sh index a2dacd2..a57ecf7 100755 --- a/scripts/generate_models.sh +++ b/scripts/generate_models.sh @@ -2,7 +2,7 @@ # Generate Pydantic v2 models from the GTFS Diff JSON Schema. # # Usage: -# ./scripts/generate_models.sh # use version from schema.conf +# ./scripts/generate_models.sh # use version from src/gtfs_diff/schema.conf # ./scripts/generate_models.sh v2-rc1 # fetch a specific version from GitHub # ./scripts/generate_models.sh /path/to/local.json # use a local file set -euo pipefail @@ -13,17 +13,18 @@ SCHEMA_REPO="MobilityData/gtfs-diff" SCHEMA_BRANCH="main" # --- Resolve input: argument > schema.conf ----------------------------------- +SCHEMA_CONF="$REPO_ROOT/src/gtfs_diff/schema.conf" if [ $# -ge 1 ]; then INPUT="$1" -elif [ -f "$REPO_ROOT/schema.conf" ]; then - # shellcheck source=../schema.conf - source "$REPO_ROOT/schema.conf" +elif [ -f "$SCHEMA_CONF" ]; then + # shellcheck source=../src/gtfs_diff/schema.conf + source "$SCHEMA_CONF" INPUT="${SCHEMA_VERSION:?SCHEMA_VERSION not set in schema.conf}" echo "Using version from schema.conf: $INPUT" else echo "Usage: $0 []" >&2 echo " e.g. $0 v2-rc1" >&2 - echo " or set SCHEMA_VERSION in schema.conf" >&2 + echo " or set SCHEMA_VERSION in src/gtfs_diff/schema.conf" >&2 exit 1 fi diff --git a/src/gtfs_diff/engine.py b/src/gtfs_diff/engine.py index 58ad3d0..45cfc85 100644 --- a/src/gtfs_diff/engine.py +++ b/src/gtfs_diff/engine.py @@ -17,6 +17,7 @@ import sys import time import zipfile +from importlib import resources from collections.abc import Callable, Generator from contextlib import contextmanager from datetime import datetime, timezone @@ -44,9 +45,9 @@ def _read_schema_version() -> str: - conf_path = Path(__file__).resolve().parent.parent.parent / "schema.conf" + conf_text = resources.files("gtfs_diff").joinpath("schema.conf").read_text() parser = configparser.ConfigParser() - parser.read_string("[default]\n" + conf_path.read_text()) + parser.read_string("[default]\n" + conf_text) return parser.get("default", "SCHEMA_VERSION") diff --git a/schema.conf b/src/gtfs_diff/schema.conf similarity index 100% rename from schema.conf rename to src/gtfs_diff/schema.conf From 176b896acc4264bbad97021b533da9e116a388f6 Mon Sep 17 00:00:00 2001 From: cka-y Date: Tue, 2 Jun 2026 15:08:29 -0400 Subject: [PATCH 6/8] fix: moved conf file --- src/gtfs_diff/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gtfs_diff/engine.py b/src/gtfs_diff/engine.py index 45cfc85..fda1ba9 100644 --- a/src/gtfs_diff/engine.py +++ b/src/gtfs_diff/engine.py @@ -17,10 +17,10 @@ import sys import time import zipfile -from importlib import resources from collections.abc import Callable, Generator from contextlib import contextmanager from datetime import datetime, timezone +from importlib import resources from pathlib import Path from typing import TextIO From 4c25ea79a8ef79b43ffd0fb24ccd1b9184041adc Mon Sep 17 00:00:00 2001 From: cka-y Date: Tue, 2 Jun 2026 15:12:32 -0400 Subject: [PATCH 7/8] fix: added extra allowed --- scripts/generate_models.sh | 1 + src/gtfs_diff/models.py | 56 +++++++++++++++++++++++++++++++++++++- 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/scripts/generate_models.sh b/scripts/generate_models.sh index a57ecf7..59a394a 100755 --- a/scripts/generate_models.sh +++ b/scripts/generate_models.sh @@ -63,6 +63,7 @@ datamodel-codegen \ --enum-field-as-literal all \ --field-constraints \ --use-schema-description \ + --extra-fields allow \ --class-name GtfsDiff \ --output "$OUTPUT" diff --git a/src/gtfs_diff/models.py b/src/gtfs_diff/models.py index cc4f9ac..b8eb026 100644 --- a/src/gtfs_diff/models.py +++ b/src/gtfs_diff/models.py @@ -8,10 +8,13 @@ from typing import Literal -from pydantic import AwareDatetime, BaseModel, Field +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field class UnsupportedFile(BaseModel): + model_config = ConfigDict( + extra="allow", + ) file_name: str = Field(..., description="File name as it appears in the archive.") present_in: Literal["base", "new", "both"] = Field( ..., description="Which archive(s) contain this file." @@ -19,6 +22,9 @@ class UnsupportedFile(BaseModel): class ColumnEntry(BaseModel): + model_config = ConfigDict( + extra="allow", + ) name: str = Field(..., description="Column name.") position: int = Field( ..., description="1-based position of this column in the CSV header row.", ge=1 @@ -26,6 +32,9 @@ class ColumnEntry(BaseModel): class FeedSource(BaseModel): + model_config = ConfigDict( + extra="allow", + ) source: str = Field(..., description="URL or local path to the GTFS archive.") downloaded_at: AwareDatetime = Field( ..., description="ISO 8601 timestamp of when the feed was downloaded." @@ -33,6 +42,9 @@ class FeedSource(BaseModel): class FileSummary(BaseModel): + model_config = ConfigDict( + extra="allow", + ) file_name: str = Field(..., description="Name of the GTFS file.") status: Literal["added", "deleted", "modified", "not_compared"] = Field( ..., description="File-level status." @@ -40,6 +52,9 @@ class FileSummary(BaseModel): class Truncated(BaseModel): + model_config = ConfigDict( + extra="allow", + ) is_truncated: Literal[True] = Field(..., description="Always true when present.") omitted_count: int = Field( ..., description="Number of row changes omitted due to the cap.", ge=1 @@ -47,6 +62,9 @@ class Truncated(BaseModel): class NotComparedReason(BaseModel): + model_config = ConfigDict( + extra="allow", + ) code: str = Field( ..., description='Machine-readable reason code (e.g. "id_churn", "missing_primary_key", "file_too_large").', @@ -58,11 +76,17 @@ class NotComparedReason(BaseModel): class IgnoredColumn(BaseModel): + model_config = ConfigDict( + extra="allow", + ) column: str = Field(..., description="The column name that was ignored.") reason: NotComparedReason class ColumnStat(BaseModel): + model_config = ConfigDict( + extra="allow", + ) column: str = Field(..., description="The column name.") modifications_count: int = Field( ..., @@ -78,6 +102,9 @@ class ColumnStat(BaseModel): class RowAdded(BaseModel): + model_config = ConfigDict( + extra="allow", + ) identifier: dict[str, str] = Field( ..., description="Primary key values identifying this row." ) @@ -91,6 +118,9 @@ class RowAdded(BaseModel): class RowDeleted(BaseModel): + model_config = ConfigDict( + extra="allow", + ) identifier: dict[str, str] = Field( ..., description="Primary key values identifying this row." ) @@ -104,12 +134,18 @@ class RowDeleted(BaseModel): class FieldChange(BaseModel): + model_config = ConfigDict( + extra="allow", + ) field: str = Field(..., description="The column name that changed.") base_value: str = Field(..., description="The value in the base feed.") new_value: str = Field(..., description="The value in the new feed.") class RowModified(BaseModel): + model_config = ConfigDict( + extra="allow", + ) identifier: dict[str, str] = Field( ..., description="Primary key values identifying this row." ) @@ -129,6 +165,9 @@ class RowModified(BaseModel): class Metadata(BaseModel): + model_config = ConfigDict( + extra="allow", + ) schema_version: str = Field(..., description="The version of the schema.") generated_at: AwareDatetime = Field( ..., description="ISO 8601 timestamp of when the diff was generated." @@ -144,6 +183,9 @@ class Metadata(BaseModel): class Summary(BaseModel): + model_config = ConfigDict( + extra="allow", + ) total_changes: int = Field( ..., description="Total number of changes across all files.", ge=0 ) @@ -161,6 +203,9 @@ class Summary(BaseModel): class RowChanges(BaseModel): + model_config = ConfigDict( + extra="allow", + ) primary_key: list[str] = Field( ..., description="Column(s) that uniquely identify a row.", min_length=1 ) @@ -174,6 +219,9 @@ class RowChanges(BaseModel): class FileStats(BaseModel): + model_config = ConfigDict( + extra="allow", + ) total_rows_base: int | None = Field( None, description="Total number of rows in the base version of the file.", ge=0 ) @@ -208,6 +256,9 @@ class FileStats(BaseModel): class FileDiff(BaseModel): + model_config = ConfigDict( + extra="allow", + ) file_name: str = Field(..., description="Name of the GTFS file.") file_action: Literal["modified", "added", "deleted", "not_compared"] = Field( ..., description="Action describing how this file changed." @@ -235,6 +286,9 @@ class GtfsDiff(BaseModel): Schema for GTFS Diff v2 output: a single JSON document describing all differences between two GTFS archives. """ + model_config = ConfigDict( + extra="allow", + ) metadata: Metadata summary: Summary file_diffs: list[FileDiff] From 775320a8e8ffc549d7a74f94c03db45ce499fcfb Mon Sep 17 00:00:00 2001 From: cka-y Date: Thu, 4 Jun 2026 12:52:37 -0400 Subject: [PATCH 8/8] fix: base row raw value --- src/gtfs_diff/engine.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gtfs_diff/engine.py b/src/gtfs_diff/engine.py index fda1ba9..ebced6c 100644 --- a/src/gtfs_diff/engine.py +++ b/src/gtfs_diff/engine.py @@ -523,10 +523,10 @@ def _remaining(used: int) -> int | None: for pk_tuple, field_changes, b_line, n_line in modified_candidates[ :modified_limit ]: - n_raw = new_index[pk_tuple][1] - n_dict = _parse_raw_line(n_raw, new_headers) - identifier = {col: n_dict.get(col, "") for col in pk_cols} - raw_value = _compute_raw_value(n_dict, union_columns, new_header_set) + b_raw = base_index[pk_tuple][1] + b_dict = _parse_raw_line(b_raw, base_headers) + identifier = {col: b_dict.get(col, "") for col in pk_cols} + raw_value = _compute_raw_value(b_dict, union_columns, base_header_set) modified_rows.append( RowModified( identifier=identifier,