diff --git a/docs/changelog.rst b/docs/changelog.rst index 64d565578..01f6c3236 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -13,6 +13,7 @@ Development - make sure to read https://www.mongodb.com/docs/manual/core/transactions-in-applications/#callback-api-vs-core-api - run_in_transaction context manager relies on Pymongo coreAPI, it will retry automatically in case of `UnknownTransactionCommitResult` but not `TransientTransactionError` exceptions - Using .count() in a transaction will always use Collection.count_document (as estimated_document_count is not supported in transactions) +- Fix use of $geoNear or $collStats in aggregate #2493 - BREAKING CHANGE: Further to the deprecation warning, remove ability to use an unpacked list to `Queryset.aggregate(*pipeline)`, a plain list must be provided instead `Queryset.aggregate(pipeline)`, as it's closer to pymongo interface - BREAKING CHANGE: Further to the deprecation warning, remove `full_response` from `QuerySet.modify` as it wasn't supported with Pymongo 3+ - Fixed stacklevel of many warnings (to point places emitting the warning more accurately) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index a64b26168..aef996996 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -1350,6 +1350,14 @@ def from_json(self, json_data): def aggregate(self, pipeline, **kwargs): """Perform an aggregate function based on your queryset params + If the queryset contains a query or skip/limit/sort or if the target Document class + uses inheritance, this method will add steps prior to the provided pipeline in an arbitrary order. + This may affect the performance or outcome of the aggregation, so use it consciously. + + For complex/critical pipelines, we recommended to use the aggregation framework of Pymongo directly, + it is available through the collection object (YourDocument._collection.aggregate) and will guarantee + that you have full control on the pipeline. + :param pipeline: list of aggregation commands, see: https://www.mongodb.com/docs/manual/core/aggregation-pipeline/ :param kwargs: (optional) kwargs dictionary to be passed to pymongo's aggregate call @@ -1381,7 +1389,18 @@ def aggregate(self, pipeline, **kwargs): if self._skip is not None: initial_pipeline.append({"$skip": self._skip}) - final_pipeline = initial_pipeline + pipeline + # geoNear and collStats must be the first stages in the pipeline if present + first_step = [] + new_user_pipeline = [] + for step_step in pipeline: + if "$geoNear" in step_step: + first_step.append(step_step) + elif "$collStats" in step_step: + first_step.append(step_step) + else: + new_user_pipeline.append(step_step) + + final_pipeline = first_step + initial_pipeline + new_user_pipeline collection = self._collection if self._read_preference is not None or self._read_concern is not None: diff --git a/tests/queryset/test_queryset_aggregation.py b/tests/queryset/test_queryset_aggregation.py index 15e08698f..7e390e35a 100644 --- a/tests/queryset/test_queryset_aggregation.py +++ b/tests/queryset/test_queryset_aggregation.py @@ -1,5 +1,3 @@ -import unittest - import pytest from pymongo.read_preferences import ReadPreference @@ -334,6 +332,44 @@ class Person(Document): assert list(data) == [] + def test_aggregate_geo_near_used_as_initial_step_before_cls_implicit_step(self): + class BaseClass(Document): + meta = {"allow_inheritance": True} + + class Aggr(BaseClass): + name = StringField() + c = PointField() + + BaseClass.drop_collection() + + x = Aggr(name="X", c=[10.634584, 35.8245029]).save() + y = Aggr(name="Y", c=[10.634584, 35.8245029]).save() + + pipeline = [ + { + "$geoNear": { + "near": {"type": "Point", "coordinates": [10.634584, 35.8245029]}, + "distanceField": "c", + "spherical": True, + } + } + ] + res = list(Aggr.objects.aggregate(pipeline)) + assert res == [ + {"_cls": "BaseClass.Aggr", "_id": x.id, "c": 0.0, "name": "X"}, + {"_cls": "BaseClass.Aggr", "_id": y.id, "c": 0.0, "name": "Y"}, + ] + + def test_aggregate_collstats_used_as_initial_step_before_cls_implicit_step(self): + class SomeDoc(Document): + name = StringField() + + SomeDoc.drop_collection() + + SomeDoc(name="X").save() + SomeDoc(name="Y").save() -if __name__ == "__main__": - unittest.main() + pipeline = [{"$collStats": {"count": {}}}] + res = list(SomeDoc.objects.aggregate(pipeline)) + assert len(res) == 1 + assert res[0]["count"] == 2