diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 8d644eccef0..5956ee8e16f 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -8,6 +8,7 @@ import inspect import logging +import platform from collections import Counter, defaultdict from pprint import pformat @@ -105,6 +106,21 @@ logger = logging.getLogger(__name__) +# TODO(MLETORCH-2048: Remove if possible or rework this to match minimal tolerance diff between architectures when TOSA is updated, or investigate/update atol in the failing tests) +def _adjust_tosa_aarch64_atol(compile_spec: ArmCompileSpec, atol: float) -> float: + """Increase tolerance for aarch64 when running on TOSA. + + This is due to the TOSA ref model being experimental on Aarch64. + + """ + if isinstance(compile_spec, TosaCompileSpec) and platform.machine().lower() in ( + "aarch64", + "arm64", + ): + return atol * 1.1 + return atol + + def _dump_lowered_modules_artifact( path_to_dump: Optional[str], artifact: Union[EdgeProgramManager, ExecutorchProgramManager], @@ -573,6 +589,8 @@ def run_method_and_compare_outputs( """ + atol = _adjust_tosa_aarch64_atol(self.compile_spec, atol) + # backward-compatible ordering (accept inputs as the first positional argument) inputs, reference_stage, test_stage = self._get_input_and_stages( inputs, stage, reference_stage_type, run_eager_mode