Skip to content

Commit

Permalink
support .where() with modify (previously silently ignored)
Browse files Browse the repository at this point in the history
  • Loading branch information
bagerard committed Oct 8, 2024
1 parent e77daa6 commit 9f9cd0a
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Development
- Add support for collation/hint/comment to delete/update and aggregate #2842
- BREAKING CHANGE: Remove LongField as it's equivalent to IntField since we drop support to Python2 long time ago (User should simply switch to IntField) #2309
- BugFix - Calling .clear on a ListField wasn't being marked as changed (and flushed to db upon .save()) #2858

- BugFix - Take `where()` into account when using `.modify()`, as in MyDocument.objects().where("this[field] >= this[otherfield]").modify(field='new') #2044

Changes in 0.29.0
=================
Expand Down
5 changes: 5 additions & 0 deletions mongoengine/queryset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,11 @@ def modify(

queryset = self.clone()
query = queryset._query

if self._where_clause:
where_clause = self._sub_js_fields(self._where_clause)
query["$where"] = where_clause

if not remove:
update = transform.update(queryset._document, **update)
sort = queryset._ordering
Expand Down
57 changes: 56 additions & 1 deletion tests/queryset/test_queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from mongoengine.queryset.base import BaseQuerySet
from tests.utils import (
db_ops_tracker,
get_as_pymongo,
requires_mongodb_gte_42,
requires_mongodb_gte_44,
requires_mongodb_lt_42,
Expand Down Expand Up @@ -4456,7 +4457,7 @@ class Comment(Document):
]
assert ([("_cls", 1), ("message", 1)], False, False) in info

def test_where(self):
def test_where_query(self):
"""Ensure that where clauses work."""

class IntPair(Document):
Expand Down Expand Up @@ -4499,6 +4500,60 @@ class IntPair(Document):
with pytest.raises(TypeError):
list(IntPair.objects.where(fielda__gte=3))

def test_where_query_field_name_subs(self):
class DomainObj(Document):
field_1 = StringField(db_field="field_2")

DomainObj.drop_collection()

DomainObj(field_1="test").save()

obj = DomainObj.objects.where("this[~field_1] == 'NOTMATCHING'")
assert not list(obj)

obj = DomainObj.objects.where("this[~field_1] == 'test'")
assert list(obj)

def test_where_modify(self):
class DomainObj(Document):
field = StringField()

DomainObj.drop_collection()

DomainObj(field="test").save()

obj = DomainObj.objects.where("this[~field] == 'NOTMATCHING'")
assert not list(obj)

obj = DomainObj.objects.where("this[~field] == 'test'")
assert list(obj)

qs = DomainObj.objects.where("this[~field] == 'NOTMATCHING'").modify(
field="new"
)
assert not qs

qs = DomainObj.objects.where("this[~field] == 'test'").modify(field="new")
assert qs

def test_where_modify_field_name_subs(self):
class DomainObj(Document):
field_1 = StringField(db_field="field_2")

DomainObj.drop_collection()

DomainObj(field_1="test").save()

obj = DomainObj.objects.where("this[~field_1] == 'NOTMATCHING'").modify(
field_1="new"
)
assert not obj

obj = DomainObj.objects.where("this[~field_1] == 'test'").modify(field_1="new")
assert obj

assert get_as_pymongo(obj) == {"_id": obj.id, "field_2": "new"}

def test_scalar(self):
class Organization(Document):
name = StringField()
Expand Down

0 comments on commit 9f9cd0a

Please sign in to comment.