From c32fb2dcacf0c4c7f249b863c04a1beb5a1887c4 Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 24 Jan 2026 07:25:32 -0500 Subject: [PATCH 1/5] Include reference period (e=-1) in CS event study for universal base period Add the reference period to Callaway-Sant'Anna event study output when using base_period="universal". This matches R `did` package behavior and provides visual clarity in event study plots by including the normalization point. Changes: - Add base_period type hint to CallawaySantAnnaAggregationMixin - Add reference period entry (effect=0, se=0, t_stat=NaN) in _aggregate_event_study - Add 3 tests for reference period behavior (universal, varying, anticipation) - Update REGISTRY.md documentation Co-Authored-By: Claude Opus 4.5 --- diff_diff/staggered_aggregation.py | 17 +++++++ docs/methodology/REGISTRY.md | 2 + tests/test_staggered.py | 78 ++++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+) diff --git a/diff_diff/staggered_aggregation.py b/diff_diff/staggered_aggregation.py index 543b924..0e0f6a7 100644 --- a/diff_diff/staggered_aggregation.py +++ b/diff_diff/staggered_aggregation.py @@ -34,6 +34,9 @@ class CallawaySantAnnaAggregationMixin: # Type hint for anticipation attribute accessed from main class anticipation: int + # Type hint for base_period attribute accessed from main class + base_period: str + def _aggregate_simple( self, group_time_effects: Dict, @@ -414,6 +417,20 @@ def _aggregate_event_study( 'n_groups': len(effect_list), } + # Add reference period for universal base period mode (matches R did package) + # The reference period e = -1 - anticipation has effect = 0 by construction + if getattr(self, 'base_period', 'varying') == "universal": + ref_period = -1 - self.anticipation + if ref_period not in event_study_effects: + event_study_effects[ref_period] = { + 'effect': 0.0, + 'se': 0.0, + 't_stat': np.nan, # Undefined - normalization constraint + 'p_value': np.nan, + 'conf_int': (0.0, 0.0), + 'n_groups': 0, # No groups contribute - fixed by construction + } + return event_study_effects def _aggregate_by_group( diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index fd11469..a4cf96e 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -229,6 +229,8 @@ Aggregations: - "universal": All comparisons use g-anticipation-1 as base - Both produce identical post-treatment ATT(g,t); differ only pre-treatment - Matches R `did::att_gt()` base_period parameter + - **Event study output**: With "universal", includes reference period (e=-1-anticipation) + with effect=0, se=0, matching R behavior for complete event study plots - Base period interaction with Sun-Abraham comparison: - CS with `base_period="varying"` produces different pre-treatment estimates than SA - This is expected: CS uses consecutive comparisons, SA uses fixed reference (e=-1-anticipation) diff --git a/tests/test_staggered.py b/tests/test_staggered.py index c0e8a99..d799fa8 100644 --- a/tests/test_staggered.py +++ b/tests/test_staggered.py @@ -2821,3 +2821,81 @@ def test_aggregated_tstat_nan_when_se_zero(self): f"Group t_stat for g={g} should be effect/SE={expected_t}, " f"got {t_stat}" ) + + def test_event_study_universal_includes_reference_period(self): + """Test that universal base period includes e=-1 with effect=0.""" + data = generate_staggered_data(n_units=200, n_periods=10, seed=42) + + cs = CallawaySantAnna(base_period="universal") + results = cs.fit( + data, + outcome='outcome', + unit='unit', + time='time', + first_treat='first_treat', + aggregate='event_study' + ) + + assert results.event_study_effects is not None, "event_study_effects should not be None" + + # Reference period should be included + assert -1 in results.event_study_effects, ( + f"Reference period e=-1 should be in event_study_effects, " + f"got periods: {list(results.event_study_effects.keys())}" + ) + ref = results.event_study_effects[-1] + + # Effect is 0 by construction (normalization) + assert ref['effect'] == 0.0, f"Reference period effect should be 0.0, got {ref['effect']}" + assert ref['se'] == 0.0, f"Reference period SE should be 0.0, got {ref['se']}" + assert np.isnan(ref['t_stat']), f"Reference period t_stat should be NaN, got {ref['t_stat']}" + assert np.isnan(ref['p_value']), f"Reference period p_value should be NaN, got {ref['p_value']}" + assert ref['conf_int'] == (0.0, 0.0), f"Reference period CI should be (0.0, 0.0), got {ref['conf_int']}" + + def test_event_study_varying_excludes_reference_period(self): + """Test that varying base period does NOT artificially add e=-1 with effect=0.""" + data = generate_staggered_data(n_units=200, n_periods=10, seed=42) + + cs = CallawaySantAnna(base_period="varying") + results = cs.fit( + data, + outcome='outcome', + unit='unit', + time='time', + first_treat='first_treat', + aggregate='event_study' + ) + + assert results.event_study_effects is not None, "event_study_effects should not be None" + + # Varying mode: no single reference period, e=-1 computed normally or excluded + # The key is we don't artificially add a 0-effect entry + if -1 in results.event_study_effects: + # If it exists, it should be an actual computed effect, not 0.0 with n_groups=0 + assert results.event_study_effects[-1]['n_groups'] > 0, ( + "Varying mode should not artificially add e=-1 with n_groups=0" + ) + + def test_event_study_universal_with_anticipation(self): + """Test reference period with anticipation > 0.""" + data = generate_staggered_data(n_units=200, n_periods=10, seed=42) + + cs = CallawaySantAnna(base_period="universal", anticipation=1) + results = cs.fit( + data, + outcome='outcome', + unit='unit', + time='time', + first_treat='first_treat', + aggregate='event_study' + ) + + assert results.event_study_effects is not None, "event_study_effects should not be None" + + # With anticipation=1, reference is e=-2 + assert -2 in results.event_study_effects, ( + f"With anticipation=1, reference period e=-2 should be in event_study_effects, " + f"got periods: {list(results.event_study_effects.keys())}" + ) + ref = results.event_study_effects[-2] + assert ref['effect'] == 0.0, f"Reference period effect should be 0.0, got {ref['effect']}" From 3b0666d7a94b2525d906e4da9419ad8298d78d6a Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 24 Jan 2026 07:47:40 -0500 Subject: [PATCH 2/5] Address code review feedback: use NaN for reference period inference fields - Change conf_int from (0.0, 0.0) to (np.nan, np.nan) for consistency with NaN-propagation policy (P0 fix) - Change se from 0.0 to np.nan since reference period is a normalization constraint, not an estimated effect - Add guard to only inject reference period when real effects exist (P1 fix) - Update tests to expect NaN for se and conf_int - Update REGISTRY.md to document NaN inference for reference period - Add launch/ and launch-video/ to .gitignore Co-Authored-By: Claude Opus 4.5 --- .gitignore | 4 +++ diff_diff/staggered_aggregation.py | 8 +++--- docs/methodology/REGISTRY.md | 3 ++- tests/test_staggered.py | 41 ++++++++++++++++++++++++++++-- 4 files changed, 50 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index d0212da..d4d86f3 100644 --- a/.gitignore +++ b/.gitignore @@ -76,3 +76,7 @@ target/ # Local scripts (not part of package) scripts/ + +# Launch directories (local only) +launch/ +launch-video/ diff --git a/diff_diff/staggered_aggregation.py b/diff_diff/staggered_aggregation.py index 0e0f6a7..1419413 100644 --- a/diff_diff/staggered_aggregation.py +++ b/diff_diff/staggered_aggregation.py @@ -419,15 +419,17 @@ def _aggregate_event_study( # Add reference period for universal base period mode (matches R did package) # The reference period e = -1 - anticipation has effect = 0 by construction + # Only add if there are actual computed effects (guard against empty data) if getattr(self, 'base_period', 'varying') == "universal": ref_period = -1 - self.anticipation - if ref_period not in event_study_effects: + # Only inject reference if we have at least one real effect + if event_study_effects and ref_period not in event_study_effects: event_study_effects[ref_period] = { 'effect': 0.0, - 'se': 0.0, + 'se': np.nan, # Undefined - no data, normalization constraint 't_stat': np.nan, # Undefined - normalization constraint 'p_value': np.nan, - 'conf_int': (0.0, 0.0), + 'conf_int': (np.nan, np.nan), # NaN propagation for undefined inference 'n_groups': 0, # No groups contribute - fixed by construction } diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index a4cf96e..ec148a0 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -230,7 +230,8 @@ Aggregations: - Both produce identical post-treatment ATT(g,t); differ only pre-treatment - Matches R `did::att_gt()` base_period parameter - **Event study output**: With "universal", includes reference period (e=-1-anticipation) - with effect=0, se=0, matching R behavior for complete event study plots + with effect=0, se=NaN, conf_int=(NaN, NaN). Inference fields are NaN since this is + a normalization constraint, not an estimated effect. Only added when real effects exist. - Base period interaction with Sun-Abraham comparison: - CS with `base_period="varying"` produces different pre-treatment estimates than SA - This is expected: CS uses consecutive comparisons, SA uses fixed reference (e=-1-anticipation) diff --git a/tests/test_staggered.py b/tests/test_staggered.py index d799fa8..382fea8 100644 --- a/tests/test_staggered.py +++ b/tests/test_staggered.py @@ -2847,10 +2847,14 @@ def test_event_study_universal_includes_reference_period(self): # Effect is 0 by construction (normalization) assert ref['effect'] == 0.0, f"Reference period effect should be 0.0, got {ref['effect']}" - assert ref['se'] == 0.0, f"Reference period SE should be 0.0, got {ref['se']}" + # Inference fields are NaN - this is a normalization constraint, not an estimated effect + assert np.isnan(ref['se']), f"Reference period SE should be NaN, got {ref['se']}" assert np.isnan(ref['t_stat']), f"Reference period t_stat should be NaN, got {ref['t_stat']}" assert np.isnan(ref['p_value']), f"Reference period p_value should be NaN, got {ref['p_value']}" - assert ref['conf_int'] == (0.0, 0.0), f"Reference period CI should be (0.0, 0.0), got {ref['conf_int']}" + assert np.isnan(ref['conf_int'][0]) and np.isnan(ref['conf_int'][1]), ( + f"Reference period CI should be (NaN, NaN), got {ref['conf_int']}" + ) + assert ref['n_groups'] == 0, f"Reference period n_groups should be 0, got {ref['n_groups']}" def test_event_study_varying_excludes_reference_period(self): """Test that varying base period does NOT artificially add e=-1 with effect=0.""" @@ -2899,3 +2903,36 @@ def test_event_study_universal_with_anticipation(self): ) ref = results.event_study_effects[-2] assert ref['effect'] == 0.0, f"Reference period effect should be 0.0, got {ref['effect']}" + # Inference fields are NaN - normalization constraint + assert np.isnan(ref['se']), f"Reference period SE should be NaN, got {ref['se']}" + assert np.isnan(ref['conf_int'][0]) and np.isnan(ref['conf_int'][1]), ( + f"Reference period CI should be (NaN, NaN), got {ref['conf_int']}" + ) + + def test_event_study_universal_no_effects_raises_error(self): + """Test that estimator raises error when no effects can be computed. + + This ensures the reference period injection code (which has an empty guard) + is never reached with empty effects - the estimator fails fast instead. + """ + import pandas as pd + + # Create minimal data with only never-treated units + # This ensures no ATT(g,t) can be computed (no treatment groups) + data = pd.DataFrame({ + 'unit': [1, 1, 2, 2, 3, 3], + 'time': [1, 2, 1, 2, 1, 2], + 'outcome': [1.0, 1.1, 1.2, 1.3, 1.4, 1.5], + 'first_treat': [0, 0, 0, 0, 0, 0], # All never-treated + }) + + cs = CallawaySantAnna(base_period="universal") + with pytest.raises(ValueError, match="Could not estimate any group-time effects"): + cs.fit( + data, + outcome='outcome', + unit='unit', + time='time', + first_treat='first_treat', + aggregate='event_study' + ) From a34445b1a122d1c97ef7f408223d0153bf82c748 Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 24 Jan 2026 08:01:16 -0500 Subject: [PATCH 3/5] Filter normalization constraints from downstream consumers Address code review feedback: the reference period (e=-1) with NaN SE could contaminate vcov matrices in pretrends and HonestDiD consumers. Changes: - pretrends.py: Filter out entries with n_groups=0 or non-finite SE when extracting pre-period effects for CallawaySantAnna and SunAbraham results - honest_did.py: Filter out entries with n_groups=0 or non-finite SE when extracting event study effects for CallawaySantAnna results - Add tests verifying base_period="universal" works correctly with both PreTrendsPower and HonestDiD (reference period is properly excluded) Co-Authored-By: Claude Opus 4.5 --- diff_diff/honest_did.py | 7 ++++++- diff_diff/pretrends.py | 6 ++++++ tests/test_honest_did.py | 33 +++++++++++++++++++++++++++++++++ tests/test_pretrends.py | 32 ++++++++++++++++++++++++++++++++ 4 files changed, 77 insertions(+), 1 deletion(-) diff --git a/diff_diff/honest_did.py b/diff_diff/honest_did.py index 612a3dd..9ce20db 100644 --- a/diff_diff/honest_did.py +++ b/diff_diff/honest_did.py @@ -584,7 +584,12 @@ def _extract_event_study_params( ) # Extract event study effects by relative time - event_effects = results.event_study_effects + # Filter out normalization constraints (n_groups=0) and non-finite SEs + event_effects = { + t: data for t, data in results.event_study_effects.items() + if data.get('n_groups', 1) > 0 + and np.isfinite(data.get('se', np.nan)) + } rel_times = sorted(event_effects.keys()) # Split into pre and post diff --git a/diff_diff/pretrends.py b/diff_diff/pretrends.py index fb88e96..d9bc1c6 100644 --- a/diff_diff/pretrends.py +++ b/diff_diff/pretrends.py @@ -656,9 +656,12 @@ def _extract_pre_period_params( ) # Get pre-period effects (negative relative times) + # Filter out normalization constraints (n_groups=0) and non-finite SEs pre_effects = { t: data for t, data in results.event_study_effects.items() if t < 0 + and data.get('n_groups', 1) > 0 + and np.isfinite(data.get('se', np.nan)) } if not pre_effects: @@ -680,9 +683,12 @@ def _extract_pre_period_params( from diff_diff.sun_abraham import SunAbrahamResults if isinstance(results, SunAbrahamResults): # Get pre-period effects (negative relative times) + # Filter out normalization constraints (n_groups=0) and non-finite SEs pre_effects = { t: data for t, data in results.event_study_effects.items() if t < 0 + and data.get('n_groups', 1) > 0 + and np.isfinite(data.get('se', np.nan)) } if not pre_effects: diff --git a/tests/test_honest_did.py b/tests/test_honest_did.py index b82cba1..5df181d 100644 --- a/tests/test_honest_did.py +++ b/tests/test_honest_did.py @@ -681,6 +681,39 @@ def test_very_large_M(self, mock_multiperiod_results): assert isinstance(results, HonestDiDResults) assert results.ci_width > 0 + def test_callaway_santanna_universal_base_period(self): + """Test that reference period (e=-1) is correctly filtered out with universal base period. + + The reference period has n_groups=0 and se=NaN, so it should be excluded + from HonestDiD analysis to avoid contaminating the vcov matrix. + """ + from diff_diff import CallawaySantAnna, generate_staggered_data + + # Generate data and fit with universal base period + data = generate_staggered_data(n_units=200, n_periods=10, seed=42) + cs = CallawaySantAnna(base_period="universal") + results = cs.fit( + data, + outcome='outcome', + unit='unit', + time='period', + first_treat='first_treat', + aggregate='event_study' + ) + + # Verify reference period exists with NaN SE + assert -1 in results.event_study_effects + assert np.isnan(results.event_study_effects[-1]['se']) + + # HonestDiD should work without errors (reference period filtered out) + honest = HonestDiD(method='relative_magnitude', M=1.0) + bounds = honest.fit(results) + + # Should have valid (non-NaN) results + assert isinstance(bounds, HonestDiDResults) + assert np.isfinite(bounds.ci_lb) + assert np.isfinite(bounds.ci_ub) + # ============================================================================= # Tests for Visualization (without matplotlib) diff --git a/tests/test_pretrends.py b/tests/test_pretrends.py index c7e8db3..ba2c0a3 100644 --- a/tests/test_pretrends.py +++ b/tests/test_pretrends.py @@ -795,6 +795,38 @@ def test_unsupported_results_type_raises(self): with pytest.raises(TypeError, match="Unsupported results type"): pt.fit("not a results object") + def test_callaway_santanna_universal_base_period(self): + """Test that reference period (e=-1) is correctly filtered out with universal base period. + + The reference period has n_groups=0 and se=NaN, so it should be excluded + from pre-trends power analysis to avoid contaminating the vcov matrix. + """ + from diff_diff import CallawaySantAnna, generate_staggered_data + + # Generate data and fit with universal base period + data = generate_staggered_data(n_units=200, n_periods=10, seed=42) + cs = CallawaySantAnna(base_period="universal") + results = cs.fit( + data, + outcome='outcome', + unit='unit', + time='period', + first_treat='first_treat', + aggregate='event_study' + ) + + # Verify reference period exists with NaN SE + assert -1 in results.event_study_effects + assert np.isnan(results.event_study_effects[-1]['se']) + + # PreTrendsPower should work without errors (reference period filtered out) + pt = PreTrendsPower() + power_results = pt.fit(results) + + # Should have valid (non-NaN) results + assert np.isfinite(power_results.power) + assert power_results.n_pre_periods >= 1 + # ============================================================================= # Tests for visualization (without rendering) From d3da585b1dca093cd43dc3f2377a420d4751270a Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 24 Jan 2026 08:26:32 -0500 Subject: [PATCH 4/5] Address code review round 3: fix max-pre-violation and event study plotting - Filter n_groups=0 entries in _estimate_max_pre_violation to prevent reference period (effect=0) from collapsing RM bounds - Allow NaN SE entries in event study plots (show point without error bars) - Add regression test for max_pre_violation exclusion behavior - Add visualization tests for reference period with NaN SE Co-Authored-By: Claude Opus 4.5 --- diff_diff/honest_did.py | 2 ++ diff_diff/visualization.py | 33 +++++++++++++------ tests/test_honest_did.py | 44 ++++++++++++++++++++++++++ tests/test_visualization.py | 63 +++++++++++++++++++++++++++++++++++++ 4 files changed, 132 insertions(+), 10 deletions(-) diff --git a/diff_diff/honest_did.py b/diff_diff/honest_did.py index 9ce20db..81c8caf 100644 --- a/diff_diff/honest_did.py +++ b/diff_diff/honest_did.py @@ -1266,10 +1266,12 @@ def _estimate_max_pre_violation( from diff_diff.staggered import CallawaySantAnnaResults if isinstance(results, CallawaySantAnnaResults): if results.event_study_effects: + # Filter out normalization constraints (n_groups=0, e.g. reference period) pre_effects = [ abs(results.event_study_effects[t]['effect']) for t in results.event_study_effects if t < 0 + and results.event_study_effects[t].get('n_groups', 1) > 0 ] if pre_effects: return max(pre_effects) diff --git a/diff_diff/visualization.py b/diff_diff/visualization.py index fbf6ce4..537ddec 100644 --- a/diff_diff/visualization.py +++ b/diff_diff/visualization.py @@ -197,11 +197,17 @@ def plot_event_study( effect = effects.get(period, np.nan) std_err = se.get(period, np.nan) - if np.isnan(effect) or np.isnan(std_err): + # Skip entries with NaN effect, but allow NaN SE (will plot without error bars) + if np.isnan(effect): continue - ci_lower = effect - critical_value * std_err - ci_upper = effect + critical_value * std_err + # Compute CI only if SE is finite + if np.isfinite(std_err): + ci_lower = effect - critical_value * std_err + ci_upper = effect + critical_value * std_err + else: + ci_lower = np.nan + ci_upper = np.nan plot_data.append({ 'period': period, @@ -244,13 +250,20 @@ def plot_event_study( ref_x = period_to_x[reference_period] ax.axvline(x=ref_x, color='gray', linestyle=':', linewidth=1, zorder=1) - # Plot error bars - yerr = [df['effect'] - df['ci_lower'], df['ci_upper'] - df['effect']] - ax.errorbar( - x_vals, df['effect'], yerr=yerr, - fmt='none', color=color, capsize=capsize, linewidth=linewidth, - capthick=linewidth, zorder=2 - ) + # Plot error bars (only for entries with finite CI) + has_ci = df['ci_lower'].notna() & df['ci_upper'].notna() + if has_ci.any(): + df_with_ci = df[has_ci] + x_with_ci = [period_to_x[p] for p in df_with_ci['period']] + yerr = [ + df_with_ci['effect'] - df_with_ci['ci_lower'], + df_with_ci['ci_upper'] - df_with_ci['effect'] + ] + ax.errorbar( + x_with_ci, df_with_ci['effect'], yerr=yerr, + fmt='none', color=color, capsize=capsize, linewidth=linewidth, + capthick=linewidth, zorder=2 + ) # Plot point estimates for i, row in df.iterrows(): diff --git a/tests/test_honest_did.py b/tests/test_honest_did.py index 5df181d..5172ac7 100644 --- a/tests/test_honest_did.py +++ b/tests/test_honest_did.py @@ -714,6 +714,50 @@ def test_callaway_santanna_universal_base_period(self): assert np.isfinite(bounds.ci_lb) assert np.isfinite(bounds.ci_ub) + def test_max_pre_violation_excludes_reference_period(self): + """Test that reference period (effect=0, n_groups=0) is excluded from max pre-violation. + + With universal base period, the reference period e=-1 is a normalization constraint + with n_groups=0. It should not be used in _estimate_max_pre_violation because + its effect is artificially set to 0, which would collapse RM bounds incorrectly. + """ + from diff_diff import CallawaySantAnna, generate_staggered_data + + # Generate data with universal base period + data = generate_staggered_data(n_units=200, n_periods=10, seed=42) + cs = CallawaySantAnna(base_period="universal") + results = cs.fit( + data, + outcome='outcome', + unit='unit', + time='period', + first_treat='first_treat', + aggregate='event_study' + ) + + # Verify reference period exists with n_groups=0 + assert -1 in results.event_study_effects + assert results.event_study_effects[-1]['n_groups'] == 0 + + # The max pre-violation calculation should exclude the reference period + honest = HonestDiD(method='relative_magnitude', M=1.0) + + # Get pre_periods excluding reference (n_groups=0) + real_pre_periods = [ + t for t in results.event_study_effects + if t < 0 and results.event_study_effects[t].get('n_groups', 1) > 0 + ] + + # If there are real pre-periods, max_violation should be > 0 + # (based on actual pre-period effects, not the reference period's effect=0) + if real_pre_periods: + max_violation = honest._estimate_max_pre_violation(results, real_pre_periods) + # Max violation should reflect actual pre-period coefficients, not 0 + # The actual effects are non-zero due to sampling variation + assert max_violation > 0, ( + "max_pre_violation should be > 0 when real pre-periods exist" + ) + # ============================================================================= # Tests for Visualization (without matplotlib) diff --git a/tests/test_visualization.py b/tests/test_visualization.py index d4736a0..604df6c 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -222,6 +222,69 @@ def test_error_invalid_results_type(self): with pytest.raises(TypeError, match="Cannot extract plot data"): plot_event_study("invalid") + def test_plot_with_nan_se_reference_period(self): + """Test that reference period with NaN SE is plotted without error bars. + + With universal base period, the reference period (e=-1) has SE=NaN. + The plot should show the point estimate without error bars rather than + skipping it entirely. + """ + pytest.importorskip("matplotlib") + import matplotlib.pyplot as plt + + # Create data with NaN SE for reference period + effects = {-2: 0.1, -1: 0.0, 0: 0.5, 1: 0.6} + se = {-2: 0.1, -1: np.nan, 0: 0.15, 1: 0.15} # NaN SE at reference period + + ax = plot_event_study( + effects=effects, + se=se, + reference_period=-1, + show=False + ) + + # Verify the plot was created successfully + assert ax is not None + + # Verify all 4 periods are plotted (including reference with NaN SE) + # The x-axis should have 4 tick labels + xtick_labels = [t.get_text() for t in ax.get_xticklabels()] + assert len(xtick_labels) == 4 + assert '-1' in xtick_labels + + plt.close() + + def test_plot_cs_universal_base_period(self): + """Test plotting CallawaySantAnna results with universal base period. + + The reference period (e=-1) should appear in the plot even though + it has SE=NaN. + """ + pytest.importorskip("matplotlib") + import matplotlib.pyplot as plt + from diff_diff import generate_staggered_data + + data = generate_staggered_data(n_units=200, n_periods=10, seed=42) + cs = CallawaySantAnna(base_period="universal") + results = cs.fit( + data, + outcome='outcome', + unit='unit', + time='period', + first_treat='first_treat', + aggregate='event_study' + ) + + # Should not raise even with NaN SE in reference period + ax = plot_event_study(results, show=False) + assert ax is not None + + # Verify reference period (-1) is in the plot + xtick_labels = [t.get_text() for t in ax.get_xticklabels()] + assert '-1' in xtick_labels + + plt.close() + class TestPlotEventStudyIntegration: """Integration tests for event study plotting.""" From 98ba6db004b98811bddef989ca43e8940286ebbc Mon Sep 17 00:00:00 2001 From: igerber Date: Sat, 24 Jan 2026 08:51:55 -0500 Subject: [PATCH 5/5] Address code review round 4: fix reference period detection with anticipation Fix _extract_plot_data() to detect reference period from n_groups=0 marker instead of hardcoding -1. This correctly handles anticipation > 0 where the reference period is at e = -1 - anticipation (e.g., e=-2 when anticipation=1). Co-Authored-By: Claude Opus 4.5 --- diff_diff/visualization.py | 10 +++++++++- tests/test_visualization.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/diff_diff/visualization.py b/diff_diff/visualization.py index 537ddec..6584b51 100644 --- a/diff_diff/visualization.py +++ b/diff_diff/visualization.py @@ -364,7 +364,15 @@ def _extract_plot_data( # Reference period is typically -1 for event study if reference_period is None: - reference_period = -1 + # Detect reference period from n_groups=0 marker (normalization constraint) + # This handles anticipation > 0 where reference is at e = -1 - anticipation + for period, effect_data in results.event_study_effects.items(): + if effect_data.get('n_groups', 1) == 0: + reference_period = period + break + # Fallback to -1 if no marker found (backward compatibility) + if reference_period is None: + reference_period = -1 if pre_periods is None: pre_periods = [p for p in periods if p < 0] diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 604df6c..4f78254 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -285,6 +285,39 @@ def test_plot_cs_universal_base_period(self): plt.close() + def test_plot_cs_with_anticipation(self): + """Test plotting CallawaySantAnna results with anticipation > 0. + + When anticipation=1, the reference period should be e=-2, not e=-1. + """ + pytest.importorskip("matplotlib") + import matplotlib.pyplot as plt + from diff_diff import generate_staggered_data + + data = generate_staggered_data(n_units=200, n_periods=10, seed=42) + cs = CallawaySantAnna(base_period="universal", anticipation=1) + results = cs.fit( + data, + outcome='outcome', + unit='unit', + time='period', + first_treat='first_treat', + aggregate='event_study' + ) + + # Reference period should be at e=-2 (not e=-1) with anticipation=1 + assert -2 in results.event_study_effects + assert results.event_study_effects[-2]['n_groups'] == 0 + + ax = plot_event_study(results, show=False) + assert ax is not None + + # Verify -2 is in the plot (the true reference period) + xtick_labels = [t.get_text() for t in ax.get_xticklabels()] + assert '-2' in xtick_labels + + plt.close() + class TestPlotEventStudyIntegration: """Integration tests for event study plotting."""