diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index a30c4d63..7cfba422 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -353,7 +353,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): s0 = time.perf_counter() # Restore original profiler setting for the profiling run config.get_keys()["enable_profiler"] = original_enable_profiler - if max_utils.profiler_enabled(config): + if original_enable_profiler: # Injecting user requested XLA tracing flags xla_flags = os.environ.get("XLA_FLAGS", "") new_flags = "--xla_enable_mxu_trace=true --xla_jf_dump_llo_html=true --xla_tpu_enable_llo_profiling=true"