Skip to content

Commit 195fea1

Browse files
authored
feat: support v1/rerank interface for embedding model. (#272)
1 parent b26c50c commit 195fea1

File tree

11 files changed

+342
-0
lines changed

11 files changed

+342
-0
lines changed

xllm/api_service/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ cc_library(
1111
chat_service_impl.h
1212
embedding_service_impl.h
1313
image_generation_service_impl.h
14+
rerank_service_impl.h
1415
non_stream_call.h
1516
service_impl_factory.h
1617
stream_call.h
@@ -23,6 +24,7 @@ cc_library(
2324
embedding_service_impl.cpp
2425
image_generation_service_impl.cpp
2526
models_service_impl.cpp
27+
rerank_service_impl.cpp
2628
DEPS
2729
:master
2830
:chat_template
@@ -32,5 +34,7 @@ cc_library(
3234
absl::flat_hash_set
3335
absl::random_random
3436
:function_call
37+
torch
38+
$<$<BOOL:${USE_NPU}>:torch_npu>
3539
)
3640

xllm/api_service/api_service.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ APIService::APIService(Master* master,
5151
embedding_service_impl_ =
5252
ServiceImplFactory<EmbeddingServiceImpl>::create_service_impl(
5353
llm_master, model_names);
54+
rerank_service_impl_ =
55+
ServiceImplFactory<RerankServiceImpl>::create_service_impl(llm_master,
56+
model_names);
5457
} else if (FLAGS_backend == "vlm") {
5558
auto vlm_master = dynamic_cast<VLMMaster*>(master);
5659
mm_chat_service_impl_ =
@@ -260,6 +263,47 @@ void APIService::ImageGenerationHttp(
260263
image_generation_service_impl_->process_async(call);
261264
}
262265

266+
void APIService::Rerank(::google::protobuf::RpcController* controller,
267+
const proto::RerankRequest* request,
268+
proto::RerankResponse* response,
269+
::google::protobuf::Closure* done) {
270+
// TODO with xllm-service
271+
}
272+
273+
void APIService::RerankHttp(::google::protobuf::RpcController* controller,
274+
const proto::HttpRequest* request,
275+
proto::HttpResponse* response,
276+
::google::protobuf::Closure* done) {
277+
xllm::ClosureGuard done_guard(
278+
done,
279+
std::bind(request_in_metric, nullptr),
280+
std::bind(request_out_metric, (void*)controller));
281+
if (!request || !response || !controller) {
282+
LOG(ERROR) << "brpc request | respose | controller is null";
283+
return;
284+
}
285+
286+
auto arena = response->GetArena();
287+
auto req_pb =
288+
google::protobuf::Arena::CreateMessage<proto::RerankRequest>(arena);
289+
auto resp_pb =
290+
google::protobuf::Arena::CreateMessage<proto::RerankResponse>(arena);
291+
292+
auto ctrl = reinterpret_cast<brpc::Controller*>(controller);
293+
std::string attachment = std::move(ctrl->request_attachment().to_string());
294+
std::string error;
295+
auto st = json2pb::JsonToProtoMessage(attachment, req_pb, &error);
296+
if (!st) {
297+
ctrl->SetFailed(error);
298+
LOG(ERROR) << "parse json to proto failed: " << error;
299+
return;
300+
}
301+
302+
std::shared_ptr<Call> call =
303+
std::make_shared<RerankCall>(ctrl, done_guard.release(), req_pb, resp_pb);
304+
rerank_service_impl_->process_async(call);
305+
}
306+
263307
void APIService::Models(::google::protobuf::RpcController* controller,
264308
const proto::ModelListRequest* request,
265309
proto::ModelListResponse* response,

xllm/api_service/api_service.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020
#include "embedding_service_impl.h"
2121
#include "image_generation_service_impl.h"
2222
#include "models_service_impl.h"
23+
#include "rerank_service_impl.h"
2324
#include "xllm_service.pb.h"
2425

2526
namespace xllm {
@@ -70,6 +71,17 @@ class APIService : public proto::XllmAPIService {
7071
const proto::HttpRequest* request,
7172
proto::HttpResponse* response,
7273
::google::protobuf::Closure* done) override;
74+
75+
void Rerank(::google::protobuf::RpcController* controller,
76+
const proto::RerankRequest* request,
77+
proto::RerankResponse* response,
78+
::google::protobuf::Closure* done) override;
79+
80+
void RerankHttp(::google::protobuf::RpcController* controller,
81+
const proto::HttpRequest* request,
82+
proto::HttpResponse* response,
83+
::google::protobuf::Closure* done) override;
84+
7385
void Models(::google::protobuf::RpcController* controller,
7486
const proto::ModelListRequest* request,
7587
proto::ModelListResponse* response,
@@ -109,6 +121,7 @@ class APIService : public proto::XllmAPIService {
109121
std::unique_ptr<EmbeddingServiceImpl> embedding_service_impl_;
110122
std::unique_ptr<ModelsServiceImpl> models_service_impl_;
111123
std::unique_ptr<ImageGenerationServiceImpl> image_generation_service_impl_;
124+
std::unique_ptr<RerankServiceImpl> rerank_service_impl_;
112125
};
113126

114127
} // namespace xllm
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "rerank_service_impl.h"
17+
18+
#include <glog/logging.h>
19+
#include <torch/torch.h>
20+
21+
#include <string>
22+
23+
#include "common/instance_name.h"
24+
#include "framework/request/request_params.h"
25+
#include "runtime/llm_master.h"
26+
#include "util/blocking_counter.h"
27+
#include "util/utils.h"
28+
#include "util/uuid.h"
29+
30+
namespace xllm {
31+
namespace {
32+
33+
struct RerankRequestOutput {
34+
int32_t index = 0;
35+
std::string document = "";
36+
float score = 0.0f;
37+
38+
RerankRequestOutput(int32_t index, std::string document, float score)
39+
: index(index), document(std::move(document)), score(score) {}
40+
};
41+
42+
bool send_result_to_client_brpc(std::shared_ptr<RerankCall> call,
43+
const std::string& request_id,
44+
int64_t created_time,
45+
const std::string& model,
46+
const std::vector<std::string>& documents,
47+
int32_t top_n,
48+
const std::vector<RequestOutput>& req_outputs) {
49+
auto& response = call->response();
50+
response.set_id(request_id);
51+
response.set_model(model);
52+
53+
// calculate cosine similarity
54+
size_t doc_size = documents.size() - 1;
55+
std::string query = documents[doc_size];
56+
auto query_embed = req_outputs[doc_size].outputs[0].embeddings.value();
57+
auto query_tensor = torch::from_blob(
58+
query_embed.data(), {query_embed.size()}, torch::kFloat32);
59+
60+
std::vector<RerankRequestOutput> rerank_outputs;
61+
rerank_outputs.reserve(doc_size);
62+
for (size_t i = 0; i < doc_size; ++i) {
63+
if (req_outputs[i].outputs[0].embeddings.has_value()) {
64+
auto doc_embed = req_outputs[i].outputs[0].embeddings.value();
65+
auto doc_tensor = torch::from_blob(
66+
doc_embed.data(), {doc_embed.size()}, torch::kFloat32);
67+
auto score =
68+
torch::cosine_similarity(query_tensor, doc_tensor, 0).item<float>();
69+
rerank_outputs.emplace_back(i, documents[i], score);
70+
}
71+
}
72+
73+
std::sort(rerank_outputs.begin(),
74+
rerank_outputs.end(),
75+
[](const RerankRequestOutput& a, const RerankRequestOutput& b) {
76+
return a.score > b.score;
77+
});
78+
79+
// add top_n results
80+
int32_t valid_top_n = std::min(top_n, static_cast<int32_t>(doc_size));
81+
response.mutable_results()->Reserve(valid_top_n);
82+
for (int32_t i = 0; i < valid_top_n; ++i) {
83+
auto* result = response.add_results();
84+
result->set_index(rerank_outputs[i].index);
85+
auto* document = result->mutable_document();
86+
document->set_text(rerank_outputs[i].document);
87+
result->set_relevance_score(rerank_outputs[i].score);
88+
}
89+
90+
// add usage statistics
91+
int32_t num_prompt_tokens = 0;
92+
int32_t num_generated_tokens = 0;
93+
int32_t num_total_tokens = 0;
94+
for (auto req_output : req_outputs) {
95+
if (req_output.usage.has_value()) {
96+
const auto& usage = req_output.usage.value();
97+
num_prompt_tokens += usage.num_prompt_tokens;
98+
num_generated_tokens += usage.num_generated_tokens;
99+
num_total_tokens += usage.num_total_tokens;
100+
}
101+
}
102+
if (num_total_tokens > 0) {
103+
auto* proto_usage = response.mutable_usage();
104+
proto_usage->set_prompt_tokens(num_prompt_tokens);
105+
proto_usage->set_completion_tokens(num_generated_tokens);
106+
proto_usage->set_total_tokens(num_total_tokens);
107+
}
108+
109+
return call->write_and_finish(response);
110+
}
111+
112+
} // namespace
113+
114+
RerankServiceImpl::RerankServiceImpl(LLMMaster* master,
115+
const std::vector<std::string>& models)
116+
: APIServiceImpl(models), master_(master) {
117+
CHECK(master_ != nullptr);
118+
}
119+
120+
// rerank_async for brpc
121+
void RerankServiceImpl::process_async_impl(std::shared_ptr<RerankCall> call) {
122+
const auto& rpc_request = call->request();
123+
// check if model is supported
124+
const auto& model = rpc_request.model();
125+
if (!models_.contains(model)) {
126+
call->finish_with_error(StatusCode::UNKNOWN, "Model not supported");
127+
return;
128+
}
129+
130+
std::vector<std::string> documents;
131+
if (rpc_request.documents_size() > 0) {
132+
documents = std::vector<std::string>(rpc_request.documents().begin(),
133+
rpc_request.documents().end());
134+
}
135+
documents.emplace_back(rpc_request.query());
136+
137+
// create RequestParams for rerank request
138+
RequestParams request_params(
139+
rpc_request, call->get_x_request_id(), call->get_x_request_time());
140+
std::vector<RequestParams> sps(documents.size(), request_params);
141+
auto request_id = request_params.request_id;
142+
auto created_time = absl::ToUnixSeconds(absl::Now());
143+
144+
// schedule the request
145+
std::vector<RequestOutput> req_outputs;
146+
req_outputs.resize(documents.size());
147+
BlockingCounter counter(documents.size());
148+
149+
auto batch_callback = [&req_outputs, &counter](size_t index,
150+
RequestOutput output) -> bool {
151+
req_outputs[index] = std::move(output);
152+
counter.decrement_count();
153+
return true;
154+
};
155+
156+
master_->handle_batch_request(documents, sps, batch_callback);
157+
158+
// Wait for all tasks to complete
159+
counter.wait();
160+
161+
int32_t top_n = documents.size() - 1;
162+
if (rpc_request.has_top_n()) {
163+
top_n = rpc_request.top_n();
164+
}
165+
166+
send_result_to_client_brpc(
167+
call, request_id, created_time, model, documents, top_n, req_outputs);
168+
}
169+
170+
} // namespace xllm
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
#include <absl/container/flat_hash_set.h>
18+
19+
#include "api_service/api_service_impl.h"
20+
#include "api_service/call.h"
21+
#include "api_service/non_stream_call.h"
22+
#include "rerank.pb.h"
23+
24+
namespace xllm {
25+
26+
using RerankCall = NonStreamCall<proto::RerankRequest, proto::RerankResponse>;
27+
28+
// a class to handle completion requests
29+
class RerankServiceImpl final : public APIServiceImpl<RerankCall> {
30+
public:
31+
RerankServiceImpl(LLMMaster* master, const std::vector<std::string>& models);
32+
33+
// brpc call_data needs to use shared_ptr
34+
void process_async_impl(std::shared_ptr<RerankCall> call);
35+
36+
private:
37+
DISALLOW_COPY_AND_ASSIGN(RerankServiceImpl);
38+
LLMMaster* master_ = nullptr;
39+
};
40+
41+
} // namespace xllm

xllm/core/framework/request/request_params.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ std::string generate_chat_request_id() {
3939
short_uuid.random();
4040
}
4141

42+
std::string generate_rerank_request_id() {
43+
return "rerankcmpl-" + InstanceName::name()->get_name_hash() + "-" +
44+
short_uuid.random();
45+
}
46+
4247
} // namespace
4348

4449
RequestParams::RequestParams(const proto::CompletionRequest& request,
@@ -332,6 +337,20 @@ RequestParams::RequestParams(const proto::EmbeddingRequest& request,
332337
streaming = false;
333338
}
334339

340+
RequestParams::RequestParams(const proto::RerankRequest& request,
341+
const std::string& x_rid,
342+
const std::string& x_rtime) {
343+
request_id = generate_rerank_request_id();
344+
if (request.has_service_request_id()) {
345+
service_request_id = request.service_request_id();
346+
}
347+
x_request_id = x_rid;
348+
x_request_time = x_rtime;
349+
is_embeddings = true;
350+
max_tokens = 1;
351+
streaming = false;
352+
}
353+
335354
bool RequestParams::verify_params(OutputCallback callback) const {
336355
if (n == 0) {
337356
CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT,

xllm/core/framework/request/request_params.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ limitations under the License.
3131
#include "multimodal.pb.h"
3232
#include "request.h"
3333
#include "request_output.h"
34+
#include "rerank.pb.h"
3435

3536
namespace xllm {
3637

@@ -48,6 +49,9 @@ struct RequestParams {
4849
RequestParams(const proto::EmbeddingRequest& request,
4950
const std::string& x_rid,
5051
const std::string& x_rtime);
52+
RequestParams(const proto::RerankRequest& request,
53+
const std::string& x_rid,
54+
const std::string& x_rtime);
5155

5256
bool verify_params(OutputCallback callback) const;
5357

xllm/proto/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ proto_library(
1010
chat.proto
1111
multimodal.proto
1212
embedding.proto
13+
rerank.proto
1314
models.proto
1415
worker.proto
1516
disagg_pd.proto

0 commit comments

Comments
 (0)