[Relax][ONNX] Add ONNX Backend Tests for systematic frontend coverage#19515
[Relax][ONNX] Add ONNX Backend Tests for systematic frontend coverage#19515Aharrypotter wants to merge 2 commits intoapache:mainfrom
Conversation
…frontend coverage (apache#19505) Add a test harness that wraps the official ONNX Backend Test Suite (Node Tests) around the Relax ONNX importer. This gives systematic, spec-aligned coverage of 116 operators with 533 passing tests, replacing hand-written edge-case models with standardized protobuf test data. The runner follows the standard `onnx.backend.base.Backend` interface, using `from_onnx()` → `DecomposeOpsForInference()` → `LegalizeOps()` → `tvm.compile()` → `VirtualMachine` to execute each test case. Known failures are tracked via `xfail` by category (trig precision, quantization edge cases, dynamic split, etc.).
There was a problem hiding this comment.
Code Review
This pull request adds a systematic verification suite for the Relax ONNX importer using the official ONNX Backend Test Suite, including a new pytest marker and a backend adapter. Review feedback identifies a bug where tvm.runtime.Tensor is used instead of tvm.runtime.NDArray and a logic error in input mapping where initializers should be filtered out from the graph inputs to align with the ONNX test runner's behavior.
I am having trouble creating individual review comments. Click here to see my feedback.
tests/python/relax/test_frontend_onnx_backend.py (92)
tvm.runtime.Tensor is not a valid class in the TVM Python API. It should be tvm.runtime.NDArray (or tvm.nd.NDArray). Using the incorrect class name will result in an AttributeError when the test runner attempts to verify outputs.
if isinstance(output, (tvm.runtime.NDArray, np.ndarray)):
tests/python/relax/test_frontend_onnx_backend.py (123)
In ONNX, model.graph.input includes both model inputs and initializers (which serve as default values). The ONNX backend test runner typically only provides positional values in the inputs list for elements that are not initializers. Mapping positional inputs to graph.input directly in run() will lead to an incorrect mapping if initializers are interspersed in the input list. Filtering graph_input_names here to exclude initializers ensures that the positional mapping in TVMRelaxBackendRep.run aligns with the test runner's behavior.
initializer_names = {t.name for t in model.graph.initializer}
graph_input_names = [inp.name for inp in model.graph.input if inp.name not in initializer_names]
Thanks for the review, Addressing both points: 1.
>>> import tvm
>>> hasattr(tvm.runtime.Tensor)
True
>>> tvm.runtime.Tensor.__mro__
(<class "tvm.runtime._tensor.Tensor">, <class "tvm_ffi.core.Tensor">, ...)2. Excluding initializers from The ONNX Backend Test data does not use initializers — all inputs (including weights) are provided as The current positional mapping logic is correct for this test data format. That said, adding the initializer filter is a good defensive measure for robustness — if the test data format ever changes or a third-party test suite uses initializers, the filter would prevent silent misalignment. |
|
Let's try to run the new tests in CI and see if they increase CI pressure. @Aharrypotter |
|
@tvm-bot run slow tests |
This PR introduces a test runner that reuses the official ONNX Backend Test suite to systematically cover relax.frontend.onnx. - Node-level test filtering via BackendTest._test_items - ONNX backend pytest marker - SKIP_SLOW_TESTS support - Documented xfails for known importer gaps
9b3785e to
a7b6c01
Compare
|
Just curious can we completely move to new backend tests or do we still need to maintain old ones? |
|
I think it's fine to move to new backend tests @mshr-h |
I checked Sequence, Attention, and Quantization locally. Quantization has a few passing cases, but enabling it cleanly would require very specific per-test include patterns. The broader Attention also needs separate work: the ONNX backend tests use the standard Q/K/V Attention form, while the current Relax converter seems to support the older Microsoft-style packed-QKV path with Sequence has similar issues, mostly around runtime sequence inputs, dynamic positions, Given that, I would keep this PR focused on the initial stable subset and track these categories as follow-up items. |
|
I think we can move in that direction, but probably not completely in one step. The ONNX Backend Tests are very useful for standard operator semantic coverage, and they can replace some duplicated hand-written semantic tests over time. However, I agree that having two ONNX frontend test files could be confusing unless the boundary is clear. My intended split is that For this PR, I would keep the scope to landing the backend-test runner and the initial stable subset. Then, in a follow-up, we can audit |
Summary
Introduce a test runner that reuses the official ONNX Backend Test Suite to systematically verify the Relax ONNX importer. This complements the existing hand-written tests in
test_frontend_onnx.pyby providing spec-aligned coverage of standard ONNX operator semantics.Towards #19505
Motivation
The existing
test_frontend_onnx.pyhas 187 hand-written tests that validate TVM-specific importer behavior (parameter handling, name sanitization, dynamic shapes, Relax IR structure). However, it relies on ONNX Runtime as the reference and cannot systematically cover all edge cases defined in the ONNX specification.The ONNX Backend Test Suite provides 1653+ node-level tests with protobuf reference inputs/outputs. It is the industry standard for validating ONNX importers/exporters (used by ONNX Runtime, TensorFlow, PyTorch). Reusing it gives Relax a living, upstream-aligned correctness baseline.
What this PR adds
tests/python/relax/test_frontend_onnx_backend.py— a backend adapter (TVMRelaxBackend) that implements theonnx.backend.base.Backendinterface, wiringfrom_onnx()→DecomposeOpsForInference()→LegalizeOps()→tvm.compile()→VirtualMachine.Coverage
72 operators with 388 test cases, all passing. Only operators where every ONNX node test passes are included — no xfail markers.
Operators not yet covered include: cast (exotic dtypes), reduce ops (edge cases), reshape/resize/attention (complex behavior), quantization, and several others with known importer gaps. These can be added incrementally as the importer improves.
Test results
388 passed, 3216 skipped (CUDA variants + operators not yet in allowlist), 0 failed, 0 xfailed
CI impact
Design decisions
test_frontend_onnx.pyremains unchanged. Backend tests cover standard ONNX semantics; hand-written tests continue to cover TVM-specific behavior (dynamic shapes, Relax IR structure, importer options).backend_test.include()with^-anchored regex patterns. No access to private ONNX APIs.include()patterns use^test_{op}(?:_.*)?(?:_cpu|_cuda)$, which can cause false matches when a short op name is a prefix of a longer one (e.g.logvslog_softmax). Affected ops (log,max,relu) are excluded until a more precise matching strategy is adopted.