diff --git a/ddtrace/contrib/pymongo/client.py b/ddtrace/contrib/pymongo/client.py index ab1916b582b..429826df811 100644 --- a/ddtrace/contrib/pymongo/client.py +++ b/ddtrace/contrib/pymongo/client.py @@ -68,6 +68,7 @@ def __init__(self, client=None, *args, **kwargs): client = _MongoClient(client, *args, **kwargs) super(TracedMongoClient, self).__init__(client) + client._datadog_proxy = self # NOTE[matt] the TracedMongoClient attempts to trace all of the network # calls in the trace library. This is good because it measures the # actual network time. It's bad because it uses a private API which @@ -85,6 +86,25 @@ def __getddpin__(self): return ddtrace.Pin.get_from(self._topology) +@contextlib.contextmanager +def wrapped_validate_session(wrapped, instance, args, kwargs): + # We do this to handle a validation `A is B` in pymongo that + # relies on IDs being equal. Since we are proxying objects, we need + # to ensure we're compare proxy with proxy or wrapped with wrapped + # or this validation will fail + client = args[0] + session = args[1] + session_client = session._client + if isinstance(session_client, TracedMongoClient): + if isinstance(client, _MongoClient): + client = getattr(client, "_datadog_proxy", client) + elif isinstance(session_client, _MongoClient): + if isinstance(client, TracedMongoClient): + client = client.__wrapped__ + + yield wrapped(client, session) + + class TracedTopology(ObjectProxy): def __init__(self, topology): super(TracedTopology, self).__init__(topology) diff --git a/ddtrace/contrib/pymongo/patch.py b/ddtrace/contrib/pymongo/patch.py index 13ee461689e..aa838a8c9ee 100644 --- a/ddtrace/contrib/pymongo/patch.py +++ b/ddtrace/contrib/pymongo/patch.py @@ -17,6 +17,7 @@ from ..trace_utils import unwrap as _u from .client import TracedMongoClient from .client import set_address_tags +from .client import wrapped_validate_session config._add( @@ -35,6 +36,7 @@ def get_version(): _VERSION = pymongo.version_tuple _CHECKOUT_FN_NAME = "get_socket" if _VERSION < (4, 5) else "checkout" +_VERIFY_VERSION_CLASS = pymongo.pool.SocketInfo if _VERSION < (4, 5) else pymongo.pool.Connection def patch(): @@ -59,6 +61,7 @@ def patch_pymongo_module(): # - Creates a new socket & performs a TCP handshake # - Grabs a socket already initialized before _w("pymongo.server", "Server.%s" % _CHECKOUT_FN_NAME, traced_get_socket) + _w("pymongo.pool", f"{_VERIFY_VERSION_CLASS.__name__}.validate_session", wrapped_validate_session) def unpatch_pymongo_module(): @@ -67,6 +70,7 @@ def unpatch_pymongo_module(): pymongo._datadog_patch = False _u(pymongo.server.Server, _CHECKOUT_FN_NAME) + _u(_VERIFY_VERSION_CLASS, "validate_session") @contextlib.contextmanager diff --git a/releasenotes/notes/pymongo-fix-session-validation-error-1a05f8ad45bbbc35.yaml b/releasenotes/notes/pymongo-fix-session-validation-error-1a05f8ad45bbbc35.yaml new file mode 100644 index 00000000000..92f081038cb --- /dev/null +++ b/releasenotes/notes/pymongo-fix-session-validation-error-1a05f8ad45bbbc35.yaml @@ -0,0 +1,3 @@ +fixes: + - | + pymongo: this resolves an issue where the library raised an error in ``pymongo.pool.validate_session`` \ No newline at end of file diff --git a/tests/contrib/pymongo/test.py b/tests/contrib/pymongo/test.py index aa3d77387d6..2ef8e10af7f 100644 --- a/tests/contrib/pymongo/test.py +++ b/tests/contrib/pymongo/test.py @@ -768,6 +768,17 @@ def test_single_op(self): self.check_socket_metadata(spans[0]) assert spans[1].name == "pymongo.cmd" + def test_validate_session_equivalence(self): + """ + This tests validate_session from: + https://github.com/mongodb/mongo-python-driver/blob/v3.13/pymongo/pool.py#L884-L898 + which fails under some circumstances unless we patch correctly + """ + # Trigger a command which calls validate_session internal to PyMongo + db_conn = pymongo.database.Database(self.client, "foo") + collection = db_conn["mycollection"] + collection.insert_one({"Foo": "Bar"}) + def test_service_name_override(self): with TracerTestCase.override_config("pymongo", dict(service_name="testdb")): self.client["some_db"].drop_collection("some_collection")