Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/826.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Preserve Forbes top-tail residence states through PUF local geography assignment.
127 changes: 113 additions & 14 deletions policyengine_us_data/calibration/clone_and_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def assign_random_geography(
household_agi: np.ndarray = None,
cd_agi_targets: dict = None,
agi_threshold_pctile: float = 90.0,
fixed_state_fips: np.ndarray = None,
) -> GeographyAssignment:
"""Assign random census block geography to cloned
CPS records.
Expand All @@ -114,12 +115,20 @@ def assign_random_geography(
dataset.
n_clones: Number of clones (default 10).
seed: Random seed for reproducibility.
fixed_state_fips: Optional state FIPS per base record. Positive
values constrain every clone of that record to blocks in the
requested state; zero or missing values remain unrestricted.

Returns:
GeographyAssignment with arrays of length
n_records * n_clones.
"""
blocks, cds, states, probs = load_global_block_distribution()
fixed_states = _validate_fixed_state_fips(
fixed_state_fips,
n_records=n_records,
available_states=states,
)

n_total = n_records * n_clones
rng = np.random.default_rng(seed)
Expand All @@ -137,7 +146,30 @@ def assign_random_geography(
threshold,
)

def _sample(size, mask_slice=None):
state_draw_cache: dict[tuple[int, str], tuple[np.ndarray, np.ndarray]] = {}

def _state_draw_inputs(state: int, probability_source: str):
key = (int(state), probability_source)
cached = state_draw_cache.get(key)
if cached is not None:
return cached

state_indices = np.flatnonzero(states == state)
base_probs = agi_probs if probability_source == "agi" else probs
state_probs = base_probs[state_indices].astype(np.float64)
if not np.isfinite(state_probs).all() or state_probs.sum() <= 0:
state_probs = probs[state_indices].astype(np.float64)
if not np.isfinite(state_probs).all() or state_probs.sum() <= 0:
state_probs = np.ones(len(state_indices), dtype=np.float64)
state_probs = state_probs / state_probs.sum()
state_draw_cache[key] = (state_indices, state_probs)
return state_indices, state_probs

def _sample_state(state: int, size: int, probability_source: str):
state_indices, state_probs = _state_draw_inputs(state, probability_source)
return rng.choice(state_indices, size=size, p=state_probs)

def _sample_unrestricted(size, mask_slice=None):
"""Sample block indices, using AGI-weighted probs for extreme HHs."""
if (
extreme_mask is not None
Expand All @@ -155,17 +187,53 @@ def _sample(size, mask_slice=None):
return out
return rng.choice(len(blocks), size=size, p=probs)

def _sample(size, mask_slice=None, fixed_slice=None):
out = np.empty(size, dtype=np.int64)
remaining = np.ones(size, dtype=bool)

if fixed_slice is not None:
fixed_slice = np.asarray(fixed_slice, dtype=np.int32)
for state in np.unique(fixed_slice[fixed_slice > 0]):
state_mask = fixed_slice == state
if mask_slice is not None and agi_probs is not None:
extreme_state_mask = state_mask & mask_slice
normal_state_mask = state_mask & ~mask_slice
if extreme_state_mask.any():
out[extreme_state_mask] = _sample_state(
int(state),
int(extreme_state_mask.sum()),
"agi",
)
if normal_state_mask.any():
out[normal_state_mask] = _sample_state(
int(state),
int(normal_state_mask.sum()),
"pop",
)
else:
out[state_mask] = _sample_state(
int(state),
int(state_mask.sum()),
"pop",
)
remaining[state_mask] = False

if remaining.any():
remaining_mask = mask_slice[remaining] if mask_slice is not None else None
out[remaining] = _sample_unrestricted(int(remaining.sum()), remaining_mask)
return out

indices = np.empty(n_total, dtype=np.int64)

# Clone 0: unrestricted draw
indices[:n_records] = _sample(n_records, extreme_mask)
indices[:n_records] = _sample(n_records, extreme_mask, fixed_states)

assigned_cds = np.empty((n_clones, n_records), dtype=object)
assigned_cds[0] = cds[indices[:n_records]]

for clone_idx in range(1, n_clones):
start = clone_idx * n_records
clone_indices = _sample(n_records, extreme_mask)
clone_indices = _sample(n_records, extreme_mask, fixed_states)
clone_cds = cds[clone_indices]

collisions = np.zeros(n_records, dtype=bool)
Expand All @@ -178,18 +246,11 @@ def _sample(size, mask_slice=None):
break
bad_mask = collisions
if extreme_mask is not None and agi_probs is not None:
bad_ext = bad_mask & extreme_mask
bad_norm = bad_mask & ~extreme_mask
if bad_ext.sum() > 0:
clone_indices[bad_ext] = rng.choice(
len(blocks), size=bad_ext.sum(), p=agi_probs
)
if bad_norm.sum() > 0:
clone_indices[bad_norm] = rng.choice(
len(blocks), size=bad_norm.sum(), p=probs
)
replacement = _sample(n_records, extreme_mask, fixed_states)
clone_indices[bad_mask] = replacement[bad_mask]
else:
clone_indices[collisions] = rng.choice(len(blocks), size=n_bad, p=probs)
replacement = _sample(n_records, fixed_slice=fixed_states)
clone_indices[collisions] = replacement[collisions]
clone_cds = cds[clone_indices]
collisions = np.zeros(n_records, dtype=bool)
for prev in range(clone_idx):
Expand All @@ -209,6 +270,44 @@ def _sample(size, mask_slice=None):
)


def _validate_fixed_state_fips(
fixed_state_fips: np.ndarray | None,
n_records: int,
available_states: np.ndarray,
) -> np.ndarray | None:
"""Validate optional record-level state constraints."""

