Skip to content

Commit

Permalink
remove unnecessary $literal wrapping in Value.as_mql()
Browse files Browse the repository at this point in the history
It only needs to wrap integers, and wrapping over types causes incorrect
matching in ArrayField.
  • Loading branch information
timgraham committed Jan 1, 2025
1 parent 3aa239c commit f91a5ea
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
24 changes: 14 additions & 10 deletions django_mongodb/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,20 +204,24 @@ def when(self, compiler, connection):

def value(self, compiler, connection): # noqa: ARG001
value = self.value
if isinstance(value, int):
# Wrap numbers in $literal to prevent ambiguity when Value appears in
# $project.
return {"$literal": value}
if isinstance(value, Decimal):
value = Decimal128(value)
elif isinstance(value, datetime.date):
return Decimal128(value)
if isinstance(value, datetime.date):
# Turn dates into datetimes since BSON doesn't support dates.
value = datetime.datetime.combine(value, datetime.datetime.min.time())
elif isinstance(value, datetime.time):
return datetime.datetime.combine(value, datetime.datetime.min.time())
if isinstance(value, datetime.time):
# Turn times into datetimes since BSON doesn't support times.
value = datetime.datetime.combine(datetime.datetime.min.date(), value)
elif isinstance(value, datetime.timedelta):
return datetime.datetime.combine(datetime.datetime.min.date(), value)
if isinstance(value, datetime.timedelta):
# DurationField stores milliseconds rather than microseconds.
value /= datetime.timedelta(milliseconds=1)
elif isinstance(value, UUID):
value = value.hex
return {"$literal": value}
return value / datetime.timedelta(milliseconds=1)
if isinstance(value, UUID):
return value.hex
return value


def register_expressions():
Expand Down
18 changes: 7 additions & 11 deletions tests/expressions_/test_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,33 @@ class ValueTests(SimpleTestCase):
def test_date(self):
self.assertEqual(
Value(datetime.date(2025, 1, 1)).as_mql(None, None),
{"$literal": datetime.datetime(2025, 1, 1)},
datetime.datetime(2025, 1, 1),
)

def test_datetime(self):
self.assertEqual(
Value(datetime.datetime(2025, 1, 1, 9, 8, 7)).as_mql(None, None),
{"$literal": datetime.datetime(2025, 1, 1)},
datetime.datetime(2025, 1, 1),
)

def test_decimal(self):
self.assertEqual(Value(Decimal("1.0")).as_mql(None, None), {"$literal": Decimal128("1.0")})
self.assertEqual(Value(Decimal("1.0")).as_mql(None, None), Decimal128("1.0"))

def test_time(self):
self.assertEqual(
Value(datetime.time(9, 8, 7)).as_mql(None, None),
{"$literal": datetime.datetime(1, 1, 1, 9, 8, 7)},
datetime.datetime(1, 1, 1, 9, 8, 7),
)

def test_timedelta(self):
self.assertEqual(
Value(datetime.timedelta(3600)).as_mql(None, None), {"$literal": 311040000000.0}
)
self.assertEqual(Value(datetime.timedelta(3600)).as_mql(None, None), 311040000000.0)

def test_int(self):
self.assertEqual(Value(1).as_mql(None, None), {"$literal": 1})

def test_str(self):
self.assertEqual(Value("foo").as_mql(None, None), {"$literal": "foo"})
self.assertEqual(Value("foo").as_mql(None, None), "foo")

def test_uuid(self):
value = uuid.UUID(int=1)
self.assertEqual(
Value(value).as_mql(None, None), {"$literal": "00000000000000000000000000000001"}
)
self.assertEqual(Value(value).as_mql(None, None), "00000000000000000000000000000001")

0 comments on commit f91a5ea

Please sign in to comment.