From ae6be7d38d49430218d3b3952b882e408028b1f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 21 Jan 2026 20:37:32 +0100 Subject: [PATCH 01/26] WIP: add gramian-based jac_to_grad --- src/torchjd/autojac/_jac_to_grad.py | 35 ++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 61427467..9671880d 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -1,9 +1,12 @@ from collections.abc import Iterable +from typing import cast import torch from torch import Tensor -from torchjd.aggregation import Aggregator +from torchjd._linalg import PSDMatrix, compute_gramian +from torchjd.aggregation import Aggregator, Weighting +from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac @@ -71,10 +74,36 @@ def jac_to_grad( if not retain_jac: _free_jacs(tensors_) + if isinstance(aggregator, GramianWeightedAggregator): + # When it's possible, avoid the concatenation of the jacobians that can be very costly in + # memory. + gradients = _gramian_based(aggregator.weighting.weighting, jacobians, tensors_) + else: + gradients = _jacobian_based(aggregator, jacobians, tensors_) + accumulate_grads(tensors_, gradients) + + +def _jacobian_based( + aggregator: Aggregator, jacobians: list[Tensor], tensors: list[TensorWithJac] +) -> list[Tensor]: jacobian_matrix = _unite_jacobians(jacobians) gradient_vector = aggregator(jacobian_matrix) - gradients = _disunite_gradient(gradient_vector, jacobians, tensors_) - accumulate_grads(tensors_, gradients) + gradients = _disunite_gradient(gradient_vector, jacobians, tensors) + return gradients + + +def _gramian_based( + weighting: Weighting[PSDMatrix], jacobians: list[Tensor], tensors: list[TensorWithJac] +) -> list[Tensor]: + gramian = _compute_gramian_sum(jacobians) + weights = weighting(gramian) + gradients = [torch.tensordot(weights, jacobian, dims=1) for jacobian in jacobians] + return gradients + + +def _compute_gramian_sum(jacobians: list[Tensor]) -> PSDMatrix: + gramian = sum([compute_gramian(matrix) for matrix in jacobians]) + return cast(PSDMatrix, gramian) def _unite_jacobians(jacobians: list[Tensor]) -> Tensor: From 8bdf51239213c0aa1310724d13b4f64b9a883fa5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 21 Jan 2026 20:39:14 +0100 Subject: [PATCH 02/26] Update changelog --- CHANGELOG.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d520e901..35416d6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,10 +35,8 @@ changelog does not include internal changes that do not affect the user. jac_to_grad(shared_module.parameters(), aggregator) ``` -- Removed an unnecessary memory duplication. This should significantly improve the memory efficiency - of `autojac`. -- Removed an unnecessary internal cloning of gradient. This should slightly improve the memory - efficiency of `autojac`. +- Removed several unnecessary memory duplications. This should significantly improve the memory + efficiency and speed of `autojac`. ## [0.8.1] - 2026-01-07 From aaf2544c2c990ea417128196af62d75ab2de881f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 23 Jan 2026 15:09:47 +0100 Subject: [PATCH 03/26] Use deque to free memory asap --- src/torchjd/autojac/_jac_to_grad.py | 31 ++++++++++++++++++----------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 9671880d..2888679a 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -1,3 +1,4 @@ +from collections import deque from collections.abc import Iterable from typing import cast @@ -66,9 +67,9 @@ def jac_to_grad( if len(tensors_) == 0: return - jacobians = [t.jac for t in tensors_] + jacobians = deque(t.jac for t in tensors_) - if not all([jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]]): + if not all([jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians]): raise ValueError("All Jacobians should have the same number of rows.") if not retain_jac: @@ -84,37 +85,43 @@ def jac_to_grad( def _jacobian_based( - aggregator: Aggregator, jacobians: list[Tensor], tensors: list[TensorWithJac] + aggregator: Aggregator, jacobians: deque[Tensor], tensors: list[TensorWithJac] ) -> list[Tensor]: jacobian_matrix = _unite_jacobians(jacobians) gradient_vector = aggregator(jacobian_matrix) - gradients = _disunite_gradient(gradient_vector, jacobians, tensors) + gradients = _disunite_gradient(gradient_vector, tensors) return gradients def _gramian_based( - weighting: Weighting[PSDMatrix], jacobians: list[Tensor], tensors: list[TensorWithJac] + weighting: Weighting[PSDMatrix], jacobians: deque[Tensor], tensors: list[TensorWithJac] ) -> list[Tensor]: gramian = _compute_gramian_sum(jacobians) weights = weighting(gramian) - gradients = [torch.tensordot(weights, jacobian, dims=1) for jacobian in jacobians] + + gradients = list[Tensor]() + while jacobians: + jacobian = jacobians.popleft() # get jacobian + dereference it to free memory asap + gradients.append(torch.tensordot(weights, jacobian, dims=1)) + return gradients -def _compute_gramian_sum(jacobians: list[Tensor]) -> PSDMatrix: +def _compute_gramian_sum(jacobians: deque[Tensor]) -> PSDMatrix: gramian = sum([compute_gramian(matrix) for matrix in jacobians]) return cast(PSDMatrix, gramian) -def _unite_jacobians(jacobians: list[Tensor]) -> Tensor: - jacobian_matrices = [jacobian.reshape(jacobian.shape[0], -1) for jacobian in jacobians] +def _unite_jacobians(jacobians: deque[Tensor]) -> Tensor: + jacobian_matrices = list[Tensor]() + while jacobians: + jacobian = jacobians.popleft() # get jacobian + dereference it to free memory asap + jacobian_matrices.append(jacobian.reshape(jacobian.shape[0], -1)) jacobian_matrix = torch.concat(jacobian_matrices, dim=1) return jacobian_matrix -def _disunite_gradient( - gradient_vector: Tensor, jacobians: list[Tensor], tensors: list[TensorWithJac] -) -> list[Tensor]: +def _disunite_gradient(gradient_vector: Tensor, tensors: list[TensorWithJac]) -> list[Tensor]: gradient_vectors = gradient_vector.split([t.numel() for t in tensors]) gradients = [g.view(t.shape) for g, t in zip(gradient_vectors, tensors)] return gradients From 8f65caa24ac09015bc160bfb8b2abb7928e07b9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 28 Jan 2026 16:58:23 +0100 Subject: [PATCH 04/26] Use gramian_weighting in jac_to_grad --- src/torchjd/autojac/_jac_to_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 2888679a..a50aeb7a 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -78,7 +78,7 @@ def jac_to_grad( if isinstance(aggregator, GramianWeightedAggregator): # When it's possible, avoid the concatenation of the jacobians that can be very costly in # memory. - gradients = _gramian_based(aggregator.weighting.weighting, jacobians, tensors_) + gradients = _gramian_based(aggregator.gramian_weighting, jacobians, tensors_) else: gradients = _jacobian_based(aggregator, jacobians, tensors_) accumulate_grads(tensors_, gradients) From f986950c043b3c1bd8618f77f7f8fa1533f401ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 29 Jan 2026 16:45:31 +0100 Subject: [PATCH 05/26] Only optimize when no forward hooks --- src/torchjd/autojac/_jac_to_grad.py | 9 +++-- tests/unit/autojac/test_jac_to_grad.py | 47 +++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index a50aeb7a..7642e793 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -3,7 +3,7 @@ from typing import cast import torch -from torch import Tensor +from torch import Tensor, nn from torchjd._linalg import PSDMatrix, compute_gramian from torchjd.aggregation import Aggregator, Weighting @@ -75,7 +75,7 @@ def jac_to_grad( if not retain_jac: _free_jacs(tensors_) - if isinstance(aggregator, GramianWeightedAggregator): + if isinstance(aggregator, GramianWeightedAggregator) and not _has_forward_hook(aggregator): # When it's possible, avoid the concatenation of the jacobians that can be very costly in # memory. gradients = _gramian_based(aggregator.gramian_weighting, jacobians, tensors_) @@ -84,6 +84,11 @@ def jac_to_grad( accumulate_grads(tensors_, gradients) +def _has_forward_hook(module: nn.Module) -> bool: + """Return whether the module has any forward hook registered.""" + return len(module._forward_hooks) > 0 or len(module._forward_pre_hooks) > 0 + + def _jacobian_based( aggregator: Aggregator, jacobians: deque[Tensor], tensors: list[TensorWithJac] ) -> list[Tensor]: diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index 60ea6838..b450ba87 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -3,7 +3,7 @@ from utils.tensors import tensor_ from torchjd.aggregation import Aggregator, Mean, PCGrad, UPGrad -from torchjd.autojac._jac_to_grad import jac_to_grad +from torchjd.autojac._jac_to_grad import _has_forward_hook, jac_to_grad @mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad()]) @@ -101,3 +101,48 @@ def test_jacs_are_freed(retain_jac: bool): check = assert_has_jac if retain_jac else assert_has_no_jac check(t1) check(t2) + + +def test_has_forward_hook(): + """Tests that _has_forward_hook correctly detects the presence of forward hooks.""" + + module = UPGrad() + + def dummy_forward_hook(_module, _input, _output): + return _output + + def dummy_forward_pre_hook(_module, _input): + return _input + + def dummy_backward_hook(_module, _grad_input, _grad_output): + return _grad_input + + def dummy_backward_pre_hook(_module, _grad_output): + return _grad_output + + # Module with no hooks or backward hooks only should return False + assert not _has_forward_hook(module) + module.register_full_backward_hook(dummy_backward_hook) + assert not _has_forward_hook(module) + module.register_full_backward_pre_hook(dummy_backward_pre_hook) + assert not _has_forward_hook(module) + + # Module with forward hook should return True + handle1 = module.register_forward_hook(dummy_forward_hook) + assert _has_forward_hook(module) + handle2 = module.register_forward_hook(dummy_forward_hook) + assert _has_forward_hook(module) + handle1.remove() + assert _has_forward_hook(module) + handle2.remove() + assert not _has_forward_hook(module) + + # Module with forward pre-hook should return True + handle3 = module.register_forward_pre_hook(dummy_forward_pre_hook) + assert _has_forward_hook(module) + handle4 = module.register_forward_pre_hook(dummy_forward_pre_hook) + assert _has_forward_hook(module) + handle3.remove() + assert _has_forward_hook(module) + handle4.remove() + assert not _has_forward_hook(module) From 4cf5cbba00d2e58c3d88f201b45f38fae60b27b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 29 Jan 2026 16:49:18 +0100 Subject: [PATCH 06/26] Make _gramian_based take aggregator instead of weighting --- src/torchjd/autojac/_jac_to_grad.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 7642e793..fbb113d8 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -6,7 +6,7 @@ from torch import Tensor, nn from torchjd._linalg import PSDMatrix, compute_gramian -from torchjd.aggregation import Aggregator, Weighting +from torchjd.aggregation import Aggregator from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac @@ -78,7 +78,7 @@ def jac_to_grad( if isinstance(aggregator, GramianWeightedAggregator) and not _has_forward_hook(aggregator): # When it's possible, avoid the concatenation of the jacobians that can be very costly in # memory. - gradients = _gramian_based(aggregator.gramian_weighting, jacobians, tensors_) + gradients = _gramian_based(aggregator, jacobians, tensors_) else: gradients = _jacobian_based(aggregator, jacobians, tensors_) accumulate_grads(tensors_, gradients) @@ -99,8 +99,9 @@ def _jacobian_based( def _gramian_based( - weighting: Weighting[PSDMatrix], jacobians: deque[Tensor], tensors: list[TensorWithJac] + aggregator: GramianWeightedAggregator, jacobians: deque[Tensor], tensors: list[TensorWithJac] ) -> list[Tensor]: + weighting = aggregator.gramian_weighting gramian = _compute_gramian_sum(jacobians) weights = weighting(gramian) From add549c6b76a9ac38eec55c68a82035fa61ffd75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 29 Jan 2026 16:56:22 +0100 Subject: [PATCH 07/26] Add _can_skip_jacobian_combination helper function Tbh I don't like it very much (because it's an extra function + some cast is required) but it's the only way to easily test that the correct aggregators use the optimized _gramian_based method. I also tried using return type hint of TypeGuard[GramianWeightedAggergator] instead of bool for _can_skip_jacobian_combination, but it's not really correct since we also check that the aggregator has no forward hook, so that TypeGuard would be really weird. So in the end we have to use this cast. --- src/torchjd/autojac/_jac_to_grad.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index fbb113d8..0be908c0 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -75,15 +75,17 @@ def jac_to_grad( if not retain_jac: _free_jacs(tensors_) - if isinstance(aggregator, GramianWeightedAggregator) and not _has_forward_hook(aggregator): - # When it's possible, avoid the concatenation of the jacobians that can be very costly in - # memory. - gradients = _gramian_based(aggregator, jacobians, tensors_) + if _can_skip_jacobian_combination(aggregator): + gradients = _gramian_based(cast(GramianWeightedAggregator, aggregator), jacobians, tensors_) else: gradients = _jacobian_based(aggregator, jacobians, tensors_) accumulate_grads(tensors_, gradients) +def _can_skip_jacobian_combination(aggregator: Aggregator) -> bool: + return isinstance(aggregator, GramianWeightedAggregator) and not _has_forward_hook(aggregator) + + def _has_forward_hook(module: nn.Module) -> bool: """Return whether the module has any forward hook registered.""" return len(module._forward_hooks) > 0 or len(module._forward_pre_hooks) > 0 From 453971a564da4d7bb8d3fe3fc742a7c3c9ddf7b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 29 Jan 2026 17:19:38 +0100 Subject: [PATCH 08/26] Add test_can_skip_jacobian_combination --- tests/unit/autojac/test_jac_to_grad.py | 72 +++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 2 deletions(-) diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index b450ba87..73c946a3 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -2,8 +2,28 @@ from utils.asserts import assert_grad_close, assert_has_jac, assert_has_no_jac from utils.tensors import tensor_ -from torchjd.aggregation import Aggregator, Mean, PCGrad, UPGrad -from torchjd.autojac._jac_to_grad import _has_forward_hook, jac_to_grad +from torchjd.aggregation import ( + IMTLG, + MGDA, + Aggregator, + AlignedMTL, + ConFIG, + Constant, + DualProj, + GradDrop, + Krum, + Mean, + PCGrad, + Random, + Sum, + TrimmedMean, + UPGrad, +) +from torchjd.autojac._jac_to_grad import ( + _can_skip_jacobian_combination, + _has_forward_hook, + jac_to_grad, +) @mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad()]) @@ -146,3 +166,51 @@ def dummy_backward_pre_hook(_module, _grad_output): assert _has_forward_hook(module) handle4.remove() assert not _has_forward_hook(module) + + +_PARAMETRIZATIONS = [ + (AlignedMTL(), True), + (DualProj(), True), + (IMTLG(), True), + (Krum(n_byzantine=1), True), + (MGDA(), True), + (PCGrad(), True), + (UPGrad(), True), + (ConFIG(), False), + (Constant(tensor_([0.5, 0.5])), False), + (GradDrop(), False), + (Mean(), False), + (Random(), False), + (Sum(), False), + (TrimmedMean(trim_number=1), False), +] + +try: + from torchjd.aggregation import CAGrad + + _PARAMETRIZATIONS.append((CAGrad(c=0.5), True)) +except ImportError: + pass + +try: + from torchjd.aggregation import NashMTL + + _PARAMETRIZATIONS.append((NashMTL(n_tasks=2), False)) +except ImportError: + pass + + +@mark.parametrize("aggregator, expected", _PARAMETRIZATIONS) +def test_can_skip_jacobian_combination(aggregator: Aggregator, expected: bool): + """ + Tests that _can_skip_jacobian_combination correctly identifies when optimization can be used. + """ + + assert _can_skip_jacobian_combination(aggregator) == expected + handle = aggregator.register_forward_hook(lambda module, input, output: output) + assert not _can_skip_jacobian_combination(aggregator) + handle.remove() + handle = aggregator.register_forward_pre_hook(lambda module, input: input) + assert not _can_skip_jacobian_combination(aggregator) + handle.remove() + assert _can_skip_jacobian_combination(aggregator) == expected From 9d4c41c692d77a9b57085a585c468312b8fbcdbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 29 Jan 2026 17:56:33 +0100 Subject: [PATCH 09/26] Optimize compute_gramian for when contracted_dims=-1 --- src/torchjd/_linalg/_gramian.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/torchjd/_linalg/_gramian.py b/src/torchjd/_linalg/_gramian.py index 1d7bebff..7ede9b60 100644 --- a/src/torchjd/_linalg/_gramian.py +++ b/src/torchjd/_linalg/_gramian.py @@ -30,11 +30,22 @@ def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor: first dimension). """ - contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + t.ndim - indices_source = list(range(t.ndim - contracted_dims)) - indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1)) - transposed = t.movedim(indices_source, indices_dest) - gramian = torch.tensordot(t, transposed, dims=contracted_dims) + # Optimization: it's faster to do that than moving dims and using tensordot, and this case + # happens very often, sometimes hundreds of times for a single jac_to_grad. + if contracted_dims == -1: + if t.ndim == 1: + matrix = t.unsqueeze(1) + else: + matrix = t.flatten(start_dim=1) + + gramian = matrix @ matrix.T + + else: + contracted_dims = contracted_dims if 0 <= contracted_dims else contracted_dims + t.ndim + indices_source = list(range(t.ndim - contracted_dims)) + indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1)) + transposed = t.movedim(indices_source, indices_dest) + gramian = torch.tensordot(t, transposed, dims=contracted_dims) return cast(PSDTensor, gramian) From 8f2660d300b7d0bda0aab954f5f5fb810a5704b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= <31951177+ValerianRey@users.noreply.github.com> Date: Fri, 30 Jan 2026 03:24:27 +0100 Subject: [PATCH 10/26] Use TypeGuard in _can_skip_jacobian_combination Co-authored-by: Pierre Quinton --- src/torchjd/autojac/_jac_to_grad.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 0be908c0..d0bb177d 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -1,6 +1,6 @@ from collections import deque from collections.abc import Iterable -from typing import cast +from typing import cast, TypeGuard import torch from torch import Tensor, nn @@ -76,13 +76,13 @@ def jac_to_grad( _free_jacs(tensors_) if _can_skip_jacobian_combination(aggregator): - gradients = _gramian_based(cast(GramianWeightedAggregator, aggregator), jacobians, tensors_) + gradients = _gramian_based(aggregator, jacobians, tensors_) else: gradients = _jacobian_based(aggregator, jacobians, tensors_) accumulate_grads(tensors_, gradients) -def _can_skip_jacobian_combination(aggregator: Aggregator) -> bool: +def _can_skip_jacobian_combination(aggregator: Aggregator) -> TypeGuard[GramianWeightedAggregator]: return isinstance(aggregator, GramianWeightedAggregator) and not _has_forward_hook(aggregator) From fc9bbcf02c19a6018b23e924baee6d618cf9a702 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 30 Jan 2026 02:24:51 +0000 Subject: [PATCH 11/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchjd/autojac/_jac_to_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index d0bb177d..7bbf8f64 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -1,6 +1,6 @@ from collections import deque from collections.abc import Iterable -from typing import cast, TypeGuard +from typing import TypeGuard, cast import torch from torch import Tensor, nn From b5ca22635a3a65544202ca969e510aafda859da8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 16:31:13 +0000 Subject: [PATCH 12/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 93a85a0f..8b9979ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,7 +45,7 @@ changelog does not include internal changes that do not affect the user. mtl_backward(losses, features) jac_to_grad(shared_module.parameters(), aggregator) ``` - + - Removed several unnecessary memory duplications. This should significantly improve the memory efficiency and speed of `autojac`. - Increased the lower bounds of the torch (from 2.0.0 to 2.3.0) and numpy (from 1.21.0 From 86be77882a1e36fef058c07ee389977cb02cb697 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Feb 2026 17:19:43 +0000 Subject: [PATCH 13/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchjd/autojac/_jac_to_grad.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 19c594d5..954667d7 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -94,7 +94,9 @@ def _has_forward_hook(module: nn.Module) -> bool: def _jacobian_based( - aggregator: Aggregator, jacobians: deque[Tensor], tensors: list[TensorWithJac] + aggregator: Aggregator, + jacobians: deque[Tensor], + tensors: list[TensorWithJac], ) -> list[Tensor]: jacobian_matrix = _unite_jacobians(jacobians) gradient_vector = aggregator(jacobian_matrix) @@ -103,7 +105,9 @@ def _jacobian_based( def _gramian_based( - aggregator: GramianWeightedAggregator, jacobians: deque[Tensor], tensors: list[TensorWithJac] + aggregator: GramianWeightedAggregator, + jacobians: deque[Tensor], + tensors: list[TensorWithJac], ) -> list[Tensor]: weighting = aggregator.gramian_weighting gramian = _compute_gramian_sum(jacobians) @@ -131,7 +135,7 @@ def _unite_jacobians(jacobians: deque[Tensor]) -> Tensor: return jacobian_matrix -def _disunite_gradient(gradient_vector: Tensor, tensors: list[TensorWithJac],) -> list[Tensor]: +def _disunite_gradient(gradient_vector: Tensor, tensors: list[TensorWithJac]) -> list[Tensor]: gradient_vectors = gradient_vector.split([t.numel() for t in tensors]) gradients = [g.view(t.shape) for g, t in zip(gradient_vectors, tensors, strict=True)] return gradients From 2a84bef5471894b32984b7703818490bd6e19da0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 13 Feb 2026 19:05:40 +0100 Subject: [PATCH 14/26] Add ruff if-else squeezing --- src/torchjd/_linalg/_gramian.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/torchjd/_linalg/_gramian.py b/src/torchjd/_linalg/_gramian.py index 862885ac..7eb9acd3 100644 --- a/src/torchjd/_linalg/_gramian.py +++ b/src/torchjd/_linalg/_gramian.py @@ -38,10 +38,7 @@ def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor: # Optimization: it's faster to do that than moving dims and using tensordot, and this case # happens very often, sometimes hundreds of times for a single jac_to_grad. if contracted_dims == -1: - if t.ndim == 1: - matrix = t.unsqueeze(1) - else: - matrix = t.flatten(start_dim=1) + matrix = t.unsqueeze(1) if t.ndim == 1 else t.flatten(start_dim=1) gramian = matrix @ matrix.T From 1b1c660a788013e33cb73d0dc17279bf1d833b31 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Feb 2026 14:07:02 +0000 Subject: [PATCH 15/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchjd/autojac/_jac_to_grad.py | 2 +- tests/unit/autojac/test_jac_to_grad.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 060730ee..5d831b2d 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -1,6 +1,6 @@ from collections import deque from collections.abc import Iterable -from typing import TypeGuard, overload, cast +from typing import TypeGuard, cast, overload import torch from torch import Tensor, nn diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index 4b7239f4..fc989663 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -23,14 +23,13 @@ TrimmedMean, UPGrad, ) +from torchjd.aggregation._aggregator_bases import WeightedAggregator from torchjd.autojac._jac_to_grad import ( _can_skip_jacobian_combination, _has_forward_hook, jac_to_grad, ) -from torchjd.aggregation._aggregator_bases import WeightedAggregator - @mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad(), ConFIG()]) def test_various_aggregators(aggregator: Aggregator) -> None: @@ -231,7 +230,7 @@ def test_can_skip_jacobian_combination(aggregator: Aggregator, expected: bool): assert not _can_skip_jacobian_combination(aggregator) handle.remove() assert _can_skip_jacobian_combination(aggregator) == expected - + def test_noncontiguous_jac() -> None: """Tests that jac_to_grad works when the .jac field is non-contiguous.""" From b7142535670967a225fe15d5cf7e2d7255da3266 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 23 Feb 2026 15:17:02 +0100 Subject: [PATCH 16/26] Many fixes of problems coming from the merge --- src/torchjd/autojac/_jac_to_grad.py | 8 +++++--- tests/unit/autojac/test_jac_to_grad.py | 16 ++++++++-------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 5d831b2d..2141a223 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -104,11 +104,13 @@ def jac_to_grad( _free_jacs(tensors_) if _can_skip_jacobian_combination(aggregator): - gradients, weights = _gramian_based(aggregator, jacobians, tensors_) + gradients, weights = _gramian_based(aggregator, jacobians) else: gradients, weights = _jacobian_based(aggregator, jacobians, tensors_) accumulate_grads(tensors_, gradients) + return weights + def _can_skip_jacobian_combination(aggregator: Aggregator) -> TypeGuard[GramianWeightedAggregator]: return isinstance(aggregator, GramianWeightedAggregator) and not _has_forward_hook(aggregator) @@ -116,6 +118,7 @@ def _can_skip_jacobian_combination(aggregator: Aggregator) -> TypeGuard[GramianW def _has_forward_hook(module: nn.Module) -> bool: """Return whether the module has any forward hook registered.""" + # TODO: also check hooks on the outer weighting return len(module._forward_hooks) > 0 or len(module._forward_pre_hooks) > 0 @@ -145,14 +148,13 @@ def capture_hook(_m: Weighting[Matrix], _i: tuple[Tensor], output: Tensor) -> No handle.remove() else: gradient_vector = aggregator(jacobian_matrix) - gradients = _disunite_gradient(gradient_vector, tensors_) + gradients = _disunite_gradient(gradient_vector, tensors) return gradients, weights def _gramian_based( aggregator: GramianWeightedAggregator, jacobians: deque[Tensor], - tensors: list[TensorWithJac], ) -> tuple[list[Tensor], Tensor]: weighting = aggregator.gramian_weighting gramian = _compute_gramian_sum(jacobians) diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index fc989663..9891f58d 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -139,21 +139,21 @@ def test_jacs_are_freed(retain_jac: bool) -> None: check(t2) -def test_has_forward_hook(): +def test_has_forward_hook() -> None: """Tests that _has_forward_hook correctly detects the presence of forward hooks.""" module = UPGrad() - def dummy_forward_hook(_module, _input, _output): + def dummy_forward_hook(_module, _input, _output) -> Tensor: return _output - def dummy_forward_pre_hook(_module, _input): + def dummy_forward_pre_hook(_module, _input) -> Tensor: return _input - def dummy_backward_hook(_module, _grad_input, _grad_output): + def dummy_backward_hook(_module, _grad_input, _grad_output) -> Tensor: return _grad_input - def dummy_backward_pre_hook(_module, _grad_output): + def dummy_backward_pre_hook(_module, _grad_output) -> Tensor: return _grad_output # Module with no hooks or backward hooks only should return False @@ -217,16 +217,16 @@ def dummy_backward_pre_hook(_module, _grad_output): @mark.parametrize("aggregator, expected", _PARAMETRIZATIONS) -def test_can_skip_jacobian_combination(aggregator: Aggregator, expected: bool): +def test_can_skip_jacobian_combination(aggregator: Aggregator, expected: bool) -> None: """ Tests that _can_skip_jacobian_combination correctly identifies when optimization can be used. """ assert _can_skip_jacobian_combination(aggregator) == expected - handle = aggregator.register_forward_hook(lambda module, input, output: output) + handle = aggregator.register_forward_hook(lambda _module, _input, output: output) assert not _can_skip_jacobian_combination(aggregator) handle.remove() - handle = aggregator.register_forward_pre_hook(lambda module, input: input) + handle = aggregator.register_forward_pre_hook(lambda _module, input: input) assert not _can_skip_jacobian_combination(aggregator) handle.remove() assert _can_skip_jacobian_combination(aggregator) == expected From 2bb8ab11daaa600110a6441fc13ab50fdda49651 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 23 Feb 2026 15:26:54 +0100 Subject: [PATCH 17/26] Fix _can_skip_jacobian_combination --- src/torchjd/autojac/_jac_to_grad.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 2141a223..967a2f89 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -113,12 +113,15 @@ def jac_to_grad( def _can_skip_jacobian_combination(aggregator: Aggregator) -> TypeGuard[GramianWeightedAggregator]: - return isinstance(aggregator, GramianWeightedAggregator) and not _has_forward_hook(aggregator) + return ( + isinstance(aggregator, GramianWeightedAggregator) + and not _has_forward_hook(aggregator) + and not _has_forward_hook(aggregator.weighting) + ) def _has_forward_hook(module: nn.Module) -> bool: """Return whether the module has any forward hook registered.""" - # TODO: also check hooks on the outer weighting return len(module._forward_hooks) > 0 or len(module._forward_pre_hooks) > 0 From 63c9dde3505c97a93a5824d06f66a96ca3215f6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 23 Feb 2026 15:36:01 +0100 Subject: [PATCH 18/26] Make check_consistent_first_dimension work with Deque --- src/torchjd/autojac/_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autojac/_utils.py b/src/torchjd/autojac/_utils.py index d3285559..fdcc7ce1 100644 --- a/src/torchjd/autojac/_utils.py +++ b/src/torchjd/autojac/_utils.py @@ -113,8 +113,9 @@ def check_consistent_first_dimension( :param jacobians: Sequence of Jacobian tensors to validate. :param variable_name: Name of the variable to include in the error message. """ + if len(jacobians) > 0 and not all( - jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:] + jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians ): raise ValueError(f"All Jacobians in `{variable_name}` should have the same number of rows.") From 0f858111c1f6ef475a8833a4fcc290a6c906185d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 23 Feb 2026 15:40:16 +0100 Subject: [PATCH 19/26] Improve test_can_skip_jacobian_combination --- tests/unit/autojac/test_jac_to_grad.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index 9891f58d..dc819838 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -23,7 +23,7 @@ TrimmedMean, UPGrad, ) -from torchjd.aggregation._aggregator_bases import WeightedAggregator +from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator, WeightedAggregator from torchjd.autojac._jac_to_grad import ( _can_skip_jacobian_combination, _has_forward_hook, @@ -226,11 +226,22 @@ def test_can_skip_jacobian_combination(aggregator: Aggregator, expected: bool) - handle = aggregator.register_forward_hook(lambda _module, _input, output: output) assert not _can_skip_jacobian_combination(aggregator) handle.remove() + assert _can_skip_jacobian_combination(aggregator) == expected handle = aggregator.register_forward_pre_hook(lambda _module, input: input) assert not _can_skip_jacobian_combination(aggregator) handle.remove() assert _can_skip_jacobian_combination(aggregator) == expected + if isinstance(aggregator, GramianWeightedAggregator): + handle = aggregator.weighting.register_forward_hook(lambda _module, _input, output: output) + assert not _can_skip_jacobian_combination(aggregator) + handle.remove() + assert _can_skip_jacobian_combination(aggregator) == expected + handle = aggregator.weighting.register_forward_pre_hook(lambda _module, input: input) + assert not _can_skip_jacobian_combination(aggregator) + handle.remove() + assert _can_skip_jacobian_combination(aggregator) == expected + def test_noncontiguous_jac() -> None: """Tests that jac_to_grad works when the .jac field is non-contiguous.""" From 9d55215e05e4d6c52fb1ab2b94282c2bf40d05ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 23 Feb 2026 15:51:55 +0100 Subject: [PATCH 20/26] Add optimize_gramian_computation param and add error when not compatible --- src/torchjd/autojac/_jac_to_grad.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 967a2f89..ef67aafc 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -39,6 +39,7 @@ def jac_to_grad( aggregator: Aggregator, *, retain_jac: bool = False, + optimize_gramian_computation: bool = False, ) -> Tensor | None: r""" Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors`` and accumulates the result @@ -51,6 +52,13 @@ def jac_to_grad( the Jacobians, ``jac_to_grad`` will also return the computed weights. :param retain_jac: Whether to preserve the ``.jac`` fields of the tensors after they have been used. Defaults to ``False``. + :param optimize_gramian_computation: When the ``aggregator`` is a + :class:`GramianWeightedAggregator ` + (e.g. :class:`UPGrad `), it's possible to skip the + concatenation of the Jacobians and to instead compute the Gramian as the sum of the Gramians + of the individual Jacobians. This saves memory (up to 50% memory saving) but can be slightly + slower (up to 15%) on CUDA. We advise to try this optimization if memory is an issue for + you. Defaults to ``False``. .. note:: This function starts by "flattening" the ``.jac`` fields into matrices (i.e. flattening all @@ -103,7 +111,14 @@ def jac_to_grad( if not retain_jac: _free_jacs(tensors_) - if _can_skip_jacobian_combination(aggregator): + if optimize_gramian_computation: + if not _can_skip_jacobian_combination(aggregator): + raise ValueError( + "In order to use `jac_to_grad` with `optimize_gramian_computation=True`, you must " + "provide a `GramianWeightedAggregator` that doesn't have any forward hooks attached" + " to it." + ) + gradients, weights = _gramian_based(aggregator, jacobians) else: gradients, weights = _jacobian_based(aggregator, jacobians, tensors_) From 456510b79e65fe025e6863a1d73c022f2640ed71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 23 Feb 2026 16:05:20 +0100 Subject: [PATCH 21/26] Fix overloads (partly) and add missing code coverage --- src/torchjd/autojac/_jac_to_grad.py | 2 ++ tests/unit/autojac/test_jac_to_grad.py | 26 +++++++++++++++++++++++--- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index ef67aafc..8c8ec123 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -20,6 +20,7 @@ def jac_to_grad( aggregator: WeightedAggregator, *, retain_jac: bool = False, + optimize_gramian_computation: bool = False, ) -> Tensor: ... @@ -30,6 +31,7 @@ def jac_to_grad( aggregator: Aggregator, # Not a WeightedAggregator, because overloads are checked in order *, retain_jac: bool = False, + optimize_gramian_computation: bool = False, ) -> None: ... diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index dc819838..0c5df145 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -31,8 +31,11 @@ ) -@mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad(), ConFIG()]) -def test_various_aggregators(aggregator: Aggregator) -> None: +@mark.parametrize( + ["aggregator", "opt"], + [(Mean(), False), (UPGrad(), True), (UPGrad(), False), (PCGrad(), True), (ConFIG(), False)], +) +def test_various_aggregators(aggregator: Aggregator, opt: bool) -> None: """ Tests that jac_to_grad works for various aggregators. For those that are weighted, the weights should also be returned. For the others, None should be returned. @@ -47,7 +50,7 @@ def test_various_aggregators(aggregator: Aggregator) -> None: g1 = expected_grad[0] g2 = expected_grad[1:] - optional_weights = jac_to_grad([t1, t2], aggregator) + optional_weights = jac_to_grad([t1, t2], aggregator, optimize_gramian_computation=opt) assert_grad_close(t1, g1) assert_grad_close(t2, g2) @@ -303,3 +306,20 @@ def hook_inner(_module: Any, _input: Any, weights: Tensor) -> Tensor: weights = jac_to_grad([t], aggregator) assert_close(weights, aggregator.weighting(jac)) + + +def test_optimize_gramian_computation_error() -> None: + """ + Tests that using optimize_gramian_computation on an incompatible aggregator raises an error. + """ + + aggregator = ConFIG() + + t1 = tensor_(1.0, requires_grad=True) + t2 = tensor_([2.0, 3.0], requires_grad=True) + jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]]) + t1.__setattr__("jac", jac[:, 0]) + t2.__setattr__("jac", jac[:, 1:]) + + with raises(ValueError): + jac_to_grad([t1, t2], aggregator, optimize_gramian_computation=True) From 55c69d1744796cba82e022dbddc2dacabe2b80ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 23 Feb 2026 16:30:21 +0100 Subject: [PATCH 22/26] Fix overloads --- src/torchjd/autojac/_jac_to_grad.py | 13 +++++++++++-- tests/unit/autojac/test_jac_to_grad.py | 10 +++++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 8c8ec123..8436041b 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -17,13 +17,23 @@ def jac_to_grad( tensors: Iterable[Tensor], /, - aggregator: WeightedAggregator, + aggregator: GramianWeightedAggregator, *, retain_jac: bool = False, optimize_gramian_computation: bool = False, ) -> Tensor: ... +@overload +def jac_to_grad( + tensors: Iterable[Tensor], + /, + aggregator: WeightedAggregator, # Not a GramianWA, because overloads are checked in order + *, + retain_jac: bool = False, +) -> Tensor: ... + + @overload def jac_to_grad( tensors: Iterable[Tensor], @@ -31,7 +41,6 @@ def jac_to_grad( aggregator: Aggregator, # Not a WeightedAggregator, because overloads are checked in order *, retain_jac: bool = False, - optimize_gramian_computation: bool = False, ) -> None: ... diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index 0c5df145..742e8c23 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -35,7 +35,7 @@ ["aggregator", "opt"], [(Mean(), False), (UPGrad(), True), (UPGrad(), False), (PCGrad(), True), (ConFIG(), False)], ) -def test_various_aggregators(aggregator: Aggregator, opt: bool) -> None: +def test_various_aggregators(aggregator: Aggregator, optimize: bool) -> None: """ Tests that jac_to_grad works for various aggregators. For those that are weighted, the weights should also be returned. For the others, None should be returned. @@ -50,7 +50,11 @@ def test_various_aggregators(aggregator: Aggregator, opt: bool) -> None: g1 = expected_grad[0] g2 = expected_grad[1:] - optional_weights = jac_to_grad([t1, t2], aggregator, optimize_gramian_computation=opt) + if optimize: + assert isinstance(aggregator, GramianWeightedAggregator) + optional_weights = jac_to_grad([t1, t2], aggregator, optimize_gramian_computation=True) + else: + optional_weights = jac_to_grad([t1, t2], aggregator) assert_grad_close(t1, g1) assert_grad_close(t2, g2) @@ -322,4 +326,4 @@ def test_optimize_gramian_computation_error() -> None: t2.__setattr__("jac", jac[:, 1:]) with raises(ValueError): - jac_to_grad([t1, t2], aggregator, optimize_gramian_computation=True) + jac_to_grad([t1, t2], aggregator, optimize_gramian_computation=True) # ty:ignore[invalid-argument-type] From 8a401a3baf346892ac70c60986f5d82ec6ae2367 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 23 Feb 2026 16:32:39 +0100 Subject: [PATCH 23/26] Fix docstring --- src/torchjd/autojac/_jac_to_grad.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 8436041b..1aee6ee1 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -63,13 +63,11 @@ def jac_to_grad( the Jacobians, ``jac_to_grad`` will also return the computed weights. :param retain_jac: Whether to preserve the ``.jac`` fields of the tensors after they have been used. Defaults to ``False``. - :param optimize_gramian_computation: When the ``aggregator`` is a - :class:`GramianWeightedAggregator ` - (e.g. :class:`UPGrad `), it's possible to skip the - concatenation of the Jacobians and to instead compute the Gramian as the sum of the Gramians - of the individual Jacobians. This saves memory (up to 50% memory saving) but can be slightly - slower (up to 15%) on CUDA. We advise to try this optimization if memory is an issue for - you. Defaults to ``False``. + :param optimize_gramian_computation: When the ``aggregator`` computes weights based on the + Gramian of the Jacobian, it's possible to skip the concatenation of the Jacobians and to + instead compute the Gramian as the sum of the Gramians of the individual Jacobians. This + saves memory (up to 50% memory saving) but can be slightly slower (up to 15%) on CUDA. We + advise to try this optimization if memory is an issue for you. Defaults to ``False``. .. note:: This function starts by "flattening" the ``.jac`` fields into matrices (i.e. flattening all From b4bf7c41aa3e70ab4cdeaf063b3479eb33de969d Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 23 Feb 2026 16:41:09 +0100 Subject: [PATCH 24/26] fixup what @ValerianRey did wrong --- tests/unit/autojac/test_jac_to_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index 742e8c23..2db0b7f3 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -32,7 +32,7 @@ @mark.parametrize( - ["aggregator", "opt"], + ["aggregator", "optimize"], [(Mean(), False), (UPGrad(), True), (UPGrad(), False), (PCGrad(), True), (ConFIG(), False)], ) def test_various_aggregators(aggregator: Aggregator, optimize: bool) -> None: From 24a991aaaeda0d1bcf8a687d198c20fce86e59d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 23 Feb 2026 18:24:55 +0100 Subject: [PATCH 25/26] Improve error message --- src/torchjd/autojac/_jac_to_grad.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 1aee6ee1..7492d392 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -124,8 +124,8 @@ def jac_to_grad( if not _can_skip_jacobian_combination(aggregator): raise ValueError( "In order to use `jac_to_grad` with `optimize_gramian_computation=True`, you must " - "provide a `GramianWeightedAggregator` that doesn't have any forward hooks attached" - " to it." + "provide an `Aggregator` that computes weights based on the Gramian of the Jacobian" + " (e.g. `UPGrad`) and that doesn't have any forward hooks attached to it." ) gradients, weights = _gramian_based(aggregator, jacobians) From 2ea44a4b5717ef224eee65353d39a4a2bd7a27c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 23 Feb 2026 18:31:41 +0100 Subject: [PATCH 26/26] Improve docstring --- src/torchjd/autojac/_jac_to_grad.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 7492d392..5947a998 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -70,10 +70,20 @@ def jac_to_grad( advise to try this optimization if memory is an issue for you. Defaults to ``False``. .. note:: - This function starts by "flattening" the ``.jac`` fields into matrices (i.e. flattening all - of their dimensions except the first one), then concatenates those matrices into a combined - Jacobian matrix. The aggregator is then used on this matrix, which returns a combined - gradient vector, that is split and reshaped to fit into the ``.grad`` fields of the tensors. + When ``optimize_gramian_computation=False``, this function starts by "flattening" the + ``.jac`` fields into matrices (i.e. flattening all of their dimensions except the first + one), then concatenates those matrices into a combined Jacobian matrix. The ``aggregator`` + is then used on this matrix, which returns a combined gradient vector, that is split and + reshaped to fit into the ``.grad`` fields of the tensors. + + .. note:: + When ``optimize_gramian_computation=True``, this function computes and sums the Gramian + of each individual ``.jac`` field, iteratively. The inner weighting of the ``aggregator`` is + then used to extract some weights from the obtained Gramian, used to compute a linear + combination of the rows of each ``.jac`` field, to be stored into the corresponding + ``.grad`` field. This is mathematically equivalent to the approach with + ``optimize_gramian_computation=False``, but saves memory by not having to hold the + concatenated Jacobian matrix in memory at any time. .. admonition:: Example