Skip to content

Commit 336fcf8

Browse files
DragonFivemaxiaolong.maxwell
authored andcommitted
feat: add rec proto,serivce and utils for rec framework[2/6].
1 parent 67a55c1 commit 336fcf8

File tree

15 files changed

+575
-7
lines changed

15 files changed

+575
-7
lines changed

xllm/api_service/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ cc_library(
88
api_service_impl.h
99
call.h
1010
completion_service_impl.h
11+
rec_completion_service_impl.h
1112
chat_service_impl.h
1213
embedding_service_impl.h
1314
image_generation_service_impl.h
@@ -23,6 +24,7 @@ cc_library(
2324
api_service.cpp
2425
call.cpp
2526
completion_service_impl.cpp
27+
rec_completion_service_impl.cpp
2628
chat_service_impl.cpp
2729
embedding_service_impl.cpp
2830
image_generation_service_impl.cpp

xllm/api_service/api_service.cpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ limitations under the License.
2727
#include "core/common/metrics.h"
2828
#include "core/runtime/dit_master.h"
2929
#include "core/runtime/llm_master.h"
30+
// TODO. add following when next pr.
31+
// #include "core/runtime/rec_master.h"
3032
#include "core/runtime/vlm_master.h"
3133
#include "core/util/closure_guard.h"
3234
#include "embedding.pb.h"
@@ -70,6 +72,11 @@ APIService::APIService(Master* master,
7072
image_generation_service_impl_ =
7173
std::make_unique<ImageGenerationServiceImpl>(
7274
dynamic_cast<DiTMaster*>(master), model_names);
75+
} else if (FLAGS_backend == "rec") {
76+
// TODO. delete this when next pr.
77+
using RecMaster = LLMMaster;
78+
rec_completion_service_impl_ = std::make_unique<RecCompletionServiceImpl>(
79+
dynamic_cast<RecMaster*>(master), model_names);
7380
}
7481
models_service_impl_ =
7582
ServiceImplFactory<ModelsServiceImpl>::create_service_impl(
@@ -80,7 +87,27 @@ void APIService::Completions(::google::protobuf::RpcController* controller,
8087
const proto::CompletionRequest* request,
8188
proto::CompletionResponse* response,
8289
::google::protobuf::Closure* done) {
83-
// TODO with xllm-service
90+
xllm::ClosureGuard done_guard(
91+
done,
92+
std::bind(request_in_metric, nullptr),
93+
std::bind(request_out_metric, (void*)controller));
94+
if (!request || !response || !controller) {
95+
LOG(ERROR) << "brpc request | respose | controller is null.";
96+
return;
97+
}
98+
auto ctrl = reinterpret_cast<brpc::Controller*>(controller);
99+
auto arena = response->GetArena();
100+
std::shared_ptr<Call> call = std::make_shared<CompletionCall>(
101+
ctrl,
102+
done_guard.release(),
103+
const_cast<proto::CompletionRequest*>(request),
104+
response,
105+
arena != nullptr);
106+
if (FLAGS_backend == "llm" || FLAGS_backend == "vlm") {
107+
completion_service_impl_->process_async(call);
108+
} else if (FLAGS_backend == "rec") {
109+
rec_completion_service_impl_->process_async(call);
110+
}
84111
}
85112

86113
void APIService::CompletionsHttp(::google::protobuf::RpcController* controller,
@@ -116,7 +143,11 @@ void APIService::CompletionsHttp(::google::protobuf::RpcController* controller,
116143

117144
std::shared_ptr<Call> call = std::make_shared<CompletionCall>(
118145
ctrl, done_guard.release(), req_pb, resp_pb, arena != nullptr);
119-
completion_service_impl_->process_async(call);
146+
if (FLAGS_backend == "llm" || FLAGS_backend == "vlm") {
147+
completion_service_impl_->process_async(call);
148+
} else if (FLAGS_backend == "rec") {
149+
rec_completion_service_impl_->process_async(call);
150+
}
120151
}
121152

122153
void APIService::ChatCompletions(::google::protobuf::RpcController* controller,

xllm/api_service/api_service.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
#include "image_generation_service_impl.h"
2222
#include "models_service_impl.h"
2323
#include "qwen3_rerank_service_impl.h"
24+
#include "rec_completion_service_impl.h"
2425
#include "rerank_service_impl.h"
2526
#include "xllm_service.pb.h"
2627

@@ -124,6 +125,7 @@ class APIService : public proto::XllmAPIService {
124125
std::unique_ptr<ModelsServiceImpl> models_service_impl_;
125126
std::unique_ptr<ImageGenerationServiceImpl> image_generation_service_impl_;
126127
std::unique_ptr<RerankServiceImpl> rerank_service_impl_;
128+
std::unique_ptr<RecCompletionServiceImpl> rec_completion_service_impl_;
127129
};
128130

129131
} // namespace xllm
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "rec_completion_service_impl.h"
17+
18+
#include <absl/time/clock.h>
19+
#include <absl/time/time.h>
20+
#include <glog/logging.h>
21+
#include <torch/torch.h>
22+
23+
#include <cstdint>
24+
#include <string>
25+
26+
#include "common/global_flags.h"
27+
#include "common/instance_name.h"
28+
#include "completion.pb.h"
29+
#include "core/framework/request/mm_data.h"
30+
#include "core/framework/request/request_output.h"
31+
#include "core/runtime/llm_master.h"
32+
// TODO. add following when next pr.
33+
// #include "core/runtime/rec_master.h"
34+
#include "core/util/utils.h"
35+
36+
#define likely(x) __builtin_expect(!!(x), 1)
37+
#define unlikely(x) __builtin_expect(!!(x), 0)
38+
39+
namespace xllm {
40+
namespace {
41+
void set_logprobs(proto::Choice* choice,
42+
const std::optional<std::vector<LogProb>>& logprobs) {
43+
if (!logprobs.has_value() || logprobs.value().empty()) {
44+
return;
45+
}
46+
47+
auto* proto_logprobs = choice->mutable_logprobs();
48+
for (const auto& logprob : logprobs.value()) {
49+
proto_logprobs->add_tokens(logprob.token);
50+
proto_logprobs->add_token_ids(logprob.token_id);
51+
proto_logprobs->add_token_logprobs(logprob.logprob);
52+
}
53+
}
54+
55+
bool send_result_to_client_brpc_rec(std::shared_ptr<CompletionCall> call,
56+
const std::string& request_id,
57+
int64_t created_time,
58+
const std::string& model,
59+
const RequestOutput& req_output) {
60+
auto& response = call->response();
61+
response.set_object("text_completion");
62+
response.set_id(request_id);
63+
response.set_created(created_time);
64+
response.set_model(model);
65+
66+
// add choices into response
67+
response.mutable_choices()->Reserve(req_output.outputs.size());
68+
for (const auto& output : req_output.outputs) {
69+
auto* choice = response.add_choices();
70+
choice->set_index(output.index);
71+
choice->set_text(output.text);
72+
set_logprobs(choice, output.logprobs);
73+
if (output.finish_reason.has_value()) {
74+
choice->set_finish_reason(output.finish_reason.value());
75+
}
76+
}
77+
78+
// add usage statistics
79+
if (req_output.usage.has_value()) {
80+
const auto& usage = req_output.usage.value();
81+
auto* proto_usage = response.mutable_usage();
82+
proto_usage->set_prompt_tokens(
83+
static_cast<int32_t>(usage.num_prompt_tokens));
84+
proto_usage->set_completion_tokens(
85+
static_cast<int32_t>(usage.num_generated_tokens));
86+
proto_usage->set_total_tokens(static_cast<int32_t>(usage.num_total_tokens));
87+
}
88+
89+
// Add rec specific output tensors
90+
auto output_tensor = response.mutable_output_tensors()->Add();
91+
output_tensor->set_name("rec_result");
92+
// TODO: add following when next pr.
93+
// if (FLAGS_enable_constrained_decoding) {
94+
if (true) {
95+
output_tensor->set_datatype(proto::DataType::INT64);
96+
output_tensor->mutable_shape()->Add(req_output.outputs.size());
97+
output_tensor->mutable_shape()->Add(1); // Single item per output
98+
// TODO: add following when next pr.
99+
/*
100+
auto context = output_tensor->mutable_contents();
101+
for (int i = 0; i < req_output.outputs.size(); ++i) {
102+
if (req_output.outputs[i].item_ids.has_value()) {
103+
context->mutable_int64_contents()->Add(
104+
req_output.outputs[i].item_ids.value());
105+
}
106+
}
107+
*/
108+
} else {
109+
output_tensor->set_datatype(proto::DataType::INT32);
110+
111+
output_tensor->mutable_shape()->Add(req_output.outputs.size());
112+
output_tensor->mutable_shape()->Add(req_output.outputs[0].token_ids.size());
113+
114+
auto context = output_tensor->mutable_contents();
115+
for (int i = 0; i < req_output.outputs.size(); ++i) {
116+
// LOG(INFO) << req_output.outputs[i].token_ids;
117+
context->mutable_int_contents()->Add(
118+
req_output.outputs[i].token_ids.begin(),
119+
req_output.outputs[i].token_ids.end());
120+
}
121+
}
122+
123+
return call->write_and_finish(response);
124+
}
125+
126+
} // namespace
127+
128+
RecCompletionServiceImpl::RecCompletionServiceImpl(
129+
RecMaster* master,
130+
const std::vector<std::string>& models)
131+
: APIServiceImpl(models), master_(master) {
132+
CHECK(master_ != nullptr);
133+
}
134+
135+
void RecCompletionServiceImpl::process_async_impl(
136+
std::shared_ptr<CompletionCall> call) {
137+
const auto& rpc_request = call->request();
138+
139+
// check if model is supported
140+
const auto& model = rpc_request.model();
141+
if (unlikely(!models_.contains(model))) {
142+
call->finish_with_error(StatusCode::UNKNOWN, "Model not supported");
143+
return;
144+
}
145+
146+
// Check if the request is being rate-limited.
147+
if (unlikely(master_->get_rate_limiter()->is_limited())) {
148+
call->finish_with_error(
149+
StatusCode::RESOURCE_EXHAUSTED,
150+
"The number of concurrent requests has reached the limit.");
151+
return;
152+
}
153+
154+
RequestParams request_params(
155+
rpc_request, call->get_x_request_id(), call->get_x_request_time());
156+
bool include_usage = false;
157+
if (rpc_request.has_stream_options()) {
158+
include_usage = rpc_request.stream_options().include_usage();
159+
}
160+
161+
std::optional<std::vector<int>> prompt_tokens = std::nullopt;
162+
if (rpc_request.has_routing()) {
163+
prompt_tokens = std::vector<int>{};
164+
prompt_tokens->reserve(rpc_request.token_ids_size());
165+
for (int i = 0; i < rpc_request.token_ids_size(); i++) {
166+
prompt_tokens->emplace_back(rpc_request.token_ids(i));
167+
}
168+
169+
request_params.decode_address = rpc_request.routing().decode_name();
170+
}
171+
172+
const auto& rpc_request_ref = call->request();
173+
std::optional<MMData> mm_data = std::nullopt;
174+
if (rpc_request_ref.input_tensors_size()) {
175+
// HISTOGRAM_OBSERVE(rec_input_first_dim,
176+
// rpc_request_ref.input_tensors(0).shape(0));
177+
178+
MMDict mm_dict;
179+
for (int i = 0; i < rpc_request_ref.input_tensors_size(); ++i) {
180+
const auto& tensor = rpc_request_ref.input_tensors(i);
181+
mm_dict[tensor.name()] =
182+
xllm::util::convert_rec_tensor_to_torch(tensor).to(torch::kBFloat16);
183+
}
184+
mm_data = std::move(MMData(MMType::EMBEDDING, mm_dict));
185+
}
186+
187+
// schedule the request
188+
auto saved_streaming = request_params.streaming;
189+
auto saved_request_id = request_params.request_id;
190+
master_->handle_request(
191+
std::move(rpc_request_ref.prompt()),
192+
std::move(prompt_tokens),
193+
// TODO. add following when next pr.
194+
// std::move(mm_data),
195+
std::move(request_params),
196+
// TODO. delete this when next pr.
197+
call.get(),
198+
[call,
199+
model,
200+
master = master_,
201+
stream = std::move(saved_streaming),
202+
include_usage = include_usage,
203+
request_id = saved_request_id,
204+
created_time = absl::ToUnixSeconds(absl::Now())](
205+
const RequestOutput& req_output) -> bool {
206+
if (req_output.status.has_value()) {
207+
const auto& status = req_output.status.value();
208+
if (!status.ok()) {
209+
// Reduce the number of concurrent requests when a request is
210+
// finished with error.
211+
master->get_rate_limiter()->decrease_one_request();
212+
213+
return call->finish_with_error(status.code(), status.message());
214+
}
215+
}
216+
217+
// Reduce the number of concurrent requests when a request is finished
218+
// or canceled.
219+
if (req_output.finished || req_output.cancelled) {
220+
master->get_rate_limiter()->decrease_one_request();
221+
}
222+
223+
return send_result_to_client_brpc_rec(
224+
call, request_id, created_time, model, req_output);
225+
});
226+
}
227+
228+
} // namespace xllm
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
#include <absl/container/flat_hash_set.h>
19+
20+
#include "api_service_impl.h"
21+
#include "completion.pb.h"
22+
#include "rec.pb.h"
23+
#include "stream_call.h"
24+
25+
namespace xllm {
26+
27+
using CompletionCall =
28+
StreamCall<proto::CompletionRequest, proto::CompletionResponse>;
29+
30+
// TODO. add following when next pr.
31+
// class RecMaster;
32+
using RecMaster = LLMMaster;
33+
34+
// a class to handle completion requests
35+
class RecCompletionServiceImpl final : public APIServiceImpl<CompletionCall> {
36+
public:
37+
RecCompletionServiceImpl(RecMaster* master,
38+
const std::vector<std::string>& models);
39+
40+
// brpc call_data needs to use shared_ptr
41+
void process_async_impl(std::shared_ptr<CompletionCall> call);
42+
43+
private:
44+
DISALLOW_COPY_AND_ASSIGN(RecCompletionServiceImpl);
45+
RecMaster* master_ = nullptr;
46+
};
47+
48+
} // namespace xllm

xllm/core/common/global_flags.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,12 @@ DEFINE_bool(
424424
"The default prefetching ratio for gateup weight is 40%."
425425
"If adjustments are needed, e.g. export PREFETCH_COEFFOCIENT=0.5");
426426

427+
// rec prefill-only mode
428+
DEFINE_bool(enable_rec_prefill_only,
429+
false,
430+
"Enable rec prefill-only mode (no decoder self-attention blocks "
431+
"allocation).");
432+
427433
// --- dp load balance ---
428434

429435
DEFINE_bool(

0 commit comments

Comments
 (0)