@@ -31,7 +31,15 @@ def simple_timeit(f, *args, matrix_dim=None, tries=10, task=None, trace_dir=None
3131 assert task is not None
3232
3333 if trace_dir :
34- return timeit_from_trace (f , * args , matrix_dim = matrix_dim , tries = tries , task = task , trace_dir = trace_dir )
34+ try :
35+ outcomes_ms = timeit_from_trace (
36+ f , * args , matrix_dim = matrix_dim , tries = tries , task = task , trace_dir = trace_dir
37+ )
38+ if outcomes_ms is not None :
39+ return outcomes_ms
40+ print ("Warning: timeit_from_trace returned empty results. Falling back to manual timing." )
41+ except Exception as e :
42+ print (f"Warning: Failed to get metrics from trace due to: { e } . Falling back to manual timing." )
3543
3644 outcomes_ms = []
3745 jax .block_until_ready (f (* args )) # warm it up!
@@ -71,9 +79,13 @@ def get_metrics_from_trace(trace: dict[str, Any], task: str) -> list[float]:
7179
7280 # Check if the given task name is a collective with corresponding TPU opertion.
7381 # This is a workaround and should be reverted or refactored in future.
82+ # If task is not present in the map, fallback to the default behavior to measure the timing from the CPU end.
7483 if task in TARGET_TASK_NAME_COLLECTIVES_MAP :
75- task = TARGET_TASK_NAME_COLLECTIVES_MAP [task ]
76- return get_metrics_from_trace_tpu (trace , task )
84+ try :
85+ task = TARGET_TASK_NAME_COLLECTIVES_MAP [task ]
86+ return get_metrics_from_trace_tpu (trace , task )
87+ except :
88+ return None
7789 event_matcher = re .compile (task )
7890
7991 if "traceEvents" not in trace :
0 commit comments