diff --git a/dash_bio/component_factory/_clustergram.py b/dash_bio/component_factory/_clustergram.py index 8ec5f106..07244431 100644 --- a/dash_bio/component_factory/_clustergram.py +++ b/dash_bio/component_factory/_clustergram.py @@ -558,11 +558,11 @@ def figure(self, computed_traces=None): col_dendro_traces_max_y = np.concatenate(col_dendro_traces_y).max() # ensure that everything is aligned properly - # with the heatmap + # with the heatmap and dendrograms zoom synchronously yaxis9 = fig["layout"]["yaxis9"] # pylint: disable=invalid-sequence-index - yaxis9.update(scaleanchor="y11") + yaxis9.update(matches="y11") xaxis3 = fig["layout"]["xaxis3"] # pylint: disable=invalid-sequence-index - xaxis3.update(scaleanchor="x11") + xaxis3.update(matches="x11") if len(tickvals_col) == 0: tickvals_col = [10 * i + 5 for i in range(len(self._column_ids))] @@ -576,9 +576,10 @@ def figure(self, computed_traces=None): showticklabels=True, side="bottom", showline=False, - range=[min(tickvals_col) - 5, max(tickvals_col) + 5] + range=[min(tickvals_col) - 5, max(tickvals_col) + 5], # workaround for autoscale issues above; otherwise # the graph cuts off and must be scaled manually + fixedrange=False ) if len(tickvals_row) == 0: @@ -591,18 +592,18 @@ def figure(self, computed_traces=None): tickfont=self._tick_font, showticklabels=True, side="right", - showline=False, + showline=False ) # hide labels, if necessary for label in self._hidden_labels: fig["layout"][label].update(ticks="", showticklabels=False) - row_colors_heatmap = self._get_row_colors_heatmap() + row_colors_heatmap = self._get_row_colors_heatmap(tickvals_row) if row_colors_heatmap is not None: - fig.append_trace(self._get_row_colors_heatmap(), 3, 2) + fig.append_trace(row_colors_heatmap, 3, 2) - col_colors_heatmap = self._get_column_colors_heatmap() + col_colors_heatmap = self._get_column_colors_heatmap(tickvals_col) if col_colors_heatmap is not None: fig.append_trace(col_colors_heatmap, 2, 3) @@ -712,6 +713,27 @@ def figure(self, computed_traces=None): domain=[0, 1 - col_ratio - col_colors_ratio] ) + # Link color heatmap axes to main heatmap for zoom synchronization + # Using 'matches' to ensure the same coordinate system + # Row colors (yaxis10) matches main heatmap y-axis (yaxis11) + if len(tickvals_row) > 0: + fig["layout"]["yaxis10"].update( + matches="y11", + range=[min(tickvals_row), max(tickvals_row)], + tickmode="array", + tickvals=[], + ticktext=[] + ) + # Similar setup for column colors: xaxis7 matches main heatmap x-axis (xaxis11) + if len(tickvals_col) > 0: + fig["layout"]["xaxis7"].update( + matches="x11", + range=[min(tickvals_col), max(tickvals_col)], + tickmode="array", + tickvals=[], + ticktext=[] + ) + fig["layout"][ "legend" ] = dict( # pylint: disable=unsupported-assignment-operation @@ -833,7 +855,7 @@ def _get_clusters(self): return (Zcol, Zrow) - def _get_row_colors_heatmap(self): + def _get_row_colors_heatmap(self, tickvals_row=None): colors = self._row_colors if colors is None: @@ -854,14 +876,21 @@ def _get_row_colors_heatmap(self): z = [[i] for i in range(len(colors))] - return go.Heatmap( - z=z, - colorscale=colorscale, - colorbar={"xpad": 100}, - showscale=False - ) + heatmap_kwargs = { + "z": z, + "colorscale": colorscale, + "colorbar": {"xpad": 100}, + "showscale": False + } + + # Use the same y-coordinates as the main heatmap for proper + # zoom synchronization + if tickvals_row is not None: + heatmap_kwargs["y"] = tickvals_row - def _get_column_colors_heatmap(self): + return go.Heatmap(**heatmap_kwargs) + + def _get_column_colors_heatmap(self, tickvals_col=None): colors = self._column_colors if colors is None: @@ -882,12 +911,19 @@ def _get_column_colors_heatmap(self): z = [[i * 5 for i in range(len(colors))]] - return go.Heatmap( - z=z, - colorscale=colorscale, - colorbar={"xpad": 100}, - showscale=False - ) + heatmap_kwargs = { + "z": z, + "colorscale": colorscale, + "colorbar": {"xpad": 100}, + "showscale": False + } + + # Use the same x-coordinates as the main heatmap for proper + # zoom synchronization + if tickvals_col is not None: + heatmap_kwargs["x"] = tickvals_col + + return go.Heatmap(**heatmap_kwargs) def _compute_clustered_data(self): """Get the traces that need to be plotted for the row and column