From cbdea70e6e04c75ca156b878502b77b8dd3ca77e Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Mon, 13 Oct 2025 14:33:43 -0700 Subject: [PATCH] Add progress bar for precompiling stack-info: PR: https://github.com/pytorch/helion/pull/919, branch: jansel/stack/174 --- helion/autotuner/base_search.py | 19 +++++++++++++++++-- helion/autotuner/pattern_search.py | 4 ++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index 63248d2f7..80b3801dd 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -326,7 +326,10 @@ def parallel_benchmark( self.start_precompile_and_check_for_hangs, zip(configs, fns, strict=True), ) - ] + ], + desc=f"{desc} precompiling" + if self.settings.autotune_progress_bar + else None, ) else: is_workings = [True] * len(configs) @@ -336,7 +339,7 @@ def parallel_benchmark( iterator = iter_with_progress( zip(configs, fns, is_workings, strict=True), total=len(configs), - description=desc, + description=f"{desc}: exploring neighbors", enabled=self.settings.autotune_progress_bar, ) for config, fn, is_working in iterator: @@ -725,6 +728,7 @@ def __call__(self) -> bool: @staticmethod def wait_for_all( futures: list[PrecompileFuture], + desc: str | None = None, ) -> list[bool]: """ Wait for all precompile futures to complete. @@ -735,10 +739,21 @@ def wait_for_all( Returns: A list of boolean values indicating completion status. """ + progress = iter_with_progress( + range(len(futures)), + total=len(futures), + description=desc, + enabled=desc is not None, + ) + next(progress, None) # display the progress bar immediately + progress_left = len(futures) remaining = [f for f in futures if f.ok is None] try: while remaining: remaining = PrecompileFuture._wait_for_all_step(remaining) + while progress_left > len(remaining): + next(progress, None) + progress_left -= 1 except Exception: for f in remaining: if (p := f.process) is not None: diff --git a/helion/autotuner/pattern_search.py b/helion/autotuner/pattern_search.py index 8d3910879..477e1dd0f 100644 --- a/helion/autotuner/pattern_search.py +++ b/helion/autotuner/pattern_search.py @@ -102,11 +102,11 @@ def _autotune(self) -> Config: unbenchmarked = [m for m in self.population if len(m.perfs) == 0] if unbenchmarked: self.parallel_benchmark_population( - unbenchmarked, desc=f"Generation {generation}: Exploring neighbors" + unbenchmarked, desc=f"Generation {generation}:" ) # higher-accuracy rebenchmark self.rebenchmark_population( - self.population, desc=f"Generation {generation}: Verifying top configs" + self.population, desc=f"Generation {generation}: verifying top configs" ) # Log final statistics for this generation self.log(f"Generation {generation} complete:", self.statistics)