Skip to content

Commit

Permalink
Add FewShotSQLTool
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Nov 20, 2024
1 parent 1691884 commit 097906b
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
3 changes: 3 additions & 0 deletions libs/community/langchain_community/tools/few_shot/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from langchain_community.tools.few_shot.tool import FewShotSQLTool

__all__ = ["FewShotSQLTool"]
46 changes: 46 additions & 0 deletions libs/community/langchain_community/tools/few_shot/tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Optional, Type

from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.example_selectors import BaseExampleSelector
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field, ConfigDict


class _FewShotToolInput(BaseModel):

Check failure on line 10 in libs/community/langchain_community/tools/few_shot/tool.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (I001)

langchain_community/tools/few_shot/tool.py:1:1: I001 Import block is un-sorted or un-formatted

Check failure on line 10 in libs/community/langchain_community/tools/few_shot/tool.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (I001)

langchain_community/tools/few_shot/tool.py:1:1: I001 Import block is un-sorted or un-formatted
question: str = Field(
..., description="The question for which we want example SQL queries."
)


class FewShotSQLTool(BaseTool):
"""Tool to get example SQL queries related to an input question."""

name: str = "few_shot_sql"
description: str = "Tool to get example SQL queries related to an input question."
args_schema: Type[BaseModel] = _FewShotToolInput

example_selector: BaseExampleSelector = Field(exclude=True)
example_input_key: str = "input"
example_query_key: str = "query"

model_config = ConfigDict(
arbitrary_types_allowed=True,
)

def _run(
self,
question: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Execute the query, return the results or an error message."""
example_prompt = PromptTemplate.from_template(
f"User input: {self.example_input_key}\nSQL query: {self.example_query_key}"
)
prompt = FewShotPromptTemplate(
example_prompt=example_prompt,
example_selector=self.example_selector,
suffix="",
input_variables=[self.example_input_key],
)
return prompt.format(**{self.example_input_key: question})

0 comments on commit 097906b

Please sign in to comment.