diff --git a/pertpy/tools/_differential_gene_expression/_base.py b/pertpy/tools/_differential_gene_expression/_base.py index d879c61d..a2f4c2ce 100644 --- a/pertpy/tools/_differential_gene_expression/_base.py +++ b/pertpy/tools/_differential_gene_expression/_base.py @@ -836,15 +836,16 @@ def _get_significance(p_val): df = results_df.pivot(index=contrast_col, columns=symbol_col, values=log2fc_col)[var_names] plt.figure(figsize=figsize) - # Ensure tick labels are shown by default (fixes #755) - # Users can override via heatmap_kwargs - default_heatmap_kwargs = {"xticklabels": True, "yticklabels": True} - default_heatmap_kwargs.update(heatmap_kwargs) - sns.heatmap(df, **default_heatmap_kwargs, cmap="coolwarm", center=0, cbar_kws={"label": "Log2 fold change"}) + sns.heatmap(df, **heatmap_kwargs, cmap="coolwarm", center=0, cbar_kws={"label": "Log2 fold change"}) _size = {"< 0.001": marker_size, "< 0.01": math.floor(marker_size / 2), "< 0.1": math.floor(marker_size / 4)} - x_locs, x_labels = plt.xticks()[0], [label.get_text() for label in plt.xticks()[1]] - y_locs, y_labels = plt.yticks()[0], [label.get_text() for label in plt.yticks()[1]] + # Calculate locations directly from DataFrame instead of extracting from rendered plot (fixes #755) + # Seaborn places cell centers at 0.5, 1.5, 2.5, etc. + # NOTE: This assumes a non-clustered heatmap. If using clustermap, coordinates would need reordering. + x_locs = np.arange(len(df.columns)) + 0.5 + x_labels = df.columns.tolist() + y_locs = np.arange(len(df.index)) + 0.5 + y_labels = df.index.tolist() for _i, row in results_df.iterrows(): if row["significance"] != "n.s.": diff --git a/tests/tools/_differential_gene_expression/test_base.py b/tests/tools/_differential_gene_expression/test_base.py index a1bd2130..50bc0b5f 100644 --- a/tests/tools/_differential_gene_expression/test_base.py +++ b/tests/tools/_differential_gene_expression/test_base.py @@ -92,16 +92,17 @@ def test_model_cond(test_adata_minimal, MockLinearModel, formula, cond_kwargs, e assert actual_contrast.index.tolist() == mod.design.columns.tolist() -def test_plot_multicomparison_fc_default_figsize(MockLinearModel, test_adata_minimal): - """Test that plot_multicomparison_fc works with default figsize. +def test_plot_multicomparison_fc_many_genes(MockLinearModel, test_adata_minimal): + """Test that plot_multicomparison_fc works even when heatmap hides tick labels. Regression test for issue #755. - When using default figsize, seaborn heatmap may not show xticklabels, - causing a ValueError when trying to plot significance markers. + When using small figsize or many genes, seaborn heatmap hides xticklabels. + The old code extracted labels from the rendered plot, causing ValueError. + The fix calculates positions directly from the DataFrame. """ - # Create mock results similar to what compare_groups would return + # Create mock results with many genes to force label hiding results = [] - genes = ["GENE1", "GENE2", "GENE3", "IL5", "IL6", "IL10"] + genes = [f"GENE{i}" for i in range(50)] # 50 genes will force label hiding contrasts = ["contrast1", "contrast2"] for contrast in contrasts: @@ -110,8 +111,8 @@ def test_plot_multicomparison_fc_default_figsize(MockLinearModel, test_adata_min { "contrast": contrast, "variable": gene, - "log_fc": 1.5 + i * 0.3, - "adj_p_value": 0.001 if i < 3 else 0.05, + "log_fc": 1.5 + i * 0.05, + "adj_p_value": 0.001 if i < 10 else 0.05, } ) @@ -120,7 +121,11 @@ def test_plot_multicomparison_fc_default_figsize(MockLinearModel, test_adata_min # Create a mock model instance mod = MockLinearModel(test_adata_minimal, "~condition") - # This should not raise ValueError: 'IL5' is not in list - # even with default figsize (which may cause heatmap to hide labels) - fig = mod.plot_multicomparison_fc(results_df, return_fig=True) + # This should not raise ValueError even with small figsize + # that causes seaborn to hide tick labels + fig = mod.plot_multicomparison_fc(results_df, figsize=(6, 4), return_fig=True) + assert fig is not None + + # Also test with heatmap_kwargs that explicitly hide labels + fig = mod.plot_multicomparison_fc(results_df, xticklabels=False, return_fig=True) assert fig is not None