From 5cfc187ec13cae376e8a70f712e7fdedfff79726 Mon Sep 17 00:00:00 2001 From: xiziqiao Date: Tue, 21 Apr 2026 13:36:38 -0700 Subject: [PATCH] fix: avoid clearing in-flight pipeline states in custom kernel cache write_signature() generates different source code for the same kernel name when input dtypes change. The old code detected this source mismatch and called clear_library(), which deallocates cached pipeline states that may still be referenced by an in-flight command buffer, causing use-after-free (Metal validation: 'command buffer references deallocated object'). Fix: use source-hash-dependent library cache keys so different source variants coexist without evicting each other. Removes the clear_library path entirely. Fixes #3347 --- mlx/backend/metal/custom_kernel.cpp | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 6d33ff5007..d729f8967e 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -361,22 +361,16 @@ void CustomKernel::eval_gpu( auto& d = metal::device(s.device); - { - // Clear kernels from the device library cache if needed - auto& kernel_cache = cache(); - if (auto it = kernel_cache.libraries.find(name_); - it != kernel_cache.libraries.end()) { - if (it->second != source_) { - auto& d = metal::device(s.device); - d.clear_library(name_); - it->second = source_; - } - } else { - kernel_cache.libraries.emplace(name_, source_); - } - } + // Use a source-dependent library key so different source variants + // (e.g. from write_signature picking different dtype qualifiers across + // calls) coexist without evicting each other. Clearing a library while + // its pipeline states are still referenced by an in-flight command + // buffer causes use-after-free. + auto source_hash = std::hash{}(source_); + auto lib_key = name_ + "_" + std::to_string(source_hash); - auto lib = d.get_library(name_, [this] { return metal::utils() + source_; }); + auto lib = + d.get_library(lib_key, [this] { return metal::utils() + source_; }); auto kernel = d.get_kernel(name_, lib); auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel);