diff --git a/src/test/embeddingsnode_test.cpp b/src/test/embeddingsnode_test.cpp index 6443bba525..7a74d182b4 100644 --- a/src/test/embeddingsnode_test.cpp +++ b/src/test/embeddingsnode_test.cpp @@ -786,3 +786,19 @@ TEST_F(EmbeddingsTokenizeHttpTest, tokenizeBatchWithPadToMaxLen) { ovms::StatusCode::OK); AssertTokenizationResult(response, expectedTokens); } + +TEST_F(EmbeddingsTokenizeHttpTest, tokenizeEmptyNestedArray) { + assertTokenizeWithInvalidTextReturnsError(handler.get(), "embeddings_ov", "[[]]", response, comp, responseComponents, writer, multiPartParser); +} + +TEST_F(EmbeddingsTokenizeHttpTest, tokenizeMultipleEmptyNestedArrays) { + assertTokenizeWithInvalidTextReturnsError(handler.get(), "embeddings_ov", "[[], [], []]", response, comp, responseComponents, writer, multiPartParser); +} + +TEST_F(EmbeddingsTokenizeHttpTest, tokenizeMultipleEmptyNestedArraysAndOneNonEmpty) { + assertTokenizeWithInvalidTextReturnsError(handler.get(), "embeddings_ov", R"([[], ["hello world"], []])", response, comp, responseComponents, writer, multiPartParser); +} + +TEST_F(EmbeddingsTokenizeHttpTest, tokenizeEmptyWithArrayMultipleLevelsOfNesting) { + assertTokenizeWithInvalidTextReturnsError(handler.get(), "embeddings_ov", "[[[[[]]]]]", response, comp, responseComponents, writer, multiPartParser); +} diff --git a/src/test/llm/tokenize_endpoint_test.cpp b/src/test/llm/tokenize_endpoint_test.cpp index 17c390b650..602961c5f5 100644 --- a/src/test/llm/tokenize_endpoint_test.cpp +++ b/src/test/llm/tokenize_endpoint_test.cpp @@ -417,9 +417,6 @@ TEST_P(LLMTokenizeTests, tokenizeArrayOfStringsWithPaddingSideLeft) { TEST_P(LLMTokenizeTests, tokenizeStringWithAddSpecialTokens) { auto params = GetParam(); - if (params.modelName == "vlm_cb_regular" || params.modelName == "vlm_legacy_regular") { - GTEST_SKIP() << "Skipping test for " << params.modelName; - } std::string requestBody = R"( { @@ -443,6 +440,26 @@ TEST_P(LLMTokenizeTests, tokenizeStringWithAddSpecialTokens) { ASSERT_GE(tokens.Size(), params.expectedTokens.size()); } +TEST_P(LLMTokenizeTests, tokenizeEmptyNestedArray) { + auto params = GetParam(); + assertTokenizeWithInvalidTextReturnsError(handler.get(), params.modelName, "[[]]", response, comp, responseComponents, writer, multiPartParser); +} + +TEST_P(LLMTokenizeTests, tokenizeMultipleEmptyNestedArrays) { + auto params = GetParam(); + assertTokenizeWithInvalidTextReturnsError(handler.get(), params.modelName, "[[], [], []]", response, comp, responseComponents, writer, multiPartParser); +} + +TEST_P(LLMTokenizeTests, tokenizeMultipleEmptyNestedArraysAndOneNonEmpty) { + auto params = GetParam(); + assertTokenizeWithInvalidTextReturnsError(handler.get(), params.modelName, R"([[], ["hello world"], []])", response, comp, responseComponents, writer, multiPartParser); +} + +TEST_P(LLMTokenizeTests, tokenizeEmptyWithArrayMultipleLevelsOfNesting) { + auto params = GetParam(); + assertTokenizeWithInvalidTextReturnsError(handler.get(), params.modelName, "[[[[[]]]]]", response, comp, responseComponents, writer, multiPartParser); +} + INSTANTIATE_TEST_SUITE_P( LLMTokenizeTestInstances, LLMTokenizeTests, diff --git a/src/test/reranknode_test.cpp b/src/test/reranknode_test.cpp index fa19dba86e..057e83e6f5 100644 --- a/src/test/reranknode_test.cpp +++ b/src/test/reranknode_test.cpp @@ -605,3 +605,19 @@ TEST_F(RerankTokenizeHttpTest, tokenizeIgnoreAddSpecialTokensParameter) { ovms::StatusCode::OK); AssertTokenizationResult(response, expectedTokens); } + +TEST_F(RerankTokenizeHttpTest, tokenizeEmptyNestedArray) { + assertTokenizeWithInvalidTextReturnsError(handler.get(), "rerank_ov", "[[]]", response, comp, responseComponents, writer, multiPartParser); +} + +TEST_F(RerankTokenizeHttpTest, tokenizeMultipleEmptyNestedArrays) { + assertTokenizeWithInvalidTextReturnsError(handler.get(), "rerank_ov", "[[], [], []]", response, comp, responseComponents, writer, multiPartParser); +} + +TEST_F(RerankTokenizeHttpTest, tokenizeMultipleEmptyNestedArraysAndOneNonEmpty) { + assertTokenizeWithInvalidTextReturnsError(handler.get(), "rerank_ov", R"([[], ["hello world"], []])", response, comp, responseComponents, writer, multiPartParser); +} + +TEST_F(RerankTokenizeHttpTest, tokenizeEmptyWithArrayMultipleLevelsOfNesting) { + assertTokenizeWithInvalidTextReturnsError(handler.get(), "rerank_ov", "[[[[[]]]]]", response, comp, responseComponents, writer, multiPartParser); +} diff --git a/src/test/test_http_utils.hpp b/src/test/test_http_utils.hpp index 62ca393c40..b59214dbf0 100644 --- a/src/test/test_http_utils.hpp +++ b/src/test/test_http_utils.hpp @@ -14,6 +14,7 @@ // limitations under the License. //***************************************************************************** #pragma once +#include #include #include #include @@ -89,3 +90,18 @@ class V3HttpTest : public ::testing::Test { handler.reset(); } }; + +inline void assertTokenizeWithInvalidTextReturnsError( + ovms::HttpRestApiHandler* handler, + const std::string& modelName, + const std::string& textJson, + std::string& response, + ovms::HttpRequestComponents& comp, + ovms::HttpResponseComponents& responseComponents, + std::shared_ptr& writer, + std::shared_ptr& multiPartParser) { + std::string requestBody = R"({"model": ")" + modelName + R"(", "text": )" + textJson + R"(})"; + std::string endpoint = "/v3/tokenize"; + ovms::Status status = handler->dispatchToProcessor(endpoint, requestBody, &response, comp, responseComponents, writer, multiPartParser); + ASSERT_EQ(status, ovms::StatusCode::MEDIAPIPE_EXECUTION_ERROR) << status.string(); +} diff --git a/src/tokenize/tokenize_parser.cpp b/src/tokenize/tokenize_parser.cpp index d6ee9dcdd1..e5dbb54620 100644 --- a/src/tokenize/tokenize_parser.cpp +++ b/src/tokenize/tokenize_parser.cpp @@ -147,28 +147,32 @@ std::variant TokenizeParser::parseI InputType input_type = InputType::NONE; for (auto& input : it->value.GetArray()) { if (input.IsArray()) { + const auto array = input.GetArray(); + if (array.Size() == 0) { + return "inner arrays in " + field_name + " should not be empty"; + } if (input_type != InputType::NONE && input_type != InputType::INT_VEC && input_type != InputType::STRING_VEC) return field_name + " must be homogeneous"; - if (input.GetArray()[0].IsInt()) { + if (array[0].IsInt()) { if (input_type == InputType::STRING_VEC) return field_name + " must be homogeneous"; input_type = InputType::INT_VEC; std::vector ints; - ints.reserve(input.GetArray().Size()); - for (auto& val : input.GetArray()) { + ints.reserve(array.Size()); + for (auto& val : array) { if (val.IsInt()) ints.push_back(val.GetInt()); else return field_name + " must be homogeneous"; } input_tokens.emplace_back(std::move(ints)); - } else if (input.GetArray()[0].IsString()) { + } else if (array[0].IsString()) { if (input_type == InputType::INT_VEC) return field_name + " must be homogeneous"; input_type = InputType::STRING_VEC; std::vector strings; - strings.reserve(input.GetArray().Size()); - for (auto& val : input.GetArray()) { + strings.reserve(array.Size()); + for (auto& val : array) { if (val.IsString()) strings.push_back(val.GetString()); else