From 33c75c2323aee9df7a49de05c61825b63db43b79 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Tue, 19 May 2026 21:20:17 +0800 Subject: [PATCH 1/5] [Relax] Normalize negative concat axis in --- .../reorder_permute_dims_after_concat.cc | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/relax/transform/reorder_permute_dims_after_concat.cc b/src/relax/transform/reorder_permute_dims_after_concat.cc index bc542ccf91ef..fc083fcb3e06 100644 --- a/src/relax/transform/reorder_permute_dims_after_concat.cc +++ b/src/relax/transform/reorder_permute_dims_after_concat.cc @@ -151,8 +151,19 @@ std::tuple)>> auto concat_attrs = concat_call->attrs.as(); TVM_FFI_ICHECK(concat_attrs); - auto old_concat_axis = [&]() -> size_t { return concat_attrs->axis.value_or(0); }(); - Integer new_concat_axis = get_permute_dims_axes(all_permute_dims[0])[old_concat_axis]; + auto permute_dims_axes = get_permute_dims_axes(all_permute_dims[0]); + + int64_t old_concat_axis = concat_attrs->axis.value_or(0); + int64_t ndim = static_cast(permute_dims_axes.size()); + if (old_concat_axis < 0) { + old_concat_axis += ndim; + } + TVM_FFI_ICHECK_GE(old_concat_axis, 0) << "concat axis " << old_concat_axis + << " out of range for " << ndim << "-D input"; + TVM_FFI_ICHECK_LT(old_concat_axis, ndim) << "concat axis " << old_concat_axis + << " out of range for " << ndim << "-D input"; + + Integer new_concat_axis = permute_dims_axes[static_cast(old_concat_axis)]; auto new_concat = concat(Tuple(args), new_concat_axis->value); auto new_permute_dims = permute_dims(new_concat, permute_axes); From 267df6f89d2fd766d227572b382e55b4ef1b494c Mon Sep 17 00:00:00 2001 From: cchung100m Date: Tue, 19 May 2026 22:52:20 +0800 Subject: [PATCH 2/5] [Relax] Fix lint error --- src/relax/transform/reorder_permute_dims_after_concat.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/relax/transform/reorder_permute_dims_after_concat.cc b/src/relax/transform/reorder_permute_dims_after_concat.cc index fc083fcb3e06..01eadc37f376 100644 --- a/src/relax/transform/reorder_permute_dims_after_concat.cc +++ b/src/relax/transform/reorder_permute_dims_after_concat.cc @@ -158,10 +158,10 @@ std::tuple)>> if (old_concat_axis < 0) { old_concat_axis += ndim; } - TVM_FFI_ICHECK_GE(old_concat_axis, 0) << "concat axis " << old_concat_axis - << " out of range for " << ndim << "-D input"; - TVM_FFI_ICHECK_LT(old_concat_axis, ndim) << "concat axis " << old_concat_axis - << " out of range for " << ndim << "-D input"; + TVM_FFI_ICHECK_GE(old_concat_axis, 0) + << "concat axis " << old_concat_axis << " out of range for " << ndim << "-D input"; + TVM_FFI_ICHECK_LT(old_concat_axis, ndim) + << "concat axis " << old_concat_axis << " out of range for " << ndim << "-D input"; Integer new_concat_axis = permute_dims_axes[static_cast(old_concat_axis)]; From 6bd197d89f45bae61d42c1181516c8e4db3ffefe Mon Sep 17 00:00:00 2001 From: cchung100m Date: Tue, 19 May 2026 22:52:57 +0800 Subject: [PATCH 3/5] [Relax] Add test case: class TestNegativeConcatAxis(Base) --- ...sform_reorder_permute_dims_after_concat.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py b/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py index f93daa4c1e00..dfa9bcee39f8 100644 --- a/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py +++ b/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py @@ -261,5 +261,34 @@ def main( return out +class TestNegativeConcatAxis(Base): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([1, 4, 8, 8], "float32"), + y: R.Tensor([1, 4, 8, 8], "float32"), + ): + with.R.dataflow(): + xt = R.permute_dims(x, axes=[0, 2, 3, 1]) + yt = R.permute_dims(y, axes=[0, 2, 3, 1]) + out = R.concat([xt, yt], axis=-1) + R.output(out) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([1, 4, 8, 8], "float32"), + y: R.Tensor([1, 4, 8, 8], "float32"), + ): + with.R.dataflow(): + merged = R.concat([x, y], axis=1) + out = R.permute_dims(merged, axes=[0, 2, 3, 1]) + R.output(out) + return out + + if __name__ == "__main__": tvm.testing.main() From 01f1710d65e758e0823db557e728ea9e9bbef954 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Wed, 20 May 2026 08:40:09 +0800 Subject: [PATCH 4/5] Fix typo --- .../relax/test_transform_reorder_permute_dims_after_concat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py b/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py index dfa9bcee39f8..b87dc8693b80 100644 --- a/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py +++ b/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py @@ -283,7 +283,7 @@ def main( x: R.Tensor([1, 4, 8, 8], "float32"), y: R.Tensor([1, 4, 8, 8], "float32"), ): - with.R.dataflow(): + with R.dataflow(): merged = R.concat([x, y], axis=1) out = R.permute_dims(merged, axes=[0, 2, 3, 1]) R.output(out) From eaa92409d8eae15d0b980c2ab77cf616869dff54 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Wed, 20 May 2026 08:40:34 +0800 Subject: [PATCH 5/5] Fix typo --- .../relax/test_transform_reorder_permute_dims_after_concat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py b/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py index b87dc8693b80..2da6cfcda99b 100644 --- a/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py +++ b/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py @@ -269,7 +269,7 @@ def main( x: R.Tensor([1, 4, 8, 8], "float32"), y: R.Tensor([1, 4, 8, 8], "float32"), ): - with.R.dataflow(): + with R.dataflow(): xt = R.permute_dims(x, axes=[0, 2, 3, 1]) yt = R.permute_dims(y, axes=[0, 2, 3, 1]) out = R.concat([xt, yt], axis=-1)