diff --git a/fastchat/serve/model_worker.py b/fastchat/serve/model_worker.py index 683a78556..82deb7820 100644 --- a/fastchat/serve/model_worker.py +++ b/fastchat/serve/model_worker.py @@ -171,7 +171,7 @@ def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict): mask = attention_mask.unsqueeze(-1).expand(data.size()).float() masked_embeddings = data * mask sum_embeddings = torch.sum(masked_embeddings, dim=1) - token_num = torch.sum(attention_mask).item() + token_num = attention_mask.sum(dim=1, keepdim=True) return sum_embeddings, token_num @@ -224,10 +224,10 @@ def get_embeddings(self, params): ): embedding = embedding / token_num normalized_embeddings = F.normalize(embedding, p=2, dim=1) - ret["token_num"] = token_num + ret["token_num"] = token_num.sum().item() else: all_embeddings = [] - all_token_num = 0 + all_token_num = 0 # per-sequence tensor, accumulated across chunks for i in range(0, input_ids.size(1), self.context_len): chunk_input_ids = input_ids[:, i : i + self.context_len] chunk_attention_mask = attention_mask[:, i : i + self.context_len] @@ -273,7 +273,11 @@ def get_embeddings(self, params): embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num normalized_embeddings = F.normalize(embedding, p=2, dim=1) - ret["token_num"] = all_token_num + ret["token_num"] = ( + all_token_num.sum().item() + if isinstance(all_token_num, torch.Tensor) + else all_token_num + ) if base64_encode == "base64": out_embeddings = self.__encode_base64(normalized_embeddings)