-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add Tool.from_component #159
Changes from 11 commits
784a19c
6d6f307
eb1496c
519d53f
7178e1d
d83ca13
6341268
1c259f9
34a0861
551a528
ce864dd
931df70
3b29458
d8f722c
2dbb8d2
b35c098
d2b7eb2
9c840bf
150fb81
1444749
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from dataclasses import MISSING, fields, is_dataclass | ||
from inspect import getdoc | ||
from typing import Any, Callable, Dict, Union, get_args, get_origin | ||
|
||
from docstring_parser import parse | ||
from haystack import logging | ||
from haystack.core.component import Component | ||
from pydantic.fields import FieldInfo | ||
|
||
from haystack_experimental.util.utils import is_pydantic_v2_model | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def extract_component_parameters(component: Component) -> Dict[str, Any]: | ||
""" | ||
Extracts parameters from a Haystack component and converts them to OpenAI tools JSON format. | ||
|
||
:param component: The component to extract parameters from. | ||
:returns: A dictionary representing the component's input parameters schema. | ||
""" | ||
properties = {} | ||
required = [] | ||
|
||
param_descriptions = get_param_descriptions(component.run) | ||
|
||
for input_name, socket in component.__haystack_input__._sockets_dict.items(): | ||
input_type = socket.type | ||
description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.") | ||
|
||
try: | ||
property_schema = create_property_schema(input_type, description) | ||
except ValueError as e: | ||
raise ValueError(f"Error processing input '{input_name}': {e}") | ||
|
||
properties[input_name] = property_schema | ||
|
||
# Use socket.is_mandatory() to check if the input is required | ||
if socket.is_mandatory: | ||
required.append(input_name) | ||
|
||
parameters_schema = {"type": "object", "properties": properties} | ||
|
||
if required: | ||
parameters_schema["required"] = required | ||
|
||
return parameters_schema | ||
|
||
|
||
def get_param_descriptions(method: Callable) -> Dict[str, str]: | ||
""" | ||
Extracts parameter descriptions from the method's docstring using docstring_parser. | ||
|
||
:param method: The method to extract parameter descriptions from. | ||
:returns: A dictionary mapping parameter names to their descriptions. | ||
""" | ||
docstring = getdoc(method) | ||
if not docstring: | ||
return {} | ||
|
||
parsed_doc = parse(docstring) | ||
param_descriptions = {} | ||
for param in parsed_doc.params: | ||
if not param.description: | ||
logger.warning( | ||
"Missing description for parameter '%s'. Please add a description in the component's " | ||
"run() method docstring using the format ':param %s: <description>'. " | ||
"This description is used to generate the Tool and helps the LLM understand how to use this parameter.", | ||
param.arg_name, | ||
param.arg_name, | ||
) | ||
param_descriptions[param.arg_name] = param.description.strip() if param.description else "" | ||
return param_descriptions | ||
|
||
|
||
# ruff: noqa: PLR0912 | ||
def create_property_schema(python_type: Any, description: str, default: Any = None) -> Dict[str, Any]: | ||
""" | ||
Creates a property schema for a given Python type, recursively if necessary. | ||
|
||
:param python_type: The Python type to create a property schema for. | ||
:param description: The description of the property. | ||
:param default: The default value of the property. | ||
:returns: A dictionary representing the property schema. | ||
""" | ||
nullable = is_nullable_type(python_type) | ||
if nullable: | ||
non_none_types = [t for t in get_args(python_type) if t is not type(None)] | ||
python_type = non_none_types[0] if non_none_types else str | ||
|
||
origin = get_origin(python_type) | ||
if origin is list: | ||
item_type = get_args(python_type)[0] if get_args(python_type) else Any | ||
# recursively call create_property_schema for the item type | ||
items_schema = create_property_schema(item_type, "") | ||
items_schema.pop("description", None) | ||
schema = {"type": "array", "description": description, "items": items_schema} | ||
elif is_dataclass(python_type) or is_pydantic_v2_model(python_type): | ||
schema = {"type": "object", "description": description, "properties": {}} | ||
required_fields = [] | ||
|
||
if is_dataclass(python_type): | ||
# Get the actual class if python_type is an instance otherwise use the type | ||
cls = python_type if isinstance(python_type, type) else python_type.__class__ | ||
for field in fields(cls): | ||
field_description = f"Field '{field.name}' of '{cls.__name__}'." | ||
if isinstance(schema["properties"], dict): | ||
schema["properties"][field.name] = create_property_schema(field.type, field_description) | ||
|
||
else: # Pydantic model | ||
for m_name, m_field in python_type.model_fields.items(): | ||
field_description = f"Field '{m_name}' of '{python_type.__name__}'." | ||
if isinstance(schema["properties"], dict): | ||
schema["properties"][m_name] = create_property_schema(m_field.annotation, field_description) | ||
if m_field.is_required(): | ||
required_fields.append(m_name) | ||
|
||
if required_fields: | ||
schema["required"] = required_fields | ||
else: | ||
# Basic types | ||
type_mapping = {str: "string", int: "integer", float: "number", bool: "boolean", dict: "object"} | ||
schema = {"type": type_mapping.get(python_type, "string"), "description": description} | ||
|
||
if default is not None: | ||
schema["default"] = default | ||
|
||
return schema | ||
|
||
|
||
def is_nullable_type(python_type: Any) -> bool: | ||
""" | ||
Checks if the type is a Union with NoneType (i.e., Optional). | ||
|
||
:param python_type: The Python type to check. | ||
:returns: True if the type is a Union with NoneType, False otherwise. | ||
""" | ||
origin = get_origin(python_type) | ||
if origin is Union: | ||
return type(None) in get_args(python_type) | ||
return False |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,17 +4,22 @@ | |
|
||
import inspect | ||
from dataclasses import asdict, dataclass | ||
from typing import Any, Callable, Dict, Optional | ||
from typing import Any, Callable, Dict, Optional, get_args, get_origin | ||
|
||
from haystack import logging | ||
from haystack.core.component import Component | ||
from haystack.lazy_imports import LazyImport | ||
from haystack.utils import deserialize_callable, serialize_callable | ||
from pydantic import create_model | ||
from pydantic import TypeAdapter, create_model | ||
|
||
with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import: | ||
from jsonschema import Draft202012Validator | ||
from jsonschema.exceptions import SchemaError | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ToolInvocationError(Exception): | ||
""" | ||
Exception raised when a Tool invocation fails. | ||
|
@@ -198,6 +203,66 @@ def get_weather( | |
|
||
return Tool(name=name or function.__name__, description=tool_description, parameters=schema, function=function) | ||
|
||
@classmethod | ||
def from_component(cls, component: Component, name: str, description: str) -> "Tool": | ||
""" | ||
Create a Tool instance from a Haystack component. | ||
|
||
:param component: The Haystack component to be converted into a Tool. | ||
:param name: Name for the tool. | ||
:param description: Description of the tool. | ||
:returns: The Tool created from the Component. | ||
:raises ValueError: If the component is invalid or schema generation fails. | ||
""" | ||
|
||
if not isinstance(component, Component): | ||
raise ValueError( | ||
f"{component} is not a Haystack component!" "Can only create a Tool from a Haystack component instance." | ||
) | ||
|
||
if getattr(component, "__haystack_added_to_pipeline__", None): | ||
msg = ( | ||
"Component has been added in a Pipeline and can't be used to create a Tool. " | ||
"Create Tool from a non-pipeline component instead." | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please explain this? If I remember correctly, one of the requirements was about deserializing Tools from YAML (which should be feasible if Tools are components). I'm not totally sure... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I thought we can have a component declared but not be part of the pipeline. Maybe not, depending on that we can remove this check. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still don't understand if this is a self-imposed limitation (I don't think so) or there are strong reasons to avoid that. Could you please explain this point further? |
||
raise ValueError(msg) | ||
|
||
from haystack_experimental.components.tools.openai.component_caller import extract_component_parameters | ||
anakin87 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Extract the parameters schema from the component | ||
parameters = extract_component_parameters(component) | ||
|
||
def component_invoker(**kwargs): | ||
""" | ||
Invokes the component using keyword arguments provided by the LLM function calling/tool generated response. | ||
|
||
:param kwargs: The keyword arguments to invoke the component with. | ||
:returns: The result of the component invocation. | ||
""" | ||
converted_kwargs = {} | ||
input_sockets = component.__haystack_input__._sockets_dict | ||
for param_name, param_value in kwargs.items(): | ||
param_type = input_sockets[param_name].type | ||
|
||
# Check if the type (or list element type) has from_dict | ||
target_type = get_args(param_type)[0] if get_origin(param_type) is list else param_type | ||
if hasattr(target_type, "from_dict"): | ||
if isinstance(param_value, list): | ||
param_value = [target_type.from_dict(item) for item in param_value if isinstance(item, dict)] | ||
elif isinstance(param_value, dict): | ||
param_value = target_type.from_dict(param_value) | ||
else: | ||
# Let TypeAdapter handle both single values and lists | ||
type_adapter = TypeAdapter(param_type) | ||
param_value = type_adapter.validate_python(param_value) | ||
|
||
converted_kwargs[param_name] = param_value | ||
logger.debug(f"Invoking component {type(component)} with kwargs: {converted_kwargs}") | ||
return component.run(**converted_kwargs) | ||
|
||
# Return a new Tool instance with the component invoker as the function to be called | ||
return Tool(name=name, description=description, parameters=parameters, function=component_invoker) | ||
|
||
|
||
def _remove_title_from_schema(schema: Dict[str, Any]): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want to make this PR more complex but have we considered using defaults for name and description? For example the component name as the default for
name
and parts of the component docstring extracted via docstring_parser fordescription
?That would also enable the ToolInvoker to call from_component internally if the input to its
tool
parameter is a list of components instead of tools, which I think is needed for defining tools as components in a yaml.