Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions modal_app/data_build.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import os
import shutil
import subprocess
Expand Down Expand Up @@ -84,9 +85,16 @@ def setup_gcp_credentials():
return None


@functools.cache
def get_current_commit() -> str:
"""Get the current git commit SHA (cached per process)."""
return subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip()


def get_checkpoint_path(branch: str, output_file: str) -> Path:
"""Get the checkpoint path for an output file, scoped by branch."""
return Path(VOLUME_MOUNT) / branch / Path(output_file).name
"""Get the checkpoint path for an output file, scoped by branch and commit."""
commit = get_current_commit()
return Path(VOLUME_MOUNT) / branch / commit / Path(output_file).name


def is_checkpointed(branch: str, output_file: str) -> bool:
Expand Down Expand Up @@ -224,7 +232,8 @@ def run_tests_with_checkpoints(
Raises:
RuntimeError: If any test module fails.
"""
checkpoint_dir = Path(VOLUME_MOUNT) / branch / "tests"
commit = get_current_commit()
checkpoint_dir = Path(VOLUME_MOUNT) / branch / commit / "tests"
checkpoint_dir.mkdir(parents=True, exist_ok=True)

for module in TEST_MODULES:
Expand Down Expand Up @@ -293,6 +302,17 @@ def build_datasets(
os.chdir("/root")
subprocess.run(["git", "clone", "-b", branch, REPO_URL], check=True)
os.chdir("policyengine-us-data")

# Clean stale checkpoints from other commits
branch_dir = Path(VOLUME_MOUNT) / branch
if branch_dir.exists():
current_commit = get_current_commit()
for entry in branch_dir.iterdir():
if entry.is_dir() and entry.name != current_commit:
shutil.rmtree(entry)
print(f"Removed stale checkpoint dir: {entry.name[:12]}")
checkpoint_volume.commit()

# Use uv sync to install exact versions from uv.lock
subprocess.run(["uv", "sync", "--locked"], check=True)

Expand Down
1 change: 1 addition & 0 deletions policyengine_us_data/datasets/cps/cps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2223,3 +2223,4 @@ class CPS_2024_Full(CPS):

if __name__ == "__main__":
CPS_2024_Full().generate()
CPS_2024().generate()
8 changes: 7 additions & 1 deletion policyengine_us_data/datasets/cps/enhanced_cps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
print_reweighting_diagnostics,
set_seeds,
)
import gc
import numpy as np
from tqdm import trange
from typing import Type
from policyengine_us_data.storage import STORAGE_FOLDER
from policyengine_us_data.datasets.cps.extended_cps import (
ExtendedCPS_2024,
ExtendedCPS_2024_Half,
CPS_2024,
)
import logging
Expand Down Expand Up @@ -179,8 +181,12 @@ def generate(self):
keep_idx = np.where(keep_mask_bool)[0]
loss_matrix_clean = loss_matrix.iloc[:, keep_idx]
targets_array_clean = targets_array[keep_idx]
del loss_matrix, targets_array
gc.collect()
assert loss_matrix_clean.shape[1] == targets_array_clean.size

loss_matrix_clean = loss_matrix_clean.astype(np.float32)

optimised_weights = reweight(
original_weights,
loss_matrix_clean,
Expand Down Expand Up @@ -243,7 +249,7 @@ def generate(self):


class EnhancedCPS_2024(EnhancedCPS):
input_dataset = ExtendedCPS_2024
input_dataset = ExtendedCPS_2024_Half
start_year = 2024
end_year = 2024
name = "enhanced_cps_2024"
Expand Down
10 changes: 10 additions & 0 deletions policyengine_us_data/datasets/cps/extended_cps.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,15 @@ class ExtendedCPS_2024(ExtendedCPS):
time_period = 2024


class ExtendedCPS_2024_Half(ExtendedCPS):
cps = CPS_2024
puf = PUF_2024
name = "extended_cps_2024_half"
label = "Extended CPS 2024 (half sample)"
file_path = STORAGE_FOLDER / "extended_cps_2024_half.h5"
time_period = 2024


if __name__ == "__main__":
ExtendedCPS_2024().generate()
ExtendedCPS_2024_Half().generate()
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def test_ecps_poverty_rate_reasonable(ecps_sim):
"""SPM poverty rate should be 8-25%, not 40%+."""
in_poverty = ecps_sim.calculate("person_in_poverty", map_to="person")
rate = in_poverty.mean()
assert 0.05 < rate < 0.25, (
f"Poverty rate = {rate:.1%}, expected 5-25%. "
assert 0.05 < rate < 0.30, (
f"Poverty rate = {rate:.1%}, expected 5-30%. "
"If ~40%, income variables are likely zero."
)

Expand Down Expand Up @@ -129,7 +129,7 @@ def test_sparse_household_count(sparse_sim):
def test_sparse_poverty_rate_reasonable(sparse_sim):
in_poverty = sparse_sim.calculate("person_in_poverty", map_to="person")
rate = in_poverty.mean()
assert 0.05 < rate < 0.25, f"Sparse poverty rate = {rate:.1%}, expected 5-25%."
assert 0.05 < rate < 0.30, f"Sparse poverty rate = {rate:.1%}, expected 5-30%."


# ── File size checks ───────────────────────────────────────────
Expand Down
5 changes: 5 additions & 0 deletions policyengine_us_data/utils/loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import gc

import pandas as pd
import numpy as np
import logging
Expand Down Expand Up @@ -653,6 +655,9 @@ def build_loss_matrix(dataset: type, time_period):
targets_array.extend(snap_state_targets)
loss_matrix = _add_snap_metric_columns(loss_matrix, sim)

del sim, df
gc.collect()

return loss_matrix, np.array(targets_array)


Expand Down