if fixed_state_fips is None:
return None

fixed = np.asarray(fixed_state_fips)
if len(fixed) != n_records:
raise ValueError(
"fixed_state_fips must have one value per base record: "
f"got {len(fixed)} for {n_records} records."
)

fixed = np.nan_to_num(fixed.astype(float), nan=0.0).astype(np.int32)
positive = np.unique(fixed[fixed > 0])
if len(positive) == 0:
return None

available = set(np.asarray(available_states, dtype=np.int32).tolist())
missing = [int(state) for state in positive if int(state) not in available]
if missing:
raise ValueError(
"fixed_state_fips contains states absent from the block "
f"distribution: {missing}"
)

logger.info(
"Preserving fixed state geography for %d of %d records",
int((fixed > 0).sum()),
n_records,
)
return fixed


def save_geography(geography: GeographyAssignment, path) -> None:
"""Save a GeographyAssignment to a compressed .npz file.

Expand Down
85 changes: 83 additions & 2 deletions policyengine_us_data/calibration/unified_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,80 @@ def compute_diagnostics(
)


def _raw_time_period_array(
raw_dataset: dict,
variable: str,
time_period: int,
) -> np.ndarray | None:
"""Extract one variable array from a raw Dataset.load_dataset() dict."""

if variable not in raw_dataset:
return None

values = raw_dataset[variable]
if isinstance(values, dict):
if time_period in values:
values = values[time_period]
elif str(time_period) in values:
values = values[str(time_period)]
else:
return None

try:
return np.asarray(values[...])
except (TypeError, ValueError):
return np.asarray(values)


def _extract_forbes_state_fips_overrides(
raw_dataset: dict,
time_period: int,
n_records: int,
) -> np.ndarray | None:
"""Return fixed-state overrides for Forbes synthetic PUF households."""

from policyengine_us_data.datasets.puf.aggregate_record_utils import (
SYNTHETIC_RECID_START,
)

household_id = _raw_time_period_array(raw_dataset, "household_id", time_period)
forbes_state_fips = _raw_time_period_array(
raw_dataset,
"forbes_state_fips",
time_period,
)
if household_id is None or forbes_state_fips is None:
return None
if len(household_id) != n_records or len(forbes_state_fips) != n_records:
logger.info(
"Skipping Forbes fixed-state overrides because "
"household_id/forbes_state_fips "
"lengths do not match household records: %s/%s vs %s",
len(household_id),
len(forbes_state_fips),
n_records,
)
return None

forbes_state_fips = np.nan_to_num(
np.asarray(forbes_state_fips, dtype=float),
nan=0.0,
).astype(np.int32)
household_id = np.asarray(household_id, dtype=float)

fixed_mask = (forbes_state_fips > 0) & (household_id >= SYNTHETIC_RECID_START)
if not fixed_mask.any():
return None

fixed_state_fips = np.zeros(n_records, dtype=np.int32)
fixed_state_fips[fixed_mask] = forbes_state_fips[fixed_mask]
logger.info(
"Detected %d Forbes synthetic households with fixed state_fips",
int(fixed_mask.sum()),
)
return fixed_state_fips


def run_calibration(
dataset_path: str,
db_path: str,
Expand Down Expand Up @@ -1193,7 +1267,8 @@ def run_calibration(
logger.info("Loading dataset from %s", dataset_path)
sim = Microsimulation(dataset=dataset_path)
n_records = len(sim.calculate("household_id", map_to="household").values)
raw_keys = sim.dataset.load_dataset()["household_id"]
raw_dataset = sim.dataset.load_dataset()
raw_keys = raw_dataset["household_id"]
if isinstance(raw_keys, dict):
time_period = int(next(iter(raw_keys)))
else:
Expand Down Expand Up @@ -1221,6 +1296,11 @@ def run_calibration(
"Loaded %d CD AGI targets for conditional assignment",
len(cd_agi_targets),
)
fixed_state_fips = _extract_forbes_state_fips_overrides(
raw_dataset=raw_dataset,
time_period=time_period,
n_records=n_records,
)

# Step 2: Clone and assign geography
logger.info(
Expand All @@ -1235,6 +1315,7 @@ def run_calibration(
seed=seed,
household_agi=base_agi,
cd_agi_targets=cd_agi_targets,
fixed_state_fips=fixed_state_fips,
)

# Step 3: Source imputation (if requested)
Expand All @@ -1245,7 +1326,7 @@ def run_calibration(

base_states = geography.state_fips[:n_records]

raw_data = sim.dataset.load_dataset()
raw_data = raw_dataset
data_dict = {}
for var in raw_data:
val = raw_data[var]
Expand Down
16 changes: 15 additions & 1 deletion policyengine_us_data/datasets/puf/disaggregate_puf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
import numpy as np
import pandas as pd
from . import aggregate_record_utils as utils
from .forbes_backbone import build_forbes_top_tail_bucket
from .forbes_backbone import (
FORBES_TOP_TAIL_METADATA_DEFAULTS,
build_forbes_top_tail_bucket,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,6 +63,17 @@ def disaggregate_aggregate_records(
if agg_mask.sum() == 0:
return puf

if use_forbes_top_tail:
missing_metadata = [
column
for column in FORBES_TOP_TAIL_METADATA_DEFAULTS
if column not in puf.columns
]
if missing_metadata:
puf = puf.copy()
for column in missing_metadata:
puf[column] = FORBES_TOP_TAIL_METADATA_DEFAULTS[column]

agg_rows = puf[agg_mask].copy().set_index("RECID")
regular = puf[~agg_mask].copy()
amount_columns = _get_amount_columns(puf.columns)
Expand Down
Loading
Loading