Skip to content

Commit 474938d

Browse files
authored
Merge pull request #9 from IdanHaim/RavenDB-6452
RavenDB-6452
2 parents 167d428 + d8eff8b commit 474938d

File tree

3 files changed

+69
-8
lines changed

3 files changed

+69
-8
lines changed

pyravendb/store/session_query.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@
55
from datetime import timedelta
66
import sys
77
import time
8+
import re
9+
10+
11+
class EscapeQueryOptions(Enum):
12+
EscapeAll = 0
13+
AllowPostfixWildcard = 1
14+
AllowAllWildcards = 2
15+
RawQuery = 3
816

917

1018
class Query(object):
@@ -56,7 +64,20 @@ def select(self, *args):
5664
self.fetch = args
5765
return self
5866

59-
def _lucene_builder(self, value, action=None):
67+
def _lucene_builder(self, value, action=None, escape_query_options=EscapeQueryOptions.EscapeAll):
68+
69+
if isinstance(value, str):
70+
if escape_query_options == EscapeQueryOptions.EscapeAll:
71+
value = Utils.escape(value, False, False)
72+
73+
elif escape_query_options == EscapeQueryOptions.AllowPostfixWildcard:
74+
value = Utils.escape(value, False, False)
75+
elif escape_query_options == EscapeQueryOptions.AllowAllWildcards:
76+
value = Utils.escape(value, False, False)
77+
value = re.sub(r'"\\\*(\s|$)"', "*${1}", value)
78+
elif escape_query_options == EscapeQueryOptions.RawQuery:
79+
value = Utils.escape(value, False, False).replace("\\*", "*")
80+
6081
lucene_text = Utils.to_lucene(value, action=action)
6182

6283
if len(self.query_builder) > 0 and not self.query_builder.endswith(' '):
@@ -70,13 +91,15 @@ def _lucene_builder(self, value, action=None):
7091
def __iter__(self):
7192
return self._execute_query().__iter__()
7293

73-
def where_equals(self, field_name, value):
94+
def where_equals(self, field_name, value, escape_query_options=EscapeQueryOptions.EscapeAll):
7495
"""
7596
To get all the document that equal to the value in the given field_name
7697
7798
@param field_name:The field name in the index you want to query.
7899
:type str
79100
@param value: The value will be the fields value you want to query
101+
@param escape_query_options: the way we should escape special characters
102+
:type EscapeQueryOptions
80103
"""
81104
if field_name is None:
82105
raise ValueError("None field_name is invalid")
@@ -89,7 +112,7 @@ def where_equals(self, field_name, value):
89112
sort_hint = self.session.conventions.get_default_sort_option("long")
90113
self._sort_hints.add("SortHint-{0}={1}".format(field_name, sort_hint))
91114

92-
lucene_text = self._lucene_builder(value, action="equal")
115+
lucene_text = self._lucene_builder(value, action="equal", escape_query_options=escape_query_options)
93116
self.query_builder += "{0}:{1}".format(field_name, lucene_text)
94117
return self
95118

@@ -108,17 +131,19 @@ def where(self, **kwargs):
108131
self.where_equals(field_name, kwargs[field_name])
109132
return self
110133

111-
def search(self, field_name, search_terms):
134+
def search(self, field_name, search_terms, escape_query_options=EscapeQueryOptions.RawQuery):
112135
"""
113136
for more complex text searching
114137
115138
@param field_name:The field name in the index you want to query.
116139
:type str
117140
@param search_terms: the terms you want to query
118141
:type str
142+
@param escape_query_options: the way we should escape special characters
143+
:type EscapeQueryOptions
119144
"""
120145
search_terms = Utils.quote_key(str(search_terms))
121-
search_terms = self._lucene_builder(search_terms, "search")
146+
search_terms = self._lucene_builder(search_terms, "search", escape_query_options)
122147
self.query_builder += "{0}:{1}".format(field_name, search_terms)
123148
return self
124149

pyravendb/tools/utils.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,6 @@ def make_initialize_dict(document, entity_init):
166166
@staticmethod
167167
def to_lucene(value, action):
168168
query_text = ""
169-
if isinstance(value, str):
170-
value = re.escape(value).replace('\*', '*')
171169
if action == "in":
172170
if not value or len(value) == 0:
173171
return None
@@ -281,3 +279,41 @@ def timedelta_to_str(timedelta_obj):
281279
if microseconds > 0:
282280
timedelta_str += ".{0}".format(microseconds)
283281
return timedelta_str
282+
283+
@staticmethod
284+
def escape(term, allow_wild_cards, make_phrase):
285+
wild_cards = ['-', '&', '|', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', ':', '\\']
286+
if not term:
287+
return "\"\""
288+
start = 0
289+
length = len(term)
290+
builder = ""
291+
if length >= 2 and term[0] == '/' and term[1] == '/':
292+
builder += "//"
293+
start = 2
294+
i = start
295+
while i < length:
296+
ch = term[i]
297+
if ch == '*' or ch == '?':
298+
if allow_wild_cards:
299+
i += 1
300+
continue
301+
302+
if ch in wild_cards:
303+
if i > start:
304+
builder += term[start:i - start]
305+
306+
builder += '\\{0}'.format(ch)
307+
start = i + 1
308+
i += 1
309+
continue
310+
311+
if ch == ' ' or ch == '\t':
312+
if make_phrase:
313+
return "\"{0}\"".format(Utils.escape(term, allow_wild_cards, False))
314+
315+
i += 1
316+
if length > start:
317+
builder += term[start: length]
318+
319+
return builder

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name='pyravendb',
55
packages=find_packages(),
6-
version='1.3.1.2',
6+
version='1.3.1.3',
77
description='This is the official python client for RavenDB document database',
88
author='Idan Haim Shalom',
99
author_email='haimdude@gmail.com',

0 commit comments

Comments
 (0)