Skip to content

Commit d61ff59

Browse files
committed
.
1 parent 29096fe commit d61ff59

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

.vscode/settings.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"python-envs.defaultEnvManager": "ms-python.python:pyenv",
3+
"python-envs.pythonProjects": []
4+
}

src/benchmark_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,12 @@
1818
import subprocess
1919
import shutil
2020

21-
# The dictionary to map a CPU collective function to its corresponding operation on TPU
22-
# "psum_scatter_ici_op" has different implementation according to its `matrix_dim` and the number of TPUs, so it's not considered in this mapping dictionary.
21+
# The dictionary to map a JAX (collective) function to its main HLO.
2322
TARGET_TASK_NAME_COLLECTIVES_MAP = {
2423
"all_to_all_ici_op": r"all-to-all.[0-9]+",
2524
"all_gather_ici_op": r"all-gather.[0-9]+",
2625
"psum_ici_op": r"all-reduce.[0-9]+",
27-
"ppermute_ici_op": r"collective-permute-done",
26+
"ppermute_ici_op": r"collective-permute.[0-9]+",
2827
}
2928

3029
def simple_timeit(f, *args, matrix_dim=None, tries=10, task=None, trace_dir=None) -> float:
@@ -72,7 +71,7 @@ def get_metrics_from_trace(trace: dict[str, Any], task: str) -> list[float]:
7271

7372
# Check if the given task name is a collective with corresponding TPU opertion.
7473
# This is a workaround and should be reverted or refactored in future.
75-
if task in TARGET_TASK_NAME_COLLECTIVES_MAP.keys():
74+
if task in TARGET_TASK_NAME_COLLECTIVES_MAP:
7675
task = TARGET_TASK_NAME_COLLECTIVES_MAP[task]
7776
return get_metrics_from_trace_tpu(trace, task)
7877
event_matcher = re.compile(task)

0 commit comments

Comments
 (0)