@@ -38,10 +38,10 @@ def load_remote(
3838 if i == skip :
3939 start = read_realtime ()
4040 tl .store (mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets , start , time_stmp_mask )
41-
41+
4242 # iris.load(data + offsets, curr_rank, peer_rank,heap_bases, data_mask)
4343 from_base = tl .load (heap_bases + curr_rank )
44- to_base = tl .load (heap_bases + peer_rank )
44+ to_base = tl .load (heap_bases + peer_rank )
4545 offset = tl .cast (data + offsets , tl .uint64 ) - from_base
4646 translated_ptr = tl .cast (tl .cast (to_base , tl .pointer_type (tl .int8 )) + offset , (data + offsets ).dtype )
4747 result = tl .load (translated_ptr , mask = data_mask , cache_modifier = ".cv" , volatile = True )
@@ -240,15 +240,14 @@ def print_run_settings(
240240 grid = lambda meta : (1 ,)
241241 for source_rank in range (num_ranks ):
242242 for destination_rank in range (num_ranks ):
243- if cur_rank in [source_rank , destination_rank ]:
244- peer_for_me = destination_rank if cur_rank == source_rank else source_rank
243+ if cur_rank == source_rank :
245244 load_remote [grid ](
246245 source_buffer ,
247246 BUFFER_LEN ,
248247 skip ,
249248 niter ,
250249 cur_rank ,
251- peer_for_me ,
250+ destination_rank ,
252251 BLOCK_SIZE ,
253252 heap_bases ,
254253 mm_begin_timestamp ,
@@ -258,13 +257,16 @@ def print_run_settings(
258257
259258 mm_begin_cpu = mm_begin_timestamp .cpu ().numpy ()
260259 mm_end_cpu = mm_end_timestamp .cpu ().numpy ()
260+
261+ gpu_freq = iris .hip .get_wall_clock_rate (cur_rank )
262+
261263 for destination_rank in range (num_ranks ):
262264 delta = mm_end_cpu [destination_rank , :] - mm_begin_cpu [destination_rank , :]
263- avg_ns = float (delta .sum () / max (1 , delta .size ) / max (1 , niter ))
264- local_latency [destination_rank ] = avg_ns
265+ avg_cc = float (delta .sum () / max (1 , delta .size ) / max (1 , niter ))
266+ local_latency [destination_rank ] = avg_cc * 1e6 / gpu_freq
265267
266268 latency_matrix = mpi_allgather (local_latency .cpu ())
267269
268270 if cur_rank == 0 :
269271 save_results (latency_matrix , args ["output_file" ])
270- print ("Benchmark complete." )
272+ print ("Benchmark complete." )
0 commit comments