55
66import graphene
77import mongoengine
8- from bson import DBRef
8+ from bson import DBRef , ObjectId
99from graphene import Context
1010from graphene .types .utils import get_type
1111from graphene .utils .str_converters import to_snake_case
@@ -215,7 +215,10 @@ def fields(self):
215215 self ._type = get_type (self ._type )
216216 return self ._type ._meta .fields
217217
218- def get_queryset (self , model , info , required_fields = list (), skip = None , limit = None , reversed = False , ** args ):
218+ def get_queryset (self , model , info , required_fields = None , skip = None , limit = None , reversed = False , ** args ):
219+ if required_fields is None :
220+ required_fields = list ()
221+
219222 if args :
220223 reference_fields = get_model_reference_fields (self .model )
221224 hydrated_references = {}
@@ -276,7 +279,9 @@ def get_queryset(self, model, info, required_fields=list(), skip=None, limit=Non
276279 skip )
277280 return model .objects (** args ).no_dereference ().only (* required_fields ).order_by (self .order_by )
278281
279- def default_resolver (self , _root , info , required_fields = list (), ** args ):
282+ def default_resolver (self , _root , info , required_fields = None , ** args ):
283+ if required_fields is None :
284+ required_fields = list ()
280285 args = args or {}
281286 for key , value in dict (args ).items ():
282287 if value is None :
@@ -400,39 +405,13 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
400405 if isinstance (self .model , mongoengine .Document ) or isinstance (self .model ,
401406 mongoengine .base .metaclasses .TopLevelDocumentMetaclass ):
402407
403- skip = 0
404- count = 0
405- limit = None
406- reverse = False
407- first = args_copy .get ("first" )
408- after = args_copy .get ("after" )
409- if after :
410- after = cursor_to_offset (after )
411- last = args_copy .get ("last" )
412- before = args_copy .get ("before" )
413408 for arg_name , arg in args .copy ().items ():
414409 if arg_name not in self .model ._fields_ordered + tuple (self .filter_args .keys ()):
415410 args_copy .pop (arg_name )
416411 if isinstance (info , GraphQLResolveInfo ):
417412 if not info .context :
418413 info = info ._replace (context = Context ())
419- args_count_copy = args .copy ()
420- for key in args .copy ():
421- if key not in self .model ._fields_ordered :
422- args_count_copy .pop (key )
423- elif isinstance (getattr (self .model , key ),
424- mongoengine .fields .ReferenceField ) or isinstance (getattr (self .model , key ),
425- mongoengine .fields .GenericReferenceField ) or isinstance (
426- getattr (self .model , key ),
427- mongoengine .fields .LazyReferenceField ) or isinstance (getattr (self .model , key ),
428- mongoengine .fields .CachedReferenceField ):
429- args_count_copy [key ] = from_global_id (args_count_copy [key ])[1 ]
430- count = mongoengine .get_db ()[self .model ._get_collection_name ()].find (args_count_copy ).count ()
431- if count != 0 :
432- skip , limit , reverse = find_skip_and_limit (first = first , after = after , last = last ,
433- before = before ,
434- count = count )
435- info .context .queryset = self .get_queryset (self .model , info , required_fields , skip , limit , reverse )
414+ info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args )
436415
437416 # XXX: Filter nested args
438417 resolved = resolver (root , info , ** args )
@@ -454,7 +433,7 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
454433 if arg_name == '_id' and isinstance (arg , dict ):
455434 operation = list (arg .keys ())[0 ]
456435 args_copy ['pk' + operation .replace ('$' , '__' )] = arg [operation ]
457- if '.' in arg_name :
436+ if '.' in arg_name and not isinstance ( arg , ObjectId ) :
458437 operation = list (arg .keys ())[0 ]
459438 args_copy [arg_name .replace ('.' , '__' ) + operation .replace ('$' , '__' )] = arg [operation ]
460439 else :
0 commit comments