@@ -83,7 +83,6 @@ def parse_args():
8383 "--cache_dtype" ,
8484 type = str ,
8585 default = "bfloat16" ,
86- choices = ["uint8" , "bfloat16" ],
8786 help = "cache dtype" ,
8887 )
8988 parser .add_argument (
@@ -115,6 +114,8 @@ def __init__(self, args):
115114 self .cpu_cache_kvs = {}
116115 self .gpu_cache_k_tensors = []
117116 self .gpu_cache_v_tensors = []
117+ self .gpu_cache_scales_k_tensors = []
118+ self .gpu_cache_scales_v_tensors = []
118119 self .speculative_config = SpeculativeConfig (args .speculative_config )
119120 self .num_extra_layers = self .speculative_config .num_extra_cache_layer
120121 self .num_extra_layer_gpu_blocks = int (args .num_gpu_blocks * self .speculative_config .num_gpu_block_expand_ratio )
@@ -126,6 +127,7 @@ def __init__(self, args):
126127 self .n_ranks = args .mp_num
127128 self .rank = rank
128129 self .device = device
130+ self .cache_dtype = args .cache_dtype
129131
130132 address = (args .pod_ip , args .cache_queue_port )
131133 self .cache_task_queue = EngineCacheQueue (
@@ -137,8 +139,11 @@ def __init__(self, args):
137139 )
138140
139141 self .num_cpu_blocks = args .num_cpu_blocks
142+ if args .cache_dtype == "block_wise_fp8" :
143+ cache_type = "uint8"
144+ else :
145+ cache_type = args .cache_dtype
140146
141- cache_type = args .cache_dtype
142147 for i in range (args .num_layers + self .num_extra_layers ):
143148 num_gpu_blocks = args .num_gpu_blocks if i < args .num_layers else self .num_extra_layer_gpu_blocks
144149
@@ -164,7 +169,6 @@ def __init__(self, args):
164169 dtype = cache_type ,
165170 )
166171 self .gpu_cache_v_tensors .append (self .gpu_cache_kvs [f"value_caches_{ i } _rank{ rank } _device{ device } " ])
167-
168172 set_data_ipc (
169173 self .gpu_cache_kvs [f"key_caches_{ i } _rank{ rank } _device{ device } " ],
170174 f"key_caches_{ i } _rank{ rank } .device{ device } " ,
@@ -173,6 +177,32 @@ def __init__(self, args):
173177 self .gpu_cache_kvs [f"value_caches_{ i } _rank{ rank } _device{ device } " ],
174178 f"value_caches_{ i } _rank{ rank } .device{ device } " ,
175179 )
180+
181+ if args .cache_dtype == "block_wise_fp8" :
182+ self .gpu_cache_kvs [f"key_cache_scales_{ i } _rank{ rank } _device{ device } " ] = paddle .full (
183+ shape = [num_gpu_blocks , args .kv_num_head , args .block_size ],
184+ fill_value = 0 ,
185+ dtype = paddle .get_default_dtype (),
186+ )
187+ self .gpu_cache_kvs [f"value_cache_scales_{ i } _rank{ rank } _device{ device } " ] = paddle .full (
188+ shape = [num_gpu_blocks , args .kv_num_head , args .block_size ],
189+ fill_value = 0 ,
190+ dtype = paddle .get_default_dtype (),
191+ )
192+ self .gpu_cache_scales_k_tensors .append (
193+ self .gpu_cache_kvs [f"key_cache_scales_{ i } _rank{ rank } _device{ device } " ]
194+ )
195+ self .gpu_cache_scales_v_tensors .append (
196+ self .gpu_cache_kvs [f"value_cache_scales_{ i } _rank{ rank } _device{ device } " ]
197+ )
198+ set_data_ipc (
199+ self .gpu_cache_kvs [f"key_cache_scales_{ i } _rank{ rank } _device{ device } " ],
200+ f"key_cache_scales_{ i } _rank{ rank } .device{ device } " ,
201+ )
202+ set_data_ipc (
203+ self .gpu_cache_kvs [f"value_cache_scales_{ i } _rank{ rank } _device{ device } " ],
204+ f"value_cache_scales_{ i } _rank{ rank } .device{ device } " ,
205+ )
176206 cache_kv_size_byte = sum ([tmp .numel () * 1 for key , tmp in self .gpu_cache_kvs .items ()])
177207 logger .info (f"device :{ self .device } " )
178208 logger .info (f"cache_kv_size_byte : { cache_kv_size_byte } " )
@@ -181,6 +211,8 @@ def __init__(self, args):
181211 paddle .set_device ("cpu" )
182212 self .k_dst_ptrs = []
183213 self .v_dst_ptrs = []
214+ self .k_scales_ptrs = []
215+ self .v_scales_ptrs = []
184216 for i in range (args .num_layers + self .num_extra_layers ):
185217 self .cpu_cache_kvs [f"key_caches_{ i } _rank{ rank } " ] = cuda_host_alloc (
186218 args .num_cpu_blocks * args .bytes_per_layer_per_block
@@ -190,6 +222,14 @@ def __init__(self, args):
190222 args .num_cpu_blocks * args .bytes_per_layer_per_block
191223 )
192224 self .v_dst_ptrs .append (self .cpu_cache_kvs [f"value_caches_{ i } _rank{ rank } " ])
225+ self .cpu_cache_kvs [f"key_caches_scales_{ i } _rank{ rank } " ] = cuda_host_alloc (
226+ args .num_cpu_blocks * args .bytes_per_layer_per_block
227+ )
228+ self .k_scales_ptrs .append (self .cpu_cache_kvs [f"key_caches_scales_{ i } _rank{ rank } " ])
229+ self .cpu_cache_kvs [f"value_caches_scales_{ i } _rank{ rank } " ] = cuda_host_alloc (
230+ args .num_cpu_blocks * args .bytes_per_layer_per_block
231+ )
232+ self .v_scales_ptrs .append (self .cpu_cache_kvs [f"value_caches_scales_{ i } _rank{ rank } " ])
193233
194234 cache_ready_signal_data = np .zeros (shape = [args .mp_num ], dtype = np .int32 )
195235 self .cache_ready_signal = IPCSignal (
@@ -388,6 +428,25 @@ def _transfer_data(
388428 self .device ,
389429 0 ,
390430 )
431+ if self .cache_dtype == "block_wise_fp8" :
432+ swap_cache_all_layers (
433+ self .gpu_cache_scales_k_tensors ,
434+ self .k_scales_ptrs ,
435+ self .num_cpu_blocks ,
436+ gpu_block_ids ,
437+ cpu_block_ids ,
438+ self .device ,
439+ 0 ,
440+ )
441+ swap_cache_all_layers (
442+ self .gpu_cache_scales_v_tensors ,
443+ self .v_scales_ptrs ,
444+ self .num_cpu_blocks ,
445+ gpu_block_ids ,
446+ cpu_block_ids ,
447+ self .device ,
448+ 0 ,
449+ )
391450
392451 elif event_type .value == CacheStatus .SWAP2GPU .value :
393452 swap_cache_all_layers (
@@ -408,6 +467,25 @@ def _transfer_data(
408467 self .device ,
409468 1 ,
410469 )
470+ if self .cache_dtype == "block_wise_fp8" :
471+ swap_cache_all_layers (
472+ self .gpu_cache_scales_k_tensors ,
473+ self .k_scales_ptrs ,
474+ self .num_cpu_blocks ,
475+ gpu_block_ids ,
476+ cpu_block_ids ,
477+ self .device ,
478+ 1 ,
479+ )
480+ swap_cache_all_layers (
481+ self .gpu_cache_scales_v_tensors ,
482+ self .v_scales_ptrs ,
483+ self .num_cpu_blocks ,
484+ gpu_block_ids ,
485+ cpu_block_ids ,
486+ self .device ,
487+ 1 ,
488+ )
411489 else :
412490 logger .warning (
413491 f"transfer data: Get unexpected event type { event_type } , only SWAP2CPU and SWAP2GPU supported"
0 commit comments