Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,7 @@ target/

# Local scripts (not part of package)
scripts/

# Launch directories (local only)
launch/
launch-video/
9 changes: 8 additions & 1 deletion diff_diff/honest_did.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,12 @@ def _extract_event_study_params(
)

# Extract event study effects by relative time
event_effects = results.event_study_effects
# Filter out normalization constraints (n_groups=0) and non-finite SEs
event_effects = {
t: data for t, data in results.event_study_effects.items()
if data.get('n_groups', 1) > 0
and np.isfinite(data.get('se', np.nan))
}
rel_times = sorted(event_effects.keys())

# Split into pre and post
Expand Down Expand Up @@ -1261,10 +1266,12 @@ def _estimate_max_pre_violation(
from diff_diff.staggered import CallawaySantAnnaResults
if isinstance(results, CallawaySantAnnaResults):
if results.event_study_effects:
# Filter out normalization constraints (n_groups=0, e.g. reference period)
pre_effects = [
abs(results.event_study_effects[t]['effect'])
for t in results.event_study_effects
if t < 0
and results.event_study_effects[t].get('n_groups', 1) > 0
]
if pre_effects:
return max(pre_effects)
Expand Down
6 changes: 6 additions & 0 deletions diff_diff/pretrends.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,9 +656,12 @@ def _extract_pre_period_params(
)

# Get pre-period effects (negative relative times)
# Filter out normalization constraints (n_groups=0) and non-finite SEs
pre_effects = {
t: data for t, data in results.event_study_effects.items()
if t < 0
and data.get('n_groups', 1) > 0
and np.isfinite(data.get('se', np.nan))
}

if not pre_effects:
Expand All @@ -680,9 +683,12 @@ def _extract_pre_period_params(
from diff_diff.sun_abraham import SunAbrahamResults
if isinstance(results, SunAbrahamResults):
# Get pre-period effects (negative relative times)
# Filter out normalization constraints (n_groups=0) and non-finite SEs
pre_effects = {
t: data for t, data in results.event_study_effects.items()
if t < 0
and data.get('n_groups', 1) > 0
and np.isfinite(data.get('se', np.nan))
}

if not pre_effects:
Expand Down
19 changes: 19 additions & 0 deletions diff_diff/staggered_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class CallawaySantAnnaAggregationMixin:
# Type hint for anticipation attribute accessed from main class
anticipation: int

# Type hint for base_period attribute accessed from main class
base_period: str

def _aggregate_simple(
self,
group_time_effects: Dict,
Expand Down Expand Up @@ -414,6 +417,22 @@ def _aggregate_event_study(
'n_groups': len(effect_list),
}

# Add reference period for universal base period mode (matches R did package)
# The reference period e = -1 - anticipation has effect = 0 by construction
# Only add if there are actual computed effects (guard against empty data)
if getattr(self, 'base_period', 'varying') == "universal":
ref_period = -1 - self.anticipation
# Only inject reference if we have at least one real effect
if event_study_effects and ref_period not in event_study_effects:
event_study_effects[ref_period] = {
'effect': 0.0,
'se': np.nan, # Undefined - no data, normalization constraint
't_stat': np.nan, # Undefined - normalization constraint
'p_value': np.nan,
'conf_int': (np.nan, np.nan), # NaN propagation for undefined inference
'n_groups': 0, # No groups contribute - fixed by construction
}

return event_study_effects

def _aggregate_by_group(
Expand Down
43 changes: 32 additions & 11 deletions diff_diff/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,17 @@ def plot_event_study(
effect = effects.get(period, np.nan)
std_err = se.get(period, np.nan)

if np.isnan(effect) or np.isnan(std_err):
# Skip entries with NaN effect, but allow NaN SE (will plot without error bars)
if np.isnan(effect):
continue

ci_lower = effect - critical_value * std_err
ci_upper = effect + critical_value * std_err
# Compute CI only if SE is finite
if np.isfinite(std_err):
ci_lower = effect - critical_value * std_err
ci_upper = effect + critical_value * std_err
else:
ci_lower = np.nan
ci_upper = np.nan

plot_data.append({
'period': period,
Expand Down Expand Up @@ -244,13 +250,20 @@ def plot_event_study(
ref_x = period_to_x[reference_period]
ax.axvline(x=ref_x, color='gray', linestyle=':', linewidth=1, zorder=1)

# Plot error bars
yerr = [df['effect'] - df['ci_lower'], df['ci_upper'] - df['effect']]
ax.errorbar(
x_vals, df['effect'], yerr=yerr,
fmt='none', color=color, capsize=capsize, linewidth=linewidth,
capthick=linewidth, zorder=2
)
# Plot error bars (only for entries with finite CI)
has_ci = df['ci_lower'].notna() & df['ci_upper'].notna()
if has_ci.any():
df_with_ci = df[has_ci]
x_with_ci = [period_to_x[p] for p in df_with_ci['period']]
yerr = [
df_with_ci['effect'] - df_with_ci['ci_lower'],
df_with_ci['ci_upper'] - df_with_ci['effect']
]
ax.errorbar(
x_with_ci, df_with_ci['effect'], yerr=yerr,
fmt='none', color=color, capsize=capsize, linewidth=linewidth,
capthick=linewidth, zorder=2
)

# Plot point estimates
for i, row in df.iterrows():
Expand Down Expand Up @@ -351,7 +364,15 @@ def _extract_plot_data(

# Reference period is typically -1 for event study
if reference_period is None:
reference_period = -1
# Detect reference period from n_groups=0 marker (normalization constraint)
# This handles anticipation > 0 where reference is at e = -1 - anticipation
for period, effect_data in results.event_study_effects.items():
if effect_data.get('n_groups', 1) == 0:
reference_period = period
break
# Fallback to -1 if no marker found (backward compatibility)
if reference_period is None:
reference_period = -1

if pre_periods is None:
pre_periods = [p for p in periods if p < 0]
Expand Down
3 changes: 3 additions & 0 deletions docs/methodology/REGISTRY.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ Aggregations:
- "universal": All comparisons use g-anticipation-1 as base
- Both produce identical post-treatment ATT(g,t); differ only pre-treatment
- Matches R `did::att_gt()` base_period parameter
- **Event study output**: With "universal", includes reference period (e=-1-anticipation)
with effect=0, se=NaN, conf_int=(NaN, NaN). Inference fields are NaN since this is
a normalization constraint, not an estimated effect. Only added when real effects exist.
- Base period interaction with Sun-Abraham comparison:
- CS with `base_period="varying"` produces different pre-treatment estimates than SA
- This is expected: CS uses consecutive comparisons, SA uses fixed reference (e=-1-anticipation)
Expand Down
77 changes: 77 additions & 0 deletions tests/test_honest_did.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,83 @@ def test_very_large_M(self, mock_multiperiod_results):
assert isinstance(results, HonestDiDResults)
assert results.ci_width > 0

def test_callaway_santanna_universal_base_period(self):
"""Test that reference period (e=-1) is correctly filtered out with universal base period.

The reference period has n_groups=0 and se=NaN, so it should be excluded
from HonestDiD analysis to avoid contaminating the vcov matrix.
"""
from diff_diff import CallawaySantAnna, generate_staggered_data

# Generate data and fit with universal base period
data = generate_staggered_data(n_units=200, n_periods=10, seed=42)
cs = CallawaySantAnna(base_period="universal")
results = cs.fit(
data,
outcome='outcome',
unit='unit',
time='period',
first_treat='first_treat',
aggregate='event_study'
)

# Verify reference period exists with NaN SE
assert -1 in results.event_study_effects
assert np.isnan(results.event_study_effects[-1]['se'])

# HonestDiD should work without errors (reference period filtered out)
honest = HonestDiD(method='relative_magnitude', M=1.0)
bounds = honest.fit(results)

# Should have valid (non-NaN) results
assert isinstance(bounds, HonestDiDResults)
assert np.isfinite(bounds.ci_lb)
assert np.isfinite(bounds.ci_ub)

def test_max_pre_violation_excludes_reference_period(self):
"""Test that reference period (effect=0, n_groups=0) is excluded from max pre-violation.

With universal base period, the reference period e=-1 is a normalization constraint
with n_groups=0. It should not be used in _estimate_max_pre_violation because
its effect is artificially set to 0, which would collapse RM bounds incorrectly.
"""
from diff_diff import CallawaySantAnna, generate_staggered_data

# Generate data with universal base period
data = generate_staggered_data(n_units=200, n_periods=10, seed=42)
cs = CallawaySantAnna(base_period="universal")
results = cs.fit(
data,
outcome='outcome',
unit='unit',
time='period',
first_treat='first_treat',
aggregate='event_study'
)

# Verify reference period exists with n_groups=0
assert -1 in results.event_study_effects
assert results.event_study_effects[-1]['n_groups'] == 0

# The max pre-violation calculation should exclude the reference period
honest = HonestDiD(method='relative_magnitude', M=1.0)

# Get pre_periods excluding reference (n_groups=0)
real_pre_periods = [
t for t in results.event_study_effects
if t < 0 and results.event_study_effects[t].get('n_groups', 1) > 0
]

# If there are real pre-periods, max_violation should be > 0
# (based on actual pre-period effects, not the reference period's effect=0)
if real_pre_periods:
max_violation = honest._estimate_max_pre_violation(results, real_pre_periods)
# Max violation should reflect actual pre-period coefficients, not 0
# The actual effects are non-zero due to sampling variation
assert max_violation > 0, (
"max_pre_violation should be > 0 when real pre-periods exist"
)


# =============================================================================
# Tests for Visualization (without matplotlib)
Expand Down
32 changes: 32 additions & 0 deletions tests/test_pretrends.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,38 @@ def test_unsupported_results_type_raises(self):
with pytest.raises(TypeError, match="Unsupported results type"):
pt.fit("not a results object")

def test_callaway_santanna_universal_base_period(self):
"""Test that reference period (e=-1) is correctly filtered out with universal base period.

The reference period has n_groups=0 and se=NaN, so it should be excluded
from pre-trends power analysis to avoid contaminating the vcov matrix.
"""
from diff_diff import CallawaySantAnna, generate_staggered_data

# Generate data and fit with universal base period
data = generate_staggered_data(n_units=200, n_periods=10, seed=42)
cs = CallawaySantAnna(base_period="universal")
results = cs.fit(
data,
outcome='outcome',
unit='unit',
time='period',
first_treat='first_treat',
aggregate='event_study'
)

# Verify reference period exists with NaN SE
assert -1 in results.event_study_effects
assert np.isnan(results.event_study_effects[-1]['se'])

# PreTrendsPower should work without errors (reference period filtered out)
pt = PreTrendsPower()
power_results = pt.fit(results)

# Should have valid (non-NaN) results
assert np.isfinite(power_results.power)
assert power_results.n_pre_periods >= 1


# =============================================================================
# Tests for visualization (without rendering)
Expand Down
Loading