diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 649aa2156..fea501f89 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -248,6 +248,15 @@ def _uses_svs_vamana(self) -> bool: def _validate_query(self, query: BaseQuery) -> None: """Validate a query.""" + if isinstance(query, BaseVectorQuery): + field = self.schema.fields[query._vector_field_name] + dist_metric = VectorDistanceMetric(field.attrs.distance_metric.upper()) # type: ignore + if dist_metric == VectorDistanceMetric.COSINE_SIMILARITY and getattr( + query, "_uses_default_vector_distance_sort", False + ): + query.sort_by(query.DISTANCE_ID, asc=False) + query._uses_default_vector_distance_sort = True + if isinstance(query, VectorQuery): field = self.schema.fields[query._vector_field_name] if query.ef_runtime and field.attrs.algorithm != VectorIndexAlgorithm.HNSW: # type: ignore @@ -1148,6 +1157,8 @@ def batch_query( self, queries: Sequence[BaseQuery], batch_size: int = 10 ) -> List[List[Dict[str, Any]]]: """Execute a batch of queries and process results.""" + for query in queries: + self._validate_query(query) results = self.batch_search( [(query.query, query.params) for query in queries], batch_size=batch_size ) @@ -2071,6 +2082,8 @@ async def batch_query( self, queries: List[BaseQuery], batch_size: int = 10 ) -> List[List[Dict[str, Any]]]: """Asynchronously execute a batch of queries and process results.""" + for query in queries: + self._validate_query(query) results = await self.batch_search( [(query.query, query.params) for query in queries], batch_size=batch_size ) diff --git a/redisvl/query/query.py b/redisvl/query/query.py index dfdc41fd9..77718606f 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -174,6 +174,9 @@ def sort_by( if sort_spec is None or sort_spec == []: # No sorting self._sortby = None + if hasattr(self, "_uses_default_vector_distance_sort"): + self._uses_default_vector_distance_sort = False + self._built_query_string = None return self # Handle backward compatibility: if sort_spec is a string and asc is specified @@ -204,6 +207,11 @@ def sort_by( # Call parent's sort_by with the first field super().sort_by(first_field, asc=first_asc) + if hasattr(self, "_uses_default_vector_distance_sort"): + self._uses_default_vector_distance_sort = False + + self._built_query_string = None + return self def set_filter( @@ -421,6 +429,7 @@ def _build_query_string(self) -> str: class BaseVectorQuery: DISTANCE_ID: str = "vector_distance" VECTOR_PARAM: str = "vector" + _vector_field_name: str # HNSW runtime parameters EF_RUNTIME: str = "EF_RUNTIME" @@ -550,6 +559,7 @@ def __init__( self._use_search_history: Optional[str] = None self._search_buffer_capacity: Optional[int] = None self._normalize_vector_distance = normalize_vector_distance + self._uses_default_vector_distance_sort = False self.set_filter(filter_expression) # Initialize the base query @@ -569,6 +579,7 @@ def __init__( self.sort_by(sort_by) else: self.sort_by(self.DISTANCE_ID) + self._uses_default_vector_distance_sort = True if in_order: self.in_order() @@ -996,6 +1007,7 @@ def __init__( self._hybrid_policy: Optional[HybridPolicy] = None self._batch_size: Optional[int] = None self._normalize_vector_distance = normalize_vector_distance + self._uses_default_vector_distance_sort = False # Initialize the base query super().__init__("*") @@ -1035,6 +1047,7 @@ def __init__( self.sort_by(sort_by) else: self.sort_by(self.DISTANCE_ID) + self._uses_default_vector_distance_sort = True if in_order: self.in_order() diff --git a/redisvl/schema/fields.py b/redisvl/schema/fields.py index 1d9aab221..f910c3110 100644 --- a/redisvl/schema/fields.py +++ b/redisvl/schema/fields.py @@ -18,7 +18,7 @@ - algorithm: Indexing algorithm ('flat', 'hnsw', or 'svs-vamana') - datatype: Float precision ('float16', 'float32', 'float64', 'bfloat16') Note: SVS-VAMANA only supports 'float16' and 'float32' - - distance_metric: Similarity metric ('COSINE', 'L2', 'IP') + - distance_metric: Similarity metric ('COSINE', 'COSINE_SIMILARITY', 'L2', 'IP') - initial_cap: Initial capacity hint for memory allocation (optional) - index_missing: Allow searching for documents without this field (optional) @@ -52,6 +52,7 @@ VECTOR_NORM_MAP = { "COSINE": norm_cosine_distance, + "COSINE_SIMILARITY": None, # already returned as a normalized similarity score "L2": norm_l2_distance, "IP": None, # normalized inner product is cosine similarity by definition } @@ -67,6 +68,7 @@ class FieldTypes(str, Enum): class VectorDistanceMetric(str, Enum): COSINE = "COSINE" + COSINE_SIMILARITY = "COSINE_SIMILARITY" L2 = "L2" IP = "IP" diff --git a/tests/unit/test_fields.py b/tests/unit/test_fields.py index eeeda6046..47b0b3e64 100644 --- a/tests/unit/test_fields.py +++ b/tests/unit/test_fields.py @@ -125,6 +125,26 @@ def test_vector_fields_as_field(): assert hnsw_vector_field.name == "example_hnswvectorfield" +def test_vector_field_supports_cosine_similarity_metric(): + vector_field = FlatVectorField( + name="embedding", + attrs={ + "dims": 128, + "algorithm": "flat", + "distance_metric": "cosine_similarity", + }, + ) + + redis_field = vector_field.as_redis_field() + + assert vector_field.attrs.distance_metric.value == "COSINE_SIMILARITY" + assert "DISTANCE_METRIC" in redis_field.args + assert ( + redis_field.args[redis_field.args.index("DISTANCE_METRIC") + 1] + == "COSINE_SIMILARITY" + ) + + @pytest.mark.parametrize( "vector_schema_func,extra_params", [ diff --git a/tests/unit/test_query_types.py b/tests/unit/test_query_types.py index fa7ad69af..d4f17c015 100644 --- a/tests/unit/test_query_types.py +++ b/tests/unit/test_query_types.py @@ -1,13 +1,18 @@ +from types import SimpleNamespace +from unittest.mock import Mock + import pytest from redis import __version__ as redis_version from redis.commands.search.query import Query from redis.commands.search.result import Result +from redisvl.index import SearchIndex from redisvl.index.index import process_results from redisvl.query import CountQuery, FilterQuery, RangeQuery, TextQuery, VectorQuery from redisvl.query.filter import Tag from redisvl.query.query import VectorRangeQuery from redisvl.redis.connection import is_version_gte +from redisvl.schema import IndexSchema # Sample data for testing sample_vector = [0.1, 0.2, 0.3, 0.4] @@ -36,6 +41,108 @@ def test_count_query(): assert process_results(fake_result, count_query, "json") == 2 +def test_process_results_preserves_cosine_similarity_scores(): + schema = IndexSchema.from_dict( + { + "index": { + "name": "test-index", + "prefix": "doc", + "storage_type": "hash", + }, + "fields": [ + { + "name": "embedding", + "type": "vector", + "attrs": { + "dims": 3, + "algorithm": "flat", + "datatype": "float32", + "distance_metric": "cosine_similarity", + }, + } + ], + } + ) + query = VectorQuery( + vector=[0.1, 0.2, 0.3], + vector_field_name="embedding", + normalize_vector_distance=True, + return_score=True, + ) + fake_results = SimpleNamespace( + docs=[SimpleNamespace(id="doc:1", vector_distance="0.7")] + ) + + processed = process_results(fake_results, query, schema) + + assert processed[0]["vector_distance"] == "0.7" + + +def test_cosine_similarity_vector_query_defaults_to_desc_sort(): + schema = IndexSchema.from_dict( + { + "index": { + "name": "test-index", + "prefix": "doc", + "storage_type": "hash", + }, + "fields": [ + { + "name": "embedding", + "type": "vector", + "attrs": { + "dims": 3, + "algorithm": "flat", + "datatype": "float32", + "distance_metric": "cosine_similarity", + }, + } + ], + } + ) + index = SearchIndex(schema=schema, redis_client=Mock()) + query = VectorQuery(vector=[0.1, 0.2, 0.3], vector_field_name="embedding") + + index._validate_query(query) + + assert query._sortby.args == [query.DISTANCE_ID, "DESC"] + assert query._uses_default_vector_distance_sort is True + + +def test_explicit_sort_is_not_overridden_for_cosine_similarity_vector_query(): + schema = IndexSchema.from_dict( + { + "index": { + "name": "test-index", + "prefix": "doc", + "storage_type": "hash", + }, + "fields": [ + { + "name": "embedding", + "type": "vector", + "attrs": { + "dims": 3, + "algorithm": "flat", + "datatype": "float32", + "distance_metric": "cosine_similarity", + }, + } + ], + } + ) + index = SearchIndex(schema=schema, redis_client=Mock()) + query = VectorQuery( + vector=[0.1, 0.2, 0.3], + vector_field_name="embedding", + sort_by="custom_field", + ) + + index._validate_query(query) + + assert query._sortby.args == ["custom_field", "ASC"] + + def test_filter_query(): # Create a filter expression filter_expression = Tag("brand") == "Nike"