Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions redisvl/index/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,15 @@ def _uses_svs_vamana(self) -> bool:

def _validate_query(self, query: BaseQuery) -> None:
"""Validate a query."""
if isinstance(query, BaseVectorQuery):
field = self.schema.fields[query._vector_field_name]
dist_metric = VectorDistanceMetric(field.attrs.distance_metric.upper()) # type: ignore
Comment on lines +251 to +253
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accessing self.schema.fields[query._vector_field_name] will raise a raw KeyError if the query references a vector field name not present in the schema. Since this is part of user input validation, it would be better to handle the missing-field case explicitly and raise QueryValidationError with a clear message (instead of leaking KeyError).

Copilot uses AI. Check for mistakes.
if dist_metric == VectorDistanceMetric.COSINE_SIMILARITY and getattr(
query, "_uses_default_vector_distance_sort", False
):
query.sort_by(query.DISTANCE_ID, asc=False)
query._uses_default_vector_distance_sort = True

if isinstance(query, VectorQuery):
field = self.schema.fields[query._vector_field_name]
if query.ef_runtime and field.attrs.algorithm != VectorIndexAlgorithm.HNSW: # type: ignore
Expand Down Expand Up @@ -1148,6 +1157,8 @@ def batch_query(
self, queries: Sequence[BaseQuery], batch_size: int = 10
) -> List[List[Dict[str, Any]]]:
"""Execute a batch of queries and process results."""
for query in queries:
self._validate_query(query)
Comment on lines +1160 to +1161
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_query() wraps QueryValidationError to add context ("Invalid query: ..."), but batch_query() now calls _validate_query() without similar handling. Consider catching QueryValidationError here as well (and indicating which query in the batch failed) so batch and non-batch APIs report validation errors consistently.

Suggested change
for query in queries:
self._validate_query(query)
for i, query in enumerate(queries):
try:
self._validate_query(query)
except QueryValidationError as e:
raise QueryValidationError(
f"Invalid query at batch index {i}: {str(e)}"
) from e

Copilot uses AI. Check for mistakes.
results = self.batch_search(
[(query.query, query.params) for query in queries], batch_size=batch_size
)
Expand Down Expand Up @@ -2071,6 +2082,8 @@ async def batch_query(
self, queries: List[BaseQuery], batch_size: int = 10
) -> List[List[Dict[str, Any]]]:
"""Asynchronously execute a batch of queries and process results."""
for query in queries:
self._validate_query(query)
Comment on lines +2085 to +2086
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as sync batch_query(): now that _validate_query() is called here, consider catching QueryValidationError and adding context about which query failed validation so async batch and single-query APIs have consistent error reporting.

Suggested change
for query in queries:
self._validate_query(query)
for i, query in enumerate(queries):
try:
self._validate_query(query)
except QueryValidationError as e:
raise QueryValidationError(
f"Invalid query at batch index {i}: {str(e)}"
) from e

