Skip to content

Commit b52e1bd

Browse files
authored
[Cherry-Pick][Feature] dy-c8 prefix caching (#4918)
* c8 prefix caching * update code * update code * update cache trans * update code * update code
1 parent f637ba7 commit b52e1bd

File tree

4 files changed

+100
-5
lines changed

4 files changed

+100
-5
lines changed

custom_ops/gpu_ops/swap_cache_batch.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@ void SwapCacheImplAllLayers(const std::vector<paddle::Tensor>& cache_gpu_tensors
3535
const int64_t max_block_num_gpu = cache_shape[0];
3636
const int64_t num_heads = cache_shape[1];
3737
const int64_t block_size = cache_shape[2];
38-
const int64_t head_dim = cache_shape[3];
38+
int64_t head_dim = 1;
39+
if (cache_shape.size() == 4) {
40+
head_dim = cache_shape[3];
41+
}
3942
const int64_t cache_stride = num_heads * block_size * head_dim;
4043

4144
auto stream = cache_gpu.stream();

fastdeploy/cache_manager/cache_transfer_manager.py

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def parse_args():
8383
"--cache_dtype",
8484
type=str,
8585
default="bfloat16",
86-
choices=["uint8", "bfloat16"],
8786
help="cache dtype",
8887
)
8988
parser.add_argument(
@@ -115,6 +114,8 @@ def __init__(self, args):
115114
self.cpu_cache_kvs = {}
116115
self.gpu_cache_k_tensors = []
117116
self.gpu_cache_v_tensors = []
117+
self.gpu_cache_scales_k_tensors = []
118+
self.gpu_cache_scales_v_tensors = []
118119
self.speculative_config = SpeculativeConfig(args.speculative_config)
119120
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
120121
self.num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
@@ -126,6 +127,7 @@ def __init__(self, args):
126127
self.n_ranks = args.mp_num
127128
self.rank = rank
128129
self.device = device
130+
self.cache_dtype = args.cache_dtype
129131

130132
address = (args.pod_ip, args.cache_queue_port)
131133
self.cache_task_queue = EngineCacheQueue(
@@ -137,8 +139,11 @@ def __init__(self, args):
137139
)
138140

139141
self.num_cpu_blocks = args.num_cpu_blocks
142+
if args.cache_dtype == "block_wise_fp8":
143+
cache_type = "uint8"
144+
else:
145+
cache_type = args.cache_dtype
140146

141-
cache_type = args.cache_dtype
142147
for i in range(args.num_layers + self.num_extra_layers):
143148
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
144149

@@ -164,7 +169,6 @@ def __init__(self, args):
164169
dtype=cache_type,
165170
)
166171
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
167-
168172
set_data_ipc(
169173
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
170174
f"key_caches_{i}_rank{rank}.device{device}",
@@ -173,6 +177,32 @@ def __init__(self, args):
173177
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
174178
f"value_caches_{i}_rank{rank}.device{device}",
175179
)
180+
181+
if args.cache_dtype == "block_wise_fp8":
182+
self.gpu_cache_kvs[f"key_cache_scales_{i}_rank{rank}_device{device}"] = paddle.full(
183+
shape=[num_gpu_blocks, args.kv_num_head, args.block_size],
184+
fill_value=0,
185+
dtype=paddle.get_default_dtype(),
186+
)
187+
self.gpu_cache_kvs[f"value_cache_scales_{i}_rank{rank}_device{device}"] = paddle.full(
188+
shape=[num_gpu_blocks, args.kv_num_head, args.block_size],
189+
fill_value=0,
190+
dtype=paddle.get_default_dtype(),
191+
)
192+
self.gpu_cache_scales_k_tensors.append(
193+
self.gpu_cache_kvs[f"key_cache_scales_{i}_rank{rank}_device{device}"]
194+
)
195+
self.gpu_cache_scales_v_tensors.append(
196+
self.gpu_cache_kvs[f"value_cache_scales_{i}_rank{rank}_device{device}"]
197+
)
198+
set_data_ipc(
199+
self.gpu_cache_kvs[f"key_cache_scales_{i}_rank{rank}_device{device}"],
200+
f"key_cache_scales_{i}_rank{rank}.device{device}",
201+
)
202+
set_data_ipc(
203+
self.gpu_cache_kvs[f"value_cache_scales_{i}_rank{rank}_device{device}"],
204+
f"value_cache_scales_{i}_rank{rank}.device{device}",
205+
)
176206
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
177207
logger.info(f"device :{self.device}")
178208
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
@@ -181,6 +211,8 @@ def __init__(self, args):
181211
paddle.set_device("cpu")
182212
self.k_dst_ptrs = []
183213
self.v_dst_ptrs = []
214+
self.k_scales_ptrs = []
215+
self.v_scales_ptrs = []
184216
for i in range(args.num_layers + self.num_extra_layers):
185217
self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc(
186218
args.num_cpu_blocks * args.bytes_per_layer_per_block
@@ -190,6 +222,14 @@ def __init__(self, args):
190222
args.num_cpu_blocks * args.bytes_per_layer_per_block
191223
)
192224
self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"])
225+
self.cpu_cache_kvs[f"key_caches_scales_{i}_rank{rank}"] = cuda_host_alloc(
226+
args.num_cpu_blocks * args.bytes_per_layer_per_block
227+
)
228+
self.k_scales_ptrs.append(self.cpu_cache_kvs[f"key_caches_scales_{i}_rank{rank}"])
229+
self.cpu_cache_kvs[f"value_caches_scales_{i}_rank{rank}"] = cuda_host_alloc(
230+
args.num_cpu_blocks * args.bytes_per_layer_per_block
231+
)
232+
self.v_scales_ptrs.append(self.cpu_cache_kvs[f"value_caches_scales_{i}_rank{rank}"])
193233

194234
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
195235
self.cache_ready_signal = IPCSignal(
@@ -388,6 +428,25 @@ def _transfer_data(
388428
self.device,
389429
0,
390430
)
431+
if self.cache_dtype == "block_wise_fp8":
432+
swap_cache_all_layers(
433+
self.gpu_cache_scales_k_tensors,
434+
self.k_scales_ptrs,
435+
self.num_cpu_blocks,
436+
gpu_block_ids,
437+
cpu_block_ids,
438+
self.device,
439+
0,
440+
)
441+
swap_cache_all_layers(
442+
self.gpu_cache_scales_v_tensors,
443+
self.v_scales_ptrs,
444+
self.num_cpu_blocks,
445+
gpu_block_ids,
446+
cpu_block_ids,
447+
self.device,
448+
0,
449+
)
391450

392451
elif event_type.value == CacheStatus.SWAP2GPU.value:
393452
swap_cache_all_layers(
@@ -408,6 +467,25 @@ def _transfer_data(
408467
self.device,
409468
1,
410469
)
470+
if self.cache_dtype == "block_wise_fp8":
471+
swap_cache_all_layers(
472+
self.gpu_cache_scales_k_tensors,
473+
self.k_scales_ptrs,
474+
self.num_cpu_blocks,
475+
gpu_block_ids,
476+
cpu_block_ids,
477+
self.device,
478+
1,
479+
)
480+
swap_cache_all_layers(
481+
self.gpu_cache_scales_v_tensors,
482+
self.v_scales_ptrs,
483+
self.num_cpu_blocks,
484+
gpu_block_ids,
485+
cpu_block_ids,
486+
self.device,
487+
1,
488+
)
411489
else:
412490
logger.warning(
413491
f"transfer data: Get unexpected event type {event_type}, only SWAP2CPU and SWAP2GPU supported"

fastdeploy/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -997,7 +997,9 @@ def __init__(self, args):
997997
self.enable_hierarchical_cache = True
998998

999999
if self.model_cfg is not None:
1000-
if self.model_cfg.quantization_config is not None:
1000+
if self.model_cfg.quantization is not None and isinstance(self.model_cfg.quantization, dict):
1001+
self.cache_dtype = self.model_cfg.quantization.get("kv_cache_quant_type", self.cache_dtype)
1002+
elif self.model_cfg.quantization_config is not None:
10011003
self.cache_dtype = self.model_cfg.quantization_config.get("kv_cache_quant_type", self.cache_dtype)
10021004
if (
10031005
hasattr(self.model_cfg, "num_key_value_heads")

fastdeploy/worker/gpu_model_runner.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,18 @@ def initialize_kv_cache(self, profile: bool = False) -> None:
10501050
value_cache = share_external_data(value_cache, val_cache_name, kv_cache_shape)
10511051
cache_kvs_list.append(value_cache)
10521052

1053+
if kv_cache_quant_type == "block_wise_fp8":
1054+
scale_key_cache_name = f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
1055+
scale_val_cache_name = f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
1056+
key_scale_cache = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
1057+
key_scale_cache = share_external_data(key_scale_cache, scale_key_cache_name, kv_cache_scale_shape)
1058+
cache_kvs_list.append(key_scale_cache)
1059+
value_scale_cache = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
1060+
value_scale_cache = share_external_data(
1061+
value_scale_cache, scale_val_cache_name, kv_cache_scale_shape
1062+
)
1063+
cache_kvs_list.append(value_scale_cache)
1064+
10531065
self.share_inputs["caches"] = cache_kvs_list
10541066
else:
10551067
for i in range(self.model_config.num_hidden_layers):

0 commit comments

Comments
 (0)