From 4078863f0195959df736588bc2402bb034cce34c Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 17 Mar 2026 10:26:28 +0100 Subject: [PATCH] BUG: torch: work around torch.round not supporting complex inputs --- array_api_compat/torch/_aliases.py | 18 +++++++++++++++++- tests/test_torch.py | 9 +++++++++ torch-xfails.txt | 1 - 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 69bd3763..a5348dc2 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -912,6 +912,22 @@ def sign(x: Array, /) -> Array: return out +def round(x: Array, /, **kwargs) -> Array: + # torch.round fails for complex inputs + # https://github.com/pytorch/pytorch/issues/58743#issuecomment-2727603845 + if x.dtype.is_complex: + out = kwargs.pop('out', None) + res_r = torch.round(x.real, **kwargs) + res_i = torch.round(x.imag, **kwargs) + res = res_r + 1j*res_i + if out is not None: + out.copy_(res) + return out + return res + else: + return torch.round(x, **kwargs) + + def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Array, ...]: # torch <= 2.9 emits a UserWarning: "torch.meshgrid: in an upcoming release, it # will be required to pass the indexing argument." @@ -923,7 +939,7 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Arra 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero', - 'diff', 'divide', + 'diff', 'divide', 'round', 'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot', 'less', 'less_equal', 'logaddexp', 'maximum', 'minimum', 'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max', diff --git a/tests/test_torch.py b/tests/test_torch.py index 463dd597..35ef5dda 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -152,3 +152,12 @@ def test_argsort_stable(): t = xp.zeros(50) # should be >16 assert xp.all(xp.argsort(t) == xp.arange(50)) + + +def test_round(): + """Verify the out= argument of xp.round with complex inputs.""" + x = torch.as_tensor([1.23456786]*3) + 3.456789j + o = torch.empty(3, dtype=torch.complex64) + r = xp.round(x, decimals=1, out=o) + assert xp.all(r == o) + assert r is o diff --git a/torch-xfails.txt b/torch-xfails.txt index 84271a56..3b75972b 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -130,7 +130,6 @@ array_api_tests/test_statistical_functions.py::test_var # These functions do not yet support complex numbers -array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_set_functions.py::test_unique_counts array_api_tests/test_set_functions.py::test_unique_values