diff --git a/projects/policyengine-api-simulation/fixtures/gateway/package_imports.py b/projects/policyengine-api-simulation/fixtures/gateway/package_imports.py new file mode 100644 index 000000000..c1a6f73f9 --- /dev/null +++ b/projects/policyengine-api-simulation/fixtures/gateway/package_imports.py @@ -0,0 +1,68 @@ +"""Fixtures for gateway package import regression tests.""" + +from __future__ import annotations + +import importlib +import sys +from collections.abc import Iterator +from dataclasses import dataclass + +import pytest + + +GATEWAY_MODEL_MODULE = "src.modal.gateway.models" +GATEWAY_ENDPOINTS_MODULE = "src.modal.gateway.endpoints" +GATEWAY_PACKAGE_MODULE = "src.modal.gateway" +FASTAPI_MODULE = "fastapi" + +GATEWAY_MODEL_IMPORT_MODULES = ( + FASTAPI_MODULE, + GATEWAY_PACKAGE_MODULE, + GATEWAY_ENDPOINTS_MODULE, + GATEWAY_MODEL_MODULE, +) + + +@dataclass(frozen=True) +class GatewayImportModuleNames: + """Module names involved in the gateway model import boundary.""" + + endpoints: str = GATEWAY_ENDPOINTS_MODULE + fastapi: str = FASTAPI_MODULE + + +@pytest.fixture() +def gateway_import_module_names() -> GatewayImportModuleNames: + return GatewayImportModuleNames() + + +@pytest.fixture() +def isolated_gateway_model_import_modules() -> Iterator[None]: + """Temporarily clear modules that would mask import side effects.""" + previous_modules = { + module_name: sys.modules.pop(module_name, None) + for module_name in GATEWAY_MODEL_IMPORT_MODULES + } + + try: + yield + finally: + for module_name in GATEWAY_MODEL_IMPORT_MODULES: + sys.modules.pop(module_name, None) + sys.modules.update( + { + module_name: module + for module_name, module in previous_modules.items() + if module is not None + } + ) + + +@pytest.fixture() +def import_gateway_models(isolated_gateway_model_import_modules): + """Import gateway models from a clean module state.""" + + def import_models(): + return importlib.import_module(GATEWAY_MODEL_MODULE) + + return import_models diff --git a/projects/policyengine-api-simulation/src/modal/gateway/__init__.py b/projects/policyengine-api-simulation/src/modal/gateway/__init__.py index cc135505b..d15c1d386 100644 --- a/projects/policyengine-api-simulation/src/modal/gateway/__init__.py +++ b/projects/policyengine-api-simulation/src/modal/gateway/__init__.py @@ -1,13 +1,3 @@ """ Gateway package for PolicyEngine Simulation API. """ - -from .endpoints import router -from .models import JobStatusResponse, JobSubmitResponse, SimulationRequest - -__all__ = [ - "router", - "SimulationRequest", - "JobSubmitResponse", - "JobStatusResponse", -] diff --git a/projects/policyengine-api-simulation/tests/conftest.py b/projects/policyengine-api-simulation/tests/conftest.py index 4c934cf30..d5dc62e73 100644 --- a/projects/policyengine-api-simulation/tests/conftest.py +++ b/projects/policyengine-api-simulation/tests/conftest.py @@ -7,6 +7,7 @@ pytest_plugins = ( "fixtures.gateway.shared", "fixtures.gateway.test_endpoints", + "fixtures.gateway.package_imports", ) project_root = Path(__file__).parent.parent diff --git a/projects/policyengine-api-simulation/tests/gateway/test_package_imports.py b/projects/policyengine-api-simulation/tests/gateway/test_package_imports.py new file mode 100644 index 000000000..3e40c4eec --- /dev/null +++ b/projects/policyengine-api-simulation/tests/gateway/test_package_imports.py @@ -0,0 +1,11 @@ +import sys + + +def test_gateway_models_import_does_not_import_fastapi_endpoints( + import_gateway_models, + gateway_import_module_names, +): + import_gateway_models() + + assert gateway_import_module_names.endpoints not in sys.modules + assert gateway_import_module_names.fastapi not in sys.modules diff --git a/projects/policyengine-apis-integ/tests/simulation/conftest.py b/projects/policyengine-apis-integ/tests/simulation/conftest.py index 7139d7aa4..29718de4b 100644 --- a/projects/policyengine-apis-integ/tests/simulation/conftest.py +++ b/projects/policyengine-apis-integ/tests/simulation/conftest.py @@ -1,8 +1,30 @@ +import json +import time +from http import HTTPStatus + import httpx import pytest from pydantic_settings import BaseSettings, SettingsConfigDict from policyengine_api_simulation_client import AuthenticatedClient, Client +from policyengine_api_simulation_client.api.default import ( + get_budget_window_job_status_budget_window_jobs_batch_job_id_get, + submit_budget_window_batch_simulate_economy_budget_window_post, +) +from policyengine_api_simulation_client.models import ( + BudgetWindowBatchRequest, + BudgetWindowBatchStatusResponse, +) + + +BUDGET_WINDOW_YEARS = ["2026", "2027"] +BUDGET_WINDOW_REFORM = { + "gov.irs.credits.ctc.refundable.fully_refundable": {"2023-01-01.2100-12-31": True} +} +BUDGET_WINDOW_DATASET = "gs://policyengine-us-data/enhanced_cps_2024.h5" +BUDGET_WINDOW_REGION = "us" +BUDGET_WINDOW_SUBSAMPLE = 200 +BUDGET_WINDOW_MAX_PARALLEL = 2 class Settings(BaseSettings): @@ -49,3 +71,122 @@ def poll_interval() -> float: def max_wait_seconds() -> float: """Return max wait time in seconds.""" return settings.timeout_in_millis / 1000 + + +def _decode_response_content(content: bytes) -> str: + try: + return json.dumps(json.loads(content), sort_keys=True) + except (json.JSONDecodeError, UnicodeDecodeError): + return content.decode("utf-8", errors="replace") + + +def _poll_budget_window_batch( + *, + client: Client | AuthenticatedClient, + batch_job_id: str, + max_wait_seconds: float, + poll_interval: float, +) -> BudgetWindowBatchStatusResponse: + deadline = time.monotonic() + max_wait_seconds + last_status_code: HTTPStatus | None = None + last_content = b"" + + while time.monotonic() < deadline: + response = get_budget_window_job_status_budget_window_jobs_batch_job_id_get.sync_detailed( + batch_job_id=batch_job_id, client=client + ) + last_status_code = response.status_code + last_content = response.content + + if response.status_code == HTTPStatus.ACCEPTED: + time.sleep(poll_interval) + continue + + if response.status_code == HTTPStatus.OK: + assert isinstance(response.parsed, BudgetWindowBatchStatusResponse), ( + f"Unexpected response type: {type(response.parsed)}" + ) + assert response.parsed.status == "complete", ( + f"Unexpected budget-window status: {response.parsed}" + ) + return response.parsed + + if response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR: + raise AssertionError( + "Budget-window batch failed: " + f"{_decode_response_content(response.content)}" + ) + + raise AssertionError( + "Unexpected budget-window poll status " + f"{response.status_code}: {_decode_response_content(response.content)}" + ) + + raise TimeoutError( + f"Budget-window batch {batch_job_id} did not complete within " + f"{max_wait_seconds}s; last response was " + f"{last_status_code}: {_decode_response_content(last_content)}" + ) + + +@pytest.fixture() +def budget_window_years() -> list[str]: + """Return the annual rows expected from the staging budget-window smoke run.""" + return list(BUDGET_WINDOW_YEARS) + + +@pytest.fixture() +def budget_window_request(us_model_version: str) -> BudgetWindowBatchRequest: + """Build the standard staging budget-window smoke request.""" + return BudgetWindowBatchRequest.from_dict( + { + "country": "us", + "version": us_model_version, + "region": BUDGET_WINDOW_REGION, + "scope": "macro", + "reform": BUDGET_WINDOW_REFORM, + "subsample": BUDGET_WINDOW_SUBSAMPLE, + "data": BUDGET_WINDOW_DATASET, + "start_year": BUDGET_WINDOW_YEARS[0], + "window_size": len(BUDGET_WINDOW_YEARS), + "max_parallel": BUDGET_WINDOW_MAX_PARALLEL, + } + ) + + +@pytest.fixture() +def decode_response_content(): + """Return a compact formatter for non-OK HTTP response payloads.""" + return _decode_response_content + + +@pytest.fixture() +def submit_budget_window_batch(client: Client | AuthenticatedClient): + """Submit a budget-window batch through the generated client.""" + + def submit(request: BudgetWindowBatchRequest): + return submit_budget_window_batch_simulate_economy_budget_window_post.sync_detailed( + client=client, + body=request, + ) + + return submit + + +@pytest.fixture() +def poll_budget_window_batch( + client: Client | AuthenticatedClient, + max_wait_seconds: float, + poll_interval: float, +): + """Poll a budget-window batch through the generated client.""" + + def poll(batch_job_id: str) -> BudgetWindowBatchStatusResponse: + return _poll_budget_window_batch( + client=client, + batch_job_id=batch_job_id, + max_wait_seconds=max_wait_seconds, + poll_interval=poll_interval, + ) + + return poll diff --git a/projects/policyengine-apis-integ/tests/simulation/test_budget_window.py b/projects/policyengine-apis-integ/tests/simulation/test_budget_window.py new file mode 100644 index 000000000..da4c4a5cb --- /dev/null +++ b/projects/policyengine-apis-integ/tests/simulation/test_budget_window.py @@ -0,0 +1,69 @@ +""" +Integration tests for Modal-based budget-window batches. + +These tests run against the staging Modal deployment and verify that the +gateway can spawn the parent budget-window worker, the parent can spawn child +simulation workers, and the completed batch result has the public response +shape expected by API consumers. +""" + +from http import HTTPStatus + +import pytest + +from policyengine_api_simulation_client.models import ( + BudgetWindowBatchSubmitResponse, + BudgetWindowResult, +) +from policyengine_api_simulation_client.types import Unset + + +@pytest.mark.beta_only +def test_budget_window_multi_year_batch_completes( + budget_window_request, + budget_window_years, + decode_response_content, + submit_budget_window_batch, + poll_budget_window_batch, + us_model_version: str, +): + """ + Given a two-year US budget-window request + When the batch is submitted and polled to completion + Then the response contains 2026 and 2027 annual impacts plus totals. + """ + submit_response = submit_budget_window_batch(budget_window_request) + + assert submit_response.status_code == HTTPStatus.OK, ( + "Unexpected submit status " + f"{submit_response.status_code}: " + f"{decode_response_content(submit_response.content)}" + ) + assert isinstance(submit_response.parsed, BudgetWindowBatchSubmitResponse), ( + f"Unexpected response type: {type(submit_response.parsed)}" + ) + assert submit_response.parsed.status == "submitted" + assert submit_response.parsed.version == us_model_version + + batch_job_id = submit_response.parsed.batch_job_id + assert submit_response.parsed.poll_url == f"/budget-window-jobs/{batch_job_id}" + + completed = poll_budget_window_batch(batch_job_id) + + assert completed.status == "complete" + assert completed.progress == 100 + assert completed.error is None or isinstance(completed.error, Unset) + assert isinstance(completed.result, BudgetWindowResult) + + result = completed.result + assert result.kind == "budgetWindow" + assert result.start_year == budget_window_years[0] + assert result.end_year == budget_window_years[-1] + assert result.window_size == len(budget_window_years) + annual_impacts = result.annual_impacts + assert not isinstance(annual_impacts, Unset) + assert [impact.year for impact in annual_impacts] == budget_window_years + assert result.totals.year == "Total" + assert all( + isinstance(impact.budgetary_impact, int | float) for impact in annual_impacts + )