Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions src/torchjd/autojac/_jac_to_grad.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,43 @@
from collections.abc import Iterable
from typing import overload

import torch
from torch import Tensor

from torchjd.aggregation import Aggregator
from torchjd.aggregation._aggregator_bases import WeightedAggregator

from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac
from ._utils import check_consistent_first_dimension


@overload
def jac_to_grad(
tensors: Iterable[Tensor],
/,
aggregator: WeightedAggregator,
*,
retain_jac: bool = False,
) -> Tensor: ...


@overload
def jac_to_grad(
tensors: Iterable[Tensor],
/,
aggregator: Aggregator, # Not a WeightedAggregator, because overloads are checked in order
Copy link
Contributor

@PierreQuinton PierreQuinton Feb 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have type disjunction though.

*,
retain_jac: bool = False,
) -> None: ...


def jac_to_grad(
tensors: Iterable[Tensor],
/,
aggregator: Aggregator,
*,
retain_jac: bool = False,
) -> None:
) -> Tensor | None:
r"""
Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors`` and accumulates the result
into their ``.grad`` fields.
Expand All @@ -25,6 +47,9 @@ def jac_to_grad(
:param aggregator: The aggregator used to reduce the Jacobians into gradients.
:param retain_jac: Whether to preserve the ``.jac`` fields of the tensors after they have been
used. Defaults to ``False``.
:returns: If ``aggregator`` is based on a
:class:`Weighting <torchjd.aggregation._weighting_bases.Weighting>` to combine the rows of
the Jacobians, returns the weights used for the aggregation, otherwise returns ``None``.

.. note::
This function starts by "flattening" the ``.jac`` fields into matrices (i.e. flattening all
Expand All @@ -48,12 +73,15 @@ def jac_to_grad(
>>> y2 = (param ** 2).sum()
>>>
>>> backward([y1, y2]) # param now has a .jac field
>>> jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field
>>> weights = jac_to_grad([param], UPGrad()) # param now has a .grad field
>>> param.grad
tensor([-1., 1.])
tensor([0.5000, 2.5000])
>>> weights
tensor([0.5, 0.5])

The ``.grad`` field of ``param`` now contains the aggregation (by UPGrad) of the Jacobian of
:math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``.
:math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. In this case, the
weights used to combine the Jacobian are equal because there was no conflict.
"""

tensors_ = list[TensorWithJac]()
Expand All @@ -66,7 +94,7 @@ def jac_to_grad(
tensors_.append(t)

if len(tensors_) == 0:
return
raise ValueError("The field `tensor` cannot be empty.")

jacobians = [t.jac for t in tensors_]

Expand All @@ -76,9 +104,15 @@ def jac_to_grad(
_free_jacs(tensors_)

jacobian_matrix = _unite_jacobians(jacobians)
gradient_vector = aggregator(jacobian_matrix)
if isinstance(aggregator, WeightedAggregator):
weights = aggregator.weighting(jacobian_matrix)
gradient_vector = weights @ jacobian_matrix
Comment on lines +108 to +109
Copy link
Contributor

@PierreQuinton PierreQuinton Feb 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, the hooks of the aggregator are not run. Not sure about how to resolve this without running the aggregator.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like we could cheat a bit and have a boolean field return_weights to forward in WeightedAggregators that we don't expose, on True, it would then also return the weights, and run the hooks. There is probably a better solution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its possible that we would benefit from having our own system of hooks for aggregators and weightings.

else:
weights = None
gradient_vector = aggregator(jacobian_matrix)
gradients = _disunite_gradient(gradient_vector, tensors_)
accumulate_grads(tensors_, gradients)
return weights


def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:
Expand Down
4 changes: 3 additions & 1 deletion tests/doc/test_jac_to_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
the obtained `.grad` field.
"""

from torch.testing import assert_close
from utils.asserts import assert_grad_close


Expand All @@ -17,6 +18,7 @@ def test_jac_to_grad() -> None:
y1 = torch.tensor([-1.0, 1.0]) @ param
y2 = (param**2).sum()
backward([y1, y2]) # param now has a .jac field
jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field
weights = jac_to_grad([param], UPGrad()) # param now has a .grad field

assert_grad_close(param, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04)
assert_close(weights, torch.tensor([0.5, 0.5]), rtol=0.0, atol=0.0)
29 changes: 24 additions & 5 deletions tests/unit/autojac/test_jac_to_grad.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
from pytest import mark, raises
from torch.testing import assert_close
from utils.asserts import assert_grad_close, assert_has_jac, assert_has_no_jac
from utils.tensors import tensor_

from torchjd.aggregation import Aggregator, Mean, PCGrad, UPGrad
from torchjd.aggregation import (
Aggregator,
ConFIG,
Mean,
PCGrad,
UPGrad,
)
from torchjd.aggregation._aggregator_bases import WeightedAggregator
from torchjd.autojac._jac_to_grad import jac_to_grad


@mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad()])
@mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad(), ConFIG()])
def test_various_aggregators(aggregator: Aggregator) -> None:
"""Tests that jac_to_grad works for various aggregators."""
"""
Tests that jac_to_grad works for various aggregators. For those that are weighted, the weights
should also be returned. For the others, None should be returned.
"""

t1 = tensor_(1.0, requires_grad=True)
t2 = tensor_([2.0, 3.0], requires_grad=True)
Expand All @@ -19,11 +30,18 @@ def test_various_aggregators(aggregator: Aggregator) -> None:
g1 = expected_grad[0]
g2 = expected_grad[1:]

jac_to_grad([t1, t2], aggregator)
optional_weights = jac_to_grad([t1, t2], aggregator)

assert_grad_close(t1, g1)
assert_grad_close(t2, g2)

if isinstance(aggregator, WeightedAggregator):
assert optional_weights is not None
expected_weights = aggregator.weighting(jac)
assert_close(optional_weights, expected_weights)
else:
assert optional_weights is None


def test_single_tensor() -> None:
"""Tests that jac_to_grad works when a single tensor is provided."""
Expand Down Expand Up @@ -82,7 +100,8 @@ def test_row_mismatch() -> None:
def test_no_tensors() -> None:
"""Tests that jac_to_grad correctly does nothing when an empty list of tensors is provided."""

jac_to_grad([], aggregator=UPGrad())
with raises(ValueError):
jac_to_grad([], UPGrad())


@mark.parametrize("retain_jac", [True, False])
Expand Down