Skip to content

Commit 183c0c3

Browse files
committed
update code
1 parent babd2d8 commit 183c0c3

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

fastdeploy/cache_manager/cache_transfer_manager.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,11 @@ def _init_cpu_cache(self, args):
319319
raise ValueError(f"Unsupported cache dtype: {args.cache_dtype}")
320320
key_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * key_cache_size
321321
value_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * value_cache_size
322+
if args.cache_dtype == "block_wise_fp8":
323+
cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
324+
cache_scales_size = self.key_cache_shape[1] * self.key_cache_shape[2]
325+
scales_key_need_to_allocate_bytes = args.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
326+
scales_value_need_to_allocate_bytes = args.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
322327
logger.info(
323328
f"[rank {self.rank}/{self.n_ranks}] ..swap space size : {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
324329
)
@@ -343,13 +348,13 @@ def _init_cpu_cache(self, args):
343348
self.cpu_cache_kvs[key_name] = cuda_host_alloc(key_need_to_allocate_bytes)
344349
self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name])
345350
if args.cache_dtype == "block_wise_fp8":
346-
self.cpu_cache_kvs[key_cache_scales_name] = cuda_host_alloc(key_need_to_allocate_bytes)
351+
self.cpu_cache_kvs[key_cache_scales_name] = cuda_host_alloc(scales_key_need_to_allocate_bytes)
347352
self.k_scales_ptrs.append(self.cpu_cache_kvs[key_cache_scales_name])
348353
if value_need_to_allocate_bytes > 0:
349354
self.cpu_cache_kvs[val_name] = cuda_host_alloc(value_need_to_allocate_bytes)
350355
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
351356
if args.cache_dtype == "block_wise_fp8":
352-
self.cpu_cache_kvs[value_cache_scales_name] = cuda_host_alloc(value_need_to_allocate_bytes)
357+
self.cpu_cache_kvs[value_cache_scales_name] = cuda_host_alloc(scales_value_need_to_allocate_bytes)
353358
self.v_scales_ptrs.append(self.cpu_cache_kvs[value_cache_scales_name])
354359
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!")
355360
self.swap_space_ready_signal.value[self.rank] = 1

0 commit comments

Comments
 (0)