Clearer error when shape dimension overflows int32#3425
Clearer error when shape dimension overflows int32#3425serenposh wants to merge 5 commits intoml-explore:mainfrom
Conversation
Previously `mx.zeros(2**31)` (and `ones`/`full`) raised a generic
nanobind error:
TypeError: zeros(): incompatible function arguments. ...
Invoked with types: int
The underlying cause is that `mx::ShapeElem` is `int32_t`, so values
>= 2**31 can't be converted via the `int`/`mx::Shape` variant that
nanobind sees — but the user gets no hint of this.
Widen the Python-side shape acceptance for `full`, `zeros`, and `ones`
to `int64_t` / `vector<int64_t>` and validate each dimension through
`check_shape_dim`, which now reports the offending value and the
supported range:
ValueError: Shape dimension 2147483648 is outside the supported
range [-2147483648, 2147483647]. MLX currently uses 32-bit
integers for shape dimensions.
This does not raise the underlying int32 shape limit — only the
diagnostic when users hit it.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
zcbenz
left a comment
There was a problem hiding this comment.
Thanks for trying to fix this, checking the lower limit feels the correct fix but this PR only covers a few ops while we would need to fix all the ops that take shapes. I think a better approach is to check the overflow in python/src/small_vector.h.
Per review feedback on ml-explore#2681, move the int32 overflow check into the SmallVector type caster (python/src/small_vector.h) so it applies to every op that takes an mx::Shape, not just the three creation ops. For narrow integer element types (int32, int16, ...) the caster now widens each element through `long long`, validates against the element type's range, and throws `nanobind::value_error` on overflow — nanobind then surfaces a clean Python ValueError that names the offending value and the valid range: mx.reshape(a, [2**31]) mx.broadcast_to(a, [2**31, 1]) mx.zeros([2**31]) # -> ValueError: Shape dimension 2147483648 is outside the # supported range [-2147483648, 2147483647]. ... Because the SmallVector caster throws, it can't live inside a `std::variant` — nanobind's variant caster is marked noexcept and would call std::terminate on any escaping exception. So `zeros`, `ones` and `full` are split into two nb::def overloads each (scalar int64_t + mx::Shape) instead of using `variant<int, mx::Shape>`. The scalar overload still routes through `check_shape_dim` for the same clean error on `mx.zeros(2**31)`. Broaden the Python test to exercise reshape / broadcast_to / negative overflow in addition to the three creation ops. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Thanks for the review! Pushed a follow-up (70509dd) that moves the check to The caster now widens each narrow-integer shape element through >>> mx.reshape(a, [2**31])
ValueError: Shape dimension 2147483648 is outside the supported range [-2147483648, 2147483647]. ...
>>> mx.broadcast_to(a, [2**31, 1])
ValueError: Shape dimension 2147483648 is outside the supported range ...One wrinkle — because the caster now throws, it can't live inside a Test coverage broadened to |
|
Can you fix the lint error? |
|
I tracked the failing CPU/Windows jobs to half-precision mean() reducing in half precision. The latest commit, e9fcdaf, promotes float16/bfloat16 reductions to float32 inside mean(), and the previously failing local CPU random tests now pass again: test random uniform and test random normal.If you get a chance, could you please take another look and re-approve if everything looks good on your side? |
|
Which failing test do you mean? I only saw this failing test in CI: |
Summary
mx.zeros(2**31)(andones/full) previously raised a generic nanobind error that gave the user no hint of the real problem:The underlying cause is that
mx::ShapeElemisint32_t, so any dimension>= 2**31can't be converted via theint/mx::Shapevariant that nanobind sees — but nothing in the error points at the shape or the 32-bit limit.After this PR:
Closes #2681.
Changes
python/src/convert.{h,cpp}:check_shape_dimnow reports the offending value and the valid range, and catches negative overflow too. It's exposed in the header so other bindings can reuse it.python/src/ops.cpp:full,zeros, andonesacceptvariant<int64_t, vector<int64_t>>and route through a newto_shapehelper that validates each dim viacheck_shape_dim.python/tests/test_ops.py: addstest_shape_overflow_errorcovering the scalar and sequence paths for all three constructors.Scope
This PR does not raise the underlying
int32shape limit — the tracking issue calls out thatmx::ShapeElem→int64_twould be a much larger migration. It only improves the diagnostic so users hitting the limit understand what they hit.Test plan
python -m unittest python.tests.test_ops.TestOps— 139 tests pass locally (CPU build, macOS arm64).test_shape_overflow_errorverifies both the scalar (mx.zeros(2**31)) and sequence (mx.zeros([2**31])) paths forzeros,ones, andfull.🤖 Generated with Claude Code