Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
5eebcf7
feat(aggregation): add GradVac aggregator
rkhosrowshahi Apr 9, 2026
a588c93
chore: Remove outdated doctesting stuff (#639)
ValerianRey Apr 11, 2026
9d65f63
chore: Add governance documentation (#637)
PierreQuinton Apr 11, 2026
3ab336c
refactor(gradvac): literal group types, eps/beta rules, and plotter UX
rkhosrowshahi Apr 12, 2026
e53849e
refactor(gradvac): base on GramianWeightedAggregator with GradVacWeig…
rkhosrowshahi Apr 12, 2026
4909964
fix: update type hint for update_gradient_coordinate function
rkhosrowshahi Apr 12, 2026
a39f343
test(gradvac): cover beta setter success path for codecov
rkhosrowshahi Apr 12, 2026
0359e60
Rename some variables in test_gradvac.py
ValerianRey Apr 12, 2026
1da5f6e
Add comment about why we move to cpu
ValerianRey Apr 12, 2026
21d55f9
Add GradVac to the aggregator table in README
ValerianRey Apr 12, 2026
17b1dd5
Add changelog entry
ValerianRey Apr 12, 2026
02a826b
Merge branch 'main' into feature/gradvac
ValerianRey Apr 12, 2026
f4e8e60
Remove seed setting in test_aggregator_output
ValerianRey Apr 12, 2026
75c89c1
fix(aggregation): Add fallback in NashMTL (#640)
ValerianRey Apr 13, 2026
b100c8b
Merge branch 'main' into feature/gradvac
ValerianRey Apr 13, 2026
193ffa6
Merge branch 'main' of https://github.com/TorchJD/torchjd into featur…
rkhosrowshahi Apr 13, 2026
9ffdd13
Revert plot test refactors; keep GradVac in interactive plotter
rkhosrowshahi Apr 13, 2026
50525a1
Merge branch 'main' into feature/gradvac (21f6b74)
rkhosrowshahi Apr 13, 2026
e626475
docs(aggregation): add grouping usage example and fix GradVac note
rkhosrowshahi Apr 13, 2026
a244d2b
docs(changelog): split Unreleased into Added and Fixed for GradVac an…
rkhosrowshahi Apr 13, 2026
1933dea
Merge branch 'main' into feature/gradvac
rkhosrowshahi Apr 13, 2026
59e7942
Remove grouping example
ValerianRey Apr 14, 2026
e7d4981
Exclude properties from the documentation
ValerianRey Apr 14, 2026
f028d54
Remove docstring of setters (not needed anymore)
ValerianRey Apr 14, 2026
d62a474
Simplify parameter explanation
ValerianRey Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ changelog does not include internal changes that do not affect the user.

## [Unreleased]

### Added

- Added `GradVac` and `GradVacWeighting` from
[Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874).

### Fixed

- Added a fallback for when the inner optimization of `NashMTL` fails (which can happen for example
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ TorchJD provides many existing aggregators from the literature, listed in the fo
| [Constant](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.Constant) | [ConstantWeighting](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.ConstantWeighting) | - |
| [DualProj](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProj) | [DualProjWeighting](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProjWeighting) | [Gradient Episodic Memory for Continual Learning](https://arxiv.org/pdf/1706.08840) |
| [GradDrop](https://torchjd.org/stable/docs/aggregation/graddrop#torchjd.aggregation.GradDrop) | - | [Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout](https://arxiv.org/pdf/2010.06808) |
| [GradVac](https://torchjd.org/stable/docs/aggregation/gradvac#torchjd.aggregation.GradVac) | [GradVacWeighting](https://torchjd.org/stable/docs/aggregation/gradvac#torchjd.aggregation.GradVacWeighting) | [Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874) |
| [IMTLG](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLG) | [IMTLGWeighting](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLGWeighting) | [Towards Impartial Multi-task Learning](https://discovery.ucl.ac.uk/id/eprint/10120667/) |
| [Krum](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.Krum) | [KrumWeighting](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.KrumWeighting) | [Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent](https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf) |
| [Mean](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.Mean) | [MeanWeighting](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.MeanWeighting) | - |
Expand Down
14 changes: 14 additions & 0 deletions docs/source/docs/aggregation/gradvac.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
:hide-toc:

GradVac
=======

.. autoclass:: torchjd.aggregation.GradVac
:members:
:undoc-members:
:exclude-members: forward, eps, beta

.. autoclass:: torchjd.aggregation.GradVacWeighting
:members:
:undoc-members:
:exclude-members: forward, eps, beta
1 change: 1 addition & 0 deletions docs/source/docs/aggregation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Abstract base classes
dualproj.rst
flattening.rst
graddrop.rst
gradvac.rst
imtl_g.rst
krum.rst
mean.rst
Expand Down
3 changes: 3 additions & 0 deletions src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from ._dualproj import DualProj, DualProjWeighting
from ._flattening import Flattening
from ._graddrop import GradDrop
from ._gradvac import GradVac, GradVacWeighting
from ._imtl_g import IMTLG, IMTLGWeighting
from ._krum import Krum, KrumWeighting
from ._mean import Mean, MeanWeighting
Expand All @@ -92,6 +93,8 @@
"Flattening",
"GeneralizedWeighting",
"GradDrop",
"GradVac",
"GradVacWeighting",
"IMTLG",
"IMTLGWeighting",
"Krum",
Expand Down
190 changes: 190 additions & 0 deletions src/torchjd/aggregation/_gradvac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from __future__ import annotations

from typing import cast

import torch
from torch import Tensor

from torchjd._linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error
from ._weighting_bases import Weighting


class GradVac(GramianWeightedAggregator):
r"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of
Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task
Optimization in Massively Multilingual Models (ICLR 2021 Spotlight)
<https://openreview.net/forum?id=F1vEjWK-lH_>`_.

For each task :math:`i`, the order in which other tasks :math:`j` are visited is drawn at
random. For each pair :math:`(i, j)`, the cosine similarity :math:`\phi_{ij}` between the
(possibly already modified) gradient of task :math:`i` and the original gradient of task
:math:`j` is compared to an EMA target :math:`\hat{\phi}_{ij}`. When
:math:`\phi_{ij} < \hat{\phi}_{ij}`, a closed-form correction adds a scaled copy of
:math:`g_j` to :math:`g_i^{(\mathrm{PC})}`. The EMA is then updated with
:math:`\hat{\phi}_{ij} \leftarrow (1-\beta)\hat{\phi}_{ij} + \beta \phi_{ij}`. The aggregated
vector is the sum of the modified rows.

This aggregator is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when
the number of tasks or dtype changes.

:param beta: EMA decay for :math:`\hat{\phi}`.
:param eps: Small non-negative constant added to denominators.

.. note::
For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently
using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if
you need reproducibility.
"""

def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
weighting = GradVacWeighting(beta=beta, eps=eps)
super().__init__(weighting)
self._gradvac_weighting = weighting
self.register_full_backward_pre_hook(raise_non_differentiable_error)

@property
def beta(self) -> float:
return self._gradvac_weighting.beta

@beta.setter
def beta(self, value: float) -> None:
self._gradvac_weighting.beta = value

@property
def eps(self) -> float:
return self._gradvac_weighting.eps

@eps.setter
def eps(self, value: float) -> None:
self._gradvac_weighting.eps = value

def reset(self) -> None:
"""Clears EMA state so the next forward starts from zero targets."""

self._gradvac_weighting.reset()

def __repr__(self) -> str:
return f"GradVac(beta={self.beta!r}, eps={self.eps!r})"


class GradVacWeighting(Weighting[PSDMatrix]):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.GradVac`.

All required quantities (gradient norms, cosine similarities, and their updates after the
vaccine correction) are derived purely from the Gramian, without needing the full Jacobian.
If :math:`g_i^{(\mathrm{PC})} = \sum_k c_{ik} g_k`, then:

.. math::

\|g_i^{(\mathrm{PC})}\|^2 = \mathbf{c}_i G \mathbf{c}_i^\top,\qquad
g_i^{(\mathrm{PC})} \cdot g_j = \mathbf{c}_i G_{:,j}

where :math:`G` is the Gramian. The correction :math:`g_i^{(\mathrm{PC})} \mathrel{+}= w
g_j` then becomes :math:`c_{ij} \mathrel{+}= w`, and the updated dot products follow
immediately.

This weighting is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when
the number of tasks or dtype changes.

:param beta: EMA decay for :math:`\hat{\phi}`.
:param eps: Small non-negative constant added to denominators.
"""

def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
super().__init__()
if not (0.0 <= beta <= 1.0):
raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.")
if eps < 0.0:
raise ValueError(f"Parameter `eps` must be non-negative. Found eps={eps!r}.")

self._beta = beta
self._eps = eps
self._phi_t: Tensor | None = None
self._state_key: tuple[int, torch.dtype] | None = None

@property
def beta(self) -> float:
return self._beta

@beta.setter
def beta(self, value: float) -> None:
if not (0.0 <= value <= 1.0):
raise ValueError(f"Attribute `beta` must be in [0, 1]. Found beta={value!r}.")
self._beta = value

@property
def eps(self) -> float:
return self._eps

@eps.setter
def eps(self, value: float) -> None:
if value < 0.0:
raise ValueError(f"Attribute `eps` must be non-negative. Found eps={value!r}.")
self._eps = value

def reset(self) -> None:
"""Clears EMA state so the next forward starts from zero targets."""

self._phi_t = None
self._state_key = None

def forward(self, gramian: PSDMatrix, /) -> Tensor:
# Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration
device = gramian.device
dtype = gramian.dtype
cpu = torch.device("cpu")

G = cast(PSDMatrix, gramian.to(device=cpu))
m = G.shape[0]

self._ensure_state(m, dtype)
phi_t = cast(Tensor, self._phi_t)

beta = self._beta
eps = self._eps

# C[i, :] holds coefficients such that g_i^PC = sum_k C[i,k] * g_k (original gradients).
# Initially each modified gradient equals the original, so C = I.
C = torch.eye(m, device=cpu, dtype=dtype)

for i in range(m):
# Dot products of g_i^PC with every original g_j, shape (m,).
cG = C[i] @ G

others = [j for j in range(m) if j != i]
perm = torch.randperm(len(others))
shuffled_js = [others[idx] for idx in perm.tolist()]

for j in shuffled_js:
dot_ij = cG[j]
norm_i_sq = (cG * C[i]).sum()
norm_i = norm_i_sq.clamp(min=0.0).sqrt()
norm_j = G[j, j].clamp(min=0.0).sqrt()
denom = norm_i * norm_j + eps
phi_ijk = dot_ij / denom

phi_hat = phi_t[i, j]
if phi_ijk < phi_hat:
sqrt_1_phi2 = (1.0 - phi_ijk * phi_ijk).clamp(min=0.0).sqrt()
sqrt_1_hat2 = (1.0 - phi_hat * phi_hat).clamp(min=0.0).sqrt()
denom_w = norm_j * sqrt_1_hat2 + eps
w = norm_i * (phi_hat * sqrt_1_phi2 - phi_ijk * sqrt_1_hat2) / denom_w
C[i, j] = C[i, j] + w
cG = cG + w * G[j]

phi_t[i, j] = (1.0 - beta) * phi_hat + beta * phi_ijk

weights = C.sum(dim=0)
return weights.to(device)

def _ensure_state(self, m: int, dtype: torch.dtype) -> None:
key = (m, dtype)
if self._state_key != key or self._phi_t is None:
self._phi_t = torch.zeros(m, m, dtype=dtype)
self._state_key = key
2 changes: 2 additions & 0 deletions tests/plots/interactive_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ConFIG,
DualProj,
GradDrop,
GradVac,
Mean,
NashMTL,
PCGrad,
Expand Down Expand Up @@ -48,6 +49,7 @@ def main() -> None:
ConFIG(),
DualProj(),
GradDrop(),
GradVac(),
IMTLG(),
Mean(),
MGDA(),
Expand Down
Loading
Loading