Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
ae6be7d
WIP: add gramian-based jac_to_grad
ValerianRey Jan 21, 2026
8bdf512
Update changelog
ValerianRey Jan 21, 2026
aaf2544
Use deque to free memory asap
ValerianRey Jan 23, 2026
64b06ad
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 23, 2026
745f707
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 28, 2026
5eb77f9
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 28, 2026
8f65caa
Use gramian_weighting in jac_to_grad
ValerianRey Jan 28, 2026
6fe15a4
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 28, 2026
d5cb5c2
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 29, 2026
f986950
Only optimize when no forward hooks
ValerianRey Jan 29, 2026
4cf5cbb
Make _gramian_based take aggregator instead of weighting
ValerianRey Jan 29, 2026
add549c
Add _can_skip_jacobian_combination helper function
ValerianRey Jan 29, 2026
453971a
Add test_can_skip_jacobian_combination
ValerianRey Jan 29, 2026
9d4c41c
Optimize compute_gramian for when contracted_dims=-1
ValerianRey Jan 29, 2026
48cd70b
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 30, 2026
8f2660d
Use TypeGuard in _can_skip_jacobian_combination
ValerianRey Jan 30, 2026
fc9bbcf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2026
3f9a6d1
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 1, 2026
9d9cbf0
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 4, 2026
b5ca226
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2026
0baa914
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 5, 2026
2ed1d7c
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 13, 2026
86be778
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 13, 2026
4ace19e
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 13, 2026
2a84bef
Add ruff if-else squeezing
ValerianRey Feb 13, 2026
4b6209c
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 23, 2026
1b1c660
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 23, 2026
b714253
Many fixes of problems coming from the merge
ValerianRey Feb 23, 2026
2bb8ab1
Fix _can_skip_jacobian_combination
ValerianRey Feb 23, 2026
63c9dde
Make check_consistent_first_dimension work with Deque
ValerianRey Feb 23, 2026
0f85811
Improve test_can_skip_jacobian_combination
ValerianRey Feb 23, 2026
9d55215
Add optimize_gramian_computation param and add error when not compatible
ValerianRey Feb 23, 2026
456510b
Fix overloads (partly) and add missing code coverage
ValerianRey Feb 23, 2026
55c69d1
Fix overloads
ValerianRey Feb 23, 2026
8a401a3
Fix docstring
ValerianRey Feb 23, 2026
b4bf7c4
fixup what @ValerianRey did wrong
PierreQuinton Feb 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,8 @@ changelog does not include internal changes that do not affect the user.
- `GeneralizedWeighting.__call__`: The `generalized_gramian` parameter is now positional-only.
Suggested change: `generalized_weighting(generalized_gramian=generalized_gramian)` =>
`generalized_weighting(generalized_gramian)`.
- Removed an unnecessary memory duplication. This should significantly improve the memory efficiency
of `autojac`.
- Removed an unnecessary internal cloning of gradient. This should slightly improve the memory
efficiency of `autojac`.
- Removed several unnecessary memory duplications. This should significantly improve the memory
efficiency and speed of `autojac`.
- Increased the lower bounds of the torch (from 2.0.0 to 2.3.0) and numpy (from 1.21.0
to 1.21.2) dependencies to reflect what really works with torchjd. We now also run torchjd's tests
with the dependency lower-bounds specified in `pyproject.toml`, so we should now always accurately
Expand Down
19 changes: 14 additions & 5 deletions src/torchjd/_linalg/_gramian.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,20 @@ def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor:
first dimension).
"""

contracted_dims = contracted_dims if contracted_dims >= 0 else contracted_dims + t.ndim
indices_source = list(range(t.ndim - contracted_dims))
indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1))
transposed = t.movedim(indices_source, indices_dest)
gramian = torch.tensordot(t, transposed, dims=contracted_dims)
# Optimization: it's faster to do that than moving dims and using tensordot, and this case
# happens very often, sometimes hundreds of times for a single jac_to_grad.
if contracted_dims == -1:
matrix = t.unsqueeze(1) if t.ndim == 1 else t.flatten(start_dim=1)

gramian = matrix @ matrix.T

else:
contracted_dims = contracted_dims if contracted_dims >= 0 else contracted_dims + t.ndim
indices_source = list(range(t.ndim - contracted_dims))
indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1))
transposed = t.movedim(indices_source, indices_dest)
gramian = torch.tensordot(t, transposed, dims=contracted_dims)

return cast(PSDTensor, gramian)


Expand Down
98 changes: 86 additions & 12 deletions src/torchjd/autojac/_jac_to_grad.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from collections import deque
from collections.abc import Iterable
from typing import overload
from typing import TypeGuard, cast, overload

import torch
from torch import Tensor
from torch import Tensor, nn

from torchjd._linalg import Matrix
from torchjd._linalg import Matrix, PSDMatrix, compute_gramian
from torchjd.aggregation import Aggregator, Weighting
from torchjd.aggregation._aggregator_bases import WeightedAggregator
from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator, WeightedAggregator

from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac
from ._utils import check_consistent_first_dimension
Expand All @@ -16,7 +17,18 @@
def jac_to_grad(
tensors: Iterable[Tensor],
/,
aggregator: WeightedAggregator,
aggregator: GramianWeightedAggregator,
*,
retain_jac: bool = False,
optimize_gramian_computation: bool = False,
) -> Tensor: ...


@overload
def jac_to_grad(
tensors: Iterable[Tensor],
/,
aggregator: WeightedAggregator, # Not a GramianWA, because overloads are checked in order
*,
retain_jac: bool = False,
) -> Tensor: ...
Expand All @@ -38,6 +50,7 @@ def jac_to_grad(
aggregator: Aggregator,
*,
retain_jac: bool = False,
optimize_gramian_computation: bool = False,
) -> Tensor | None:
r"""
Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors`` and accumulates the result
Expand All @@ -50,6 +63,11 @@ def jac_to_grad(
the Jacobians, ``jac_to_grad`` will also return the computed weights.
:param retain_jac: Whether to preserve the ``.jac`` fields of the tensors after they have been
used. Defaults to ``False``.
:param optimize_gramian_computation: When the ``aggregator`` computes weights based on the
Gramian of the Jacobian, it's possible to skip the concatenation of the Jacobians and to
instead compute the Gramian as the sum of the Gramians of the individual Jacobians. This
saves memory (up to 50% memory saving) but can be slightly slower (up to 15%) on CUDA. We
advise to try this optimization if memory is an issue for you. Defaults to ``False``.

.. note::
This function starts by "flattening" the ``.jac`` fields into matrices (i.e. flattening all
Expand Down Expand Up @@ -96,13 +114,46 @@ def jac_to_grad(
if len(tensors_) == 0:
raise ValueError("The `tensors` parameter cannot be empty.")

jacobians = [t.jac for t in tensors_]

jacobians = deque(t.jac for t in tensors_)
check_consistent_first_dimension(jacobians, "tensors.jac")

if not retain_jac:
_free_jacs(tensors_)

if optimize_gramian_computation:
if not _can_skip_jacobian_combination(aggregator):
raise ValueError(
"In order to use `jac_to_grad` with `optimize_gramian_computation=True`, you must "
"provide a `GramianWeightedAggregator` that doesn't have any forward hooks attached"
" to it."
)

gradients, weights = _gramian_based(aggregator, jacobians)
else:
gradients, weights = _jacobian_based(aggregator, jacobians, tensors_)
accumulate_grads(tensors_, gradients)

return weights


def _can_skip_jacobian_combination(aggregator: Aggregator) -> TypeGuard[GramianWeightedAggregator]:
return (
isinstance(aggregator, GramianWeightedAggregator)
and not _has_forward_hook(aggregator)
and not _has_forward_hook(aggregator.weighting)
)


def _has_forward_hook(module: nn.Module) -> bool:
"""Return whether the module has any forward hook registered."""
return len(module._forward_hooks) > 0 or len(module._forward_pre_hooks) > 0


def _jacobian_based(
aggregator: Aggregator,
jacobians: deque[Tensor],
tensors: list[TensorWithJac],
) -> tuple[list[Tensor], Tensor | None]:
jacobian_matrix = _unite_jacobians(jacobians)
weights: Tensor | None = None

Expand All @@ -124,13 +175,36 @@ def capture_hook(_m: Weighting[Matrix], _i: tuple[Tensor], output: Tensor) -> No
handle.remove()
else:
gradient_vector = aggregator(jacobian_matrix)
gradients = _disunite_gradient(gradient_vector, tensors_)
accumulate_grads(tensors_, gradients)
return weights
gradients = _disunite_gradient(gradient_vector, tensors)
return gradients, weights


def _gramian_based(
aggregator: GramianWeightedAggregator,
jacobians: deque[Tensor],
) -> tuple[list[Tensor], Tensor]:
weighting = aggregator.gramian_weighting
gramian = _compute_gramian_sum(jacobians)
weights = weighting(gramian)

gradients = list[Tensor]()
while jacobians:
jacobian = jacobians.popleft() # get jacobian + dereference it to free memory asap
gradients.append(torch.tensordot(weights, jacobian, dims=1))

return gradients, weights


def _compute_gramian_sum(jacobians: deque[Tensor]) -> PSDMatrix:
gramian = sum([compute_gramian(matrix) for matrix in jacobians])
return cast(PSDMatrix, gramian)


def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:
jacobian_matrices = [jacobian.reshape(jacobian.shape[0], -1) for jacobian in jacobians]
def _unite_jacobians(jacobians: deque[Tensor]) -> Tensor:
jacobian_matrices = list[Tensor]()
while jacobians:
jacobian = jacobians.popleft() # get jacobian + dereference it to free memory asap
jacobian_matrices.append(jacobian.reshape(jacobian.shape[0], -1))
jacobian_matrix = torch.concat(jacobian_matrices, dim=1)
return jacobian_matrix

Expand Down
3 changes: 2 additions & 1 deletion src/torchjd/autojac/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ def check_consistent_first_dimension(
:param jacobians: Sequence of Jacobian tensors to validate.
:param variable_name: Name of the variable to include in the error message.
"""

if len(jacobians) > 0 and not all(
jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]
jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians
):
raise ValueError(f"All Jacobians in `{variable_name}` should have the same number of rows.")

Expand Down
152 changes: 147 additions & 5 deletions tests/unit/autojac/test_jac_to_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,35 @@
from utils.tensors import tensor_

from torchjd.aggregation import (
IMTLG,
MGDA,
Aggregator,
AlignedMTL,
ConFIG,
Constant,
DualProj,
GradDrop,
Krum,
Mean,
PCGrad,
Random,
Sum,
TrimmedMean,
UPGrad,
)
from torchjd.aggregation._aggregator_bases import WeightedAggregator
from torchjd.autojac._jac_to_grad import jac_to_grad
from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator, WeightedAggregator
from torchjd.autojac._jac_to_grad import (
_can_skip_jacobian_combination,
_has_forward_hook,
jac_to_grad,
)


@mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad(), ConFIG()])
def test_various_aggregators(aggregator: Aggregator) -> None:
@mark.parametrize(
["aggregator", "optimize"],
[(Mean(), False), (UPGrad(), True), (UPGrad(), False), (PCGrad(), True), (ConFIG(), False)],
)
def test_various_aggregators(aggregator: Aggregator, optimize: bool) -> None:
"""
Tests that jac_to_grad works for various aggregators. For those that are weighted, the weights
should also be returned. For the others, None should be returned.
Expand All @@ -33,7 +50,11 @@ def test_various_aggregators(aggregator: Aggregator) -> None:
g1 = expected_grad[0]
g2 = expected_grad[1:]

optional_weights = jac_to_grad([t1, t2], aggregator)
if optimize:
assert isinstance(aggregator, GramianWeightedAggregator)
optional_weights = jac_to_grad([t1, t2], aggregator, optimize_gramian_computation=True)
else:
optional_weights = jac_to_grad([t1, t2], aggregator)

assert_grad_close(t1, g1)
assert_grad_close(t2, g2)
Expand Down Expand Up @@ -125,6 +146,110 @@ def test_jacs_are_freed(retain_jac: bool) -> None:
check(t2)


def test_has_forward_hook() -> None:
"""Tests that _has_forward_hook correctly detects the presence of forward hooks."""

module = UPGrad()

def dummy_forward_hook(_module, _input, _output) -> Tensor:
return _output

def dummy_forward_pre_hook(_module, _input) -> Tensor:
return _input

def dummy_backward_hook(_module, _grad_input, _grad_output) -> Tensor:
return _grad_input

def dummy_backward_pre_hook(_module, _grad_output) -> Tensor:
return _grad_output

# Module with no hooks or backward hooks only should return False
assert not _has_forward_hook(module)
module.register_full_backward_hook(dummy_backward_hook)
assert not _has_forward_hook(module)
module.register_full_backward_pre_hook(dummy_backward_pre_hook)
assert not _has_forward_hook(module)

# Module with forward hook should return True
handle1 = module.register_forward_hook(dummy_forward_hook)
assert _has_forward_hook(module)
handle2 = module.register_forward_hook(dummy_forward_hook)
assert _has_forward_hook(module)
handle1.remove()
assert _has_forward_hook(module)
handle2.remove()
assert not _has_forward_hook(module)

# Module with forward pre-hook should return True
handle3 = module.register_forward_pre_hook(dummy_forward_pre_hook)
assert _has_forward_hook(module)
handle4 = module.register_forward_pre_hook(dummy_forward_pre_hook)
assert _has_forward_hook(module)
handle3.remove()
assert _has_forward_hook(module)
handle4.remove()
assert not _has_forward_hook(module)


_PARAMETRIZATIONS = [
(AlignedMTL(), True),
(DualProj(), True),
(IMTLG(), True),
(Krum(n_byzantine=1), True),
(MGDA(), True),
(PCGrad(), True),
(UPGrad(), True),
(ConFIG(), False),
(Constant(tensor_([0.5, 0.5])), False),
(GradDrop(), False),
(Mean(), False),
(Random(), False),
(Sum(), False),
(TrimmedMean(trim_number=1), False),
]

try:
from torchjd.aggregation import CAGrad

_PARAMETRIZATIONS.append((CAGrad(c=0.5), True))
except ImportError:
pass

try:
from torchjd.aggregation import NashMTL

_PARAMETRIZATIONS.append((NashMTL(n_tasks=2), False))
except ImportError:
pass


@mark.parametrize("aggregator, expected", _PARAMETRIZATIONS)
def test_can_skip_jacobian_combination(aggregator: Aggregator, expected: bool) -> None:
"""
Tests that _can_skip_jacobian_combination correctly identifies when optimization can be used.
"""

assert _can_skip_jacobian_combination(aggregator) == expected
handle = aggregator.register_forward_hook(lambda _module, _input, output: output)
assert not _can_skip_jacobian_combination(aggregator)
handle.remove()
assert _can_skip_jacobian_combination(aggregator) == expected
handle = aggregator.register_forward_pre_hook(lambda _module, input: input)
assert not _can_skip_jacobian_combination(aggregator)
handle.remove()
assert _can_skip_jacobian_combination(aggregator) == expected

if isinstance(aggregator, GramianWeightedAggregator):
handle = aggregator.weighting.register_forward_hook(lambda _module, _input, output: output)
assert not _can_skip_jacobian_combination(aggregator)
handle.remove()
assert _can_skip_jacobian_combination(aggregator) == expected
handle = aggregator.weighting.register_forward_pre_hook(lambda _module, input: input)
assert not _can_skip_jacobian_combination(aggregator)
handle.remove()
assert _can_skip_jacobian_combination(aggregator) == expected


def test_noncontiguous_jac() -> None:
"""Tests that jac_to_grad works when the .jac field is non-contiguous."""

Expand Down Expand Up @@ -185,3 +310,20 @@ def hook_inner(_module: Any, _input: Any, weights: Tensor) -> Tensor:

weights = jac_to_grad([t], aggregator)
assert_close(weights, aggregator.weighting(jac))


def test_optimize_gramian_computation_error() -> None:
"""
Tests that using optimize_gramian_computation on an incompatible aggregator raises an error.
"""

aggregator = ConFIG()

t1 = tensor_(1.0, requires_grad=True)
t2 = tensor_([2.0, 3.0], requires_grad=True)
jac = tensor_([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]])
t1.__setattr__("jac", jac[:, 0])
t2.__setattr__("jac", jac[:, 1:])

with raises(ValueError):
jac_to_grad([t1, t2], aggregator, optimize_gramian_computation=True) # ty:ignore[invalid-argument-type]