Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xllm/api_service/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ cc_library(
api_service_impl.h
call.h
completion_service_impl.h
rec_completion_service_impl.h
chat_service_impl.h
embedding_service_impl.h
image_generation_service_impl.h
Expand All @@ -23,6 +24,7 @@ cc_library(
api_service.cpp
call.cpp
completion_service_impl.cpp
rec_completion_service_impl.cpp
chat_service_impl.cpp
embedding_service_impl.cpp
image_generation_service_impl.cpp
Expand Down
35 changes: 33 additions & 2 deletions xllm/api_service/api_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ limitations under the License.
#include "core/common/metrics.h"
#include "core/runtime/dit_master.h"
#include "core/runtime/llm_master.h"
// TODO. add following when next pr.
// #include "core/runtime/rec_master.h"
#include "core/runtime/vlm_master.h"
#include "core/util/closure_guard.h"
#include "embedding.pb.h"
Expand Down Expand Up @@ -70,6 +72,11 @@ APIService::APIService(Master* master,
image_generation_service_impl_ =
std::make_unique<ImageGenerationServiceImpl>(
dynamic_cast<DiTMaster*>(master), model_names);
} else if (FLAGS_backend == "rec") {
// TODO. delete this when next pr.
using RecMaster = LLMMaster;
rec_completion_service_impl_ = std::make_unique<RecCompletionServiceImpl>(
dynamic_cast<RecMaster*>(master), model_names);
}
models_service_impl_ =
ServiceImplFactory<ModelsServiceImpl>::create_service_impl(
Expand All @@ -80,7 +87,27 @@ void APIService::Completions(::google::protobuf::RpcController* controller,
const proto::CompletionRequest* request,
proto::CompletionResponse* response,
::google::protobuf::Closure* done) {
// TODO with xllm-service
xllm::ClosureGuard done_guard(
done,
std::bind(request_in_metric, nullptr),
std::bind(request_out_metric, (void*)controller));
if (!request || !response || !controller) {
LOG(ERROR) << "brpc request | respose | controller is null.";
return;
}
auto ctrl = reinterpret_cast<brpc::Controller*>(controller);
auto arena = response->GetArena();
std::shared_ptr<Call> call = std::make_shared<CompletionCall>(
ctrl,
done_guard.release(),
const_cast<proto::CompletionRequest*>(request),
response,
arena != nullptr);
if (FLAGS_backend == "llm" || FLAGS_backend == "vlm") {
completion_service_impl_->process_async(call);
} else if (FLAGS_backend == "rec") {
rec_completion_service_impl_->process_async(call);
}
}

void APIService::CompletionsHttp(::google::protobuf::RpcController* controller,
Expand Down Expand Up @@ -116,7 +143,11 @@ void APIService::CompletionsHttp(::google::protobuf::RpcController* controller,

std::shared_ptr<Call> call = std::make_shared<CompletionCall>(
ctrl, done_guard.release(), req_pb, resp_pb, arena != nullptr);
completion_service_impl_->process_async(call);
if (FLAGS_backend == "llm" || FLAGS_backend == "vlm") {
completion_service_impl_->process_async(call);
} else if (FLAGS_backend == "rec") {
rec_completion_service_impl_->process_async(call);
}
}

void APIService::ChatCompletions(::google::protobuf::RpcController* controller,
Expand Down
2 changes: 2 additions & 0 deletions xllm/api_service/api_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include "image_generation_service_impl.h"
#include "models_service_impl.h"
#include "qwen3_rerank_service_impl.h"
#include "rec_completion_service_impl.h"
#include "rerank_service_impl.h"
#include "xllm_service.pb.h"

Expand Down Expand Up @@ -124,6 +125,7 @@ class APIService : public proto::XllmAPIService {
std::unique_ptr<ModelsServiceImpl> models_service_impl_;
std::unique_ptr<ImageGenerationServiceImpl> image_generation_service_impl_;
std::unique_ptr<RerankServiceImpl> rerank_service_impl_;
std::unique_ptr<RecCompletionServiceImpl> rec_completion_service_impl_;
};

} // namespace xllm
228 changes: 228 additions & 0 deletions xllm/api_service/rec_completion_service_impl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "rec_completion_service_impl.h"

#include <absl/time/clock.h>
#include <absl/time/time.h>
#include <glog/logging.h>
#include <torch/torch.h>

#include <cstdint>
#include <string>

#include "common/global_flags.h"
#include "common/instance_name.h"
#include "completion.pb.h"
#include "core/framework/request/mm_data.h"
#include "core/framework/request/request_output.h"
#include "core/runtime/llm_master.h"
// TODO. add following when next pr.
// #include "core/runtime/rec_master.h"
#include "core/util/utils.h"

#define likely(x) __builtin_expect(!!(x), 1)
#define unlikely(x) __builtin_expect(!!(x), 0)

namespace xllm {
namespace {
void set_logprobs(proto::Choice* choice,
const std::optional<std::vector<LogProb>>& logprobs) {
if (!logprobs.has_value() || logprobs.value().empty()) {
return;
}

auto* proto_logprobs = choice->mutable_logprobs();
for (const auto& logprob : logprobs.value()) {
proto_logprobs->add_tokens(logprob.token);
proto_logprobs->add_token_ids(logprob.token_id);
proto_logprobs->add_token_logprobs(logprob.logprob);
}
}

bool send_result_to_client_brpc_rec(std::shared_ptr<CompletionCall> call,
const std::string& request_id,
int64_t created_time,
const std::string& model,
const RequestOutput& req_output) {
auto& response = call->response();
response.set_object("text_completion");
response.set_id(request_id);
response.set_created(created_time);
response.set_model(model);

// add choices into response
response.mutable_choices()->Reserve(req_output.outputs.size());
for (const auto& output : req_output.outputs) {
auto* choice = response.add_choices();
choice->set_index(output.index);
choice->set_text(output.text);
set_logprobs(choice, output.logprobs);
if (output.finish_reason.has_value()) {
choice->set_finish_reason(output.finish_reason.value());
}
}

// add usage statistics
if (req_output.usage.has_value()) {
const auto& usage = req_output.usage.value();
auto* proto_usage = response.mutable_usage();
proto_usage->set_prompt_tokens(
static_cast<int32_t>(usage.num_prompt_tokens));
proto_usage->set_completion_tokens(
static_cast<int32_t>(usage.num_generated_tokens));
proto_usage->set_total_tokens(static_cast<int32_t>(usage.num_total_tokens));
}

// Add rec specific output tensors
auto output_tensor = response.mutable_output_tensors()->Add();
output_tensor->set_name("rec_result");
// TODO: add following when next pr.
// if (FLAGS_enable_constrained_decoding) {
if (true) {
output_tensor->set_datatype(proto::DataType::INT64);
output_tensor->mutable_shape()->Add(req_output.outputs.size());
output_tensor->mutable_shape()->Add(1); // Single item per output
// TODO: add following when next pr.
/*
auto context = output_tensor->mutable_contents();
for (int i = 0; i < req_output.outputs.size(); ++i) {
if (req_output.outputs[i].item_ids.has_value()) {
context->mutable_int64_contents()->Add(
req_output.outputs[i].item_ids.value());
}
}
*/
} else {
output_tensor->set_datatype(proto::DataType::INT32);

output_tensor->mutable_shape()->Add(req_output.outputs.size());
output_tensor->mutable_shape()->Add(req_output.outputs[0].token_ids.size());

auto context = output_tensor->mutable_contents();
for (int i = 0; i < req_output.outputs.size(); ++i) {
// LOG(INFO) << req_output.outputs[i].token_ids;
context->mutable_int_contents()->Add(
req_output.outputs[i].token_ids.begin(),
req_output.outputs[i].token_ids.end());
}
}

return call->write_and_finish(response);
}

} // namespace

RecCompletionServiceImpl::RecCompletionServiceImpl(
RecMaster* master,
const std::vector<std::string>& models)
: APIServiceImpl(models), master_(master) {
CHECK(master_ != nullptr);
}

void RecCompletionServiceImpl::process_async_impl(
std::shared_ptr<CompletionCall> call) {
const auto& rpc_request = call->request();

// check if model is supported
const auto& model = rpc_request.model();
if (unlikely(!models_.contains(model))) {
call->finish_with_error(StatusCode::UNKNOWN, "Model not supported");
return;
}

// Check if the request is being rate-limited.
if (unlikely(master_->get_rate_limiter()->is_limited())) {
call->finish_with_error(
StatusCode::RESOURCE_EXHAUSTED,
"The number of concurrent requests has reached the limit.");
return;
}

RequestParams request_params(
rpc_request, call->get_x_request_id(), call->get_x_request_time());
bool include_usage = false;
if (rpc_request.has_stream_options()) {
include_usage = rpc_request.stream_options().include_usage();
}

std::optional<std::vector<int>> prompt_tokens = std::nullopt;
if (rpc_request.has_routing()) {
prompt_tokens = std::vector<int>{};
prompt_tokens->reserve(rpc_request.token_ids_size());
for (int i = 0; i < rpc_request.token_ids_size(); i++) {
prompt_tokens->emplace_back(rpc_request.token_ids(i));
}

request_params.decode_address = rpc_request.routing().decode_name();
}

const auto& rpc_request_ref = call->request();
std::optional<MMData> mm_data = std::nullopt;
if (rpc_request_ref.input_tensors_size()) {
// HISTOGRAM_OBSERVE(rec_input_first_dim,
// rpc_request_ref.input_tensors(0).shape(0));

MMDict mm_dict;
for (int i = 0; i < rpc_request_ref.input_tensors_size(); ++i) {
const auto& tensor = rpc_request_ref.input_tensors(i);
mm_dict[tensor.name()] =
xllm::util::convert_rec_tensor_to_torch(tensor).to(torch::kBFloat16);
}
mm_data = std::move(MMData(MMType::EMBEDDING, mm_dict));
}

// schedule the request
auto saved_streaming = request_params.streaming;
auto saved_request_id = request_params.request_id;
master_->handle_request(
std::move(rpc_request_ref.prompt()),
std::move(prompt_tokens),
// TODO. add following when next pr.
// std::move(mm_data),
std::move(request_params),
// TODO. delete this when next pr.
call.get(),
[call,
model,
master = master_,
stream = std::move(saved_streaming),
include_usage = include_usage,
request_id = saved_request_id,
created_time = absl::ToUnixSeconds(absl::Now())](
const RequestOutput& req_output) -> bool {
if (req_output.status.has_value()) {
const auto& status = req_output.status.value();
if (!status.ok()) {
// Reduce the number of concurrent requests when a request is
// finished with error.
master->get_rate_limiter()->decrease_one_request();

return call->finish_with_error(status.code(), status.message());
}
}

// Reduce the number of concurrent requests when a request is finished
// or canceled.
if (req_output.finished || req_output.cancelled) {
master->get_rate_limiter()->decrease_one_request();
}

return send_result_to_client_brpc_rec(
call, request_id, created_time, model, req_output);
});
}

} // namespace xllm
48 changes: 48 additions & 0 deletions xllm/api_service/rec_completion_service_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once

#include <absl/container/flat_hash_set.h>

#include "api_service_impl.h"
#include "completion.pb.h"
#include "rec.pb.h"
#include "stream_call.h"

namespace xllm {

using CompletionCall =
StreamCall<proto::CompletionRequest, proto::CompletionResponse>;

// TODO. add following when next pr.
// class RecMaster;
using RecMaster = LLMMaster;

// a class to handle completion requests
class RecCompletionServiceImpl final : public APIServiceImpl<CompletionCall> {
public:
RecCompletionServiceImpl(RecMaster* master,
const std::vector<std::string>& models);

// brpc call_data needs to use shared_ptr
void process_async_impl(std::shared_ptr<CompletionCall> call);

private:
DISALLOW_COPY_AND_ASSIGN(RecCompletionServiceImpl);
RecMaster* master_ = nullptr;
};

} // namespace xllm
6 changes: 6 additions & 0 deletions xllm/core/common/global_flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,12 @@ DEFINE_bool(
"The default prefetching ratio for gateup weight is 40%."
"If adjustments are needed, e.g. export PREFETCH_COEFFOCIENT=0.5");

// rec prefill-only mode
DEFINE_bool(enable_rec_prefill_only,
false,
"Enable rec prefill-only mode (no decoder self-attention blocks "
"allocation).");

// --- dp load balance ---

DEFINE_bool(
Expand Down
Loading
Loading