@@ -330,7 +330,18 @@ def default_resolver(self, _root, info, required_fields=list(), **args):
330330 info = info ._replace (context = Context ())
331331 info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args )
332332 elif _root is None or args :
333- count = self .get_queryset (self .model , info , required_fields , ** args ).count ()
333+ args_copy = args .copy ()
334+ for key in args .copy ():
335+ if key not in self .model ._fields_ordered :
336+ args_copy .pop (key )
337+ elif isinstance (getattr (self .model , key ),
338+ mongoengine .fields .ReferenceField ) or isinstance (getattr (self .model , key ),
339+ mongoengine .fields .GenericReferenceField ) or isinstance (
340+ getattr (self .model , key ),
341+ mongoengine .fields .LazyReferenceField ) or isinstance (getattr (self .model , key ),
342+ mongoengine .fields .CachedReferenceField ):
343+ args_copy [key ] = from_global_id (args_copy [key ])[1 ]
344+ count = mongoengine .get_db ()[self .model ._get_collection_name ()].find (args_copy ).count ()
334345 if count != 0 :
335346 skip , limit , reverse = find_skip_and_limit (first = first , after = after , last = last , before = before ,
336347 count = count )
@@ -388,13 +399,40 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
388399 if not bool (args ) or not is_partial :
389400 if isinstance (self .model , mongoengine .Document ) or isinstance (self .model ,
390401 mongoengine .base .metaclasses .TopLevelDocumentMetaclass ):
402+
403+ skip = 0
404+ count = 0
405+ limit = None
406+ reverse = False
407+ first = args_copy .get ("first" )
408+ after = args_copy .get ("after" )
409+ if after :
410+ after = cursor_to_offset (after )
411+ last = args_copy .get ("last" )
412+ before = args_copy .get ("before" )
391413 for arg_name , arg in args .copy ().items ():
392414 if arg_name not in self .model ._fields_ordered + tuple (self .filter_args .keys ()):
393415 args_copy .pop (arg_name )
394416 if isinstance (info , GraphQLResolveInfo ):
395417 if not info .context :
396418 info = info ._replace (context = Context ())
397- info .context .queryset = self .get_queryset (self .model , info , required_fields , ** args_copy )
419+ args_count_copy = args .copy ()
420+ for key in args .copy ():
421+ if key not in self .model ._fields_ordered :
422+ args_count_copy .pop (key )
423+ elif isinstance (getattr (self .model , key ),
424+ mongoengine .fields .ReferenceField ) or isinstance (getattr (self .model , key ),
425+ mongoengine .fields .GenericReferenceField ) or isinstance (
426+ getattr (self .model , key ),
427+ mongoengine .fields .LazyReferenceField ) or isinstance (getattr (self .model , key ),
428+ mongoengine .fields .CachedReferenceField ):
429+ args_count_copy [key ] = from_global_id (args_count_copy [key ])[1 ]
430+ count = mongoengine .get_db ()[self .model ._get_collection_name ()].find (args_count_copy ).count ()
431+ if count != 0 :
432+ skip , limit , reverse = find_skip_and_limit (first = first , after = after , last = last ,
433+ before = before ,
434+ count = count )
435+ info .context .queryset = self .get_queryset (self .model , info , required_fields , skip , limit , reverse )
398436
399437 # XXX: Filter nested args
400438 resolved = resolver (root , info , ** args )
0 commit comments