-
Notifications
You must be signed in to change notification settings - Fork 75
Direct COSINE_SIMILARITY metric #582
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -248,6 +248,15 @@ def _uses_svs_vamana(self) -> bool: | |||||||||||||||||||
|
|
||||||||||||||||||||
| def _validate_query(self, query: BaseQuery) -> None: | ||||||||||||||||||||
| """Validate a query.""" | ||||||||||||||||||||
| if isinstance(query, BaseVectorQuery): | ||||||||||||||||||||
| field = self.schema.fields[query._vector_field_name] | ||||||||||||||||||||
| dist_metric = VectorDistanceMetric(field.attrs.distance_metric.upper()) # type: ignore | ||||||||||||||||||||
| if dist_metric == VectorDistanceMetric.COSINE_SIMILARITY and getattr( | ||||||||||||||||||||
| query, "_uses_default_vector_distance_sort", False | ||||||||||||||||||||
| ): | ||||||||||||||||||||
| query.sort_by(query.DISTANCE_ID, asc=False) | ||||||||||||||||||||
| query._uses_default_vector_distance_sort = True | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if isinstance(query, VectorQuery): | ||||||||||||||||||||
| field = self.schema.fields[query._vector_field_name] | ||||||||||||||||||||
| if query.ef_runtime and field.attrs.algorithm != VectorIndexAlgorithm.HNSW: # type: ignore | ||||||||||||||||||||
|
|
@@ -1148,6 +1157,8 @@ def batch_query( | |||||||||||||||||||
| self, queries: Sequence[BaseQuery], batch_size: int = 10 | ||||||||||||||||||||
| ) -> List[List[Dict[str, Any]]]: | ||||||||||||||||||||
| """Execute a batch of queries and process results.""" | ||||||||||||||||||||
| for query in queries: | ||||||||||||||||||||
| self._validate_query(query) | ||||||||||||||||||||
|
Comment on lines
+1160
to
+1161
|
||||||||||||||||||||
| for query in queries: | |
| self._validate_query(query) | |
| for i, query in enumerate(queries): | |
| try: | |
| self._validate_query(query) | |
| except QueryValidationError as e: | |
| raise QueryValidationError( | |
| f"Invalid query at batch index {i}: {str(e)}" | |
| ) from e |
Copilot
AI
Apr 13, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as sync batch_query(): now that _validate_query() is called here, consider catching QueryValidationError and adding context about which query failed validation so async batch and single-query APIs have consistent error reporting.
| for query in queries: | |
| self._validate_query(query) | |
| for i, query in enumerate(queries): | |
| try: | |
| self._validate_query(query) | |
| except QueryValidationError as e: | |
| raise QueryValidationError( | |
| f"Invalid query at batch index {i}: {str(e)}" | |
| ) from e |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,13 +1,18 @@ | ||
| from types import SimpleNamespace | ||
| from unittest.mock import Mock | ||
|
|
||
| import pytest | ||
| from redis import __version__ as redis_version | ||
| from redis.commands.search.query import Query | ||
| from redis.commands.search.result import Result | ||
|
|
||
| from redisvl.index import SearchIndex | ||
| from redisvl.index.index import process_results | ||
| from redisvl.query import CountQuery, FilterQuery, RangeQuery, TextQuery, VectorQuery | ||
| from redisvl.query.filter import Tag | ||
| from redisvl.query.query import VectorRangeQuery | ||
| from redisvl.redis.connection import is_version_gte | ||
| from redisvl.schema import IndexSchema | ||
|
|
||
| # Sample data for testing | ||
| sample_vector = [0.1, 0.2, 0.3, 0.4] | ||
|
|
@@ -36,6 +41,108 @@ def test_count_query(): | |
| assert process_results(fake_result, count_query, "json") == 2 | ||
|
||
|
|
||
|
|
||
| def test_process_results_preserves_cosine_similarity_scores(): | ||
| schema = IndexSchema.from_dict( | ||
| { | ||
| "index": { | ||
| "name": "test-index", | ||
| "prefix": "doc", | ||
| "storage_type": "hash", | ||
| }, | ||
| "fields": [ | ||
| { | ||
| "name": "embedding", | ||
| "type": "vector", | ||
| "attrs": { | ||
| "dims": 3, | ||
| "algorithm": "flat", | ||
| "datatype": "float32", | ||
| "distance_metric": "cosine_similarity", | ||
| }, | ||
| } | ||
| ], | ||
| } | ||
| ) | ||
| query = VectorQuery( | ||
| vector=[0.1, 0.2, 0.3], | ||
| vector_field_name="embedding", | ||
| normalize_vector_distance=True, | ||
| return_score=True, | ||
| ) | ||
| fake_results = SimpleNamespace( | ||
| docs=[SimpleNamespace(id="doc:1", vector_distance="0.7")] | ||
| ) | ||
|
|
||
| processed = process_results(fake_results, query, schema) | ||
|
|
||
| assert processed[0]["vector_distance"] == "0.7" | ||
|
|
||
|
|
||
| def test_cosine_similarity_vector_query_defaults_to_desc_sort(): | ||
| schema = IndexSchema.from_dict( | ||
| { | ||
| "index": { | ||
| "name": "test-index", | ||
| "prefix": "doc", | ||
| "storage_type": "hash", | ||
| }, | ||
| "fields": [ | ||
| { | ||
| "name": "embedding", | ||
| "type": "vector", | ||
| "attrs": { | ||
| "dims": 3, | ||
| "algorithm": "flat", | ||
| "datatype": "float32", | ||
| "distance_metric": "cosine_similarity", | ||
| }, | ||
| } | ||
| ], | ||
| } | ||
| ) | ||
| index = SearchIndex(schema=schema, redis_client=Mock()) | ||
| query = VectorQuery(vector=[0.1, 0.2, 0.3], vector_field_name="embedding") | ||
|
|
||
| index._validate_query(query) | ||
|
|
||
| assert query._sortby.args == [query.DISTANCE_ID, "DESC"] | ||
| assert query._uses_default_vector_distance_sort is True | ||
|
|
||
|
|
||
| def test_explicit_sort_is_not_overridden_for_cosine_similarity_vector_query(): | ||
| schema = IndexSchema.from_dict( | ||
| { | ||
| "index": { | ||
| "name": "test-index", | ||
| "prefix": "doc", | ||
| "storage_type": "hash", | ||
| }, | ||
| "fields": [ | ||
| { | ||
| "name": "embedding", | ||
| "type": "vector", | ||
| "attrs": { | ||
| "dims": 3, | ||
| "algorithm": "flat", | ||
| "datatype": "float32", | ||
| "distance_metric": "cosine_similarity", | ||
| }, | ||
| } | ||
| ], | ||
| } | ||
| ) | ||
| index = SearchIndex(schema=schema, redis_client=Mock()) | ||
| query = VectorQuery( | ||
| vector=[0.1, 0.2, 0.3], | ||
| vector_field_name="embedding", | ||
| sort_by="custom_field", | ||
| ) | ||
|
|
||
| index._validate_query(query) | ||
|
|
||
| assert query._sortby.args == ["custom_field", "ASC"] | ||
|
|
||
|
|
||
| def test_filter_query(): | ||
| # Create a filter expression | ||
| filter_expression = Tag("brand") == "Nike" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accessing
self.schema.fields[query._vector_field_name]will raise a rawKeyErrorif the query references a vector field name not present in the schema. Since this is part of user input validation, it would be better to handle the missing-field case explicitly and raiseQueryValidationErrorwith a clear message (instead of leakingKeyError).