|
19 | 19 | import shutil |
20 | 20 |
|
21 | 21 |
|
22 | | -def simple_timeit(f, *args, matrix_dim, tries=10, task=None, trace_dir=None) -> float: |
| 22 | +def simple_timeit(f, *args, matrix_dim=None, tries=10, task=None, trace_dir=None) -> float: |
23 | 23 | """Simple utility to time a function for multiple runs.""" |
24 | 24 | assert task is not None |
25 | 25 |
|
@@ -97,15 +97,20 @@ def is_local_directory_path(dir: str) -> bool: |
97 | 97 | return dir.startswith("/") or dir.startswith("./") or dir.startswith("../") |
98 | 98 |
|
99 | 99 |
|
100 | | -def timeit_from_trace(f, *args, matrix_dim, tries=10, task=None, trace_dir=None) -> float: |
| 100 | +def timeit_from_trace(f, *args, matrix_dim=None, tries=10, task=None, trace_dir=None) -> float: |
101 | 101 | """ |
102 | 102 | Time a function with jax.profiler and get the run time from the trace. |
103 | 103 | """ |
104 | 104 | LOCAL_TRACE_DIR = "/tmp/microbenchmarks_tmptrace" |
105 | 105 |
|
106 | 106 | jax.block_until_ready(f(*args)) # warm it up! |
107 | 107 |
|
108 | | - trace_name = f"{task}_dim_{matrix_dim}" |
| 108 | + if matrix_dim is not None: |
| 109 | + trace_name = f"{task}_dim_{matrix_dim}" |
| 110 | + else: |
| 111 | + trace_name = f"t_{task}_" + "".join( |
| 112 | + random.choices(string.ascii_uppercase + string.digits, k=10) |
| 113 | + ) |
109 | 114 |
|
110 | 115 | trace_full_dir = f"{trace_dir}/{trace_name}" |
111 | 116 | tmp_trace_dir = trace_full_dir |
|
0 commit comments