From 2a7fe18119aadd79f214f03345717a5fcd450a34 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 19 Feb 2026 10:48:35 +0100 Subject: [PATCH 1/5] Fix in `_jacobian_computer`. The rational here is to let torch fail if the operations where not implemented but allow custom implementations. --- src/torchjd/autogram/_jacobian_computer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index 45cc71ba..8dc88dba 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -5,6 +5,7 @@ import torch from torch import Tensor, nn from torch.nn import Parameter +from torch.overrides import is_tensor_like from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only from torchjd._linalg import Matrix @@ -83,8 +84,8 @@ def _compute_jacobian( /, ) -> Matrix: grad_outputs_in_dims = (0,) * len(grad_outputs) - args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) - kwargs_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, kwargs) + args_in_dims = tree_map(lambda t: 0 if is_tensor_like(t) else None, args) + kwargs_in_dims = tree_map(lambda t: 0 if is_tensor_like(t) else None, kwargs) in_dims = (grad_outputs_in_dims, args_in_dims, kwargs_in_dims) vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims) @@ -114,7 +115,7 @@ def functional_model_call(rg_params: dict[str, Parameter]) -> tuple[Tensor, ...] ] output = torch.func.functional_call(self.module, all_state, args_j, kwargs_j) flat_outputs = tree_flatten(output)[0] - rg_outputs = tuple(t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad) + rg_outputs = tuple(t for t in flat_outputs if is_tensor_like(t) and t.requires_grad) return rg_outputs vjp_func = torch.func.vjp(functional_model_call, self.rg_params)[1] From 562e68cc0309f88754247790c9c42e1ef616b302 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 19 Feb 2026 10:50:34 +0100 Subject: [PATCH 2/5] Change in `module_hook_manager`. I think this one is not safe as the `require_grad` field might not exist in something that is not a `Tensor`. I'm not sure the status of this in torch so we should investigate some more. --- src/torchjd/autogram/_module_hook_manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index e2958d93..23f84cbf 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -4,6 +4,7 @@ import torch from torch import Tensor, nn from torch.autograd.graph import get_gradient_edge +from torch.overrides import is_tensor_like from torch.utils._pytree import PyTree, tree_flatten, tree_unflatten from torch.utils.hooks import RemovableHandle as TorchRemovableHandle @@ -114,7 +115,7 @@ def __call__( rg_outputs = list[Tensor]() rg_output_indices = list[int]() for idx, output in enumerate(flat_outputs): - if isinstance(output, Tensor) and output.requires_grad: + if is_tensor_like(output) and output.requires_grad: rg_outputs.append(output) rg_output_indices.append(idx) From e425af8a7b6ad8d3d89770695786ea13ffe474ee Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 19 Feb 2026 11:06:18 +0100 Subject: [PATCH 3/5] This one makes sense, but we are forced to cast as they did not make `is_tensor_like` a `TypeGuard`. I opened an issue: https://github.com/pytorch/pytorch/issues/175324 --- src/torchjd/autojac/_backward.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index 5a5b0139..aa855c3f 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -1,6 +1,8 @@ from collections.abc import Iterable, Sequence +from typing import cast from torch import Tensor +from torch.overrides import is_tensor_like from ._transform import AccumulateJac, Diagonalize, Init, Jac, OrderedSet, Transform from ._utils import ( @@ -140,7 +142,9 @@ def _create_jac_tensors_dict( # Transform that turns the gradients into Jacobians. diag = Diagonalize(tensors) return (diag << init)({}) - jac_tensors = [opt_jac_tensors] if isinstance(opt_jac_tensors, Tensor) else opt_jac_tensors + jac_tensors = cast( + Sequence[Tensor], (opt_jac_tensors,) if is_tensor_like(opt_jac_tensors) else opt_jac_tensors + ) check_matching_length(jac_tensors, tensors, "jac_tensors", "tensors") check_matching_jac_shapes(jac_tensors, tensors, "jac_tensors", "tensors") check_consistent_first_dimension(jac_tensors, "jac_tensors") From f94ee135d700216eb77cb2bbe21b838a3b6fcd2a Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 19 Feb 2026 11:09:29 +0100 Subject: [PATCH 4/5] Same in `jac`, need cast but makes sense. --- src/torchjd/autojac/_jac.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autojac/_jac.py b/src/torchjd/autojac/_jac.py index 404bb730..00f65c29 100644 --- a/src/torchjd/autojac/_jac.py +++ b/src/torchjd/autojac/_jac.py @@ -1,6 +1,8 @@ from collections.abc import Sequence +from typing import cast from torch import Tensor +from torch.overrides import is_tensor_like from torchjd.autojac._transform._base import Transform from torchjd.autojac._transform._diagonalize import Diagonalize @@ -154,7 +156,7 @@ def jac( raise ValueError("`outputs` cannot be empty.") # Preserve repetitions to duplicate jacobians at the return statement - inputs_with_repetition = (inputs,) if isinstance(inputs, Tensor) else inputs + inputs_with_repetition = cast(Sequence[Tensor], (inputs,) if is_tensor_like(inputs) else inputs) inputs_ = OrderedSet(inputs_with_repetition) jac_outputs_dict = _create_jac_outputs_dict(outputs_, jac_outputs) @@ -180,7 +182,9 @@ def _create_jac_outputs_dict( # Transform that turns the gradients into Jacobians. diag = Diagonalize(outputs) return (diag << init)({}) - jac_outputs = [opt_jac_outputs] if isinstance(opt_jac_outputs, Tensor) else opt_jac_outputs + jac_outputs = cast( + Sequence[Tensor], (opt_jac_outputs,) if is_tensor_like(opt_jac_outputs) else opt_jac_outputs + ) check_matching_length(jac_outputs, outputs, "jac_outputs", "outputs") check_matching_jac_shapes(jac_outputs, outputs, "jac_outputs", "outputs") check_consistent_first_dimension(jac_outputs, "jac_outputs") From ca877b5fc930d79359b4e1ae172871015712da23 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 19 Feb 2026 11:11:25 +0100 Subject: [PATCH 5/5] Same with `autojac._utils` --- src/torchjd/autojac/_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autojac/_utils.py b/src/torchjd/autojac/_utils.py index 03c4a405..d3285559 100644 --- a/src/torchjd/autojac/_utils.py +++ b/src/torchjd/autojac/_utils.py @@ -4,6 +4,7 @@ from torch import Tensor from torch.autograd.graph import Node +from torch.overrides import is_tensor_like from ._transform import OrderedSet @@ -20,8 +21,8 @@ def as_checked_ordered_set( tensors: Sequence[Tensor] | Tensor, variable_name: str, ) -> OrderedSet[Tensor]: - if isinstance(tensors, Tensor): - tensors = [tensors] + if is_tensor_like(tensors): + tensors = (cast(Tensor, tensors),) original_length = len(tensors) output = OrderedSet(tensors)