Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions pertpy/tools/_differential_gene_expression/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,15 +836,16 @@ def _get_significance(p_val):
df = results_df.pivot(index=contrast_col, columns=symbol_col, values=log2fc_col)[var_names]

plt.figure(figsize=figsize)
# Ensure tick labels are shown by default (fixes #755)
# Users can override via heatmap_kwargs
default_heatmap_kwargs = {"xticklabels": True, "yticklabels": True}
default_heatmap_kwargs.update(heatmap_kwargs)
sns.heatmap(df, **default_heatmap_kwargs, cmap="coolwarm", center=0, cbar_kws={"label": "Log2 fold change"})
sns.heatmap(df, **heatmap_kwargs, cmap="coolwarm", center=0, cbar_kws={"label": "Log2 fold change"})

_size = {"< 0.001": marker_size, "< 0.01": math.floor(marker_size / 2), "< 0.1": math.floor(marker_size / 4)}
x_locs, x_labels = plt.xticks()[0], [label.get_text() for label in plt.xticks()[1]]
y_locs, y_labels = plt.yticks()[0], [label.get_text() for label in plt.yticks()[1]]
# Calculate locations directly from DataFrame instead of extracting from rendered plot (fixes #755)
# Seaborn places cell centers at 0.5, 1.5, 2.5, etc.
# NOTE: This assumes a non-clustered heatmap. If using clustermap, coordinates would need reordering.
x_locs = np.arange(len(df.columns)) + 0.5
x_labels = df.columns.tolist()
y_locs = np.arange(len(df.index)) + 0.5
y_labels = df.index.tolist()

for _i, row in results_df.iterrows():
if row["significance"] != "n.s.":
Expand Down
27 changes: 16 additions & 11 deletions tests/tools/_differential_gene_expression/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,17 @@ def test_model_cond(test_adata_minimal, MockLinearModel, formula, cond_kwargs, e
assert actual_contrast.index.tolist() == mod.design.columns.tolist()


def test_plot_multicomparison_fc_default_figsize(MockLinearModel, test_adata_minimal):
"""Test that plot_multicomparison_fc works with default figsize.
def test_plot_multicomparison_fc_many_genes(MockLinearModel, test_adata_minimal):
"""Test that plot_multicomparison_fc works even when heatmap hides tick labels.

Regression test for issue #755.
When using default figsize, seaborn heatmap may not show xticklabels,
causing a ValueError when trying to plot significance markers.
When using small figsize or many genes, seaborn heatmap hides xticklabels.
The old code extracted labels from the rendered plot, causing ValueError.
The fix calculates positions directly from the DataFrame.
"""
# Create mock results similar to what compare_groups would return
# Create mock results with many genes to force label hiding
results = []
genes = ["GENE1", "GENE2", "GENE3", "IL5", "IL6", "IL10"]
genes = [f"GENE{i}" for i in range(50)] # 50 genes will force label hiding
contrasts = ["contrast1", "contrast2"]

for contrast in contrasts:
Expand All @@ -110,8 +111,8 @@ def test_plot_multicomparison_fc_default_figsize(MockLinearModel, test_adata_min
{
"contrast": contrast,
"variable": gene,
"log_fc": 1.5 + i * 0.3,
"adj_p_value": 0.001 if i < 3 else 0.05,
"log_fc": 1.5 + i * 0.05,
"adj_p_value": 0.001 if i < 10 else 0.05,
}
)

Expand All @@ -120,7 +121,11 @@ def test_plot_multicomparison_fc_default_figsize(MockLinearModel, test_adata_min
# Create a mock model instance
mod = MockLinearModel(test_adata_minimal, "~condition")

# This should not raise ValueError: 'IL5' is not in list
# even with default figsize (which may cause heatmap to hide labels)
fig = mod.plot_multicomparison_fc(results_df, return_fig=True)
# This should not raise ValueError even with small figsize
# that causes seaborn to hide tick labels
fig = mod.plot_multicomparison_fc(results_df, figsize=(6, 4), return_fig=True)
assert fig is not None

# Also test with heatmap_kwargs that explicitly hide labels
fig = mod.plot_multicomparison_fc(results_df, xticklabels=False, return_fig=True)
assert fig is not None
Loading