Skip to content

Commit 4e09b70

Browse files
committed
Add get_metrics_from_trace_tpu function to extract time from tpu trace.
1 parent 4295384 commit 4e09b70

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

.vscode/settings.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"python-envs.defaultEnvManager": "ms-python.python:pyenv",
3+
"python-envs.pythonProjects": []
4+
}

src/benchmark_utils.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@
1818
import subprocess
1919
import shutil
2020

21+
# The dictionary to map a JAX (collective) function to its main HLO.
22+
TARGET_TASK_NAME_COLLECTIVES_MAP = {
23+
"all_to_all_ici_op": r"all-to-all.[0-9]+",
24+
"all_gather_ici_op": r"all-gather.[0-9]+",
25+
"psum_ici_op": r"all-reduce.[0-9]+",
26+
"ppermute_ici_op": r"collective-permute.[0-9]+",
27+
}
2128

2229
def simple_timeit(f, *args, matrix_dim=None, tries=10, task=None, trace_dir=None) -> float:
2330
"""Simple utility to time a function for multiple runs."""
@@ -60,8 +67,15 @@ def get_trace(log_dir: str) -> dict[str, Any]:
6067
return trace
6168

6269

63-
def get_metrics_from_trace(trace: dict[str, Any], task: str) -> float:
70+
def get_metrics_from_trace(trace: dict[str, Any], task: str) -> list[float]:
71+
72+
# Check if the given task name is a collective with corresponding TPU opertion.
73+
# This is a workaround and should be reverted or refactored in future.
74+
if task in TARGET_TASK_NAME_COLLECTIVES_MAP:
75+
task = TARGET_TASK_NAME_COLLECTIVES_MAP[task]
76+
return get_metrics_from_trace_tpu(trace, task)
6477
event_matcher = re.compile(task)
78+
6579
if "traceEvents" not in trace:
6680
raise KeyError("Key 'traceEvents' not found in trace.")
6781

@@ -85,6 +99,26 @@ def get_metrics_from_trace(trace: dict[str, Any], task: str) -> float:
8599
raise
86100
return durations_ms
87101

102+
def get_metrics_from_trace_tpu(trace: dict[str, Any], task: str) -> list[float]:
103+
event_matcher = re.compile(task)
104+
105+
if "traceEvents" not in trace:
106+
raise KeyError("Key 'traceEvents' not found in trace.")
107+
108+
events = []
109+
for e in trace["traceEvents"]:
110+
if "name" in e and event_matcher.match(e["name"]):
111+
events.append(e)
112+
113+
# For each trace, find the TPU with smallest `pid` value and consider it to be TPU-0
114+
min_pid = min([e["pid"] for e in events])
115+
events_from_min_pid = [e for e in events if e["pid"] == min_pid]
116+
try:
117+
durations_ms = [float(e["args"]["device_duration_ps"]) / 1e9 for e in events_from_min_pid]
118+
except KeyError:
119+
print("KeyError: Key 'device_duration_ps' not found in the event object")
120+
raise
121+
return durations_ms
88122

89123
def is_local_directory_path(dir: str) -> bool:
90124
"""

0 commit comments

Comments
 (0)