iris.x: Device-side communication + iris.ops APIs.#296
Merged
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR introduces iris.x, a new module providing device-side tile-level primitives for fine-grained collective operations. Unlike iris.ccl which handles full tensors with internal tiling, iris.x provides composable functions that users can call from their own kernels to manage tile iteration themselves.
Key Changes:
- New
iris.xmodule with tile-level communication primitives (all-reduce, all-gather, all-to-all, reduce-scatter) - Fused GEMM+Communication operations requiring tritonBLAS (gemm_all_reduce, gemm_all_gather, etc.)
- Comprehensive test suite for new primitives in
tests/x/ - CI/CD modernization with unified workflow replacing 3 separate workflows
- Documentation updates and benchmark enhancements
Reviewed changes
Copilot reviewed 33 out of 33 changed files in this pull request and generated 14 comments.
Show a summary per file
| File | Description |
|---|---|
iris/x/__init__.py |
Module initialization exposing all tile-level primitives with optional GEMM operations |
iris/x/all_reduce.py |
Five all-reduce variants (atomic, one-shot, two-shot, spinlock, ring) for different use cases |
iris/x/all_gather.py |
Tile-level all-gather primitive for gathering data from all ranks |
iris/x/all_to_all.py |
Tile-level all-to-all primitive for bidirectional data exchange |
iris/x/reduce_scatter.py |
Tile-level reduce-scatter that reduces and scatters to assigned ranks |
iris/x/gemm_all_reduce.py |
Fused GEMM + all-reduce using tritonBLAS stages |
iris/x/gemm_all_gather.py |
Fused GEMM + all-gather combining computation and communication |
iris/x/gemm_reduce_scatter.py |
Fused GEMM + reduce-scatter for column-parallel workloads |
iris/x/all_gather_gemm.py |
Fused all-gather + GEMM for tensor-parallel workloads |
iris/x/common.py |
Shared utilities for tile indexing and offset computation |
tests/x/test_*.py |
Comprehensive test suite validating all primitives against PyTorch references |
.github/workflows/iris-tests.yml |
New unified test workflow supporting multiple test directories and install methods |
.github/scripts/run_tests.sh |
Updated test runner with tritonBLAS installation for iris.x tests |
tests/ccl/test_all_reduce.py |
Modified to add explicit preamble calls for better test isolation |
pyproject.toml |
Added optional gemm dependency group for tritonBLAS |
docs/reference/examples.md |
Updated documentation with new example references |
benchmark/ccl/all_to_all/benchmark.py |
Added RCCL comparison benchmarking option |
This was
linked to
issues
Dec 10, 2025
Collaborator
|
@neoblizz we should be able to use |
Resolved conflicts by accepting main's changes for: - .gitignore - benchmark/ccl/*.py files - docker/Dockerfile - iris/ccl/*.py files
…eContext Refactor all tile-based collective operations and fused GEMM operators to use new object-oriented API, dramatically simplifying function signatures and improving code readability. Changes: - Collectives: all_gather, all_reduce (4 variants), reduce_scatter, all_to_all - Fused ops: all_gather_gemm, gemm_all_gather, gemm_all_reduce, gemm_reduce_scatter - Replace verbose parameter lists with OOP objects (Tile, TensorView, DeviceContext) - Add tl.constexpr annotations to all GEMM kernel parameters - Fix iris.load/atomic_add call signatures for correct argument ordering - Net reduction: -50 lines of code across 8 files
Update all test kernels to use new OOP API (Tile, TensorView, DeviceContext) and fix critical tile iteration bug causing test failures at scale. Changes: - Rename all test kernels to test_x_*_kernel pattern (avoids pytest warnings) - Update kernel calls to use OOP objects instead of verbose parameters - Fix tile iteration stride: use tl.num_programs(0) instead of 1 to prevent multiple CUs from processing the same tiles (fixes race conditions) - Fix all_to_all PyTorch reference to use .contiguous() chunks
iris.x: Device-side communication + .x APIs.iris.x: Device-side communication + .ops APIs.
iris.x: Device-side communication + .ops APIs.iris.x: Device-side communication + iris.ops APIs.
…all_reduce function
mawad-amd
reviewed
Feb 2, 2026
Co-authored-by: Muhammad Awad <112003944+mawad-amd@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
KNOWN ISSUES
other=0.0for that. (disabled)cache_modifiers=".wt"support for the data to be visible before the lock is freed. (run a few times to pass)TLDR;
Introduces
iris.x(device-side tile-level primitives for custom fusion patterns) andiris.ops(high-level fused GEMM+collective operations), enabling fine-grained control and overlap of computation and communication within Triton kernels.iris.x: Device-side Tile-Level Primitives for Fused Patternsiris.xprovides composable device-side tile-level primitives for fine-grained compute and collective operations within Triton kernels. Unlikeiris.cclwhich operates on full tensors with internal tiling,iris.xgives you direct control over tile iteration, enabling custom fusion patterns and fine-grained overlap of computation and communication.Key Differences from iris.ccl
Overview
iris.ccliris.xCore Abstractions
TileView
Represents a tile with position and dimensions (no data):
Tile
Extends TileView with embedded data for computed results:
TensorView
Describes tensor memory layout for device-side access:
DeviceContext
Encapsulates distributed context (rank, world size, heap bases):
Collective Operations
All-Reduce
Reduce data across all ranks with support for multiple algorithms. The
Tileobject must contain pre-computed data intile.data:Algorithms:
atomic(default): Fine-grained atomic operations - takes pre-computedtile.dataand atomically adds to all ranksring: Ring-based reduction - loads from src_view and performs ring reduce-scatter + all-gathertwo_shot: Two-shot algorithm - ranks divide responsibility, load from remote ranks, reduce locally, then scatterone_shot: One-shot algorithm - all ranks load from all remote ranks and reduce locally (duplicated work)spinlock: Lock-based synchronization - uses spinlocks for exclusive access during read-modify-writeAll-Gather
Gather data from all ranks along a specified dimension. The
Tileobject must contain pre-computed data intile.data:The function scatters the pre-computed
tile.datato all ranks at rank-specific offsets along the gather dimension.All-to-All
Personalized all-to-all exchange. Can use either
TileorTileViewsince it loads data fromsrc_viewinternally:Each rank sends N_per_rank columns to every other rank and receives N_per_rank columns from every other rank.
Reduce-Scatter
Reduce and scatter results to assigned ranks. The
Tileobject must contain pre-computed data intile.data:Each rank reduces only its assigned contiguous block of tiles using a two-shot approach, storing results locally.
iris.ops: High-Level Fused GEMM+CCL OperationsWhile
iris.xprovides low-level tile primitives for custom fusion patterns,iris.opsoffers high-level torch-like APIs for common fused GEMM+collective patterns. These operations automatically handle tiling, memory management, and hardware parameter inference.Available Operations
Key Features
FusedConfigfor tuning block sizes and algorithmsExample Usage
When to Use
iris.opsvsiris.xiris.opsiris.xiris.opswithFusedConfigiris.xiris.xUsage ExampleHere's a complete example showing custom tile iteration with all-reduce using
iris.x:Test Plan
unittests
iris.xAll tests pass! 💯
iris.opstest_matmul_all_reduce.pytest_matmul_reduce_scatter.pytest_all_gather_matmul.py