Skip to content

Commit ad03093

Browse files
committed
initial impl of latency
1 parent a187495 commit ad03093

File tree

1 file changed

+88
-61
lines changed

1 file changed

+88
-61
lines changed

tests/examples/test_load_latency.py

Lines changed: 88 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -8,103 +8,130 @@
88
import triton.language as tl
99
import numpy as np
1010
import iris
11-
from examples.common.utils import read_realtime
11+
from iris._mpi_helpers import mpi_allgather
12+
# from examples.common.utils import read_realtime
1213

14+
@triton.jit
15+
def read_realtime():
16+
tmp = tl.inline_asm_elementwise(
17+
asm="mov.u64 $0, %globaltimer;",
18+
constraints=("=l"),
19+
args=[],
20+
dtype=tl.int64,
21+
is_pure=False,
22+
pack=1,
23+
)
24+
return tmp
25+
26+
@triton.jit()
27+
def gather_latencies(
28+
local_latency,
29+
global_latency,
30+
curr_rank,
31+
num_ranks ,
32+
BLOCK_SIZE: tl.constexpr,
33+
heap_bases: tl.tensor
34+
):
35+
pid = tl.program_id(0)
36+
block_start = pid * BLOCK_SIZE
37+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
38+
39+
latency_mask = offsets < num_ranks
40+
iris.put(local_latency + offsets, global_latency + curr_rank * num_ranks + offsets, curr_rank, 0, heap_bases, mask=latency_mask)
1341

1442
@triton.jit()
1543
def ping_pong(
1644
data,
17-
result,
18-
len,
19-
iter,
45+
n_elements,
2046
skip,
21-
flag: tl.tensor,
47+
niter,
48+
flag,
2249
curr_rank,
50+
peer_rank,
2351
BLOCK_SIZE: tl.constexpr,
2452
heap_bases: tl.tensor,
2553
mm_begin_timestamp_ptr: tl.tensor = None,
2654
mm_end_timestamp_ptr: tl.tensor = None,
2755
):
28-
peer = (curr_rank + 1) % 2
2956
pid = tl.program_id(0)
3057
block_start = pid * BLOCK_SIZE
3158
offsets = block_start + tl.arange(0, BLOCK_SIZE)
3259

33-
data_mask = offsets < len
60+
data_mask = offsets < n_elements
3461
flag_mask = offsets < 1
3562
time_stmp_mask = offsets < 1
3663

37-
for i in range(iter + skip):
64+
for i in range(niter + skip):
3865
if i == skip:
3966
start = read_realtime()
40-
tl.atomic_xchg(mm_begin_timestamp_ptr + offsets, start, time_stmp_mask)
41-
if curr_rank == (i + 1) % 2:
42-
while tl.load(flag, cache_modifier=".cv", volatile=True) != i + 1:
67+
tl.atomic_xchg(mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, start, time_stmp_mask)
68+
first_rank = tl.minimum(curr_rank, peer_rank) if (i % 2) == 0 else tl.maximum(curr_rank, peer_rank)
69+
token_first_done = i + 1
70+
token_second_done = i + 2
71+
if curr_rank == first_rank:
72+
iris.put(data + offsets, data + offsets, curr_rank, peer_rank, heap_bases, mask=data_mask)
73+
iris.store(flag + offsets, token_first_done, curr_rank, peer_rank, heap_bases, flag_mask)
74+
while tl.load(flag, cache_modifier=".cv", volatile=True) != token_second_done:
4375
pass
44-
iris.put(data + offsets, result + offsets, curr_rank, peer, heap_bases, mask=data_mask)
45-
tl.store(flag + offsets, i + 1, mask=flag_mask)
46-
iris.put(flag + offsets, flag + offsets, curr_rank, peer, heap_bases, flag_mask)
4776
else:
48-
iris.put(data + offsets, result + offsets, curr_rank, peer, heap_bases, mask=data_mask)
49-
tl.store(flag + offsets, i + 1, mask=flag_mask)
50-
iris.put(flag + offsets, flag + offsets, curr_rank, peer, heap_bases, flag_mask)
51-
while tl.load(flag, cache_modifier=".cv", volatile=True) != i + 1:
77+
while tl.load(flag, cache_modifier=".cv", volatile=True) != token_first_done:
5278
pass
53-
stop = read_realtime()
54-
tl.atomic_xchg(mm_end_timestamp_ptr + offsets, stop, time_stmp_mask)
79+
iris.put(data + offsets, data + offsets, curr_rank, peer_rank, heap_bases, mask=data_mask)
80+
iris.store(flag + offsets, token_second_done, curr_rank, peer_rank, heap_bases, flag_mask)
5581

