|
1 | 1 | """ |
2 | 2 | One-Shot All-Reduce Example |
3 | | -======================================== |
| 3 | +=========================== |
4 | 4 | This example demonstrates how to implement a one-shot pulling all-reduce operation |
5 | 5 | using Helion and PyTorch's distributed capabilities. It includes a Helion kernel |
6 | 6 | demonstrating how to do cross-device synchronization using symmetric memory signal pads |
|
10 | 10 | # %% |
11 | 11 | # Imports |
12 | 12 | # ------- |
| 13 | + |
| 14 | +# %% |
13 | 15 | from __future__ import annotations |
14 | 16 |
|
15 | 17 | import os |
|
24 | 26 |
|
25 | 27 | # %% |
26 | 28 | # Work around before symm mem natively supports extract dev_ptrs as tensors: from_blob |
| 29 | + |
| 30 | +# %% |
27 | 31 | from_blob_cpp = """ |
28 | 32 | #include <cuda.h> |
29 | 33 | #include <cuda_runtime.h> |
@@ -72,7 +76,10 @@ def dev_array_to_tensor_short( |
72 | 76 |
|
73 | 77 | # %% |
74 | 78 | # One Shot All-Reduce Kernel Implementation |
75 | | -# ---------------------------------------- |
| 79 | +# ----------------------------------------- |
| 80 | + |
| 81 | + |
| 82 | +# %% |
76 | 83 | @helion.jit( |
77 | 84 | config=helion.Config( |
78 | 85 | block_sizes=[8192], |
@@ -159,7 +166,10 @@ def one_shot_all_reduce_kernel( |
159 | 166 |
|
160 | 167 | # %% |
161 | 168 | # Attract tensors from symmetric memory handler |
162 | | -# ---------------------------------------- |
| 169 | +# --------------------------------------------- |
| 170 | + |
| 171 | + |
| 172 | +# %% |
163 | 173 | def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor: |
164 | 174 | """ |
165 | 175 | Prepares symmetric memory tensors for Helion one-shot all-reduce kernel. |
@@ -203,7 +213,10 @@ def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor: |
203 | 213 |
|
204 | 214 | # %% |
205 | 215 | # Testing Function |
206 | | -# ---------------------------------------- |
| 216 | +# ---------------- |
| 217 | + |
| 218 | + |
| 219 | +# %% |
207 | 220 | def test(N: int, device: torch.device, dtype: torch.dtype) -> None: |
208 | 221 | """ |
209 | 222 | Test the Helion all-reduce implementation against PyTorch's reference implementation. |
|
0 commit comments