Skip to content

Commit

Permalink
core: allow artifact in create_retriever_tool (#28903)
Browse files Browse the repository at this point in the history
Add option to return content and artifacts, to also be able to access
the full info of the retrieved documents.

They are returned as a list of dicts in the `artifacts` property if
parameter `response_format` is set to `"content_and_artifact"`.

Defaults to `"content"` to keep current behavior.

---------

Co-authored-by: Erick Friis <[email protected]>
  • Loading branch information
ianchi and efriis authored Jan 3, 2025
1 parent 3e618b1 commit acddfc7
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 6 deletions.
31 changes: 26 additions & 5 deletions libs/core/langchain_core/tools/retriever.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from functools import partial
from typing import Optional
from typing import Literal, Optional, Union

from pydantic import BaseModel, Field

from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document
from langchain_core.prompts import (
BasePromptTemplate,
PromptTemplate,
Expand All @@ -28,11 +29,16 @@ def _get_relevant_documents(
document_prompt: BasePromptTemplate,
document_separator: str,
callbacks: Callbacks = None,
) -> str:
response_format: Literal["content", "content_and_artifact"] = "content",
) -> Union[str, tuple[str, list[Document]]]:
docs = retriever.invoke(query, config={"callbacks": callbacks})
return document_separator.join(
content = document_separator.join(
format_document(doc, document_prompt) for doc in docs
)
if response_format == "content_and_artifact":
return (content, docs)

return content


async def _aget_relevant_documents(
Expand All @@ -41,12 +47,18 @@ async def _aget_relevant_documents(
document_prompt: BasePromptTemplate,
document_separator: str,
callbacks: Callbacks = None,
) -> str:
response_format: Literal["content", "content_and_artifact"] = "content",
) -> Union[str, tuple[str, list[Document]]]:
docs = await retriever.ainvoke(query, config={"callbacks": callbacks})
return document_separator.join(
content = document_separator.join(
[await aformat_document(doc, document_prompt) for doc in docs]
)

if response_format == "content_and_artifact":
return (content, docs)

return content


def create_retriever_tool(
retriever: BaseRetriever,
Expand All @@ -55,6 +67,7 @@ def create_retriever_tool(
*,
document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = "\n\n",
response_format: Literal["content", "content_and_artifact"] = "content",
) -> Tool:
"""Create a tool to do retrieval of documents.
Expand All @@ -66,6 +79,11 @@ def create_retriever_tool(
model, so should be descriptive.
document_prompt: The prompt to use for the document. Defaults to None.
document_separator: The separator to use between documents. Defaults to "\n\n".
response_format: The tool response format. If "content" then the output of
the tool is interpreted as the contents of a ToolMessage. If
"content_and_artifact" then the output is expected to be a two-tuple
corresponding to the (content, artifact) of a ToolMessage (artifact
being a list of documents in this case). Defaults to "content".
Returns:
Tool class to pass to an agent.
Expand All @@ -76,17 +94,20 @@ def create_retriever_tool(
retriever=retriever,
document_prompt=document_prompt,
document_separator=document_separator,
response_format=response_format,
)
afunc = partial(
_aget_relevant_documents,
retriever=retriever,
document_prompt=document_prompt,
document_separator=document_separator,
response_format=response_format,
)
return Tool(
name=name,
description=description,
func=func,
coroutine=afunc,
args_schema=RetrieverInput,
response_format=response_format,
)
58 changes: 57 additions & 1 deletion libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,13 @@
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.messages import ToolMessage
from langchain_core.callbacks.manager import (
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.messages import ToolCall, ToolMessage
from langchain_core.messages.tool import ToolOutputMixin
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import (
Runnable,
RunnableConfig,
Expand Down Expand Up @@ -2118,6 +2123,57 @@ def my_tool(val: int, other_val: Annotated[dict, "my annotation"]) -> str:
assert schema.__annotations__ == expected_type_hints


def test_create_retriever_tool() -> None:
class MyRetriever(BaseRetriever):
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
return [Document(page_content=f"foo {query}"), Document(page_content="bar")]

retriever = MyRetriever()
retriever_tool = tools.create_retriever_tool(
retriever, "retriever_tool_content", "Retriever Tool Content"
)
assert isinstance(retriever_tool, BaseTool)
assert retriever_tool.name == "retriever_tool_content"
assert retriever_tool.description == "Retriever Tool Content"
assert retriever_tool.invoke("bar") == "foo bar\n\nbar"
assert retriever_tool.invoke(
ToolCall(
name="retriever_tool_content",
args={"query": "bar"},
id="123",
type="tool_call",
)
) == ToolMessage(
"foo bar\n\nbar", tool_call_id="123", name="retriever_tool_content"
)

retriever_tool_artifact = tools.create_retriever_tool(
retriever,
"retriever_tool_artifact",
"Retriever Tool Artifact",
response_format="content_and_artifact",
)
assert isinstance(retriever_tool_artifact, BaseTool)
assert retriever_tool_artifact.name == "retriever_tool_artifact"
assert retriever_tool_artifact.description == "Retriever Tool Artifact"
assert retriever_tool_artifact.invoke("bar") == "foo bar\n\nbar"
assert retriever_tool_artifact.invoke(
ToolCall(
name="retriever_tool_artifact",
args={"query": "bar"},
id="123",
type="tool_call",
)
) == ToolMessage(
"foo bar\n\nbar",
artifact=[Document(page_content="foo bar"), Document(page_content="bar")],
tool_call_id="123",
name="retriever_tool_artifact",
)


@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.")
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
from pydantic import BaseModel as BaseModelV2
Expand Down

0 comments on commit acddfc7

Please sign in to comment.