From 67d2af7ed114a4c1495fb49dd07649ff57548bc8 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Wed, 8 Apr 2026 11:18:12 -0400 Subject: [PATCH 01/10] Add clone origin diagnostics for ECPS --- .../calibration/puf_impute.py | 22 +++ .../datasets/cps/enhanced_cps.py | 186 ++++++++++++++++++ .../utils/dataset_validation.py | 26 +++ .../test_calibration_puf_impute.py | 27 +++ tests/unit/test_dataset_validation.py | 56 ++++++ .../test_enhanced_cps_clone_diagnostics.py | 37 ++++ 6 files changed, 354 insertions(+) create mode 100644 tests/unit/test_enhanced_cps_clone_diagnostics.py diff --git a/policyengine_us_data/calibration/puf_impute.py b/policyengine_us_data/calibration/puf_impute.py index b87f846f8..a2159fa6e 100644 --- a/policyengine_us_data/calibration/puf_impute.py +++ b/policyengine_us_data/calibration/puf_impute.py @@ -25,6 +25,14 @@ logger = logging.getLogger(__name__) +CLONE_ORIGIN_FLAGS = { + "person": "person_is_puf_clone", + "tax_unit": "tax_unit_is_puf_clone", + "spm_unit": "spm_unit_is_puf_clone", + "family": "family_is_puf_clone", + "household": "household_is_puf_clone", +} + PUF_SUBSAMPLE_TARGET = 20_000 PUF_TOP_PERCENTILE = 99.5 @@ -531,6 +539,20 @@ def _map_to_entity(pred_values, variable_name): time_period: np.concatenate([state_fips, state_fips]).astype(np.int32) } + for entity_key, flag_name in CLONE_ORIGIN_FLAGS.items(): + id_variable = f"{entity_key}_id" + if id_variable not in data: + continue + n_entities = len(data[id_variable][time_period]) + new_data[flag_name] = { + time_period: np.concatenate( + [ + np.zeros(n_entities, dtype=np.int8), + np.ones(n_entities, dtype=np.int8), + ] + ) + } + if y_full: for var in IMPUTED_VARIABLES: if var not in data: diff --git a/policyengine_us_data/datasets/cps/enhanced_cps.py b/policyengine_us_data/datasets/cps/enhanced_cps.py index 275d71e2b..f8cf8fe95 100644 --- a/policyengine_us_data/datasets/cps/enhanced_cps.py +++ b/policyengine_us_data/datasets/cps/enhanced_cps.py @@ -1,3 +1,7 @@ +import json +from pathlib import Path + +import h5py from policyengine_core.data import Dataset import pandas as pd from policyengine_us_data.utils import ( @@ -28,6 +32,173 @@ torch = None +def _to_numpy(value) -> np.ndarray: + return np.asarray(getattr(value, "values", value)) + + +def _weighted_share(mask, weights) -> float: + weights = np.asarray(weights, dtype=np.float64) + total_weight = float(weights.sum()) + if total_weight <= 0: + return 0.0 + mask = np.asarray(mask, dtype=bool) + return 100 * float(weights[mask].sum()) / total_weight + + +def compute_clone_diagnostics_summary( + *, + household_is_puf_clone, + household_weight, + person_is_puf_clone, + person_weight, + person_in_poverty, + person_reported_in_poverty, + spm_unit_is_puf_clone, + spm_unit_weight, + spm_unit_capped_work_childcare_expenses, + spm_unit_pre_subsidy_childcare_expenses, + spm_unit_taxes, + spm_unit_market_income, +) -> dict[str, float]: + household_is_puf_clone = np.asarray(household_is_puf_clone, dtype=bool) + household_weight = np.asarray(household_weight, dtype=np.float64) + person_is_puf_clone = np.asarray(person_is_puf_clone, dtype=bool) + person_weight = np.asarray(person_weight, dtype=np.float64) + person_in_poverty = np.asarray(person_in_poverty, dtype=bool) + person_reported_in_poverty = np.asarray(person_reported_in_poverty, dtype=bool) + spm_unit_is_puf_clone = np.asarray(spm_unit_is_puf_clone, dtype=bool) + spm_unit_weight = np.asarray(spm_unit_weight, dtype=np.float64) + capped_childcare = np.asarray( + spm_unit_capped_work_childcare_expenses, dtype=np.float64 + ) + pre_subsidy_childcare = np.asarray( + spm_unit_pre_subsidy_childcare_expenses, dtype=np.float64 + ) + spm_unit_taxes = np.asarray(spm_unit_taxes, dtype=np.float64) + spm_unit_market_income = np.asarray(spm_unit_market_income, dtype=np.float64) + + poor_modeled_only = person_in_poverty & ~person_reported_in_poverty + clone_spm_weight = spm_unit_weight[spm_unit_is_puf_clone].sum() + + return { + "clone_household_weight_share_pct": _weighted_share( + household_is_puf_clone, household_weight + ), + "clone_person_weight_share_pct": _weighted_share( + person_is_puf_clone, person_weight + ), + "clone_poor_modeled_only_person_weight_share_pct": _weighted_share( + person_is_puf_clone & poor_modeled_only, + person_weight, + ), + "poor_modeled_only_within_clone_person_weight_share_pct": ( + 0.0 + if person_weight[person_is_puf_clone].sum() <= 0 + else _weighted_share( + poor_modeled_only[person_is_puf_clone], + person_weight[person_is_puf_clone], + ) + ), + "clone_childcare_exceeds_pre_subsidy_share_pct": ( + 0.0 + if clone_spm_weight <= 0 + else _weighted_share( + capped_childcare[spm_unit_is_puf_clone] + > pre_subsidy_childcare[spm_unit_is_puf_clone] + 1, + spm_unit_weight[spm_unit_is_puf_clone], + ) + ), + "clone_childcare_above_5000_share_pct": ( + 0.0 + if clone_spm_weight <= 0 + else _weighted_share( + capped_childcare[spm_unit_is_puf_clone] > 5_000, + spm_unit_weight[spm_unit_is_puf_clone], + ) + ), + "clone_taxes_exceed_market_income_share_pct": ( + 0.0 + if clone_spm_weight <= 0 + else _weighted_share( + spm_unit_taxes[spm_unit_is_puf_clone] + > spm_unit_market_income[spm_unit_is_puf_clone] + 1, + spm_unit_weight[spm_unit_is_puf_clone], + ) + ), + } + + +def _load_saved_period_array( + file_path: str | Path, + variable_name: str, + period: int, +) -> np.ndarray: + with h5py.File(file_path, "r") as h5_file: + obj = h5_file[variable_name] + if isinstance(obj, h5py.Dataset): + return np.asarray(obj[...]) + period_key = str(period) + if period_key in obj: + return np.asarray(obj[period_key][...]) + if period in obj: + return np.asarray(obj[period][...]) + raise KeyError(f"{variable_name} missing period {period}") + + +def clone_diagnostics_path(file_path: str | Path) -> Path: + return Path(file_path).with_suffix(".clone_diagnostics.json") + + +def write_clone_diagnostics_report(file_path: str | Path, diagnostics: dict) -> Path: + output_path = clone_diagnostics_path(file_path) + output_path.write_text(json.dumps(diagnostics, indent=2, sort_keys=True) + "\n") + return output_path + + +def build_clone_diagnostics_for_saved_dataset( + dataset_cls: Type[Dataset], period: int +) -> dict[str, float]: + from policyengine_us import Microsimulation + + sim = Microsimulation(dataset=dataset_cls) + dataset_path = Path(dataset_cls.file_path) + + person_reported_in_poverty = _to_numpy( + sim.calculate("spm_unit_net_income_reported", period=period, map_to="person") + ) < _to_numpy( + sim.calculate("spm_unit_spm_threshold", period=period, map_to="person") + ) + + return compute_clone_diagnostics_summary( + household_is_puf_clone=_load_saved_period_array( + dataset_path, "household_is_puf_clone", period + ), + household_weight=_to_numpy(sim.calculate("household_weight", period=period)), + person_is_puf_clone=_load_saved_period_array( + dataset_path, "person_is_puf_clone", period + ), + person_weight=_to_numpy(sim.calculate("person_weight", period=period)), + person_in_poverty=_to_numpy( + sim.calculate("person_in_poverty", period=period) + ), + person_reported_in_poverty=person_reported_in_poverty, + spm_unit_is_puf_clone=_load_saved_period_array( + dataset_path, "spm_unit_is_puf_clone", period + ), + spm_unit_weight=_to_numpy(sim.calculate("spm_unit_weight", period=period)), + spm_unit_capped_work_childcare_expenses=_to_numpy( + sim.calculate("spm_unit_capped_work_childcare_expenses", period=period) + ), + spm_unit_pre_subsidy_childcare_expenses=_to_numpy( + sim.calculate("spm_unit_pre_subsidy_childcare_expenses", period=period) + ), + spm_unit_taxes=_to_numpy(sim.calculate("spm_unit_taxes", period=period)), + spm_unit_market_income=_to_numpy( + sim.calculate("spm_unit_market_income", period=period) + ), + ) + + def _get_period_array(period_values: dict, period: int) -> np.ndarray: """Get a period array from a TIME_PERIOD_ARRAYS variable dict.""" value = period_values.get(period) @@ -313,6 +484,21 @@ def generate(self): logging.info("Post-generation weight validation passed") self.save_dataset(data) + try: + diagnostics = build_clone_diagnostics_for_saved_dataset( + type(self), + base_year, + ) + diagnostics["period"] = base_year + output_path = write_clone_diagnostics_report(self.file_path, diagnostics) + logging.info("Saved clone diagnostics to %s", output_path) + logging.info("Clone diagnostics summary: %s", diagnostics) + except Exception: + logging.warning( + "Unable to compute clone diagnostics for %s", + self.file_path, + exc_info=True, + ) class ReweightedCPS_2024(Dataset): diff --git a/policyengine_us_data/utils/dataset_validation.py b/policyengine_us_data/utils/dataset_validation.py index 932003860..58960b472 100644 --- a/policyengine_us_data/utils/dataset_validation.py +++ b/policyengine_us_data/utils/dataset_validation.py @@ -22,6 +22,14 @@ "household": "household_id", } +AUXILIARY_ENTITY_PREFIXES = { + "person_": "person", + "tax_unit_": "tax_unit", + "family_": "family", + "spm_unit_": "spm_unit", + "household_": "household", +} + class DatasetContractError(Exception): """Raised when a built dataset does not match the active country package.""" @@ -127,6 +135,24 @@ def _infer_auxiliary_entity( entity_counts: dict[str, int], file_name: str, ) -> str: + for prefix, entity_key in AUXILIARY_ENTITY_PREFIXES.items(): + if not variable_name.startswith(prefix): + continue + expected_length = entity_counts.get(entity_key) + if expected_length is None: + raise DatasetContractError( + f"{file_name} contains auxiliary variable {variable_name} with a " + f"{entity_key}-scoped prefix, but {ENTITY_ID_VARIABLES[entity_key]} " + "is missing from the dataset." + ) + if actual_length != expected_length: + raise DatasetContractError( + f"{file_name} contains auxiliary variable {variable_name} with " + f"{prefix!r} prefix and expected {entity_key} length " + f"{expected_length}, found {actual_length}." + ) + return entity_key + candidate_entities = [ entity_key for entity_key, entity_count in entity_counts.items() diff --git a/tests/unit/calibration/test_calibration_puf_impute.py b/tests/unit/calibration/test_calibration_puf_impute.py index d803486ee..13a2f5176 100644 --- a/tests/unit/calibration/test_calibration_puf_impute.py +++ b/tests/unit/calibration/test_calibration_puf_impute.py @@ -34,9 +34,11 @@ def _make_mock_data(n_persons=20, n_households=5, time_period=2024): "household_id": {time_period: np.arange(1, n_households + 1)}, "tax_unit_id": {time_period: np.arange(1, n_households + 1)}, "spm_unit_id": {time_period: np.arange(1, n_households + 1)}, + "family_id": {time_period: np.arange(1, n_households + 1)}, "person_household_id": {time_period: household_ids_person}, "person_tax_unit_id": {time_period: tax_unit_ids_person}, "person_spm_unit_id": {time_period: spm_unit_ids_person}, + "person_family_id": {time_period: household_ids_person}, "age": {time_period: ages.astype(np.float32)}, "is_male": {time_period: is_male.astype(np.float32)}, "household_weight": {time_period: np.ones(n_households) * 1000}, @@ -135,6 +137,31 @@ def test_overridden_subset_of_imputed(self): for var in OVERRIDDEN_IMPUTED_VARIABLES: assert var in IMPUTED_VARIABLES + def test_clone_origin_flags_are_added(self): + data = _make_mock_data(n_persons=20, n_households=5) + state_fips = np.array([1, 2, 36, 6, 48]) + + result = puf_clone_dataset( + data=data, + state_fips=state_fips, + time_period=2024, + skip_qrf=True, + ) + + expected_lengths = { + "person_is_puf_clone": 20, + "tax_unit_is_puf_clone": 5, + "spm_unit_is_puf_clone": 5, + "family_is_puf_clone": 5, + "household_is_puf_clone": 5, + } + + for variable_name, half_length in expected_lengths.items(): + values = result[variable_name][2024] + assert values.dtype == np.int8 + np.testing.assert_array_equal(values[:half_length], 0) + np.testing.assert_array_equal(values[half_length:], 1) + class TestStratifiedSubsample: def test_noop_when_small(self): diff --git a/tests/unit/test_dataset_validation.py b/tests/unit/test_dataset_validation.py index c1ee1ea84..af9941577 100644 --- a/tests/unit/test_dataset_validation.py +++ b/tests/unit/test_dataset_validation.py @@ -248,3 +248,59 @@ def test_validate_dataset_contract_rejects_entity_length_mismatch( microsimulation_cls=_FakeMicrosimulation, dataset_loader=lambda path: path, ) + + +def test_validate_dataset_contract_accepts_prefixed_auxiliary_entity( + tmp_path, monkeypatch +): + file_path = tmp_path / "prefixed_aux.h5" + _write_test_h5( + file_path, + { + "person_id": np.array([101], dtype=np.int32), + "household_id": np.array([201], dtype=np.int32), + "household_is_puf_clone": np.array([1], dtype=np.int8), + }, + ) + monkeypatch.setattr( + "policyengine_us_data.utils.dataset_validation.assert_locked_policyengine_us_version", + lambda: PolicyEngineUSBuildInfo(version="1.587.0"), + ) + + summary = validate_dataset_contract( + file_path, + tax_benefit_system=_fake_tax_benefit_system(), + microsimulation_cls=_FakeMicrosimulation, + dataset_loader=lambda path: path, + ) + + assert summary.variable_count == 3 + + +def test_validate_dataset_contract_rejects_prefixed_auxiliary_length_mismatch( + tmp_path, monkeypatch +): + file_path = tmp_path / "prefixed_aux_bad.h5" + _write_test_h5( + file_path, + { + "person_id": np.array([101], dtype=np.int32), + "household_id": np.array([201], dtype=np.int32), + "household_is_puf_clone": np.array([1, 0], dtype=np.int8), + }, + ) + monkeypatch.setattr( + "policyengine_us_data.utils.dataset_validation.assert_locked_policyengine_us_version", + lambda: PolicyEngineUSBuildInfo(version="1.587.0"), + ) + + with pytest.raises( + DatasetContractError, + match="expected household length 1, found 2", + ): + validate_dataset_contract( + file_path, + tax_benefit_system=_fake_tax_benefit_system(), + microsimulation_cls=_FakeMicrosimulation, + dataset_loader=lambda path: path, + ) diff --git a/tests/unit/test_enhanced_cps_clone_diagnostics.py b/tests/unit/test_enhanced_cps_clone_diagnostics.py new file mode 100644 index 000000000..89e2047ef --- /dev/null +++ b/tests/unit/test_enhanced_cps_clone_diagnostics.py @@ -0,0 +1,37 @@ +import pytest + +from policyengine_us_data.datasets.cps.enhanced_cps import ( + compute_clone_diagnostics_summary, +) + + +def test_compute_clone_diagnostics_summary(): + diagnostics = compute_clone_diagnostics_summary( + household_is_puf_clone=[False, True], + household_weight=[9.0, 1.0], + person_is_puf_clone=[False, True, True], + person_weight=[4.0, 3.0, 3.0], + person_in_poverty=[False, True, True], + person_reported_in_poverty=[False, False, True], + spm_unit_is_puf_clone=[False, True, True], + spm_unit_weight=[2.0, 3.0, 5.0], + spm_unit_capped_work_childcare_expenses=[0.0, 6000.0, 7000.0], + spm_unit_pre_subsidy_childcare_expenses=[0.0, 5000.0, 8000.0], + spm_unit_taxes=[100.0, 9000.0, 200.0], + spm_unit_market_income=[1000.0, 8000.0, 1000.0], + ) + + assert diagnostics["clone_household_weight_share_pct"] == pytest.approx(10.0) + assert diagnostics["clone_poor_modeled_only_person_weight_share_pct"] == pytest.approx( + 30.0 + ) + assert diagnostics["poor_modeled_only_within_clone_person_weight_share_pct"] == pytest.approx( + 50.0 + ) + assert diagnostics["clone_childcare_exceeds_pre_subsidy_share_pct"] == pytest.approx( + 37.5 + ) + assert diagnostics["clone_childcare_above_5000_share_pct"] == pytest.approx(100.0) + assert diagnostics["clone_taxes_exceed_market_income_share_pct"] == pytest.approx( + 37.5 + ) From b06a2bd5955b7623d7d5c3178fe22dfeffb2a646 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Wed, 8 Apr 2026 11:18:51 -0400 Subject: [PATCH 02/10] Add changelog fragment for clone diagnostics --- changelog.d/703.fixed | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/703.fixed diff --git a/changelog.d/703.fixed b/changelog.d/703.fixed new file mode 100644 index 000000000..b3dd0c7d0 --- /dev/null +++ b/changelog.d/703.fixed @@ -0,0 +1 @@ +Added explicit clone-origin flags to extended/enhanced CPS datasets and saved ECPS clone diagnostics for clone weight share, modeled-only-poor share, and extreme childcare/tax checks. From f908809ae833afdb0d79c75a070d42ef4c191751 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Wed, 8 Apr 2026 12:17:40 -0400 Subject: [PATCH 03/10] Format clone diagnostics files --- .../datasets/cps/enhanced_cps.py | 4 +--- .../test_enhanced_cps_clone_diagnostics.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/policyengine_us_data/datasets/cps/enhanced_cps.py b/policyengine_us_data/datasets/cps/enhanced_cps.py index f8cf8fe95..c0c1a42b3 100644 --- a/policyengine_us_data/datasets/cps/enhanced_cps.py +++ b/policyengine_us_data/datasets/cps/enhanced_cps.py @@ -178,9 +178,7 @@ def build_clone_diagnostics_for_saved_dataset( dataset_path, "person_is_puf_clone", period ), person_weight=_to_numpy(sim.calculate("person_weight", period=period)), - person_in_poverty=_to_numpy( - sim.calculate("person_in_poverty", period=period) - ), + person_in_poverty=_to_numpy(sim.calculate("person_in_poverty", period=period)), person_reported_in_poverty=person_reported_in_poverty, spm_unit_is_puf_clone=_load_saved_period_array( dataset_path, "spm_unit_is_puf_clone", period diff --git a/tests/unit/test_enhanced_cps_clone_diagnostics.py b/tests/unit/test_enhanced_cps_clone_diagnostics.py index 89e2047ef..fe8704173 100644 --- a/tests/unit/test_enhanced_cps_clone_diagnostics.py +++ b/tests/unit/test_enhanced_cps_clone_diagnostics.py @@ -22,15 +22,15 @@ def test_compute_clone_diagnostics_summary(): ) assert diagnostics["clone_household_weight_share_pct"] == pytest.approx(10.0) - assert diagnostics["clone_poor_modeled_only_person_weight_share_pct"] == pytest.approx( - 30.0 - ) - assert diagnostics["poor_modeled_only_within_clone_person_weight_share_pct"] == pytest.approx( - 50.0 - ) - assert diagnostics["clone_childcare_exceeds_pre_subsidy_share_pct"] == pytest.approx( - 37.5 - ) + assert diagnostics[ + "clone_poor_modeled_only_person_weight_share_pct" + ] == pytest.approx(30.0) + assert diagnostics[ + "poor_modeled_only_within_clone_person_weight_share_pct" + ] == pytest.approx(50.0) + assert diagnostics[ + "clone_childcare_exceeds_pre_subsidy_share_pct" + ] == pytest.approx(37.5) assert diagnostics["clone_childcare_above_5000_share_pct"] == pytest.approx(100.0) assert diagnostics["clone_taxes_exceed_market_income_share_pct"] == pytest.approx( 37.5 From a8553841d1b0c0ceecfef85dad85bcb0a2123604 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Wed, 8 Apr 2026 13:40:39 -0400 Subject: [PATCH 04/10] Harden clone diagnostics reporting --- .../datasets/cps/enhanced_cps.py | 54 ++++++++++++++++--- .../test_enhanced_cps_clone_diagnostics.py | 48 +++++++++++++++++ 2 files changed, 96 insertions(+), 6 deletions(-) diff --git a/policyengine_us_data/datasets/cps/enhanced_cps.py b/policyengine_us_data/datasets/cps/enhanced_cps.py index c0c1a42b3..863865261 100644 --- a/policyengine_us_data/datasets/cps/enhanced_cps.py +++ b/policyengine_us_data/datasets/cps/enhanced_cps.py @@ -149,12 +149,43 @@ def clone_diagnostics_path(file_path: str | Path) -> Path: return Path(file_path).with_suffix(".clone_diagnostics.json") +def build_clone_diagnostics_payload( + period_to_diagnostics: dict[int, dict[str, float]], +) -> dict: + if not period_to_diagnostics: + raise ValueError("Expected at least one period of clone diagnostics") + + ordered_periods = sorted(period_to_diagnostics) + if len(ordered_periods) == 1: + period = ordered_periods[0] + diagnostics = dict(period_to_diagnostics[period]) + diagnostics["period"] = int(period) + return diagnostics + + return { + "periods": { + str(period): period_to_diagnostics[period] for period in ordered_periods + } + } + + def write_clone_diagnostics_report(file_path: str | Path, diagnostics: dict) -> Path: output_path = clone_diagnostics_path(file_path) output_path.write_text(json.dumps(diagnostics, indent=2, sort_keys=True) + "\n") return output_path +def refresh_clone_diagnostics_report( + file_path: str | Path, + diagnostics_builder, +) -> Path: + output_path = clone_diagnostics_path(file_path) + if output_path.exists(): + output_path.unlink() + diagnostics = diagnostics_builder() + return write_clone_diagnostics_report(file_path, diagnostics) + + def build_clone_diagnostics_for_saved_dataset( dataset_cls: Type[Dataset], period: int ) -> dict[str, float]: @@ -483,14 +514,25 @@ def generate(self): self.save_dataset(data) try: - diagnostics = build_clone_diagnostics_for_saved_dataset( - type(self), - base_year, + periods = list(range(self.start_year, self.end_year + 1)) + diagnostics_payload = build_clone_diagnostics_payload( + { + period: build_clone_diagnostics_for_saved_dataset( + type(self), + period, + ) + for period in periods + } + ) + output_path = refresh_clone_diagnostics_report( + self.file_path, + lambda: diagnostics_payload, ) - diagnostics["period"] = base_year - output_path = write_clone_diagnostics_report(self.file_path, diagnostics) logging.info("Saved clone diagnostics to %s", output_path) - logging.info("Clone diagnostics summary: %s", diagnostics) + logging.info( + "Clone diagnostics summary: %s", + diagnostics_payload, + ) except Exception: logging.warning( "Unable to compute clone diagnostics for %s", diff --git a/tests/unit/test_enhanced_cps_clone_diagnostics.py b/tests/unit/test_enhanced_cps_clone_diagnostics.py index fe8704173..728f1fc13 100644 --- a/tests/unit/test_enhanced_cps_clone_diagnostics.py +++ b/tests/unit/test_enhanced_cps_clone_diagnostics.py @@ -1,7 +1,12 @@ +from pathlib import Path + import pytest from policyengine_us_data.datasets.cps.enhanced_cps import ( + build_clone_diagnostics_payload, compute_clone_diagnostics_summary, + clone_diagnostics_path, + refresh_clone_diagnostics_report, ) @@ -35,3 +40,46 @@ def test_compute_clone_diagnostics_summary(): assert diagnostics["clone_taxes_exceed_market_income_share_pct"] == pytest.approx( 37.5 ) + + +def test_build_clone_diagnostics_payload_single_period(): + payload = build_clone_diagnostics_payload( + {2024: {"clone_person_weight_share_pct": 12.5}} + ) + + assert payload == { + "period": 2024, + "clone_person_weight_share_pct": 12.5, + } + + +def test_build_clone_diagnostics_payload_multiple_periods(): + payload = build_clone_diagnostics_payload( + { + 2026: {"clone_person_weight_share_pct": 20.0}, + 2024: {"clone_person_weight_share_pct": 10.0}, + } + ) + + assert payload == { + "periods": { + "2024": {"clone_person_weight_share_pct": 10.0}, + "2026": {"clone_person_weight_share_pct": 20.0}, + } + } + + +def test_refresh_clone_diagnostics_report_removes_stale_sidecar_on_failure(tmp_path): + file_path = tmp_path / "enhanced_cps_2024.h5" + file_path.write_text("placeholder") + stale_path = clone_diagnostics_path(file_path) + stale_path.write_text("stale") + + def _raise(): + raise RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + refresh_clone_diagnostics_report(file_path, _raise) + + assert stale_path == Path(file_path).with_suffix(".clone_diagnostics.json") + assert not stale_path.exists() From dacb1f6d9d3be8c0c442c7143e649a51ec11cd54 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Wed, 8 Apr 2026 21:49:08 -0400 Subject: [PATCH 05/10] Harden clone diagnostics refresh --- .../datasets/cps/enhanced_cps.py | 40 +++++++++----- .../test_enhanced_cps_clone_diagnostics.py | 53 +++++++++++++++++++ 2 files changed, 80 insertions(+), 13 deletions(-) diff --git a/policyengine_us_data/datasets/cps/enhanced_cps.py b/policyengine_us_data/datasets/cps/enhanced_cps.py index 863865261..656fce122 100644 --- a/policyengine_us_data/datasets/cps/enhanced_cps.py +++ b/policyengine_us_data/datasets/cps/enhanced_cps.py @@ -186,6 +186,29 @@ def refresh_clone_diagnostics_report( return write_clone_diagnostics_report(file_path, diagnostics) +def save_clone_diagnostics_report( + dataset_cls: Type[Dataset], + *, + start_year: int, + end_year: int, +) -> tuple[Path, dict]: + periods = list(range(start_year, end_year + 1)) + output_path = refresh_clone_diagnostics_report( + dataset_cls.file_path, + lambda: build_clone_diagnostics_payload( + { + period: build_clone_diagnostics_for_saved_dataset( + dataset_cls, + period, + ) + for period in periods + } + ), + ) + diagnostics_payload = json.loads(output_path.read_text()) + return output_path, diagnostics_payload + + def build_clone_diagnostics_for_saved_dataset( dataset_cls: Type[Dataset], period: int ) -> dict[str, float]: @@ -514,19 +537,10 @@ def generate(self): self.save_dataset(data) try: - periods = list(range(self.start_year, self.end_year + 1)) - diagnostics_payload = build_clone_diagnostics_payload( - { - period: build_clone_diagnostics_for_saved_dataset( - type(self), - period, - ) - for period in periods - } - ) - output_path = refresh_clone_diagnostics_report( - self.file_path, - lambda: diagnostics_payload, + output_path, diagnostics_payload = save_clone_diagnostics_report( + type(self), + start_year=self.start_year, + end_year=self.end_year, ) logging.info("Saved clone diagnostics to %s", output_path) logging.info( diff --git a/tests/unit/test_enhanced_cps_clone_diagnostics.py b/tests/unit/test_enhanced_cps_clone_diagnostics.py index 728f1fc13..a0261aac3 100644 --- a/tests/unit/test_enhanced_cps_clone_diagnostics.py +++ b/tests/unit/test_enhanced_cps_clone_diagnostics.py @@ -7,6 +7,7 @@ compute_clone_diagnostics_summary, clone_diagnostics_path, refresh_clone_diagnostics_report, + save_clone_diagnostics_report, ) @@ -83,3 +84,55 @@ def _raise(): assert stale_path == Path(file_path).with_suffix(".clone_diagnostics.json") assert not stale_path.exists() + + +def test_save_clone_diagnostics_report_removes_stale_sidecar_on_failure( + tmp_path, monkeypatch +): + class DummyDataset: + file_path = tmp_path / "enhanced_cps_2024.h5" + + DummyDataset.file_path.write_text("placeholder") + stale_path = clone_diagnostics_path(DummyDataset.file_path) + stale_path.write_text("stale") + + monkeypatch.setattr( + "policyengine_us_data.datasets.cps.enhanced_cps.build_clone_diagnostics_for_saved_dataset", + lambda dataset_cls, period: (_ for _ in ()).throw(RuntimeError("boom")), + ) + + with pytest.raises(RuntimeError, match="boom"): + save_clone_diagnostics_report( + DummyDataset, + start_year=2024, + end_year=2024, + ) + + assert not stale_path.exists() + + +def test_save_clone_diagnostics_report_writes_fresh_payload(tmp_path, monkeypatch): + class DummyDataset: + file_path = tmp_path / "enhanced_cps_2024.h5" + + DummyDataset.file_path.write_text("placeholder") + + monkeypatch.setattr( + "policyengine_us_data.datasets.cps.enhanced_cps.build_clone_diagnostics_for_saved_dataset", + lambda dataset_cls, period: {"clone_person_weight_share_pct": float(period)}, + ) + + output_path, payload = save_clone_diagnostics_report( + DummyDataset, + start_year=2024, + end_year=2025, + ) + + assert output_path == clone_diagnostics_path(DummyDataset.file_path) + assert payload == { + "periods": { + "2024": {"clone_person_weight_share_pct": 2024.0}, + "2025": {"clone_person_weight_share_pct": 2025.0}, + } + } + assert output_path.exists() From 1f518c9381f0b489263228031f9baeeab8dd96f2 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Wed, 8 Apr 2026 23:06:09 -0400 Subject: [PATCH 06/10] Upload clone diagnostics and preserve clone flags --- .../datasets/cps/small_enhanced_cps.py | 83 +++++++++++++++-- .../storage/upload_completed_datasets.py | 2 + tests/unit/test_small_enhanced_cps.py | 89 +++++++++++++++++++ tests/unit/test_upload_completed_datasets.py | 77 ++++++++++++++++ 4 files changed, 243 insertions(+), 8 deletions(-) create mode 100644 tests/unit/test_small_enhanced_cps.py diff --git a/policyengine_us_data/datasets/cps/small_enhanced_cps.py b/policyengine_us_data/datasets/cps/small_enhanced_cps.py index a15080321..e825d080b 100644 --- a/policyengine_us_data/datasets/cps/small_enhanced_cps.py +++ b/policyengine_us_data/datasets/cps/small_enhanced_cps.py @@ -1,14 +1,81 @@ +import h5py +import logging import os -import pandas as pd import numpy as np -import h5py - +import pandas as pd from policyengine_us import Microsimulation -from policyengine_us_data.datasets import EnhancedCPS_2024 -from policyengine_us_data.storage import STORAGE_FOLDER from policyengine_core.enums import Enum from policyengine_core.data.dataset import Dataset -import logging +from policyengine_us_data.calibration.puf_impute import CLONE_ORIGIN_FLAGS +from policyengine_us_data.datasets import EnhancedCPS_2024 +from policyengine_us_data.storage import STORAGE_FOLDER + + +def _load_saved_period_array( + h5_file: h5py.File, variable_name: str, period: int +) -> np.ndarray: + obj = h5_file[variable_name] + if isinstance(obj, h5py.Dataset): + return np.asarray(obj[...]) + + period_key = str(period) + if period_key in obj: + return np.asarray(obj[period_key][...]) + if period in obj: + return np.asarray(obj[period][...]) + + raise KeyError(f"{variable_name} missing period {period}") + + +def _attach_clone_origin_flags( + data: dict[str, dict[int, np.ndarray]], + source_dataset_path, +) -> None: + with h5py.File(source_dataset_path, "r") as source_h5: + for entity_key, flag_name in CLONE_ORIGIN_FLAGS.items(): + id_variable = f"{entity_key}_id" + if id_variable not in data: + raise KeyError( + f"Expected {id_variable} in sampled small ECPS data " + f"before attaching {flag_name}" + ) + + source_ids_by_period = {} + source_flags_by_period = {} + for period_key in data[id_variable]: + period = int(period_key) + source_ids = _load_saved_period_array(source_h5, id_variable, period) + source_flags = _load_saved_period_array(source_h5, flag_name, period) + if len(source_ids) != len(source_flags): + raise ValueError( + f"{flag_name} length {len(source_flags)} does not match " + f"{id_variable} length {len(source_ids)} for period {period}" + ) + source_ids_by_period[period_key] = source_ids + source_flags_by_period[period_key] = source_flags + + data[flag_name] = {} + for period_key, sampled_ids in data[id_variable].items(): + source_ids = source_ids_by_period[period_key] + source_flags = source_flags_by_period[period_key] + flag_lookup = { + int(entity_id): np.int8(flag_value) + for entity_id, flag_value in zip(source_ids, source_flags) + } + sampled_ids = np.asarray(sampled_ids) + missing_ids = [ + int(entity_id) + for entity_id in sampled_ids + if int(entity_id) not in flag_lookup + ] + if missing_ids: + raise KeyError( + f"Missing {flag_name} values for sampled IDs {missing_ids[:5]}" + ) + data[flag_name][period_key] = np.asarray( + [flag_lookup[int(entity_id)] for entity_id in sampled_ids], + dtype=np.int8, + ) def create_small_ecps(): @@ -51,6 +118,8 @@ def create_small_ecps(): if len(data[variable]) == 0: del data[variable] + _attach_clone_origin_flags(data, EnhancedCPS_2024.file_path) + with h5py.File(STORAGE_FOLDER / "small_enhanced_cps_2024.h5", "w") as f: for variable, periods in data.items(): grp = f.create_group(variable) @@ -65,7 +134,6 @@ def create_sparse_ecps(): ecps = EnhancedCPS_2024() h5 = ecps.load() sparse_weights = h5["household_weight"][str(time_period)][:] - hh_ids = h5["household_id"][str(time_period)][:] h5.close() template_sim = Microsimulation( @@ -78,7 +146,6 @@ def create_sparse_ecps(): household_weight_column = f"household_weight__{time_period}" df_household_id_column = f"household_id__{time_period}" - df_person_id_column = f"person_id__{time_period}" # Group by household ID and get the first entry for each group h_df = df.groupby(df_household_id_column).first() diff --git a/policyengine_us_data/storage/upload_completed_datasets.py b/policyengine_us_data/storage/upload_completed_datasets.py index a21a94b3c..e88504ab2 100644 --- a/policyengine_us_data/storage/upload_completed_datasets.py +++ b/policyengine_us_data/storage/upload_completed_datasets.py @@ -5,6 +5,7 @@ from policyengine_us_data.datasets import EnhancedCPS_2024 from policyengine_us_data.datasets.cps.cps import CPS_2024 +from policyengine_us_data.datasets.cps.enhanced_cps import clone_diagnostics_path from policyengine_us_data.storage import STORAGE_FOLDER from policyengine_us_data.utils.data_upload import upload_data_files from policyengine_us_data.utils.dataset_validation import ( @@ -187,6 +188,7 @@ def upload_datasets(require_enhanced_cps: bool = True): ] enhanced_files = [ EnhancedCPS_2024.file_path, + clone_diagnostics_path(EnhancedCPS_2024.file_path), STORAGE_FOLDER / "small_enhanced_cps_2024.h5", ] if require_enhanced_cps: diff --git a/tests/unit/test_small_enhanced_cps.py b/tests/unit/test_small_enhanced_cps.py new file mode 100644 index 000000000..f9a926831 --- /dev/null +++ b/tests/unit/test_small_enhanced_cps.py @@ -0,0 +1,89 @@ +import h5py +import numpy as np +import pytest + +from policyengine_us_data.datasets.cps.small_enhanced_cps import ( + _attach_clone_origin_flags, +) + + +def _write_period_group( + h5_file, variable_name: str, values_by_period: dict[int, np.ndarray] +): + group = h5_file.create_group(variable_name) + for period, values in values_by_period.items(): + group.create_dataset(str(period), data=values) + + +def test_attach_clone_origin_flags_maps_sampled_entity_ids(tmp_path): + source_path = tmp_path / "enhanced_cps_2024.h5" + with h5py.File(source_path, "w") as h5_file: + _write_period_group(h5_file, "person_id", {2024: np.array([11, 12, 13])}) + _write_period_group( + h5_file, "person_is_puf_clone", {2024: np.array([0, 1, 0], dtype=np.int8)} + ) + _write_period_group(h5_file, "tax_unit_id", {2024: np.array([21, 22])}) + _write_period_group( + h5_file, "tax_unit_is_puf_clone", {2024: np.array([1, 0], dtype=np.int8)} + ) + _write_period_group(h5_file, "spm_unit_id", {2024: np.array([31, 32])}) + _write_period_group( + h5_file, "spm_unit_is_puf_clone", {2024: np.array([0, 1], dtype=np.int8)} + ) + _write_period_group(h5_file, "family_id", {2024: np.array([41, 42])}) + _write_period_group( + h5_file, "family_is_puf_clone", {2024: np.array([1, 1], dtype=np.int8)} + ) + _write_period_group(h5_file, "household_id", {2024: np.array([51, 52])}) + _write_period_group( + h5_file, "household_is_puf_clone", {2024: np.array([0, 1], dtype=np.int8)} + ) + + sampled_data = { + "person_id": {2024: np.array([12, 11], dtype=np.int32)}, + "tax_unit_id": {2024: np.array([22, 21], dtype=np.int32)}, + "spm_unit_id": {2024: np.array([32, 31], dtype=np.int32)}, + "family_id": {2024: np.array([41, 42], dtype=np.int32)}, + "household_id": {2024: np.array([52, 51], dtype=np.int32)}, + } + + _attach_clone_origin_flags(sampled_data, source_path) + + assert sampled_data["person_is_puf_clone"][2024].tolist() == [1, 0] + assert sampled_data["tax_unit_is_puf_clone"][2024].tolist() == [0, 1] + assert sampled_data["spm_unit_is_puf_clone"][2024].tolist() == [1, 0] + assert sampled_data["family_is_puf_clone"][2024].tolist() == [1, 1] + assert sampled_data["household_is_puf_clone"][2024].tolist() == [1, 0] + + +def test_attach_clone_origin_flags_rejects_missing_sampled_id(tmp_path): + source_path = tmp_path / "enhanced_cps_2024.h5" + with h5py.File(source_path, "w") as h5_file: + for entity_name, id_base in [ + ("person", 10), + ("tax_unit", 20), + ("spm_unit", 30), + ("family", 40), + ("household", 50), + ]: + _write_period_group( + h5_file, + f"{entity_name}_id", + {2024: np.array([id_base], dtype=np.int32)}, + ) + _write_period_group( + h5_file, + f"{entity_name}_is_puf_clone", + {2024: np.array([0], dtype=np.int8)}, + ) + + sampled_data = { + "person_id": {2024: np.array([999], dtype=np.int32)}, + "tax_unit_id": {2024: np.array([20], dtype=np.int32)}, + "spm_unit_id": {2024: np.array([30], dtype=np.int32)}, + "family_id": {2024: np.array([40], dtype=np.int32)}, + "household_id": {2024: np.array([50], dtype=np.int32)}, + } + + with pytest.raises(KeyError, match="person_is_puf_clone"): + _attach_clone_origin_flags(sampled_data, source_path) diff --git a/tests/unit/test_upload_completed_datasets.py b/tests/unit/test_upload_completed_datasets.py index 7a602c340..34265ec18 100644 --- a/tests/unit/test_upload_completed_datasets.py +++ b/tests/unit/test_upload_completed_datasets.py @@ -7,6 +7,7 @@ import policyengine_us_data.storage.upload_completed_datasets as upload_module from policyengine_us_data.storage.upload_completed_datasets import ( DatasetValidationError, + upload_datasets, validate_dataset, ) import policyengine_us_data.utils.dataset_validation as _dv_mod @@ -153,3 +154,79 @@ def test_validate_dataset_infers_time_period_for_flat_h5(tmp_path, monkeypatch): validate_dataset(file_path) assert _TimePeriodCheckingAggregateMicrosimulation.last_dataset.time_period == 2024 + + +def test_upload_datasets_includes_clone_diagnostics_sidecar(tmp_path, monkeypatch): + storage_folder = tmp_path / "storage" + calibration_folder = storage_folder / "calibration" + calibration_folder.mkdir(parents=True) + + cps_path = storage_folder / "cps_2024.h5" + enhanced_path = storage_folder / "enhanced_cps_2024.h5" + diagnostics_path = enhanced_path.with_suffix(".clone_diagnostics.json") + small_path = storage_folder / "small_enhanced_cps_2024.h5" + policy_db = calibration_folder / "policy_data.db" + + for path in [cps_path, enhanced_path, diagnostics_path, small_path, policy_db]: + path.write_text("placeholder") + + monkeypatch.setattr( + upload_module, + "CPS_2024", + SimpleNamespace(file_path=cps_path), + ) + monkeypatch.setattr( + upload_module, + "EnhancedCPS_2024", + SimpleNamespace(file_path=enhanced_path), + ) + monkeypatch.setattr(upload_module, "STORAGE_FOLDER", storage_folder) + monkeypatch.setattr(upload_module, "validate_dataset", lambda file_path: None) + + uploaded_files = [] + + def _capture_upload(*, files, **kwargs): + uploaded_files.extend(files) + + monkeypatch.setattr(upload_module, "upload_data_files", _capture_upload) + + upload_datasets(require_enhanced_cps=True) + + assert uploaded_files == [ + cps_path, + policy_db, + enhanced_path, + diagnostics_path, + small_path, + ] + + +def test_upload_datasets_requires_clone_diagnostics_sidecar(tmp_path, monkeypatch): + storage_folder = tmp_path / "storage" + calibration_folder = storage_folder / "calibration" + calibration_folder.mkdir(parents=True) + + cps_path = storage_folder / "cps_2024.h5" + enhanced_path = storage_folder / "enhanced_cps_2024.h5" + small_path = storage_folder / "small_enhanced_cps_2024.h5" + policy_db = calibration_folder / "policy_data.db" + + for path in [cps_path, enhanced_path, small_path, policy_db]: + path.write_text("placeholder") + + monkeypatch.setattr( + upload_module, + "CPS_2024", + SimpleNamespace(file_path=cps_path), + ) + monkeypatch.setattr( + upload_module, + "EnhancedCPS_2024", + SimpleNamespace(file_path=enhanced_path), + ) + monkeypatch.setattr(upload_module, "STORAGE_FOLDER", storage_folder) + monkeypatch.setattr(upload_module, "validate_dataset", lambda file_path: None) + monkeypatch.setattr(upload_module, "upload_data_files", lambda **kwargs: None) + + with pytest.raises(FileNotFoundError, match="clone_diagnostics"): + upload_datasets(require_enhanced_cps=True) From 1a0678a5fbf9384fcc861fb6bbece4ad6503b387 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Thu, 9 Apr 2026 10:17:17 -0400 Subject: [PATCH 07/10] Handle period keys in small ECPS clone flags --- .../datasets/cps/small_enhanced_cps.py | 10 ++++- tests/unit/test_small_enhanced_cps.py | 41 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/policyengine_us_data/datasets/cps/small_enhanced_cps.py b/policyengine_us_data/datasets/cps/small_enhanced_cps.py index e825d080b..6dd6a02b7 100644 --- a/policyengine_us_data/datasets/cps/small_enhanced_cps.py +++ b/policyengine_us_data/datasets/cps/small_enhanced_cps.py @@ -27,6 +27,14 @@ def _load_saved_period_array( raise KeyError(f"{variable_name} missing period {period}") +def _normalize_annual_period_key(period_key) -> int: + if isinstance(period_key, (int, np.integer)): + return int(period_key) + if hasattr(period_key, "year"): + return int(period_key.year) + return int(str(period_key)) + + def _attach_clone_origin_flags( data: dict[str, dict[int, np.ndarray]], source_dataset_path, @@ -43,7 +51,7 @@ def _attach_clone_origin_flags( source_ids_by_period = {} source_flags_by_period = {} for period_key in data[id_variable]: - period = int(period_key) + period = _normalize_annual_period_key(period_key) source_ids = _load_saved_period_array(source_h5, id_variable, period) source_flags = _load_saved_period_array(source_h5, flag_name, period) if len(source_ids) != len(source_flags): diff --git a/tests/unit/test_small_enhanced_cps.py b/tests/unit/test_small_enhanced_cps.py index f9a926831..9a5089cc0 100644 --- a/tests/unit/test_small_enhanced_cps.py +++ b/tests/unit/test_small_enhanced_cps.py @@ -1,5 +1,6 @@ import h5py import numpy as np +import pandas as pd import pytest from policyengine_us_data.datasets.cps.small_enhanced_cps import ( @@ -87,3 +88,43 @@ def test_attach_clone_origin_flags_rejects_missing_sampled_id(tmp_path): with pytest.raises(KeyError, match="person_is_puf_clone"): _attach_clone_origin_flags(sampled_data, source_path) + + +def test_attach_clone_origin_flags_accepts_period_keys(tmp_path): + source_path = tmp_path / "enhanced_cps_2024.h5" + with h5py.File(source_path, "w") as h5_file: + _write_period_group(h5_file, "person_id", {2024: np.array([11, 12])}) + _write_period_group( + h5_file, + "person_is_puf_clone", + {2024: np.array([0, 1], dtype=np.int8)}, + ) + for entity_name, id_base in [ + ("tax_unit", 20), + ("spm_unit", 30), + ("family", 40), + ("household", 50), + ]: + _write_period_group( + h5_file, + f"{entity_name}_id", + {2024: np.array([id_base], dtype=np.int32)}, + ) + _write_period_group( + h5_file, + f"{entity_name}_is_puf_clone", + {2024: np.array([0], dtype=np.int8)}, + ) + + period = pd.Period("2024", freq="Y") + sampled_data = { + "person_id": {period: np.array([12, 11], dtype=np.int32)}, + "tax_unit_id": {period: np.array([20], dtype=np.int32)}, + "spm_unit_id": {period: np.array([30], dtype=np.int32)}, + "family_id": {period: np.array([40], dtype=np.int32)}, + "household_id": {period: np.array([50], dtype=np.int32)}, + } + + _attach_clone_origin_flags(sampled_data, source_path) + + assert sampled_data["person_is_puf_clone"][period].tolist() == [1, 0] From f83a464301084f4b3f2f02617f1e3661227a2e7a Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Thu, 9 Apr 2026 13:09:00 -0400 Subject: [PATCH 08/10] Align validation with upload artifacts --- .../storage/upload_completed_datasets.py | 16 ++++++---- tests/unit/test_upload_completed_datasets.py | 30 +++++++++++++++++++ 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/policyengine_us_data/storage/upload_completed_datasets.py b/policyengine_us_data/storage/upload_completed_datasets.py index e88504ab2..1ea24d2c3 100644 --- a/policyengine_us_data/storage/upload_completed_datasets.py +++ b/policyengine_us_data/storage/upload_completed_datasets.py @@ -50,6 +50,14 @@ class DatasetValidationError(Exception): pass +def _enhanced_dataset_files() -> list[Path]: + return [ + EnhancedCPS_2024.file_path, + clone_diagnostics_path(EnhancedCPS_2024.file_path), + STORAGE_FOLDER / "small_enhanced_cps_2024.h5", + ] + + def validate_dataset(file_path: Path) -> None: """Validate a dataset file before upload. @@ -186,11 +194,7 @@ def upload_datasets(require_enhanced_cps: bool = True): CPS_2024.file_path, STORAGE_FOLDER / "calibration" / "policy_data.db", ] - enhanced_files = [ - EnhancedCPS_2024.file_path, - clone_diagnostics_path(EnhancedCPS_2024.file_path), - STORAGE_FOLDER / "small_enhanced_cps_2024.h5", - ] + enhanced_files = _enhanced_dataset_files() if require_enhanced_cps: required_files.extend(enhanced_files) @@ -235,7 +239,7 @@ def validate_all_datasets(): def validate_built_datasets(require_enhanced_cps: bool = True): required_files = [CPS_2024.file_path] if require_enhanced_cps: - required_files.append(EnhancedCPS_2024.file_path) + required_files.extend(_enhanced_dataset_files()) for file_path in required_files: if not file_path.exists(): diff --git a/tests/unit/test_upload_completed_datasets.py b/tests/unit/test_upload_completed_datasets.py index 34265ec18..4c9de9f53 100644 --- a/tests/unit/test_upload_completed_datasets.py +++ b/tests/unit/test_upload_completed_datasets.py @@ -9,6 +9,7 @@ DatasetValidationError, upload_datasets, validate_dataset, + validate_built_datasets, ) import policyengine_us_data.utils.dataset_validation as _dv_mod from policyengine_us_data.utils.dataset_validation import validate_dataset_contract @@ -230,3 +231,32 @@ def test_upload_datasets_requires_clone_diagnostics_sidecar(tmp_path, monkeypatc with pytest.raises(FileNotFoundError, match="clone_diagnostics"): upload_datasets(require_enhanced_cps=True) + + +def test_validate_built_datasets_requires_clone_diagnostics_sidecar( + tmp_path, monkeypatch +): + storage_folder = tmp_path / "storage" + cps_path = storage_folder / "cps_2024.h5" + enhanced_path = storage_folder / "enhanced_cps_2024.h5" + small_path = storage_folder / "small_enhanced_cps_2024.h5" + + storage_folder.mkdir(parents=True) + for path in [cps_path, enhanced_path, small_path]: + path.write_text("placeholder") + + monkeypatch.setattr( + upload_module, + "CPS_2024", + SimpleNamespace(file_path=cps_path), + ) + monkeypatch.setattr( + upload_module, + "EnhancedCPS_2024", + SimpleNamespace(file_path=enhanced_path), + ) + monkeypatch.setattr(upload_module, "STORAGE_FOLDER", storage_folder) + monkeypatch.setattr(upload_module, "validate_dataset", lambda file_path: None) + + with pytest.raises(FileNotFoundError, match="clone_diagnostics"): + validate_built_datasets(require_enhanced_cps=True) From 2bdbb79007fcde577e65524d228f1a994b8b541c Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sat, 11 Apr 2026 20:34:13 -0400 Subject: [PATCH 09/10] Fix clone diagnostics weight mapping --- .../datasets/cps/enhanced_cps.py | 29 ++++++++- .../test_enhanced_cps_clone_diagnostics.py | 64 +++++++++++++++++++ 2 files changed, 91 insertions(+), 2 deletions(-) diff --git a/policyengine_us_data/datasets/cps/enhanced_cps.py b/policyengine_us_data/datasets/cps/enhanced_cps.py index 656fce122..87bf28186 100644 --- a/policyengine_us_data/datasets/cps/enhanced_cps.py +++ b/policyengine_us_data/datasets/cps/enhanced_cps.py @@ -217,6 +217,27 @@ def build_clone_diagnostics_for_saved_dataset( sim = Microsimulation(dataset=dataset_cls) dataset_path = Path(dataset_cls.file_path) + return build_clone_diagnostics_for_simulation( + sim, + dataset_path=dataset_path, + period=period, + ) + + +def build_clone_diagnostics_for_simulation( + sim, + *, + dataset_path: str | Path, + period: int, +) -> dict[str, float]: + """Build clone diagnostics from a simulation and saved clone-flag arrays. + + The enhanced CPS save path preserves zeroed person/spm-unit weight inputs on + the clone half. For diagnostics, always map the calibrated household weights + to persons/SPM units explicitly instead of reading those stale entity-level + weight inputs back from disk. + """ + person_reported_in_poverty = _to_numpy( sim.calculate("spm_unit_net_income_reported", period=period, map_to="person") ) < _to_numpy( @@ -231,13 +252,17 @@ def build_clone_diagnostics_for_saved_dataset( person_is_puf_clone=_load_saved_period_array( dataset_path, "person_is_puf_clone", period ), - person_weight=_to_numpy(sim.calculate("person_weight", period=period)), + person_weight=_to_numpy( + sim.calculate("household_weight", period=period, map_to="person") + ), person_in_poverty=_to_numpy(sim.calculate("person_in_poverty", period=period)), person_reported_in_poverty=person_reported_in_poverty, spm_unit_is_puf_clone=_load_saved_period_array( dataset_path, "spm_unit_is_puf_clone", period ), - spm_unit_weight=_to_numpy(sim.calculate("spm_unit_weight", period=period)), + spm_unit_weight=_to_numpy( + sim.calculate("household_weight", period=period, map_to="spm_unit") + ), spm_unit_capped_work_childcare_expenses=_to_numpy( sim.calculate("spm_unit_capped_work_childcare_expenses", period=period) ), diff --git a/tests/unit/test_enhanced_cps_clone_diagnostics.py b/tests/unit/test_enhanced_cps_clone_diagnostics.py index a0261aac3..9ddcc1998 100644 --- a/tests/unit/test_enhanced_cps_clone_diagnostics.py +++ b/tests/unit/test_enhanced_cps_clone_diagnostics.py @@ -1,8 +1,10 @@ from pathlib import Path +import numpy as np import pytest from policyengine_us_data.datasets.cps.enhanced_cps import ( + build_clone_diagnostics_for_simulation, build_clone_diagnostics_payload, compute_clone_diagnostics_summary, clone_diagnostics_path, @@ -43,6 +45,68 @@ def test_compute_clone_diagnostics_summary(): ) +def test_build_clone_diagnostics_for_simulation_maps_household_weights( + monkeypatch, +): + class FakeResult: + def __init__(self, values): + self.values = np.asarray(values) + + class FakeSim: + def calculate(self, variable, period=None, map_to=None): + lookup = { + ("spm_unit_net_income_reported", "person"): [1000.0, 300.0, 100.0], + ("spm_unit_spm_threshold", "person"): [500.0, 200.0, 200.0], + ("household_weight", None): [9.0, 1.0], + ("household_weight", "person"): [9.0, 1.0, 1.0], + ("household_weight", "spm_unit"): [9.0, 1.0], + # Trap values: diagnostics should not read these stale inputs. + ("person_weight", None): [9.0, 0.0, 0.0], + ("spm_unit_weight", None): [9.0, 0.0], + ("person_in_poverty", None): [False, True, True], + ("spm_unit_capped_work_childcare_expenses", None): [0.0, 6000.0], + ("spm_unit_pre_subsidy_childcare_expenses", None): [0.0, 5000.0], + ("spm_unit_taxes", None): [100.0, 9000.0], + ("spm_unit_market_income", None): [1000.0, 8000.0], + } + return FakeResult(lookup[(variable, map_to)]) + + saved_arrays = { + "household_is_puf_clone": np.array([False, True]), + "person_is_puf_clone": np.array([False, True, True]), + "spm_unit_is_puf_clone": np.array([False, True]), + } + + monkeypatch.setattr( + "policyengine_us_data.datasets.cps.enhanced_cps._load_saved_period_array", + lambda dataset_path, variable_name, period: saved_arrays[variable_name], + ) + + diagnostics = build_clone_diagnostics_for_simulation( + FakeSim(), + dataset_path=Path("enhanced_cps_2024.h5"), + period=2024, + ) + + assert diagnostics["clone_household_weight_share_pct"] == pytest.approx(10.0) + assert diagnostics["clone_person_weight_share_pct"] == pytest.approx( + 200.0 / 11.0 + ) + assert diagnostics[ + "clone_poor_modeled_only_person_weight_share_pct" + ] == pytest.approx(100.0 / 11.0) + assert diagnostics[ + "poor_modeled_only_within_clone_person_weight_share_pct" + ] == pytest.approx(50.0) + assert diagnostics[ + "clone_childcare_exceeds_pre_subsidy_share_pct" + ] == pytest.approx(100.0) + assert diagnostics["clone_childcare_above_5000_share_pct"] == pytest.approx(100.0) + assert diagnostics["clone_taxes_exceed_market_income_share_pct"] == pytest.approx( + 100.0 + ) + + def test_build_clone_diagnostics_payload_single_period(): payload = build_clone_diagnostics_payload( {2024: {"clone_person_weight_share_pct": 12.5}} From 4742d6887ae48e8dafb55779ea95c56729dd65b0 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sat, 11 Apr 2026 22:00:19 -0400 Subject: [PATCH 10/10] Format clone diagnostics regression test --- tests/unit/test_enhanced_cps_clone_diagnostics.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/unit/test_enhanced_cps_clone_diagnostics.py b/tests/unit/test_enhanced_cps_clone_diagnostics.py index 9ddcc1998..2b9b1a0a2 100644 --- a/tests/unit/test_enhanced_cps_clone_diagnostics.py +++ b/tests/unit/test_enhanced_cps_clone_diagnostics.py @@ -89,9 +89,7 @@ def calculate(self, variable, period=None, map_to=None): ) assert diagnostics["clone_household_weight_share_pct"] == pytest.approx(10.0) - assert diagnostics["clone_person_weight_share_pct"] == pytest.approx( - 200.0 / 11.0 - ) + assert diagnostics["clone_person_weight_share_pct"] == pytest.approx(200.0 / 11.0) assert diagnostics[ "clone_poor_modeled_only_person_weight_share_pct" ] == pytest.approx(100.0 / 11.0)