From ee48692809b7ca1f8aed19403e3836fa82dfc794 Mon Sep 17 00:00:00 2001 From: Igor Tsvetkov Date: Mon, 22 Jun 2026 11:25:01 -0700 Subject: [PATCH] Fix compile_cache_test fragility by filtering cache files for jit_train_step. Previously, test_train_step_cache_hit asserted that the compilation cache contained exactly 1 file, assuming only would be cached. However, in some environments (like nightly workflows), JAX also caches other compiled functions (like ), causing the test to fail. This fix changes the assertion to filter the cache files for the prefix before verifying that only one compilation occurred, making the test robust against other functions being cached. --- tests/unit/compile_cache_test.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/unit/compile_cache_test.py b/tests/unit/compile_cache_test.py index b94cd0dec6..a1f35b98a1 100644 --- a/tests/unit/compile_cache_test.py +++ b/tests/unit/compile_cache_test.py @@ -106,8 +106,11 @@ def test_train_step_cache_hit(): "cache was not writeable or the JAX cache configuration was ignored." ) - assert len(cache_files) == 1, ( - f"Expected exactly 1 JAX compilation cache file, but found {len(cache_files)}: {cache_files}. " + train_step_cache_files = [f for f in cache_files if f.startswith("jit_train_step")] + assert len(train_step_cache_files) == 1, ( + f"Expected exactly 1 JAX compilation cache file for 'jit_train_step', " + f"but found {len(train_step_cache_files)}: {train_step_cache_files} " + f"(all cache files: {cache_files}). " "This indicates a cache miss where AOT compilation and runtime execution generated different keys, " "causing train_step to be compiled twice (double-compilation regression)." )