From e667c82e8c3cea9e26675ccdbbc1969430296330 Mon Sep 17 00:00:00 2001 From: Aharrypotter Date: Wed, 18 Mar 2026 01:21:26 +0800 Subject: [PATCH] [Relax][ONNX] Fix get_converter selecting wrong impl when opset < minimum supported version When no _impl_vN version exists for the given opset, the bisect-based index becomes 0; subtracting 1 gives -1, and Python negative indexing silently selects the latest (incompatible) implementation. For example, ReduceMean with opset=9 was mapped to _impl_v18 instead of raising an error, producing wrong output shapes. Fix by filtering to compatible versions first and raising NotImplementedError when none exist. Fixes #18698 --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 9 ++++-- tests/python/relax/test_frontend_onnx.py | 30 +++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 3dc575ae778c..1aa7b36e801f 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -302,8 +302,13 @@ def get_converter(cls, opset): number smaller than or equal to opset belongs to all support versions. """ versions = [int(d.replace("_impl_v", "")) for d in dir(cls) if "_impl_v" in d] - versions = sorted(versions + [opset]) - version = versions[max([i for i, v in enumerate(versions) if v == opset]) - 1] + compatible = [v for v in versions if v <= opset] + if not compatible: + raise NotImplementedError( + f"{cls.__name__} is not supported for opset {opset}. " + f"Minimum supported opset: {min(versions)}" + ) + version = max(compatible) if hasattr(cls, f"_impl_v{version}"): return getattr(cls, f"_impl_v{version}") raise NotImplementedError(f"opset version {version} of {cls.__name__} not implemented") diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index ecbc6c9e8a5e..475659852989 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -3980,5 +3980,35 @@ def test_nms_score_threshold(): ) +def test_reduce_mean(): + # opset 13: axes passed as attribute + node = helper.make_node("ReduceMean", inputs=["x"], outputs=["y"], axes=[2], keepdims=1) + graph = helper.make_graph( + [node], + "reduce_mean_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 68, 4, 18])], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 68, 1, 18])], + ) + model = helper.make_model(graph, producer_name="reduce_mean_test") + check_correctness(model, opset=13) + + +def test_reduce_mean_unsupported_opset(): + # Regression test for https://github.com/apache/tvm/issues/18698. + # When opset < minimum available impl version, get_converter previously + # wrapped to -1 and silently picked the newest impl instead of raising. + node = helper.make_node("ReduceMean", inputs=["x"], outputs=["y"], axes=[2], keepdims=1) + graph = helper.make_graph( + [node], + "reduce_mean_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 68, 4, 18])], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 68, 1, 18])], + ) + model = helper.make_model(graph, producer_name="reduce_mean_test") + model.opset_import[0].version = 9 + with pytest.raises(NotImplementedError, match="not supported for opset 9"): + from_onnx(model, opset=9, keep_params_in_input=True) + + if __name__ == "__main__": tvm.testing.main()