55from datetime import timedelta
66import sys
77import time
8+ import re
9+
10+
11+ class EscapeQueryOptions (Enum ):
12+ EscapeAll = 0
13+ AllowPostfixWildcard = 1
14+ AllowAllWildcards = 2
15+ RawQuery = 3
816
917
1018class Query (object ):
@@ -56,7 +64,20 @@ def select(self, *args):
5664 self .fetch = args
5765 return self
5866
59- def _lucene_builder (self , value , action = None ):
67+ def _lucene_builder (self , value , action = None , escape_query_options = EscapeQueryOptions .EscapeAll ):
68+
69+ if isinstance (value , str ):
70+ if escape_query_options == EscapeQueryOptions .EscapeAll :
71+ value = Utils .escape (value , False , False )
72+
73+ elif escape_query_options == EscapeQueryOptions .AllowPostfixWildcard :
74+ value = Utils .escape (value , False , False )
75+ elif escape_query_options == EscapeQueryOptions .AllowAllWildcards :
76+ value = Utils .escape (value , False , False )
77+ value = re .sub (r'"\\\*(\s|$)"' , "*${1}" , value )
78+ elif escape_query_options == EscapeQueryOptions .RawQuery :
79+ value = Utils .escape (value , False , False ).replace ("\\ *" , "*" )
80+
6081 lucene_text = Utils .to_lucene (value , action = action )
6182
6283 if len (self .query_builder ) > 0 and not self .query_builder .endswith (' ' ):
@@ -70,13 +91,15 @@ def _lucene_builder(self, value, action=None):
7091 def __iter__ (self ):
7192 return self ._execute_query ().__iter__ ()
7293
73- def where_equals (self , field_name , value ):
94+ def where_equals (self , field_name , value , escape_query_options = EscapeQueryOptions . EscapeAll ):
7495 """
7596 To get all the document that equal to the value in the given field_name
7697
7798 @param field_name:The field name in the index you want to query.
7899 :type str
79100 @param value: The value will be the fields value you want to query
101+ @param escape_query_options: the way we should escape special characters
102+ :type EscapeQueryOptions
80103 """
81104 if field_name is None :
82105 raise ValueError ("None field_name is invalid" )
@@ -89,7 +112,7 @@ def where_equals(self, field_name, value):
89112 sort_hint = self .session .conventions .get_default_sort_option ("long" )
90113 self ._sort_hints .add ("SortHint-{0}={1}" .format (field_name , sort_hint ))
91114
92- lucene_text = self ._lucene_builder (value , action = "equal" )
115+ lucene_text = self ._lucene_builder (value , action = "equal" , escape_query_options = escape_query_options )
93116 self .query_builder += "{0}:{1}" .format (field_name , lucene_text )
94117 return self
95118
@@ -108,17 +131,19 @@ def where(self, **kwargs):
108131 self .where_equals (field_name , kwargs [field_name ])
109132 return self
110133
111- def search (self , field_name , search_terms ):
134+ def search (self , field_name , search_terms , escape_query_options = EscapeQueryOptions . RawQuery ):
112135 """
113136 for more complex text searching
114137
115138 @param field_name:The field name in the index you want to query.
116139 :type str
117140 @param search_terms: the terms you want to query
118141 :type str
142+ @param escape_query_options: the way we should escape special characters
143+ :type EscapeQueryOptions
119144 """
120145 search_terms = Utils .quote_key (str (search_terms ))
121- search_terms = self ._lucene_builder (search_terms , "search" )
146+ search_terms = self ._lucene_builder (search_terms , "search" , escape_query_options )
122147 self .query_builder += "{0}:{1}" .format (field_name , search_terms )
123148 return self
124149
0 commit comments