From 784a19cf3b9d2ec3d6c8530e9aafc6502393a50e Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 12 Nov 2024 15:02:39 +0100 Subject: [PATCH 01/18] Initial Tool.from_component --- .../tools/openai/component_caller.py | 129 +++++ haystack_experimental/dataclasses/tool.py | 106 +++- haystack_experimental/util/utils.py | 12 +- test/components/tools/test_tool_component.py | 452 ++++++++++++++++++ 4 files changed, 696 insertions(+), 3 deletions(-) create mode 100644 haystack_experimental/components/tools/openai/component_caller.py create mode 100644 test/components/tools/test_tool_component.py diff --git a/haystack_experimental/components/tools/openai/component_caller.py b/haystack_experimental/components/tools/openai/component_caller.py new file mode 100644 index 00000000..889fadac --- /dev/null +++ b/haystack_experimental/components/tools/openai/component_caller.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import MISSING, fields, is_dataclass +from inspect import getdoc +from typing import Any, Callable, Dict, Union, get_args, get_origin + +from docstring_parser import parse +from haystack import logging +from haystack.core.component import Component + +from haystack_experimental.util.utils import is_pydantic_v2_model + +logger = logging.getLogger(__name__) + + +def extract_component_parameters(component: Component) -> Dict[str, Any]: + """ + Extracts parameters from a Haystack component and converts them to OpenAI tools JSON format. + + :param component: The component to extract parameters from. + :returns: A dictionary representing the component's input parameters schema. + """ + properties = {} + required = [] + + param_descriptions = get_param_descriptions(component.run) + + for input_name, socket in component.__haystack_input__._sockets_dict.items(): + input_type = socket.type + description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.") + + try: + property_schema = create_property_schema(input_type, description) + except ValueError as e: + raise ValueError(f"Error processing input '{input_name}': {e}") + + properties[input_name] = property_schema + + # Use socket.is_mandatory() to check if the input is required + if socket.is_mandatory: + required.append(input_name) + + parameters_schema = {"type": "object", "properties": properties} + + if required: + parameters_schema["required"] = required + + return parameters_schema + + +def get_param_descriptions(method: Callable) -> Dict[str, str]: + """ + Extracts parameter descriptions from the method's docstring using docstring_parser. + + :param method: The method to extract parameter descriptions from. + :returns: A dictionary mapping parameter names to their descriptions. + """ + docstring = getdoc(method) + if not docstring: + return {} + + parsed_doc = parse(docstring) + return {param.arg_name: param.description.strip() for param in parsed_doc.params} + + +def create_property_schema(python_type: Any, description: str, default: Any = None) -> Dict[str, Any]: + """ + Creates a property schema for a given Python type, recursively if necessary. + + :param python_type: The Python type to create a property schema for. + :param description: The description of the property. + :param default: The default value of the property. + :returns: A dictionary representing the property schema. + """ + nullable = is_nullable_type(python_type) + if nullable: + non_none_types = [t for t in get_args(python_type) if t is not type(None)] + python_type = non_none_types[0] if non_none_types else str + + origin = get_origin(python_type) + if origin is list: + item_type = get_args(python_type)[0] if get_args(python_type) else Any + items_schema = create_property_schema(item_type, "") + items_schema.pop("description", None) + schema = {"type": "array", "description": description, "items": items_schema} + elif is_dataclass(python_type) or is_pydantic_v2_model(python_type): + schema = {"type": "object", "description": description, "properties": {}} + required_fields = [] + + if is_dataclass(python_type): + for field in fields(python_type): + field_description = f"Field '{field.name}' of '{python_type.__name__}'." + schema["properties"][field.name] = create_property_schema(field.type, field_description) + if field.default is MISSING and field.default_factory is MISSING: + required_fields.append(field.name) + else: # Pydantic model + model_fields = python_type.model_fields + for name, field in model_fields.items(): + field_description = f"Field '{name}' of '{python_type.__name__}'." + schema["properties"][name] = create_property_schema(field.annotation, field_description) + if field.is_required(): + required_fields.append(name) + + if required_fields: + schema["required"] = required_fields + else: + # Basic types + type_mapping = {str: "string", int: "integer", float: "number", bool: "boolean", dict: "object"} + schema = {"type": type_mapping.get(python_type, "string"), "description": description} + + if default is not None: + schema["default"] = default + + return schema + + +def is_nullable_type(python_type: Any) -> bool: + """ + Checks if the type is a Union with NoneType (i.e., Optional). + + :param python_type: The Python type to check. + :returns: True if the type is a Union with NoneType, False otherwise. + """ + origin = get_origin(python_type) + if origin is Union: + return type(None) in get_args(python_type) + return False diff --git a/haystack_experimental/dataclasses/tool.py b/haystack_experimental/dataclasses/tool.py index 33719524..f2c04f3a 100644 --- a/haystack_experimental/dataclasses/tool.py +++ b/haystack_experimental/dataclasses/tool.py @@ -3,18 +3,25 @@ # SPDX-License-Identifier: Apache-2.0 import inspect -from dataclasses import asdict, dataclass -from typing import Any, Callable, Dict, Optional +from dataclasses import asdict, dataclass, is_dataclass +from typing import Any, Callable, Dict, Optional, get_args, get_origin, get_type_hints +from haystack import logging +from haystack.core.component import Component from haystack.lazy_imports import LazyImport from haystack.utils import deserialize_callable, serialize_callable from pydantic import create_model +from haystack_experimental.util.utils import is_pydantic_v2_model + with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import: from jsonschema import Draft202012Validator from jsonschema.exceptions import SchemaError +logger = logging.getLogger(__name__) + + class ToolInvocationError(Exception): """ Exception raised when a Tool invocation fails. @@ -198,6 +205,101 @@ def get_weather( return Tool(name=name or function.__name__, description=tool_description, parameters=schema, function=function) + @classmethod + def from_component(cls, component: Component, name: str, description: str) -> "Tool": + """ + Create a Tool instance from a Haystack component. + + :param component: The Haystack component to be converted into a Tool. + :param name: Name for the tool. + :param description: Description of the tool. + :returns: The Tool created from the Component. + :raises ValueError: If the component is invalid or schema generation fails. + """ + from haystack_experimental.components.tools.openai.component_caller import extract_component_parameters + + # Extract the parameters schema from the component + parameters = extract_component_parameters(component) + + def _convert_to_dataclass(data: Any, data_type: Any) -> Any: + """ + Recursively convert dictionaries into dataclass instances based on the provided data type. + + This function handles nested dataclasses by recursively converting each field. + + :param data: + The input data to convert. + :param data_type: + The target data type, expected to be a dataclass type. + :returns: + An instance of the dataclass with data populated from the input dictionary. + """ + if data is None or not isinstance(data, dict): + return data + + # Check if the target type is a dataclass + if is_dataclass(data_type): + # Get field types for the dataclass (field name -> field type) + field_types = get_type_hints(data_type) + converted_data = {} + # Recursively convert each field in the dataclass + for field_name, field_type in field_types.items(): + if field_name in data: + # Recursive step: convert nested dictionaries into dataclass instances + converted_data[field_name] = _convert_to_dataclass(data[field_name], field_type) + # Instantiate the dataclass with the converted data + return data_type(**converted_data) + # If data_type is not a dataclass, return the data unchanged + return data + + def component_invoker(**kwargs): + """ + Invokes the component using keyword arguments provided by the LLM function calling/tool generated response. + + :param kwargs: The keyword arguments to invoke the component with. + :returns: The result of the component invocation. + """ + converted_kwargs = {} + + # Get input sockets for type information + input_sockets = component.__haystack_input__._sockets_dict + + for param_name, param_value in kwargs.items(): + socket = input_sockets[param_name] + param_type = socket.type + + # Determine the origin type (e.g., list) and target_type + origin = get_origin(param_type) or param_type + + if origin is list: + # Parameter is a list; get the element type + target_type = get_args(param_type)[0] + values_to_convert = param_value + else: + # Parameter is a single value + target_type = param_type + values_to_convert = [param_value] + + # Convert dictionary inputs into dataclass or Pydantic model instances if necessary + if is_dataclass(target_type) or is_pydantic_v2_model(target_type): + converted = [ + target_type.model_validate(item) + if is_pydantic_v2_model(target_type) + else _convert_to_dataclass(item, target_type) + for item in values_to_convert + if isinstance(item, dict) + ] + # Update the parameter value with the converted data + param_value = converted if origin is list else converted[0] + + converted_kwargs[param_name] = param_value + + logger.debug(f"Invoking component with kwargs: {converted_kwargs}") + return component.run(**converted_kwargs) + + # Return a new Tool instance with the component invoker as the function to be called + return Tool(name=name, description=description, parameters=parameters, function=component_invoker) + def _remove_title_from_schema(schema: Dict[str, Any]): """ diff --git a/haystack_experimental/util/utils.py b/haystack_experimental/util/utils.py index 59beeacd..568f3930 100644 --- a/haystack_experimental/util/utils.py +++ b/haystack_experimental/util/utils.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import List, Union +from typing import Any, List, Union def expand_page_range(page_range: List[Union[str, int]]) -> List[int]: @@ -41,3 +41,13 @@ def expand_page_range(page_range: List[Union[str, int]]) -> List[int]: raise ValueError("No valid page numbers or ranges found in the input list") return expanded_page_range + + +def is_pydantic_v2_model(instance: Any) -> bool: + """ + Checks if the instance is a Pydantic v2 model. + + :param instance: The instance to check. + :returns: True if the instance is a Pydantic v2 model, False otherwise. + """ + return hasattr(instance, "model_validate") diff --git a/test/components/tools/test_tool_component.py b/test/components/tools/test_tool_component.py new file mode 100644 index 00000000..cfffffcd --- /dev/null +++ b/test/components/tools/test_tool_component.py @@ -0,0 +1,452 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import pytest +from typing import Dict, List, Optional, Any +from dataclasses import dataclass +from haystack import component +from pydantic import BaseModel +from haystack import Pipeline +from haystack_experimental.dataclasses import ChatMessage, ToolCall, ChatRole +from haystack_experimental.components.tools.tool_invoker import ToolInvoker +from haystack_experimental.components.generators.chat import OpenAIChatGenerator + +from haystack_experimental.dataclasses.tool import Tool + + +### Component and Model Definitions + +@component +class SimpleComponent: + """A simple component that generates text.""" + + @component.output_types(reply=str) + def run(self, text: str) -> Dict[str, str]: + """ + A simple component that generates text. + + :param text: The text to generate. + :return: A dictionary with the generated text. + """ + return {"reply": f"Hello, {text}!"} + + +class Product(BaseModel): + """A product model.""" + name: str + price: float + + +@dataclass +class User: + """A simple user dataclass.""" + name: str = "Anonymous" + age: int = 0 + + +@component +class UserGreeter: + """A simple component that processes a User.""" + + @component.output_types(message=str) + def run(self, user: User) -> Dict[str, str]: + """ + A simple component that processes a User. + + :param user: The User object to process. + :return: A dictionary with a message about the user. + """ + return {"message": f"User {user.name} is {user.age} years old"} + + +@component +class ListProcessor: + """A component that processes a list of strings.""" + + @component.output_types(concatenated=str) + def run(self, texts: List[str]) -> Dict[str, str]: + """ + Concatenates a list of strings into a single string. + + :param texts: The list of strings to concatenate. + :return: A dictionary with the concatenated string. + """ + return {"concatenated": ' '.join(texts)} + + +@component +class ProductProcessor: + """A component that processes a Product.""" + + @component.output_types(description=str) + def run(self, product: Product) -> Dict[str, str]: + """ + Creates a description for the product. + + :param product: The Product to process. + :return: A dictionary with the product description. + """ + return { + "description": f"The product {product.name} costs ${product.price:.2f}." + } + + +@dataclass +class Address: + """A dataclass representing a physical address.""" + street: str + city: str + + +@dataclass +class Person: + """A person with an address.""" + name: str + address: Address + + +@component +class PersonProcessor: + """A component that processes a Person with nested Address.""" + + @component.output_types(info=str) + def run(self, person: Person) -> Dict[str, str]: + """ + Creates information about the person. + + :param person: The Person to process. + :return: A dictionary with the person's information. + """ + return { + "info": f"{person.name} lives at {person.address.street}, {person.address.city}." + } + + +## Unit tests +class TestToolComponent: + def test_from_component_basic(self): + component = SimpleComponent() + + tool = Tool.from_component( + component=component, + name="hello_tool", + description="A hello tool" + ) + + assert tool.name == "hello_tool" + assert tool.description == "A hello tool" + assert tool.parameters == { + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "The text to generate." + } + }, + "required": ["text"] + } + + # Test tool invocation + result = tool.invoke(text="world") + assert isinstance(result, dict) + assert "reply" in result + assert result["reply"] == "Hello, world!" + + def test_from_component_with_dataclass(self): + component = UserGreeter() + + tool = Tool.from_component( + component=component, + name="user_info_tool", + description="A tool that returns user information" + ) + + assert tool.parameters == { + "type": "object", + "properties": { + "user": { + "type": "object", + "description": "The User object to process.", + "properties": { + "name": { + "type": "string", + "description": "Field 'name' of 'User'." + }, + "age": { + "type": "integer", + "description": "Field 'age' of 'User'." + } + } + } + }, + "required": ["user"] + } + + # Test tool invocation + result = tool.invoke(user={"name": "Alice", "age": 30}) + assert isinstance(result, dict) + assert "message" in result + assert result["message"] == "User Alice is 30 years old" + + def test_from_component_with_list_input(self): + component = ListProcessor() + + tool = Tool.from_component( + component=component, + name="list_processing_tool", + description="A tool that concatenates strings" + ) + + assert tool.parameters == { + "type": "object", + "properties": { + "texts": { + "type": "array", + "description": "The list of strings to concatenate.", + "items": { + "type": "string" + } + } + }, + "required": ["texts"] + } + + # Test tool invocation + result = tool.invoke(texts=["hello", "world"]) + assert isinstance(result, dict) + assert "concatenated" in result + assert result["concatenated"] == "hello world" + + def test_from_component_with_pydantic_model(self): + component = ProductProcessor() + + tool = Tool.from_component( + component=component, + name="product_tool", + description="A tool that processes products" + ) + + assert tool.parameters == { + "type": "object", + "properties": { + "product": { + "type": "object", + "description": "The Product to process.", + "properties": { + "name": { + "type": "string", + "description": "Field 'name' of 'Product'." + }, + "price": { + "type": "number", + "description": "Field 'price' of 'Product'." + } + }, + "required": ["name", "price"] + } + }, + "required": ["product"] + } + + # Test tool invocation + result = tool.invoke(product={"name": "Widget", "price": 19.99}) + assert isinstance(result, dict) + assert "description" in result + assert result["description"] == "The product Widget costs $19.99." + + def test_from_component_with_nested_dataclass(self): + component = PersonProcessor() + + tool = Tool.from_component( + component=component, + name="person_tool", + description="A tool that processes people" + ) + + assert tool.parameters == { + "type": "object", + "properties": { + "person": { + "type": "object", + "description": "The Person to process.", + "properties": { + "name": { + "type": "string", + "description": "Field 'name' of 'Person'." + }, + "address": { + "type": "object", + "description": "Field 'address' of 'Person'.", + "properties": { + "street": { + "type": "string", + "description": "Field 'street' of 'Address'." + }, + "city": { + "type": "string", + "description": "Field 'city' of 'Address'." + } + }, + "required": ["street", "city"] + } + }, + "required": ["name", "address"] + } + }, + "required": ["person"] + } + + # Test tool invocation + result = tool.invoke(person={ + "name": "Diana", + "address": { + "street": "123 Elm Street", + "city": "Metropolis" + } + }) + assert isinstance(result, dict) + assert "info" in result + assert result["info"] == "Diana lives at 123 Elm Street, Metropolis." + + +## Integration tests +class TestToolComponentInPipeline: + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_component_tool_in_pipeline(self): + # Create component and convert it to tool + component = SimpleComponent() + tool = Tool.from_component( + component=component, + name="hello_tool", + description="A tool that generates a greeting message" + ) + + # Create pipeline with OpenAIChatGenerator and ToolInvoker + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + + # Connect components + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="Hey I'm Vladimir") + + # Run pipeline + result = pipeline.run({"llm": {"messages": [message]}}) + + # Check results + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert tool_message.tool_call_result.result == str({"reply": "Hello, Vladimir!"}) + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_user_greeter_in_pipeline(self): + component = UserGreeter() + tool = Tool.from_component( + component=component, + name="user_greeter", + description="A tool that greets users with their name and age" + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="I am Alice and I'm 30 years old") + + result = pipeline.run({"llm": {"messages": [message]}}) + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert tool_message.tool_call_result.result == str({"message": "User Alice is 30 years old"}) + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_list_processor_in_pipeline(self): + component = ListProcessor() + tool = Tool.from_component( + component=component, + name="list_processor", + description="A tool that concatenates a list of strings" + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="Can you join these words: hello, beautiful, world") + + result = pipeline.run({"llm": {"messages": [message]}}) + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert tool_message.tool_call_result.result == str({"concatenated": "hello beautiful world"}) + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_product_processor_in_pipeline(self): + component = ProductProcessor() + tool = Tool.from_component( + component=component, + name="product_processor", + description="A tool that creates a description for a product with its name and price" + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="Can you describe a product called Widget that costs $19.99?") + + result = pipeline.run({"llm": {"messages": [message]}}) + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert tool_message.tool_call_result.result == str({"description": "The product Widget costs $19.99."}) + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_person_processor_in_pipeline(self): + component = PersonProcessor() + tool = Tool.from_component( + component=component, + name="person_processor", + description="A tool that processes information about a person and their address" + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="Diana lives at 123 Elm Street in Metropolis") + + result = pipeline.run({"llm": {"messages": [message]}}) + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert tool_message.tool_call_result.result == str({"info": "Diana lives at 123 Elm Street, Metropolis."}) + assert not tool_message.tool_call_result.error From 6d6f3078817e70c32229e4ca1edf3a4d88b917d5 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 18 Dec 2024 12:18:03 +0100 Subject: [PATCH 02/18] Simplify types conversion with TypeAdapter --- haystack_experimental/dataclasses/tool.py | 57 +++-------------------- 1 file changed, 7 insertions(+), 50 deletions(-) diff --git a/haystack_experimental/dataclasses/tool.py b/haystack_experimental/dataclasses/tool.py index f2c04f3a..b4bff232 100644 --- a/haystack_experimental/dataclasses/tool.py +++ b/haystack_experimental/dataclasses/tool.py @@ -3,16 +3,14 @@ # SPDX-License-Identifier: Apache-2.0 import inspect -from dataclasses import asdict, dataclass, is_dataclass -from typing import Any, Callable, Dict, Optional, get_args, get_origin, get_type_hints +from dataclasses import asdict, dataclass +from typing import Any, Callable, Dict, Optional, get_args, get_origin from haystack import logging from haystack.core.component import Component from haystack.lazy_imports import LazyImport from haystack.utils import deserialize_callable, serialize_callable -from pydantic import create_model - -from haystack_experimental.util.utils import is_pydantic_v2_model +from pydantic import TypeAdapter, create_model with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import: from jsonschema import Draft202012Validator @@ -221,37 +219,6 @@ def from_component(cls, component: Component, name: str, description: str) -> "T # Extract the parameters schema from the component parameters = extract_component_parameters(component) - def _convert_to_dataclass(data: Any, data_type: Any) -> Any: - """ - Recursively convert dictionaries into dataclass instances based on the provided data type. - - This function handles nested dataclasses by recursively converting each field. - - :param data: - The input data to convert. - :param data_type: - The target data type, expected to be a dataclass type. - :returns: - An instance of the dataclass with data populated from the input dictionary. - """ - if data is None or not isinstance(data, dict): - return data - - # Check if the target type is a dataclass - if is_dataclass(data_type): - # Get field types for the dataclass (field name -> field type) - field_types = get_type_hints(data_type) - converted_data = {} - # Recursively convert each field in the dataclass - for field_name, field_type in field_types.items(): - if field_name in data: - # Recursive step: convert nested dictionaries into dataclass instances - converted_data[field_name] = _convert_to_dataclass(data[field_name], field_type) - # Instantiate the dataclass with the converted data - return data_type(**converted_data) - # If data_type is not a dataclass, return the data unchanged - return data - def component_invoker(**kwargs): """ Invokes the component using keyword arguments provided by the LLM function calling/tool generated response. @@ -260,36 +227,26 @@ def component_invoker(**kwargs): :returns: The result of the component invocation. """ converted_kwargs = {} - - # Get input sockets for type information input_sockets = component.__haystack_input__._sockets_dict for param_name, param_value in kwargs.items(): socket = input_sockets[param_name] param_type = socket.type - - # Determine the origin type (e.g., list) and target_type origin = get_origin(param_type) or param_type if origin is list: - # Parameter is a list; get the element type target_type = get_args(param_type)[0] values_to_convert = param_value else: - # Parameter is a single value target_type = param_type values_to_convert = [param_value] - # Convert dictionary inputs into dataclass or Pydantic model instances if necessary - if is_dataclass(target_type) or is_pydantic_v2_model(target_type): + if isinstance(param_value, dict): + # TypeAdapter handles dict conversion for both dataclasses and Pydantic models + type_adapter = TypeAdapter(target_type) converted = [ - target_type.model_validate(item) - if is_pydantic_v2_model(target_type) - else _convert_to_dataclass(item, target_type) - for item in values_to_convert - if isinstance(item, dict) + type_adapter.validate_python(item) for item in values_to_convert if isinstance(item, dict) ] - # Update the parameter value with the converted data param_value = converted if origin is list else converted[0] converted_kwargs[param_name] = param_value From eb1496c610d81be77c64e402328899e9ea27e0ca Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 18 Dec 2024 14:08:59 +0100 Subject: [PATCH 03/18] Pylint, small fixes --- .../tools/openai/component_caller.py | 25 +++++++++++-------- test/components/tools/test_tool_component.py | 4 +-- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/haystack_experimental/components/tools/openai/component_caller.py b/haystack_experimental/components/tools/openai/component_caller.py index 889fadac..6f7b0c4a 100644 --- a/haystack_experimental/components/tools/openai/component_caller.py +++ b/haystack_experimental/components/tools/openai/component_caller.py @@ -9,6 +9,7 @@ from docstring_parser import parse from haystack import logging from haystack.core.component import Component +from pydantic.fields import FieldInfo from haystack_experimental.util.utils import is_pydantic_v2_model @@ -62,9 +63,10 @@ def get_param_descriptions(method: Callable) -> Dict[str, str]: return {} parsed_doc = parse(docstring) - return {param.arg_name: param.description.strip() for param in parsed_doc.params} + return {param.arg_name: param.description.strip() if param.description else "" for param in parsed_doc.params} +# ruff: noqa: PLR0912 def create_property_schema(python_type: Any, description: str, default: Any = None) -> Dict[str, Any]: """ Creates a property schema for a given Python type, recursively if necessary. @@ -90,18 +92,21 @@ def create_property_schema(python_type: Any, description: str, default: Any = No required_fields = [] if is_dataclass(python_type): - for field in fields(python_type): - field_description = f"Field '{field.name}' of '{python_type.__name__}'." - schema["properties"][field.name] = create_property_schema(field.type, field_description) + # Get the actual class if python_type is an instance + cls = python_type if isinstance(python_type, type) else python_type.__class__ + for field in fields(cls): + field_description = f"Field '{field.name}' of '{cls.__name__}'." + if isinstance(schema["properties"], dict): + schema["properties"][field.name] = create_property_schema(field.type, field_description) if field.default is MISSING and field.default_factory is MISSING: required_fields.append(field.name) else: # Pydantic model - model_fields = python_type.model_fields - for name, field in model_fields.items(): - field_description = f"Field '{name}' of '{python_type.__name__}'." - schema["properties"][name] = create_property_schema(field.annotation, field_description) - if field.is_required(): - required_fields.append(name) + for m_name, m_field in python_type.model_fields.items(): + field_description = f"Field '{m_name}' of '{python_type.__name__}'." + if isinstance(schema["properties"], dict): + schema["properties"][m_name] = create_property_schema(m_field.annotation, field_description) + if m_field.is_required(): + required_fields.append(m_name) if required_fields: schema["required"] = required_fields diff --git a/test/components/tools/test_tool_component.py b/test/components/tools/test_tool_component.py index cfffffcd..48343e87 100644 --- a/test/components/tools/test_tool_component.py +++ b/test/components/tools/test_tool_component.py @@ -322,7 +322,7 @@ def test_component_tool_in_pipeline(self): tool = Tool.from_component( component=component, name="hello_tool", - description="A tool that generates a greeting message" + description="A tool that generates a greeting message for the user" ) # Create pipeline with OpenAIChatGenerator and ToolInvoker @@ -333,7 +333,7 @@ def test_component_tool_in_pipeline(self): # Connect components pipeline.connect("llm.replies", "tool_invoker.messages") - message = ChatMessage.from_user(text="Hey I'm Vladimir") + message = ChatMessage.from_user(text="Hello, I'm Vladimir") # Run pipeline result = pipeline.run({"llm": {"messages": [message]}}) From 519d53fa761d96552bdf16e44d7bcaa4acd88870 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 18 Dec 2024 15:38:29 +0100 Subject: [PATCH 04/18] Improve warning when component run pydocs are missing --- .../components/tools/openai/component_caller.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/haystack_experimental/components/tools/openai/component_caller.py b/haystack_experimental/components/tools/openai/component_caller.py index 6f7b0c4a..8a23a09b 100644 --- a/haystack_experimental/components/tools/openai/component_caller.py +++ b/haystack_experimental/components/tools/openai/component_caller.py @@ -63,7 +63,18 @@ def get_param_descriptions(method: Callable) -> Dict[str, str]: return {} parsed_doc = parse(docstring) - return {param.arg_name: param.description.strip() if param.description else "" for param in parsed_doc.params} + param_descriptions = {} + for param in parsed_doc.params: + if not param.description: + logger.warning( + "Missing description for parameter '%s'. Please add a description in the component's " + "run() method docstring using the format ':param %s: '. " + "This description is used to generate the Tool and helps the LLM understand how to use this parameter.", + param.arg_name, + param.arg_name, + ) + param_descriptions[param.arg_name] = param.description.strip() if param.description else "" + return param_descriptions # ruff: noqa: PLR0912 @@ -84,6 +95,7 @@ def create_property_schema(python_type: Any, description: str, default: Any = No origin = get_origin(python_type) if origin is list: item_type = get_args(python_type)[0] if get_args(python_type) else Any + # recursively call create_property_schema for the item type items_schema = create_property_schema(item_type, "") items_schema.pop("description", None) schema = {"type": "array", "description": description, "items": items_schema} @@ -92,7 +104,7 @@ def create_property_schema(python_type: Any, description: str, default: Any = No required_fields = [] if is_dataclass(python_type): - # Get the actual class if python_type is an instance + # Get the actual class if python_type is an instance otherwise use the type cls = python_type if isinstance(python_type, type) else python_type.__class__ for field in fields(cls): field_description = f"Field '{field.name}' of '{cls.__name__}'." From 7178e1d77555182695cb270bdfe261518cfec5e2 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 18 Dec 2024 16:22:07 +0100 Subject: [PATCH 05/18] Add Anthropic integration tests --- test/components/tools/test_tool_component.py | 158 ++++++++++++++++++- 1 file changed, 151 insertions(+), 7 deletions(-) diff --git a/test/components/tools/test_tool_component.py b/test/components/tools/test_tool_component.py index 48343e87..2be41af3 100644 --- a/test/components/tools/test_tool_component.py +++ b/test/components/tools/test_tool_component.py @@ -9,13 +9,14 @@ from haystack import component from pydantic import BaseModel from haystack import Pipeline -from haystack_experimental.dataclasses import ChatMessage, ToolCall, ChatRole +from haystack_experimental.dataclasses import ChatMessage, ChatRole from haystack_experimental.components.tools.tool_invoker import ToolInvoker from haystack_experimental.components.generators.chat import OpenAIChatGenerator - +from haystack_experimental.components.generators.anthropic.chat import AnthropicChatGenerator from haystack_experimental.dataclasses.tool import Tool + ### Component and Model Definitions @component @@ -27,7 +28,7 @@ def run(self, text: str) -> Dict[str, str]: """ A simple component that generates text. - :param text: The text to generate. + :param text: user's introductory message :return: A dictionary with the generated text. """ return {"reply": f"Hello, {text}!"} @@ -312,7 +313,7 @@ def test_from_component_with_nested_dataclass(self): ## Integration tests -class TestToolComponentInPipeline: +class TestToolComponentInPipelineWithOpenAI: @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") @pytest.mark.integration @@ -333,7 +334,7 @@ def test_component_tool_in_pipeline(self): # Connect components pipeline.connect("llm.replies", "tool_invoker.messages") - message = ChatMessage.from_user(text="Hello, I'm Vladimir") + message = ChatMessage.from_user(text="Vladimir") # Run pipeline result = pipeline.run({"llm": {"messages": [message]}}) @@ -344,7 +345,7 @@ def test_component_tool_in_pipeline(self): tool_message = tool_messages[0] assert tool_message.is_from(ChatRole.TOOL) - assert tool_message.tool_call_result.result == str({"reply": "Hello, Vladimir!"}) + assert "Vladimir" in tool_message.tool_call_result.result assert not tool_message.tool_call_result.error @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") @@ -448,5 +449,148 @@ def test_person_processor_in_pipeline(self): tool_message = tool_messages[0] assert tool_message.is_from(ChatRole.TOOL) - assert tool_message.tool_call_result.result == str({"info": "Diana lives at 123 Elm Street, Metropolis."}) + assert "Diana" in tool_message.tool_call_result.result and "Metropolis" in tool_message.tool_call_result.result + assert not tool_message.tool_call_result.error + + + + +## Integration tests +class TestToolComponentInPipelineWithAnthropic: + + @pytest.mark.skipif(not os.environ.get("ANTHROPIC_API_KEY"), reason="ANTHROPIC_API_KEY not set") + @pytest.mark.integration + def test_component_tool_in_pipeline(self): + # Create component and convert it to tool + component = SimpleComponent() + tool = Tool.from_component( + component=component, + name="hello_tool", + description="A tool that generates a greeting message for the user" + ) + + # Create pipeline with OpenAIChatGenerator and ToolInvoker + pipeline = Pipeline() + pipeline.add_component("llm", AnthropicChatGenerator(model="claude-3-5-sonnet-20240620", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + + # Connect components + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="Vladimir") + + # Run pipeline + result = pipeline.run({"llm": {"messages": [message]}}) + + # Check results + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert "Vladimir" in tool_message.tool_call_result.result + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("ANTHROPIC_API_KEY"), reason="ANTHROPIC_API_KEY not set") + @pytest.mark.integration + def test_user_greeter_in_pipeline(self): + component = UserGreeter() + tool = Tool.from_component( + component=component, + name="user_greeter", + description="A tool that greets users with their name and age" + ) + + pipeline = Pipeline() + pipeline.add_component("llm", AnthropicChatGenerator(model="claude-3-5-sonnet-20240620", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="I am Alice and I'm 30 years old") + + result = pipeline.run({"llm": {"messages": [message]}}) + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert tool_message.tool_call_result.result == str({"message": "User Alice is 30 years old"}) + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("ANTHROPIC_API_KEY"), reason="ANTHROPIC_API_KEY not set") + @pytest.mark.integration + def test_list_processor_in_pipeline(self): + component = ListProcessor() + tool = Tool.from_component( + component=component, + name="list_processor", + description="A tool that concatenates a list of strings" + ) + + pipeline = Pipeline() + pipeline.add_component("llm", AnthropicChatGenerator(model="claude-3-5-sonnet-20240620", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="Can you join these words: hello, beautiful, world") + + result = pipeline.run({"llm": {"messages": [message]}}) + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert tool_message.tool_call_result.result == str({"concatenated": "hello beautiful world"}) + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("ANTHROPIC_API_KEY"), reason="ANTHROPIC_API_KEY not set") + @pytest.mark.integration + def test_product_processor_in_pipeline(self): + component = ProductProcessor() + tool = Tool.from_component( + component=component, + name="product_processor", + description="A tool that creates a description for a product with its name and price" + ) + + pipeline = Pipeline() + pipeline.add_component("llm", AnthropicChatGenerator(model="claude-3-5-sonnet-20240620", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="Can you describe a product called Widget that costs $19.99?") + + result = pipeline.run({"llm": {"messages": [message]}}) + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert tool_message.tool_call_result.result == str({"description": "The product Widget costs $19.99."}) + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("ANTHROPIC_API_KEY"), reason="ANTHROPIC_API_KEY not set") + @pytest.mark.integration + def test_person_processor_in_pipeline(self): + component = PersonProcessor() + tool = Tool.from_component( + component=component, + name="person_processor", + description="A tool that processes information about a person and their address" + ) + + pipeline = Pipeline() + pipeline.add_component("llm", AnthropicChatGenerator(model="claude-3-5-sonnet-20240620", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="Diana lives at 123 Elm Street in Metropolis") + + result = pipeline.run({"llm": {"messages": [message]}}) + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert "Diana" in tool_message.tool_call_result.result and "Metropolis" in tool_message.tool_call_result.result assert not tool_message.tool_call_result.error From d83ca1351a4054657d212751fc74af3b3f4d21ce Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 18 Dec 2024 16:25:10 +0100 Subject: [PATCH 06/18] Minor test fix --- test/components/tools/test_tool_component.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/components/tools/test_tool_component.py b/test/components/tools/test_tool_component.py index 2be41af3..7fc89df7 100644 --- a/test/components/tools/test_tool_component.py +++ b/test/components/tools/test_tool_component.py @@ -143,7 +143,7 @@ def test_from_component_basic(self): "properties": { "text": { "type": "string", - "description": "The text to generate." + "description": "user's introductory message" } }, "required": ["text"] From 1c259f9e7a5e1f90970f0f49d5bd7ce2ab628ff5 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 19 Dec 2024 10:57:30 +0100 Subject: [PATCH 07/18] Handle our own dataclasses (e.g. Document) --- haystack_experimental/dataclasses/tool.py | 33 ++--- test/components/tools/test_tool_component.py | 147 +++++++++++++++++++ 2 files changed, 160 insertions(+), 20 deletions(-) diff --git a/haystack_experimental/dataclasses/tool.py b/haystack_experimental/dataclasses/tool.py index b4bff232..dfa2e389 100644 --- a/haystack_experimental/dataclasses/tool.py +++ b/haystack_experimental/dataclasses/tool.py @@ -228,30 +228,23 @@ def component_invoker(**kwargs): """ converted_kwargs = {} input_sockets = component.__haystack_input__._sockets_dict - for param_name, param_value in kwargs.items(): - socket = input_sockets[param_name] - param_type = socket.type - origin = get_origin(param_type) or param_type - - if origin is list: - target_type = get_args(param_type)[0] - values_to_convert = param_value + param_type = input_sockets[param_name].type + + # Check if the type (or list element type) has from_dict + target_type = get_args(param_type)[0] if get_origin(param_type) is list else param_type + if hasattr(target_type, "from_dict"): + if isinstance(param_value, list): + param_value = [target_type.from_dict(item) for item in param_value if isinstance(item, dict)] + elif isinstance(param_value, dict): + param_value = target_type.from_dict(param_value) else: - target_type = param_type - values_to_convert = [param_value] - - if isinstance(param_value, dict): - # TypeAdapter handles dict conversion for both dataclasses and Pydantic models - type_adapter = TypeAdapter(target_type) - converted = [ - type_adapter.validate_python(item) for item in values_to_convert if isinstance(item, dict) - ] - param_value = converted if origin is list else converted[0] + # Let TypeAdapter handle both single values and lists + type_adapter = TypeAdapter(param_type) + param_value = type_adapter.validate_python(param_value) converted_kwargs[param_name] = param_value - - logger.debug(f"Invoking component with kwargs: {converted_kwargs}") + logger.debug(f"Invoking component {type(component)} with kwargs: {converted_kwargs}") return component.run(**converted_kwargs) # Return a new Tool instance with the component invoker as the function to be called diff --git a/test/components/tools/test_tool_component.py b/test/components/tools/test_tool_component.py index 7fc89df7..3e0b0685 100644 --- a/test/components/tools/test_tool_component.py +++ b/test/components/tools/test_tool_component.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import json import os import pytest from typing import Dict, List, Optional, Any @@ -9,6 +10,7 @@ from haystack import component from pydantic import BaseModel from haystack import Pipeline +from haystack.dataclasses import Document from haystack_experimental.dataclasses import ChatMessage, ChatRole from haystack_experimental.components.tools.tool_invoker import ToolInvoker from haystack_experimental.components.generators.chat import OpenAIChatGenerator @@ -125,6 +127,21 @@ def run(self, person: Person) -> Dict[str, str]: } +@component +class DocumentProcessor: + """A component that processes a list of Documents.""" + + @component.output_types(concatenated=str) + def run(self, documents: List[Document]) -> Dict[str, str]: + """ + Concatenates the content of multiple documents with newlines. + + :param documents: List of Documents whose content will be concatenated + :returns: Dictionary containing the concatenated document contents + """ + return {"concatenated": '\n'.join(doc.content for doc in documents)} + + ## Unit tests class TestToolComponent: def test_from_component_basic(self): @@ -311,6 +328,104 @@ def test_from_component_with_nested_dataclass(self): assert "info" in result assert result["info"] == "Diana lives at 123 Elm Street, Metropolis." + def test_from_component_with_document_list(self): + component = DocumentProcessor() + + tool = Tool.from_component( + component=component, + name="document_processor", + description="A tool that concatenates document contents" + ) + + assert tool.parameters == { + "type": "object", + "properties": { + "documents": { + "type": "array", + "description": "List of Documents whose content will be concatenated", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Field 'id' of 'Document'." + }, + "content": { + "type": "string", + "description": "Field 'content' of 'Document'." + }, + "dataframe": { + "type": "string", + "description": "Field 'dataframe' of 'Document'." + }, + "blob": { + "type": "object", + "description": "Field 'blob' of 'Document'.", + "properties": { + "data": { + "type": "string", + "description": "Field 'data' of 'ByteStream'." + }, + "meta": { + "type": "string", + "description": "Field 'meta' of 'ByteStream'." + }, + "mime_type": { + "type": "string", + "description": "Field 'mime_type' of 'ByteStream'." + } + }, + "required": ["data"] + }, + "meta": { + "type": "string", + "description": "Field 'meta' of 'Document'." + }, + "score": { + "type": "number", + "description": "Field 'score' of 'Document'." + }, + "embedding": { + "type": "array", + "description": "Field 'embedding' of 'Document'.", + "items": { + "type": "number" + } + }, + "sparse_embedding": { + "type": "object", + "description": "Field 'sparse_embedding' of 'Document'.", + "properties": { + "indices": { + "type": "array", + "description": "Field 'indices' of 'SparseEmbedding'.", + "items": { + "type": "integer" + } + }, + "values": { + "type": "array", + "description": "Field 'values' of 'SparseEmbedding'.", + "items": { + "type": "number" + } + } + }, + "required": ["indices", "values"] + } + } + } + } + }, + "required": ["documents"] + } + + # Test tool invocation + result = tool.invoke(documents=[{"content": "First document"}, {"content": "Second document"}]) + assert isinstance(result, dict) + assert "concatenated" in result + assert result["concatenated"] == "First document\nSecond document" + ## Integration tests class TestToolComponentInPipelineWithOpenAI: @@ -452,6 +567,38 @@ def test_person_processor_in_pipeline(self): assert "Diana" in tool_message.tool_call_result.result and "Metropolis" in tool_message.tool_call_result.result assert not tool_message.tool_call_result.error + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_document_processor_in_pipeline(self): + component = DocumentProcessor() + tool = Tool.from_component( + component=component, + name="document_processor", + description="A tool that concatenates the content of multiple documents" + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool], convert_result_to_json_string=True)) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user( + text="I have two documents. First one says 'Hello world' and second one says 'Goodbye world'. Can you concatenate them?" + ) + + result = pipeline.run({"llm": {"messages": [message]}}) + + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + result = json.loads(tool_message.tool_call_result.result) + assert "concatenated" in result + assert "Hello world" in result["concatenated"] + assert "Goodbye world" in result["concatenated"] + assert not tool_message.tool_call_result.error + From 34a0861d494679d09e9786fc60da7f7c787a6b6b Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 19 Dec 2024 12:03:57 +0100 Subject: [PATCH 08/18] For dataclasses don't check required fields, add more itegration tests --- .../tools/openai/component_caller.py | 3 +- test/components/tools/test_tool_component.py | 40 +++++++++++++++---- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/haystack_experimental/components/tools/openai/component_caller.py b/haystack_experimental/components/tools/openai/component_caller.py index 8a23a09b..d4f20070 100644 --- a/haystack_experimental/components/tools/openai/component_caller.py +++ b/haystack_experimental/components/tools/openai/component_caller.py @@ -110,8 +110,7 @@ def create_property_schema(python_type: Any, description: str, default: Any = No field_description = f"Field '{field.name}' of '{cls.__name__}'." if isinstance(schema["properties"], dict): schema["properties"][field.name] = create_property_schema(field.type, field_description) - if field.default is MISSING and field.default_factory is MISSING: - required_fields.append(field.name) + else: # Pydantic model for m_name, m_field in python_type.model_fields.items(): field_description = f"Field '{m_name}' of '{python_type.__name__}'." diff --git a/test/components/tools/test_tool_component.py b/test/components/tools/test_tool_component.py index 3e0b0685..fdc50fe3 100644 --- a/test/components/tools/test_tool_component.py +++ b/test/components/tools/test_tool_component.py @@ -306,11 +306,9 @@ def test_from_component_with_nested_dataclass(self): "type": "string", "description": "Field 'city' of 'Address'." } - }, - "required": ["street", "city"] + } } - }, - "required": ["name", "address"] + } } }, "required": ["person"] @@ -374,8 +372,7 @@ def test_from_component_with_document_list(self): "type": "string", "description": "Field 'mime_type' of 'ByteStream'." } - }, - "required": ["data"] + } }, "meta": { "type": "string", @@ -410,8 +407,7 @@ def test_from_component_with_document_list(self): "type": "number" } } - }, - "required": ["indices", "values"] + } } } } @@ -599,6 +595,34 @@ def test_document_processor_in_pipeline(self): assert "Goodbye world" in result["concatenated"] assert not tool_message.tool_call_result.error + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_lost_in_middle_ranker_in_pipeline(self): + from haystack.components.rankers import LostInTheMiddleRanker + + component = LostInTheMiddleRanker(top_k=2) + tool = Tool.from_component( + component=component, + name="lost_in_middle_ranker", + description="A tool that ranks documents using the Lost in the Middle algorithm and returns top k results" + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user( + text="I have three documents with content: 'First doc', 'Middle doc', and 'Last doc'. Rank them top_k=2. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, dataframe, blob fields." + ) + + result = pipeline.run({"llm": {"messages": [message]}}) + + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + From 551a528ea8d6d1d63ef9bc1ae667619c92b52531 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 19 Dec 2024 12:05:48 +0100 Subject: [PATCH 09/18] Small fix for better test --- test/components/tools/test_tool_component.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/components/tools/test_tool_component.py b/test/components/tools/test_tool_component.py index fdc50fe3..e909734f 100644 --- a/test/components/tools/test_tool_component.py +++ b/test/components/tools/test_tool_component.py @@ -600,7 +600,7 @@ def test_document_processor_in_pipeline(self): def test_lost_in_middle_ranker_in_pipeline(self): from haystack.components.rankers import LostInTheMiddleRanker - component = LostInTheMiddleRanker(top_k=2) + component = LostInTheMiddleRanker() tool = Tool.from_component( component=component, name="lost_in_middle_ranker", From ce864dd8f4f10c2c42f3ded35a515ef71a995585 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 19 Dec 2024 15:42:11 +0100 Subject: [PATCH 10/18] Make sure we are only using non-pipeline components for Tools --- haystack_experimental/dataclasses/tool.py | 13 ++++++++++ test/components/tools/test_tool_component.py | 26 ++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/haystack_experimental/dataclasses/tool.py b/haystack_experimental/dataclasses/tool.py index dfa2e389..6c3a026d 100644 --- a/haystack_experimental/dataclasses/tool.py +++ b/haystack_experimental/dataclasses/tool.py @@ -214,6 +214,19 @@ def from_component(cls, component: Component, name: str, description: str) -> "T :returns: The Tool created from the Component. :raises ValueError: If the component is invalid or schema generation fails. """ + + if not isinstance(component, Component): + raise ValueError( + f"{component} is not a Haystack component!" "Can only create a Tool from a Haystack component instance." + ) + + if getattr(component, "__haystack_added_to_pipeline__", None): + msg = ( + "Component has been added in a Pipeline and can't be used to create a Tool. " + "Create Tool from a non-pipeline component instead." + ) + raise ValueError(msg) + from haystack_experimental.components.tools.openai.component_caller import extract_component_parameters # Extract the parameters schema from the component diff --git a/test/components/tools/test_tool_component.py b/test/components/tools/test_tool_component.py index e909734f..9cb26ca0 100644 --- a/test/components/tools/test_tool_component.py +++ b/test/components/tools/test_tool_component.py @@ -422,6 +422,32 @@ def test_from_component_with_document_list(self): assert "concatenated" in result assert result["concatenated"] == "First document\nSecond document" + def test_from_component_with_non_component(self): + class NotAComponent: + def foo(self, text: str): + return {"reply": f"Hello, {text}!"} + + not_a_component = NotAComponent() + + with pytest.raises(ValueError): + Tool.from_component( + component=not_a_component, + name="invalid_tool", + description="This should fail" + ) + + def test_from_component_for_pipeline_component(self): + pipeline = Pipeline() + component = SimpleComponent() + pipeline.add_component("component", component) + + with pytest.raises(ValueError): + Tool.from_component( + component=component, + name="invalid_tool", + description="This should fail" + ) + ## Integration tests class TestToolComponentInPipelineWithOpenAI: From 931df70122810c11db7e673c909ade880b50a032 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 20 Dec 2024 14:10:19 +0100 Subject: [PATCH 11/18] Move modules around --- haystack_experimental/dataclasses/tool.py | 17 ++++----- haystack_experimental/tools/__init__.py | 7 ++++ .../tool_component_descriptor.py} | 35 +++++++++++-------- test/components/tools/test_tool_component.py | 18 ++-------- 4 files changed, 36 insertions(+), 41 deletions(-) create mode 100644 haystack_experimental/tools/__init__.py rename haystack_experimental/{components/tools/openai/component_caller.py => tools/tool_component_descriptor.py} (97%) diff --git a/haystack_experimental/dataclasses/tool.py b/haystack_experimental/dataclasses/tool.py index 6c3a026d..070cd70f 100644 --- a/haystack_experimental/dataclasses/tool.py +++ b/haystack_experimental/dataclasses/tool.py @@ -12,6 +12,8 @@ from haystack.utils import deserialize_callable, serialize_callable from pydantic import TypeAdapter, create_model +from haystack_experimental.tools import extract_component_parameters + with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import: from jsonschema import Draft202012Validator from jsonschema.exceptions import SchemaError @@ -216,18 +218,11 @@ def from_component(cls, component: Component, name: str, description: str) -> "T """ if not isinstance(component, Component): - raise ValueError( - f"{component} is not a Haystack component!" "Can only create a Tool from a Haystack component instance." - ) - - if getattr(component, "__haystack_added_to_pipeline__", None): - msg = ( - "Component has been added in a Pipeline and can't be used to create a Tool. " - "Create Tool from a non-pipeline component instead." + message = ( + f"Object {component!r} is not a Haystack component. " + "Use this method to create a Tool only with Haystack component instances." ) - raise ValueError(msg) - - from haystack_experimental.components.tools.openai.component_caller import extract_component_parameters + raise ValueError(message) # Extract the parameters schema from the component parameters = extract_component_parameters(component) diff --git a/haystack_experimental/tools/__init__.py b/haystack_experimental/tools/__init__.py new file mode 100644 index 00000000..826ea101 --- /dev/null +++ b/haystack_experimental/tools/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .tool_component_descriptor import extract_component_parameters + +__all__ = ["extract_component_parameters"] diff --git a/haystack_experimental/components/tools/openai/component_caller.py b/haystack_experimental/tools/tool_component_descriptor.py similarity index 97% rename from haystack_experimental/components/tools/openai/component_caller.py rename to haystack_experimental/tools/tool_component_descriptor.py index d4f20070..b3aae0ce 100644 --- a/haystack_experimental/components/tools/openai/component_caller.py +++ b/haystack_experimental/tools/tool_component_descriptor.py @@ -2,14 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 -from dataclasses import MISSING, fields, is_dataclass +from dataclasses import fields, is_dataclass from inspect import getdoc from typing import Any, Callable, Dict, Union, get_args, get_origin from docstring_parser import parse from haystack import logging from haystack.core.component import Component -from pydantic.fields import FieldInfo from haystack_experimental.util.utils import is_pydantic_v2_model @@ -77,6 +76,25 @@ def get_param_descriptions(method: Callable) -> Dict[str, str]: return param_descriptions +class UnsupportedTypeError(Exception): + """Raised when a type is not supported for schema generation.""" + + pass + + +def is_nullable_type(python_type: Any) -> bool: + """ + Checks if the type is a Union with NoneType (i.e., Optional). + + :param python_type: The Python type to check. + :returns: True if the type is a Union with NoneType, False otherwise. + """ + origin = get_origin(python_type) + if origin is Union: + return type(None) in get_args(python_type) + return False + + # ruff: noqa: PLR0912 def create_property_schema(python_type: Any, description: str, default: Any = None) -> Dict[str, Any]: """ @@ -130,16 +148,3 @@ def create_property_schema(python_type: Any, description: str, default: Any = No schema["default"] = default return schema - - -def is_nullable_type(python_type: Any) -> bool: - """ - Checks if the type is a Union with NoneType (i.e., Optional). - - :param python_type: The Python type to check. - :returns: True if the type is a Union with NoneType, False otherwise. - """ - origin = get_origin(python_type) - if origin is Union: - return type(None) in get_args(python_type) - return False diff --git a/test/components/tools/test_tool_component.py b/test/components/tools/test_tool_component.py index 9cb26ca0..9866a278 100644 --- a/test/components/tools/test_tool_component.py +++ b/test/components/tools/test_tool_component.py @@ -30,7 +30,7 @@ def run(self, text: str) -> Dict[str, str]: """ A simple component that generates text. - :param text: user's introductory message + :param text: user's name :return: A dictionary with the generated text. """ return {"reply": f"Hello, {text}!"} @@ -160,7 +160,7 @@ def test_from_component_basic(self): "properties": { "text": { "type": "string", - "description": "user's introductory message" + "description": "user's name" } }, "required": ["text"] @@ -436,18 +436,6 @@ def foo(self, text: str): description="This should fail" ) - def test_from_component_for_pipeline_component(self): - pipeline = Pipeline() - component = SimpleComponent() - pipeline.add_component("component", component) - - with pytest.raises(ValueError): - Tool.from_component( - component=component, - name="invalid_tool", - description="This should fail" - ) - ## Integration tests class TestToolComponentInPipelineWithOpenAI: @@ -605,7 +593,7 @@ def test_document_processor_in_pipeline(self): pipeline.connect("llm.replies", "tool_invoker.messages") message = ChatMessage.from_user( - text="I have two documents. First one says 'Hello world' and second one says 'Goodbye world'. Can you concatenate them?" + text="Concatenate these documents: First one says 'Hello world' and second one says 'Goodbye world'. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, dataframe, blob fields." ) result = pipeline.run({"llm": {"messages": [message]}}) From 3b2945894d87fbec3d911d7e51182dea473a1275 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 20 Dec 2024 14:54:13 +0100 Subject: [PATCH 12/18] Refactor and simplify tools schema creation --- .../tools/tool_component_descriptor.py | 112 ++++++++++++------ 1 file changed, 73 insertions(+), 39 deletions(-) diff --git a/haystack_experimental/tools/tool_component_descriptor.py b/haystack_experimental/tools/tool_component_descriptor.py index b3aae0ce..8867c752 100644 --- a/haystack_experimental/tools/tool_component_descriptor.py +++ b/haystack_experimental/tools/tool_component_descriptor.py @@ -17,7 +17,7 @@ def extract_component_parameters(component: Component) -> Dict[str, Any]: """ - Extracts parameters from a Haystack component and converts them to OpenAI tools JSON format. + Extracts parameters from a component's run method and converts them to OpenAI tools definition format. :param component: The component to extract parameters from. :returns: A dictionary representing the component's input parameters schema. @@ -38,7 +38,7 @@ def extract_component_parameters(component: Component) -> Dict[str, Any]: properties[input_name] = property_schema - # Use socket.is_mandatory() to check if the input is required + # Use socket.is_mandatory to check if the input is required if socket.is_mandatory: required.append(input_name) @@ -76,12 +76,6 @@ def get_param_descriptions(method: Callable) -> Dict[str, str]: return param_descriptions -class UnsupportedTypeError(Exception): - """Raised when a type is not supported for schema generation.""" - - pass - - def is_nullable_type(python_type: Any) -> bool: """ Checks if the type is a Union with NoneType (i.e., Optional). @@ -95,7 +89,71 @@ def is_nullable_type(python_type: Any) -> bool: return False -# ruff: noqa: PLR0912 +def _create_list_schema(item_type: Any, description: str) -> Dict[str, Any]: + """ + Creates a schema for a list type. + + :param item_type: The type of items in the list. + :param description: The description of the list. + :returns: A dictionary representing the list schema. + """ + items_schema = create_property_schema(item_type, "") + items_schema.pop("description", None) + return {"type": "array", "description": description, "items": items_schema} + + +def _create_dataclass_schema(python_type: Any, description: str) -> Dict[str, Any]: + """ + Creates a schema for a dataclass. + + :param python_type: The dataclass type. + :param description: The description of the dataclass. + :returns: A dictionary representing the dataclass schema. + """ + schema = {"type": "object", "description": description, "properties": {}} + cls = python_type if isinstance(python_type, type) else python_type.__class__ + for field in fields(cls): + field_description = f"Field '{field.name}' of '{cls.__name__}'." + if isinstance(schema["properties"], dict): + schema["properties"][field.name] = create_property_schema(field.type, field_description) + return schema + + +def _create_pydantic_schema(python_type: Any, description: str) -> Dict[str, Any]: + """ + Creates a schema for a Pydantic model. + + :param python_type: The Pydantic model type. + :param description: The description of the model. + :returns: A dictionary representing the Pydantic model schema. + """ + schema = {"type": "object", "description": description, "properties": {}} + required_fields = [] + + for m_name, m_field in python_type.model_fields.items(): + field_description = f"Field '{m_name}' of '{python_type.__name__}'." + if isinstance(schema["properties"], dict): + schema["properties"][m_name] = create_property_schema(m_field.annotation, field_description) + if m_field.is_required(): + required_fields.append(m_name) + + if required_fields: + schema["required"] = required_fields + return schema + + +def _create_basic_type_schema(python_type: Any, description: str) -> Dict[str, Any]: + """ + Creates a schema for a basic Python type. + + :param python_type: The Python type. + :param description: The description of the type. + :returns: A dictionary representing the basic type schema. + """ + type_mapping = {str: "string", int: "integer", float: "number", bool: "boolean", dict: "object"} + return {"type": type_mapping.get(python_type, "string"), "description": description} + + def create_property_schema(python_type: Any, description: str, default: Any = None) -> Dict[str, Any]: """ Creates a property schema for a given Python type, recursively if necessary. @@ -112,37 +170,13 @@ def create_property_schema(python_type: Any, description: str, default: Any = No origin = get_origin(python_type) if origin is list: - item_type = get_args(python_type)[0] if get_args(python_type) else Any - # recursively call create_property_schema for the item type - items_schema = create_property_schema(item_type, "") - items_schema.pop("description", None) - schema = {"type": "array", "description": description, "items": items_schema} - elif is_dataclass(python_type) or is_pydantic_v2_model(python_type): - schema = {"type": "object", "description": description, "properties": {}} - required_fields = [] - - if is_dataclass(python_type): - # Get the actual class if python_type is an instance otherwise use the type - cls = python_type if isinstance(python_type, type) else python_type.__class__ - for field in fields(cls): - field_description = f"Field '{field.name}' of '{cls.__name__}'." - if isinstance(schema["properties"], dict): - schema["properties"][field.name] = create_property_schema(field.type, field_description) - - else: # Pydantic model - for m_name, m_field in python_type.model_fields.items(): - field_description = f"Field '{m_name}' of '{python_type.__name__}'." - if isinstance(schema["properties"], dict): - schema["properties"][m_name] = create_property_schema(m_field.annotation, field_description) - if m_field.is_required(): - required_fields.append(m_name) - - if required_fields: - schema["required"] = required_fields + schema = _create_list_schema(get_args(python_type)[0] if get_args(python_type) else Any, description) + elif is_dataclass(python_type): + schema = _create_dataclass_schema(python_type, description) + elif is_pydantic_v2_model(python_type): + schema = _create_pydantic_schema(python_type, description) else: - # Basic types - type_mapping = {str: "string", int: "integer", float: "number", bool: "boolean", dict: "object"} - schema = {"type": type_mapping.get(python_type, "string"), "description": description} + schema = _create_basic_type_schema(python_type, description) if default is not None: schema["default"] = default From d8f722caa63132abc346dfac65dc77451a9fc752 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 20 Dec 2024 15:05:37 +0100 Subject: [PATCH 13/18] Better naming --- haystack_experimental/dataclasses/tool.py | 8 ++++---- haystack_experimental/tools/__init__.py | 4 ++-- haystack_experimental/tools/tool_component_descriptor.py | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/haystack_experimental/dataclasses/tool.py b/haystack_experimental/dataclasses/tool.py index 070cd70f..0fdb5f7d 100644 --- a/haystack_experimental/dataclasses/tool.py +++ b/haystack_experimental/dataclasses/tool.py @@ -12,7 +12,7 @@ from haystack.utils import deserialize_callable, serialize_callable from pydantic import TypeAdapter, create_model -from haystack_experimental.tools import extract_component_parameters +from haystack_experimental.tools import create_tool_parameters_schema with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import: from jsonschema import Draft202012Validator @@ -224,8 +224,8 @@ def from_component(cls, component: Component, name: str, description: str) -> "T ) raise ValueError(message) - # Extract the parameters schema from the component - parameters = extract_component_parameters(component) + # Create the tools schema from the component run method parameters + tool_schema = create_tool_parameters_schema(component) def component_invoker(**kwargs): """ @@ -256,7 +256,7 @@ def component_invoker(**kwargs): return component.run(**converted_kwargs) # Return a new Tool instance with the component invoker as the function to be called - return Tool(name=name, description=description, parameters=parameters, function=component_invoker) + return Tool(name=name, description=description, parameters=tool_schema, function=component_invoker) def _remove_title_from_schema(schema: Dict[str, Any]): diff --git a/haystack_experimental/tools/__init__.py b/haystack_experimental/tools/__init__.py index 826ea101..61d49889 100644 --- a/haystack_experimental/tools/__init__.py +++ b/haystack_experimental/tools/__init__.py @@ -2,6 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -from .tool_component_descriptor import extract_component_parameters +from .tool_component_descriptor import create_tool_parameters_schema -__all__ = ["extract_component_parameters"] +__all__ = ["create_tool_parameters_schema"] diff --git a/haystack_experimental/tools/tool_component_descriptor.py b/haystack_experimental/tools/tool_component_descriptor.py index 8867c752..e01665bc 100644 --- a/haystack_experimental/tools/tool_component_descriptor.py +++ b/haystack_experimental/tools/tool_component_descriptor.py @@ -15,12 +15,12 @@ logger = logging.getLogger(__name__) -def extract_component_parameters(component: Component) -> Dict[str, Any]: +def create_tool_parameters_schema(component: Component) -> Dict[str, Any]: """ - Extracts parameters from a component's run method and converts them to OpenAI tools definition format. + Creates an OpenAI tools schema from a component's run method parameters. - :param component: The component to extract parameters from. - :returns: A dictionary representing the component's input parameters schema. + :param component: The component to create the schema from. + :returns: OpenAI tools schema for the component's run method parameters. """ properties = {} required = [] From 2dbb8d242ef0561ba74bf1779aa48c8e945c827b Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 20 Dec 2024 15:15:41 +0100 Subject: [PATCH 14/18] Rename module --- haystack_experimental/tools/__init__.py | 2 +- .../tools/{tool_component_descriptor.py => component_schema.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename haystack_experimental/tools/{tool_component_descriptor.py => component_schema.py} (100%) diff --git a/haystack_experimental/tools/__init__.py b/haystack_experimental/tools/__init__.py index 61d49889..e2fe866c 100644 --- a/haystack_experimental/tools/__init__.py +++ b/haystack_experimental/tools/__init__.py @@ -2,6 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -from .tool_component_descriptor import create_tool_parameters_schema +from .component_schema import create_tool_parameters_schema __all__ = ["create_tool_parameters_schema"] diff --git a/haystack_experimental/tools/tool_component_descriptor.py b/haystack_experimental/tools/component_schema.py similarity index 100% rename from haystack_experimental/tools/tool_component_descriptor.py rename to haystack_experimental/tools/component_schema.py From b35c0982a9d54bd0873bf73cc916e60b65ffa928 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 20 Dec 2024 17:17:41 +0100 Subject: [PATCH 15/18] PR feedback --- haystack_experimental/dataclasses/tool.py | 4 ++-- haystack_experimental/tools/__init__.py | 4 ---- .../tools/component_schema.py | 20 +++++++++---------- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/haystack_experimental/dataclasses/tool.py b/haystack_experimental/dataclasses/tool.py index 0fdb5f7d..f3d85b48 100644 --- a/haystack_experimental/dataclasses/tool.py +++ b/haystack_experimental/dataclasses/tool.py @@ -12,7 +12,7 @@ from haystack.utils import deserialize_callable, serialize_callable from pydantic import TypeAdapter, create_model -from haystack_experimental.tools import create_tool_parameters_schema +from haystack_experimental.tools.component_schema import _create_tool_parameters_schema with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import: from jsonschema import Draft202012Validator @@ -225,7 +225,7 @@ def from_component(cls, component: Component, name: str, description: str) -> "T raise ValueError(message) # Create the tools schema from the component run method parameters - tool_schema = create_tool_parameters_schema(component) + tool_schema = _create_tool_parameters_schema(component) def component_invoker(**kwargs): """ diff --git a/haystack_experimental/tools/__init__.py b/haystack_experimental/tools/__init__.py index e2fe866c..c1764a6e 100644 --- a/haystack_experimental/tools/__init__.py +++ b/haystack_experimental/tools/__init__.py @@ -1,7 +1,3 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 - -from .component_schema import create_tool_parameters_schema - -__all__ = ["create_tool_parameters_schema"] diff --git a/haystack_experimental/tools/component_schema.py b/haystack_experimental/tools/component_schema.py index e01665bc..6dd81a96 100644 --- a/haystack_experimental/tools/component_schema.py +++ b/haystack_experimental/tools/component_schema.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) -def create_tool_parameters_schema(component: Component) -> Dict[str, Any]: +def _create_tool_parameters_schema(component: Component) -> Dict[str, Any]: """ Creates an OpenAI tools schema from a component's run method parameters. @@ -25,14 +25,14 @@ def create_tool_parameters_schema(component: Component) -> Dict[str, Any]: properties = {} required = [] - param_descriptions = get_param_descriptions(component.run) + param_descriptions = _get_param_descriptions(component.run) for input_name, socket in component.__haystack_input__._sockets_dict.items(): input_type = socket.type description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.") try: - property_schema = create_property_schema(input_type, description) + property_schema = _create_property_schema(input_type, description) except ValueError as e: raise ValueError(f"Error processing input '{input_name}': {e}") @@ -50,7 +50,7 @@ def create_tool_parameters_schema(component: Component) -> Dict[str, Any]: return parameters_schema -def get_param_descriptions(method: Callable) -> Dict[str, str]: +def _get_param_descriptions(method: Callable) -> Dict[str, str]: """ Extracts parameter descriptions from the method's docstring using docstring_parser. @@ -76,7 +76,7 @@ def get_param_descriptions(method: Callable) -> Dict[str, str]: return param_descriptions -def is_nullable_type(python_type: Any) -> bool: +def _is_nullable_type(python_type: Any) -> bool: """ Checks if the type is a Union with NoneType (i.e., Optional). @@ -97,7 +97,7 @@ def _create_list_schema(item_type: Any, description: str) -> Dict[str, Any]: :param description: The description of the list. :returns: A dictionary representing the list schema. """ - items_schema = create_property_schema(item_type, "") + items_schema = _create_property_schema(item_type, "") items_schema.pop("description", None) return {"type": "array", "description": description, "items": items_schema} @@ -115,7 +115,7 @@ def _create_dataclass_schema(python_type: Any, description: str) -> Dict[str, An for field in fields(cls): field_description = f"Field '{field.name}' of '{cls.__name__}'." if isinstance(schema["properties"], dict): - schema["properties"][field.name] = create_property_schema(field.type, field_description) + schema["properties"][field.name] = _create_property_schema(field.type, field_description) return schema @@ -133,7 +133,7 @@ def _create_pydantic_schema(python_type: Any, description: str) -> Dict[str, Any for m_name, m_field in python_type.model_fields.items(): field_description = f"Field '{m_name}' of '{python_type.__name__}'." if isinstance(schema["properties"], dict): - schema["properties"][m_name] = create_property_schema(m_field.annotation, field_description) + schema["properties"][m_name] = _create_property_schema(m_field.annotation, field_description) if m_field.is_required(): required_fields.append(m_name) @@ -154,7 +154,7 @@ def _create_basic_type_schema(python_type: Any, description: str) -> Dict[str, A return {"type": type_mapping.get(python_type, "string"), "description": description} -def create_property_schema(python_type: Any, description: str, default: Any = None) -> Dict[str, Any]: +def _create_property_schema(python_type: Any, description: str, default: Any = None) -> Dict[str, Any]: """ Creates a property schema for a given Python type, recursively if necessary. @@ -163,7 +163,7 @@ def create_property_schema(python_type: Any, description: str, default: Any = No :param default: The default value of the property. :returns: A dictionary representing the property schema. """ - nullable = is_nullable_type(python_type) + nullable = _is_nullable_type(python_type) if nullable: non_none_types = [t for t in get_args(python_type) if t is not type(None)] python_type = non_none_types[0] if non_none_types else str From 9c840bfa3d8a7c0b7e1091bd4b413aaadef383df Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 6 Jan 2025 11:39:25 +0100 Subject: [PATCH 16/18] Add top_k to DocumentProcessor, update tests --- test/components/tools/test_tool_component.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/test/components/tools/test_tool_component.py b/test/components/tools/test_tool_component.py index 9866a278..41d2ec38 100644 --- a/test/components/tools/test_tool_component.py +++ b/test/components/tools/test_tool_component.py @@ -132,14 +132,15 @@ class DocumentProcessor: """A component that processes a list of Documents.""" @component.output_types(concatenated=str) - def run(self, documents: List[Document]) -> Dict[str, str]: + def run(self, documents: List[Document], top_k: int = 5) -> Dict[str, str]: """ Concatenates the content of multiple documents with newlines. :param documents: List of Documents whose content will be concatenated + :param top_k: The number of top documents to concatenate :returns: Dictionary containing the concatenated document contents """ - return {"concatenated": '\n'.join(doc.content for doc in documents)} + return {"concatenated": '\n'.join(doc.content for doc in documents[:top_k])} ## Unit tests @@ -411,7 +412,11 @@ def test_from_component_with_document_list(self): } } } - } + }, + "top_k": { + "description": "The number of top documents to concatenate", + "type": "integer", + }, }, "required": ["documents"] } @@ -593,7 +598,7 @@ def test_document_processor_in_pipeline(self): pipeline.connect("llm.replies", "tool_invoker.messages") message = ChatMessage.from_user( - text="Concatenate these documents: First one says 'Hello world' and second one says 'Goodbye world'. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, dataframe, blob fields." + text="Concatenate these documents: First one says 'Hello world' and second one says 'Goodbye world' and third one says 'Hello again', but use top_k=2. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, dataframe, blob fields." ) result = pipeline.run({"llm": {"messages": [message]}}) From 150fb8199dc33312855b8c9adbe425ebd1b5370c Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 6 Jan 2025 14:42:07 +0100 Subject: [PATCH 17/18] Make name/description optional --- haystack_experimental/dataclasses/tool.py | 27 +++++++++++++++++--- test/components/tools/test_tool_component.py | 20 +++++---------- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/haystack_experimental/dataclasses/tool.py b/haystack_experimental/dataclasses/tool.py index f3d85b48..090b3d6a 100644 --- a/haystack_experimental/dataclasses/tool.py +++ b/haystack_experimental/dataclasses/tool.py @@ -206,13 +206,20 @@ def get_weather( return Tool(name=name or function.__name__, description=tool_description, parameters=schema, function=function) @classmethod - def from_component(cls, component: Component, name: str, description: str) -> "Tool": + def from_component( + cls, + component: Component, + name: Optional[str] = None, + description: Optional[str] = None, + ) -> "Tool": """ Create a Tool instance from a Haystack component. :param component: The Haystack component to be converted into a Tool. - :param name: Name for the tool. - :param description: Description of the tool. + :param name: Name for the tool (optional). If not provided, the name will be derived + from the component's class name using snake_case (e.g. "user_component" for UserComponent). + :param description: Description of the tool (optional). If not provided, the pydoc description + of the component will be used. :returns: The Tool created from the Component. :raises ValueError: If the component is invalid or schema generation fails. """ @@ -255,6 +262,20 @@ def component_invoker(**kwargs): logger.debug(f"Invoking component {type(component)} with kwargs: {converted_kwargs}") return component.run(**converted_kwargs) + # Generate a name for the tool if not provided + if not name: + class_name = component.__class__.__name__ + # Convert camelCase/PascalCase to snake_case + name = "".join( + [ + "_" + c.lower() if c.isupper() and i > 0 and not class_name[i - 1].isupper() else c.lower() + for i, c in enumerate(class_name) + ] + ).lstrip("_") + + # Generate a description for the tool if not provided and truncate to 512 characters + description = (description or component.__doc__ or name)[:512] + # Return a new Tool instance with the component invoker as the function to be called return Tool(name=name, description=description, parameters=tool_schema, function=component_invoker) diff --git a/test/components/tools/test_tool_component.py b/test/components/tools/test_tool_component.py index 41d2ec38..94bd085d 100644 --- a/test/components/tools/test_tool_component.py +++ b/test/components/tools/test_tool_component.py @@ -148,14 +148,10 @@ class TestToolComponent: def test_from_component_basic(self): component = SimpleComponent() - tool = Tool.from_component( - component=component, - name="hello_tool", - description="A hello tool" - ) + tool = Tool.from_component(component) - assert tool.name == "hello_tool" - assert tool.description == "A hello tool" + assert tool.name == "simple_component" + assert tool.description == "A simple component that generates text." assert tool.parameters == { "type": "object", "properties": { @@ -176,12 +172,7 @@ def test_from_component_basic(self): def test_from_component_with_dataclass(self): component = UserGreeter() - tool = Tool.from_component( - component=component, - name="user_info_tool", - description="A tool that returns user information" - ) - + tool = Tool.from_component(component) assert tool.parameters == { "type": "object", "properties": { @@ -203,6 +194,9 @@ def test_from_component_with_dataclass(self): "required": ["user"] } + assert tool.name == "user_greeter" + assert tool.description == "A simple component that processes a User." + # Test tool invocation result = tool.invoke(user={"name": "Alice", "age": 30}) assert isinstance(result, dict) From 144474928a0506992f981923723b14dabb6ff0c4 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 6 Jan 2025 14:59:51 +0100 Subject: [PATCH 18/18] Switch to gpt-4o-mini --- test/components/tools/test_tool_component.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/components/tools/test_tool_component.py b/test/components/tools/test_tool_component.py index 94bd085d..3f236ca7 100644 --- a/test/components/tools/test_tool_component.py +++ b/test/components/tools/test_tool_component.py @@ -452,7 +452,7 @@ def test_component_tool_in_pipeline(self): # Create pipeline with OpenAIChatGenerator and ToolInvoker pipeline = Pipeline() - pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o", tools=[tool])) + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) # Connect components @@ -483,7 +483,7 @@ def test_user_greeter_in_pipeline(self): ) pipeline = Pipeline() - pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o", tools=[tool])) + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) pipeline.connect("llm.replies", "tool_invoker.messages") @@ -509,7 +509,7 @@ def test_list_processor_in_pipeline(self): ) pipeline = Pipeline() - pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o", tools=[tool])) + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) pipeline.connect("llm.replies", "tool_invoker.messages") @@ -535,7 +535,7 @@ def test_product_processor_in_pipeline(self): ) pipeline = Pipeline() - pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o", tools=[tool])) + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) pipeline.connect("llm.replies", "tool_invoker.messages") @@ -561,7 +561,7 @@ def test_person_processor_in_pipeline(self): ) pipeline = Pipeline() - pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o", tools=[tool])) + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) pipeline.connect("llm.replies", "tool_invoker.messages") @@ -587,7 +587,7 @@ def test_document_processor_in_pipeline(self): ) pipeline = Pipeline() - pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o", tools=[tool])) + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool], convert_result_to_json_string=True)) pipeline.connect("llm.replies", "tool_invoker.messages")