From 867fd841aeaa7f140135860606d57577e9c44d4f Mon Sep 17 00:00:00 2001 From: Per Held Date: Fri, 24 Apr 2026 10:42:28 +0200 Subject: [PATCH] Arm backend: block invalid bilinear RESIZE downscale at 1/16 TOSA bilinear RESIZE requires the downscale ratio to be strictly greater than 1/16. The Arm backend currently accepts the exact 1/16 boundary case and can emit RESIZE parameters that the TOSA reference model rejects. Move upsample operator gating into explicit Arm operator support checks. Keep nearest upsample explicitly supported there, and reject bilinear upsample cases whose computed TOSA RESIZE scale hits the invalid 1/16 boundary. Keep the corresponding validation in the fake TOSA RESIZE op so the dialect-level constraint is still enforced directly. Also add regressions for the dialect-level validation and the end-to-end Arm bilinear interpolate case with align_corners=False and scale_factor=1/16. Signed-off-by: Per Held Change-Id: I612fc7315fa4d1bd158e2f71bcaa493fcaf08c03 --- backends/arm/operator_support/__init__.py | 1 + .../tosa_profile_supported_op_lists.py | 4 -- .../arm/operator_support/upsample_support.py | 69 +++++++++++++++++++ .../arm/test/misc/test_tosa_dialect_resize.py | 36 ++++++++++ .../arm/test/ops/test_upsample_bilinear2d.py | 30 ++++++++ backends/arm/test/targets.bzl | 1 + backends/arm/tosa/dialect/ops/resize.py | 11 ++- 7 files changed, 146 insertions(+), 6 deletions(-) create mode 100644 backends/arm/operator_support/upsample_support.py create mode 100644 backends/arm/test/misc/test_tosa_dialect_resize.py diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index 088cee95371..066b5462f64 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -24,5 +24,6 @@ to_dim_order_copy_support, tosa_supported_operators, unfold_copy_support, + upsample_support, where_support, ) diff --git a/backends/arm/operator_support/tosa_profile_supported_op_lists.py b/backends/arm/operator_support/tosa_profile_supported_op_lists.py index 96c164214a0..3c3aa57774f 100644 --- a/backends/arm/operator_support/tosa_profile_supported_op_lists.py +++ b/backends/arm/operator_support/tosa_profile_supported_op_lists.py @@ -87,8 +87,6 @@ exir_ops.edge.aten.select_copy.int, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.tanh.default, - exir_ops.edge.aten.upsample_bilinear2d.vec, - exir_ops.edge.aten.upsample_nearest2d.vec, exir_ops.edge.aten.view_copy.default, exir_ops.edge.aten.unsqueeze_copy.default, exir_ops.edge.aten.squeeze_copy.dims, @@ -211,8 +209,6 @@ exir_ops.edge.aten._log_softmax.default, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.tanh.default, - exir_ops.edge.aten.upsample_bilinear2d.vec, - exir_ops.edge.aten.upsample_nearest2d.vec, exir_ops.edge.aten.var.correction, exir_ops.edge.aten.var.dim, exir_ops.edge.aten.view_copy.default, diff --git a/backends/arm/operator_support/upsample_support.py b/backends/arm/operator_support/upsample_support.py new file mode 100644 index 00000000000..bd03a4d2b4f --- /dev/null +++ b/backends/arm/operator_support/upsample_support.py @@ -0,0 +1,69 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Provide TOSA support checks for upsample operators.""" + +import torch.fx as fx +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm._passes.rewrite_upsample import RewriteUpsamplePass +from executorch.backends.arm.common.type import ensure_type +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa import TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + + +@register_tosa_support_check +class UpsampleNearest2dSupported(SupportedTOSAOperatorCheck): + """Provide the explicit TOSA support gate for nearest upsample.""" + + targets = [exir_ops.edge.aten.upsample_nearest2d.vec] + + def is_node_tosa_supported( + self, _node: fx.Node, _tosa_spec: TosaSpecification + ) -> bool: # type: ignore[override, misc] + return True + + +@register_tosa_support_check +class UpsampleBilinear2dSupported(SupportedTOSAOperatorCheck): + """Reject bilinear upsample cases that cannot lower to a valid TOSA + RESIZE. + """ + + targets = [exir_ops.edge.aten.upsample_bilinear2d.vec] + + def is_node_tosa_supported( + self, node: fx.Node, _tosa_spec: TosaSpecification + ) -> bool: # type: ignore[override, misc] + input_node = ensure_type(fx.Node, node.args[0]) + align_corners = ensure_type(bool, node.args[2]) + input_size_yx = get_first_fake_tensor(input_node).shape[2:] + output_size_yx = get_first_fake_tensor(node).shape[2:] + + try: + scale_y_n, scale_y_d, _, _ = RewriteUpsamplePass.get_resize_parameters_1d( + input_size_yx[0], output_size_yx[0], align_corners + ) + scale_x_n, scale_x_d, _, _ = RewriteUpsamplePass.get_resize_parameters_1d( + input_size_yx[1], output_size_yx[1], align_corners + ) + except RuntimeError as err: + self.reporter.report_reject(node, str(err)) + return False + + # get_resize_parameters_1d() returns the TOSA RESIZE scale fraction for + # each spatial dimension. For align_corners=False, this is the effective + # output_size / input_size ratio, so the 1/16 boundary is checked + # directly in the same representation that RESIZE lowering will use. + if scale_y_d >= 16 * scale_y_n or scale_x_d >= 16 * scale_x_n: + self.reporter.report_reject( + node, + "Bilinear RESIZE downscale must be strictly greater than 1/16", + ) + return False + + return True diff --git a/backends/arm/test/misc/test_tosa_dialect_resize.py b/backends/arm/test/misc/test_tosa_dialect_resize.py new file mode 100644 index 00000000000..91e7aad8ad9 --- /dev/null +++ b/backends/arm/test/misc/test_tosa_dialect_resize.py @@ -0,0 +1,36 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.arm.tosa.dialect # noqa: F401 + +import pytest +import torch + +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode + + +def test_bilinear_resize_rejects_exact_one_sixteenth_downscale(): + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+INT") + ), FakeTensorMode() as mode: + with pytest.raises( + TosaValueError, + match="Bilinear RESIZE downscale must be strictly greater than 1/16", + ): + exir_ops.backend.tosa.RESIZE.default( + mode.from_tensor( + torch.randint(0, 10, (1, 3, 256, 448), dtype=torch.int8) + ), + [2, 32, 2, 32], + [15, 15], + [-15, -15], + resize_mode="bilinear", + ) diff --git a/backends/arm/test/ops/test_upsample_bilinear2d.py b/backends/arm/test/ops/test_upsample_bilinear2d.py index 3ac727e529f..d06d1688ffe 100644 --- a/backends/arm/test/ops/test_upsample_bilinear2d.py +++ b/backends/arm/test/ops/test_upsample_bilinear2d.py @@ -140,6 +140,25 @@ def forward(self, x): return self.upsample(x) +class InterpolateAlignCornersFalse(torch.nn.Module): + def __init__( + self, + size: Optional[Tuple[int]], + scale_factor: Optional[float | Tuple[float]], + ): + super().__init__() + self.upsample = lambda x: torch.nn.functional.interpolate( + x, + size=size, + scale_factor=scale_factor, + mode="bilinear", + align_corners=False, + ) + + def forward(self, x): + return self.upsample(x) + + @common.parametrize( "test_data", test_data_suite_tosa | test_data_suite_tosa_bf16 | test_data_suite_tosa_fp16, @@ -231,6 +250,17 @@ def test_upsample_bilinear2d_vec_tosa_FP_Interpolate( pipeline.run() +def test_upsample_bilinear2d_vec_tosa_does_not_delegate_exact_one_sixteenth_downscale(): + pipeline = OpNotSupportedPipeline[input_t1]( + InterpolateAlignCornersFalse(size=None, scale_factor=1.0 / 16.0), + (torch.randn(1, 3, 256, 448),), + {exir_op: 1}, + n_expected_delegates=0, + ) + + pipeline.run() + + @common.parametrize("test_data", test_data_suite_tosa) def test_upsample_bilinear2d_vec_tosa_INT_intropolate( test_data: torch.Tensor, diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index b9e8726f78d..52d1b651b75 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -47,6 +47,7 @@ def define_arm_tests(): "misc/test_compile_spec.py", # "misc/test_evaluate_model.py", "misc/test_pass_pipeline_config.py", + "misc/test_tosa_dialect_resize.py", "misc/test_tosa_spec.py", "misc/test_bn_relu_folding_qat.py", "misc/test_custom_partition.py", diff --git a/backends/arm/tosa/dialect/ops/resize.py b/backends/arm/tosa/dialect/ops/resize.py index f8b078c8690..47add0ffb7f 100644 --- a/backends/arm/tosa/dialect/ops/resize.py +++ b/backends/arm/tosa/dialect/ops/resize.py @@ -50,7 +50,7 @@ def _get_output_dtype( return output_dtype -def _validate_resize_parameters(scale, border): +def _validate_resize_parameters(scale, border, resize_mode): def in_int16_range(values): return all((x >= -(2**15)) and (x <= 2**15 - 1) for x in values) @@ -58,6 +58,13 @@ def in_int16_range(values): raise TosaValueError("scale is out of the int16 range", op="RESIZE") if not in_int16_range(border): raise TosaValueError("border is out of the int16 range", op="RESIZE") + if resize_mode == "bilinear": + scale_y_n, scale_y_d, scale_x_n, scale_x_d = scale + if scale_y_d >= 16 * scale_y_n or scale_x_d >= 16 * scale_x_n: + raise TosaValueError( + "Bilinear RESIZE downscale must be strictly greater than 1/16", + op="RESIZE", + ) @register_fake_tosa_op( @@ -79,7 +86,7 @@ def RESIZE( f"Input tensor must be 4D, but got {x.dim()}D", op="RESIZE" ) _validate_resize_mode(resize_mode) - _validate_resize_parameters(scale, border) + _validate_resize_parameters(scale, border, resize_mode) output_dtype = _get_output_dtype(x.dtype, tosa_spec, resize_mode) input_shape = x.shape