1111from iris ._mpi_helpers import mpi_allgather
1212# from examples.common.utils import read_realtime
1313
14+
1415@triton .jit
1516def read_realtime ():
1617 tmp = tl .inline_asm_elementwise (
@@ -23,21 +24,25 @@ def read_realtime():
2324 )
2425 return tmp
2526
27+
2628@triton .jit ()
2729def gather_latencies (
28- local_latency ,
29- global_latency ,
30- curr_rank ,
31- num_ranks ,
32- BLOCK_SIZE : tl .constexpr ,
33- heap_bases : tl .tensor
30+ local_latency , global_latency , curr_rank , num_ranks , BLOCK_SIZE : tl .constexpr , heap_bases : tl .tensor
3431):
3532 pid = tl .program_id (0 )
3633 block_start = pid * BLOCK_SIZE
3734 offsets = block_start + tl .arange (0 , BLOCK_SIZE )
3835
3936 latency_mask = offsets < num_ranks
40- iris .put (local_latency + offsets , global_latency + curr_rank * num_ranks + offsets , curr_rank , 0 , heap_bases , mask = latency_mask )
37+ iris .put (
38+ local_latency + offsets ,
39+ global_latency + curr_rank * num_ranks + offsets ,
40+ curr_rank ,
41+ 0 ,
42+ heap_bases ,
43+ mask = latency_mask ,
44+ )
45+
4146
4247@triton .jit ()
4348def ping_pong (
@@ -66,7 +71,7 @@ def ping_pong(
6671 start = read_realtime ()
6772 tl .atomic_xchg (mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets , start , time_stmp_mask )
6873 first_rank = tl .minimum (curr_rank , peer_rank ) if (i % 2 ) == 0 else tl .maximum (curr_rank , peer_rank )
69- token_first_done = i + 1
74+ token_first_done = i + 1
7075 token_second_done = i + 2
7176 if curr_rank == first_rank :
7277 iris .put (data + offsets , data + offsets , curr_rank , peer_rank , heap_bases , mask = data_mask )
@@ -82,8 +87,9 @@ def ping_pong(
8287 stop = read_realtime ()
8388 tl .atomic_xchg (mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets , stop , time_stmp_mask )
8489
90+
8591if __name__ == "__main__" :
86- dtype = torch .int32
92+ dtype = torch .int32
8793 heap_size = 1 << 32
8894 shmem = iris .iris (heap_size )
8995 num_ranks = shmem .get_num_ranks ()
@@ -96,42 +102,48 @@ def ping_pong(
96102 iter = 200
97103 skip = 1
98104 mm_begin_timestamp = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = torch .int64 , device = "cuda" )
99- mm_end_timestamp = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = torch .int64 , device = "cuda" )
105+ mm_end_timestamp = torch .zeros ((num_ranks , BLOCK_SIZE ), dtype = torch .int64 , device = "cuda" )
100106
101- local_latency = torch .zeros ((num_ranks ), dtype = torch .float32 , device = "cuda" )
107+ local_latency = torch .zeros ((num_ranks ), dtype = torch .float32 , device = "cuda" )
102108
103109 source_buffer = shmem .ones (BUFFER_LEN , dtype = dtype )
104110 result_buffer = shmem .zeros_like (source_buffer )
105- flag = shmem .ones (1 , dtype = dtype )
111+ flag = shmem .ones (1 , dtype = dtype )
106112
107113 grid = lambda meta : (1 ,)
108114 for source_rank in range (num_ranks ):
109115 for destination_rank in range (num_ranks ):
110116 if source_rank != destination_rank and cur_rank in [source_rank , destination_rank ]:
111117 peer_for_me = destination_rank if cur_rank == source_rank else source_rank
112- ping_pong [grid ](source_buffer ,
113- BUFFER_LEN ,
114- skip , iter ,
115- flag ,
116- cur_rank , peer_for_me ,
117- BLOCK_SIZE ,
118- heap_bases ,
119- mm_begin_timestamp ,
120- mm_end_timestamp )
118+ ping_pong [grid ](
119+ source_buffer ,
120+ BUFFER_LEN ,
121+ skip ,
122+ iter ,
123+ flag ,
124+ cur_rank ,
125+ peer_for_me ,
126+ BLOCK_SIZE ,
127+ heap_bases ,
128+ mm_begin_timestamp ,
129+ mm_end_timestamp ,
130+ )
121131 shmem .barrier ()
122-
132+
123133 for destination_rank in range (num_ranks ):
124- local_latency [destination_rank ] = (mm_end_timestamp .cpu ()[destination_rank ] - mm_begin_timestamp .cpu ()[destination_rank ]) / iter
125-
134+ local_latency [destination_rank ] = (
135+ mm_end_timestamp .cpu ()[destination_rank ] - mm_begin_timestamp .cpu ()[destination_rank ]
136+ ) / iter
137+
126138 latency_matrix = mpi_allgather (local_latency .cpu ())
127139
128140 if cur_rank == 0 :
129- with open (f "latency.txt" , "w" ) as f :
141+ with open ("latency.txt" , "w" ) as f :
130142 f .write (" ," + ", " .join (f"R{ j } " for j in range (num_ranks )) + "\n " )
131143 for i in range (num_ranks ):
132144 row_entries = []
133145 for j in range (num_ranks ):
134146 val = float (latency_matrix [i , j ])
135147 row_entries .append (f"{ val :0.6f} " )
136148 line = f"R{ i } ," + ", " .join (row_entries ) + "\n "
137- f .write (line )
149+ f .write (line )
0 commit comments