Skip to content

Commit 74f33ef

Browse files
authored
[Intel HPU] fix bugs caused by other commits (#5074)
* [Intel HPU] fix bugs caused by other commits * update code by copilot
1 parent 33f96ff commit 74f33ef

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,15 @@ class HPUAttentionBackend(AttentionBackend_HPU):
186186
HPUAttentionBackend backend implementation.
187187
"""
188188

189-
def __init__(self, llm_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int):
189+
def __init__(
190+
self,
191+
llm_config: FDConfig,
192+
kv_num_heads: int,
193+
num_heads: int,
194+
head_dim: int,
195+
encoder_block_shape_q: int = -1,
196+
decoder_block_shape_q: int = -1,
197+
):
190198
"""
191199
HPUAttentionBackend __init__
192200
"""
@@ -239,11 +247,13 @@ def init_attention_metadata(self, forward_meta):
239247
def get_kv_cache_shape(
240248
self,
241249
max_num_blocks: int,
250+
kv_cache_quant_type: Optional[str] = None,
242251
):
243252
"""
244253
Caculate kv cache shape
245254
"""
246-
return (max_num_blocks, self.block_size, self.kv_num_heads, self.head_dim)
255+
key_cache_shape = value_cache_shape = [max_num_blocks, self.block_size, self.kv_num_heads, self.head_dim]
256+
return key_cache_shape, value_cache_shape
247257

248258
def forward_extend(
249259
self, src, qkv_proj: QKVParallelLinear, o_proj: RowParallelLinear, layer: Attention, forward_meta

fastdeploy/worker/hpu_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def __init__(
328328

329329
# Sampler
330330
if not self.speculative_decoding:
331-
self.sampler = Sampler()
331+
self.sampler = Sampler(fd_config)
332332
else:
333333
self.sampler = SpeculativeSampler(fd_config)
334334

0 commit comments

Comments
 (0)