-
Notifications
You must be signed in to change notification settings - Fork 14
feat(autojac): Make jac_to_grad return optional weights #586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f09a485
13523cc
cad8971
f9b728d
56ad695
b701c59
5801380
d24c56a
700085d
e154273
a03f06f
caf7dd0
a985fb0
100efb2
4ef4240
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,21 +1,43 @@ | ||
| from collections.abc import Iterable | ||
| from typing import overload | ||
|
|
||
| import torch | ||
| from torch import Tensor | ||
|
|
||
| from torchjd.aggregation import Aggregator | ||
| from torchjd.aggregation._aggregator_bases import WeightedAggregator | ||
|
|
||
| from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac | ||
| from ._utils import check_consistent_first_dimension | ||
|
|
||
|
|
||
| @overload | ||
| def jac_to_grad( | ||
| tensors: Iterable[Tensor], | ||
| /, | ||
| aggregator: WeightedAggregator, | ||
| *, | ||
| retain_jac: bool = False, | ||
| ) -> Tensor: ... | ||
|
|
||
|
|
||
| @overload | ||
| def jac_to_grad( | ||
| tensors: Iterable[Tensor], | ||
| /, | ||
| aggregator: Aggregator, # Not a WeightedAggregator, because overloads are checked in order | ||
| *, | ||
| retain_jac: bool = False, | ||
| ) -> None: ... | ||
|
|
||
|
|
||
| def jac_to_grad( | ||
| tensors: Iterable[Tensor], | ||
| /, | ||
| aggregator: Aggregator, | ||
| *, | ||
| retain_jac: bool = False, | ||
| ) -> None: | ||
| ) -> Tensor | None: | ||
| r""" | ||
| Aggregates the Jacobians stored in the ``.jac`` fields of ``tensors`` and accumulates the result | ||
| into their ``.grad`` fields. | ||
|
|
@@ -25,6 +47,9 @@ def jac_to_grad( | |
| :param aggregator: The aggregator used to reduce the Jacobians into gradients. | ||
| :param retain_jac: Whether to preserve the ``.jac`` fields of the tensors after they have been | ||
| used. Defaults to ``False``. | ||
| :returns: If ``aggregator`` is based on a | ||
| :class:`Weighting <torchjd.aggregation._weighting_bases.Weighting>` to combine the rows of | ||
| the Jacobians, returns the weights used for the aggregation, otherwise returns ``None``. | ||
|
|
||
| .. note:: | ||
| This function starts by "flattening" the ``.jac`` fields into matrices (i.e. flattening all | ||
|
|
@@ -48,12 +73,15 @@ def jac_to_grad( | |
| >>> y2 = (param ** 2).sum() | ||
| >>> | ||
| >>> backward([y1, y2]) # param now has a .jac field | ||
| >>> jac_to_grad([param], aggregator=UPGrad()) # param now has a .grad field | ||
| >>> weights = jac_to_grad([param], UPGrad()) # param now has a .grad field | ||
| >>> param.grad | ||
| tensor([-1., 1.]) | ||
| tensor([0.5000, 2.5000]) | ||
| >>> weights | ||
| tensor([0.5, 0.5]) | ||
|
|
||
| The ``.grad`` field of ``param`` now contains the aggregation (by UPGrad) of the Jacobian of | ||
| :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. | ||
| :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. In this case, the | ||
| weights used to combine the Jacobian are equal because there was no conflict. | ||
| """ | ||
|
|
||
| tensors_ = list[TensorWithJac]() | ||
|
|
@@ -66,7 +94,7 @@ def jac_to_grad( | |
| tensors_.append(t) | ||
|
|
||
| if len(tensors_) == 0: | ||
| return | ||
| raise ValueError("The field `tensor` cannot be empty.") | ||
|
|
||
| jacobians = [t.jac for t in tensors_] | ||
|
|
||
|
|
@@ -76,9 +104,15 @@ def jac_to_grad( | |
| _free_jacs(tensors_) | ||
|
|
||
| jacobian_matrix = _unite_jacobians(jacobians) | ||
| gradient_vector = aggregator(jacobian_matrix) | ||
| if isinstance(aggregator, WeightedAggregator): | ||
| weights = aggregator.weighting(jacobian_matrix) | ||
| gradient_vector = weights @ jacobian_matrix | ||
|
Comment on lines
+108
to
+109
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case, the hooks of the aggregator are not run. Not sure about how to resolve this without running the aggregator.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It feels like we could cheat a bit and have a boolean field
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Its possible that we would benefit from having our own system of hooks for aggregators and weightings. |
||
| else: | ||
| weights = None | ||
| gradient_vector = aggregator(jacobian_matrix) | ||
| gradients = _disunite_gradient(gradient_vector, tensors_) | ||
| accumulate_grads(tensors_, gradients) | ||
| return weights | ||
|
|
||
|
|
||
| def _unite_jacobians(jacobians: list[Tensor]) -> Tensor: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice to have type disjunction though.