Skip to content

Shift arbitrary values per row/col#1145

Open
cliffburdick wants to merge 3 commits intomainfrom
shift_arb
Open

Shift arbitrary values per row/col#1145
cliffburdick wants to merge 3 commits intomainfrom
shift_arb

Conversation

@cliffburdick
Copy link
Copy Markdown
Collaborator

The shift() operator can take an operator as an input to provide multiple shift values along an axis. This should be limited to 0 or 1D operators. This feature was not working and produced incorrect results. Unit tests and examples in the docs added to cover this case.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 3, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@cliffburdick cliffburdick changed the title Shift arb Shift arbitrary values per row/col Apr 3, 2026
@cliffburdick
Copy link
Copy Markdown
Collaborator Author

/build

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 3, 2026

Greptile Summary

This PR fixes the shift() operator's per-row/column shift feature (rank-1 shift tensor) which was previously broken and producing incorrect results. The fix introduces a SHIFT_DIM_ constant to correctly identify which dimension the 1D shift tensor indexes into, adds static_assert guards to restrict rank-1 shifts to rank-2 inputs only (cleanly addressing the prior rank≥3 ambiguity concern), replaces the generic MATX_ASSERT_COMPATIBLE_OP_SIZES(shift_) with a more precise size check, and generates correct JIT code for both the rank-0 and rank-1 shift paths. Three new test blocks and an updated doc example cover the added functionality.

Key changes:

  • include/matx/operators/shift.h: Added shift_rank_ / SHIFT_DIM_ compile-time constants; two new static_asserts guard against unsupported ranks; constructor and get_impl branch on shift_rank_ to index the 1D shift tensor via idx[SHIFT_DIM_]; JIT struct emits SHIFT_DIM_ and uses it in the generated lookup line.
  • test/00_operators/shift_test.cu: Three new test cases (per-column shift of dim 0, per-row shift of dim 1, negative per-column shift) plus reference-value verification loops.
  • docs_input/api/manipulation/rearranging/shift.rst: Description updated; second literalinclude example added.

Confidence Score: 5/5

Safe to merge — the rank-1 shift logic is correct, the static_assert guards cleanly resolve the prior rank≥3 concern, and all three new test cases verify the expected modular-shift arithmetic.

All previously raised P1 concerns (SHIFT_DIM_ aliasing for rank≥3, missing rank-3 test coverage) are resolved by the new static_asserts. The only remaining findings from prior threads (JIT double-indentation, missing trailing newline) are P2 style issues that do not affect runtime correctness or compilation. No new P0/P1 issues were found during this review pass.

No files require special attention beyond the pre-existing style items already flagged in prior threads.

Important Files Changed

Filename Overview
include/matx/operators/shift.h Core fix: adds SHIFT_DIM_ constant and branches get_impl/JIT on shift_rank_. Logic is correct for all tested rank-2 cases; JIT format string produces double-indented lookup line (pre-existing style issue from prior review thread), and SHIFT_DIM_ is emitted unconditionally in JIT even for rank-0 shifts (harmless redundancy).
test/00_operators/shift_test.cu Adds three well-structured test blocks; verification arithmetic matches the implementation's modular-shift semantics. File still lacks a trailing newline (flagged in prior review thread).
docs_input/api/manipulation/rearranging/shift.rst Documentation updated with accurate description of rank-1 shift semantics and second literalinclude block pointing to the new shift-test-2 example.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["shift<DIM>(op, shift_op)"] --> B{shift_rank_ == 1?}
    B -- "Yes (rank-1 tensor)" --> C["static_assert: get_rank<T1>() == 2"]
    C --> D["SHIFT_DIM_ = DIM==0 ? 1 : 0"]
    D --> E["Validate: shift_.Size(0) == sizes_[SHIFT_DIM_]"]
    E --> F["get_impl: shift = -get_value(shiftin, idx[SHIFT_DIM_])"]
    B -- "No (scalar / rank-0)" --> G["shift_rank_ == 0: pass size check"]
    G --> H["get_impl: shift = -get_value(shiftin, indices...)"]
    F --> I["shift = (shift + idx[DIM]) % sizes[DIM]"]
    H --> I
    I --> J{"shift < 0?"}
    J -- Yes --> K["shift += sizes[DIM]"]
    J -- No --> L["idx[DIM] = shift"]
    K --> L
    L --> M["return get_value(op, idx)"]
Loading

Reviews (3): Last reviewed commit: "Fixing issue with rank 3" | Re-trigger Greptile

Comment on lines +70 to +72
static_assert(shift_rank_ <= 1, "Shift operator must be rank 0 or rank 1. Higher-rank shift operators are not supported.");
static_assert(shift_rank_ != 1 || detail::get_rank<T1>() >= 2,
"Rank-1 shift operator requires input operator of rank 2 or higher");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 No test coverage for rank ≥ 3 with rank-1 shift

The static_assert only blocks rank < 2 for T1 when shift_rank_ == 1, which means rank-3+ tensors compile successfully. However, as noted in the SHIFT_DIM_ comment, rank ≥ 3 exhibits unintuitive behavior (multiple DIMs alias to the same SHIFT_DIM_). Given that this path is reachable and has subtleties, it would be worth either:

  • Adding a test that explicitly documents the rank-3 behavior, or
  • Strengthening the static_assert to detail::get_rank<T1>() == 2 to limit rank-1 shift operators to rank-2 inputs only until the higher-rank case is fully designed and tested.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant