Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Fixtures for gateway package import regression tests."""

from __future__ import annotations

import importlib
import sys
from collections.abc import Iterator
from dataclasses import dataclass

import pytest


GATEWAY_MODEL_MODULE = "src.modal.gateway.models"
GATEWAY_ENDPOINTS_MODULE = "src.modal.gateway.endpoints"
GATEWAY_PACKAGE_MODULE = "src.modal.gateway"
FASTAPI_MODULE = "fastapi"

GATEWAY_MODEL_IMPORT_MODULES = (
FASTAPI_MODULE,
GATEWAY_PACKAGE_MODULE,
GATEWAY_ENDPOINTS_MODULE,
GATEWAY_MODEL_MODULE,
)


@dataclass(frozen=True)
class GatewayImportModuleNames:
"""Module names involved in the gateway model import boundary."""

endpoints: str = GATEWAY_ENDPOINTS_MODULE
fastapi: str = FASTAPI_MODULE


@pytest.fixture()
def gateway_import_module_names() -> GatewayImportModuleNames:
return GatewayImportModuleNames()


@pytest.fixture()
def isolated_gateway_model_import_modules() -> Iterator[None]:
"""Temporarily clear modules that would mask import side effects."""
previous_modules = {
module_name: sys.modules.pop(module_name, None)
for module_name in GATEWAY_MODEL_IMPORT_MODULES
}

try:
yield
finally:
for module_name in GATEWAY_MODEL_IMPORT_MODULES:
sys.modules.pop(module_name, None)
sys.modules.update(
{
module_name: module
for module_name, module in previous_modules.items()
if module is not None
}
)


@pytest.fixture()
def import_gateway_models(isolated_gateway_model_import_modules):
"""Import gateway models from a clean module state."""

def import_models():
return importlib.import_module(GATEWAY_MODEL_MODULE)

return import_models
10 changes: 0 additions & 10 deletions projects/policyengine-api-simulation/src/modal/gateway/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,3 @@
"""
Gateway package for PolicyEngine Simulation API.
"""

from .endpoints import router
from .models import JobStatusResponse, JobSubmitResponse, SimulationRequest

__all__ = [
"router",
"SimulationRequest",
"JobSubmitResponse",
"JobStatusResponse",
]
1 change: 1 addition & 0 deletions projects/policyengine-api-simulation/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
pytest_plugins = (
"fixtures.gateway.shared",
"fixtures.gateway.test_endpoints",
"fixtures.gateway.package_imports",
)

project_root = Path(__file__).parent.parent
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import sys


def test_gateway_models_import_does_not_import_fastapi_endpoints(
import_gateway_models,
gateway_import_module_names,
):
import_gateway_models()

assert gateway_import_module_names.endpoints not in sys.modules
assert gateway_import_module_names.fastapi not in sys.modules
141 changes: 141 additions & 0 deletions projects/policyengine-apis-integ/tests/simulation/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,30 @@
import json
import time
from http import HTTPStatus

import httpx
import pytest
from pydantic_settings import BaseSettings, SettingsConfigDict

from policyengine_api_simulation_client import AuthenticatedClient, Client
from policyengine_api_simulation_client.api.default import (
get_budget_window_job_status_budget_window_jobs_batch_job_id_get,
submit_budget_window_batch_simulate_economy_budget_window_post,
)
from policyengine_api_simulation_client.models import (
BudgetWindowBatchRequest,
BudgetWindowBatchStatusResponse,
)


BUDGET_WINDOW_YEARS = ["2026", "2027"]
BUDGET_WINDOW_REFORM = {
"gov.irs.credits.ctc.refundable.fully_refundable": {"2023-01-01.2100-12-31": True}
}
BUDGET_WINDOW_DATASET = "gs://policyengine-us-data/enhanced_cps_2024.h5"
BUDGET_WINDOW_REGION = "us"
BUDGET_WINDOW_SUBSAMPLE = 200
BUDGET_WINDOW_MAX_PARALLEL = 2


class Settings(BaseSettings):
Expand Down Expand Up @@ -49,3 +71,122 @@ def poll_interval() -> float:
def max_wait_seconds() -> float:
"""Return max wait time in seconds."""
return settings.timeout_in_millis / 1000


def _decode_response_content(content: bytes) -> str:
try:
return json.dumps(json.loads(content), sort_keys=True)
except (json.JSONDecodeError, UnicodeDecodeError):
return content.decode("utf-8", errors="replace")


def _poll_budget_window_batch(
*,
client: Client | AuthenticatedClient,
batch_job_id: str,
max_wait_seconds: float,
poll_interval: float,
) -> BudgetWindowBatchStatusResponse:
deadline = time.monotonic() + max_wait_seconds
last_status_code: HTTPStatus | None = None
last_content = b""

while time.monotonic() < deadline:
response = get_budget_window_job_status_budget_window_jobs_batch_job_id_get.sync_detailed(
batch_job_id=batch_job_id, client=client
)
last_status_code = response.status_code
last_content = response.content

if response.status_code == HTTPStatus.ACCEPTED:
time.sleep(poll_interval)
continue

if response.status_code == HTTPStatus.OK:
assert isinstance(response.parsed, BudgetWindowBatchStatusResponse), (
f"Unexpected response type: {type(response.parsed)}"
)
assert response.parsed.status == "complete", (
f"Unexpected budget-window status: {response.parsed}"
)
return response.parsed

if response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR:
raise AssertionError(
"Budget-window batch failed: "
f"{_decode_response_content(response.content)}"
)

raise AssertionError(
"Unexpected budget-window poll status "
f"{response.status_code}: {_decode_response_content(response.content)}"
)

raise TimeoutError(
f"Budget-window batch {batch_job_id} did not complete within "
f"{max_wait_seconds}s; last response was "
f"{last_status_code}: {_decode_response_content(last_content)}"
)


@pytest.fixture()
def budget_window_years() -> list[str]:
"""Return the annual rows expected from the staging budget-window smoke run."""
return list(BUDGET_WINDOW_YEARS)


@pytest.fixture()
def budget_window_request(us_model_version: str) -> BudgetWindowBatchRequest:
"""Build the standard staging budget-window smoke request."""
return BudgetWindowBatchRequest.from_dict(
{
"country": "us",
"version": us_model_version,
"region": BUDGET_WINDOW_REGION,
"scope": "macro",
"reform": BUDGET_WINDOW_REFORM,
"subsample": BUDGET_WINDOW_SUBSAMPLE,
"data": BUDGET_WINDOW_DATASET,
"start_year": BUDGET_WINDOW_YEARS[0],
"window_size": len(BUDGET_WINDOW_YEARS),
"max_parallel": BUDGET_WINDOW_MAX_PARALLEL,
}
)


@pytest.fixture()
def decode_response_content():
"""Return a compact formatter for non-OK HTTP response payloads."""
return _decode_response_content


@pytest.fixture()
def submit_budget_window_batch(client: Client | AuthenticatedClient):
"""Submit a budget-window batch through the generated client."""

def submit(request: BudgetWindowBatchRequest):
return submit_budget_window_batch_simulate_economy_budget_window_post.sync_detailed(
client=client,
body=request,
)

return submit


@pytest.fixture()
def poll_budget_window_batch(
client: Client | AuthenticatedClient,
max_wait_seconds: float,
poll_interval: float,
):
"""Poll a budget-window batch through the generated client."""

def poll(batch_job_id: str) -> BudgetWindowBatchStatusResponse:
return _poll_budget_window_batch(
client=client,
batch_job_id=batch_job_id,
max_wait_seconds=max_wait_seconds,
poll_interval=poll_interval,
)

return poll
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
Integration tests for Modal-based budget-window batches.

These tests run against the staging Modal deployment and verify that the
gateway can spawn the parent budget-window worker, the parent can spawn child
simulation workers, and the completed batch result has the public response
shape expected by API consumers.
"""