82+
stop = read_realtime()
83+
tl.atomic_xchg(mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, stop, time_stmp_mask)
5684

57-
@pytest.mark.parametrize(
58-
"dtype",
59-
[
60-
torch.int32,
61-
# torch.float16,
62-
# torch.bfloat16,
63-
# torch.float32,
64-
],
65-
)
66-
@pytest.mark.parametrize(
67-
"heap_size",
68-
[
69-
(1 << 33),
70-
],
71-
)
72-
def test_load_bench(dtype, heap_size):
85+
if __name__ == "__main__":
86+
dtype = torch.int32
87+
heap_size = 1 << 32
7388
shmem = iris.iris(heap_size)
7489
num_ranks = shmem.get_num_ranks()
7590
heap_bases = shmem.get_heap_bases()
7691
cur_rank = shmem.get_rank()
77-
assert num_ranks == 2
7892

7993
BLOCK_SIZE = 1
80-
BUFFER_LEN = 64 * 1024
94+
BUFFER_LEN = 1
8195

8296
iter = 200
83-
skip = 20
84-
mm_begin_timestamp = torch.zeros(BLOCK_SIZE, dtype=torch.int64, device="cuda")
85-
mm_end_timestamp = torch.zeros(BLOCK_SIZE, dtype=torch.int64, device="cuda")
97+
skip = 1
98+
mm_begin_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda")
99+
mm_end_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda")
100+
101+
local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda")
86102

87103
source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype)
88104
result_buffer = shmem.zeros_like(source_buffer)
89-
flag = shmem.ones(1, dtype=dtype)
105+
flag = shmem.ones(1, dtype=dtype)
90106

91107
grid = lambda meta: (1,)
92-
ping_pong[grid](
93-
source_buffer,
94-
result_buffer,
95-
BUFFER_LEN,
96-
skip,
97-
iter,
98-
flag,
99-
cur_rank,
100-
BLOCK_SIZE,
101-
heap_bases,
102-
mm_begin_timestamp,
103-
mm_end_timestamp,
104-
)
105-
shmem.barrier()
106-
begin_val = mm_begin_timestamp.cpu().item()
107-
end_val = mm_end_timestamp.cpu().item()
108-
with open(f"timestamps_{cur_rank}.txt", "w") as f:
109-
f.write(f"mm_begin_timestamp: {begin_val}\n")
110-
f.write(f"mm_end_timestamp: {end_val}\n")
108+
for source_rank in range(num_ranks):
109+
for destination_rank in range(num_ranks):
110+
if source_rank != destination_rank and cur_rank in [source_rank, destination_rank]:
111+
peer_for_me = destination_rank if cur_rank == source_rank else source_rank
112+
ping_pong[grid](source_buffer,
113+
BUFFER_LEN,
114+
skip, iter,
115+
flag,
116+
cur_rank, peer_for_me,
117+
BLOCK_SIZE,
118+
heap_bases,
119+
mm_begin_timestamp,
120+
mm_end_timestamp)
121+
shmem.barrier()
122+
123+
for destination_rank in range(num_ranks):
124+
local_latency[destination_rank] = (mm_end_timestamp.cpu()[destination_rank] - mm_begin_timestamp.cpu()[destination_rank]) / iter
125+
126+
latency_matrix = mpi_allgather(local_latency.cpu())
127+
128+
if cur_rank == 0:
129+
with open(f"latency.txt", "w") as f:
130+
f.write(" ," + ", ".join(f"R{j}" for j in range(num_ranks)) + "\n")
131+
for i in range(num_ranks):
132+
row_entries = []
133+
for j in range(num_ranks):
134+
val = float(latency_matrix[i, j])
135+
row_entries.append(f"{val:0.6f}")
136+
line = f"R{i}," + ", ".join(row_entries) + "\n"
137+
f.write(line)

0 commit comments

Comments
 (0)