Skip to content

Commit a187495

Browse files
Apply Ruff auto-fixes
1 parent 5fabfef commit a187495

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

tests/examples/test_load_latency.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ def ping_pong(
2929
pid = tl.program_id(0)
3030
block_start = pid * BLOCK_SIZE
3131
offsets = block_start + tl.arange(0, BLOCK_SIZE)
32-
32+
3333
data_mask = offsets < len
3434
flag_mask = offsets < 1
3535
time_stmp_mask = offsets < 1
3636

3737
for i in range(iter + skip):
38-
if (i == skip):
39-
start = read_realtime();
38+
if i == skip:
39+
start = read_realtime()
4040
tl.atomic_xchg(mm_begin_timestamp_ptr + offsets, start, time_stmp_mask)
4141
if curr_rank == (i + 1) % 2:
4242
while tl.load(flag, cache_modifier=".cv", volatile=True) != i + 1:
@@ -50,7 +50,7 @@ def ping_pong(
5050
iris.put(flag + offsets, flag + offsets, curr_rank, peer, heap_bases, flag_mask)
5151
while tl.load(flag, cache_modifier=".cv", volatile=True) != i + 1:
5252
pass
53-
stop = read_realtime();
53+
stop = read_realtime()
5454
tl.atomic_xchg(mm_end_timestamp_ptr + offsets, stop, time_stmp_mask)
5555

5656

@@ -66,7 +66,7 @@ def ping_pong(
6666
@pytest.mark.parametrize(
6767
"heap_size",
6868
[
69-
(1 << 33),
69+
(1 << 33),
7070
],
7171
)
7272
def test_load_bench(dtype, heap_size):
@@ -77,22 +77,34 @@ def test_load_bench(dtype, heap_size):
7777
assert num_ranks == 2
7878

7979
BLOCK_SIZE = 1
80-
BUFFER_LEN = 64*1024
80+
BUFFER_LEN = 64 * 1024
8181

8282
iter = 200
8383
skip = 20
8484
mm_begin_timestamp = torch.zeros(BLOCK_SIZE, dtype=torch.int64, device="cuda")
85-
mm_end_timestamp = torch.zeros(BLOCK_SIZE, dtype=torch.int64, device="cuda")
85+
mm_end_timestamp = torch.zeros(BLOCK_SIZE, dtype=torch.int64, device="cuda")
8686

8787
source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype)
8888
result_buffer = shmem.zeros_like(source_buffer)
89-
flag = shmem.ones(1, dtype=dtype)
89+
flag = shmem.ones(1, dtype=dtype)
9090

9191
grid = lambda meta: (1,)
92-
ping_pong[grid](source_buffer, result_buffer, BUFFER_LEN, skip, iter, flag, cur_rank, BLOCK_SIZE, heap_bases,mm_begin_timestamp, mm_end_timestamp)
92+
ping_pong[grid](
93+
source_buffer,
94+
result_buffer,
95+
BUFFER_LEN,
96+
skip,
97+
iter,
98+
flag,
99+
cur_rank,
100+
BLOCK_SIZE,
101+
heap_bases,
102+
mm_begin_timestamp,
103+
mm_end_timestamp,
104+
)
93105
shmem.barrier()
94106
begin_val = mm_begin_timestamp.cpu().item()
95107
end_val = mm_end_timestamp.cpu().item()
96-
with open(f'timestamps_{cur_rank}.txt', 'w') as f:
108+
with open(f"timestamps_{cur_rank}.txt", "w") as f:
97109
f.write(f"mm_begin_timestamp: {begin_val}\n")
98110
f.write(f"mm_end_timestamp: {end_val}\n")

0 commit comments

Comments
 (0)