From 336fcf8fcbcdc8c4ed250eaa37d197c357bffc53 Mon Sep 17 00:00:00 2001 From: DragonFive <1690302963@qq.com> Date: Wed, 26 Nov 2025 11:41:44 +0800 Subject: [PATCH] feat: add rec proto,serivce and utils for rec framework[2/6]. --- xllm/api_service/CMakeLists.txt | 2 + xllm/api_service/api_service.cpp | 35 ++- xllm/api_service/api_service.h | 2 + .../rec_completion_service_impl.cpp | 228 ++++++++++++++++++ .../api_service/rec_completion_service_impl.h | 48 ++++ xllm/core/common/global_flags.cpp | 6 + xllm/core/common/metrics.cpp | 14 ++ xllm/core/common/metrics.h | 8 + xllm/core/common/types.h | 5 + xllm/core/util/CMakeLists.txt | 9 +- xllm/core/util/utils.cpp | 95 ++++++++ xllm/core/util/utils.h | 4 + xllm/proto/CMakeLists.txt | 1 + xllm/proto/completion.proto | 6 + xllm/proto/rec.proto | 119 +++++++++ 15 files changed, 575 insertions(+), 7 deletions(-) create mode 100644 xllm/api_service/rec_completion_service_impl.cpp create mode 100644 xllm/api_service/rec_completion_service_impl.h create mode 100644 xllm/proto/rec.proto diff --git a/xllm/api_service/CMakeLists.txt b/xllm/api_service/CMakeLists.txt index b200409c9..dcb13c121 100644 --- a/xllm/api_service/CMakeLists.txt +++ b/xllm/api_service/CMakeLists.txt @@ -8,6 +8,7 @@ cc_library( api_service_impl.h call.h completion_service_impl.h + rec_completion_service_impl.h chat_service_impl.h embedding_service_impl.h image_generation_service_impl.h @@ -23,6 +24,7 @@ cc_library( api_service.cpp call.cpp completion_service_impl.cpp + rec_completion_service_impl.cpp chat_service_impl.cpp embedding_service_impl.cpp image_generation_service_impl.cpp diff --git a/xllm/api_service/api_service.cpp b/xllm/api_service/api_service.cpp index 9e797a4f7..c50b047a6 100644 --- a/xllm/api_service/api_service.cpp +++ b/xllm/api_service/api_service.cpp @@ -27,6 +27,8 @@ limitations under the License. #include "core/common/metrics.h" #include "core/runtime/dit_master.h" #include "core/runtime/llm_master.h" +// TODO. add following when next pr. +// #include "core/runtime/rec_master.h" #include "core/runtime/vlm_master.h" #include "core/util/closure_guard.h" #include "embedding.pb.h" @@ -70,6 +72,11 @@ APIService::APIService(Master* master, image_generation_service_impl_ = std::make_unique( dynamic_cast(master), model_names); + } else if (FLAGS_backend == "rec") { + // TODO. delete this when next pr. + using RecMaster = LLMMaster; + rec_completion_service_impl_ = std::make_unique( + dynamic_cast(master), model_names); } models_service_impl_ = ServiceImplFactory::create_service_impl( @@ -80,7 +87,27 @@ void APIService::Completions(::google::protobuf::RpcController* controller, const proto::CompletionRequest* request, proto::CompletionResponse* response, ::google::protobuf::Closure* done) { - // TODO with xllm-service + xllm::ClosureGuard done_guard( + done, + std::bind(request_in_metric, nullptr), + std::bind(request_out_metric, (void*)controller)); + if (!request || !response || !controller) { + LOG(ERROR) << "brpc request | respose | controller is null."; + return; + } + auto ctrl = reinterpret_cast(controller); + auto arena = response->GetArena(); + std::shared_ptr call = std::make_shared( + ctrl, + done_guard.release(), + const_cast(request), + response, + arena != nullptr); + if (FLAGS_backend == "llm" || FLAGS_backend == "vlm") { + completion_service_impl_->process_async(call); + } else if (FLAGS_backend == "rec") { + rec_completion_service_impl_->process_async(call); + } } void APIService::CompletionsHttp(::google::protobuf::RpcController* controller, @@ -116,7 +143,11 @@ void APIService::CompletionsHttp(::google::protobuf::RpcController* controller, std::shared_ptr call = std::make_shared( ctrl, done_guard.release(), req_pb, resp_pb, arena != nullptr); - completion_service_impl_->process_async(call); + if (FLAGS_backend == "llm" || FLAGS_backend == "vlm") { + completion_service_impl_->process_async(call); + } else if (FLAGS_backend == "rec") { + rec_completion_service_impl_->process_async(call); + } } void APIService::ChatCompletions(::google::protobuf::RpcController* controller, diff --git a/xllm/api_service/api_service.h b/xllm/api_service/api_service.h index 4236fa2da..427843ae0 100644 --- a/xllm/api_service/api_service.h +++ b/xllm/api_service/api_service.h @@ -21,6 +21,7 @@ limitations under the License. #include "image_generation_service_impl.h" #include "models_service_impl.h" #include "qwen3_rerank_service_impl.h" +#include "rec_completion_service_impl.h" #include "rerank_service_impl.h" #include "xllm_service.pb.h" @@ -124,6 +125,7 @@ class APIService : public proto::XllmAPIService { std::unique_ptr models_service_impl_; std::unique_ptr image_generation_service_impl_; std::unique_ptr rerank_service_impl_; + std::unique_ptr rec_completion_service_impl_; }; } // namespace xllm diff --git a/xllm/api_service/rec_completion_service_impl.cpp b/xllm/api_service/rec_completion_service_impl.cpp new file mode 100644 index 000000000..fddc9e2a0 --- /dev/null +++ b/xllm/api_service/rec_completion_service_impl.cpp @@ -0,0 +1,228 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "rec_completion_service_impl.h" + +#include +#include +#include +#include + +#include +#include + +#include "common/global_flags.h" +#include "common/instance_name.h" +#include "completion.pb.h" +#include "core/framework/request/mm_data.h" +#include "core/framework/request/request_output.h" +#include "core/runtime/llm_master.h" +// TODO. add following when next pr. +// #include "core/runtime/rec_master.h" +#include "core/util/utils.h" + +#define likely(x) __builtin_expect(!!(x), 1) +#define unlikely(x) __builtin_expect(!!(x), 0) + +namespace xllm { +namespace { +void set_logprobs(proto::Choice* choice, + const std::optional>& logprobs) { + if (!logprobs.has_value() || logprobs.value().empty()) { + return; + } + + auto* proto_logprobs = choice->mutable_logprobs(); + for (const auto& logprob : logprobs.value()) { + proto_logprobs->add_tokens(logprob.token); + proto_logprobs->add_token_ids(logprob.token_id); + proto_logprobs->add_token_logprobs(logprob.logprob); + } +} + +bool send_result_to_client_brpc_rec(std::shared_ptr call, + const std::string& request_id, + int64_t created_time, + const std::string& model, + const RequestOutput& req_output) { + auto& response = call->response(); + response.set_object("text_completion"); + response.set_id(request_id); + response.set_created(created_time); + response.set_model(model); + + // add choices into response + response.mutable_choices()->Reserve(req_output.outputs.size()); + for (const auto& output : req_output.outputs) { + auto* choice = response.add_choices(); + choice->set_index(output.index); + choice->set_text(output.text); + set_logprobs(choice, output.logprobs); + if (output.finish_reason.has_value()) { + choice->set_finish_reason(output.finish_reason.value()); + } + } + + // add usage statistics + if (req_output.usage.has_value()) { + const auto& usage = req_output.usage.value(); + auto* proto_usage = response.mutable_usage(); + proto_usage->set_prompt_tokens( + static_cast(usage.num_prompt_tokens)); + proto_usage->set_completion_tokens( + static_cast(usage.num_generated_tokens)); + proto_usage->set_total_tokens(static_cast(usage.num_total_tokens)); + } + + // Add rec specific output tensors + auto output_tensor = response.mutable_output_tensors()->Add(); + output_tensor->set_name("rec_result"); + // TODO: add following when next pr. + // if (FLAGS_enable_constrained_decoding) { + if (true) { + output_tensor->set_datatype(proto::DataType::INT64); + output_tensor->mutable_shape()->Add(req_output.outputs.size()); + output_tensor->mutable_shape()->Add(1); // Single item per output + // TODO: add following when next pr. + /* + auto context = output_tensor->mutable_contents(); + for (int i = 0; i < req_output.outputs.size(); ++i) { + if (req_output.outputs[i].item_ids.has_value()) { + context->mutable_int64_contents()->Add( + req_output.outputs[i].item_ids.value()); + } + } + */ + } else { + output_tensor->set_datatype(proto::DataType::INT32); + + output_tensor->mutable_shape()->Add(req_output.outputs.size()); + output_tensor->mutable_shape()->Add(req_output.outputs[0].token_ids.size()); + + auto context = output_tensor->mutable_contents(); + for (int i = 0; i < req_output.outputs.size(); ++i) { + // LOG(INFO) << req_output.outputs[i].token_ids; + context->mutable_int_contents()->Add( + req_output.outputs[i].token_ids.begin(), + req_output.outputs[i].token_ids.end()); + } + } + + return call->write_and_finish(response); +} + +} // namespace + +RecCompletionServiceImpl::RecCompletionServiceImpl( + RecMaster* master, + const std::vector& models) + : APIServiceImpl(models), master_(master) { + CHECK(master_ != nullptr); +} + +void RecCompletionServiceImpl::process_async_impl( + std::shared_ptr call) { + const auto& rpc_request = call->request(); + + // check if model is supported + const auto& model = rpc_request.model(); + if (unlikely(!models_.contains(model))) { + call->finish_with_error(StatusCode::UNKNOWN, "Model not supported"); + return; + } + + // Check if the request is being rate-limited. + if (unlikely(master_->get_rate_limiter()->is_limited())) { + call->finish_with_error( + StatusCode::RESOURCE_EXHAUSTED, + "The number of concurrent requests has reached the limit."); + return; + } + + RequestParams request_params( + rpc_request, call->get_x_request_id(), call->get_x_request_time()); + bool include_usage = false; + if (rpc_request.has_stream_options()) { + include_usage = rpc_request.stream_options().include_usage(); + } + + std::optional> prompt_tokens = std::nullopt; + if (rpc_request.has_routing()) { + prompt_tokens = std::vector{}; + prompt_tokens->reserve(rpc_request.token_ids_size()); + for (int i = 0; i < rpc_request.token_ids_size(); i++) { + prompt_tokens->emplace_back(rpc_request.token_ids(i)); + } + + request_params.decode_address = rpc_request.routing().decode_name(); + } + + const auto& rpc_request_ref = call->request(); + std::optional mm_data = std::nullopt; + if (rpc_request_ref.input_tensors_size()) { + // HISTOGRAM_OBSERVE(rec_input_first_dim, + // rpc_request_ref.input_tensors(0).shape(0)); + + MMDict mm_dict; + for (int i = 0; i < rpc_request_ref.input_tensors_size(); ++i) { + const auto& tensor = rpc_request_ref.input_tensors(i); + mm_dict[tensor.name()] = + xllm::util::convert_rec_tensor_to_torch(tensor).to(torch::kBFloat16); + } + mm_data = std::move(MMData(MMType::EMBEDDING, mm_dict)); + } + + // schedule the request + auto saved_streaming = request_params.streaming; + auto saved_request_id = request_params.request_id; + master_->handle_request( + std::move(rpc_request_ref.prompt()), + std::move(prompt_tokens), + // TODO. add following when next pr. + // std::move(mm_data), + std::move(request_params), + // TODO. delete this when next pr. + call.get(), + [call, + model, + master = master_, + stream = std::move(saved_streaming), + include_usage = include_usage, + request_id = saved_request_id, + created_time = absl::ToUnixSeconds(absl::Now())]( + const RequestOutput& req_output) -> bool { + if (req_output.status.has_value()) { + const auto& status = req_output.status.value(); + if (!status.ok()) { + // Reduce the number of concurrent requests when a request is + // finished with error. + master->get_rate_limiter()->decrease_one_request(); + + return call->finish_with_error(status.code(), status.message()); + } + } + + // Reduce the number of concurrent requests when a request is finished + // or canceled. + if (req_output.finished || req_output.cancelled) { + master->get_rate_limiter()->decrease_one_request(); + } + + return send_result_to_client_brpc_rec( + call, request_id, created_time, model, req_output); + }); +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/api_service/rec_completion_service_impl.h b/xllm/api_service/rec_completion_service_impl.h new file mode 100644 index 000000000..e383a103f --- /dev/null +++ b/xllm/api_service/rec_completion_service_impl.h @@ -0,0 +1,48 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include "api_service_impl.h" +#include "completion.pb.h" +#include "rec.pb.h" +#include "stream_call.h" + +namespace xllm { + +using CompletionCall = + StreamCall; + +// TODO. add following when next pr. +// class RecMaster; +using RecMaster = LLMMaster; + +// a class to handle completion requests +class RecCompletionServiceImpl final : public APIServiceImpl { + public: + RecCompletionServiceImpl(RecMaster* master, + const std::vector& models); + + // brpc call_data needs to use shared_ptr + void process_async_impl(std::shared_ptr call); + + private: + DISALLOW_COPY_AND_ASSIGN(RecCompletionServiceImpl); + RecMaster* master_ = nullptr; +}; + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index db5a38b22..b8e5ad90e 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -424,6 +424,12 @@ DEFINE_bool( "The default prefetching ratio for gateup weight is 40%." "If adjustments are needed, e.g. export PREFETCH_COEFFOCIENT=0.5"); +// rec prefill-only mode +DEFINE_bool(enable_rec_prefill_only, + false, + "Enable rec prefill-only mode (no decoder self-attention blocks " + "allocation)."); + // --- dp load balance --- DEFINE_bool( diff --git a/xllm/core/common/metrics.cpp b/xllm/core/common/metrics.cpp index 5f792f2b9..c784caac0 100644 --- a/xllm/core/common/metrics.cpp +++ b/xllm/core/common/metrics.cpp @@ -181,6 +181,20 @@ DEFINE_COUNTER(proto_latency_seconds_o2proto, DEFINE_COUNTER(prepare_input_latency_seconds, "Latency of preparing input in seconds"); +// rec engine metrics +DEFINE_COUNTER(prepare_input_latency_microseconds, + "Latency of preparing input in microseconds"); +DEFINE_COUNTER(rec_first_token_latency_microseconds, + "Latency of rec first token generation in microseconds"); +DEFINE_COUNTER(rec_second_token_latency_microseconds, + "Latency of rec second token generation in microseconds"); +DEFINE_COUNTER(rec_third_token_latency_microseconds, + "Latency of rec third token generation in microseconds"); +DEFINE_COUNTER(rec_sampling_latency_microseconds, + "Latency of rec sampling in microseconds"); +DEFINE_HISTOGRAM(expand_beam_latency_microseconds, + "Histogram of expand beam latency in microseconds"); + // multi node metrics DEFINE_COUNTER(worker_service_latency_seconds, "Worker service execution latency in seconds"); diff --git a/xllm/core/common/metrics.h b/xllm/core/common/metrics.h index 48663341d..82c9f231a 100644 --- a/xllm/core/common/metrics.h +++ b/xllm/core/common/metrics.h @@ -205,6 +205,14 @@ DECLARE_COUNTER(proto_latency_seconds_o2proto); // engine metrics DECLARE_COUNTER(prepare_input_latency_seconds); +// rec engine metrics +DECLARE_COUNTER(prepare_input_latency_microseconds); +DECLARE_COUNTER(rec_first_token_latency_microseconds); +DECLARE_COUNTER(rec_second_token_latency_microseconds); +DECLARE_COUNTER(rec_third_token_latency_microseconds); +DECLARE_COUNTER(rec_sampling_latency_microseconds); +DECLARE_HISTOGRAM(expand_beam_latency_microseconds); + // multi node metrics DECLARE_COUNTER(worker_service_latency_seconds); DECLARE_COUNTER(engine_latency_seconds); diff --git a/xllm/core/common/types.h b/xllm/core/common/types.h index 159739062..0339d4e3d 100644 --- a/xllm/core/common/types.h +++ b/xllm/core/common/types.h @@ -31,6 +31,7 @@ class EngineType { SSM = 1, VLM = 2, DIT = 3, + REC = 4, INVALID = -1, }; @@ -44,6 +45,8 @@ class EngineType { value_ = VLM; } else if (str == "DIT") { value_ = DIT; + } else if (str == "REC") { + value_ = REC; } else { value_ = INVALID; } @@ -68,6 +71,8 @@ class EngineType { return "VLM"; } else if (this->value_ == DIT) { return "DIT"; + } else if (this->value_ == REC) { + return "REC"; } else { return "INVALID"; } diff --git a/xllm/core/util/CMakeLists.txt b/xllm/core/util/CMakeLists.txt index 2c11cc3bb..e309f58e0 100644 --- a/xllm/core/util/CMakeLists.txt +++ b/xllm/core/util/CMakeLists.txt @@ -33,6 +33,8 @@ cc_library( device_name_utils.cpp env_var.cpp http_downloader.cpp + # TODO. add following at next pr. + # hash_util.cpp json_reader.cpp net.cpp pretty_print.cpp @@ -52,7 +54,9 @@ cc_library( Boost::serialization absl::synchronization ${Python_LIBRARIES} + proto::xllm_proto :platform + SMHasherSupport ) target_link_libraries(util PRIVATE OpenSSL::SSL OpenSSL::Crypto) add_dependencies(util brpc-static) @@ -72,8 +76,3 @@ cc_test( ) target_link_libraries(util_test PRIVATE brpc leveldb::leveldb OpenSSL::SSL OpenSSL::Crypto) add_dependencies(util_test brpc-static) - - - - - diff --git a/xllm/core/util/utils.cpp b/xllm/core/util/utils.cpp index f03538640..b66db9bdc 100755 --- a/xllm/core/util/utils.cpp +++ b/xllm/core/util/utils.cpp @@ -148,5 +148,100 @@ std::vector cal_vec_split_index(uint32_t vec_size, return split_index; } +torch::Dtype convert_rec_type_to_torch(proto::DataType data_type) { + // Future extensions go here. + switch (data_type) { + case proto::DataType::FLOAT: + return torch::kFloat32; + + case proto::DataType::BFLOAT16: + return torch::kBFloat16; + + case proto::DataType::BOOL: + return torch::kBool; + + case proto::DataType::UINT8: + return torch::kUInt8; + + case proto::DataType::INT8: + return torch::kInt8; + + case proto::DataType::INT16: + return torch::kInt16; + + default: + throw std::runtime_error("Unsupported data type: " + + std::to_string(static_cast(data_type))); + } +} + +torch::Tensor convert_rec_tensor_to_torch( + const proto::InferInputTensor& input_tensor) { + std::vector shape; + shape.reserve(input_tensor.shape_size()); + for (int i = 0; i < input_tensor.shape_size(); ++i) { + shape.push_back(input_tensor.shape(i)); + } + + if (!input_tensor.has_contents()) { + throw std::runtime_error("Input tensor '" + input_tensor.name() + + "' has no contents"); + } + + const auto& contents = input_tensor.contents(); + torch::Dtype dtype = convert_rec_type_to_torch(input_tensor.data_type()); + + switch (dtype) { + case torch::kFloat32: { + // Directly use protobuf's float array + const auto& data = contents.fp32_contents(); + return torch::from_blob( + const_cast(data.data()), + shape, + torch::dtype(torch::kFloat32).requires_grad(false)) + .clone(); // Clone to ensure independent memory + } + // not support now. + // case torch::kFloat16: { + // // Need type conversion (protobuf usually stores float16 as uint16) + // const auto& data = contents.bytes_contents(); + // std::vector half_data; + // half_data.reserve(data.size()); + // for (auto val : data) { + // half_data.push_back(static_cast(val)); + // } + // return torch::tensor(half_data, torch::dtype(torch::kFloat16)) + // .view(shape); + // } + + case torch::kInt32: { + const auto& data = contents.int_contents(); + return torch::from_blob(const_cast(data.data()), + shape, + torch::dtype(torch::kInt32)) + .clone(); + } + + case torch::kInt64: { + const auto& data = contents.int64_contents(); + return torch::from_blob(const_cast(data.data()), + shape, + torch::dtype(torch::kInt64)) + .clone(); + } + + case torch::kBool: { + const auto& data = contents.bool_contents(); + return torch::tensor(std::vector(data.begin(), data.end()), + torch::dtype(torch::kBool)) + .view(shape); + } + + default: + throw std::runtime_error("Unhandled data type conversion for: " + + std::to_string(static_cast(dtype))); + } +} + } // namespace util } // namespace xllm diff --git a/xllm/core/util/utils.h b/xllm/core/util/utils.h index 51491972d..3c95ca84e 100644 --- a/xllm/core/util/utils.h +++ b/xllm/core/util/utils.h @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "rec.pb.h" #include "slice.h" namespace xllm { @@ -71,5 +72,8 @@ bool match_suffix(const Slice& data, const Slice& suffix); std::vector cal_vec_split_index(uint32_t vec_size, uint32_t part_num); +torch::Tensor convert_rec_tensor_to_torch( + const proto::InferInputTensor& input_tensor); + } // namespace util } // namespace xllm diff --git a/xllm/proto/CMakeLists.txt b/xllm/proto/CMakeLists.txt index 38be2b75d..5014c409f 100644 --- a/xllm/proto/CMakeLists.txt +++ b/xllm/proto/CMakeLists.txt @@ -6,6 +6,7 @@ proto_library( SRCS tensor.proto common.proto + rec.proto completion.proto chat.proto multimodal.proto diff --git a/xllm/proto/completion.proto b/xllm/proto/completion.proto index 65e929a8c..a645fc44c 100644 --- a/xllm/proto/completion.proto +++ b/xllm/proto/completion.proto @@ -4,6 +4,7 @@ option go_package = "jd.com/jd-infer/xllm;xllm"; package xllm.proto; import "common.proto"; +import "rec.proto"; // Next ID: 26 message CompletionRequest { @@ -97,6 +98,8 @@ message CompletionRequest { optional int32 beam_width = 29; optional bool add_special_tokens = 30; + // tensor for rec embedding. + repeated InferInputTensor input_tensors = 31; } message LogProbs { @@ -144,5 +147,8 @@ message CompletionResponse { // usage statistics for the completion request. Usage usage = 6; + + // for rec output + repeated InferOutputTensor output_tensors = 7; } diff --git a/xllm/proto/rec.proto b/xllm/proto/rec.proto new file mode 100644 index 000000000..5504b865a --- /dev/null +++ b/xllm/proto/rec.proto @@ -0,0 +1,119 @@ +syntax = "proto3"; +option go_package = "jd.com/jd-infer/xllm;xllm"; +package xllm.proto; +import "common.proto"; + +option cc_enable_arenas = true; +option cc_generic_services = true; +enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10; + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + // Non-IEEE floating-point format based on papers + // FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433, + // 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf. + // Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear. + // The computation usually happens inside a block quantize / dequantize + // fused by the runtime. + FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf + FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero + FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients + FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, not inf, mostly used for gradients, no negative zero + // 4-bit integer data types + UINT4 = 21; // Unsigned integer in range [0, 15] + INT4 = 22; // Signed integer in range [-8, 7], using two's-complement representation + // 4-bit floating point data types + FLOAT4E2M1 = 23; + // E8M0 type used as the scale for microscaling (MX) formats: + // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + FLOAT8E8M0 = 24; + // Future extensions go here. +} +// The data contained in a tensor represented by the repeated type +// that matches the tensor's data type. Protobuf oneof is not used +// because oneofs cannot contain repeated fields. +message InferTensorContents +{ + // Representation for BOOL data type. The size must match what is + // expected by the tensor's shape. The contents must be the flattened, + // one-dimensional, row-major order of the tensor elements. + repeated bool bool_contents = 1; + // Representation for INT8, INT16, and INT32 data types. The size + // must match what is expected by the tensor's shape. The contents + // must be the flattened, one-dimensional, row-major order of the + // tensor elements. + repeated int32 int_contents = 2; + // Representation for INT64 data types. The size must match what + // is expected by the tensor's shape. The contents must be the + // flattened, one-dimensional, row-major order of the tensor elements. + repeated int64 int64_contents = 3; + // Representation for UINT8, UINT16, and UINT32 data types. The size + // must match what is expected by the tensor's shape. The contents + // must be the flattened, one-dimensional, row-major order of the + // tensor elements. + repeated uint32 uint_contents = 4; + // Representation for UINT64 data types. The size must match what + // is expected by the tensor's shape. The contents must be the + // flattened, one-dimensional, row-major order of the tensor elements. + repeated uint64 uint64_contents = 5; + // Representation for FP32 data type. The size must match what is + // expected by the tensor's shape. The contents must be the flattened, + // one-dimensional, row-major order of the tensor elements. + repeated float fp32_contents = 6; + // Representation for FP64 data type. The size must match what is + // expected by the tensor's shape. The contents must be the flattened, + // one-dimensional, row-major order of the tensor elements. + repeated double fp64_contents = 7; + // Representation for BYTES data type. The size must match what is + // expected by the tensor's shape. The contents must be the flattened, + // one-dimensional, row-major order of the tensor elements. + repeated bytes bytes_contents = 8; +} +// An input tensor for an inference request. +message InferInputTensor +{ + // The tensor name. + string name = 1; + // The tensor data type. + DataType data_type = 2; + // The tensor shape. + repeated int64 shape = 3; + // The tensor contents using a data-type format. This field must + // not be specified if "raw" tensor contents are being used for + // the inference request. + InferTensorContents contents = 4; +} +// An output tensor returned for an inference request. +message InferOutputTensor +{ + // The tensor name. + string name = 1; + // The tensor data type. + DataType datatype = 2; + // The tensor shape. + repeated int64 shape = 3; + // The tensor contents using a data-type format. This field must + // not be specified if "raw" tensor contents are being used for + // the inference response. + InferTensorContents contents = 4; +} \ No newline at end of file