diff --git a/core/src/main/java/com/redis/vl/query/TextQuery.java b/core/src/main/java/com/redis/vl/query/TextQuery.java index 87e3601..ee4a6ab 100644 --- a/core/src/main/java/com/redis/vl/query/TextQuery.java +++ b/core/src/main/java/com/redis/vl/query/TextQuery.java @@ -1,122 +1,264 @@ package com.redis.vl.query; -import java.util.ArrayList; -import java.util.List; +import com.redis.vl.utils.TokenEscaper; +import java.util.HashMap; +import java.util.Map; +import lombok.Getter; -/** Full-text search query */ +/** + * Full-text search query with support for field weights. + * + *

Python port: Implements text_field_name with Union[str, Dict[str, float]] for weighted text + * search across multiple fields. + * + *

Example usage: + * + *

{@code
+ * // Single field (backward compatible)
+ * TextQuery query = TextQuery.builder()
+ *     .text("search terms")
+ *     .textField("description")
+ *     .build();
+ *
+ * // Multiple fields with weights
+ * TextQuery query = TextQuery.builder()
+ *     .text("search terms")
+ *     .textFieldWeights(Map.of("title", 5.0, "content", 2.0, "tags", 1.0))
+ *     .build();
+ * }
+ */ +@Getter public class TextQuery { private final String text; - private final String textField; private final String scorer; private final Filter filterExpression; - private final List returnFields; + private final Integer numResults; - /** - * Create a text query without a filter expression. - * - * @param text The text to search for - * @param textField The field to search in - * @param scorer The scoring algorithm (e.g., "BM25", "TF IDF") - * @param returnFields List of fields to return in results - */ - public TextQuery(String text, String textField, String scorer, List returnFields) { - this(text, textField, scorer, null, returnFields); - } + /** Field names mapped to their search weights */ + private Map fieldWeights; - /** - * Create a text query with all parameters. - * - * @param text The text to search for - * @param textField The field to search in - * @param scorer The scoring algorithm - * @param filterExpression Optional filter to apply - * @param returnFields List of fields to return in results - */ - public TextQuery( - String text, - String textField, - String scorer, - Filter filterExpression, - List returnFields) { - this.text = text; - this.textField = textField; - this.scorer = scorer; - this.filterExpression = filterExpression; - this.returnFields = returnFields != null ? new ArrayList<>(returnFields) : null; + private TextQuery(Builder builder) { + this.text = builder.text; + this.scorer = builder.scorer; + this.filterExpression = builder.filterExpression; + this.numResults = builder.numResults; + this.fieldWeights = new HashMap<>(builder.fieldWeights); } /** - * Get the search text + * Update the field weights dynamically. * - * @return Search text + * @param fieldWeights Map of field names to weights */ - public String getText() { - return text; + public void setFieldWeights(Map fieldWeights) { + validateFieldWeights(fieldWeights); + this.fieldWeights = new HashMap<>(fieldWeights); } /** - * Get the text field to search in + * Get a copy of the field weights. * - * @return Text field name + * @return Map of field names to weights */ - public String getTextField() { - return textField; + public Map getFieldWeights() { + return new HashMap<>(fieldWeights); } /** - * Get the scoring algorithm + * Build the Redis query string for text search with weighted fields. + * + *

Format: * - * @return Scorer name + *

    + *
  • Single field default weight: {@code @field:(term1 | term2)} + *
  • Single field with weight: {@code @field:(term1 | term2) => { $weight: 5.0 }} + *
  • Multiple fields: {@code (@field1:(terms) => { $weight: 3.0 } | @field2:(terms) => { + * $weight: 2.0 })} + *
+ * + * @return Redis query string */ - public String getScorer() { - return scorer; + public String toQueryString() { + TokenEscaper escaper = new TokenEscaper(); + + // Tokenize and escape the query text + String[] tokens = text.split("\\s+"); + StringBuilder escapedQuery = new StringBuilder(); + + for (int i = 0; i < tokens.length; i++) { + if (i > 0) { + escapedQuery.append(" | "); + } + String cleanToken = + tokens[i].strip().stripLeading().stripTrailing().replace(",", "").toLowerCase(); + escapedQuery.append(escaper.escape(cleanToken)); + } + + String escapedText = escapedQuery.toString(); + + // Build query parts for each field with its weight + StringBuilder queryBuilder = new StringBuilder(); + int fieldCount = 0; + + for (Map.Entry entry : fieldWeights.entrySet()) { + String field = entry.getKey(); + Double weight = entry.getValue(); + + if (fieldCount > 0) { + queryBuilder.append(" | "); + } + + queryBuilder.append("@").append(field).append(":(").append(escapedText).append(")"); + + // Add weight modifier if not default + if (weight != 1.0) { + queryBuilder.append(" => { $weight: ").append(weight).append(" }"); + } + + fieldCount++; + } + + // Wrap multiple fields in parentheses + String textQuery; + if (fieldWeights.size() > 1) { + textQuery = "(" + queryBuilder.toString() + ")"; + } else { + textQuery = queryBuilder.toString(); + } + + // Add filter expression if present + if (filterExpression != null) { + return textQuery + " AND " + filterExpression.build(); + } + + return textQuery; } - /** - * Get the filter expression - * - * @return Filter expression or null - */ - public Filter getFilterExpression() { - return filterExpression; + @Override + public String toString() { + return toQueryString(); } /** - * Get the return fields + * Create a new Builder for TextQuery. * - * @return List of fields to return or null + * @return Builder instance */ - public List getReturnFields() { - return returnFields != null ? new ArrayList<>(returnFields) : null; + public static Builder builder() { + return new Builder(); } - /** - * Build the query string for Redis text search - * - * @return Query string - */ - public String toQueryString() { - StringBuilder query = new StringBuilder(); + /** Builder for TextQuery with support for field weights. */ + public static class Builder { + private String text; + private String scorer = "BM25STD"; + private Filter filterExpression; + private Integer numResults = 10; + private Map fieldWeights = new HashMap<>(); - // Add filter expression if present - if (filterExpression != null) { - query.append(filterExpression.build()).append(" "); + /** + * Set the text to search for. + * + * @param text Search text + * @return Builder + */ + public Builder text(String text) { + this.text = text; + return this; } - // Add text search - if (textField != null && !textField.isEmpty()) { - query.append("@").append(textField).append(":(").append(text).append(")"); - } else { - // Search all text fields - query.append(text); + /** + * Set a single text field to search (backward compatible). + * + * @param fieldName Field name + * @return Builder + */ + public Builder textField(String fieldName) { + this.fieldWeights = Map.of(fieldName, 1.0); + return this; } - return query.toString(); + /** + * Set multiple text fields with weights. + * + * @param fieldWeights Map of field names to weights + * @return Builder + */ + public Builder textFieldWeights(Map fieldWeights) { + validateFieldWeights(fieldWeights); + this.fieldWeights = new HashMap<>(fieldWeights); + return this; + } + + /** + * Set the scoring algorithm. + * + * @param scorer Scorer name (e.g., BM25STD, TFIDF) + * @return Builder + */ + public Builder scorer(String scorer) { + this.scorer = scorer; + return this; + } + + /** + * Set the filter expression. + * + * @param filterExpression Filter to apply + * @return Builder + */ + public Builder filterExpression(Filter filterExpression) { + this.filterExpression = filterExpression; + return this; + } + + /** + * Set the number of results to return. + * + * @param numResults Number of results + * @return Builder + */ + public Builder numResults(int numResults) { + this.numResults = numResults; + return this; + } + + /** + * Build the TextQuery instance. + * + * @return TextQuery + * @throws IllegalArgumentException if text is null or field weights are empty + */ + public TextQuery build() { + if (text == null || text.trim().isEmpty()) { + throw new IllegalArgumentException("Text cannot be null or empty"); + } + if (fieldWeights.isEmpty()) { + throw new IllegalArgumentException("At least one text field must be specified"); + } + return new TextQuery(this); + } } - @Override - public String toString() { - return toQueryString(); + /** + * Validate field weights. + * + * @param fieldWeights Map to validate + * @throws IllegalArgumentException if weights are invalid + */ + private static void validateFieldWeights(Map fieldWeights) { + for (Map.Entry entry : fieldWeights.entrySet()) { + String field = entry.getKey(); + Double weight = entry.getValue(); + + if (weight == null) { + throw new IllegalArgumentException("Weight for field '" + field + "' cannot be null"); + } + if (weight <= 0) { + throw new IllegalArgumentException( + "Weight for field '" + field + "' must be positive, got " + weight); + } + } } } diff --git a/core/src/test/java/com/redis/vl/query/QueryIntegrationTest.java b/core/src/test/java/com/redis/vl/query/QueryIntegrationTest.java index 2887ff5..9ed0f05 100644 --- a/core/src/test/java/com/redis/vl/query/QueryIntegrationTest.java +++ b/core/src/test/java/com/redis/vl/query/QueryIntegrationTest.java @@ -602,7 +602,8 @@ void testTextQuery() { List scorers = Arrays.asList("BM25", "TFIDF", "TFIDF.DOCNORM", "DISMAX", "DOCSCORE"); for (String scorer : scorers) { - TextQuery textQuery = new TextQuery(text, textField, scorer, returnFields); + TextQuery textQuery = + TextQuery.builder().text(text).textField(textField).scorer(scorer).numResults(10).build(); List> results = index.query(textQuery); assertThat(results) .as("TextQuery with scorer " + scorer + " should return results") @@ -630,7 +631,14 @@ void testTextQueryWithFilter() { Filter.and(Filter.tag("credit_score", "high"), Filter.numeric("age").gt(30)); String scorer = "TFIDF"; - TextQuery textQuery = new TextQuery(text, textField, scorer, filterExpression, returnFields); + TextQuery textQuery = + TextQuery.builder() + .text(text) + .textField(textField) + .scorer(scorer) + .filterExpression(filterExpression) + .numResults(10) + .build(); List> results = index.query(textQuery); assertThat(results).hasSize(2); // mary and derrick diff --git a/core/src/test/java/com/redis/vl/query/TextQueryWeightsTest.java b/core/src/test/java/com/redis/vl/query/TextQueryWeightsTest.java new file mode 100644 index 0000000..9911e2d --- /dev/null +++ b/core/src/test/java/com/redis/vl/query/TextQueryWeightsTest.java @@ -0,0 +1,157 @@ +package com.redis.vl.query; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.Map; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +/** + * Unit tests for TextQuery with field weights functionality. + * + *

Python port: tests/unit/test_text_query_weights.py + */ +@DisplayName("TextQuery Field Weights Tests") +class TextQueryWeightsTest { + + @Test + @DisplayName("Should accept weights dictionary") + void testTextQueryAcceptsWeightsDict() { + // Given + String text = "example search query"; + Map fieldWeights = Map.of("title", 5.0, "content", 2.0, "tags", 1.0); + + // When + TextQuery textQuery = + TextQuery.builder().text(text).textFieldWeights(fieldWeights).numResults(10).build(); + + // Then + assertThat(textQuery.getFieldWeights()).isEqualTo(fieldWeights); + } + + @Test + @DisplayName("Should generate weighted query string") + void testTextQueryGeneratesWeightedQueryString() { + // Given + String text = "search query"; + Map fieldWeights = Map.of("title", 5.0); + + // When + TextQuery textQuery = + TextQuery.builder().text(text).textFieldWeights(fieldWeights).numResults(10).build(); + + String queryString = textQuery.toQueryString(); + + // Then - should generate: @title:(search | query) => { $weight: 5.0 } + assertThat(queryString) + .containsAnyOf( + "@title:(search | query) => { $weight: 5.0 }", + "@title:(search | query)=>{ $weight: 5.0 }", + "@title:(search | query)=>{$weight:5.0}"); + } + + @Test + @DisplayName("Should handle multiple fields with weights") + void testTextQueryMultipleFieldsWithWeights() { + // Given + String text = "search terms"; + Map fieldWeights = Map.of("title", 3.0, "content", 1.5, "tags", 1.0); + + // When + TextQuery textQuery = + TextQuery.builder().text(text).textFieldWeights(fieldWeights).numResults(10).build(); + + String queryString = textQuery.toQueryString(); + + // Then - all fields should be present + assertThat(queryString).contains("@title:"); + assertThat(queryString).contains("@content:"); + assertThat(queryString).contains("@tags:"); + + // Weights should be in the query + assertThat(queryString).containsAnyOf("$weight: 3.0", "$weight:3.0"); + assertThat(queryString).containsAnyOf("$weight: 1.5", "$weight:1.5"); + // Weight of 1.0 might be omitted as it's the default + } + + @Test + @DisplayName("Should maintain backward compatibility with single string field") + void testTextQueryBackwardCompatibility() { + // Given + String text = "backward compatible"; + + // When - use single string field name (original API) + TextQuery textQuery = + TextQuery.builder().text(text).textField("description").numResults(5).build(); + + String queryString = textQuery.toQueryString(); + + // Then + assertThat(queryString).contains("@description:"); + assertThat(queryString).contains("backward | compatible"); + + // Field weights should have the single field with weight 1.0 + assertThat(textQuery.getFieldWeights()).isEqualTo(Map.of("description", 1.0)); + } + + @Test + @DisplayName("Should reject negative weights") + void testTextQueryRejectsNegativeWeights() { + // Given + String text = "test query"; + + // When/Then - negative weight should throw + assertThatThrownBy( + () -> + TextQuery.builder() + .text(text) + .textFieldWeights(Map.of("title", -1.0)) + .numResults(10) + .build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("must be positive"); + } + + @Test + @DisplayName("Should reject zero weights") + void testTextQueryRejectsZeroWeights() { + // Given + String text = "test query"; + + // When/Then - zero weight should throw + assertThatThrownBy( + () -> + TextQuery.builder() + .text(text) + .textFieldWeights(Map.of("title", 0.0)) + .numResults(10) + .build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("must be positive"); + } + + @Test + @DisplayName("Should support dynamic weight updates") + void testSetFieldWeightsMethod() { + // Given + String text = "dynamic weights"; + + // When - start with single field + TextQuery textQuery = TextQuery.builder().text(text).textField("title").numResults(10).build(); + + assertThat(textQuery.getFieldWeights()).isEqualTo(Map.of("title", 1.0)); + + // Update to multiple fields with weights + Map newWeights = Map.of("title", 5.0, "content", 2.0); + textQuery.setFieldWeights(newWeights); + + // Then + assertThat(textQuery.getFieldWeights()).isEqualTo(newWeights); + + // Query string should reflect new weights + String queryString = textQuery.toQueryString(); + assertThat(queryString).containsAnyOf("$weight: 5.0", "$weight:5.0"); + assertThat(queryString).containsAnyOf("$weight: 2.0", "$weight:2.0"); + } +}