Skip to content

Commit 2588585

Browse files
committed
Use full query num_results for KNN; set parameter defaults
1 parent dd11f2a commit 2588585

File tree

3 files changed

+65
-46
lines changed

3 files changed

+65
-46
lines changed

redisvl/query/hybrid.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,18 @@ def __init__(
5252
text_scorer: str = "BM25STD",
5353
yield_text_score_as: Optional[str] = None,
5454
vector_search_method: Optional[Literal["KNN", "RANGE"]] = None,
55-
knn_num_results: Optional[int] = None,
56-
knn_ef_runtime: Optional[int] = None,
57-
range_radius: Optional[int] = None,
58-
range_epsilon: Optional[float] = None,
55+
knn_ef_runtime: int = 10,
56+
range_radius: Optional[float] = None,
57+
range_epsilon: float = 0.01,
5958
yield_vsim_score_as: Optional[str] = None,
6059
filter_expression: Optional[Union[str, FilterExpression]] = None,
6160
combination_method: Optional[Literal["RRF", "LINEAR"]] = None,
62-
rrf_window: Optional[int] = None,
63-
rrf_constant: Optional[float] = None,
64-
linear_alpha: Optional[float] = None,
61+
rrf_window: int = 20,
62+
rrf_constant: int = 60,
63+
linear_alpha: float = 0.3,
6564
yield_combined_score_as: Optional[str] = None,
6665
dtype: str = "float32",
67-
num_results: Optional[int] = None,
66+
num_results: Optional[int] = 10,
6867
return_fields: Optional[List[str]] = None,
6968
stopwords: Optional[Union[str, Set[str]]] = "english",
7069
text_weights: Optional[Dict[str, float]] = None,
@@ -83,7 +82,6 @@ def __init__(
8382
see https://redis.io/docs/latest/develop/ai/search-and-query/advanced-concepts/scoring/
8483
yield_text_score_as: The name of the field to yield the text score as.
8584
vector_search_method: The vector search method to use. Options are {KNN, RANGE}. Defaults to None.
86-
knn_num_results: The number of nearest neighbors to return, required if `vector_search_method` is "KNN".
8785
knn_ef_runtime: The exploration factor parameter for HNSW, optional if `vector_search_method` is "KNN".
8886
range_radius: The search radius to use, required if `vector_search_method` is "RANGE".
8987
range_epsilon: The epsilon value to use, optional if `vector_search_method` is "RANGE"; defines the
@@ -98,10 +96,10 @@ def __init__(
9896
fusion scope.
9997
rrf_constant: The constant to use for the reciprocal rank fusion (RRF) combination method. Controls decay
10098
of rank influence.
101-
linear_alpha: The weight of the first query for the linear combination method (LINEAR).
99+
linear_alpha: The weight of the text query for the linear combination method (LINEAR).
102100
yield_combined_score_as: The name of the field to yield the combined score as.
103101
dtype: The data type of the vector. Defaults to "float32".
104-
num_results: The number of results to return. If not specified, the server default will be used (10).
102+
num_results: The number of results to return.
105103
return_fields: The fields to return. Defaults to None.
106104
stopwords (Optional[Union[str, Set[str]]], optional): The stopwords to remove from the
107105
provided text prior to search-use. If a string such as "english" "german" is
@@ -155,7 +153,7 @@ def __init__(
155153
text_scorer=text_scorer,
156154
yield_text_score_as=yield_text_score_as,
157155
vector_search_method=vector_search_method,
158-
knn_num_results=knn_num_results,
156+
num_results=num_results,
159157
knn_ef_runtime=knn_ef_runtime,
160158
range_radius=range_radius,
161159
range_epsilon=range_epsilon,
@@ -185,9 +183,9 @@ def build_base_query(
185183
text_scorer: str = "BM25STD",
186184
yield_text_score_as: Optional[str] = None,
187185
vector_search_method: Optional[Literal["KNN", "RANGE"]] = None,
188-
knn_num_results: Optional[int] = None,
186+
num_results: Optional[int] = None,
189187
knn_ef_runtime: Optional[int] = None,
190-
range_radius: Optional[int] = None,
188+
range_radius: Optional[float] = None,
191189
range_epsilon: Optional[float] = None,
192190
yield_vsim_score_as: Optional[str] = None,
193191
filter_expression: Optional[Union[str, FilterExpression]] = None,
@@ -205,7 +203,7 @@ def build_base_query(
205203
see https://redis.io/docs/latest/develop/ai/search-and-query/advanced-concepts/scoring/
206204
yield_text_score_as: The name of the field to yield the text score as.
207205
vector_search_method: The vector search method to use. Options are {KNN, RANGE}. Defaults to None.
208-
knn_num_results: The number of nearest neighbors to return, required if `vector_search_method` is "KNN".
206+
num_results: The number of nearest neighbors to return, required if `vector_search_method` is "KNN".
209207
knn_ef_runtime: The exploration factor parameter for HNSW, optional if `vector_search_method` is "KNN".
210208
range_radius: The search radius to use, required if `vector_search_method` is "RANGE".
211209
range_epsilon: The epsilon value to use, optional if `vector_search_method` is "RANGE"; defines the
@@ -254,12 +252,12 @@ def build_base_query(
254252
vsim_search_method_params: Dict[str, Any] = {}
255253
if vector_search_method == "KNN":
256254
vsim_search_method = VectorSearchMethods.KNN
257-
if not knn_num_results:
255+
if not num_results:
258256
raise ValueError(
259-
"Must provide `knn_num_results` if vector_search_method is KNN"
257+
"Must provide `num_results` if vector_search_method is KNN"
260258
)
261259

262-
vsim_search_method_params["K"] = knn_num_results
260+
vsim_search_method_params["K"] = num_results
263261
if knn_ef_runtime:
264262
vsim_search_method_params["EF_RUNTIME"] = knn_ef_runtime
265263

tests/integration/test_hybrid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def test_hybrid_query(index):
136136

137137
results = index.query(hybrid_query)
138138
assert isinstance(results, list)
139-
assert len(results) == 10 # Server-side default for hybrid search
139+
assert len(results) == 10 # default for hybrid search
140140
for doc in results:
141141
assert doc["user"] in [
142142
"john",

tests/unit/test_hybrid_types.py

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def test_hybrid_query_basic_initialization():
7777
"VSIM",
7878
"@embedding",
7979
bytes_vector,
80+
"LIMIT",
81+
"0",
82+
"10",
8083
]
8184

8285
# Verify that no combination method is set
@@ -97,16 +100,16 @@ def test_hybrid_query_with_all_parameters():
97100
text_scorer="TFIDF",
98101
yield_text_score_as="text_score",
99102
vector_search_method="KNN",
100-
knn_num_results=10,
101103
knn_ef_runtime=100,
102104
yield_vsim_score_as="vsim_score",
103105
filter_expression=filter_expression,
104106
stopwords=None,
105107
text_weights=text_weights,
106108
combination_method="RRF",
107109
rrf_window=10,
108-
rrf_constant=0.5,
110+
rrf_constant=50,
109111
yield_combined_score_as="hybrid_score",
112+
num_results=10,
110113
)
111114

112115
assert hybrid_query._ft_helper is not None
@@ -140,9 +143,12 @@ def test_hybrid_query_with_all_parameters():
140143
"WINDOW",
141144
10,
142145
"CONSTANT",
143-
0.5,
146+
50,
144147
"YIELD_SCORE_AS",
145148
"hybrid_score",
149+
"LIMIT",
150+
"0",
151+
"10",
146152
]
147153

148154
# Add post-processing and verify that it is reflected in the query
@@ -373,6 +379,9 @@ def test_hybrid_query_with_string_filter():
373379
bytes_vector,
374380
"FILTER",
375381
"@category:{tech|science|engineering}",
382+
"LIMIT",
383+
"0",
384+
"10",
376385
]
377386

378387

@@ -399,6 +408,9 @@ def test_hybrid_query_with_tag_filter():
399408
bytes_vector,
400409
"FILTER",
401410
"@genre:{comedy}",
411+
"LIMIT",
412+
"0",
413+
"10",
402414
]
403415

404416

@@ -506,28 +518,18 @@ def test_hybrid_query_without_filter():
506518
@pytest.mark.skipif(not REDIS_HYBRID_AVAILABLE, reason=SKIP_REASON)
507519
def test_hybrid_query_vector_search_method_knn():
508520
"""Test HybridQuery with KNN vector search method."""
509-
with pytest.raises(ValueError):
510-
# KNN requires K
511-
HybridQuery(
512-
text=sample_text,
513-
text_field_name="description",
514-
vector=sample_vector,
515-
vector_field_name="embedding",
516-
vector_search_method="KNN",
517-
)
518-
519521
hybrid_query = HybridQuery(
520522
text=sample_text,
521523
text_field_name="description",
522524
vector=sample_vector,
523525
vector_field_name="embedding",
524526
vector_search_method="KNN",
525-
knn_num_results=10,
527+
num_results=10,
526528
)
527529

528530
# KNN with params should be in args
529531
args = get_query_pieces(hybrid_query)
530-
assert args[-4:] == ["KNN", 2, "K", 10]
532+
assert args[7:13] == ["KNN", 4, "K", 10, "EF_RUNTIME", 10]
531533

532534
# With optional EF_RUNTIME param
533535
hybrid_query = HybridQuery(
@@ -536,13 +538,13 @@ def test_hybrid_query_vector_search_method_knn():
536538
vector=sample_vector,
537539
vector_field_name="embedding",
538540
vector_search_method="KNN",
539-
knn_num_results=10,
540541
knn_ef_runtime=100,
542+
num_results=10,
541543
)
542544

543545
# KNN with params should be in args
544546
args = get_query_pieces(hybrid_query)
545-
assert args[-6:] == ["KNN", 4, "K", 10, "EF_RUNTIME", 100]
547+
assert args[7:13] == ["KNN", 4, "K", 10, "EF_RUNTIME", 100]
546548

547549

548550
@pytest.mark.skipif(not REDIS_HYBRID_AVAILABLE, reason=SKIP_REASON)
@@ -569,7 +571,7 @@ def test_hybrid_query_vector_search_method_range():
569571

570572
# RANGE with params should be in args
571573
args = get_query_pieces(hybrid_query)
572-
assert args[-4:] == ["RANGE", 2, "RADIUS", 10]
574+
assert args[7:13] == ["RANGE", 4, "RADIUS", 10, "EPSILON", 0.01]
573575

574576
# With optional EPSILON param
575577
hybrid_query = HybridQuery(
@@ -584,7 +586,7 @@ def test_hybrid_query_vector_search_method_range():
584586

585587
# RANGE with params should be in args
586588
args = get_query_pieces(hybrid_query)
587-
assert args[-6:] == ["RANGE", 4, "RADIUS", 10, "EPSILON", 0.1]
589+
assert args[7:13] == ["RANGE", 4, "RADIUS", 10, "EPSILON", 0.1]
588590

589591

590592
@pytest.mark.skipif(not REDIS_HYBRID_AVAILABLE, reason=SKIP_REASON)
@@ -644,6 +646,9 @@ def test_hybrid_query_special_characters_in_text():
644646
"VSIM",
645647
"@embedding",
646648
bytes_vector,
649+
"LIMIT",
650+
"0",
651+
"10",
647652
]
648653

649654

@@ -668,6 +673,9 @@ def test_hybrid_query_unicode_text():
668673
"VSIM",
669674
"@embedding",
670675
bytes_vector,
676+
"LIMIT",
677+
"0",
678+
"10",
671679
]
672680

673681

@@ -682,13 +690,22 @@ def test_hybrid_query_with_vector_filter_and_method():
682690
vector=sample_vector,
683691
vector_field_name="embedding",
684692
vector_search_method="KNN",
685-
knn_num_results=10,
686693
filter_expression=tag_filter,
694+
num_results=10,
687695
)
688696

689697
# Verify KNN params and filter are both in args
690698
args = get_query_pieces(hybrid_query)
691-
assert args[-6:] == ["KNN", 2, "K", 10, "FILTER", "@genre:{comedy}"]
699+
assert args[7:15] == [
700+
"KNN",
701+
4,
702+
"K",
703+
10,
704+
"EF_RUNTIME",
705+
10,
706+
"FILTER",
707+
"@genre:{comedy}",
708+
]
692709

693710

694711
# Combination method tests
@@ -713,9 +730,11 @@ def test_hybrid_query_combination_method_rrf_basic():
713730
assert hybrid_query.combination_method.get_args() == [
714731
"COMBINE",
715732
"RRF",
716-
2,
733+
4,
717734
"WINDOW",
718735
10,
736+
"CONSTANT",
737+
60,
719738
]
720739

721740

@@ -728,7 +747,7 @@ def test_hybrid_query_combination_method_rrf_with_constant():
728747
vector=sample_vector,
729748
vector_field_name="embedding",
730749
combination_method="RRF",
731-
rrf_constant=0.5,
750+
rrf_constant=50,
732751
)
733752

734753
# Verify RRF combination method is set
@@ -738,9 +757,11 @@ def test_hybrid_query_combination_method_rrf_with_constant():
738757
assert hybrid_query.combination_method.get_args() == [
739758
"COMBINE",
740759
"RRF",
741-
2,
760+
4,
761+
"WINDOW",
762+
20,
742763
"CONSTANT",
743-
0.5,
764+
50,
744765
]
745766

746767

@@ -754,7 +775,7 @@ def test_hybrid_query_combination_method_rrf_with_both_params():
754775
vector_field_name="embedding",
755776
combination_method="RRF",
756777
rrf_window=20,
757-
rrf_constant=1.0,
778+
rrf_constant=50,
758779
yield_combined_score_as="rrf_score",
759780
)
760781

@@ -769,7 +790,7 @@ def test_hybrid_query_combination_method_rrf_with_both_params():
769790
"WINDOW",
770791
20,
771792
"CONSTANT",
772-
1.0,
793+
50,
773794
"YIELD_SCORE_AS",
774795
"rrf_score",
775796
]

0 commit comments

Comments
 (0)