Skip to content

Commit 70ec1e1

Browse files
authored
[Features] add audio request & fix embedding bug (#5201)
* [Features] add audio request & fix embedding bug * fix bug
1 parent 9f4977e commit 70ec1e1

File tree

2 files changed

+46
-14
lines changed

2 files changed

+46
-14
lines changed

fastdeploy/input/tokenzier_client.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,27 +27,43 @@
2727
class BaseEncodeRequest(BaseModel):
2828
version: str
2929
req_id: str
30-
is_gen: bool
31-
resolution: int
3230

3331

3432
class ImageEncodeRequest(BaseEncodeRequest):
3533
image_url: Union[str, HttpUrl]
34+
is_gen: bool
35+
resolution: int
3636

3737

3838
class VideoEncodeRequest(BaseEncodeRequest):
3939
video_url: Union[str, HttpUrl]
40+
is_gen: bool
41+
resolution: int
4042
start_ts: int
4143
end_ts: int
4244
frames: int
4345
vit_merge: bool
4446

4547

48+
class AudioEncodeRequest(BaseEncodeRequest):
49+
audio_url: Union[str, HttpUrl]
50+
is_add_spk_emb: bool
51+
is_pad_aug: bool
52+
is_aug: bool
53+
audio_start: Optional[float]
54+
audio_dur: Optional[float]
55+
56+
4657
class ImageDecodeRequest(BaseModel):
4758
req_id: str
4859
data: list[Any]
4960

5061

62+
class AudioDecodeRequest(BaseModel):
63+
req_id: str
64+
data: list[Any]
65+
66+
5167
class AsyncTokenizerClient:
5268
def __init__(
5369
self,
@@ -74,9 +90,15 @@ async def encode_image(self, request: ImageEncodeRequest):
7490
async def encode_video(self, request: VideoEncodeRequest):
7591
return await self._async_encode_request("video", request.__dict__)
7692

93+
async def encode_audio(self, request: AudioEncodeRequest):
94+
return await self._async_encode_request("audio", request.__dict__)
95+
7796
async def decode_image(self, request: ImageDecodeRequest):
7897
return await self._async_decode_request("image", request.__dict__)
7998

99+
async def decode_audio(self, request: AudioDecodeRequest):
100+
return await self._async_decode_request("audio", request.__dict__)
101+
80102
async def log_request(self, request):
81103
data_processor_logger.debug(f">>> Request: {request.method} {request.url}")
82104
data_processor_logger.debug(f">>> Headers: {request.headers}")
@@ -101,6 +123,8 @@ async def _async_encode_request(self, type: str, request: dict):
101123
url = f"{self.base_url}/image/encode"
102124
elif type == "video":
103125
url = f"{self.base_url}/video/encode"
126+
elif type == "audio":
127+
url = f"{self.base_url}/audio/encode"
104128
else:
105129
raise ValueError("Invalid type")
106130

@@ -110,6 +134,7 @@ async def _async_encode_request(self, type: str, request: dict):
110134
raise RuntimeError(f"Failed to create tokenize task: {e}") from e
111135

112136
task_info = resp.json()
137+
113138
if task_info.get("code") != 0:
114139
raise RuntimeError(f"Tokenize task creation failed, {task_info.get('message')}")
115140

@@ -154,6 +179,8 @@ async def _async_decode_request(self, type: str, request: dict):
154179
url = None
155180
if type == "image":
156181
url = f"{self.base_url}/image/decode"
182+
elif type == "audio":
183+
url = f"{self.base_url}/audio/decode"
157184
else:
158185
raise ValueError("Invalid type")
159186

fastdeploy/model_executor/layers/embeddings.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def __init__(
106106
params_dtype: str = "bfloat16",
107107
prefix="",
108108
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
109+
org_num_embeddings: int | None = None,
110+
general=False,
109111
) -> None:
110112
"""
111113
Initialize the VocabParallelEmbedding layer for the model.
@@ -132,17 +134,23 @@ def __init__(
132134
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
133135
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
134136
self.params_dtype: str = params_dtype
135-
self.padding_size = padding_size
136137

137-
self.org_vocab_size = num_embeddings
138+
self.general = general # used for general Embedding
138139
self.num_embeddings = num_embeddings
139-
num_added_embeddings = num_embeddings - self.org_vocab_size
140+
self.padding_size = padding_size
141+
if self.general:
142+
self.org_vocab_size = num_embeddings
143+
self.num_embeddings_padded = num_embeddings
144+
self.org_vocab_size_padded = num_embeddings
145+
else:
146+
self.org_vocab_size = org_num_embeddings or num_embeddings
147+
num_added_embeddings = num_embeddings - self.org_vocab_size
140148

141-
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, self.padding_size)
142-
self.num_embeddings_padded = pad_vocab_size(
143-
self.org_vocab_size_padded + num_added_embeddings, self.padding_size
144-
)
145-
assert self.org_vocab_size_padded <= self.num_embeddings_padded
149+
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, self.padding_size)
150+
self.num_embeddings_padded = pad_vocab_size(
151+
self.org_vocab_size_padded + num_added_embeddings, self.padding_size
152+
)
153+
assert self.org_vocab_size_padded <= self.num_embeddings_padded
146154
self.shard_indices = self._get_indices(
147155
self.num_embeddings_padded,
148156
self.org_vocab_size_padded,
@@ -152,9 +160,6 @@ def __init__(
152160
self.world_size,
153161
)
154162

155-
if num_embeddings % self.world_size != 0:
156-
self.num_embeddings_padded = pad_vocab_size(num_embeddings, self.padding_size)
157-
158163
if not self.column_cut:
159164
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
160165
self.num_embeddings_padded,
@@ -188,7 +193,7 @@ def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
188193
Args:
189194
state_dict (dict): A dictionary containing the checkpoint weights and biases.
190195
"""
191-
if self.tie_word_embeddings:
196+
if self.tie_word_embeddings and not self.general:
192197
weight_tensor = get_tensor(state_dict[self.prefix + ".weight"]).astype(paddle.get_default_dtype())
193198
else:
194199
weight_tensor = get_tensor(state_dict.pop(self.prefix + ".weight")).astype(paddle.get_default_dtype())

0 commit comments

Comments
 (0)