Skip to content

Commit 622e82c

Browse files
author
Sam Partee
authored
Implement Geographic Filter (#34)
Implement Geographic filter and clean up the hybrid query notebook.
1 parent d74053a commit 622e82c

File tree

8 files changed

+314
-167
lines changed

8 files changed

+314
-167
lines changed

docs/api/filter.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,20 @@ NumericFilter
5757
:show-inheritance:
5858
:members:
5959
:inherited-members:
60+
61+
62+
GeoFilter
63+
=========
64+
65+
.. currentmodule:: redisvl.query
66+
67+
.. autosummary::
68+
69+
GeoFilter.__init__
70+
GeoFilter.to_string
71+
72+
73+
.. autoclass:: GeoFilter
74+
:show-inheritance:
75+
:members:
76+
:inherited-members:
494 Bytes
Binary file not shown.

docs/user_guide/hybrid_queries_02.ipynb

Lines changed: 185 additions & 166 deletions
Large diffs are not rendered by default.

docs/user_guide/jupyterutils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from IPython.display import display, HTML
2+
3+
def table_print(dict_list):
4+
# If there's nothing in the list, there's nothing to print
5+
if len(dict_list) == 0:
6+
return
7+
8+
# Getting column names (dictionary keys) using the first dictionary
9+
columns = dict_list[0].keys()
10+
11+
# HTML table header
12+
html = '<table><tr><th>'
13+
html += '</th><th>'.join(columns)
14+
html += '</th></tr>'
15+
16+
# HTML table content
17+
for dictionary in dict_list:
18+
html += '<tr><td>'
19+
html += '</td><td>'.join(str(dictionary[column]) for column in columns)
20+
html += '</td></tr>'
21+
22+
# HTML table footer
23+
html += '</table>'
24+
25+
# Displaying the table
26+
display(HTML(html))
27+
28+
29+
def result_print(results):
30+
# If there's nothing in the list, there's nothing to print
31+
if len(results.docs) == 0:
32+
return
33+
34+
data = [doc.__dict__ for doc in results.docs]
35+
36+
to_remove = ["id", "payload"]
37+
for doc in data:
38+
for key in to_remove:
39+
if key in doc:
40+
del doc[key]
41+
42+
table_print(data)
File renamed without changes.

redisvl/cli/query.py

Whitespace-only changes.

redisvl/query.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,61 @@ def to_string(self) -> str:
6666
)
6767

6868

69+
class GeoFilter(Filter):
70+
GEO_UNITS = ["m", "km", "mi", "ft"]
71+
72+
def __init__(self, field, longitude, latitude, radius, unit="km"):
73+
"""Filter for Geo fields.
74+
75+
Args:
76+
field (str): The field to filter on.
77+
longitude (float): The longitude.
78+
latitude (float): The latitude.
79+
radius (float): The radius.
80+
unit (str, optional): The unit of the radius. Defaults to "km".
81+
82+
Raises:
83+
ValueError: If the unit is not one of ["m", "km", "mi", "ft"].
84+
85+
Examples:
86+
>>> # looking for Chinese restaurants near San Francisco
87+
>>> # (within a 5km radius) would be
88+
>>> #
89+
>>> from redisvl.query import GeoFilter
90+
>>> gf = GeoFilter("location", -122.4194, 37.7749, 5)
91+
"""
92+
super().__init__(field)
93+
self._longitude = longitude
94+
self._latitude = latitude
95+
self._radius = radius
96+
self._unit = self._set_unit(unit)
97+
98+
def _set_unit(self, unit):
99+
if unit.lower() not in self.GEO_UNITS:
100+
raise ValueError(f"Unit must be one of {self.GEO_UNITS}")
101+
return unit.lower()
102+
103+
def to_string(self) -> str:
104+
"""Converts the geo filter to a string.
105+
106+
Returns:
107+
str: The geo filter as a string.
108+
"""
109+
return (
110+
"@"
111+
+ self._field
112+
+ ":["
113+
+ str(self._longitude)
114+
+ " "
115+
+ str(self._latitude)
116+
+ " "
117+
+ str(self._radius)
118+
+ " "
119+
+ self._unit
120+
+ "]"
121+
)
122+
123+
69124
class NumericFilter(Filter):
70125
def __init__(self, field, minval, maxval, min_exclusive=False, max_exclusive=False):
71126
"""Filter for Numeric fields.

tests/test_filter.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
import pytest
22

3-
from redisvl.query import Filter, NumericFilter, TagFilter, TextFilter, VectorQuery
3+
from redisvl.query import (
4+
Filter,
5+
GeoFilter,
6+
NumericFilter,
7+
TagFilter,
8+
TextFilter,
9+
VectorQuery,
10+
)
411
from redisvl.utils.utils import TokenEscaper
512

613

@@ -25,6 +32,13 @@ def test_text_filter(self):
2532
txt_f = TextFilter("text_field", "text")
2633
assert txt_f.to_string() == "@text_field:text"
2734

35+
def test_geo_filter(self):
36+
geo_f = GeoFilter("geo_field", 1, 2, 3)
37+
assert geo_f.to_string() == "@geo_field:[1 2 3 km]"
38+
39+
geo_f = GeoFilter("geo_field", 1, 2, 3, unit="m")
40+
assert geo_f.to_string() == "@geo_field:[1 2 3 m]"
41+
2842
def test_filters_combination(self):
2943
tf1 = TagFilter("tag_field", ["tag1", "tag2"])
3044
tf2 = TagFilter("tag_field", ["tag3"])

0 commit comments

Comments
 (0)