Copilot uses AI. Check for mistakes.
results = await self.batch_search(
[(query.query, query.params) for query in queries], batch_size=batch_size
)
Expand Down
13 changes: 13 additions & 0 deletions redisvl/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ def sort_by(
if sort_spec is None or sort_spec == []:
# No sorting
self._sortby = None
if hasattr(self, "_uses_default_vector_distance_sort"):
self._uses_default_vector_distance_sort = False
self._built_query_string = None
return self

# Handle backward compatibility: if sort_spec is a string and asc is specified
Expand Down Expand Up @@ -204,6 +207,11 @@ def sort_by(
# Call parent's sort_by with the first field
super().sort_by(first_field, asc=first_asc)

if hasattr(self, "_uses_default_vector_distance_sort"):
self._uses_default_vector_distance_sort = False

self._built_query_string = None

return self

def set_filter(
Expand Down Expand Up @@ -421,6 +429,7 @@ def _build_query_string(self) -> str:
class BaseVectorQuery:
DISTANCE_ID: str = "vector_distance"
VECTOR_PARAM: str = "vector"
_vector_field_name: str

# HNSW runtime parameters
EF_RUNTIME: str = "EF_RUNTIME"
Expand Down Expand Up @@ -550,6 +559,7 @@ def __init__(
self._use_search_history: Optional[str] = None
self._search_buffer_capacity: Optional[int] = None
self._normalize_vector_distance = normalize_vector_distance
self._uses_default_vector_distance_sort = False
self.set_filter(filter_expression)

# Initialize the base query
Expand All @@ -569,6 +579,7 @@ def __init__(
self.sort_by(sort_by)
else:
self.sort_by(self.DISTANCE_ID)
self._uses_default_vector_distance_sort = True

if in_order:
self.in_order()
Expand Down Expand Up @@ -996,6 +1007,7 @@ def __init__(
self._hybrid_policy: Optional[HybridPolicy] = None
self._batch_size: Optional[int] = None
self._normalize_vector_distance = normalize_vector_distance
self._uses_default_vector_distance_sort = False

# Initialize the base query
super().__init__("*")
Expand Down Expand Up @@ -1035,6 +1047,7 @@ def __init__(
self.sort_by(sort_by)
else:
self.sort_by(self.DISTANCE_ID)
self._uses_default_vector_distance_sort = True

if in_order:
self.in_order()
Expand Down
4 changes: 3 additions & 1 deletion redisvl/schema/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
- algorithm: Indexing algorithm ('flat', 'hnsw', or 'svs-vamana')
- datatype: Float precision ('float16', 'float32', 'float64', 'bfloat16')
Note: SVS-VAMANA only supports 'float16' and 'float32'
- distance_metric: Similarity metric ('COSINE', 'L2', 'IP')
- distance_metric: Similarity metric ('COSINE', 'COSINE_SIMILARITY', 'L2', 'IP')
- initial_cap: Initial capacity hint for memory allocation (optional)
- index_missing: Allow searching for documents without this field (optional)

Expand Down Expand Up @@ -52,6 +52,7 @@

VECTOR_NORM_MAP = {
"COSINE": norm_cosine_distance,
"COSINE_SIMILARITY": None, # already returned as a normalized similarity score
"L2": norm_l2_distance,
"IP": None, # normalized inner product is cosine similarity by definition
}
Expand All @@ -67,6 +68,7 @@ class FieldTypes(str, Enum):

class VectorDistanceMetric(str, Enum):
COSINE = "COSINE"
COSINE_SIMILARITY = "COSINE_SIMILARITY"
L2 = "L2"
IP = "IP"

Expand Down
20 changes: 20 additions & 0 deletions tests/unit/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,26 @@ def test_vector_fields_as_field():
assert hnsw_vector_field.name == "example_hnswvectorfield"


def test_vector_field_supports_cosine_similarity_metric():
vector_field = FlatVectorField(
name="embedding",
attrs={
"dims": 128,
"algorithm": "flat",
"distance_metric": "cosine_similarity",
},
)

redis_field = vector_field.as_redis_field()

assert vector_field.attrs.distance_metric.value == "COSINE_SIMILARITY"
assert "DISTANCE_METRIC" in redis_field.args
assert (
redis_field.args[redis_field.args.index("DISTANCE_METRIC") + 1]
== "COSINE_SIMILARITY"
)


@pytest.mark.parametrize(
"vector_schema_func,extra_params",
[
Expand Down
107 changes: 107 additions & 0 deletions tests/unit/test_query_types.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from types import SimpleNamespace
from unittest.mock import Mock

import pytest
from redis import __version__ as redis_version
from redis.commands.search.query import Query
from redis.commands.search.result import Result

from redisvl.index import SearchIndex
from redisvl.index.index import process_results
from redisvl.query import CountQuery, FilterQuery, RangeQuery, TextQuery, VectorQuery
from redisvl.query.filter import Tag
from redisvl.query.query import VectorRangeQuery
from redisvl.redis.connection import is_version_gte
from redisvl.schema import IndexSchema

# Sample data for testing
sample_vector = [0.1, 0.2, 0.3, 0.4]
Expand Down Expand Up @@ -36,6 +41,108 @@ def test_count_query():
assert process_results(fake_result, count_query, "json") == 2
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

process_results now takes an IndexSchema, but this test still passes the string "json". It works only because the CountQuery early-return skips schema usage; updating the test to pass a real schema (or a minimal IndexSchema fixture) will keep it aligned with the public signature and avoid future breakage if the implementation changes.

Copilot uses AI. Check for mistakes.


def test_process_results_preserves_cosine_similarity_scores():
schema = IndexSchema.from_dict(
{
"index": {
"name": "test-index",
"prefix": "doc",
"storage_type": "hash",
},
"fields": [
{
"name": "embedding",
"type": "vector",
"attrs": {
"dims": 3,
"algorithm": "flat",
"datatype": "float32",
"distance_metric": "cosine_similarity",
},
}
],
}
)
query = VectorQuery(
vector=[0.1, 0.2, 0.3],
vector_field_name="embedding",
normalize_vector_distance=True,
return_score=True,
)
fake_results = SimpleNamespace(
docs=[SimpleNamespace(id="doc:1", vector_distance="0.7")]
)

processed = process_results(fake_results, query, schema)

assert processed[0]["vector_distance"] == "0.7"


def test_cosine_similarity_vector_query_defaults_to_desc_sort():
schema = IndexSchema.from_dict(
{
"index": {
"name": "test-index",
"prefix": "doc",
"storage_type": "hash",
},
"fields": [
{
"name": "embedding",
"type": "vector",
"attrs": {
"dims": 3,
"algorithm": "flat",
"datatype": "float32",
"distance_metric": "cosine_similarity",
},
}
],
}
)
index = SearchIndex(schema=schema, redis_client=Mock())
query = VectorQuery(vector=[0.1, 0.2, 0.3], vector_field_name="embedding")

index._validate_query(query)

assert query._sortby.args == [query.DISTANCE_ID, "DESC"]
assert query._uses_default_vector_distance_sort is True


def test_explicit_sort_is_not_overridden_for_cosine_similarity_vector_query():
schema = IndexSchema.from_dict(
{
"index": {
"name": "test-index",
"prefix": "doc",
"storage_type": "hash",
},
"fields": [
{
"name": "embedding",
"type": "vector",
"attrs": {
"dims": 3,
"algorithm": "flat",
"datatype": "float32",
"distance_metric": "cosine_similarity",
},
}
],
}
)
index = SearchIndex(schema=schema, redis_client=Mock())
query = VectorQuery(
vector=[0.1, 0.2, 0.3],
vector_field_name="embedding",
sort_by="custom_field",
)

index._validate_query(query)

assert query._sortby.args == ["custom_field", "ASC"]


def test_filter_query():
# Create a filter expression
filter_expression = Tag("brand") == "Nike"
Expand Down
Loading