diff --git a/fastdeploy/cache_manager/v1/cache_controller.py b/fastdeploy/cache_manager/v1/cache_controller.py index 53b7292179f..049a261ae2b 100644 --- a/fastdeploy/cache_manager/v1/cache_controller.py +++ b/fastdeploy/cache_manager/v1/cache_controller.py @@ -104,6 +104,9 @@ def __init__(self, config: "FDConfig", local_rank: int, device_id: int): # NUMA binding flag self._numa_bound = False + self._is_mla = getattr(self.model_config, "kv_lora_rank", 0) > 0 + self._is_dsa = self._is_mla and getattr(self.model_config, "index_head_dim", 0) > 0 + @property def write_policy(self) -> Optional[str]: """Get the write policy for cache operations.""" @@ -230,13 +233,22 @@ def _get_cache_names(self, layer_idx: int) -> Dict[str, str]: """ local_rank = self._local_rank % self.parallel_config.tensor_parallel_size - return { - "key": f"key_caches_{layer_idx}_rank{local_rank}.device{self._device_id}", - "value": f"value_caches_{layer_idx}_rank{local_rank}.device{self._device_id}", - "key_scale": f"key_cache_scales_{layer_idx}_rank{local_rank}.device{self._device_id}", - "value_scale": f"value_cache_scales_{layer_idx}_rank{local_rank}.device{self._device_id}", + names = { + "key": f"key_cache_{layer_idx}_rank{local_rank}.device{self._device_id}", } + if self._is_dsa: + names["indexer"] = f"indexer_caches_{layer_idx}_rank{local_rank}.device{self._device_id}" + elif self._is_mla: + pass # MLA: only key, no value, no indexer + else: + # GQA/MHA: key + value + optional scales + names["value"] = f"value_caches_{layer_idx}_rank{local_rank}.device{self._device_id}" + names["key_scale"] = f"key_cache_scales_{layer_idx}_rank{local_rank}.device{self._device_id}" + names["value_scale"] = f"value_cache_scales_{layer_idx}_rank{local_rank}.device{self._device_id}" + + return names + # ============ KV Cache Management ============ def get_kv_caches(self) -> Optional[Dict[str, Any]]: @@ -255,7 +267,7 @@ def initialize_kv_cache( num_gpu_blocks: int, ) -> List[Any]: """ - Initialize KV Cache tensors. + Initialize KV Cache tensors (GQA/MHA only). Create KV Cache tensors on GPU for storing attention Key and Value. @@ -266,37 +278,40 @@ def initialize_kv_cache( Returns: cache_kvs_list: KV Cache tensor list in [key_cache_layer0, value_cache_layer0, ...] order. """ - # Get kv cache quantization type - kv_cache_quant_type = self._get_kv_cache_quant_type() + # Dispatch to specialized initializers for MLA/DSA + if self._is_dsa: + return self.initialize_dsa_kv_cache(attn_backend, num_gpu_blocks) + elif self._is_mla: + return self.initialize_mla_kv_cache(attn_backend, num_gpu_blocks) - # Get kv cache shape + # GQA/MHA path + kv_cache_quant_type = self._get_kv_cache_quant_type() key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape( max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type ) + cache_dtype = self.model_config.dtype - # Get scale shape for block_wise_fp8 quantization + # Scale shape for block_wise_fp8 quantization kv_cache_scale_shape = None if self._is_fp8_quantization(kv_cache_quant_type): kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]] - logger.info(f"Initializing kv cache for all layers. num_layers={self._num_layers}") + logger.info( + f"Initializing GQA kv cache: num_layers={self._num_layers}, " + f"key_shape={key_cache_shape}, value_shape={value_cache_shape}" + ) cache_kvs_list = [] for i in range(self._num_layers): - # Generate cache names cache_names = self._get_cache_names(i) - logger.info(f"..creating kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}") - - # Create key cache and value cache - key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=self.model_config.dtype) + key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype) self.cache_kvs_map[cache_names["key"]] = key_cache - val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=self.model_config.dtype) + val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_dtype) self.cache_kvs_map[cache_names["value"]] = val_cache cache_kvs_list.extend([key_cache, val_cache]) - # Create scale caches for block_wise_fp8 quantization if self._is_fp8_quantization(kv_cache_quant_type) and kv_cache_scale_shape: key_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() @@ -309,14 +324,108 @@ def initialize_kv_cache( cache_kvs_list.extend([key_cache_scales, val_cache_scales]) paddle.device.cuda.empty_cache() - logger.info("kv cache is initialized!") + logger.info("GQA kv cache initialized!") - # Share cache_kvs_map with transfer manager for data transfer operations self._transfer_manager.set_cache_kvs_map(self.cache_kvs_map) - - # Initialize host cache self.initialize_host_cache(attn_backend) + return cache_kvs_list + + def initialize_mla_kv_cache( + self, + attn_backend: Any, + num_gpu_blocks: int, + ) -> List[Any]: + """ + Initialize MLA KV Cache tensors (key only, no value). + + Args: + attn_backend: Attention backend instance for getting kv cache shape. + num_gpu_blocks: Maximum number of blocks on GPU. + Returns: + cache_kvs_list: KV Cache tensor list in [key_layer0, key_layer1, ...] order. + """ + kv_cache_quant_type = self._get_kv_cache_quant_type() + key_cache_shape, _ = attn_backend.get_kv_cache_shape( + max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type + ) + cache_dtype = self.model_config.dtype + + # NOTE: set_data_ipc pins tensor storage so paddle allocator cannot + # reuse/migrate it. Without pinning, CUDAGraph capture records a + # data_ptr that allocator may later mark reusable, corrupting replay. + # Align with V0 path (gpu_model_runner.initialize_kv_cache). + from fastdeploy.model_executor.ops.gpu import set_data_ipc + + logger.info(f"Initializing MLA kv cache: num_layers={self._num_layers}, " f"key_shape={key_cache_shape}") + cache_kvs_list = [] + + for i in range(self._num_layers): + cache_names = self._get_cache_names(i) + + key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype) + set_data_ipc(key_cache, cache_names["key"]) + self.cache_kvs_map[cache_names["key"]] = key_cache + cache_kvs_list.append(key_cache) + + paddle.device.cuda.empty_cache() + logger.info("MLA kv cache initialized!") + + self._transfer_manager.set_cache_kvs_map(self.cache_kvs_map) + return cache_kvs_list + + def initialize_dsa_kv_cache( + self, + attn_backend: Any, + num_gpu_blocks: int, + ) -> List[Any]: + """ + Initialize DSA KV Cache tensors (key + indexer, two pools). + + Creates interleaved [key, indexer, key, indexer, ...] layout. + Future HiSparse extension: add host_blocks parameter for key host backup. + + Args: + attn_backend: Attention backend instance for getting kv cache shape. + num_gpu_blocks: Maximum number of blocks on GPU. + + Returns: + cache_kvs_list: KV Cache tensor list in [key_layer0, indexer_layer0, ...] order. + """ + key_cache_shape, _, indexer_cache_shape = attn_backend.get_kv_cache_shape( + max_num_blocks=num_gpu_blocks, kv_cache_quant_type="uint8" + ) + cache_dtype = "uint8" + + # NOTE: set_data_ipc pins tensor storage so paddle allocator cannot + # reuse/migrate it. Without pinning, CUDAGraph capture records a + # data_ptr that allocator may later mark reusable, corrupting replay. + # Align with V0 path (gpu_model_runner.initialize_kv_cache). + from fastdeploy.model_executor.ops.gpu import set_data_ipc + + logger.info( + f"Initializing DSA kv cache: num_layers={self._num_layers}, " + f"key_shape={key_cache_shape}, indexer_shape={indexer_cache_shape}" + ) + cache_kvs_list = [] + + for i in range(self._num_layers): + cache_names = self._get_cache_names(i) + + key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype) + set_data_ipc(key_cache, cache_names["key"]) + self.cache_kvs_map[cache_names["key"]] = key_cache + + indexer_cache = paddle.full(shape=indexer_cache_shape, fill_value=0, dtype=cache_dtype) + set_data_ipc(indexer_cache, cache_names["indexer"]) + self.cache_kvs_map[cache_names["indexer"]] = indexer_cache + + cache_kvs_list.extend([key_cache, indexer_cache]) + + paddle.device.cuda.empty_cache() + logger.info("DSA kv cache initialized!") + + self._transfer_manager.set_cache_kvs_map(self.cache_kvs_map) return cache_kvs_list def initialize_mtp_kv_cache( @@ -346,31 +455,50 @@ def initialize_mtp_kv_cache( """ kv_cache_quant_type = self._get_kv_cache_quant_type() - key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape( - max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type - ) + if self._is_dsa: + kv_cache_quant_type = "uint8" + key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape( + max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type + ) + cache_dtype = "uint8" + else: + key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape( + max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type + ) + indexer_cache_shape = [] + cache_dtype = self.model_config.dtype kv_cache_scale_shape = None - if self._is_fp8_quantization(kv_cache_quant_type): + if not self._is_mla and self._is_fp8_quantization(kv_cache_quant_type): kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]] logger.info( f"[CacheController] Initializing MTP kv cache for {num_mtp_layers} layers " f"(layer_offset={layer_offset}, num_gpu_blocks={num_gpu_blocks})." + f"is_dsa = {self._is_dsa}, _is_mla = {self._is_mla}." ) cache_kvs_list = [] for i in range(layer_offset, layer_offset + num_mtp_layers): cache_names = self._get_cache_names(i) - key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=self.model_config.dtype) + key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_dtype) self.cache_kvs_map[cache_names["key"]] = key_cache - val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=self.model_config.dtype) - self.cache_kvs_map[cache_names["value"]] = val_cache - cache_kvs_list.extend([key_cache, val_cache]) - - if self._is_fp8_quantization(kv_cache_quant_type) and kv_cache_scale_shape: + if value_cache_shape: + val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_dtype) + self.cache_kvs_map[cache_names["value"]] = val_cache + cache_kvs_list.extend([key_cache, val_cache]) + elif indexer_cache_shape: + # DSA: key + indexer + indexer_cache = paddle.full(shape=indexer_cache_shape, fill_value=0, dtype=cache_dtype) + self.cache_kvs_map[cache_names["indexer"]] = indexer_cache + cache_kvs_list.extend([key_cache, indexer_cache]) + else: + # MLA: only key, no value, no indexer + cache_kvs_list.append(key_cache) + + if not self._is_mla and self._is_fp8_quantization(kv_cache_quant_type) and kv_cache_scale_shape: key_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() ) @@ -542,9 +670,16 @@ def initialize_host_cache( kv_cache_quant_type = self._get_kv_cache_quant_type() # Get kv cache shape (pass num_host_blocks as max_num_blocks for host cache) - key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape( - max_num_blocks=num_host_blocks, kv_cache_quant_type=kv_cache_quant_type - ) + if self._is_dsa: + kv_cache_quant_type = "uint8" + key_cache_shape, _, indexer_cache_shape = attn_backend.get_kv_cache_shape( + max_num_blocks=num_host_blocks, kv_cache_quant_type=kv_cache_quant_type + ) + value_cache_shape = [] + else: + key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape( + max_num_blocks=num_host_blocks, kv_cache_quant_type=kv_cache_quant_type + ) # Calculate cache sizes (elements per block per layer) key_cache_size = key_cache_shape[1] * key_cache_shape[2] * key_cache_shape[3] @@ -554,8 +689,11 @@ def initialize_host_cache( value_cache_size = 0 # Get cache dtype and bytes per element - cache_dtype = self.cache_config.cache_dtype - cache_item_bytes = self.cache_config.get_cache_bytes(cache_dtype) + if self._is_dsa: + cache_item_bytes = 1 + else: + cache_dtype = self.cache_config.cache_dtype + cache_item_bytes = self.cache_config.get_cache_bytes(cache_dtype) # Calculate total bytes to allocate key_need_to_allocate_bytes = num_host_blocks * cache_item_bytes * key_cache_size diff --git a/fastdeploy/cache_manager/v1/transfer_manager.py b/fastdeploy/cache_manager/v1/transfer_manager.py index f4ed0bb6539..85992abbf96 100644 --- a/fastdeploy/cache_manager/v1/transfer_manager.py +++ b/fastdeploy/cache_manager/v1/transfer_manager.py @@ -130,6 +130,10 @@ def __init__( self._storage_connector = create_storage_connector(self.cache_config) self._transfer_connector = create_transfer_connector(self.cache_config) + # ============ MLA & DSA ============ + self._is_mla = getattr(config.model_config, "kv_lora_rank", 0) > 0 + self._is_dsa = self._is_mla and getattr(config.model_config, "index_head_dim", 0) > 0 + # ============ Cache Map Setters ============ @property @@ -169,17 +173,24 @@ def _build_device_layer_indices(self) -> None: self._device_value_scales = [] for layer_idx in range(self._num_layers): - key_name = f"key_caches_{layer_idx}_rank{self._local_rank}.device{self._device_id}" - val_name = f"value_caches_{layer_idx}_rank{self._local_rank}.device{self._device_id}" - key_scale_name = f"key_cache_scales_{layer_idx}_rank{self._local_rank}.device{self._device_id}" - val_scale_name = f"value_cache_scales_{layer_idx}_rank{self._local_rank}.device{self._device_id}" - + key_name = f"key_cache_{layer_idx}_rank{self._local_rank}.device{self._device_id}" self._device_key_caches.append(self._cache_kvs_map.get(key_name)) - self._device_value_caches.append(self._cache_kvs_map.get(val_name)) - if self._is_fp8_quantization(): - self._device_key_scales.append(self._cache_kvs_map.get(key_scale_name)) - self._device_value_scales.append(self._cache_kvs_map.get(val_scale_name)) + if self._is_dsa: + # DSA: indexer treated as "value" for swap purposes + idx_name = f"indexer_caches_{layer_idx}_rank{self._local_rank}.device{self._device_id}" + self._device_value_caches.append(self._cache_kvs_map.get(idx_name)) + elif not self._is_mla: + # GQA: has value caches + val_name = f"value_caches_{layer_idx}_rank{self._local_rank}.device{self._device_id}" + self._device_value_caches.append(self._cache_kvs_map.get(val_name)) + + if self._is_fp8_quantization(): + key_scale_name = f"key_cache_scales_{layer_idx}_rank{self._local_rank}.device{self._device_id}" + val_scale_name = f"value_cache_scales_{layer_idx}_rank{self._local_rank}.device{self._device_id}" + self._device_key_scales.append(self._cache_kvs_map.get(key_scale_name)) + self._device_value_scales.append(self._cache_kvs_map.get(val_scale_name)) + # MLA: no value caches to add @property def host_cache_kvs_map(self) -> Dict[str, Any]: @@ -215,17 +226,24 @@ def _build_host_layer_indices(self) -> None: self._host_value_scales_ptrs = [] for layer_idx in range(self._num_layers): - key_name = f"key_caches_{layer_idx}_rank{self._local_rank}.device{self._device_id}" - val_name = f"value_caches_{layer_idx}_rank{self._local_rank}.device{self._device_id}" - key_scale_name = f"key_cache_scales_{layer_idx}_rank{self._local_rank}.device{self._device_id}" - val_scale_name = f"value_cache_scales_{layer_idx}_rank{self._local_rank}.device{self._device_id}" - + key_name = f"key_cache_{layer_idx}_rank{self._local_rank}.device{self._device_id}" self._host_key_ptrs.append(self._host_cache_kvs_map.get(key_name, 0)) - self._host_value_ptrs.append(self._host_cache_kvs_map.get(val_name, 0)) - if self._is_fp8_quantization(): - self._host_key_scales_ptrs.append(self._host_cache_kvs_map.get(key_scale_name, 0)) - self._host_value_scales_ptrs.append(self._host_cache_kvs_map.get(val_scale_name, 0)) + if self._is_dsa: + # DSA: indexer treated as "value" for swap purposes + idx_name = f"indexer_caches_{layer_idx}_rank{self._local_rank}.device{self._device_id}" + self._host_value_ptrs.append(self._host_cache_kvs_map.get(idx_name, 0)) + elif not self._is_mla: + # GQA: has value host cache + val_name = f"value_caches_{layer_idx}_rank{self._local_rank}.device{self._device_id}" + self._host_value_ptrs.append(self._host_cache_kvs_map.get(val_name, 0)) + + if self._is_fp8_quantization(): + key_scale_name = f"key_cache_scales_{layer_idx}_rank{self._local_rank}.device{self._device_id}" + val_scale_name = f"value_cache_scales_{layer_idx}_rank{self._local_rank}.device{self._device_id}" + self._host_key_scales_ptrs.append(self._host_cache_kvs_map.get(key_scale_name, 0)) + self._host_value_scales_ptrs.append(self._host_cache_kvs_map.get(val_scale_name, 0)) + # MLA: no value host pointers to add # ============ Metadata Properties ============ @@ -329,16 +347,24 @@ def _swap_all_layers( self._device_id, mode, ) - swap_cache_all_layers( - self._device_value_caches, - self._host_value_ptrs, - self._num_host_blocks, - device_block_ids, - host_block_ids, - self._device_id, - mode, - ) - if self._is_fp8_quantization() and self._device_key_scales and self._host_key_scales_ptrs: + # Value cache is only used in GQA + if not self._is_mla and self._device_value_caches: + swap_cache_all_layers( + self._device_value_caches, + self._host_value_ptrs, + self._num_host_blocks, + device_block_ids, + host_block_ids, + self._device_id, + mode, + ) + # Scale cache is only used in GQA + fp8 quantization + if ( + not self._is_mla + and self._is_fp8_quantization() + and self._device_key_scales + and self._host_key_scales_ptrs + ): swap_cache_all_layers( self._device_key_scales, self._host_key_scales_ptrs, @@ -389,13 +415,10 @@ def _swap_single_layer( try: key_cache = self.get_device_key_cache(layer_idx) - value_cache = self.get_device_value_cache(layer_idx) - if key_cache is None or value_cache is None: + if key_cache is None: return False - key_ptr = self.get_host_key_ptr(layer_idx) - value_ptr = self.get_host_value_ptr(layer_idx) - if key_ptr == 0 or value_ptr == 0: + if key_ptr == 0: return False swap_cache_per_layer( @@ -407,15 +430,21 @@ def _swap_single_layer( self._device_id, mode, ) - swap_cache_per_layer( - value_cache, - value_ptr, - self._num_host_blocks, - device_block_ids, - host_block_ids, - self._device_id, - mode, - ) + + if not self._is_mla or self._is_dsa: + value_cache = self.get_device_value_cache(layer_idx) + value_ptr = self.get_host_value_ptr(layer_idx) + if value_cache is None or value_ptr == 0: + return False + swap_cache_per_layer( + value_cache, + value_ptr, + self._num_host_blocks, + device_block_ids, + host_block_ids, + self._device_id, + mode, + ) return True except Exception: import traceback @@ -466,16 +495,24 @@ def _swap_all_layers_async( self._device_id, mode, ) - swap_cache_all_layers( - self._device_value_caches, - self._host_value_ptrs, - self._num_host_blocks, - device_block_ids, - host_block_ids, - self._device_id, - mode, - ) - if self._is_fp8_quantization() and self._device_key_scales and self._host_key_scales_ptrs: + # Value/indexer cache: GQA has value, DSA has indexer (both in _device_value_caches) + # MLA has neither, so _device_value_caches is empty + if self._device_value_caches and self._host_value_ptrs: + swap_cache_all_layers( + self._device_value_caches, + self._host_value_ptrs, + self._num_host_blocks, + device_block_ids, + host_block_ids, + self._device_id, + mode, + ) + if ( + not self._is_mla + and self._is_fp8_quantization() + and self._device_key_scales + and self._host_key_scales_ptrs + ): swap_cache_all_layers( self._device_key_scales, self._host_key_scales_ptrs, @@ -527,13 +564,10 @@ def _swap_single_layer_async( stream = self._output_stream if mode == 0 else self._input_stream key_cache = self.get_device_key_cache(layer_idx) - value_cache = self.get_device_value_cache(layer_idx) - if key_cache is None or value_cache is None: + if key_cache is None: return False - key_ptr = self.get_host_key_ptr(layer_idx) - value_ptr = self.get_host_value_ptr(layer_idx) - if key_ptr == 0 or value_ptr == 0: + if key_ptr == 0: return False try: @@ -548,15 +582,22 @@ def _swap_single_layer_async( self._device_id, mode, ) - swap_cache_per_layer_async( - value_cache, - value_ptr, - self._num_host_blocks, - device_block_ids, - host_block_ids, - self._device_id, - mode, - ) + + if not self._is_mla or self._is_dsa: + # GQA: swap value; DSA: swap indexer (stored in value slot) + value_cache = self.get_device_value_cache(layer_idx) + value_ptr = self.get_host_value_ptr(layer_idx) + if value_cache is None or value_ptr == 0: + return False + swap_cache_per_layer_async( + value_cache, + value_ptr, + self._num_host_blocks, + device_block_ids, + host_block_ids, + self._device_id, + mode, + ) return True except Exception: import traceback