Skip to content

Commit

Permalink
fix(langchain): safely check if instance is openai llm/chat (#8896)
Browse files Browse the repository at this point in the history
Fixes #8889.

This PR adds a safe check for if a traced LLM/Chat model instance is an
OpenAI instance by adding try/catch and reversing the order of the type
checks. Previously we were checking directly against the installed
langchain package, i.e.
```python
if isinstance(instance, BASE_LANGCHAIN_MODULE.chat_models.ChatOpenAI) or (
            langchain_openai and isinstance(instance, langchain_openai.ChatOpenAI)
        ):
# BASE_LANGCHAIN_MODULE can be either `langchain` or `langchain_community`
```
But `langchain_community` does not allow automatically accessing its
submodules, i.e. `langchain_community.chat_models.ChatOpenAI` will
result in an error unless `from langchain_community import chat_models`
is performed already.

With this fix, there are three scenarios for
`langchain_community/langchain_openai` users:
1. They use `langchain_openai` --> we perform the type check using
`langchain_openai` first which will always be available, and will never
hit the `BASE_LANGCHAIN_MODULE` type check.
2. They use `langchain_community` --> since users are using
`langchain_community.chat_models` they must have already imported this
in their code and it should not result in any errors. Regardless, we
will safely try type checking against the submodule.
3. They use `langchain` --> `langchain` allows automatically accessing
submodules without directly importing, so this should also not result in
any errors.

## Checklist

- [x] Change(s) are motivated and described in the PR description
- [x] Testing strategy is described if automated tests are not included
in the PR
- [x] Risks are described (performance impact, potential for breakage,
maintainability)
- [x] Change is maintainable (easy to change, telemetry, documentation)
- [x] [Library release note
guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html)
are followed or label `changelog/no-changelog` is set
- [x] Documentation is included (in-code, generated user docs, [public
corp docs](https://github.com/DataDog/documentation/))
- [x] Backport labels are set (if
[applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting))
- [x] If this PR changes the public interface, I've notified
`@DataDog/apm-tees`.

## Reviewer Checklist

- [x] Title is accurate
- [x] All changes are related to the pull request's stated goal
- [x] Description motivates each change
- [x] Avoids breaking
[API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces)
changes
- [x] Testing strategy adequately addresses listed risks
- [x] Change is maintainable (easy to change, telemetry, documentation)
- [x] Release note makes sense to a user of the library
- [x] Author has acknowledged and discussed the performance implications
of this PR as reported in the benchmarks PR comment
- [x] Backport labels are set in a manner that is consistent with the
[release branch maintenance
policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)

Co-authored-by: Alberto Vara <[email protected]>
  • Loading branch information
Yun-Kim and avara1986 authored Apr 8, 2024
1 parent 131cf91 commit 698021b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 14 deletions.
42 changes: 28 additions & 14 deletions ddtrace/contrib/langchain/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,30 @@ def _tag_openai_token_usage(
_tag_openai_token_usage(span._parent, llm_output, propagated_cost=propagated_cost + total_cost, propagate=True)


def _is_openai_llm_instance(instance):
"""Safely check if a traced instance is an OpenAI LLM.
langchain_community does not automatically import submodules which may result in AttributeErrors.
"""
try:
if langchain_openai:
return isinstance(instance, langchain_openai.OpenAI)
return isinstance(instance, BASE_LANGCHAIN_MODULE.llms.OpenAI)
except (AttributeError, ModuleNotFoundError, ImportError):
return False


def _is_openai_chat_instance(instance):
"""Safely check if a traced instance is an OpenAI Chat Model.
langchain_community does not automatically import submodules which may result in AttributeErrors.
"""
try:
if langchain_openai:
return isinstance(instance, langchain_openai.ChatOpenAI)
return isinstance(instance, BASE_LANGCHAIN_MODULE.chat_models.ChatOpenAI)
except (AttributeError, ModuleNotFoundError, ImportError):
return False


@with_traced_module
def traced_llm_generate(langchain, pin, func, instance, args, kwargs):
llm_provider = instance._llm_type
Expand Down Expand Up @@ -173,9 +197,7 @@ def traced_llm_generate(langchain, pin, func, instance, args, kwargs):
span.set_tag_str("langchain.request.%s.parameters.%s" % (llm_provider, param), str(val))

completions = func(*args, **kwargs)
if isinstance(instance, BASE_LANGCHAIN_MODULE.llms.OpenAI) or (
langchain_openai and isinstance(instance, langchain_openai.OpenAI)
):
if _is_openai_llm_instance(instance):
_tag_openai_token_usage(span, completions.llm_output)
integration.record_usage(span, completions.llm_output)

Expand Down Expand Up @@ -253,9 +275,7 @@ async def traced_llm_agenerate(langchain, pin, func, instance, args, kwargs):
span.set_tag_str("langchain.request.%s.parameters.%s" % (llm_provider, param), str(val))

completions = await func(*args, **kwargs)
if isinstance(instance, BASE_LANGCHAIN_MODULE.llms.OpenAI) or (
langchain_openai and isinstance(instance, langchain_openai.OpenAI)
):
if _is_openai_llm_instance(instance):
_tag_openai_token_usage(span, completions.llm_output)
integration.record_usage(span, completions.llm_output)

Expand Down Expand Up @@ -346,9 +366,7 @@ def traced_chat_model_generate(langchain, pin, func, instance, args, kwargs):
span.set_tag_str("langchain.request.%s.parameters.%s" % (llm_provider, param), str(val))

chat_completions = func(*args, **kwargs)
if isinstance(instance, BASE_LANGCHAIN_MODULE.chat_models.ChatOpenAI) or (
langchain_openai and isinstance(instance, langchain_openai.ChatOpenAI)
):
if _is_openai_chat_instance(instance):
_tag_openai_token_usage(span, chat_completions.llm_output)
integration.record_usage(span, chat_completions.llm_output)

Expand Down Expand Up @@ -453,9 +471,7 @@ async def traced_chat_model_agenerate(langchain, pin, func, instance, args, kwar
span.set_tag_str("langchain.request.%s.parameters.%s" % (llm_provider, param), str(val))

chat_completions = await func(*args, **kwargs)
if isinstance(instance, BASE_LANGCHAIN_MODULE.chat_models.ChatOpenAI) or (
langchain_openai and isinstance(instance, langchain_openai.ChatOpenAI)
):
if _is_openai_chat_instance(instance):
_tag_openai_token_usage(span, chat_completions.llm_output)
integration.record_usage(span, chat_completions.llm_output)

Expand Down Expand Up @@ -842,9 +858,7 @@ def patch():
# ref: https://github.com/DataDog/dd-trace-py/issues/7123
if SHOULD_PATCH_LANGCHAIN_COMMUNITY:
from langchain.chains.base import Chain # noqa:F401
from langchain_community import chat_models # noqa:F401
from langchain_community import embeddings # noqa:F401
from langchain_community import llms # noqa:F401
from langchain_community import vectorstores # noqa:F401

wrap("langchain_core", "language_models.llms.BaseLLM.generate", traced_llm_generate(langchain))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
fixes:
- |
langchain: This fix adds error handling for checking if a traced LLM or chat model is an OpenAI instance, as the
langchain_community package does not allow automatic submodule importing.

0 comments on commit 698021b

Please sign in to comment.