Skip to content

Commit 29096fe

Browse files
committed
Add get_metrics_from_trace_tpu function to extract time from tpu trace.
1 parent 4295384 commit 29096fe

File tree

1 file changed

+36
-1
lines changed

1 file changed

+36
-1
lines changed

src/benchmark_utils.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@
1818
import subprocess
1919
import shutil
2020

21+
# The dictionary to map a CPU collective function to its corresponding operation on TPU
22+
# "psum_scatter_ici_op" has different implementation according to its `matrix_dim` and the number of TPUs, so it's not considered in this mapping dictionary.
23+
TARGET_TASK_NAME_COLLECTIVES_MAP = {
24+
"all_to_all_ici_op": r"all-to-all.[0-9]+",
25+
"all_gather_ici_op": r"all-gather.[0-9]+",
26+
"psum_ici_op": r"all-reduce.[0-9]+",
27+
"ppermute_ici_op": r"collective-permute-done",
28+
}
2129

2230
def simple_timeit(f, *args, matrix_dim=None, tries=10, task=None, trace_dir=None) -> float:
2331
"""Simple utility to time a function for multiple runs."""
@@ -60,8 +68,15 @@ def get_trace(log_dir: str) -> dict[str, Any]:
6068
return trace
6169

6270

63-
def get_metrics_from_trace(trace: dict[str, Any], task: str) -> float:
71+
def get_metrics_from_trace(trace: dict[str, Any], task: str) -> list[float]:
72+
73+
# Check if the given task name is a collective with corresponding TPU opertion.
74+
# This is a workaround and should be reverted or refactored in future.
75+
if task in TARGET_TASK_NAME_COLLECTIVES_MAP.keys():
76+
task = TARGET_TASK_NAME_COLLECTIVES_MAP[task]
77+
return get_metrics_from_trace_tpu(trace, task)
6478
event_matcher = re.compile(task)
79+
6580
if "traceEvents" not in trace:
6681
raise KeyError("Key 'traceEvents' not found in trace.")
6782

@@ -85,6 +100,26 @@ def get_metrics_from_trace(trace: dict[str, Any], task: str) -> float:
85100
raise
86101
return durations_ms
87102

103+
def get_metrics_from_trace_tpu(trace: dict[str, Any], task: str) -> list[float]:
104+
event_matcher = re.compile(task)
105+
106+
if "traceEvents" not in trace:
107+
raise KeyError("Key 'traceEvents' not found in trace.")
108+
109+
events = []
110+
for e in trace["traceEvents"]:
111+
if "name" in e and event_matcher.match(e["name"]):
112+
events.append(e)
113+
114+
# For each trace, find the TPU with smallest `pid` value and consider it to be TPU-0
115+
min_pid = min([e["pid"] for e in events])
116+
events_from_min_pid = [e for e in events if e["pid"] == min_pid]
117+
try:
118+
durations_ms = [float(e["args"]["device_duration_ps"]) / 1e9 for e in events_from_min_pid]
119+
except KeyError:
120+
print("KeyError: Key 'device_duration_ps' not found in the event object")
121+
raise
122+
return durations_ms
88123

89124
def is_local_directory_path(dir: str) -> bool:
90125
"""

0 commit comments

Comments
 (0)