Skip to content

Commit 2e26cc8

Browse files
authored
Merge pull request #57 from foundation-model-stack/fix_timing_crash
Fix crash when timings only has one element
2 parents 4adf53a + 6852ddb commit 2e26cc8

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

scripts/inference.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -671,11 +671,12 @@ def infer(use_cache, do_sample, warmup):
671671
elif args.timing == "per-token":
672672
if not warmup:
673673
dprint(f"First-token latency: {timings[0]*1000:.3f} ms")
674-
dprint(f"Average next-token latency: {np.mean(timings[1:])*1000:.3f} ms")
675674
dprint(f"Average next-token latency (including first token): {np.mean(timings)*1000:.3f} ms")
676-
dprint(f"Max next-token latency: {np.max(timings[1:])*1000:.3f} ms (token #{np.argmax(timings[1:]) + 2})")
677-
dprint(f"Min next-token latency: {np.min(timings[1:])*1000:.3f} ms (token #{np.argmin(timings[1:]) + 2})")
678-
dprint(f"Std deviation of next-token latencies: {np.std(timings[1:])*1000:.3f} ms")
675+
if len(timings) > 1:
676+
dprint(f"Average next-token latency: {np.mean(timings[1:])*1000:.3f} ms")
677+
dprint(f"Max next-token latency: {np.max(timings[1:])*1000:.3f} ms (token #{np.argmax(timings[1:]) + 2})")
678+
dprint(f"Min next-token latency: {np.min(timings[1:])*1000:.3f} ms (token #{np.argmin(timings[1:]) + 2})")
679+
dprint(f"Std deviation of next-token latencies: {np.std(timings[1:])*1000:.3f} ms")
679680
timings = [f"{t*1000:.3f}" for t in timings]
680681
dprint(f"Per-token timing information: {', '.join(timings)} ms")
681682
if len(result.shape) == 1:

0 commit comments

Comments
 (0)