diff --git a/django_mongodb/expressions.py b/django_mongodb/expressions.py index 957c5f15..477d0101 100644 --- a/django_mongodb/expressions.py +++ b/django_mongodb/expressions.py @@ -203,20 +203,24 @@ def when(self, compiler, connection): def value(self, compiler, connection): # noqa: ARG001 value = self.value + if isinstance(value, int): + # Wrap numbers in $literal to prevent ambiguity when Value appears in + # $project. + return {"$literal": value} if isinstance(value, Decimal): - value = Decimal128(value) - elif isinstance(value, datetime.date): + return Decimal128(value) + if isinstance(value, datetime.date): # Turn dates into datetimes since BSON doesn't support dates. - value = datetime.datetime.combine(value, datetime.datetime.min.time()) - elif isinstance(value, datetime.time): + return datetime.datetime.combine(value, datetime.datetime.min.time()) + if isinstance(value, datetime.time): # Turn times into datetimes since BSON doesn't support times. - value = datetime.datetime.combine(datetime.datetime.min.date(), value) - elif isinstance(value, datetime.timedelta): + return datetime.datetime.combine(datetime.datetime.min.date(), value) + if isinstance(value, datetime.timedelta): # DurationField stores milliseconds rather than microseconds. - value /= datetime.timedelta(milliseconds=1) - elif isinstance(value, UUID): - value = value.hex - return {"$literal": value} + return value / datetime.timedelta(milliseconds=1) + if isinstance(value, UUID): + return value.hex + return value def register_expressions(): diff --git a/tests/expressions_/test_value.py b/tests/expressions_/test_value.py index 3325092c..51ac2d28 100644 --- a/tests/expressions_/test_value.py +++ b/tests/expressions_/test_value.py @@ -11,37 +11,33 @@ class ValueTests(SimpleTestCase): def test_date(self): self.assertEqual( Value(datetime.date(2025, 1, 1)).as_mql(None, None), - {"$literal": datetime.datetime(2025, 1, 1)}, + datetime.datetime(2025, 1, 1), ) def test_datetime(self): self.assertEqual( Value(datetime.datetime(2025, 1, 1, 9, 8, 7)).as_mql(None, None), - {"$literal": datetime.datetime(2025, 1, 1)}, + datetime.datetime(2025, 1, 1), ) def test_decimal(self): - self.assertEqual(Value(Decimal("1.0")).as_mql(None, None), {"$literal": Decimal128("1.0")}) + self.assertEqual(Value(Decimal("1.0")).as_mql(None, None), Decimal128("1.0")) def test_time(self): self.assertEqual( Value(datetime.time(9, 8, 7)).as_mql(None, None), - {"$literal": datetime.datetime(1, 1, 1, 9, 8, 7)}, + datetime.datetime(1, 1, 1, 9, 8, 7), ) def test_timedelta(self): - self.assertEqual( - Value(datetime.timedelta(3600)).as_mql(None, None), {"$literal": 311040000000.0} - ) + self.assertEqual(Value(datetime.timedelta(3600)).as_mql(None, None), 311040000000.0) def test_int(self): self.assertEqual(Value(1).as_mql(None, None), {"$literal": 1}) def test_str(self): - self.assertEqual(Value("foo").as_mql(None, None), {"$literal": "foo"}) + self.assertEqual(Value("foo").as_mql(None, None), "foo") def test_uuid(self): value = uuid.UUID(int=1) - self.assertEqual( - Value(value).as_mql(None, None), {"$literal": "00000000000000000000000000000001"} - ) + self.assertEqual(Value(value).as_mql(None, None), "00000000000000000000000000000001")