Skip to content

Commit

Permalink
docs, standard-tests: how to standard test a custom tool, imports (#2…
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis authored Nov 15, 2024
1 parent 39fcb47 commit 409c794
Show file tree
Hide file tree
Showing 5 changed files with 267 additions and 8 deletions.
223 changes: 223 additions & 0 deletions docs/docs/how_to/tools_standard_tests.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# How to add standard tests to a tool\n",
"\n",
"When creating either a custom tool or a new tool to publish in a LangChain integration, it is important to add standard tests to ensure the tool works as expected. This guide will show you how to add standard tests to a tool.\n",
"\n",
"## Setup\n",
"\n",
"First, let's install 2 dependencies:\n",
"\n",
"- `langchain-core` will define the interfaces we want to import to define our custom tool.\n",
"- `langchain-tests==0.3.0` will provide the standard tests we want to use.\n",
"\n",
":::note\n",
"\n",
"The `langchain-tests` package contains the module `langchain_standard_tests`. This name\n",
"mistmatch is due to this package historically being called `langchain_standard_tests` and\n",
"the name not being available on PyPi. This will either be reconciled by our \n",
"[PEP 541 request](https://github.com/pypi/support/issues/5062) (we welcome upvotes!), \n",
"or in a new release of `langchain-tests`.\n",
"\n",
"Because added tests in new versions of `langchain-tests` will always break your CI/CD pipelines, we recommend pinning the \n",
"version of `langchain-tests==0.3.0` to avoid unexpected changes.\n",
"\n",
":::"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install -U langchain-core langchain-tests==0.3.0 pytest pytest-socket"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's say we're publishing a package, `langchain_parrot_link`, that exposes a\n",
"tool called `ParrotMultiplyTool`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# title=\"langchain_parrot_link/tools.py\"\n",
"from langchain_core.tools import BaseTool\n",
"\n",
"\n",
"class ParrotMultiplyTool(BaseTool):\n",
" name: str = \"ParrotMultiplyTool\"\n",
" description: str = (\n",
" \"Multiply two numbers like a parrot. Parrots always add \"\n",
" \"eighty for their matey.\"\n",
" )\n",
"\n",
" def _run(self, a: int, b: int) -> int:\n",
" return a * b + 80"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And we'll assume you've structured your package the same way as the main LangChain\n",
"packages:\n",
"\n",
"```\n",
"/\n",
"├── langchain_parrot_link/\n",
"│ └── tools.py\n",
"└── tests/\n",
" ├── unit_tests/\n",
" │ └── test_tools.py\n",
" └── integration_tests/\n",
" └── test_tools.py\n",
"```\n",
"\n",
"## Add and configure standard tests\n",
"\n",
"There are 2 namespaces in the `langchain-tests` package: \n",
"\n",
"- unit tests (`langchain_standard_tests.unit_tests`): designed to be used to test the tool in isolation and without access to external services\n",
"- integration tests (`langchain_standard_tests.integration_tests`): designed to be used to test the tool with access to external services (in particular, the external service that the tool is designed to interact with).\n",
"\n",
":::note\n",
"\n",
"Integration tests can also be run without access to external services, **if** they are properly mocked.\n",
"\n",
":::\n",
"\n",
"Both types of tests are implemented as [`pytest` class-based test suites](https://docs.pytest.org/en/7.1.x/getting-started.html#group-multiple-tests-in-a-class).\n",
"\n",
"By subclassing the base classes for each type of standard test (see below), you get all of the standard tests for that type, and you\n",
"can override the properties that the test suite uses to configure the tests.\n",
"\n",
"### Standard tools tests\n",
"\n",
"Here's how you would configure the standard unit tests for the custom tool, e.g. in `tests/test_tools.py`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"title": "tests/test_custom_tool.py"
},
"outputs": [],
"source": [
"# title=\"tests/unit_tests/test_custom_tool.py\"\n",
"from typing import Type\n",
"\n",
"from langchain_parrot_link.tools import ParrotMultiplyTool\n",
"from langchain_standard_tests.unit_tests import ToolsUnitTests\n",
"\n",
"\n",
"class MultiplyToolUnitTests(ToolsUnitTests):\n",
" @property\n",
" def tool_constructor(self) -> Type[ParrotMultiplyTool]:\n",
" return ParrotMultiplyTool\n",
"\n",
" def tool_constructor_params(self) -> dict:\n",
" # if your tool constructor instead required initialization arguments like\n",
" # `def __init__(self, some_arg: int):`, you would return those here\n",
" # as a dictionary, e.g.: `return {'some_arg': 42}`\n",
" return {}\n",
"\n",
" def tool_invoke_params_example(self) -> dict:\n",
" \"\"\"\n",
" Returns a dictionary representing the \"args\" of an example tool call.\n",
"\n",
" This should NOT be a ToolCall dict - i.e. it should not\n",
" have {\"name\", \"id\", \"args\"} keys.\n",
" \"\"\"\n",
" return {\"a\": 2, \"b\": 3}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# title=\"tests/integration_tests/test_custom_tool.py\"\n",
"from typing import Type\n",
"\n",
"from langchain_parrot_link.tools import ParrotMultiplyTool\n",
"from langchain_standard_tests.integration_tests import ToolsIntegrationTests\n",
"\n",
"\n",
"class MultiplyToolIntegrationTests(ToolsIntegrationTests):\n",
" @property\n",
" def tool_constructor(self) -> Type[ParrotMultiplyTool]:\n",
" return ParrotMultiplyTool\n",
"\n",
" def tool_constructor_params(self) -> dict:\n",
" # if your tool constructor instead required initialization arguments like\n",
" # `def __init__(self, some_arg: int):`, you would return those here\n",
" # as a dictionary, e.g.: `return {'some_arg': 42}`\n",
" return {}\n",
"\n",
" def tool_invoke_params_example(self) -> dict:\n",
" \"\"\"\n",
" Returns a dictionary representing the \"args\" of an example tool call.\n",
"\n",
" This should NOT be a ToolCall dict - i.e. it should not\n",
" have {\"name\", \"id\", \"args\"} keys.\n",
" \"\"\"\n",
" return {\"a\": 2, \"b\": 3}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"and you would run these with the following commands from your project root\n",
"\n",
"```bash\n",
"# run unit tests without network access\n",
"pytest --disable-socket --enable-unix-socket tests/unit_tests\n",
"\n",
"# run integration tests\n",
"pytest tests/integration_tests\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
10 changes: 10 additions & 0 deletions docs/scripts/notebook_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ def preprocess_cell(self, cell, resources, cell_index):
# escape ``` in code
cell.source = cell.source.replace("```", r"\`\`\`")
# escape ``` in output

# allow overriding title based on comment at beginning of cell
if cell.source.startswith("# title="):
lines = cell.source.split("\n")
title = lines[0].split("# title=")[1]
if title.startswith('"') and title.endswith('"'):
title = title[1:-1]
cell.metadata["title"] = title
cell.source = "\n".join(lines[1:])

if "outputs" in cell:
filter_out = set()
for i, output in enumerate(cell["outputs"]):
Expand Down
15 changes: 15 additions & 0 deletions docs/scripts/notebook_convert_templates/mdoutput/index.md.j2
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
{% extends 'markdown/index.md.j2' %}

{% block input %}
```
{%- if 'magics_language' in cell.metadata -%}
{{ cell.metadata.magics_language}}
{%- elif 'name' in nb.metadata.get('language_info', {}) -%}
{{ nb.metadata.language_info.name }}
{%- endif %}
{%- if 'title' in cell.metadata -%}
{{ ' ' }}title="{{ cell.metadata.title }}"

{%- endif %}
{{ cell.source}}
```
{% endblock input %}

{%- block traceback_line -%}
```output
{{ line.rstrip() | strip_ansi }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,29 @@
"chat_models",
"vectorstores",
"embeddings",
"tools",
]

for module in modules:
pytest.register_assert_rewrite(
f"langchain_standard_tests.integration_tests.{module}"
)

from langchain_standard_tests.integration_tests.chat_models import (
ChatModelIntegrationTests,
)
from langchain_standard_tests.integration_tests.embeddings import (
EmbeddingsIntegrationTests,
)
from .base_store import BaseStoreAsyncTests, BaseStoreSyncTests
from .cache import AsyncCacheTestSuite, SyncCacheTestSuite
from .chat_models import ChatModelIntegrationTests
from .embeddings import EmbeddingsIntegrationTests
from .tools import ToolsIntegrationTests
from .vectorstores import AsyncReadWriteTestSuite, ReadWriteTestSuite

__all__ = [
"ChatModelIntegrationTests",
"EmbeddingsIntegrationTests",
"ToolsIntegrationTests",
"BaseStoreAsyncTests",
"BaseStoreSyncTests",
"AsyncCacheTestSuite",
"SyncCacheTestSuite",
"AsyncReadWriteTestSuite",
"ReadWriteTestSuite",
]
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
modules = [
"chat_models",
"embeddings",
"tools",
]

for module in modules:
pytest.register_assert_rewrite(f"langchain_standard_tests.unit_tests.{module}")

from langchain_standard_tests.unit_tests.chat_models import ChatModelUnitTests
from .chat_models import ChatModelUnitTests
from .embeddings import EmbeddingsUnitTests
from .tools import ToolsUnitTests

__all__ = ["ChatModelUnitTests", "EmbeddingsUnitTests"]
__all__ = ["ChatModelUnitTests", "EmbeddingsUnitTests", "ToolsUnitTests"]

0 comments on commit 409c794

Please sign in to comment.