-
Notifications
You must be signed in to change notification settings - Fork 15
feat(aggregation): Add GradVac aggregator #638
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
5eebcf7
feat(aggregation): add GradVac aggregator
rkhosrowshahi a588c93
chore: Remove outdated doctesting stuff (#639)
ValerianRey 9d65f63
chore: Add governance documentation (#637)
PierreQuinton 3ab336c
refactor(gradvac): literal group types, eps/beta rules, and plotter UX
rkhosrowshahi e53849e
refactor(gradvac): base on GramianWeightedAggregator with GradVacWeig…
rkhosrowshahi 4909964
fix: update type hint for update_gradient_coordinate function
rkhosrowshahi a39f343
test(gradvac): cover beta setter success path for codecov
rkhosrowshahi 0359e60
Rename some variables in test_gradvac.py
ValerianRey 1da5f6e
Add comment about why we move to cpu
ValerianRey 21d55f9
Add GradVac to the aggregator table in README
ValerianRey 17b1dd5
Add changelog entry
ValerianRey 02a826b
Merge branch 'main' into feature/gradvac
ValerianRey f4e8e60
Remove seed setting in test_aggregator_output
ValerianRey 75c89c1
fix(aggregation): Add fallback in NashMTL (#640)
ValerianRey b100c8b
Merge branch 'main' into feature/gradvac
ValerianRey 193ffa6
Merge branch 'main' of https://github.com/TorchJD/torchjd into featur…
rkhosrowshahi 9ffdd13
Revert plot test refactors; keep GradVac in interactive plotter
rkhosrowshahi 50525a1
Merge branch 'main' into feature/gradvac (21f6b74)
rkhosrowshahi e626475
docs(aggregation): add grouping usage example and fix GradVac note
rkhosrowshahi a244d2b
docs(changelog): split Unreleased into Added and Fixed for GradVac an…
rkhosrowshahi 1933dea
Merge branch 'main' into feature/gradvac
rkhosrowshahi 59e7942
Remove grouping example
ValerianRey e7d4981
Exclude properties from the documentation
ValerianRey f028d54
Remove docstring of setters (not needed anymore)
ValerianRey d62a474
Simplify parameter explanation
ValerianRey File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| :hide-toc: | ||
|
|
||
| GradVac | ||
| ======= | ||
|
|
||
| .. autoclass:: torchjd.aggregation.GradVac | ||
| :members: | ||
| :undoc-members: | ||
| :exclude-members: forward, eps, beta | ||
|
|
||
| .. autoclass:: torchjd.aggregation.GradVacWeighting | ||
| :members: | ||
| :undoc-members: | ||
| :exclude-members: forward, eps, beta |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,190 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import cast | ||
|
|
||
| import torch | ||
| from torch import Tensor | ||
|
|
||
| from torchjd._linalg import PSDMatrix | ||
|
|
||
| from ._aggregator_bases import GramianWeightedAggregator | ||
| from ._utils.non_differentiable import raise_non_differentiable_error | ||
| from ._weighting_bases import Weighting | ||
|
|
||
|
|
||
| class GradVac(GramianWeightedAggregator): | ||
| r""" | ||
| :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of | ||
| Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task | ||
| Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) | ||
| <https://openreview.net/forum?id=F1vEjWK-lH_>`_. | ||
|
|
||
| For each task :math:`i`, the order in which other tasks :math:`j` are visited is drawn at | ||
| random. For each pair :math:`(i, j)`, the cosine similarity :math:`\phi_{ij}` between the | ||
| (possibly already modified) gradient of task :math:`i` and the original gradient of task | ||
| :math:`j` is compared to an EMA target :math:`\hat{\phi}_{ij}`. When | ||
| :math:`\phi_{ij} < \hat{\phi}_{ij}`, a closed-form correction adds a scaled copy of | ||
| :math:`g_j` to :math:`g_i^{(\mathrm{PC})}`. The EMA is then updated with | ||
| :math:`\hat{\phi}_{ij} \leftarrow (1-\beta)\hat{\phi}_{ij} + \beta \phi_{ij}`. The aggregated | ||
| vector is the sum of the modified rows. | ||
|
|
||
| This aggregator is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when | ||
| the number of tasks or dtype changes. | ||
|
|
||
| :param beta: EMA decay for :math:`\hat{\phi}`. | ||
| :param eps: Small non-negative constant added to denominators. | ||
|
|
||
| .. note:: | ||
| For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently | ||
| using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if | ||
| you need reproducibility. | ||
| """ | ||
|
|
||
| def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: | ||
| weighting = GradVacWeighting(beta=beta, eps=eps) | ||
| super().__init__(weighting) | ||
| self._gradvac_weighting = weighting | ||
| self.register_full_backward_pre_hook(raise_non_differentiable_error) | ||
|
|
||
| @property | ||
| def beta(self) -> float: | ||
| return self._gradvac_weighting.beta | ||
|
|
||
| @beta.setter | ||
| def beta(self, value: float) -> None: | ||
| self._gradvac_weighting.beta = value | ||
|
|
||
| @property | ||
| def eps(self) -> float: | ||
| return self._gradvac_weighting.eps | ||
|
|
||
| @eps.setter | ||
| def eps(self, value: float) -> None: | ||
| self._gradvac_weighting.eps = value | ||
|
|
||
| def reset(self) -> None: | ||
| """Clears EMA state so the next forward starts from zero targets.""" | ||
|
|
||
| self._gradvac_weighting.reset() | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"GradVac(beta={self.beta!r}, eps={self.eps!r})" | ||
|
|
||
|
|
||
| class GradVacWeighting(Weighting[PSDMatrix]): | ||
| r""" | ||
| :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of | ||
| :class:`~torchjd.aggregation.GradVac`. | ||
|
|
||
| All required quantities (gradient norms, cosine similarities, and their updates after the | ||
| vaccine correction) are derived purely from the Gramian, without needing the full Jacobian. | ||
| If :math:`g_i^{(\mathrm{PC})} = \sum_k c_{ik} g_k`, then: | ||
|
|
||
| .. math:: | ||
|
|
||
| \|g_i^{(\mathrm{PC})}\|^2 = \mathbf{c}_i G \mathbf{c}_i^\top,\qquad | ||
| g_i^{(\mathrm{PC})} \cdot g_j = \mathbf{c}_i G_{:,j} | ||
|
|
||
| where :math:`G` is the Gramian. The correction :math:`g_i^{(\mathrm{PC})} \mathrel{+}= w | ||
| g_j` then becomes :math:`c_{ij} \mathrel{+}= w`, and the updated dot products follow | ||
| immediately. | ||
|
|
||
| This weighting is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when | ||
| the number of tasks or dtype changes. | ||
|
|
||
| :param beta: EMA decay for :math:`\hat{\phi}`. | ||
| :param eps: Small non-negative constant added to denominators. | ||
| """ | ||
|
|
||
| def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: | ||
| super().__init__() | ||
| if not (0.0 <= beta <= 1.0): | ||
| raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.") | ||
| if eps < 0.0: | ||
| raise ValueError(f"Parameter `eps` must be non-negative. Found eps={eps!r}.") | ||
|
|
||
| self._beta = beta | ||
| self._eps = eps | ||
| self._phi_t: Tensor | None = None | ||
| self._state_key: tuple[int, torch.dtype] | None = None | ||
|
|
||
| @property | ||
| def beta(self) -> float: | ||
| return self._beta | ||
|
|
||
| @beta.setter | ||
| def beta(self, value: float) -> None: | ||
| if not (0.0 <= value <= 1.0): | ||
| raise ValueError(f"Attribute `beta` must be in [0, 1]. Found beta={value!r}.") | ||
| self._beta = value | ||
|
|
||
| @property | ||
| def eps(self) -> float: | ||
| return self._eps | ||
|
|
||
| @eps.setter | ||
| def eps(self, value: float) -> None: | ||
| if value < 0.0: | ||
| raise ValueError(f"Attribute `eps` must be non-negative. Found eps={value!r}.") | ||
| self._eps = value | ||
|
|
||
| def reset(self) -> None: | ||
| """Clears EMA state so the next forward starts from zero targets.""" | ||
|
|
||
| self._phi_t = None | ||
| self._state_key = None | ||
|
|
||
| def forward(self, gramian: PSDMatrix, /) -> Tensor: | ||
| # Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration | ||
| device = gramian.device | ||
| dtype = gramian.dtype | ||
| cpu = torch.device("cpu") | ||
|
|
||
| G = cast(PSDMatrix, gramian.to(device=cpu)) | ||
| m = G.shape[0] | ||
|
|
||
| self._ensure_state(m, dtype) | ||
| phi_t = cast(Tensor, self._phi_t) | ||
|
|
||
| beta = self._beta | ||
| eps = self._eps | ||
|
|
||
| # C[i, :] holds coefficients such that g_i^PC = sum_k C[i,k] * g_k (original gradients). | ||
| # Initially each modified gradient equals the original, so C = I. | ||
| C = torch.eye(m, device=cpu, dtype=dtype) | ||
|
|
||
| for i in range(m): | ||
| # Dot products of g_i^PC with every original g_j, shape (m,). | ||
| cG = C[i] @ G | ||
|
|
||
| others = [j for j in range(m) if j != i] | ||
| perm = torch.randperm(len(others)) | ||
| shuffled_js = [others[idx] for idx in perm.tolist()] | ||
|
|
||
| for j in shuffled_js: | ||
| dot_ij = cG[j] | ||
| norm_i_sq = (cG * C[i]).sum() | ||
| norm_i = norm_i_sq.clamp(min=0.0).sqrt() | ||
| norm_j = G[j, j].clamp(min=0.0).sqrt() | ||
| denom = norm_i * norm_j + eps | ||
| phi_ijk = dot_ij / denom | ||
|
|
||
| phi_hat = phi_t[i, j] | ||
| if phi_ijk < phi_hat: | ||
| sqrt_1_phi2 = (1.0 - phi_ijk * phi_ijk).clamp(min=0.0).sqrt() | ||
| sqrt_1_hat2 = (1.0 - phi_hat * phi_hat).clamp(min=0.0).sqrt() | ||
| denom_w = norm_j * sqrt_1_hat2 + eps | ||
| w = norm_i * (phi_hat * sqrt_1_phi2 - phi_ijk * sqrt_1_hat2) / denom_w | ||
| C[i, j] = C[i, j] + w | ||
| cG = cG + w * G[j] | ||
|
|
||
| phi_t[i, j] = (1.0 - beta) * phi_hat + beta * phi_ijk | ||
|
|
||
| weights = C.sum(dim=0) | ||
| return weights.to(device) | ||
|
|
||
| def _ensure_state(self, m: int, dtype: torch.dtype) -> None: | ||
| key = (m, dtype) | ||
| if self._state_key != key or self._phi_t is None: | ||
| self._phi_t = torch.zeros(m, m, dtype=dtype) | ||
| self._state_key = key | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.