diff --git a/django_mongodb_backend/aggregates.py b/django_mongodb_backend/aggregates.py index fb41ce4fc..9921e34ab 100644 --- a/django_mongodb_backend/aggregates.py +++ b/django_mongodb_backend/aggregates.py @@ -52,6 +52,7 @@ def count(self, compiler, connection, resolve_inner_expression=False): # If distinct=True or resolve_inner_expression=False, sum the size of the # set. lhs_mql = process_lhs(self, compiler, connection, as_expr=True) + lhs_mql = {"$ifNull": [lhs_mql, []]} # None shouldn't be counted, so subtract 1 if it's present. exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}} return {"$add": [{"$size": lhs_mql}, exits_null]} diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index cb867221e..1388e9ba7 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -38,6 +38,7 @@ def __init__(self, *args, **kwargs): self.subqueries = [] # Atlas search stage. self.search_pipeline = [] + self.wrap_for_global_aggregation = False def _get_group_alias_column(self, expr, annotation_group_idx): """Generate a dummy field for use in the ids fields in $group.""" @@ -234,21 +235,8 @@ def _build_aggregation_pipeline(self, ids, group): """Build the aggregation pipeline for grouping.""" pipeline = [] if not ids: - group["_id"] = None - pipeline.append({"$facet": {"group": [{"$group": group}]}}) - pipeline.append( - { - "$addFields": { - key: { - "$getField": { - "input": {"$arrayElemAt": ["$group", 0]}, - "field": key, - } - } - for key in group - } - } - ) + pipeline.append({"$group": {"_id": None, **group}}) + self.wrap_for_global_aggregation = True else: group["_id"] = ids pipeline.append({"$group": group}) diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index 5b4f0ec51..85cf0c774 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -56,6 +56,7 @@ def __init__(self, compiler): # $lookup stage that encapsulates the pipeline for performing a nested # subquery. self.subquery_lookup = None + self.wrap_for_global_aggregation = compiler.wrap_for_global_aggregation def __repr__(self): return f"" @@ -91,6 +92,22 @@ def get_pipeline(self): pipeline.append({"$match": self.match_mql}) if self.aggregation_pipeline: pipeline.extend(self.aggregation_pipeline) + if self.wrap_for_global_aggregation: + pipeline = [ + {"$collStats": {}}, + { + "$lookup": { + "from": self.compiler.collection_name, + "as": "wrapped", + "pipeline": pipeline, + } + }, + { + "$replaceWith": { + "$cond": [{"$eq": ["$wrapped", []]}, {}, {"$first": "$wrapped"}] + } + }, + ] if self.project_fields: pipeline.append({"$project": self.project_fields}) if self.combinator_pipeline: