Skip to content

Commit 4295384

Browse files
authored
Fix error in src/benchmark_utils.py (#24)
1 parent 19ace8e commit 4295384

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/benchmark_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import shutil
2020

2121

22-
def simple_timeit(f, *args, matrix_dim, tries=10, task=None, trace_dir=None) -> float:
22+
def simple_timeit(f, *args, matrix_dim=None, tries=10, task=None, trace_dir=None) -> float:
2323
"""Simple utility to time a function for multiple runs."""
2424
assert task is not None
2525

@@ -97,15 +97,20 @@ def is_local_directory_path(dir: str) -> bool:
9797
return dir.startswith("/") or dir.startswith("./") or dir.startswith("../")
9898

9999

100-
def timeit_from_trace(f, *args, matrix_dim, tries=10, task=None, trace_dir=None) -> float:
100+
def timeit_from_trace(f, *args, matrix_dim=None, tries=10, task=None, trace_dir=None) -> float:
101101
"""
102102
Time a function with jax.profiler and get the run time from the trace.
103103
"""
104104
LOCAL_TRACE_DIR = "/tmp/microbenchmarks_tmptrace"
105105

106106
jax.block_until_ready(f(*args)) # warm it up!
107107

108-
trace_name = f"{task}_dim_{matrix_dim}"
108+
if matrix_dim is not None:
109+
trace_name = f"{task}_dim_{matrix_dim}"
110+
else:
111+
trace_name = f"t_{task}_" + "".join(
112+
random.choices(string.ascii_uppercase + string.digits, k=10)
113+
)
109114

110115
trace_full_dir = f"{trace_dir}/{trace_name}"
111116
tmp_trace_dir = trace_full_dir

0 commit comments

Comments
 (0)