88import triton .language as tl
99import numpy as np
1010import iris
11- from examples .common .utils import read_realtime
11+ from iris ._mpi_helpers import mpi_allgather
12+ # from examples.common.utils import read_realtime
1213
14+ @triton .jit
15+ def read_realtime ():
16+ tmp = tl .inline_asm_elementwise (
17+ asm = "mov.u64 $0, %globaltimer;" ,
18+ constraints = ("=l" ),
19+ args = [],
20+ dtype = tl .int64 ,
21+ is_pure = False ,
22+ pack = 1 ,
23+ )
24+ return tmp
25+
26+ @triton .jit ()
27+ def gather_latencies (
28+ local_latency ,
29+ global_latency ,
30+ curr_rank ,
31+ num_ranks ,
32+ BLOCK_SIZE : tl .constexpr ,
33+ heap_bases : tl .tensor
34+ ):
35+ pid = tl .program_id (0 )
36+ block_start = pid * BLOCK_SIZE
37+ offsets = block_start + tl .arange (0 , BLOCK_SIZE )
38+
39+ latency_mask = offsets < num_ranks
40+ iris .put (local_latency + offsets , global_latency + curr_rank * num_ranks + offsets , curr_rank , 0 , heap_bases , mask = latency_mask )
1341
1442@triton .jit ()
1543def ping_pong (
1644 data ,
17- result ,
18- len ,
19- iter ,
45+ n_elements ,
2046 skip ,
21- flag : tl .tensor ,
47+ niter ,
48+ flag ,
2249 curr_rank ,
50+ peer_rank ,
2351 BLOCK_SIZE : tl .constexpr ,
2452 heap_bases : tl .tensor ,
2553 mm_begin_timestamp_ptr : tl .tensor = None ,
2654 mm_end_timestamp_ptr : tl .tensor = None ,
2755):
28- peer = (curr_rank + 1 ) % 2
2956 pid = tl .program_id (0 )
3057 block_start = pid * BLOCK_SIZE
3158 offsets = block_start + tl .arange (0 , BLOCK_SIZE )
3259
33- data_mask = offsets < len
60+ data_mask = offsets < n_elements
3461 flag_mask = offsets < 1
3562 time_stmp_mask = offsets < 1
3663
37- for i in range (iter + skip ):
64+ for i in range (niter + skip ):
3865 if i == skip :
3966 start = read_realtime ()
40- tl .atomic_xchg (mm_begin_timestamp_ptr + offsets , start , time_stmp_mask )
41- if curr_rank == (i + 1 ) % 2 :
42- while tl .load (flag , cache_modifier = ".cv" , volatile = True ) != i + 1 :
67+ tl .atomic_xchg (mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets , start , time_stmp_mask )
68+ first_rank = tl .minimum (curr_rank , peer_rank ) if (i % 2 ) == 0 else tl .maximum (curr_rank , peer_rank )
69+ token_first_done = i + 1
70+ token_second_done = i + 2
71+ if curr_rank == first_rank :
72+ iris .put (data + offsets , data + offsets , curr_rank , peer_rank , heap_bases , mask = data_mask )
73+ iris .store (flag + offsets , token_first_done , curr_rank , peer_rank , heap_bases , flag_mask )
74+ while tl .load (flag , cache_modifier = ".cv" , volatile = True ) != token_second_done :
4375 pass
44- iris .put (data + offsets , result + offsets , curr_rank , peer , heap_bases , mask = data_mask )
45- tl .store (flag + offsets , i + 1 , mask = flag_mask )
46- iris .put (flag + offsets , flag + offsets , curr_rank , peer , heap_bases , flag_mask )
4776 else :
48- iris .put (data + offsets , result + offsets , curr_rank , peer , heap_bases , mask = data_mask )
49- tl .store (flag + offsets , i + 1 , mask = flag_mask )
50- iris .put (flag + offsets , flag + offsets , curr_rank , peer , heap_bases , flag_mask )
51- while tl .load (flag , cache_modifier = ".cv" , volatile = True ) != i + 1 :
77+ while tl .load (flag , cache_modifier = ".cv" , volatile = True ) != token_first_done :
5278 pass
53- stop = read_realtime ( )
54- tl . atomic_xchg ( mm_end_timestamp_ptr + offsets , stop , time_stmp_mask )
79+ iris . put ( data + offsets , data + offsets , curr_rank , peer_rank , heap_bases , mask = data_mask )
80+ iris . store ( flag + offsets , token_second_done , curr_rank , peer_rank , heap_bases , flag_mask )
5581
82+ stop = read_realtime ()
83+ tl .atomic_xchg (mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets , stop , time_stmp_mask )
5684
57- @pytest .mark .parametrize (
58- "dtype" ,
59- [
60- torch .int32 ,
61- # torch.float16,
62- # torch.bfloat16,
63- # torch.float32,
64- ],
65- )
66- @pytest .mark .parametrize (
67- "heap_size" ,
68- [
69- (1 << 33 ),
70- ],
71- )
72- def test_load_bench (dtype , heap_size ):
85+ if __name__ == "__main__" :
86+ dtype = torch .int32
87+ heap_size = 1 << 32
7388 shmem = iris .iris (heap_size )
7489 num_ranks = shmem .get_num_ranks ()
7590 heap_bases = shmem .get_heap_bases ()
7691 cur_rank = shmem .get_rank ()
77- assert num_ranks == 2
7892
7993 BLOCK_SIZE = 1
80- BUFFER_LEN = 64 * 1024
94+ BUFFER_LEN = 1
8195
8296 iter = 200
83- skip = 20
84- mm_begin_timestamp = torch .zeros (BLOCK_SIZE , dtype = torch .int64 , device = "cuda" )
85- mm_end_timestamp = torch .zeros (BLOCK_SIZE , dtype = torch .int64 , device = "cuda" )
97+ skip = 1
98+ mm_begin_timestamp = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = torch .int64 , device = "cuda" )
99+ mm_end_timestamp = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = torch .int64 , device = "cuda" )
100+
101+ local_latency = torch .zeros ((num_ranks ), dtype = torch .float32 , device = "cuda" )
86102
87103 source_buffer = shmem .ones (BUFFER_LEN , dtype = dtype )
88104 result_buffer = shmem .zeros_like (source_buffer )
89- flag = shmem .ones (1 , dtype = dtype )
105+ flag = shmem .ones (1 , dtype = dtype )
90106
91107 grid = lambda meta : (1 ,)
92- ping_pong [grid ](
93- source_buffer ,
94- result_buffer ,
95- BUFFER_LEN ,
96- skip ,
97- iter ,
98- flag ,
99- cur_rank ,
100- BLOCK_SIZE ,
101- heap_bases ,
102- mm_begin_timestamp ,
103- mm_end_timestamp ,
104- )
105- shmem .barrier ()
106- begin_val = mm_begin_timestamp .cpu ().item ()
107- end_val = mm_end_timestamp .cpu ().item ()
108- with open (f"timestamps_{ cur_rank } .txt" , "w" ) as f :
109- f .write (f"mm_begin_timestamp: { begin_val } \n " )
110- f .write (f"mm_end_timestamp: { end_val } \n " )
108+ for source_rank in range (num_ranks ):
109+ for destination_rank in range (num_ranks ):
110+ if source_rank != destination_rank and cur_rank in [source_rank , destination_rank ]:
111+ peer_for_me = destination_rank if cur_rank == source_rank else source_rank
112+ ping_pong [grid ](source_buffer ,
113+ BUFFER_LEN ,
114+ skip , iter ,
115+ flag ,
116+ cur_rank , peer_for_me ,
117+ BLOCK_SIZE ,
118+ heap_bases ,
119+ mm_begin_timestamp ,
120+ mm_end_timestamp )
121+ shmem .barrier ()
122+
123+ for destination_rank in range (num_ranks ):
124+ local_latency [destination_rank ] = (mm_end_timestamp .cpu ()[destination_rank ] - mm_begin_timestamp .cpu ()[destination_rank ]) / iter
125+
126+ latency_matrix = mpi_allgather (local_latency .cpu ())
127+
128+ if cur_rank == 0 :
129+ with open (f"latency.txt" , "w" ) as f :
130+ f .write (" ," + ", " .join (f"R{ j } " for j in range (num_ranks )) + "\n " )
131+ for i in range (num_ranks ):
132+ row_entries = []
133+ for j in range (num_ranks ):
134+ val = float (latency_matrix [i , j ])
135+ row_entries .append (f"{ val :0.6f} " )
136+ line = f"R{ i } ," + ", " .join (row_entries ) + "\n "
137+ f .write (line )
0 commit comments