diff --git a/.ci/docker/ci_commit_pins/pytorch.txt b/.ci/docker/ci_commit_pins/pytorch.txt index 242371cbebe..401a0594d98 100644 --- a/.ci/docker/ci_commit_pins/pytorch.txt +++ b/.ci/docker/ci_commit_pins/pytorch.txt @@ -1 +1 @@ -release/2.12 +release/2.13 diff --git a/backends/arm/_passes/convert_full_like_to_full_pass.py b/backends/arm/_passes/convert_full_like_to_full_pass.py index f7a94424228..80be0472e67 100644 --- a/backends/arm/_passes/convert_full_like_to_full_pass.py +++ b/backends/arm/_passes/convert_full_like_to_full_pass.py @@ -42,7 +42,9 @@ def call_operator(self, op, args, kwargs, meta): return super().call_operator(op, args, kwargs, meta) tensor = args[0].data - full_args = (list(tensor.shape), args[1]) + # Each entry is an int (static dim) or an aten.sym_size.int ProxyValue (dynamic dim). + size_args = self.call_size_operator_all(args[0], meta) + full_args = (size_args, args[1]) full_kwargs = {"dtype": tensor.dtype} return super().call_operator( exir_ops.edge.aten.full.default, full_args, full_kwargs, meta diff --git a/backends/arm/_passes/decompose_linear_pass.py b/backends/arm/_passes/decompose_linear_pass.py index b11c6ac6ab3..e510d058dd1 100644 --- a/backends/arm/_passes/decompose_linear_pass.py +++ b/backends/arm/_passes/decompose_linear_pass.py @@ -55,7 +55,27 @@ def call(self, graph_module): # weights have shape (Co, Ci) weights_reshaped_shape = [weights_shape[0], weights_shape[1], 1, 1] + # ``Graph.create_node`` rejects raw SymInts in call_function + # args; ``materialize_symints`` walks each value's sympy + # expression and emits the FX subgraph that recomputes it from + # existing producers (placeholders / sym ops). One call shares + # the symbol→Proxy hash-cons across all three shape lists so + # repeated symbols become a single subgraph. with graph_module.graph.inserting_before(node): + n_in, n_w = ( + len(input_reshaped_shape), + len(weights_reshaped_shape), + ) + all_sizes = graph_module.graph.materialize_symints( + [ + *input_reshaped_shape, + *weights_reshaped_shape, + *output_shape, + ] + ) + input_reshaped_shape = all_sizes[:n_in] + weights_reshaped_shape = all_sizes[n_in : n_in + n_w] + output_size = all_sizes[n_in + n_w :] # Reshape input to 4D with shape (N, Ci, 1, 1) input_reshaped = create_node( graph=graph_module.graph, @@ -102,7 +122,7 @@ def call(self, graph_module): output = create_node( graph=graph_module.graph, op_target=exir_ops.edge.aten.view_copy.default, - args=(conv, list(output_shape)), + args=(conv, output_size), kwargs={}, from_node=node, inherit_qparams=False, diff --git a/backends/arm/_passes/decompose_select.py b/backends/arm/_passes/decompose_select.py index 9af42d27e5d..15b3586e6ee 100644 --- a/backends/arm/_passes/decompose_select.py +++ b/backends/arm/_passes/decompose_select.py @@ -55,10 +55,18 @@ def call(self, graph_module: torch.fx.GraphModule): index = size_at_dim - abs(index) with graph_module.graph.inserting_before(node): + # ``Graph.create_node`` rejects raw SymInts in call_function + # args; ``materialize_symints`` lifts symbolic ``index`` / + # ``index + 1`` (from negative-index wrap-around against a + # dynamic shape) into graph nodes in a single call (shares + # producer-discovery + hash-cons). Static ints pass through. + start_arg, end_arg = graph_module.graph.materialize_symints( + [index, index + 1] + ) slice_node = create_node( graph_module.graph, slice_op, - (input_node, dim, index, index + 1), + (input_node, dim, start_arg, end_arg), from_node=node, inherit_qparams=False, ) diff --git a/backends/arm/_passes/insert_dynamic_padding.py b/backends/arm/_passes/insert_dynamic_padding.py index 22de1262e83..add414d4525 100644 --- a/backends/arm/_passes/insert_dynamic_padding.py +++ b/backends/arm/_passes/insert_dynamic_padding.py @@ -36,12 +36,13 @@ class InsertDynamicPaddingPass(ArmOpTargetedPass): def _is_dynamic_padding( self, padding: ProxyValue | list[int] | tuple[int, ...] ) -> bool: - return (isinstance(padding, ProxyValue) and is_shape_op_node(padding.node)) or ( - ( - isinstance(padding, (list, tuple)) - and any(isinstance(p, torch.SymInt) for p in padding) + if isinstance(padding, ProxyValue) and is_shape_op_node(padding.node): + return True + if isinstance(padding, (list, tuple)): + return any( + isinstance(p, (torch.SymInt, ProxyValue)) for p in padding ) - ) + return False def call_operator(self, op, args, kwargs, meta, updated=False) -> ProxyValue: if op not in self.target_ops: diff --git a/backends/arm/_passes/rewrite_conv_pass.py b/backends/arm/_passes/rewrite_conv_pass.py index 2b32bd760e4..454b5e64134 100644 --- a/backends/arm/_passes/rewrite_conv_pass.py +++ b/backends/arm/_passes/rewrite_conv_pass.py @@ -607,6 +607,29 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 dilation, ) + # Compute fake tensor BEFORE materializing SymInts into FX nodes, + # since the underlying op expects ints/SymInts (not FX Nodes). + bias_fake_tensor = get_first_fake_tensor(bias) if bias else None + tosa_node_fake_tensor = target_op( + input_tensor_for_tosa_fake, + weight_fake_tensor, + bias_fake_tensor, + *conv_args[3:], + ) + + # ``Graph.create_node`` rejects raw SymInts in call_function args. + # If ``pad`` contains symbolic entries, materialize them into FX + # nodes so the TOSA conv node references the producing graph + # subgraph instead of holding raw SymInts. + if isinstance(pad, list) and any( + isinstance(p, torch.SymInt) for p in pad + ): + with graph_module.graph.inserting_before(node): + materialized_pad = graph_module.graph.materialize_symints(pad) + new_conv_args = list(conv_args) + new_conv_args[4] = materialized_pad + conv_args = tuple(new_conv_args) + with graph_module.graph.inserting_after(node): tosa_op = create_node( graph=graph_module.graph, @@ -615,13 +638,6 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 from_node=node, inherit_qparams=True, ) - bias_fake_tensor = get_first_fake_tensor(bias) if bias else None - tosa_node_fake_tensor = target_op( - input_tensor_for_tosa_fake, - weight_fake_tensor, - bias_fake_tensor, - *conv_args[3:], - ) tosa_op.meta["val"] = tosa_node_fake_tensor node_replacement, node_replacement_fake_tensor = ( diff --git a/backends/arm/_passes/size_adjust_input_pass.py b/backends/arm/_passes/size_adjust_input_pass.py index 1c331b9c329..b5d7ab88dbc 100644 --- a/backends/arm/_passes/size_adjust_input_pass.py +++ b/backends/arm/_passes/size_adjust_input_pass.py @@ -233,10 +233,29 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: parent_node = node.args[0] with graph_module.graph.inserting_before(node): + # ``Graph.create_node`` rejects raw SymInts in + # call_function args. Aggregate every slice arg across + # all entries and lift via a single ``materialize_symints`` + # call -- the helper walks the graph once for producer + # discovery, so a single call amortises that cost and lets + # symints with shared sub-expressions get hash-consed into + # one subgraph. Plain ints pass through unchanged. + flat = [a for args in slice_args for a in args] + materialized = iter(graph.materialize_symints(flat)) + # Pop exactly len(args) values from the materialized iterator + # and pack them back into a tuple -- regroups the flat list + # of lifted values into the original (dim, start, end) shape. + lifted_slice_args = [ + tuple(next(materialized) for _ in args) for args in slice_args + ] + last_node = cast(torch.fx.Node, parent_node) - for args in slice_args: + for args in lifted_slice_args: slice_node = create_node( - graph, slice_op, (last_node,) + args, from_node=node + graph, + slice_op, + (last_node,) + args, + from_node=node, ) last_node = slice_node node.replace_input_with(cast(torch.fx.Node, parent_node), last_node) diff --git a/backends/arm/test/passes/test_decompose_select_pass.py b/backends/arm/test/passes/test_decompose_select_pass.py index 8702fb086da..1590a89cf1d 100644 --- a/backends/arm/test/passes/test_decompose_select_pass.py +++ b/backends/arm/test/passes/test_decompose_select_pass.py @@ -63,11 +63,18 @@ def test_decompose_select_negative_symbolic_index_uses_symbolic_sub() -> None: slice_node = slice_nodes[0] assert slice_node.args[1] == 1 - assert slice_node.args[2] != -1 - assert isinstance(slice_node.args[2], torch.SymInt) - assert isinstance(slice_node.args[3], torch.SymInt) - assert str(slice_node.args[2]).endswith(" - 1") - assert str(slice_node.args[3]) in str(slice_node.args[2]) + # Start/end are now FX nodes (materialized from SymInts) rather than raw + # SymInts, since Graph.create_node rejects raw symbolic leaves in + # call_function args. The original SymInt is preserved in meta['val']. + start_arg, end_arg = slice_node.args[2], slice_node.args[3] + assert isinstance(start_arg, torch.fx.Node) + assert isinstance(end_arg, torch.fx.Node) + start_val = start_arg.meta["val"] + end_val = end_arg.meta["val"] + assert isinstance(start_val, torch.SymInt) + assert isinstance(end_val, torch.SymInt) + assert str(start_val).endswith(" - 1") + assert str(end_val) in str(start_val) assert squeeze_nodes[0].args == (slice_node, [1]) result.graph_module.graph.lint() diff --git a/backends/arm/test/passes/test_insert_dynamic_padding_pass.py b/backends/arm/test/passes/test_insert_dynamic_padding_pass.py index 64594403dae..9c2796b70e3 100644 --- a/backends/arm/test/passes/test_insert_dynamic_padding_pass.py +++ b/backends/arm/test/passes/test_insert_dynamic_padding_pass.py @@ -50,7 +50,13 @@ def test_insert_dynamic_padding(): n for n in nodes if n.target == exir_ops.backend.tosa.CONV2D.default ) initial_padding = conv_node.args[4] - assert any(isinstance(p, torch.SymInt) for p in initial_padding) + # SymInts now appear as FX nodes (with the SymInt stored in meta['val']) + # so that Graph.create_node does not reject raw SymInts in call_function args. + assert any(isinstance(p, torch.fx.Node) for p in initial_padding) + initial_padding_vals = [ + p.meta["val"] if isinstance(p, torch.fx.Node) else p + for p in initial_padding + ] edge_model = edge_model.transform( [ @@ -70,5 +76,5 @@ def test_insert_dynamic_padding(): pad_list = padding_node.args[1].meta["val"] assert len(pad_list) == 8 assert pad_list[:2] == [0, 0] # N-padding - assert pad_list[2:6] == initial_padding # HW-padding in NHWC order + assert pad_list[2:6] == initial_padding_vals # HW-padding in NHWC order assert pad_list[6:] == [0, 0] # C-padding diff --git a/exir/pass_base.py b/exir/pass_base.py index f93dd75d156..c6cff19ddb8 100644 --- a/exir/pass_base.py +++ b/exir/pass_base.py @@ -562,6 +562,45 @@ def call_operator( ) -> ProxyValue: return self._fx("call_function", op, args, kwargs, meta) + def call_size_operator( + self, + tensor_proxy: ProxyValue, + dim: int, + meta: NodeMetadata, + ) -> Union[ProxyValue, int]: + """Read ``tensor_proxy.size(dim)`` as a value usable in graph args. + Returns a plain ``int`` for static dims; emits and returns an + ``aten.sym_size.int`` ``ProxyValue`` for SymInt dims (since + ``Graph.create_node`` rejects raw SymInts in call_function args). + """ + size = tensor_proxy.data.shape[dim] + if isinstance(size, torch.SymInt): + new_proxy = self.call_operator( + torch.ops.aten.sym_size.int, (tensor_proxy, dim), {}, meta + ) + # Mirror source's "example_value" if present, so the new node + # matches the surrounding graph's meta-key convention. "val" + # is already set by call_operator → _fx → set_metadata. + if "example_value" in tensor_proxy.node.meta: + new_proxy.node.meta["example_value"] = new_proxy.node.meta["val"] + return new_proxy + return int(size) + + def call_size_operator_all( + self, + tensor_proxy: ProxyValue, + meta: NodeMetadata, + ) -> list[Union[ProxyValue, int]]: + """Return all dims of ``tensor_proxy.shape`` as a list of values + usable in graph args. Each entry is an ``int`` (static dim) or an + ``aten.sym_size.int`` ``ProxyValue`` (dynamic dim) — see + ``call_size_operator``. + """ + return [ + self.call_size_operator(tensor_proxy, d, meta) + for d in range(len(tensor_proxy.data.shape)) + ] + def call_sym( self, target: Fn, diff --git a/exir/tests/test_dynamic_shape_propagation.py b/exir/tests/test_dynamic_shape_propagation.py index 5697501039b..12c0b2ec823 100644 --- a/exir/tests/test_dynamic_shape_propagation.py +++ b/exir/tests/test_dynamic_shape_propagation.py @@ -115,7 +115,7 @@ def call_operator(self, op, args, kwargs, meta): x = args[0] x = super().call_operator( exir_ops.edge.aten.view_copy.default, - (x, list(x.data.shape) + [1]), + (x, self.call_size_operator_all(x, meta) + [1]), {}, meta, ) @@ -123,7 +123,7 @@ def call_operator(self, op, args, kwargs, meta): w = args[1] w = super().call_operator( exir_ops.edge.aten.view_copy.default, - (w, list(w.data.shape) + [1]), + (w, self.call_size_operator_all(w, meta) + [1]), {}, meta, ) @@ -144,7 +144,7 @@ def call_operator(self, op, args, kwargs, meta): ) x = super().call_operator( exir_ops.edge.aten.view_copy.default, - (x, list(x.data.shape)[:-1]), + (x, self.call_size_operator_all(x, meta)[:-1]), {}, meta, )