|
1 | 1 | from typing import Any, Dict, List, Optional, Set, Tuple, Union |
2 | 2 |
|
| 3 | +from pydantic import BaseModel, field_validator |
3 | 4 | from redis.commands.search.aggregation import AggregateRequest, Desc |
4 | 5 |
|
5 | 6 | from redisvl.query.filter import FilterExpression |
6 | 7 | from redisvl.redis.utils import array_to_buffer |
| 8 | +from redisvl.schema.fields import VectorDataType |
7 | 9 | from redisvl.utils.token_escaper import TokenEscaper |
8 | 10 | from redisvl.utils.utils import lazy_import |
9 | 11 |
|
10 | 12 | nltk = lazy_import("nltk") |
11 | 13 | nltk_stopwords = lazy_import("nltk.corpus.stopwords") |
12 | 14 |
|
13 | 15 |
|
| 16 | +class Vector(BaseModel): |
| 17 | + """ |
| 18 | + Simple object containing the necessary arguments to perform a multi vector query. |
| 19 | + """ |
| 20 | + |
| 21 | + vector: Union[List[float], bytes] |
| 22 | + field_name: str |
| 23 | + dtype: str = "float32" |
| 24 | + weight: float = 1.0 |
| 25 | + |
| 26 | + @field_validator("dtype") |
| 27 | + @classmethod |
| 28 | + def validate_dtype(cls, dtype: str) -> str: |
| 29 | + try: |
| 30 | + VectorDataType(dtype.upper()) |
| 31 | + except ValueError: |
| 32 | + raise ValueError( |
| 33 | + f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}" |
| 34 | + ) |
| 35 | + |
| 36 | + return dtype |
| 37 | + |
| 38 | + |
14 | 39 | class AggregationQuery(AggregateRequest): |
15 | 40 | """ |
16 | 41 | Base class for aggregation queries used to create aggregation queries for Redis. |
@@ -227,3 +252,149 @@ def _build_query_string(self) -> str: |
227 | 252 | def __str__(self) -> str: |
228 | 253 | """Return the string representation of the query.""" |
229 | 254 | return " ".join([str(x) for x in self.build_args()]) |
| 255 | + |
| 256 | + |
| 257 | +class MultiVectorQuery(AggregationQuery): |
| 258 | + """ |
| 259 | + MultiVectorQuery allows for search over multiple vector fields in a document simulateously. |
| 260 | + The final score will be a weighted combination of the individual vector similarity scores |
| 261 | + following the formula: |
| 262 | +
|
| 263 | + score = (w_1 * score_1 + w_2 * score_2 + w_3 * score_3 + ... ) |
| 264 | +
|
| 265 | + Vectors may be of different size and datatype, but must be indexed using the 'cosine' distance_metric. |
| 266 | +
|
| 267 | + .. code-block:: python |
| 268 | +
|
| 269 | + from redisvl.query import MultiVectorQuery, Vector |
| 270 | + from redisvl.index import SearchIndex |
| 271 | +
|
| 272 | + index = SearchIndex.from_yaml("path/to/index.yaml") |
| 273 | +
|
| 274 | + vector_1 = Vector( |
| 275 | + vector=[0.1, 0.2, 0.3], |
| 276 | + field_name="text_vector", |
| 277 | + dtype="float32", |
| 278 | + weight=0.7, |
| 279 | + ) |
| 280 | + vector_2 = Vector( |
| 281 | + vector=[0.5, 0.5], |
| 282 | + field_name="image_vector", |
| 283 | + dtype="bfloat16", |
| 284 | + weight=0.2, |
| 285 | + ) |
| 286 | + vector_3 = Vector( |
| 287 | + vector=[0.1, 0.2, 0.3], |
| 288 | + field_name="text_vector", |
| 289 | + dtype="float64", |
| 290 | + weight=0.5, |
| 291 | + ) |
| 292 | +
|
| 293 | + query = MultiVectorQuery( |
| 294 | + vectors=[vector_1, vector_2, vector_3], |
| 295 | + filter_expression=None, |
| 296 | + num_results=10, |
| 297 | + return_fields=["field1", "field2"], |
| 298 | + dialect=2, |
| 299 | + ) |
| 300 | +
|
| 301 | + results = index.query(query) |
| 302 | + """ |
| 303 | + |
| 304 | + _vectors: List[Vector] |
| 305 | + |
| 306 | + def __init__( |
| 307 | + self, |
| 308 | + vectors: Union[Vector, List[Vector]], |
| 309 | + return_fields: Optional[List[str]] = None, |
| 310 | + filter_expression: Optional[Union[str, FilterExpression]] = None, |
| 311 | + num_results: int = 10, |
| 312 | + dialect: int = 2, |
| 313 | + ): |
| 314 | + """ |
| 315 | + Instantiates a MultiVectorQuery object. |
| 316 | +
|
| 317 | + Args: |
| 318 | + vectors (Union[Vector, List[Vector]]): The Vectors to perform vector similarity search. |
| 319 | + return_fields (Optional[List[str]], optional): The fields to return. Defaults to None. |
| 320 | + filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to use. |
| 321 | + Defaults to None. |
| 322 | + num_results (int, optional): The number of results to return. Defaults to 10. |
| 323 | + dialect (int, optional): The Redis dialect version. Defaults to 2. |
| 324 | + """ |
| 325 | + |
| 326 | + self._filter_expression = filter_expression |
| 327 | + self._num_results = num_results |
| 328 | + |
| 329 | + if isinstance(vectors, Vector): |
| 330 | + self._vectors = [vectors] |
| 331 | + else: |
| 332 | + self._vectors = vectors # type: ignore |
| 333 | + |
| 334 | + if not all([isinstance(v, Vector) for v in self._vectors]): |
| 335 | + raise TypeError( |
| 336 | + "vector argument must be a Vector object or list of Vector objects." |
| 337 | + ) |
| 338 | + |
| 339 | + query_string = self._build_query_string() |
| 340 | + super().__init__(query_string) |
| 341 | + |
| 342 | + # calculate the respective vector similarities |
| 343 | + for i in range(len(self._vectors)): |
| 344 | + self.apply(**{f"score_{i}": f"(2 - @distance_{i})/2"}) |
| 345 | + |
| 346 | + # construct the scoring string based on the vector similarity scores and weights |
| 347 | + combined_scores = [] |
| 348 | + for i, w in enumerate([v.weight for v in self._vectors]): |
| 349 | + combined_scores.append(f"@score_{i} * {w}") |
| 350 | + combined_score_string = " + ".join(combined_scores) |
| 351 | + |
| 352 | + self.apply(combined_score=combined_score_string) |
| 353 | + |
| 354 | + self.sort_by(Desc("@combined_score"), max=num_results) # type: ignore |
| 355 | + self.dialect(dialect) |
| 356 | + if return_fields: |
| 357 | + self.load(*return_fields) # type: ignore[arg-type] |
| 358 | + |
| 359 | + @property |
| 360 | + def params(self) -> Dict[str, Any]: |
| 361 | + """Return the parameters for the aggregation. |
| 362 | +
|
| 363 | + Returns: |
| 364 | + Dict[str, Any]: The parameters for the aggregation. |
| 365 | + """ |
| 366 | + params = {} |
| 367 | + for i, (vector, dtype) in enumerate( |
| 368 | + [(v.vector, v.dtype) for v in self._vectors] |
| 369 | + ): |
| 370 | + if isinstance(vector, list): |
| 371 | + vector = array_to_buffer(vector, dtype=dtype) # type: ignore |
| 372 | + params[f"vector_{i}"] = vector |
| 373 | + return params |
| 374 | + |
| 375 | + def _build_query_string(self) -> str: |
| 376 | + """Build the full query string for text search with optional filtering.""" |
| 377 | + |
| 378 | + # base KNN query |
| 379 | + range_queries = [] |
| 380 | + for i, (vector, field) in enumerate( |
| 381 | + [(v.vector, v.field_name) for v in self._vectors] |
| 382 | + ): |
| 383 | + range_queries.append( |
| 384 | + f"@{field}:[VECTOR_RANGE 2.0 $vector_{i}]=>{{$YIELD_DISTANCE_AS: distance_{i}}}" |
| 385 | + ) |
| 386 | + |
| 387 | + range_query = " | ".join(range_queries) |
| 388 | + |
| 389 | + filter_expression = self._filter_expression |
| 390 | + if isinstance(self._filter_expression, FilterExpression): |
| 391 | + filter_expression = str(self._filter_expression) |
| 392 | + |
| 393 | + if filter_expression: |
| 394 | + return f"({range_query}) AND ({filter_expression})" |
| 395 | + else: |
| 396 | + return f"{range_query}" |
| 397 | + |
| 398 | + def __str__(self) -> str: |
| 399 | + """Return the string representation of the query.""" |
| 400 | + return " ".join([str(x) for x in self.build_args()]) |
0 commit comments