1818import subprocess
1919import shutil
2020
21+ # The dictionary to map a JAX (collective) function to its main HLO.
22+ TARGET_TASK_NAME_COLLECTIVES_MAP = {
23+ "all_to_all_ici_op" : r"all-to-all.[0-9]+" ,
24+ "all_gather_ici_op" : r"all-gather.[0-9]+" ,
25+ "psum_ici_op" : r"all-reduce.[0-9]+" ,
26+ "ppermute_ici_op" : r"collective-permute.[0-9]+" ,
27+ }
2128
2229def simple_timeit (f , * args , matrix_dim = None , tries = 10 , task = None , trace_dir = None ) -> float :
2330 """Simple utility to time a function for multiple runs."""
@@ -60,8 +67,15 @@ def get_trace(log_dir: str) -> dict[str, Any]:
6067 return trace
6168
6269
63- def get_metrics_from_trace (trace : dict [str , Any ], task : str ) -> float :
70+ def get_metrics_from_trace (trace : dict [str , Any ], task : str ) -> list [float ]:
71+
72+ # Check if the given task name is a collective with corresponding TPU opertion.
73+ # This is a workaround and should be reverted or refactored in future.
74+ if task in TARGET_TASK_NAME_COLLECTIVES_MAP :
75+ task = TARGET_TASK_NAME_COLLECTIVES_MAP [task ]
76+ return get_metrics_from_trace_tpu (trace , task )
6477 event_matcher = re .compile (task )
78+
6579 if "traceEvents" not in trace :
6680 raise KeyError ("Key 'traceEvents' not found in trace." )
6781
@@ -85,6 +99,26 @@ def get_metrics_from_trace(trace: dict[str, Any], task: str) -> float:
8599 raise
86100 return durations_ms
87101
102+ def get_metrics_from_trace_tpu (trace : dict [str , Any ], task : str ) -> list [float ]:
103+ event_matcher = re .compile (task )
104+
105+ if "traceEvents" not in trace :
106+ raise KeyError ("Key 'traceEvents' not found in trace." )
107+
108+ events = []
109+ for e in trace ["traceEvents" ]:
110+ if "name" in e and event_matcher .match (e ["name" ]):
111+ events .append (e )
112+
113+ # For each trace, find the TPU with smallest `pid` value and consider it to be TPU-0
114+ min_pid = min ([e ["pid" ] for e in events ])
115+ events_from_min_pid = [e for e in events if e ["pid" ] == min_pid ]
116+ try :
117+ durations_ms = [float (e ["args" ]["device_duration_ps" ]) / 1e9 for e in events_from_min_pid ]
118+ except KeyError :
119+ print ("KeyError: Key 'device_duration_ps' not found in the event object" )
120+ raise
121+ return durations_ms
88122
89123def is_local_directory_path (dir : str ) -> bool :
90124 """
0 commit comments