from http import HTTPStatus

import pytest

from policyengine_api_simulation_client.models import (
BudgetWindowBatchSubmitResponse,
BudgetWindowResult,
)
from policyengine_api_simulation_client.types import Unset


@pytest.mark.beta_only
def test_budget_window_multi_year_batch_completes(
budget_window_request,
budget_window_years,
decode_response_content,
submit_budget_window_batch,
poll_budget_window_batch,
us_model_version: str,
):
"""
Given a two-year US budget-window request
When the batch is submitted and polled to completion
Then the response contains 2026 and 2027 annual impacts plus totals.
"""
submit_response = submit_budget_window_batch(budget_window_request)

assert submit_response.status_code == HTTPStatus.OK, (
"Unexpected submit status "
f"{submit_response.status_code}: "
f"{decode_response_content(submit_response.content)}"
)
assert isinstance(submit_response.parsed, BudgetWindowBatchSubmitResponse), (
f"Unexpected response type: {type(submit_response.parsed)}"
)
assert submit_response.parsed.status == "submitted"
assert submit_response.parsed.version == us_model_version

batch_job_id = submit_response.parsed.batch_job_id
assert submit_response.parsed.poll_url == f"/budget-window-jobs/{batch_job_id}"

completed = poll_budget_window_batch(batch_job_id)

assert completed.status == "complete"
assert completed.progress == 100
assert completed.error is None or isinstance(completed.error, Unset)
assert isinstance(completed.result, BudgetWindowResult)

result = completed.result
assert result.kind == "budgetWindow"
assert result.start_year == budget_window_years[0]
assert result.end_year == budget_window_years[-1]
assert result.window_size == len(budget_window_years)
annual_impacts = result.annual_impacts
assert not isinstance(annual_impacts, Unset)
assert [impact.year for impact in annual_impacts] == budget_window_years
assert result.totals.year == "Total"
assert all(
isinstance(impact.budgetary_impact, int | float) for impact in annual_impacts
)