@@ -319,6 +319,11 @@ def _init_cpu_cache(self, args):
319319 raise ValueError (f"Unsupported cache dtype: { args .cache_dtype } " )
320320 key_need_to_allocate_bytes = args .num_cpu_blocks * cache_bytes * key_cache_size
321321 value_need_to_allocate_bytes = args .num_cpu_blocks * cache_bytes * value_cache_size
322+ if args .cache_dtype == "block_wise_fp8" :
323+ cache_scales = paddle .empty (shape = [], dtype = paddle .get_default_dtype ())
324+ cache_scales_size = self .key_cache_shape [1 ] * self .key_cache_shape [2 ]
325+ scales_key_need_to_allocate_bytes = args .num_cpu_blocks * cache_scales .element_size () * cache_scales_size
326+ scales_value_need_to_allocate_bytes = args .num_cpu_blocks * cache_scales .element_size () * cache_scales_size
322327 logger .info (
323328 f"[rank { self .rank } /{ self .n_ranks } ] ..swap space size : { (key_need_to_allocate_bytes + value_need_to_allocate_bytes ) / 1024 ** 3 :.2f} GB"
324329 )
@@ -343,13 +348,13 @@ def _init_cpu_cache(self, args):
343348 self .cpu_cache_kvs [key_name ] = cuda_host_alloc (key_need_to_allocate_bytes )
344349 self .k_dst_ptrs .append (self .cpu_cache_kvs [key_name ])
345350 if args .cache_dtype == "block_wise_fp8" :
346- self .cpu_cache_kvs [key_cache_scales_name ] = cuda_host_alloc (key_need_to_allocate_bytes )
351+ self .cpu_cache_kvs [key_cache_scales_name ] = cuda_host_alloc (scales_key_need_to_allocate_bytes )
347352 self .k_scales_ptrs .append (self .cpu_cache_kvs [key_cache_scales_name ])
348353 if value_need_to_allocate_bytes > 0 :
349354 self .cpu_cache_kvs [val_name ] = cuda_host_alloc (value_need_to_allocate_bytes )
350355 self .v_dst_ptrs .append (self .cpu_cache_kvs [val_name ])
351356 if args .cache_dtype == "block_wise_fp8" :
352- self .cpu_cache_kvs [value_cache_scales_name ] = cuda_host_alloc (value_need_to_allocate_bytes )
357+ self .cpu_cache_kvs [value_cache_scales_name ] = cuda_host_alloc (scales_value_need_to_allocate_bytes )
353358 self .v_scales_ptrs .append (self .cpu_cache_kvs [value_cache_scales_name ])
354359 logger .info (f"[rank { self .rank } /{ self .n_ranks } ] ✅ swap space (cpu cache) is ready!" )
355360 self .swap_space_ready_signal .value [self .rank ] = 1
0 commit comments