diff --git a/tests/unit/compile_cache_test.py b/tests/unit/compile_cache_test.py index b94cd0dec6..a1f35b98a1 100644 --- a/tests/unit/compile_cache_test.py +++ b/tests/unit/compile_cache_test.py @@ -106,8 +106,11 @@ def test_train_step_cache_hit(): "cache was not writeable or the JAX cache configuration was ignored." ) - assert len(cache_files) == 1, ( - f"Expected exactly 1 JAX compilation cache file, but found {len(cache_files)}: {cache_files}. " + train_step_cache_files = [f for f in cache_files if f.startswith("jit_train_step")] + assert len(train_step_cache_files) == 1, ( + f"Expected exactly 1 JAX compilation cache file for 'jit_train_step', " + f"but found {len(train_step_cache_files)}: {train_step_cache_files} " + f"(all cache files: {cache_files}). " "This indicates a cache miss where AOT compilation and runtime execution generated different keys, " "causing train_step to be compiled twice (double-compilation regression)." )