From 21b718207a27a068ce69697f0cee0d7575215d49 Mon Sep 17 00:00:00 2001 From: Harsh Pandhe Date: Wed, 20 May 2026 13:42:17 +0530 Subject: [PATCH] feat(phase-4-plus): project save/load persists Tower offset + Phase 7+/12d state MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Until now, ropeway.project.save_project() persisted only Tower.distance and Tower.height. A save+load round trip lost: - Tower.is_station (Phase 7+ pinned-station flag) - Tower.offset (Phase 12c lateral offset) - no_tower_zones (Phase 7+) - forced_flyover_zones (Phase 12d) - intermediate_stations (Phase 7+ pinned waypoints) So a Streamlit user saving an urban-gondola study and re-opening it got the centreline back without the constraints that made it feasible. Changes (additive + backward-compatible): - Schema: new extras_json TEXT NOT NULL DEFAULT '{}' column. Older DBs without it auto-migrate via ALTER TABLE ADD COLUMN inside _connect() (idempotent via try/except OperationalError). - Tower JSON now carries is_station + offset. - extras_json holds the three corridor-constraint lists (no_tower_zones, intermediate_stations, forced_flyover_zones) with stable, hand-readable shapes. - ProjectRecord gains the three new fields (default empty lists). - ProjectRecord.to_alignment(profile_fn, surface_fn=...) rebuilds an Alignment that includes every persisted constraint, so the round-tripped alignment evaluates identically to the saved one. - save_project gets a new optional intermediate_stations kwarg (the only constraint Alignment doesn't already carry directly). Backward compatibility: rows written by the pre-4+ code (no extras_json column populated) deserialise to empty zone/station lists and Tower.offset defaults to 0.0. Existing tests pass unchanged. Tests: tests/test_project_persistence_plus.py — 7 new - tower offset round-trips - no_tower_zones round-trip (kind enum, distances, name) - forced_flyover_zones round-trip (min_cable_elev_m, name) - intermediate_stations round-trip - to_alignment() rebuilds all persisted constraints - legacy-format DB row (no extras_json) loads without error - delete_project return-value semantics Full suite 206 -> 216, zero regressions in test_project.py. --- src/ropeway/project.py | 151 +++++++++++++++-- tests/test_project_persistence_plus.py | 222 +++++++++++++++++++++++++ 2 files changed, 359 insertions(+), 14 deletions(-) create mode 100644 tests/test_project_persistence_plus.py diff --git a/src/ropeway/project.py b/src/ropeway/project.py index dd9e982..ba6a2d3 100644 --- a/src/ropeway/project.py +++ b/src/ropeway/project.py @@ -32,6 +32,7 @@ from pathlib import Path from .alignment import Alignment, Tower +from .obstacles import ForcedFlyOverZone, NoTowerZone, ZoneKind from .safety import ConstraintConfig DEFAULT_DB_PATH = Path("data/projects.db") @@ -46,11 +47,62 @@ corridor_len REAL, utm_epsg INTEGER, config_json TEXT NOT NULL, - towers_json TEXT NOT NULL + towers_json TEXT NOT NULL, + extras_json TEXT NOT NULL DEFAULT '{}' ); """ +def _zones_to_json(zones: list) -> list[dict]: + """Serialise a list of NoTowerZone instances (Phase 7+).""" + out = [] + for z in zones or []: + out.append({ + "distance_start_m": z.distance_start_m, + "distance_end_m": z.distance_end_m, + "kind": z.kind.value if hasattr(z.kind, "value") else str(z.kind), + "name": z.name, + }) + return out + + +def _zones_from_json(rows: list[dict]) -> list[NoTowerZone]: + return [ + NoTowerZone( + distance_start_m=float(r["distance_start_m"]), + distance_end_m=float(r["distance_end_m"]), + kind=ZoneKind(r["kind"]) if r.get("kind") else ZoneKind.PROTECTED, + name=str(r.get("name", "")), + ) + for r in (rows or []) + ] + + +def _flyovers_to_json(zones: list) -> list[dict]: + """Serialise a list of ForcedFlyOverZone instances (Phase 12d).""" + out = [] + for z in zones or []: + out.append({ + "distance_start_m": z.distance_start_m, + "distance_end_m": z.distance_end_m, + "min_cable_elev_m": z.min_cable_elev_m, + "name": z.name, + }) + return out + + +def _flyovers_from_json(rows: list[dict]) -> list[ForcedFlyOverZone]: + return [ + ForcedFlyOverZone( + distance_start_m=float(r["distance_start_m"]), + distance_end_m=float(r["distance_end_m"]), + min_cable_elev_m=float(r["min_cable_elev_m"]), + name=str(r.get("name", "")), + ) + for r in (rows or []) + ] + + @dataclass class ProjectRecord: id: int @@ -62,14 +114,39 @@ class ProjectRecord: utm_epsg: int config: ConstraintConfig towers: list[Tower] + # Phase 4+ — Phase 7+ / 12d state persisted alongside the towers. + intermediate_stations: list[float] = None + no_tower_zones: list = None + forced_flyover_zones: list = None - def to_alignment(self, profile_fn, clearance_profile=None) -> Alignment: - """Rebuild a runnable Alignment given a freshly-sampled profile_fn.""" + def __post_init__(self) -> None: + if self.intermediate_stations is None: + self.intermediate_stations = [] + if self.no_tower_zones is None: + self.no_tower_zones = [] + if self.forced_flyover_zones is None: + self.forced_flyover_zones = [] + + def to_alignment(self, profile_fn, clearance_profile=None, + surface_fn=None) -> Alignment: + """Rebuild a runnable Alignment given a freshly-sampled profile_fn. + + Phase 4+: ``Tower.offset`` and ``is_station`` are now preserved, + and the Phase 7+ / 12d corridor constraints carry through so the + round-tripped alignment evaluates identically to the saved one. + """ return Alignment( - towers=[Tower(t.distance, t.height) for t in self.towers], + towers=[ + Tower(t.distance, t.height, + is_station=bool(t.is_station), offset=float(t.offset)) + for t in self.towers + ], profile_fn=profile_fn, cfg=self.config, clearance_profile=clearance_profile, + no_tower_zones=list(self.no_tower_zones), + surface_fn=surface_fn, + forced_flyover_zones=list(self.forced_flyover_zones), ) @@ -78,6 +155,14 @@ def _connect(db_path: str | Path) -> sqlite3.Connection: db_path.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(str(db_path)) conn.execute(_SCHEMA) + # Phase 4+ schema migration: older DBs predate the extras_json column; + # add it idempotently so an upgraded process can still read old files. + try: + conn.execute("ALTER TABLE projects ADD COLUMN extras_json TEXT NOT NULL DEFAULT '{}'") + conn.commit() + except sqlite3.OperationalError: + # Column already exists — expected on every run after the first. + pass return conn @@ -90,20 +175,39 @@ def save_project( corridor_len: float, utm_epsg: int, db_path: str | Path = DEFAULT_DB_PATH, + intermediate_stations: list[float] | None = None, ) -> int: - """Persist a project and return its new row id.""" + """Persist a project and return its new row id. + + Phase 4+: Tower offset + is_station are now stored, and Phase 7+ / + 12d corridor constraints (no_tower_zones, intermediate_stations, + forced_flyover_zones) are serialised alongside as ``extras_json``. + """ conn = _connect(db_path) try: config_json = json.dumps(asdict(alignment.cfg)) - towers_json = json.dumps( - [{"distance": t.distance, "height": t.height} for t in alignment.towers] - ) + towers_json = json.dumps([ + { + "distance": t.distance, + "height": t.height, + "is_station": bool(t.is_station), + "offset": float(t.offset), + } + for t in alignment.towers + ]) + extras_json = json.dumps({ + "intermediate_stations": list(intermediate_stations or []), + "no_tower_zones": _zones_to_json(alignment.no_tower_zones), + "forced_flyover_zones": _flyovers_to_json( + alignment.forced_flyover_zones + ), + }) cur = conn.execute( """ INSERT INTO projects (name, created_at, start_lon, start_lat, end_lon, end_lat, - corridor_len, utm_epsg, config_json, towers_json) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + corridor_len, utm_epsg, config_json, towers_json, extras_json) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( name, @@ -111,7 +215,7 @@ def save_project( start_lonlat[0], start_lonlat[1], end_lonlat[0], end_lonlat[1], corridor_len, utm_epsg, - config_json, towers_json, + config_json, towers_json, extras_json, ), ) conn.commit() @@ -138,13 +242,19 @@ def list_projects(db_path: str | Path = DEFAULT_DB_PATH) -> list[dict]: def load_project(project_id: int, db_path: str | Path = DEFAULT_DB_PATH) -> ProjectRecord: - """Load a project by id. Raises KeyError if not found.""" + """Load a project by id. Raises KeyError if not found. + + Phase 4+: rehydrates Tower offset + is_station and the corridor- + constraint extras saved alongside the towers; pre-4+ records (no + extras column populated) deserialise to empty zone/station lists + so legacy callers see no behavioural change. + """ conn = _connect(db_path) try: row = conn.execute( """ SELECT id, name, created_at, start_lon, start_lat, end_lon, end_lat, - corridor_len, utm_epsg, config_json, towers_json + corridor_len, utm_epsg, config_json, towers_json, extras_json FROM projects WHERE id = ? """, (project_id,), @@ -153,7 +263,15 @@ def load_project(project_id: int, db_path: str | Path = DEFAULT_DB_PATH) -> Proj raise KeyError(f"project id {project_id} not found in {db_path}") cfg_dict = json.loads(row[9]) config = ConstraintConfig(**cfg_dict) - towers = [Tower(t["distance"], t["height"]) for t in json.loads(row[10])] + towers = [ + Tower( + t["distance"], t["height"], + is_station=bool(t.get("is_station", False)), + offset=float(t.get("offset", 0.0)), + ) + for t in json.loads(row[10]) + ] + extras = json.loads(row[11] or "{}") return ProjectRecord( id=row[0], name=row[1], @@ -164,6 +282,11 @@ def load_project(project_id: int, db_path: str | Path = DEFAULT_DB_PATH) -> Proj utm_epsg=row[8], config=config, towers=towers, + intermediate_stations=list(extras.get("intermediate_stations", [])), + no_tower_zones=_zones_from_json(extras.get("no_tower_zones", [])), + forced_flyover_zones=_flyovers_from_json( + extras.get("forced_flyover_zones", []) + ), ) finally: conn.close() diff --git a/tests/test_project_persistence_plus.py b/tests/test_project_persistence_plus.py new file mode 100644 index 0000000..efd374b --- /dev/null +++ b/tests/test_project_persistence_plus.py @@ -0,0 +1,222 @@ +"""Phase 4+ — project persistence carries Tower.offset, Tower.is_station, +and the Phase 7+/12d corridor constraints across save/load.""" + +from __future__ import annotations + +import sqlite3 +from pathlib import Path + +import numpy as np +import pytest + +from ropeway.alignment import Alignment, Tower +from ropeway.obstacles import ForcedFlyOverZone, NoTowerZone, ZoneKind +from ropeway.project import ( + ProjectRecord, + delete_project, + load_project, + save_project, +) +from ropeway.safety import ConstraintConfig + + +def _profile_fn(level=100.0): + def fn(x): + x = np.asarray(x, dtype=float) + return np.full_like(x, level) + return fn + + +@pytest.fixture() +def db(tmp_path: Path) -> Path: + return tmp_path / "projects.db" + + +# --------------------------------------------------------------------------- +# Tower.offset + Tower.is_station round-trip +# --------------------------------------------------------------------------- + + +def test_save_load_preserves_tower_offset(db): + align = Alignment( + towers=[ + Tower(0.0, 25.0, is_station=True), + Tower(500.0, 30.0, offset=42.5), + Tower(900.0, 28.0, offset=-17.0, is_station=True), + Tower(1500.0, 22.0, is_station=True), + ], + profile_fn=_profile_fn(), + cfg=ConstraintConfig(max_span_m=2000.0, corridor_half_width_m=80.0), + ) + pid = save_project( + "offsets", align, + start_lonlat=(0.0, 0.0), end_lonlat=(0.01, 0.0), + corridor_len=1500.0, utm_epsg=32631, db_path=db, + ) + rec = load_project(pid, db_path=db) + assert [t.offset for t in rec.towers] == pytest.approx([0.0, 42.5, -17.0, 0.0]) + assert [t.is_station for t in rec.towers] == [True, False, True, True] + + +def test_save_load_preserves_no_tower_zones(db): + align = Alignment( + towers=[Tower(0.0, 20.0), Tower(2000.0, 20.0)], + profile_fn=_profile_fn(), + cfg=ConstraintConfig(max_span_m=2500.0), + no_tower_zones=[ + NoTowerZone(400.0, 800.0, kind=ZoneKind.WATER, name="lake"), + NoTowerZone(1200.0, 1500.0, kind=ZoneKind.PROTECTED, + name="reserve"), + ], + ) + pid = save_project( + "zones", align, + start_lonlat=(0.0, 0.0), end_lonlat=(0.02, 0.0), + corridor_len=2000.0, utm_epsg=32631, db_path=db, + ) + rec = load_project(pid, db_path=db) + assert len(rec.no_tower_zones) == 2 + assert rec.no_tower_zones[0].name == "lake" + assert rec.no_tower_zones[0].kind == ZoneKind.WATER + assert rec.no_tower_zones[1].distance_start_m == 1200.0 + assert rec.no_tower_zones[1].distance_end_m == 1500.0 + + +def test_save_load_preserves_forced_flyover_zones(db): + align = Alignment( + towers=[Tower(0.0, 30.0), Tower(1000.0, 30.0)], + profile_fn=_profile_fn(), + cfg=ConstraintConfig(max_span_m=1200.0), + forced_flyover_zones=[ + ForcedFlyOverZone(200.0, 400.0, min_cable_elev_m=180.0, + name="bridge"), + ForcedFlyOverZone(600.0, 700.0, min_cable_elev_m=200.0), + ], + ) + pid = save_project( + "flyovers", align, + start_lonlat=(0.0, 0.0), end_lonlat=(0.01, 0.0), + corridor_len=1000.0, utm_epsg=32631, db_path=db, + ) + rec = load_project(pid, db_path=db) + assert len(rec.forced_flyover_zones) == 2 + assert rec.forced_flyover_zones[0].min_cable_elev_m == 180.0 + assert rec.forced_flyover_zones[0].name == "bridge" + + +def test_save_load_preserves_intermediate_stations(db): + align = Alignment( + towers=[Tower(0.0, 20.0), Tower(1000.0, 20.0), Tower(2000.0, 20.0)], + profile_fn=_profile_fn(), + cfg=ConstraintConfig(max_span_m=2500.0), + ) + pid = save_project( + "stations", align, + start_lonlat=(0.0, 0.0), end_lonlat=(0.02, 0.0), + corridor_len=2000.0, utm_epsg=32631, db_path=db, + intermediate_stations=[750.0, 1500.0], + ) + rec = load_project(pid, db_path=db) + assert rec.intermediate_stations == [750.0, 1500.0] + + +# --------------------------------------------------------------------------- +# to_alignment() round-trip +# --------------------------------------------------------------------------- + + +def test_to_alignment_rebuilds_everything(db): + align = Alignment( + towers=[ + Tower(0.0, 25.0, is_station=True), + Tower(800.0, 30.0, offset=15.0), + Tower(1500.0, 22.0, is_station=True), + ], + profile_fn=_profile_fn(), + cfg=ConstraintConfig(max_span_m=2000.0, corridor_half_width_m=50.0), + no_tower_zones=[NoTowerZone(900.0, 1100.0, name="creek")], + forced_flyover_zones=[ + ForcedFlyOverZone(400.0, 600.0, min_cable_elev_m=150.0, + name="road"), + ], + ) + pid = save_project( + "roundtrip", align, + start_lonlat=(0.0, 0.0), end_lonlat=(0.01, 0.0), + corridor_len=1500.0, utm_epsg=32631, db_path=db, + ) + rec = load_project(pid, db_path=db) + rebuilt = rec.to_alignment(_profile_fn()) + # Tower state mirrored. + assert rebuilt.towers[1].offset == pytest.approx(15.0) + assert rebuilt.towers[0].is_station and rebuilt.towers[-1].is_station + # Constraint lists carried. + assert len(rebuilt.no_tower_zones) == 1 + assert rebuilt.no_tower_zones[0].name == "creek" + assert len(rebuilt.forced_flyover_zones) == 1 + assert rebuilt.forced_flyover_zones[0].min_cable_elev_m == 150.0 + + +# --------------------------------------------------------------------------- +# Backward compatibility — load a pre-4+ DB row that lacks extras_json +# --------------------------------------------------------------------------- + + +def test_loads_legacy_row_without_extras(db): + """Older saves predate the extras_json column; load must succeed and + return empty zone/station lists.""" + # Hand-write a row missing the extras_json column (simulate legacy DB). + conn = sqlite3.connect(db) + conn.execute(""" + CREATE TABLE IF NOT EXISTS projects ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, created_at TEXT NOT NULL, + start_lon REAL, start_lat REAL, end_lon REAL, end_lat REAL, + corridor_len REAL, utm_epsg INTEGER, + config_json TEXT NOT NULL, towers_json TEXT NOT NULL + ) + """) + import json as _json + cfg = ConstraintConfig() + from dataclasses import asdict + conn.execute(""" + INSERT INTO projects (name, created_at, + start_lon, start_lat, end_lon, end_lat, + corridor_len, utm_epsg, config_json, towers_json) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + "legacy", "2025-01-01T00:00:00", + 0.0, 0.0, 0.01, 0.0, + 1000.0, 32631, + _json.dumps(asdict(cfg)), + _json.dumps([{"distance": 0.0, "height": 20.0}, + {"distance": 1000.0, "height": 20.0}]), + )) + conn.commit() + conn.close() + + rec = load_project(1, db_path=db) + assert rec.name == "legacy" + assert len(rec.towers) == 2 + # Missing fields default sensibly. + assert all(t.offset == 0.0 for t in rec.towers) + assert rec.no_tower_zones == [] + assert rec.forced_flyover_zones == [] + assert rec.intermediate_stations == [] + + +def test_delete_project_returns_true_when_row_removed(db): + align = Alignment( + towers=[Tower(0.0, 20.0), Tower(500.0, 20.0)], + profile_fn=_profile_fn(), + cfg=ConstraintConfig(max_span_m=1000.0), + ) + pid = save_project( + "delete-me", align, + start_lonlat=(0.0, 0.0), end_lonlat=(0.005, 0.0), + corridor_len=500.0, utm_epsg=32631, db_path=db, + ) + assert delete_project(pid, db_path=db) is True + assert delete_project(pid, db_path=db) is False + with pytest.raises(KeyError): + load_project(pid, db_path=db)