Skip to content

iris.x: Device-side communication + iris.ops APIs.#296

Merged
neoblizz merged 58 commits intomainfrom
muhosama/iris-x
Feb 3, 2026
Merged

iris.x: Device-side communication + iris.ops APIs.#296
neoblizz merged 58 commits intomainfrom
muhosama/iris-x

Conversation

@neoblizz
Copy link
Member

@neoblizz neoblizz commented Dec 9, 2025

KNOWN ISSUES

  • This PR does not support unaligned matrices, we need support for other=0.0 for that. (disabled)
  • This PR has a race condition for spinlock algorithm, I believe we need cache_modifiers=".wt" support for the data to be visible before the lock is freed. (run a few times to pass)
  • This PR doesn't have a working ring-reduce tile-based algorithm --- disabled for now. I have no idea how I am going to eventually implement that.
  • This PR uses TritonBLAS as a dependency and needs it installed as editable (pip install isn't working).

TLDR;

Introduces iris.x (device-side tile-level primitives for custom fusion patterns) and iris.ops (high-level fused GEMM+collective operations), enabling fine-grained control and overlap of computation and communication within Triton kernels.

iris.x: Device-side Tile-Level Primitives for Fused Patterns

iris.x provides composable device-side tile-level primitives for fine-grained compute and collective operations within Triton kernels. Unlike iris.ccl which operates on full tensors with internal tiling, iris.x gives you direct control over tile iteration, enabling custom fusion patterns and fine-grained overlap of computation and communication.

Key Differences from iris.ccl

# iris.ccl: High-level, operates on full tensors
shmem.all_reduce(input_tensor, output_tensor)

# iris.x: Low-level, operates on tiles within your kernel
@triton.jit
def my_kernel(...):
    tile = iris.x.TileView(pid_m, pid_n, BLOCK_M, BLOCK_N)
    # ... compute data ...
    tile_with_data = iris.x.Tile(pid_m, pid_n, BLOCK_M, BLOCK_N, data)
    iris.x.all_reduce_atomic(tile_with_data, dst_view, ctx)

Overview

Feature iris.ccl iris.x
Level Host-side, operates on full tensors Device-side, operates on tiles
Tiling Automatic, internal Manual, user-controlled
Control Simple, high-level Fine-grained, low-level
Use Case General collectives Custom fusion, overlap patterns

Core Abstractions

TileView

Represents a tile with position and dimensions (no data):

tile = iris.x.TileView(pid_m, pid_n, BLOCK_M, BLOCK_N)

Tile

Extends TileView with embedded data for computed results:

tile = iris.x.Tile(pid_m, pid_n, BLOCK_M, BLOCK_N, data)

TensorView

Describes tensor memory layout for device-side access:

view = iris.x.TensorView(ptr, M, N, stride_m, stride_n)

DeviceContext

Encapsulates distributed context (rank, world size, heap bases):

ctx = iris.x.DeviceContext(rank, world_size, heap_bases)

Collective Operations

All-Reduce

Reduce data across all ranks with support for multiple algorithms. The Tile object must contain pre-computed data in tile.data:

# Standalone API with specific algorithms
iris.x.all_reduce_atomic(tile, dst_view, ctx)     # Default: atomic
iris.x.all_reduce_ring(tile, src_view, dst_view, ctx)       # Ring algorithm
iris.x.all_reduce_two_shot(tile, src_view, dst_view, locks, start_tile, stride, ctx)   # Two-shot
iris.x.all_reduce_one_shot(tile, src_view, dst_view, locks, ctx)   # One-shot
iris.x.all_reduce_spinlock(tile, dst_view, locks, ctx)  # Spinlock

Algorithms:

  • atomic (default): Fine-grained atomic operations - takes pre-computed tile.data and atomically adds to all ranks
  • ring: Ring-based reduction - loads from src_view and performs ring reduce-scatter + all-gather
  • two_shot: Two-shot algorithm - ranks divide responsibility, load from remote ranks, reduce locally, then scatter
  • one_shot: One-shot algorithm - all ranks load from all remote ranks and reduce locally (duplicated work)
  • spinlock: Lock-based synchronization - uses spinlocks for exclusive access during read-modify-write

All-Gather

Gather data from all ranks along a specified dimension. The Tile object must contain pre-computed data in tile.data:

# Standalone API
iris.x.all_gather(tile, dst_view, dim, ctx)

The function scatters the pre-computed tile.data to all ranks at rank-specific offsets along the gather dimension.

All-to-All

Personalized all-to-all exchange. Can use either Tile or TileView since it loads data from src_view internally:

# Standalone API
tile = iris.x.TileView(pid_m, pid_n, BLOCK_M, BLOCK_N)
iris.x.all_to_all(tile, src_view, dst_view, N_per_rank, ctx)

Each rank sends N_per_rank columns to every other rank and receives N_per_rank columns from every other rank.

Reduce-Scatter

Reduce and scatter results to assigned ranks. The Tile object must contain pre-computed data in tile.data:

# Standalone API
iris.x.reduce_scatter(tile, src_view, dst_view, locks, ctx)

Each rank reduces only its assigned contiguous block of tiles using a two-shot approach, storing results locally.

iris.ops: High-Level Fused GEMM+CCL Operations

While iris.x provides low-level tile primitives for custom fusion patterns, iris.ops offers high-level torch-like APIs for common fused GEMM+collective patterns. These operations automatically handle tiling, memory management, and hardware parameter inference.

Available Operations

# Via shmem.ops namespace (recommended)
shmem.ops.matmul_all_reduce(output, A, B)           # GEMM + All-Reduce
shmem.ops.all_gather_matmul(output, A_sharded, B)   # All-Gather + GEMM  
shmem.ops.matmul_all_gather(output, A, B)           # GEMM + All-Gather
shmem.ops.matmul_reduce_scatter(output, A, B)       # GEMM + Reduce-Scatter

Key Features

  • Automatic inference: Dimensions, strides, and hardware parameters inferred from tensors
  • Configurable: Optional FusedConfig for tuning block sizes and algorithms
  • Workspace management: Optional pre-allocated workspace for repeated calls
  • Async support: Optional async mode for overlapping computation

Example Usage

import iris

# Initialize Iris with symmetric heap
shmem = iris.iris(heap_size=2**33)

# Allocate tensors on symmetric heap
A = shmem.randn((M, K), dtype=torch.float16)
B = shmem.randn((K, N), dtype=torch.float16)
output = shmem.zeros((M, N), dtype=torch.float16)

# Fused GEMM + All-Reduce (for data parallelism)
shmem.ops.matmul_all_reduce(output, A, B)

# Fused GEMM + Reduce-Scatter (for tensor parallelism)
N_local = N // world_size
output_local = shmem.zeros((M, N_local), dtype=torch.float16)
shmem.ops.matmul_reduce_scatter(output_local, A, B)

When to Use iris.ops vs iris.x

Use Case Choose
Standard GEMM+collective patterns iris.ops
Custom fusion patterns iris.x
Maximum performance tuning iris.ops with FusedConfig
Research/experimentation iris.x

iris.x Usage Example

Here's a complete example showing custom tile iteration with all-reduce using iris.x:

import triton
import triton.language as tl
import iris.x

@triton.jit
def custom_kernel(
    input_ptr, output_ptr,
    M, N,
    stride_m, stride_n,
    heap_bases: tl.tensor,
    cur_rank: tl.constexpr,
    world_size: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    # Setup device context
    ctx = iris.x.DeviceContext(cur_rank, world_size, heap_bases)
    
    # Create tensor views
    src_view = iris.x.TensorView(input_ptr, M, N, stride_m, stride_n)
    dst_view = iris.x.TensorView(output_ptr, M, N, stride_m, stride_n)
    
    # Get program ID
    pid = tl.program_id(0)
    num_tiles_m = tl.cdiv(M, BLOCK_M)
    num_tiles_n = tl.cdiv(N, BLOCK_N)
    total_tiles = num_tiles_m * num_tiles_n
    
    # Persistent tile iteration
    for tile_id in range(pid, total_tiles, tl.num_programs(0)):
        # Compute tile coordinates
        pid_m = tile_id // num_tiles_n
        pid_n = tile_id % num_tiles_n
        
        # Create tile view
        tile_view = iris.x.TileView(pid_m, pid_n, BLOCK_M, BLOCK_N)
        
        # *** Your custom computation here ***
        # Load tile from src_view
        src_ptr, mask = src_view.tile_ptr(tile_view)
        data = tl.load(src_ptr, mask=mask)
        
        # Do computation on data...
        result = data * 2.0  # Example computation
        
        # Create tile with computed data
        tile = iris.x.Tile(pid_m, pid_n, BLOCK_M, BLOCK_N, result)
        
        # Perform all-reduce on this tile
        iris.x.all_reduce_atomic(tile, dst_view, ctx)

Test Plan

unittests

Test File Result Notes
test_arange.py ✅ 50/50 Fully passing
test_atomic_and_gluon.py ✅ 72/72 6841454! (was failing before)
test_atomic_and_triton.py ✅ 72/72 Still passing
test_atomic_cas_gluon.py ✅ 27/27 6841454! (was failing before)
test_atomic_cas_triton.py ✅ 27/27 Still passing
test_atomic_xchg_gluon.py ✅ 27/27 6841454! (was failing before)
test_atomic_xchg_triton.py ✅ 27/27 Still passing
test_broadcast_gluon.py ✅ 13/13 Fully passing
test_broadcast_triton.py ✅ 13/13 Fully passing
test_empty.py ✅ 84/84 Fully passing
test_full.py ✅ 87/87 Fully passing
test_get_num_xcc.py ✅ 1/1 Fully passing
test_iris_helpers.py ✅ 2/2 Fully passing
test_linspace.py ✅ 63/63 Fully passing
test_logging.py ✅ 7/7 Fully passing
test_ones.py ✅ 80/80 Fully passing
test_put_gluon.py ✅ 16/16 Fully passing
test_put_triton.py ✅ 16/16 Fully passing
test_rand.py ✅ 50/50 Fully passing
test_randint.py ✅ 66/66 Fully passing
test_randn.py ✅ 50/50 Fully passing
test_store_gluon.py ✅ 16/16 Fully passing
test_store_triton.py ✅ 16/16 Fully passing
test_zeros.py ✅ 79/79 Fully passing
test_zeros_like.py ✅ 81/81 Fully passing

iris.x

All tests pass! 💯

iris.ops

Test File 1 Rank 2 Ranks 4 Ranks 8 Ranks Total Tests
test_matmul_all_reduce.py ✅ 25/25 ✅ 25/25 ✅ 25/25 ✅ 25/25 100 tests
test_matmul_reduce_scatter.py ✅ 4/4 ✅ 4/4 ✅ 4/4 ✅ 4/4 16 tests
test_all_gather_matmul.py ✅ 2/2 ✅ 2/2 ✅ 2/2 ✅ 2/2 8 tests

@github-actions github-actions bot added in-progress We are working on it iris Iris project issue labels Dec 9, 2025
@neoblizz neoblizz changed the base branch from main to muhosama/ccl-more December 9, 2025 20:32
@neoblizz neoblizz requested a review from Copilot December 9, 2025 20:33
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces iris.x, a new module providing device-side tile-level primitives for fine-grained collective operations. Unlike iris.ccl which handles full tensors with internal tiling, iris.x provides composable functions that users can call from their own kernels to manage tile iteration themselves.

Key Changes:

  • New iris.x module with tile-level communication primitives (all-reduce, all-gather, all-to-all, reduce-scatter)
  • Fused GEMM+Communication operations requiring tritonBLAS (gemm_all_reduce, gemm_all_gather, etc.)
  • Comprehensive test suite for new primitives in tests/x/
  • CI/CD modernization with unified workflow replacing 3 separate workflows
  • Documentation updates and benchmark enhancements

Reviewed changes

Copilot reviewed 33 out of 33 changed files in this pull request and generated 14 comments.

Show a summary per file
File Description
iris/x/__init__.py Module initialization exposing all tile-level primitives with optional GEMM operations
iris/x/all_reduce.py Five all-reduce variants (atomic, one-shot, two-shot, spinlock, ring) for different use cases
iris/x/all_gather.py Tile-level all-gather primitive for gathering data from all ranks
iris/x/all_to_all.py Tile-level all-to-all primitive for bidirectional data exchange
iris/x/reduce_scatter.py Tile-level reduce-scatter that reduces and scatters to assigned ranks
iris/x/gemm_all_reduce.py Fused GEMM + all-reduce using tritonBLAS stages
iris/x/gemm_all_gather.py Fused GEMM + all-gather combining computation and communication
iris/x/gemm_reduce_scatter.py Fused GEMM + reduce-scatter for column-parallel workloads
iris/x/all_gather_gemm.py Fused all-gather + GEMM for tensor-parallel workloads
iris/x/common.py Shared utilities for tile indexing and offset computation
tests/x/test_*.py Comprehensive test suite validating all primitives against PyTorch references
.github/workflows/iris-tests.yml New unified test workflow supporting multiple test directories and install methods
.github/scripts/run_tests.sh Updated test runner with tritonBLAS installation for iris.x tests
tests/ccl/test_all_reduce.py Modified to add explicit preamble calls for better test isolation
pyproject.toml Added optional gemm dependency group for tritonBLAS
docs/reference/examples.md Updated documentation with new example references
benchmark/ccl/all_to_all/benchmark.py Added RCCL comparison benchmarking option

@mawad-amd
Copy link
Collaborator

@neoblizz we should be able to use aggregate to cleanup the APIs for device-side APIs. See https://godbolt.org/z/hY3oWfW1x

Resolved conflicts by accepting main's changes for:
- .gitignore
- benchmark/ccl/*.py files
- docker/Dockerfile
- iris/ccl/*.py files
…eContext

Refactor all tile-based collective operations and fused GEMM operators to use
new object-oriented API, dramatically simplifying function signatures and
improving code readability.

Changes:
- Collectives: all_gather, all_reduce (4 variants), reduce_scatter, all_to_all
- Fused ops: all_gather_gemm, gemm_all_gather, gemm_all_reduce, gemm_reduce_scatter
- Replace verbose parameter lists with OOP objects (Tile, TensorView, DeviceContext)
- Add tl.constexpr annotations to all GEMM kernel parameters
- Fix iris.load/atomic_add call signatures for correct argument ordering
- Net reduction: -50 lines of code across 8 files
Update all test kernels to use new OOP API (Tile, TensorView, DeviceContext)
and fix critical tile iteration bug causing test failures at scale.

Changes:
- Rename all test kernels to test_x_*_kernel pattern (avoids pytest warnings)
- Update kernel calls to use OOP objects instead of verbose parameters
- Fix tile iteration stride: use tl.num_programs(0) instead of 1 to prevent
  multiple CUs from processing the same tiles (fixes race conditions)
- Fix all_to_all PyTorch reference to use .contiguous() chunks
@neoblizz neoblizz changed the title iris.x: Device-side communication + .x APIs. iris.x: Device-side communication + .ops APIs. Jan 29, 2026
@neoblizz neoblizz changed the title iris.x: Device-side communication + .ops APIs. iris.x: Device-side communication + iris.ops APIs. Feb 1, 2026
Co-authored-by: Muhammad Awad <112003944+mawad-amd@users.noreply.github.com>
@neoblizz neoblizz merged commit 31339bf into main Feb 3, 2026
59 of 78 checks passed
@neoblizz neoblizz deleted the muhosama/iris-x branch February 3, 2026 02:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in-progress We are working on it iris Iris project issue

Projects

None yet

2 participants