Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,13 @@ def get_converter(cls, opset):
number smaller than or equal to opset belongs to all support versions.
"""
versions = [int(d.replace("_impl_v", "")) for d in dir(cls) if "_impl_v" in d]
versions = sorted(versions + [opset])
version = versions[max([i for i, v in enumerate(versions) if v == opset]) - 1]
compatible = [v for v in versions if v <= opset]
if not compatible:
raise NotImplementedError(
f"{cls.__name__} is not supported for opset {opset}. "
f"Minimum supported opset: {min(versions)}"
)
Comment on lines +306 to +310
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic to raise an error for unsupported opsets is a great improvement. However, there's an edge case that could lead to an unhandled exception. If an operator has no _impl_vN methods, the versions list will be empty. In this scenario, min(versions) will raise a ValueError. It's better to handle this case explicitly to provide a more informative error message and prevent the crash.

Suggested change
if not compatible:
raise NotImplementedError(
f"{cls.__name__} is not supported for opset {opset}. "
f"Minimum supported opset: {min(versions)}"
)
if not compatible:
if not versions:
raise NotImplementedError(
f"{cls.__name__} is not supported for opset {opset}, as no implementations are available."
)
raise NotImplementedError(
f"{cls.__name__} is not supported for opset {opset}. "
f"Minimum supported opset: {min(versions)}"
)

version = max(compatible)
if hasattr(cls, f"_impl_v{version}"):
return getattr(cls, f"_impl_v{version}")
raise NotImplementedError(f"opset version {version} of {cls.__name__} not implemented")
Expand Down
30 changes: 30 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3980,5 +3980,35 @@ def test_nms_score_threshold():
)


def test_reduce_mean():
# opset 13: axes passed as attribute
node = helper.make_node("ReduceMean", inputs=["x"], outputs=["y"], axes=[2], keepdims=1)
graph = helper.make_graph(
[node],
"reduce_mean_test",
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 68, 4, 18])],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 68, 1, 18])],
)
model = helper.make_model(graph, producer_name="reduce_mean_test")
check_correctness(model, opset=13)


def test_reduce_mean_unsupported_opset():
# Regression test for https://github.com/apache/tvm/issues/18698.
# When opset < minimum available impl version, get_converter previously
# wrapped to -1 and silently picked the newest impl instead of raising.
node = helper.make_node("ReduceMean", inputs=["x"], outputs=["y"], axes=[2], keepdims=1)
graph = helper.make_graph(
[node],
"reduce_mean_test",
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 68, 4, 18])],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 68, 1, 18])],
)
model = helper.make_model(graph, producer_name="reduce_mean_test")
model.opset_import[0].version = 9
with pytest.raises(NotImplementedError, match="not supported for opset 9"):
from_onnx(model, opset=9, keep_params_in_input=True)


Comment on lines +3983 to +4012
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new tests test_reduce_mean and test_reduce_mean_unsupported_opset are great for ensuring correctness and preventing regressions. However, they share a significant amount of code for model creation. To improve maintainability and reduce duplication, you could extract the model creation logic into a helper function.

def _get_reduce_mean_model():
    """Helper to create a ReduceMean ONNX model for testing."""
    node = helper.make_node("ReduceMean", inputs=["x"], outputs=["y"], axes=[2], keepdims=1)
    graph = helper.make_graph(
        [node],
        "reduce_mean_test",
        inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 68, 4, 18])],
        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 68, 1, 18])],
    )
    return helper.make_model(graph, producer_name="reduce_mean_test")


def test_reduce_mean():
    # opset 13: axes passed as attribute
    model = _get_reduce_mean_model()
    check_correctness(model, opset=13)


def test_reduce_mean_unsupported_opset():
    # Regression test for https://github.com/apache/tvm/issues/18698.
    # When opset < minimum available impl version, get_converter previously
    # wrapped to -1 and silently picked the newest impl instead of raising.
    model = _get_reduce_mean_model()
    model.opset_import[0].version = 9
    with pytest.raises(NotImplementedError, match="not supported for opset 9"):
        from_onnx(model, opset=9, keep_params_in_input=True)

if __name__ == "__main__":
tvm.testing.main()