1+ /* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+
7+ https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+ Unless required by applicable law or agreed to in writing, software
10+ distributed under the License is distributed on an "AS IS" BASIS,
11+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ See the License for the specific language governing permissions and
13+ limitations under the License.
14+ ==============================================================================*/
15+
16+ #include " rec_completion_service_impl.h"
17+
18+ #include < absl/time/clock.h>
19+ #include < absl/time/time.h>
20+ #include < glog/logging.h>
21+ #include < torch/torch.h>
22+
23+ #include < cstdint>
24+ #include < string>
25+
26+ #include " common/global_flags.h"
27+ #include " common/instance_name.h"
28+ #include " completion.pb.h"
29+ #include " core/framework/request/mm_data.h"
30+ #include " core/framework/request/request_output.h"
31+ #include " core/runtime/llm_master.h"
32+ // TODO. add following when next pr.
33+ // #include "core/runtime/rec_master.h"
34+ #include " core/util/utils.h"
35+
36+ #define likely (x ) __builtin_expect(!!(x), 1 )
37+ #define unlikely (x ) __builtin_expect(!!(x), 0 )
38+
39+ namespace xllm {
40+ namespace {
41+ void set_logprobs (proto::Choice* choice,
42+ const std::optional<std::vector<LogProb>>& logprobs) {
43+ if (!logprobs.has_value () || logprobs.value ().empty ()) {
44+ return ;
45+ }
46+
47+ auto * proto_logprobs = choice->mutable_logprobs ();
48+ for (const auto & logprob : logprobs.value ()) {
49+ proto_logprobs->add_tokens (logprob.token );
50+ proto_logprobs->add_token_ids (logprob.token_id );
51+ proto_logprobs->add_token_logprobs (logprob.logprob );
52+ }
53+ }
54+
55+ bool send_result_to_client_brpc_rec (std::shared_ptr<CompletionCall> call,
56+ const std::string& request_id,
57+ int64_t created_time,
58+ const std::string& model,
59+ const RequestOutput& req_output) {
60+ auto & response = call->response ();
61+ response.set_object (" text_completion" );
62+ response.set_id (request_id);
63+ response.set_created (created_time);
64+ response.set_model (model);
65+
66+ // add choices into response
67+ response.mutable_choices ()->Reserve (req_output.outputs .size ());
68+ for (const auto & output : req_output.outputs ) {
69+ auto * choice = response.add_choices ();
70+ choice->set_index (output.index );
71+ choice->set_text (output.text );
72+ set_logprobs (choice, output.logprobs );
73+ if (output.finish_reason .has_value ()) {
74+ choice->set_finish_reason (output.finish_reason .value ());
75+ }
76+ }
77+
78+ // add usage statistics
79+ if (req_output.usage .has_value ()) {
80+ const auto & usage = req_output.usage .value ();
81+ auto * proto_usage = response.mutable_usage ();
82+ proto_usage->set_prompt_tokens (
83+ static_cast <int32_t >(usage.num_prompt_tokens ));
84+ proto_usage->set_completion_tokens (
85+ static_cast <int32_t >(usage.num_generated_tokens ));
86+ proto_usage->set_total_tokens (static_cast <int32_t >(usage.num_total_tokens ));
87+ }
88+
89+ // Add rec specific output tensors
90+ auto output_tensor = response.mutable_output_tensors ()->Add ();
91+ output_tensor->set_name (" rec_result" );
92+ // TODO: add following when next pr.
93+ // if (FLAGS_enable_constrained_decoding) {
94+ if (true ) {
95+ output_tensor->set_datatype (proto::DataType::INT64);
96+ output_tensor->mutable_shape ()->Add (req_output.outputs .size ());
97+ output_tensor->mutable_shape ()->Add (1 ); // Single item per output
98+ // TODO: add following when next pr.
99+ /*
100+ auto context = output_tensor->mutable_contents();
101+ for (int i = 0; i < req_output.outputs.size(); ++i) {
102+ if (req_output.outputs[i].item_ids.has_value()) {
103+ context->mutable_int64_contents()->Add(
104+ req_output.outputs[i].item_ids.value());
105+ }
106+ }
107+ */
108+ } else {
109+ output_tensor->set_datatype (proto::DataType::INT32);
110+
111+ output_tensor->mutable_shape ()->Add (req_output.outputs .size ());
112+ output_tensor->mutable_shape ()->Add (req_output.outputs [0 ].token_ids .size ());
113+
114+ auto context = output_tensor->mutable_contents ();
115+ for (int i = 0 ; i < req_output.outputs .size (); ++i) {
116+ // LOG(INFO) << req_output.outputs[i].token_ids;
117+ context->mutable_int_contents ()->Add (
118+ req_output.outputs [i].token_ids .begin (),
119+ req_output.outputs [i].token_ids .end ());
120+ }
121+ }
122+
123+ return call->write_and_finish (response);
124+ }
125+
126+ } // namespace
127+
128+ RecCompletionServiceImpl::RecCompletionServiceImpl (
129+ RecMaster* master,
130+ const std::vector<std::string>& models)
131+ : APIServiceImpl(models), master_(master) {
132+ CHECK (master_ != nullptr );
133+ }
134+
135+ void RecCompletionServiceImpl::process_async_impl (
136+ std::shared_ptr<CompletionCall> call) {
137+ const auto & rpc_request = call->request ();
138+
139+ // check if model is supported
140+ const auto & model = rpc_request.model ();
141+ if (unlikely (!models_.contains (model))) {
142+ call->finish_with_error (StatusCode::UNKNOWN, " Model not supported" );
143+ return ;
144+ }
145+
146+ // Check if the request is being rate-limited.
147+ if (unlikely (master_->get_rate_limiter ()->is_limited ())) {
148+ call->finish_with_error (
149+ StatusCode::RESOURCE_EXHAUSTED,
150+ " The number of concurrent requests has reached the limit." );
151+ return ;
152+ }
153+
154+ RequestParams request_params (
155+ rpc_request, call->get_x_request_id (), call->get_x_request_time ());
156+ bool include_usage = false ;
157+ if (rpc_request.has_stream_options ()) {
158+ include_usage = rpc_request.stream_options ().include_usage ();
159+ }
160+
161+ std::optional<std::vector<int >> prompt_tokens = std::nullopt ;
162+ if (rpc_request.has_routing ()) {
163+ prompt_tokens = std::vector<int >{};
164+ prompt_tokens->reserve (rpc_request.token_ids_size ());
165+ for (int i = 0 ; i < rpc_request.token_ids_size (); i++) {
166+ prompt_tokens->emplace_back (rpc_request.token_ids (i));
167+ }
168+
169+ request_params.decode_address = rpc_request.routing ().decode_name ();
170+ }
171+
172+ const auto & rpc_request_ref = call->request ();
173+ std::optional<MMData> mm_data = std::nullopt ;
174+ if (rpc_request_ref.input_tensors_size ()) {
175+ // HISTOGRAM_OBSERVE(rec_input_first_dim,
176+ // rpc_request_ref.input_tensors(0).shape(0));
177+
178+ MMDict mm_dict;
179+ for (int i = 0 ; i < rpc_request_ref.input_tensors_size (); ++i) {
180+ const auto & tensor = rpc_request_ref.input_tensors (i);
181+ mm_dict[tensor.name ()] =
182+ xllm::util::convert_rec_tensor_to_torch (tensor).to (torch::kBFloat16 );
183+ }
184+ mm_data = std::move (MMData (MMType::EMBEDDING, mm_dict));
185+ }
186+
187+ // schedule the request
188+ auto saved_streaming = request_params.streaming ;
189+ auto saved_request_id = request_params.request_id ;
190+ master_->handle_request (
191+ std::move (rpc_request_ref.prompt ()),
192+ std::move (prompt_tokens),
193+ // TODO. add following when next pr.
194+ // std::move(mm_data),
195+ std::move (request_params),
196+ // TODO. delete this when next pr.
197+ call.get (),
198+ [call,
199+ model,
200+ master = master_,
201+ stream = std::move (saved_streaming),
202+ include_usage = include_usage,
203+ request_id = saved_request_id,
204+ created_time = absl::ToUnixSeconds (absl::Now ())](
205+ const RequestOutput& req_output) -> bool {
206+ if (req_output.status .has_value ()) {
207+ const auto & status = req_output.status .value ();
208+ if (!status.ok ()) {
209+ // Reduce the number of concurrent requests when a request is
210+ // finished with error.
211+ master->get_rate_limiter ()->decrease_one_request ();
212+
213+ return call->finish_with_error (status.code (), status.message ());
214+ }
215+ }
216+
217+ // Reduce the number of concurrent requests when a request is finished
218+ // or canceled.
219+ if (req_output.finished || req_output.cancelled ) {
220+ master->get_rate_limiter ()->decrease_one_request ();
221+ }
222+
223+ return send_result_to_client_brpc_rec (
224+ call, request_id, created_time, model, req_output);
225+ });
226+ }
227+
228+ } // namespace xllm
0 commit comments