Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/torchjd/autogram/_jacobian_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from torch import Tensor, nn
from torch.nn import Parameter
from torch.overrides import is_tensor_like
from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only

from torchjd._linalg import Matrix
Expand Down Expand Up @@ -83,8 +84,8 @@ def _compute_jacobian(
/,
) -> Matrix:
grad_outputs_in_dims = (0,) * len(grad_outputs)
args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args)
kwargs_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, kwargs)
args_in_dims = tree_map(lambda t: 0 if is_tensor_like(t) else None, args)
kwargs_in_dims = tree_map(lambda t: 0 if is_tensor_like(t) else None, kwargs)
in_dims = (grad_outputs_in_dims, args_in_dims, kwargs_in_dims)
vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims)

Expand Down Expand Up @@ -114,7 +115,7 @@ def functional_model_call(rg_params: dict[str, Parameter]) -> tuple[Tensor, ...]
]
output = torch.func.functional_call(self.module, all_state, args_j, kwargs_j)
flat_outputs = tree_flatten(output)[0]
rg_outputs = tuple(t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad)
rg_outputs = tuple(t for t in flat_outputs if is_tensor_like(t) and t.requires_grad)
return rg_outputs

vjp_func = torch.func.vjp(functional_model_call, self.rg_params)[1]
Expand Down
3 changes: 2 additions & 1 deletion src/torchjd/autogram/_module_hook_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch import Tensor, nn
from torch.autograd.graph import get_gradient_edge
from torch.overrides import is_tensor_like
from torch.utils._pytree import PyTree, tree_flatten, tree_unflatten
from torch.utils.hooks import RemovableHandle as TorchRemovableHandle

Expand Down Expand Up @@ -114,7 +115,7 @@ def __call__(
rg_outputs = list[Tensor]()
rg_output_indices = list[int]()
for idx, output in enumerate(flat_outputs):
if isinstance(output, Tensor) and output.requires_grad:
if is_tensor_like(output) and output.requires_grad:
rg_outputs.append(output)
rg_output_indices.append(idx)

Expand Down
6 changes: 5 additions & 1 deletion src/torchjd/autojac/_backward.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections.abc import Iterable, Sequence
from typing import cast

from torch import Tensor
from torch.overrides import is_tensor_like

from ._transform import AccumulateJac, Diagonalize, Init, Jac, OrderedSet, Transform
from ._utils import (
Expand Down Expand Up @@ -140,7 +142,9 @@ def _create_jac_tensors_dict(
# Transform that turns the gradients into Jacobians.
diag = Diagonalize(tensors)
return (diag << init)({})
jac_tensors = [opt_jac_tensors] if isinstance(opt_jac_tensors, Tensor) else opt_jac_tensors
jac_tensors = cast(
Sequence[Tensor], (opt_jac_tensors,) if is_tensor_like(opt_jac_tensors) else opt_jac_tensors
)
check_matching_length(jac_tensors, tensors, "jac_tensors", "tensors")
check_matching_jac_shapes(jac_tensors, tensors, "jac_tensors", "tensors")
check_consistent_first_dimension(jac_tensors, "jac_tensors")
Expand Down
8 changes: 6 additions & 2 deletions src/torchjd/autojac/_jac.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections.abc import Sequence
from typing import cast

from torch import Tensor
from torch.overrides import is_tensor_like

from torchjd.autojac._transform._base import Transform
from torchjd.autojac._transform._diagonalize import Diagonalize
Expand Down Expand Up @@ -154,7 +156,7 @@ def jac(
raise ValueError("`outputs` cannot be empty.")

# Preserve repetitions to duplicate jacobians at the return statement
inputs_with_repetition = (inputs,) if isinstance(inputs, Tensor) else inputs
inputs_with_repetition = cast(Sequence[Tensor], (inputs,) if is_tensor_like(inputs) else inputs)
inputs_ = OrderedSet(inputs_with_repetition)

jac_outputs_dict = _create_jac_outputs_dict(outputs_, jac_outputs)
Expand All @@ -180,7 +182,9 @@ def _create_jac_outputs_dict(
# Transform that turns the gradients into Jacobians.
diag = Diagonalize(outputs)
return (diag << init)({})
jac_outputs = [opt_jac_outputs] if isinstance(opt_jac_outputs, Tensor) else opt_jac_outputs
jac_outputs = cast(
Sequence[Tensor], (opt_jac_outputs,) if is_tensor_like(opt_jac_outputs) else opt_jac_outputs
)
check_matching_length(jac_outputs, outputs, "jac_outputs", "outputs")
check_matching_jac_shapes(jac_outputs, outputs, "jac_outputs", "outputs")
check_consistent_first_dimension(jac_outputs, "jac_outputs")
Expand Down
5 changes: 3 additions & 2 deletions src/torchjd/autojac/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from torch import Tensor
from torch.autograd.graph import Node
from torch.overrides import is_tensor_like

from ._transform import OrderedSet

Expand All @@ -20,8 +21,8 @@ def as_checked_ordered_set(
tensors: Sequence[Tensor] | Tensor,
variable_name: str,
) -> OrderedSet[Tensor]:
if isinstance(tensors, Tensor):
tensors = [tensors]
if is_tensor_like(tensors):
tensors = (cast(Tensor, tensors),)

original_length = len(tensors)
output = OrderedSet(tensors)
Expand Down