55import copy
66from datetime import datetime , date
77from dateutil .relativedelta import relativedelta
8- from functools import reduce
8+ from functools import reduce , partial
99from mongoengine import (
1010 EmbeddedDocumentField ,
1111 EmbeddedDocument ,
@@ -446,15 +446,24 @@ def _get_reference_model(cls, key):
446446 return None , None , None , None
447447
448448 @classmethod
449- def _change_reference_condition (cls , key , value , operator ):
449+ def _change_reference_condition (cls , key , value , operator , reference_filter = None ):
450450 ref_model , ref_key , ref_query_key , foreign_key = cls ._get_reference_model (key )
451451 if ref_model :
452452 if value is None :
453453 return ref_key , value , operator
454454 else :
455- ref_vos , total_count = ref_model .query (
456- filter = [{"k" : ref_query_key , "v" : value , "o" : operator }]
457- )
455+ if operator == "not" :
456+ _filter = [{"k" : ref_query_key , "v" : value , "o" : "eq" }]
457+ elif operator == "not_in" :
458+ _filter = [{"k" : ref_query_key , "v" : value , "o" : "in" }]
459+ else :
460+ _filter = [{"k" : ref_query_key , "v" : value , "o" : operator }]
461+ if reference_filter :
462+ for key , value in reference_filter .items ():
463+ if value :
464+ _filter .append ({"k" : key , "v" : value , "o" : "eq" })
465+
466+ ref_vos , total_count = ref_model .query (filter = _filter )
458467
459468 if foreign_key :
460469 ref_values = []
@@ -464,13 +473,17 @@ def _change_reference_condition(cls, key, value, operator):
464473 ref_values .append (ref_value )
465474 else :
466475 ref_values = list (ref_vos )
467- return ref_key , ref_values , "in"
476+
477+ if operator in ["not" , "not_in" ]:
478+ return ref_key , ref_values , "not_in"
479+ else :
480+ return ref_key , ref_values , "in"
468481
469482 else :
470483 return key , value , operator
471484
472485 @classmethod
473- def _make_condition (cls , condition ):
486+ def _make_condition (cls , condition , reference_filter = None ):
474487 key = condition .get ("key" , condition .get ("k" ))
475488 value = condition .get ("value" , condition .get ("v" ))
476489 operator = condition .get ("operator" , condition .get ("o" ))
@@ -479,7 +492,7 @@ def _make_condition(cls, condition):
479492 if operator not in FILTER_OPERATORS :
480493 raise ERROR_DB_QUERY (
481494 reason = f"Filter operator is not supported. (operator = "
482- f"{ FILTER_OPERATORS .keys ()} )"
495+ f"{ FILTER_OPERATORS .keys ()} )"
483496 )
484497
485498 resolver , mongo_operator , is_multiple = FILTER_OPERATORS .get (operator )
@@ -493,7 +506,7 @@ def _make_condition(cls, condition):
493506 if operator not in ["regex" , "regex_in" ]:
494507 if cls ._check_reference_field (key ):
495508 key , value , operator = cls ._change_reference_condition (
496- key , value , operator
509+ key , value , operator , reference_filter
497510 )
498511
499512 resolver , mongo_operator , is_multiple = FILTER_OPERATORS [operator ]
@@ -507,15 +520,27 @@ def _make_condition(cls, condition):
507520 )
508521
509522 @classmethod
510- def _make_filter (cls , filter , filter_or ):
523+ def _make_filter (cls , filter , filter_or , reference_filter ):
511524 _filter = None
512525 _filter_or = None
513526
514527 if len (filter ) > 0 :
515- _filter = reduce (lambda x , y : x & y , map (cls ._make_condition , filter ))
528+ _filter = reduce (
529+ lambda x , y : x & y ,
530+ map (
531+ partial (cls ._make_condition , reference_filter = reference_filter ),
532+ filter ,
533+ ),
534+ )
516535
517536 if len (filter_or ) > 0 :
518- _filter_or = reduce (lambda x , y : x | y , map (cls ._make_condition , filter_or ))
537+ _filter_or = reduce (
538+ lambda x , y : x | y ,
539+ map (
540+ partial (cls ._make_condition , reference_filter = reference_filter ),
541+ filter_or ,
542+ ),
543+ )
519544
520545 if _filter and _filter_or :
521546 _filter = _filter & _filter_or
@@ -566,14 +591,14 @@ def _make_unwind_project_stage(only: list):
566591
567592 @classmethod
568593 def _stat_with_unwind (
569- cls ,
570- unwind : list ,
571- only : list = None ,
572- filter : list = None ,
573- filter_or : list = None ,
574- sort : list = None ,
575- page : dict = None ,
576- target : str = None ,
594+ cls ,
595+ unwind : list ,
596+ only : list = None ,
597+ filter : list = None ,
598+ filter_or : list = None ,
599+ sort : list = None ,
600+ page : dict = None ,
601+ target : str = None ,
577602 ):
578603 if only is None :
579604 raise ERROR_DB_QUERY (reason = "unwind option requires only option." )
@@ -641,19 +666,20 @@ def _stat_with_unwind(
641666
642667 @classmethod
643668 def query (
644- cls ,
645- * args ,
646- only = None ,
647- exclude = None ,
648- filter = None ,
649- filter_or = None ,
650- sort = None ,
651- page = None ,
652- minimal = False ,
653- count_only = False ,
654- unwind = None ,
655- target = None ,
656- ** kwargs ,
669+ cls ,
670+ * args ,
671+ only = None ,
672+ exclude = None ,
673+ filter = None ,
674+ filter_or = None ,
675+ sort = None ,
676+ page = None ,
677+ minimal = False ,
678+ count_only = False ,
679+ unwind = None ,
680+ reference_filter = None ,
681+ target = None ,
682+ ** kwargs ,
657683 ):
658684 filter = filter or []
659685 filter_or = filter_or or []
@@ -669,7 +695,7 @@ def query(
669695 _order_by = []
670696 minimal_fields = cls ._meta .get ("minimal_fields" )
671697
672- _filter = cls ._make_filter (filter , filter_or )
698+ _filter = cls ._make_filter (filter , filter_or , reference_filter )
673699
674700 for sort_option in sort :
675701 if sort_option .get ("desc" , False ):
@@ -715,7 +741,7 @@ def query(
715741 if start < 1 :
716742 start = 1
717743
718- vos = vos [start - 1 : start + page ["limit" ] - 1 ]
744+ vos = vos [start - 1 : start + page ["limit" ] - 1 ]
719745
720746 return vos , total_count
721747
@@ -786,7 +812,7 @@ def _make_sub_conditions(cls, sub_conditions, _before_group_keys):
786812 if operator not in _SUPPORTED_OPERATOR :
787813 raise ERROR_DB_QUERY (
788814 reason = f"'aggregate.group.fields.conditions.operator' condition's { operator } operator is not "
789- f"supported. (supported_operator = { _SUPPORTED_OPERATOR } )"
815+ f"supported. (supported_operator = { _SUPPORTED_OPERATOR } )"
790816 )
791817
792818 if key in _before_group_keys :
@@ -808,7 +834,7 @@ def _get_group_fields(cls, condition, _before_group_keys):
808834 if operator not in STAT_GROUP_OPERATORS :
809835 raise ERROR_DB_QUERY (
810836 reason = f"'aggregate.group.fields' condition's { operator } operator is not supported. "
811- f"(supported_operator = { list (STAT_GROUP_OPERATORS .keys ())} )"
837+ f"(supported_operator = { list (STAT_GROUP_OPERATORS .keys ())} )"
812838 )
813839
814840 if name is None :
@@ -927,7 +953,7 @@ def _get_project_fields(cls, condition):
927953 if operator and operator not in STAT_PROJECT_OPERATORS :
928954 raise ERROR_DB_QUERY (
929955 reason = f"'aggregate.project.fields' condition's { operator } operator is not supported. "
930- f"(supported_operator = { list (STAT_PROJECT_OPERATORS .keys ())} )"
956+ f"(supported_operator = { list (STAT_PROJECT_OPERATORS .keys ())} )"
931957 )
932958
933959 if name is None :
@@ -1085,9 +1111,9 @@ def _make_aggregate_rules(cls, aggregate):
10851111 else :
10861112 raise ERROR_REQUIRED_PARAMETER (
10871113 key = "aggregate.unwind or aggregate.group or "
1088- "aggregate.count or aggregate.sort or "
1089- "aggregate.project or aggregate.limit or "
1090- "aggregate.skip"
1114+ "aggregate.count or aggregate.sort or "
1115+ "aggregate.project or aggregate.limit or "
1116+ "aggregate.skip"
10911117 )
10921118
10931119 return _aggregate_rules
@@ -1141,23 +1167,24 @@ def _stat_distinct(cls, vos, distinct, page):
11411167 start = 1
11421168
11431169 result ["total_count" ] = len (values )
1144- values = values [start - 1 : start + page ["limit" ] - 1 ]
1170+ values = values [start - 1 : start + page ["limit" ] - 1 ]
11451171
11461172 result ["results" ] = cls ._make_distinct_values (values )
11471173 return result
11481174
11491175 @classmethod
11501176 def stat (
1151- cls ,
1152- * args ,
1153- aggregate = None ,
1154- distinct = None ,
1155- filter = None ,
1156- filter_or = None ,
1157- page = None ,
1158- target = "SECONDARY_PREFERRED" ,
1159- allow_disk_use = False ,
1160- ** kwargs ,
1177+ cls ,
1178+ * args ,
1179+ aggregate = None ,
1180+ distinct = None ,
1181+ filter = None ,
1182+ filter_or = None ,
1183+ page = None ,
1184+ reference_filter = None ,
1185+ target = "SECONDARY_PREFERRED" ,
1186+ allow_disk_use = False ,
1187+ ** kwargs ,
11611188 ):
11621189 filter = filter or []
11631190 filter_or = filter_or or []
@@ -1166,7 +1193,7 @@ def stat(
11661193 if not (aggregate or distinct ):
11671194 raise ERROR_REQUIRED_PARAMETER (key = "aggregate" )
11681195
1169- _filter = cls ._make_filter (filter , filter_or )
1196+ _filter = cls ._make_filter (filter , filter_or , reference_filter )
11701197
11711198 try :
11721199 vos = cls ._get_target_objects (target ).filter (_filter )
@@ -1453,24 +1480,25 @@ def _convert_date_value(cls, date_value, date_field_format):
14531480
14541481 @classmethod
14551482 def analyze (
1456- cls ,
1457- * args ,
1458- granularity = None ,
1459- fields = None ,
1460- select = None ,
1461- group_by = None ,
1462- field_group = None ,
1463- filter = None ,
1464- filter_or = None ,
1465- page = None ,
1466- sort = None ,
1467- start = None ,
1468- end = None ,
1469- date_field = "date" ,
1470- date_field_format = "%Y-%m-%d" ,
1471- target = "SECONDARY_PREFERRED" ,
1472- allow_disk_use = False ,
1473- ** kwargs ,
1483+ cls ,
1484+ * args ,
1485+ granularity = None ,
1486+ fields = None ,
1487+ select = None ,
1488+ group_by = None ,
1489+ field_group = None ,
1490+ filter = None ,
1491+ filter_or = None ,
1492+ page = None ,
1493+ sort = None ,
1494+ start = None ,
1495+ end = None ,
1496+ date_field = "date" ,
1497+ date_field_format = "%Y-%m-%d" ,
1498+ reference_filter = None ,
1499+ target = "SECONDARY_PREFERRED" ,
1500+ allow_disk_use = False ,
1501+ ** kwargs ,
14741502 ):
14751503 if fields is None :
14761504 raise ERROR_REQUIRED_PARAMETER (key = "fields" )
@@ -1504,6 +1532,7 @@ def analyze(
15041532 "aggregate" : [{"group" : {"keys" : group_keys , "fields" : group_fields }}],
15051533 "target" : target ,
15061534 "allow_disk_use" : allow_disk_use ,
1535+ "reference_filter" : reference_filter ,
15071536 }
15081537
15091538 if select :
0 commit comments