diff --git a/.gitignore b/.gitignore index d0212da..d4d86f3 100644 --- a/.gitignore +++ b/.gitignore @@ -76,3 +76,7 @@ target/ # Local scripts (not part of package) scripts/ + +# Launch directories (local only) +launch/ +launch-video/ diff --git a/diff_diff/honest_did.py b/diff_diff/honest_did.py index 612a3dd..81c8caf 100644 --- a/diff_diff/honest_did.py +++ b/diff_diff/honest_did.py @@ -584,7 +584,12 @@ def _extract_event_study_params( ) # Extract event study effects by relative time - event_effects = results.event_study_effects + # Filter out normalization constraints (n_groups=0) and non-finite SEs + event_effects = { + t: data for t, data in results.event_study_effects.items() + if data.get('n_groups', 1) > 0 + and np.isfinite(data.get('se', np.nan)) + } rel_times = sorted(event_effects.keys()) # Split into pre and post @@ -1261,10 +1266,12 @@ def _estimate_max_pre_violation( from diff_diff.staggered import CallawaySantAnnaResults if isinstance(results, CallawaySantAnnaResults): if results.event_study_effects: + # Filter out normalization constraints (n_groups=0, e.g. reference period) pre_effects = [ abs(results.event_study_effects[t]['effect']) for t in results.event_study_effects if t < 0 + and results.event_study_effects[t].get('n_groups', 1) > 0 ] if pre_effects: return max(pre_effects) diff --git a/diff_diff/pretrends.py b/diff_diff/pretrends.py index fb88e96..d9bc1c6 100644 --- a/diff_diff/pretrends.py +++ b/diff_diff/pretrends.py @@ -656,9 +656,12 @@ def _extract_pre_period_params( ) # Get pre-period effects (negative relative times) + # Filter out normalization constraints (n_groups=0) and non-finite SEs pre_effects = { t: data for t, data in results.event_study_effects.items() if t < 0 + and data.get('n_groups', 1) > 0 + and np.isfinite(data.get('se', np.nan)) } if not pre_effects: @@ -680,9 +683,12 @@ def _extract_pre_period_params( from diff_diff.sun_abraham import SunAbrahamResults if isinstance(results, SunAbrahamResults): # Get pre-period effects (negative relative times) + # Filter out normalization constraints (n_groups=0) and non-finite SEs pre_effects = { t: data for t, data in results.event_study_effects.items() if t < 0 + and data.get('n_groups', 1) > 0 + and np.isfinite(data.get('se', np.nan)) } if not pre_effects: diff --git a/diff_diff/staggered_aggregation.py b/diff_diff/staggered_aggregation.py index 543b924..1419413 100644 --- a/diff_diff/staggered_aggregation.py +++ b/diff_diff/staggered_aggregation.py @@ -34,6 +34,9 @@ class CallawaySantAnnaAggregationMixin: # Type hint for anticipation attribute accessed from main class anticipation: int + # Type hint for base_period attribute accessed from main class + base_period: str + def _aggregate_simple( self, group_time_effects: Dict, @@ -414,6 +417,22 @@ def _aggregate_event_study( 'n_groups': len(effect_list), } + # Add reference period for universal base period mode (matches R did package) + # The reference period e = -1 - anticipation has effect = 0 by construction + # Only add if there are actual computed effects (guard against empty data) + if getattr(self, 'base_period', 'varying') == "universal": + ref_period = -1 - self.anticipation + # Only inject reference if we have at least one real effect + if event_study_effects and ref_period not in event_study_effects: + event_study_effects[ref_period] = { + 'effect': 0.0, + 'se': np.nan, # Undefined - no data, normalization constraint + 't_stat': np.nan, # Undefined - normalization constraint + 'p_value': np.nan, + 'conf_int': (np.nan, np.nan), # NaN propagation for undefined inference + 'n_groups': 0, # No groups contribute - fixed by construction + } + return event_study_effects def _aggregate_by_group( diff --git a/diff_diff/visualization.py b/diff_diff/visualization.py index fbf6ce4..6584b51 100644 --- a/diff_diff/visualization.py +++ b/diff_diff/visualization.py @@ -197,11 +197,17 @@ def plot_event_study( effect = effects.get(period, np.nan) std_err = se.get(period, np.nan) - if np.isnan(effect) or np.isnan(std_err): + # Skip entries with NaN effect, but allow NaN SE (will plot without error bars) + if np.isnan(effect): continue - ci_lower = effect - critical_value * std_err - ci_upper = effect + critical_value * std_err + # Compute CI only if SE is finite + if np.isfinite(std_err): + ci_lower = effect - critical_value * std_err + ci_upper = effect + critical_value * std_err + else: + ci_lower = np.nan + ci_upper = np.nan plot_data.append({ 'period': period, @@ -244,13 +250,20 @@ def plot_event_study( ref_x = period_to_x[reference_period] ax.axvline(x=ref_x, color='gray', linestyle=':', linewidth=1, zorder=1) - # Plot error bars - yerr = [df['effect'] - df['ci_lower'], df['ci_upper'] - df['effect']] - ax.errorbar( - x_vals, df['effect'], yerr=yerr, - fmt='none', color=color, capsize=capsize, linewidth=linewidth, - capthick=linewidth, zorder=2 - ) + # Plot error bars (only for entries with finite CI) + has_ci = df['ci_lower'].notna() & df['ci_upper'].notna() + if has_ci.any(): + df_with_ci = df[has_ci] + x_with_ci = [period_to_x[p] for p in df_with_ci['period']] + yerr = [ + df_with_ci['effect'] - df_with_ci['ci_lower'], + df_with_ci['ci_upper'] - df_with_ci['effect'] + ] + ax.errorbar( + x_with_ci, df_with_ci['effect'], yerr=yerr, + fmt='none', color=color, capsize=capsize, linewidth=linewidth, + capthick=linewidth, zorder=2 + ) # Plot point estimates for i, row in df.iterrows(): @@ -351,7 +364,15 @@ def _extract_plot_data( # Reference period is typically -1 for event study if reference_period is None: - reference_period = -1 + # Detect reference period from n_groups=0 marker (normalization constraint) + # This handles anticipation > 0 where reference is at e = -1 - anticipation + for period, effect_data in results.event_study_effects.items(): + if effect_data.get('n_groups', 1) == 0: + reference_period = period + break + # Fallback to -1 if no marker found (backward compatibility) + if reference_period is None: + reference_period = -1 if pre_periods is None: pre_periods = [p for p in periods if p < 0] diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index fd11469..ec148a0 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -229,6 +229,9 @@ Aggregations: - "universal": All comparisons use g-anticipation-1 as base - Both produce identical post-treatment ATT(g,t); differ only pre-treatment - Matches R `did::att_gt()` base_period parameter + - **Event study output**: With "universal", includes reference period (e=-1-anticipation) + with effect=0, se=NaN, conf_int=(NaN, NaN). Inference fields are NaN since this is + a normalization constraint, not an estimated effect. Only added when real effects exist. - Base period interaction with Sun-Abraham comparison: - CS with `base_period="varying"` produces different pre-treatment estimates than SA - This is expected: CS uses consecutive comparisons, SA uses fixed reference (e=-1-anticipation) diff --git a/tests/test_honest_did.py b/tests/test_honest_did.py index b82cba1..5172ac7 100644 --- a/tests/test_honest_did.py +++ b/tests/test_honest_did.py @@ -681,6 +681,83 @@ def test_very_large_M(self, mock_multiperiod_results): assert isinstance(results, HonestDiDResults) assert results.ci_width > 0 + def test_callaway_santanna_universal_base_period(self): + """Test that reference period (e=-1) is correctly filtered out with universal base period. + + The reference period has n_groups=0 and se=NaN, so it should be excluded + from HonestDiD analysis to avoid contaminating the vcov matrix. + """ + from diff_diff import CallawaySantAnna, generate_staggered_data + + # Generate data and fit with universal base period + data = generate_staggered_data(n_units=200, n_periods=10, seed=42) + cs = CallawaySantAnna(base_period="universal") + results = cs.fit( + data, + outcome='outcome', + unit='unit', + time='period', + first_treat='first_treat', + aggregate='event_study' + ) + + # Verify reference period exists with NaN SE + assert -1 in results.event_study_effects + assert np.isnan(results.event_study_effects[-1]['se']) + + # HonestDiD should work without errors (reference period filtered out) + honest = HonestDiD(method='relative_magnitude', M=1.0) + bounds = honest.fit(results) + + # Should have valid (non-NaN) results + assert isinstance(bounds, HonestDiDResults) + assert np.isfinite(bounds.ci_lb) + assert np.isfinite(bounds.ci_ub) + + def test_max_pre_violation_excludes_reference_period(self): + """Test that reference period (effect=0, n_groups=0) is excluded from max pre-violation. + + With universal base period, the reference period e=-1 is a normalization constraint + with n_groups=0. It should not be used in _estimate_max_pre_violation because + its effect is artificially set to 0, which would collapse RM bounds incorrectly. + """ + from diff_diff import CallawaySantAnna, generate_staggered_data + + # Generate data with universal base period + data = generate_staggered_data(n_units=200, n_periods=10, seed=42) + cs = CallawaySantAnna(base_period="universal") + results = cs.fit( + data, + outcome='outcome', + unit='unit', + time='period', + first_treat='first_treat', + aggregate='event_study' + ) + + # Verify reference period exists with n_groups=0 + assert -1 in results.event_study_effects + assert results.event_study_effects[-1]['n_groups'] == 0 + + # The max pre-violation calculation should exclude the reference period + honest = HonestDiD(method='relative_magnitude', M=1.0) + + # Get pre_periods excluding reference (n_groups=0) + real_pre_periods = [ + t for t in results.event_study_effects + if t < 0 and results.event_study_effects[t].get('n_groups', 1) > 0 + ] + + # If there are real pre-periods, max_violation should be > 0 + # (based on actual pre-period effects, not the reference period's effect=0) + if real_pre_periods: + max_violation = honest._estimate_max_pre_violation(results, real_pre_periods) + # Max violation should reflect actual pre-period coefficients, not 0 + # The actual effects are non-zero due to sampling variation + assert max_violation > 0, ( + "max_pre_violation should be > 0 when real pre-periods exist" + ) + # ============================================================================= # Tests for Visualization (without matplotlib) diff --git a/tests/test_pretrends.py b/tests/test_pretrends.py index c7e8db3..ba2c0a3 100644 --- a/tests/test_pretrends.py +++ b/tests/test_pretrends.py @@ -795,6 +795,38 @@ def test_unsupported_results_type_raises(self): with pytest.raises(TypeError, match="Unsupported results type"): pt.fit("not a results object") + def test_callaway_santanna_universal_base_period(self): + """Test that reference period (e=-1) is correctly filtered out with universal base period. + + The reference period has n_groups=0 and se=NaN, so it should be excluded + from pre-trends power analysis to avoid contaminating the vcov matrix. + """ + from diff_diff import CallawaySantAnna, generate_staggered_data + + # Generate data and fit with universal base period + data = generate_staggered_data(n_units=200, n_periods=10, seed=42) + cs = CallawaySantAnna(base_period="universal") + results = cs.fit( + data, + outcome='outcome', + unit='unit', + time='period', + first_treat='first_treat', + aggregate='event_study' + ) + + # Verify reference period exists with NaN SE + assert -1 in results.event_study_effects + assert np.isnan(results.event_study_effects[-1]['se']) + + # PreTrendsPower should work without errors (reference period filtered out) + pt = PreTrendsPower() + power_results = pt.fit(results) + + # Should have valid (non-NaN) results + assert np.isfinite(power_results.power) + assert power_results.n_pre_periods >= 1 + # ============================================================================= # Tests for visualization (without rendering) diff --git a/tests/test_staggered.py b/tests/test_staggered.py index c0e8a99..382fea8 100644 --- a/tests/test_staggered.py +++ b/tests/test_staggered.py @@ -2821,3 +2821,118 @@ def test_aggregated_tstat_nan_when_se_zero(self): f"Group t_stat for g={g} should be effect/SE={expected_t}, " f"got {t_stat}" ) + + def test_event_study_universal_includes_reference_period(self): + """Test that universal base period includes e=-1 with effect=0.""" + data = generate_staggered_data(n_units=200, n_periods=10, seed=42) + + cs = CallawaySantAnna(base_period="universal") + results = cs.fit( + data, + outcome='outcome', + unit='unit', + time='time', + first_treat='first_treat', + aggregate='event_study' + ) + + assert results.event_study_effects is not None, "event_study_effects should not be None" + + # Reference period should be included + assert -1 in results.event_study_effects, ( + f"Reference period e=-1 should be in event_study_effects, " + f"got periods: {list(results.event_study_effects.keys())}" + ) + ref = results.event_study_effects[-1] + + # Effect is 0 by construction (normalization) + assert ref['effect'] == 0.0, f"Reference period effect should be 0.0, got {ref['effect']}" + # Inference fields are NaN - this is a normalization constraint, not an estimated effect + assert np.isnan(ref['se']), f"Reference period SE should be NaN, got {ref['se']}" + assert np.isnan(ref['t_stat']), f"Reference period t_stat should be NaN, got {ref['t_stat']}" + assert np.isnan(ref['p_value']), f"Reference period p_value should be NaN, got {ref['p_value']}" + assert np.isnan(ref['conf_int'][0]) and np.isnan(ref['conf_int'][1]), ( + f"Reference period CI should be (NaN, NaN), got {ref['conf_int']}" + ) + assert ref['n_groups'] == 0, f"Reference period n_groups should be 0, got {ref['n_groups']}" + + def test_event_study_varying_excludes_reference_period(self): + """Test that varying base period does NOT artificially add e=-1 with effect=0.""" + data = generate_staggered_data(n_units=200, n_periods=10, seed=42) + + cs = CallawaySantAnna(base_period="varying") + results = cs.fit( + data, + outcome='outcome', + unit='unit', + time='time', + first_treat='first_treat', + aggregate='event_study' + ) + + assert results.event_study_effects is not None, "event_study_effects should not be None" + + # Varying mode: no single reference period, e=-1 computed normally or excluded + # The key is we don't artificially add a 0-effect entry + if -1 in results.event_study_effects: + # If it exists, it should be an actual computed effect, not 0.0 with n_groups=0 + assert results.event_study_effects[-1]['n_groups'] > 0, ( + "Varying mode should not artificially add e=-1 with n_groups=0" + ) + + def test_event_study_universal_with_anticipation(self): + """Test reference period with anticipation > 0.""" + data = generate_staggered_data(n_units=200, n_periods=10, seed=42) + + cs = CallawaySantAnna(base_period="universal", anticipation=1) + results = cs.fit( + data, + outcome='outcome', + unit='unit', + time='time', + first_treat='first_treat', + aggregate='event_study' + ) + + assert results.event_study_effects is not None, "event_study_effects should not be None" + + # With anticipation=1, reference is e=-2 + assert -2 in results.event_study_effects, ( + f"With anticipation=1, reference period e=-2 should be in event_study_effects, " + f"got periods: {list(results.event_study_effects.keys())}" + ) + ref = results.event_study_effects[-2] + assert ref['effect'] == 0.0, f"Reference period effect should be 0.0, got {ref['effect']}" + # Inference fields are NaN - normalization constraint + assert np.isnan(ref['se']), f"Reference period SE should be NaN, got {ref['se']}" + assert np.isnan(ref['conf_int'][0]) and np.isnan(ref['conf_int'][1]), ( + f"Reference period CI should be (NaN, NaN), got {ref['conf_int']}" + ) + + def test_event_study_universal_no_effects_raises_error(self): + """Test that estimator raises error when no effects can be computed. + + This ensures the reference period injection code (which has an empty guard) + is never reached with empty effects - the estimator fails fast instead. + """ + import pandas as pd + + # Create minimal data with only never-treated units + # This ensures no ATT(g,t) can be computed (no treatment groups) + data = pd.DataFrame({ + 'unit': [1, 1, 2, 2, 3, 3], + 'time': [1, 2, 1, 2, 1, 2], + 'outcome': [1.0, 1.1, 1.2, 1.3, 1.4, 1.5], + 'first_treat': [0, 0, 0, 0, 0, 0], # All never-treated + }) + + cs = CallawaySantAnna(base_period="universal") + with pytest.raises(ValueError, match="Could not estimate any group-time effects"): + cs.fit( + data, + outcome='outcome', + unit='unit', + time='time', + first_treat='first_treat', + aggregate='event_study' + ) diff --git a/tests/test_visualization.py b/tests/test_visualization.py index d4736a0..4f78254 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -222,6 +222,102 @@ def test_error_invalid_results_type(self): with pytest.raises(TypeError, match="Cannot extract plot data"): plot_event_study("invalid") + def test_plot_with_nan_se_reference_period(self): + """Test that reference period with NaN SE is plotted without error bars. + + With universal base period, the reference period (e=-1) has SE=NaN. + The plot should show the point estimate without error bars rather than + skipping it entirely. + """ + pytest.importorskip("matplotlib") + import matplotlib.pyplot as plt + + # Create data with NaN SE for reference period + effects = {-2: 0.1, -1: 0.0, 0: 0.5, 1: 0.6} + se = {-2: 0.1, -1: np.nan, 0: 0.15, 1: 0.15} # NaN SE at reference period + + ax = plot_event_study( + effects=effects, + se=se, + reference_period=-1, + show=False + ) + + # Verify the plot was created successfully + assert ax is not None + + # Verify all 4 periods are plotted (including reference with NaN SE) + # The x-axis should have 4 tick labels + xtick_labels = [t.get_text() for t in ax.get_xticklabels()] + assert len(xtick_labels) == 4 + assert '-1' in xtick_labels + + plt.close() + + def test_plot_cs_universal_base_period(self): + """Test plotting CallawaySantAnna results with universal base period. + + The reference period (e=-1) should appear in the plot even though + it has SE=NaN. + """ + pytest.importorskip("matplotlib") + import matplotlib.pyplot as plt + from diff_diff import generate_staggered_data + + data = generate_staggered_data(n_units=200, n_periods=10, seed=42) + cs = CallawaySantAnna(base_period="universal") + results = cs.fit( + data, + outcome='outcome', + unit='unit', + time='period', + first_treat='first_treat', + aggregate='event_study' + ) + + # Should not raise even with NaN SE in reference period + ax = plot_event_study(results, show=False) + assert ax is not None + + # Verify reference period (-1) is in the plot + xtick_labels = [t.get_text() for t in ax.get_xticklabels()] + assert '-1' in xtick_labels + + plt.close() + + def test_plot_cs_with_anticipation(self): + """Test plotting CallawaySantAnna results with anticipation > 0. + + When anticipation=1, the reference period should be e=-2, not e=-1. + """ + pytest.importorskip("matplotlib") + import matplotlib.pyplot as plt + from diff_diff import generate_staggered_data + + data = generate_staggered_data(n_units=200, n_periods=10, seed=42) + cs = CallawaySantAnna(base_period="universal", anticipation=1) + results = cs.fit( + data, + outcome='outcome', + unit='unit', + time='period', + first_treat='first_treat', + aggregate='event_study' + ) + + # Reference period should be at e=-2 (not e=-1) with anticipation=1 + assert -2 in results.event_study_effects + assert results.event_study_effects[-2]['n_groups'] == 0 + + ax = plot_event_study(results, show=False) + assert ax is not None + + # Verify -2 is in the plot (the true reference period) + xtick_labels = [t.get_text() for t in ax.get_xticklabels()] + assert '-2' in xtick_labels + + plt.close() + class TestPlotEventStudyIntegration: """Integration tests for event study plotting."""