Skip to content

Commit afd28a0

Browse files
committed
Add exception get_metrics_from_trace_tpu.
1 parent 44727bc commit afd28a0

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

src/benchmark_utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,15 @@ def simple_timeit(f, *args, matrix_dim=None, tries=10, task=None, trace_dir=None
3131
assert task is not None
3232

3333
if trace_dir:
34-
return timeit_from_trace(f, *args, matrix_dim=matrix_dim, tries=tries, task=task, trace_dir=trace_dir)
34+
try:
35+
outcomes_ms = timeit_from_trace(
36+
f, *args, matrix_dim=matrix_dim, tries=tries, task=task, trace_dir=trace_dir
37+
)
38+
if outcomes_ms is not None:
39+
return outcomes_ms
40+
print("Warning: timeit_from_trace returned empty results. Falling back to manual timing.")
41+
except Exception as e:
42+
print(f"Warning: Failed to get metrics from trace due to: {e}. Falling back to manual timing.")
3543

3644
outcomes_ms = []
3745
jax.block_until_ready(f(*args)) # warm it up!
@@ -71,9 +79,13 @@ def get_metrics_from_trace(trace: dict[str, Any], task: str) -> list[float]:
7179

7280
# Check if the given task name is a collective with corresponding TPU opertion.
7381
# This is a workaround and should be reverted or refactored in future.
82+
# If task is not present in the map, fallback to the default behavior to measure the timing from the CPU end.
7483
if task in TARGET_TASK_NAME_COLLECTIVES_MAP:
75-
task = TARGET_TASK_NAME_COLLECTIVES_MAP[task]
76-
return get_metrics_from_trace_tpu(trace, task)
84+
try:
85+
task = TARGET_TASK_NAME_COLLECTIVES_MAP[task]
86+
return get_metrics_from_trace_tpu(trace, task)
87+
except:
88+
return None
7789
event_matcher = re.compile(task)
7890

7991
if "traceEvents" not in trace:

0 commit comments

Comments
 (0)