Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/llm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,23 @@ ovms_cc_library(
"//src:image_conversion",
"//src/filesystem:libovmsfilesystem",
"@stb//:image",
":canonical_request",
":openai_request",
":output_parsers",
"//third_party:genai",],
visibility = ["//visibility:public"],
)

ovms_cc_library(
name = "canonical_request",
hdrs = ["preprocessing/canonical_request.hpp"],
deps = [
"//third_party:genai",
":openai_request",
],
visibility = ["//visibility:public"],
)

ovms_cc_library(
name = "openai_completions_api_handler",
hdrs = ["apis/openai_completions.hpp", "apis/openai_json_response.hpp"],
Expand Down Expand Up @@ -296,6 +307,7 @@ ovms_cc_library(
"language_model/continuous_batching/llm_executor.hpp",
"language_model/continuous_batching/servable_initializer.hpp",
"visual_language_model/continuous_batching/servable.hpp",
"visual_language_model/image_prompt_utils.hpp",
"language_model/legacy/servable.hpp",
"language_model/legacy/servable_initializer.hpp",
"language_model/legacy/legacy_executor.hpp",
Expand All @@ -307,6 +319,7 @@ ovms_cc_library(
"servable_initializer.cpp",
"language_model/continuous_batching/servable.cpp",
"language_model/continuous_batching/servable_initializer.cpp",
"visual_language_model/image_prompt_utils.cpp",
"visual_language_model/continuous_batching/servable.cpp",
"language_model/legacy/servable.cpp",
"language_model/legacy/servable_initializer.cpp",
Expand Down
48 changes: 35 additions & 13 deletions src/llm/apis/openai_api_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,11 +348,9 @@ absl::Status OpenAIApiHandler::parseTools() {
return absl::InvalidArgumentError("tool_choice is not a valid JSON object or string");
}
}
bool jsonChanged = false;
if (toolChoice == "none") {
// remove tools from the request
doc.RemoveMember("tools");
jsonChanged = true;
}
auto it = doc.FindMember("tools");
if (it != doc.MemberEnd() && !it->value.IsNull()) {
Expand Down Expand Up @@ -405,7 +403,6 @@ absl::Status OpenAIApiHandler::parseTools() {
// If toolChoice is set to a specific function name, we keep only that tool
if (toolChoice != "auto" && toolChoice != "required" && toolChoice != functionName) {
it->value.Erase(&obj);
jsonChanged = true;
continue;
}

Expand All @@ -430,16 +427,10 @@ absl::Status OpenAIApiHandler::parseTools() {
}

request.toolChoice = toolChoice;
if (jsonChanged) {
StringBuffer buffer;
Writer<StringBuffer> writer(buffer);
doc.Accept(writer);
request.processedJson = buffer.GetString();
}
return absl::OkStatus();
}

absl::StatusOr<std::optional<ov::genai::JsonContainer>> OpenAIApiHandler::parseToolsToJsonContainer() {
absl::StatusOr<std::optional<ov::genai::JsonContainer>> OpenAIApiHandler::parseToolsToJsonContainer() const {
auto it = doc.FindMember("tools");
if (it == doc.MemberEnd() || it->value.IsNull()) {
return std::nullopt;
Expand All @@ -460,7 +451,7 @@ absl::StatusOr<std::optional<ov::genai::JsonContainer>> OpenAIApiHandler::parseT
}
}

absl::StatusOr<std::optional<ov::genai::JsonContainer>> OpenAIApiHandler::parseChatTemplateKwargsToJsonContainer() {
absl::StatusOr<std::optional<ov::genai::JsonContainer>> OpenAIApiHandler::parseChatTemplateKwargsToJsonContainer() const {
auto it = doc.FindMember("chat_template_kwargs");
if (it == doc.MemberEnd() || it->value.IsNull()) {
return std::nullopt;
Expand Down Expand Up @@ -492,15 +483,47 @@ const OpenAIRequest& OpenAIApiHandler::getRequest() const {
return request;
}

absl::StatusOr<CanonicalRequest> OpenAIApiHandler::buildCanonicalRequest(RendererType rendererType) const {
return buildCanonicalRequestImpl(rendererType);
}

absl::StatusOr<const CanonicalRequest*> OpenAIApiHandler::getCanonicalRequest(RendererType rendererType) const {
auto& cache = (rendererType == RendererType::CPP_TOKENIZER) ? cachedCppCanonicalRequest : cachedPyCanonicalRequest;
if (!cache.has_value()) {
auto canonical = buildCanonicalRequest(rendererType);
if (!canonical.ok()) {
return canonical.status();
}
cache = std::move(canonical.value());
}
return &(*cache);
}

const std::string& OpenAIApiHandler::getProcessedJson() const {
return request.processedJson;
auto canonicalRequest = getCanonicalRequest(RendererType::PY_JINJA);
if (canonicalRequest.ok()) {
const auto* pyPath = std::get_if<PyPath>(canonicalRequest.value());
if (pyPath != nullptr) {
return pyPath->processedJson.get();
}
}
static const std::string EMPTY_JSON{};
return EMPTY_JSON;
}

const ImageHistory& OpenAIApiHandler::getImageHistory() const {
auto canonicalRequest = getCanonicalRequest(RendererType::CPP_TOKENIZER);
if (!canonicalRequest.ok()) {
return request.imageHistory;
}
return request.imageHistory;
}
Comment on lines 514 to 520

ov::genai::ChatHistory& OpenAIApiHandler::getChatHistory() {
auto canonicalRequest = getCanonicalRequest(RendererType::CPP_TOKENIZER);
if (!canonicalRequest.ok()) {
return request.chatHistory;
}
return request.chatHistory;
}
Comment on lines 522 to 528

Expand All @@ -512,7 +535,6 @@ std::optional<std::string> OpenAIApiHandler::getResponseFormat() const {
return request.responseFormat;
}

std::optional<std::string> OpenAIApiHandler::getPrompt() const { return request.prompt; }
std::optional<int> OpenAIApiHandler::getNumReturnSequences() const { return request.numReturnSequences; }
StreamOptions OpenAIApiHandler::getStreamOptions() const { return request.streamOptions; }

Expand Down
14 changes: 11 additions & 3 deletions src/llm/apis/openai_api_handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#pragma warning(pop)
#include "../io_processing/output_parser.hpp"
#include "openai_request.hpp"
#include "../preprocessing/canonical_request.hpp"

// Forward declarations for types only used by reference in virtual method signatures
namespace ov {
Expand Down Expand Up @@ -119,6 +120,13 @@ class OpenAIApiHandler {
// Shared VLM workaround: encode text to tokens using tokenizer, validates shape
std::vector<int64_t> encodeTextToTokens(const std::string& text);

virtual absl::StatusOr<CanonicalRequest> buildCanonicalRequestImpl(RendererType rendererType) const = 0;
absl::StatusOr<CanonicalRequest> buildCanonicalRequest(RendererType rendererType) const;

mutable std::optional<CanonicalRequest> cachedCppCanonicalRequest;
mutable std::optional<CanonicalRequest> cachedPyCanonicalRequest;
mutable std::optional<std::string> synthesizedProcessedJson;

public:
OpenAIApiHandler(Document& doc, Endpoint endpoint, std::chrono::time_point<std::chrono::system_clock> creationTime,
ov::genai::Tokenizer tokenizer, const std::string& toolParserName = "", const std::string& reasoningParserName = "") :
Expand Down Expand Up @@ -147,18 +155,18 @@ class OpenAIApiHandler {

// Shared parsing (non-virtual)
absl::Status parseTools();
absl::StatusOr<std::optional<ov::genai::JsonContainer>> parseToolsToJsonContainer();
absl::StatusOr<std::optional<ov::genai::JsonContainer>> parseChatTemplateKwargsToJsonContainer();
absl::StatusOr<std::optional<ov::genai::JsonContainer>> parseToolsToJsonContainer() const;
absl::StatusOr<std::optional<ov::genai::JsonContainer>> parseChatTemplateKwargsToJsonContainer() const;
const bool areToolsAvailable() const;

// Accessors (non-virtual)
const OpenAIRequest& getRequest() const;
std::optional<std::string> getPrompt() const;
std::optional<int> getNumReturnSequences() const;
StreamOptions getStreamOptions() const;
const std::string& getProcessedJson() const;
const ImageHistory& getImageHistory() const;
ov::genai::ChatHistory& getChatHistory();
absl::StatusOr<const CanonicalRequest*> getCanonicalRequest(RendererType rendererType) const;
std::optional<int> getMaxTokens() const;
std::optional<std::string> getResponseFormat() const;
bool isStream() const;
Expand Down
45 changes: 36 additions & 9 deletions src/llm/apis/openai_completions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,9 @@ absl::Status OpenAIChatCompletionsHandler::parseCompletionsPart() {
if (it != doc.MemberEnd()) {
if (!it->value.IsString()) {
return absl::InvalidArgumentError("prompt is not a string");
} else {
request.prompt = it->value.GetString();
}
}
if (!request.prompt.has_value() || !request.prompt.value().size()) {
if (it == doc.MemberEnd() || it->value.GetStringLength() == 0) {
return absl::Status(absl::StatusCode::kInvalidArgument, "prompt is missing");
}
// logprobs: int; 1 value allowed
Expand Down Expand Up @@ -265,16 +263,45 @@ absl::Status OpenAIChatCompletionsHandler::parseMessages(std::optional<std::stri
return status;
}
}
if (jsonChanged) {
StringBuffer buffer;
Writer<StringBuffer> writer(buffer);
doc.Accept(writer);
request.processedJson = buffer.GetString();
}
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Parsed messages successfully");
return absl::OkStatus();
}

absl::StatusOr<CanonicalRequest> OpenAIChatCompletionsHandler::buildCanonicalRequestImpl(RendererType rendererType) const {
if (rendererType == RendererType::CPP_TOKENIZER) {
auto tools = parseToolsToJsonContainer();
if (!tools.ok()) {
return tools.status();
}
auto kwargs = parseChatTemplateKwargsToJsonContainer();
if (!kwargs.ok()) {
return kwargs.status();
}
std::optional<std::string> rawPrompt;
if (endpoint == Endpoint::COMPLETIONS) {
auto promptIt = doc.FindMember("prompt");
if (promptIt != doc.MemberEnd() && promptIt->value.IsString()) {
rawPrompt = std::string(promptIt->value.GetString(), promptIt->value.GetStringLength());
}
}
CppPath cppPath{
std::cref(request.chatHistory),
std::cref(request.imageHistory),
std::move(tools.value()),
std::move(kwargs.value()),
std::move(rawPrompt),
true};
return CanonicalRequest(std::move(cppPath));
}

StringBuffer buffer;
Writer<StringBuffer> writer(buffer);
doc.Accept(writer);
synthesizedProcessedJson = buffer.GetString();
PyPath pyPath{std::cref(synthesizedProcessedJson.value())};
return CanonicalRequest(std::move(pyPath));
}

// --- Unary response serialization ---

std::string OpenAIChatCompletionsHandler::serializeUnaryResponse(const std::vector<ov::genai::GenerationOutput>& generationOutputs) {
Expand Down
1 change: 1 addition & 0 deletions src/llm/apis/openai_completions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class OpenAIChatCompletionsHandler : public OpenAIApiHandler {
absl::Status parseRequest(std::optional<uint32_t> maxTokensLimit, uint32_t bestOfLimit, std::optional<uint32_t> maxModelLength,
std::optional<std::string> allowedLocalMediaPath = std::nullopt, std::optional<std::vector<std::string>> allowedMediaDomains = std::nullopt) override;
absl::Status parseMessages(std::optional<std::string> allowedLocalMediaPath = std::nullopt, std::optional<std::vector<std::string>> allowedMediaDomains = std::nullopt);
absl::StatusOr<CanonicalRequest> buildCanonicalRequestImpl(RendererType rendererType) const override;

std::string serializeUnaryResponse(const std::vector<ov::genai::GenerationOutput>& generationOutputs) override;
std::string serializeUnaryResponse(ov::genai::EncodedResults& results) override;
Expand Down
2 changes: 0 additions & 2 deletions src/llm/apis/openai_request.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ struct StreamOptions {
// Class that maps OpenAI request content.
struct OpenAIRequest {
ov::genai::ChatHistory chatHistory;
std::string processedJson;
ImageHistory imageHistory;
std::optional<std::string> prompt{std::nullopt};
bool stream{false};
StreamOptions streamOptions;
std::string model;
Expand Down
Loading