diff --git a/tests/plots/_utils.py b/tests/plots/_utils.py index dc69bfda..40118fea 100644 --- a/tests/plots/_utils.py +++ b/tests/plots/_utils.py @@ -1,3 +1,5 @@ +from collections.abc import Callable + import numpy as np import torch from plotly import graph_objects as go @@ -7,14 +9,22 @@ class Plotter: - def __init__(self, aggregators: list[Aggregator], matrix: torch.Tensor, seed: int = 0) -> None: - self.aggregators = aggregators + def __init__( + self, + aggregator_factories: dict[str, Callable[[], Aggregator]], + selected_keys: list[str], + matrix: torch.Tensor, + seed: int = 0, + ) -> None: + self._aggregator_factories = aggregator_factories + self.selected_keys = selected_keys self.matrix = matrix self.seed = seed def make_fig(self) -> Figure: torch.random.manual_seed(self.seed) - results = [agg(self.matrix) for agg in self.aggregators] + aggregators = [self._aggregator_factories[key]() for key in self.selected_keys] + results = [agg(self.matrix) for agg in aggregators] fig = go.Figure() @@ -23,14 +33,19 @@ def make_fig(self) -> Figure: fig.add_trace(cone) for i in range(len(self.matrix)): - scatter = make_vector_scatter(self.matrix[i], "blue", f"g{i + 1}") + scatter = make_vector_scatter( + self.matrix[i], + "blue", + f"g{i + 1}", + textposition="top right", + ) fig.add_trace(scatter) for i in range(len(results)): scatter = make_vector_scatter( results[i], "black", - str(self.aggregators[i]), + self.selected_keys[i], showlegend=True, dash=True, ) diff --git a/tests/plots/interactive_plotter.py b/tests/plots/interactive_plotter.py index 2411e4c3..a722ccc1 100644 --- a/tests/plots/interactive_plotter.py +++ b/tests/plots/interactive_plotter.py @@ -1,17 +1,20 @@ import logging import os import webbrowser +from collections.abc import Callable from threading import Timer import numpy as np import torch from dash import Dash, Input, Output, callback, dcc, html from plotly.graph_objs import Figure +from typing_extensions import Unpack from plots._utils import Plotter, angle_to_coord, coord_to_angle from torchjd.aggregation import ( IMTLG, MGDA, + Aggregator, AlignedMTL, CAGrad, ConFIG, @@ -31,6 +34,14 @@ MAX_LENGTH = 25.0 +def _format_angle_display(angle: float) -> str: + return f"{np.degrees(angle):.1f}°" + + +def _format_length_display(r: float) -> str: + return f"{r:.2f}" + + def main() -> None: log = logging.getLogger("werkzeug") log.setLevel(logging.CRITICAL) @@ -43,27 +54,30 @@ def main() -> None: ], ) - aggregators = [ - AlignedMTL(), - CAGrad(c=0.5), - ConFIG(), - DualProj(), - GradDrop(), - GradVac(), - IMTLG(), - Mean(), - MGDA(), - NashMTL(n_tasks=matrix.shape[0]), - PCGrad(), - Random(), - Sum(), - TrimmedMean(trim_number=1), - UPGrad(), - ] - - aggregators_dict = {str(aggregator): aggregator for aggregator in aggregators} - - plotter = Plotter([], matrix) + n_tasks = matrix.shape[0] + aggregator_factories: dict[str, Callable[[], Aggregator]] = { + "AlignedMTL-min": lambda: AlignedMTL(scale_mode="min"), + "AlignedMTL-median": lambda: AlignedMTL(scale_mode="median"), + "AlignedMTL-RMSE": lambda: AlignedMTL(scale_mode="rmse"), + str(CAGrad(c=0.5)): lambda: CAGrad(c=0.5), + str(ConFIG()): lambda: ConFIG(), + str(DualProj()): lambda: DualProj(), + str(GradDrop()): lambda: GradDrop(), + str(GradVac()): lambda: GradVac(), + str(IMTLG()): lambda: IMTLG(), + str(Mean()): lambda: Mean(), + str(MGDA()): lambda: MGDA(), + str(NashMTL(n_tasks=n_tasks)): lambda: NashMTL(n_tasks=n_tasks), + str(PCGrad()): lambda: PCGrad(), + str(Random()): lambda: Random(), + str(Sum()): lambda: Sum(), + str(TrimmedMean(trim_number=1)): lambda: TrimmedMean(trim_number=1), + str(UPGrad()): lambda: UPGrad(), + } + + aggregator_strings = list(aggregator_factories.keys()) + + plotter = Plotter(aggregator_factories, [], matrix) app = Dash(__name__) @@ -98,7 +112,6 @@ def main() -> None: gradient_slider_inputs.append(Input(angle_input, "value")) gradient_slider_inputs.append(Input(r_input, "value")) - aggregator_strings = [str(aggregator) for aggregator in aggregators] checklist = dcc.Checklist(aggregator_strings, [], id="aggregator-checklist") control_div = html.Div( @@ -117,22 +130,32 @@ def update_seed(value: int) -> Figure: plotter.seed = value return plotter.make_fig() + n_gradients = len(matrix) + gradient_value_outputs: list[Output] = [] + for i in range(n_gradients): + gradient_value_outputs.append(Output(f"g{i + 1}-angle-value", "children")) + gradient_value_outputs.append(Output(f"g{i + 1}-length-value", "children")) + @callback( Output("aggregations-fig", "figure", allow_duplicate=True), + *gradient_value_outputs, *gradient_slider_inputs, prevent_initial_call=True, ) - def update_gradient_coordinate(*values: str) -> Figure: + def update_gradient_coordinate(*values: str) -> tuple[Figure, Unpack[tuple[str, ...]]]: values_ = [float(value) for value in values] + display_parts: list[str] = [] for j in range(len(values_) // 2): angle = values_[2 * j] r = values_[2 * j + 1] x, y = angle_to_coord(angle, r) plotter.matrix[j, 0] = x plotter.matrix[j, 1] = y + display_parts.append(_format_angle_display(angle)) + display_parts.append(_format_length_display(r)) - return plotter.make_fig() + return (plotter.make_fig(), *display_parts) @callback( Output("aggregations-fig", "figure", allow_duplicate=True), @@ -140,9 +163,7 @@ def update_gradient_coordinate(*values: str) -> Figure: prevent_initial_call=True, ) def update_aggregators(value: list[str]) -> Figure: - aggregator_keys = value - new_aggregators = [aggregators_dict[key] for key in aggregator_keys] - plotter.aggregators = new_aggregators + plotter.selected_keys = list(value) return plotter.make_fig() Timer(1, open_browser).start() @@ -175,11 +196,56 @@ def make_gradient_div( style={"width": "250px"}, ) + label_style: dict[str, str | int] = { + "display": "inline-block", + "width": "52px", + "margin-right": "8px", + "vertical-align": "middle", + } + value_style: dict[str, str] = { + "display": "inline-block", + "margin-left": "10px", + "min-width": "140px", + "font-family": "monospace", + "font-size": "13px", + "vertical-align": "middle", + } + row_style: dict[str, str] = {"display": "block", "margin-bottom": "6px"} div = html.Div( [ - html.P(f"g{i + 1}", style={"display": "inline-block", "margin-right": 20}), - angle_input, - r_input, + dcc.Markdown( + f"$g_{{{i + 1}}}$", + mathjax=True, + style={ + "margin": "0 0 6px 0", + "font-weight": "bold", + "display": "block", + }, + ), + html.Div( + [ + html.Span("Angle", style=label_style), + angle_input, + html.Span( + id=f"g{i + 1}-angle-value", + children=_format_angle_display(angle), + style=value_style, + ), + ], + style=row_style, + ), + html.Div( + [ + html.Span("Length", style=label_style), + r_input, + html.Span( + id=f"g{i + 1}-length-value", + children=_format_length_display(r), + style=value_style, + ), + ], + style={**row_style, "margin-bottom": "12px"}, + ), ], ) return div, angle_input, r_input