diff --git a/src/test/tokenize_parser_test.cpp b/src/test/tokenize_parser_test.cpp index 9c45d9a108..21e8bae46d 100644 --- a/src/test/tokenize_parser_test.cpp +++ b/src/test/tokenize_parser_test.cpp @@ -154,7 +154,7 @@ TEST(TokenizeDeserialization, invalidTokenizeMaxLengthType) { auto status = ovms::TokenizeParser::parseTokenizeRequest(d, request); ASSERT_NE(status, absl::OkStatus()); auto error = status.message(); - ASSERT_EQ(error, "max_length should be integer"); + ASSERT_EQ(error, "max_length should be unsigned integer"); } TEST(TokenizeDeserialization, invalidTokenizePadToMaxLengthType) { @@ -228,3 +228,39 @@ TEST(TokenizeDeserialization, invalidTokenizePaddingSideValue) { auto error = status.message(); ASSERT_EQ(error, "padding_side should be either left or right"); } + +TEST(TokenizeDeserialization, invalidTokenizeMaxLengthNegative) { + std::string requestBody = R"( + { + "model": "embeddings", + "text": ["one", "two", "three"], + "max_length": -10 + } + )"; + rapidjson::Document d; + rapidjson::ParseResult ok = d.Parse(requestBody.c_str()); + ovms::TokenizeRequest request; + ASSERT_EQ(ok.Code(), 0); + auto status = ovms::TokenizeParser::parseTokenizeRequest(d, request); + ASSERT_NE(status, absl::OkStatus()); + auto error = status.message(); + ASSERT_EQ(error, "max_length should be unsigned integer"); +} + +TEST(TokenizeDeserialization, invalidTokenizeMaxLengthZero) { + std::string requestBody = R"( + { + "model": "embeddings", + "text": ["one", "two", "three"], + "max_length": 0 + } + )"; + rapidjson::Document d; + rapidjson::ParseResult ok = d.Parse(requestBody.c_str()); + ovms::TokenizeRequest request; + ASSERT_EQ(ok.Code(), 0); + auto status = ovms::TokenizeParser::parseTokenizeRequest(d, request); + ASSERT_NE(status, absl::OkStatus()); + auto error = status.message(); + ASSERT_EQ(error, "max_length should be greater than 0"); +} diff --git a/src/tokenize/tokenize_parser.cpp b/src/tokenize/tokenize_parser.cpp index 9f982491b3..f070606a23 100644 --- a/src/tokenize/tokenize_parser.cpp +++ b/src/tokenize/tokenize_parser.cpp @@ -81,13 +81,17 @@ std::variant TokenizeParser::validateTokenizeReque auto it = parsedJson.FindMember("max_length"); if (it != parsedJson.MemberEnd()) { - if (it->value.IsInt()) { - size_t max_length = it->value.GetInt(); + if (it->value.IsUint()) { + size_t max_length = it->value.GetUint(); + if (max_length == 0) { + return "max_length should be greater than 0"; + } + request.parameters["max_length"] = max_length; // Keep OVMS tokenize API contract: max_length implies truncation. request.parameters["truncation"] = true; } else { - return "max_length should be integer"; + return "max_length should be unsigned integer"; } } it = parsedJson.FindMember("pad_to_max_length");