@@ -29,14 +29,14 @@ def ping_pong(
2929 pid = tl .program_id (0 )
3030 block_start = pid * BLOCK_SIZE
3131 offsets = block_start + tl .arange (0 , BLOCK_SIZE )
32-
32+
3333 data_mask = offsets < len
3434 flag_mask = offsets < 1
3535 time_stmp_mask = offsets < 1
3636
3737 for i in range (iter + skip ):
38- if ( i == skip ) :
39- start = read_realtime ();
38+ if i == skip :
39+ start = read_realtime ()
4040 tl .atomic_xchg (mm_begin_timestamp_ptr + offsets , start , time_stmp_mask )
4141 if curr_rank == (i + 1 ) % 2 :
4242 while tl .load (flag , cache_modifier = ".cv" , volatile = True ) != i + 1 :
@@ -50,7 +50,7 @@ def ping_pong(
5050 iris .put (flag + offsets , flag + offsets , curr_rank , peer , heap_bases , flag_mask )
5151 while tl .load (flag , cache_modifier = ".cv" , volatile = True ) != i + 1 :
5252 pass
53- stop = read_realtime ();
53+ stop = read_realtime ()
5454 tl .atomic_xchg (mm_end_timestamp_ptr + offsets , stop , time_stmp_mask )
5555
5656
@@ -66,7 +66,7 @@ def ping_pong(
6666@pytest .mark .parametrize (
6767 "heap_size" ,
6868 [
69- (1 << 33 ),
69+ (1 << 33 ),
7070 ],
7171)
7272def test_load_bench (dtype , heap_size ):
@@ -77,22 +77,34 @@ def test_load_bench(dtype, heap_size):
7777 assert num_ranks == 2
7878
7979 BLOCK_SIZE = 1
80- BUFFER_LEN = 64 * 1024
80+ BUFFER_LEN = 64 * 1024
8181
8282 iter = 200
8383 skip = 20
8484 mm_begin_timestamp = torch .zeros (BLOCK_SIZE , dtype = torch .int64 , device = "cuda" )
85- mm_end_timestamp = torch .zeros (BLOCK_SIZE , dtype = torch .int64 , device = "cuda" )
85+ mm_end_timestamp = torch .zeros (BLOCK_SIZE , dtype = torch .int64 , device = "cuda" )
8686
8787 source_buffer = shmem .ones (BUFFER_LEN , dtype = dtype )
8888 result_buffer = shmem .zeros_like (source_buffer )
89- flag = shmem .ones (1 , dtype = dtype )
89+ flag = shmem .ones (1 , dtype = dtype )
9090
9191 grid = lambda meta : (1 ,)
92- ping_pong [grid ](source_buffer , result_buffer , BUFFER_LEN , skip , iter , flag , cur_rank , BLOCK_SIZE , heap_bases ,mm_begin_timestamp , mm_end_timestamp )
92+ ping_pong [grid ](
93+ source_buffer ,
94+ result_buffer ,
95+ BUFFER_LEN ,
96+ skip ,
97+ iter ,
98+ flag ,
99+ cur_rank ,
100+ BLOCK_SIZE ,
101+ heap_bases ,
102+ mm_begin_timestamp ,
103+ mm_end_timestamp ,
104+ )
93105 shmem .barrier ()
94106 begin_val = mm_begin_timestamp .cpu ().item ()
95107 end_val = mm_end_timestamp .cpu ().item ()
96- with open (f' timestamps_{ cur_rank } .txt' , 'w' ) as f :
108+ with open (f" timestamps_{ cur_rank } .txt" , "w" ) as f :
97109 f .write (f"mm_begin_timestamp: { begin_val } \n " )
98110 f .write (f"mm_end_timestamp: { end_val } \n " )
0 commit comments