From 3c598d25a67b5779a2cc45b9207cf5b992146f4e Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 5 Sep 2024 10:36:42 -0400 Subject: [PATCH] core[minor]: Add get_input_jsonschema, get_output_jsonschema, get_config_jsonschema (#26034) This PR adds methods to directly get the json schema for inputs, outputs, and config. Currently, it's delegating to the underlying pydantic implementation, but this may be changed in the future to be independent of pydantic. --- libs/core/langchain_core/runnables/base.py | 71 +++ .../prompts/__snapshots__/test_chat.ambr | 43 +- .../tests/unit_tests/prompts/test_chat.py | 5 +- .../tests/unit_tests/prompts/test_prompt.py | 150 +++--- .../__snapshots__/test_runnable.ambr | 58 +-- .../tests/unit_tests/runnables/test_graph.py | 5 +- .../unit_tests/runnables/test_runnable.py | 463 ++++++++++-------- 7 files changed, 479 insertions(+), 316 deletions(-) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index b4989030f389c..444b191c2f55d 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -350,6 +350,34 @@ def get_input_schema( __root__=root_type, ) + def get_input_jsonschema( + self, config: Optional[RunnableConfig] = None + ) -> Dict[str, Any]: + """Get a JSON schema that represents the input to the Runnable. + + Args: + config: A config to use when generating the schema. + + Returns: + A JSON schema that represents the input to the Runnable. + + Example: + + .. code-block:: python + + from langchain_core.runnables import RunnableLambda + + def add_one(x: int) -> int: + return x + 1 + + runnable = RunnableLambda(add_one) + + print(runnable.get_input_jsonschema()) + + .. versionadded:: 0.3.0 + """ + return self.get_input_schema(config).model_json_schema() + @property def output_schema(self) -> Type[BaseModel]: """The type of output this Runnable produces specified as a pydantic model.""" @@ -382,6 +410,34 @@ def get_output_schema( __root__=root_type, ) + def get_output_jsonschema( + self, config: Optional[RunnableConfig] = None + ) -> Dict[str, Any]: + """Get a JSON schema that represents the output of the Runnable. + + Args: + config: A config to use when generating the schema. + + Returns: + A JSON schema that represents the output of the Runnable. + + Example: + + .. code-block:: python + + from langchain_core.runnables import RunnableLambda + + def add_one(x: int) -> int: + return x + 1 + + runnable = RunnableLambda(add_one) + + print(runnable.get_output_jsonschema()) + + .. versionadded:: 0.3.0 + """ + return self.get_output_schema(config).model_json_schema() + @property def config_specs(self) -> List[ConfigurableFieldSpec]: """List configurable fields for this Runnable.""" @@ -435,6 +491,21 @@ def config_schema( ) return model + def get_config_jsonschema( + self, *, include: Optional[Sequence[str]] = None + ) -> Dict[str, Any]: + """Get a JSON schema that represents the output of the Runnable. + + Args: + include: A list of fields to include in the config schema. + + Returns: + A JSON schema that represents the output of the Runnable. + + .. versionadded:: 0.3.0 + """ + return self.config_schema(include=include).model_json_schema() + def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: """Return a graph representation of this Runnable.""" from langchain_core.runnables.graph import Graph diff --git a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr index 83ac40f37a81d..3d9d246164656 100644 --- a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr +++ b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr @@ -1,7 +1,7 @@ # serializer version: 1 # name: test_chat_input_schema[partial] dict({ - 'definitions': dict({ + '$defs': dict({ 'AIMessage': dict({ 'additionalProperties': True, 'description': ''' @@ -60,7 +60,7 @@ 'default': list([ ]), 'items': dict({ - '$ref': '#/definitions/InvalidToolCall', + '$ref': '#/$defs/InvalidToolCall', }), 'title': 'Invalid Tool Calls', 'type': 'array', @@ -85,7 +85,7 @@ 'default': list([ ]), 'items': dict({ - '$ref': '#/definitions/ToolCall', + '$ref': '#/$defs/ToolCall', }), 'title': 'Tool Calls', 'type': 'array', @@ -102,7 +102,7 @@ 'usage_metadata': dict({ 'anyOf': list([ dict({ - '$ref': '#/definitions/UsageMetadata', + '$ref': '#/$defs/UsageMetadata', }), dict({ 'type': 'null', @@ -1133,6 +1133,7 @@ 'type': 'object', }), 'artifact': dict({ + 'default': None, 'title': 'Artifact', }), 'content': dict({ @@ -1345,25 +1346,26 @@ }), 'properties': dict({ 'history': dict({ + 'default': None, 'items': dict({ 'anyOf': list([ dict({ - '$ref': '#/definitions/AIMessage', + '$ref': '#/$defs/AIMessage', }), dict({ - '$ref': '#/definitions/HumanMessage', + '$ref': '#/$defs/HumanMessage', }), dict({ - '$ref': '#/definitions/ChatMessage', + '$ref': '#/$defs/ChatMessage', }), dict({ - '$ref': '#/definitions/SystemMessage', + '$ref': '#/$defs/SystemMessage', }), dict({ - '$ref': '#/definitions/FunctionMessage', + '$ref': '#/$defs/FunctionMessage', }), dict({ - '$ref': '#/definitions/ToolMessage', + '$ref': '#/$defs/ToolMessage', }), dict({ '$ref': '#/definitions/AIMessageChunk', @@ -1402,7 +1404,7 @@ # --- # name: test_chat_input_schema[required] dict({ - 'definitions': dict({ + '$defs': dict({ 'AIMessage': dict({ 'additionalProperties': True, 'description': ''' @@ -1461,7 +1463,7 @@ 'default': list([ ]), 'items': dict({ - '$ref': '#/definitions/InvalidToolCall', + '$ref': '#/$defs/InvalidToolCall', }), 'title': 'Invalid Tool Calls', 'type': 'array', @@ -1486,7 +1488,7 @@ 'default': list([ ]), 'items': dict({ - '$ref': '#/definitions/ToolCall', + '$ref': '#/$defs/ToolCall', }), 'title': 'Tool Calls', 'type': 'array', @@ -1503,7 +1505,7 @@ 'usage_metadata': dict({ 'anyOf': list([ dict({ - '$ref': '#/definitions/UsageMetadata', + '$ref': '#/$defs/UsageMetadata', }), dict({ 'type': 'null', @@ -2534,6 +2536,7 @@ 'type': 'object', }), 'artifact': dict({ + 'default': None, 'title': 'Artifact', }), 'content': dict({ @@ -2749,22 +2752,22 @@ 'items': dict({ 'anyOf': list([ dict({ - '$ref': '#/definitions/AIMessage', + '$ref': '#/$defs/AIMessage', }), dict({ - '$ref': '#/definitions/HumanMessage', + '$ref': '#/$defs/HumanMessage', }), dict({ - '$ref': '#/definitions/ChatMessage', + '$ref': '#/$defs/ChatMessage', }), dict({ - '$ref': '#/definitions/SystemMessage', + '$ref': '#/$defs/SystemMessage', }), dict({ - '$ref': '#/definitions/FunctionMessage', + '$ref': '#/$defs/FunctionMessage', }), dict({ - '$ref': '#/definitions/ToolMessage', + '$ref': '#/$defs/ToolMessage', }), dict({ '$ref': '#/definitions/AIMessageChunk', diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 08f5540f7c5a5..a5145fbfaf970 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -31,7 +31,6 @@ SystemMessagePromptTemplate, _convert_to_message, ) -from tests.unit_tests.pydantic_utils import _schema @pytest.fixture @@ -796,14 +795,14 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None: assert prompt_all_required.optional_variables == [] with pytest.raises(ValidationError): prompt_all_required.input_schema(input="") - assert _schema(prompt_all_required.input_schema) == snapshot(name="required") + assert prompt_all_required.get_input_jsonschema() == snapshot(name="required") prompt_optional = ChatPromptTemplate( messages=[MessagesPlaceholder("history", optional=True), ("user", "${input}")] ) # input variables only lists required variables assert set(prompt_optional.input_variables) == {"input"} prompt_optional.input_schema(input="") # won't raise error - assert _schema(prompt_optional.input_schema) == snapshot(name="partial") + assert prompt_optional.get_input_jsonschema() == snapshot(name="partial") def test_chat_prompt_w_msgs_placeholder_ser_des(snapshot: SnapshotAssertion) -> None: diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index a9bee11ba8933..396114c718d29 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -7,7 +7,6 @@ from langchain_core.prompts.prompt import PromptTemplate from langchain_core.tracers.run_collector import RunCollectorCallbackHandler -from tests.unit_tests.pydantic_utils import _schema def test_prompt_valid() -> None: @@ -70,10 +69,10 @@ def test_mustache_prompt_from_template() -> None: prompt = PromptTemplate.from_template(template, template_format="mustache") assert prompt.format(foo="bar") == "This is a bar test." assert prompt.input_variables == ["foo"] - assert _schema(prompt.input_schema) == { + assert prompt.get_input_jsonschema() == { "title": "PromptInput", "type": "object", - "properties": {"foo": {"title": "Foo", "type": "string"}}, + "properties": {"foo": {"title": "Foo", "type": "string", "default": None}}, } # Multiple input variables. @@ -81,12 +80,12 @@ def test_mustache_prompt_from_template() -> None: prompt = PromptTemplate.from_template(template, template_format="mustache") assert prompt.format(bar="baz", foo="bar") == "This baz is a bar test." assert prompt.input_variables == ["bar", "foo"] - assert _schema(prompt.input_schema) == { + assert prompt.get_input_jsonschema() == { "title": "PromptInput", "type": "object", "properties": { - "bar": {"title": "Bar", "type": "string"}, - "foo": {"title": "Foo", "type": "string"}, + "bar": {"title": "Bar", "type": "string", "default": None}, + "foo": {"title": "Foo", "type": "string", "default": None}, }, } @@ -95,12 +94,12 @@ def test_mustache_prompt_from_template() -> None: prompt = PromptTemplate.from_template(template, template_format="mustache") assert prompt.format(bar="baz", foo="bar") == "This baz is a bar test bar." assert prompt.input_variables == ["bar", "foo"] - assert _schema(prompt.input_schema) == { + assert prompt.get_input_jsonschema() == { "title": "PromptInput", "type": "object", "properties": { - "bar": {"title": "Bar", "type": "string"}, - "foo": {"title": "Foo", "type": "string"}, + "bar": {"title": "Bar", "type": "string", "default": None}, + "foo": {"title": "Foo", "type": "string", "default": None}, }, } @@ -111,23 +110,23 @@ def test_mustache_prompt_from_template() -> None: "This foo is a bar test baz." ) assert prompt.input_variables == ["foo", "obj"] - assert _schema(prompt.input_schema) == { - "title": "PromptInput", - "type": "object", - "properties": { - "foo": {"title": "Foo", "type": "string"}, - "obj": {"$ref": "#/definitions/obj"}, - }, - "definitions": { + assert prompt.get_input_jsonschema() == { + "$defs": { "obj": { - "title": "obj", - "type": "object", "properties": { - "foo": {"title": "Foo", "type": "string"}, - "bar": {"title": "Bar", "type": "string"}, + "bar": {"default": None, "title": "Bar", "type": "string"}, + "foo": {"default": None, "title": "Foo", "type": "string"}, }, + "title": "obj", + "type": "object", } }, + "properties": { + "foo": {"default": None, "title": "Foo", "type": "string"}, + "obj": {"allOf": [{"$ref": "#/$defs/obj"}], "default": None}, + }, + "title": "PromptInput", + "type": "object", } # . variables @@ -135,7 +134,7 @@ def test_mustache_prompt_from_template() -> None: prompt = PromptTemplate.from_template(template, template_format="mustache") assert prompt.format(foo="baz") == ("This {'foo': 'baz'} is a test.") assert prompt.input_variables == [] - assert _schema(prompt.input_schema) == { + assert prompt.get_input_jsonschema() == { "title": "PromptInput", "type": "object", "properties": {}, @@ -152,17 +151,19 @@ def test_mustache_prompt_from_template() -> None: is a test.""" ) assert prompt.input_variables == ["foo"] - assert _schema(prompt.input_schema) == { - "title": "PromptInput", - "type": "object", - "properties": {"foo": {"$ref": "#/definitions/foo"}}, - "definitions": { + assert prompt.get_input_jsonschema() == { + "$defs": { "foo": { + "properties": { + "bar": {"default": None, "title": "Bar", "type": "string"} + }, "title": "foo", "type": "object", - "properties": {"bar": {"title": "Bar", "type": "string"}}, } }, + "properties": {"foo": {"allOf": [{"$ref": "#/$defs/foo"}], "default": None}}, + "title": "PromptInput", + "type": "object", } # more complex nested section/context variables @@ -184,26 +185,28 @@ def test_mustache_prompt_from_template() -> None: is a test.""" ) assert prompt.input_variables == ["foo"] - assert _schema(prompt.input_schema) == { - "title": "PromptInput", - "type": "object", - "properties": {"foo": {"$ref": "#/definitions/foo"}}, - "definitions": { - "foo": { - "title": "foo", - "type": "object", + assert prompt.get_input_jsonschema() == { + "$defs": { + "baz": { "properties": { - "bar": {"title": "Bar", "type": "string"}, - "baz": {"$ref": "#/definitions/baz"}, - "quux": {"title": "Quux", "type": "string"}, + "qux": {"default": None, "title": "Qux", "type": "string"} }, - }, - "baz": { "title": "baz", "type": "object", - "properties": {"qux": {"title": "Qux", "type": "string"}}, + }, + "foo": { + "properties": { + "bar": {"default": None, "title": "Bar", "type": "string"}, + "baz": {"allOf": [{"$ref": "#/$defs/baz"}], "default": None}, + "quux": {"default": None, "title": "Quux", "type": "string"}, + }, + "title": "foo", + "type": "object", }, }, + "properties": {"foo": {"allOf": [{"$ref": "#/$defs/foo"}], "default": None}}, + "title": "PromptInput", + "type": "object", } # triply nested section/context variables @@ -239,39 +242,43 @@ def test_mustache_prompt_from_template() -> None: is a test.""" ) assert prompt.input_variables == ["foo"] - assert _schema(prompt.input_schema) == { - "title": "PromptInput", - "type": "object", - "properties": {"foo": {"$ref": "#/definitions/foo"}}, - "definitions": { - "foo": { - "title": "foo", - "type": "object", + assert prompt.get_input_jsonschema() == { + "$defs": { + "barfoo": { "properties": { - "bar": {"title": "Bar", "type": "string"}, - "baz": {"$ref": "#/definitions/baz"}, - "quux": {"title": "Quux", "type": "string"}, + "foobar": {"default": None, "title": "Foobar", "type": "string"} }, + "title": "barfoo", + "type": "object", }, "baz": { + "properties": { + "qux": {"allOf": [{"$ref": "#/$defs/qux"}], "default": None} + }, "title": "baz", "type": "object", - "properties": {"qux": {"$ref": "#/definitions/qux"}}, }, - "qux": { - "title": "qux", - "type": "object", + "foo": { "properties": { - "foobar": {"title": "Foobar", "type": "string"}, - "barfoo": {"$ref": "#/definitions/barfoo"}, + "bar": {"default": None, "title": "Bar", "type": "string"}, + "baz": {"allOf": [{"$ref": "#/$defs/baz"}], "default": None}, + "quux": {"default": None, "title": "Quux", "type": "string"}, }, + "title": "foo", + "type": "object", }, - "barfoo": { - "title": "barfoo", + "qux": { + "properties": { + "barfoo": {"allOf": [{"$ref": "#/$defs/barfoo"}], "default": None}, + "foobar": {"default": None, "title": "Foobar", "type": "string"}, + }, + "title": "qux", "type": "object", - "properties": {"foobar": {"title": "Foobar", "type": "string"}}, }, }, + "properties": {"foo": {"allOf": [{"$ref": "#/$defs/foo"}], "default": None}}, + "title": "PromptInput", + "type": "object", } # section/context variables with repeats @@ -287,19 +294,20 @@ def test_mustache_prompt_from_template() -> None: is a test.""" ) assert prompt.input_variables == ["foo"] - assert _schema(prompt.input_schema) == { - "title": "PromptInput", - "type": "object", - "properties": {"foo": {"$ref": "#/definitions/foo"}}, - "definitions": { + assert prompt.get_input_jsonschema() == { + "$defs": { "foo": { + "properties": { + "bar": {"default": None, "title": "Bar", "type": "string"} + }, "title": "foo", "type": "object", - "properties": {"bar": {"title": "Bar", "type": "string"}}, } }, + "properties": {"foo": {"allOf": [{"$ref": "#/$defs/foo"}], "default": None}}, + "title": "PromptInput", + "type": "object", } - template = """This{{^foo}} no foos {{/foo}}is a test.""" @@ -310,10 +318,10 @@ def test_mustache_prompt_from_template() -> None: is a test.""" ) assert prompt.input_variables == ["foo"] - assert _schema(prompt.input_schema) == { + assert prompt.get_input_jsonschema() == { + "properties": {"foo": {"default": None, "title": "Foo", "type": "object"}}, "title": "PromptInput", "type": "object", - "properties": {"foo": {"title": "Foo", "type": "object"}}, } diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr index 164e3f8366fa3..ea1ec005fee4c 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr @@ -5178,7 +5178,7 @@ # --- # name: test_schemas[chat_prompt_input_schema] dict({ - 'definitions': dict({ + '$defs': dict({ 'AIMessage': dict({ 'additionalProperties': True, 'description': ''' @@ -5237,7 +5237,7 @@ 'default': list([ ]), 'items': dict({ - '$ref': '#/definitions/InvalidToolCall', + '$ref': '#/$defs/InvalidToolCall', }), 'title': 'Invalid Tool Calls', 'type': 'array', @@ -5262,7 +5262,7 @@ 'default': list([ ]), 'items': dict({ - '$ref': '#/definitions/ToolCall', + '$ref': '#/$defs/ToolCall', }), 'title': 'Tool Calls', 'type': 'array', @@ -5279,7 +5279,7 @@ 'usage_metadata': dict({ 'anyOf': list([ dict({ - '$ref': '#/definitions/UsageMetadata', + '$ref': '#/$defs/UsageMetadata', }), dict({ 'type': 'null', @@ -6310,6 +6310,7 @@ 'type': 'object', }), 'artifact': dict({ + 'default': None, 'title': 'Artifact', }), 'content': dict({ @@ -6525,22 +6526,22 @@ 'items': dict({ 'anyOf': list([ dict({ - '$ref': '#/definitions/AIMessage', + '$ref': '#/$defs/AIMessage', }), dict({ - '$ref': '#/definitions/HumanMessage', + '$ref': '#/$defs/HumanMessage', }), dict({ - '$ref': '#/definitions/ChatMessage', + '$ref': '#/$defs/ChatMessage', }), dict({ - '$ref': '#/definitions/SystemMessage', + '$ref': '#/$defs/SystemMessage', }), dict({ - '$ref': '#/definitions/FunctionMessage', + '$ref': '#/$defs/FunctionMessage', }), dict({ - '$ref': '#/definitions/ToolMessage', + '$ref': '#/$defs/ToolMessage', }), dict({ '$ref': '#/definitions/AIMessageChunk', @@ -6575,15 +6576,7 @@ # --- # name: test_schemas[chat_prompt_output_schema] dict({ - 'anyOf': list([ - dict({ - '$ref': '#/definitions/StringPromptValue', - }), - dict({ - '$ref': '#/definitions/ChatPromptValueConcrete', - }), - ]), - 'definitions': dict({ + '$defs': dict({ 'AIMessage': dict({ 'additionalProperties': True, 'description': ''' @@ -6642,7 +6635,7 @@ 'default': list([ ]), 'items': dict({ - '$ref': '#/definitions/InvalidToolCall', + '$ref': '#/$defs/InvalidToolCall', }), 'title': 'Invalid Tool Calls', 'type': 'array', @@ -6667,7 +6660,7 @@ 'default': list([ ]), 'items': dict({ - '$ref': '#/definitions/ToolCall', + '$ref': '#/$defs/ToolCall', }), 'title': 'Tool Calls', 'type': 'array', @@ -6684,7 +6677,7 @@ 'usage_metadata': dict({ 'anyOf': list([ dict({ - '$ref': '#/definitions/UsageMetadata', + '$ref': '#/$defs/UsageMetadata', }), dict({ 'type': 'null', @@ -6981,22 +6974,22 @@ 'items': dict({ 'anyOf': list([ dict({ - '$ref': '#/definitions/AIMessage', + '$ref': '#/$defs/AIMessage', }), dict({ - '$ref': '#/definitions/HumanMessage', + '$ref': '#/$defs/HumanMessage', }), dict({ - '$ref': '#/definitions/ChatMessage', + '$ref': '#/$defs/ChatMessage', }), dict({ - '$ref': '#/definitions/SystemMessage', + '$ref': '#/$defs/SystemMessage', }), dict({ - '$ref': '#/definitions/FunctionMessage', + '$ref': '#/$defs/FunctionMessage', }), dict({ - '$ref': '#/definitions/ToolMessage', + '$ref': '#/$defs/ToolMessage', }), dict({ '$ref': '#/definitions/AIMessageChunk', @@ -7804,6 +7797,7 @@ 'type': 'object', }), 'artifact': dict({ + 'default': None, 'title': 'Artifact', }), 'content': dict({ @@ -8014,6 +8008,14 @@ 'type': 'object', }), }), + 'anyOf': list([ + dict({ + '$ref': '#/$defs/StringPromptValue', + }), + dict({ + '$ref': '#/$defs/ChatPromptValueConcrete', + }), + ]), 'title': 'ChatPromptTemplateOutput', }) # --- diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index 06ac81277b254..c681db36365ac 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -11,7 +11,6 @@ from langchain_core.runnables.base import Runnable, RunnableConfig from langchain_core.runnables.graph import Edge, Graph, Node from langchain_core.runnables.graph_mermaid import _escape_node_label -from tests.unit_tests.pydantic_utils import _schema def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None: @@ -19,10 +18,10 @@ def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None: graph = StrOutputParser().get_graph() first_node = graph.first_node() assert first_node is not None - assert _schema(first_node.data) == _schema(runnable.input_schema) # type: ignore[union-attr] + assert first_node.data.schema() == runnable.get_input_jsonschema() # type: ignore[union-attr] last_node = graph.last_node() assert last_node is not None - assert _schema(last_node.data) == _schema(runnable.output_schema) # type: ignore[union-attr] + assert last_node.data.schema() == runnable.get_output_jsonschema() # type: ignore[union-attr] assert len(graph.nodes) == 3 assert len(graph.edges) == 2 assert graph.edges[0].source == first_node.id diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index de15d43cbc394..2ad80e95a43a7 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -90,7 +90,7 @@ RunLogPatch, ) from langchain_core.tracers.context import collect_runs -from tests.unit_tests.pydantic_utils import _schema, replace_all_of_with_ref +from tests.unit_tests.pydantic_utils import _schema from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk @@ -226,42 +226,47 @@ async def _aget_relevant_documents( def test_schemas(snapshot: SnapshotAssertion) -> None: fake = FakeRunnable() # str -> int - assert _schema(fake.input_schema) == { + assert fake.get_input_jsonschema() == { "title": "FakeRunnableInput", "type": "string", } - assert _schema(fake.output_schema) == { + assert fake.get_output_jsonschema() == { "title": "FakeRunnableOutput", "type": "integer", } - assert _schema(fake.config_schema(include=["tags", "metadata", "run_name"])) == { - "title": "FakeRunnableConfig", - "type": "object", + assert fake.get_config_jsonschema(include=["tags", "metadata", "run_name"]) == { "properties": { - "metadata": {"title": "Metadata", "type": "object"}, - "run_name": {"title": "Run Name", "type": "string"}, - "tags": {"items": {"type": "string"}, "title": "Tags", "type": "array"}, + "metadata": {"default": None, "title": "Metadata", "type": "object"}, + "run_name": {"default": None, "title": "Run Name", "type": "string"}, + "tags": { + "default": None, + "items": {"type": "string"}, + "title": "Tags", + "type": "array", + }, }, + "title": "FakeRunnableConfig", + "type": "object", } fake_bound = FakeRunnable().bind(a="b") # str -> int - assert _schema(fake_bound.input_schema) == { + assert fake_bound.get_input_jsonschema() == { "title": "FakeRunnableInput", "type": "string", } - assert _schema(fake_bound.output_schema) == { + assert fake_bound.get_output_jsonschema() == { "title": "FakeRunnableOutput", "type": "integer", } fake_w_fallbacks = FakeRunnable().with_fallbacks((fake,)) # str -> int - assert _schema(fake_w_fallbacks.input_schema) == { + assert fake_w_fallbacks.get_input_jsonschema() == { "title": "FakeRunnableInput", "type": "string", } - assert _schema(fake_w_fallbacks.output_schema) == { + assert fake_w_fallbacks.get_output_jsonschema() == { "title": "FakeRunnableOutput", "type": "integer", } @@ -271,11 +276,11 @@ def typed_lambda_impl(x: str) -> int: typed_lambda = RunnableLambda(typed_lambda_impl) # str -> int - assert _schema(typed_lambda.input_schema) == { + assert typed_lambda.get_input_jsonschema() == { "title": "typed_lambda_impl_input", "type": "string", } - assert _schema(typed_lambda.output_schema) == { + assert typed_lambda.get_output_jsonschema() == { "title": "typed_lambda_impl_output", "type": "integer", } @@ -285,49 +290,64 @@ async def typed_async_lambda_impl(x: str) -> int: typed_async_lambda: Runnable = RunnableLambda(typed_async_lambda_impl) # str -> int - assert _schema(typed_async_lambda.input_schema) == { + assert typed_async_lambda.get_input_jsonschema() == { "title": "typed_async_lambda_impl_input", "type": "string", } - assert _schema(typed_async_lambda.output_schema) == { + assert typed_async_lambda.get_output_jsonschema() == { "title": "typed_async_lambda_impl_output", "type": "integer", } fake_ret = FakeRetriever() # str -> List[Document] - assert _schema(fake_ret.input_schema) == { + assert fake_ret.get_input_jsonschema() == { "title": "FakeRetrieverInput", "type": "string", } - assert _schema(fake_ret.output_schema) == { - "title": "FakeRetrieverOutput", - "type": "array", - "items": {"$ref": "#/definitions/Document"}, - "definitions": { + assert fake_ret.get_output_jsonschema() == { + "$defs": { "Document": { - "title": "Document", - "description": AnyStr(), - "type": "object", + "description": "Class for storing a piece of text and " + "associated metadata.\n" + "\n" + "Example:\n" + "\n" + " .. code-block:: python\n" + "\n" + " from langchain_core.documents " + "import Document\n" + "\n" + " document = Document(\n" + ' page_content="Hello, ' + 'world!",\n' + ' metadata={"source": ' + '"https://example.com"}\n' + " )", "properties": { - "page_content": {"title": "Page Content", "type": "string"}, - "metadata": {"title": "Metadata", "type": "object"}, "id": { - "title": "Id", "anyOf": [{"type": "string"}, {"type": "null"}], "default": None, + "title": "Id", }, + "metadata": {"title": "Metadata", "type": "object"}, + "page_content": {"title": "Page Content", "type": "string"}, "type": { - "title": "Type", - "enum": ["Document"], - "default": "Document", "const": "Document", + "default": "Document", + "enum": ["Document"], + "title": "Type", "type": "string", }, }, "required": ["page_content"], + "title": "Document", + "type": "object", } }, + "items": {"$ref": "#/$defs/Document"}, + "title": "FakeRetrieverOutput", + "type": "array", } fake_llm = FakeListLLM(responses=["a"]) # str -> List[List[str]] @@ -350,16 +370,16 @@ async def typed_async_lambda_impl(x: str) -> int: ] ) - assert _schema(chat_prompt.input_schema) == snapshot( + assert chat_prompt.get_input_jsonschema() == snapshot( name="chat_prompt_input_schema" ) - assert _schema(chat_prompt.output_schema) == snapshot( + assert chat_prompt.get_output_jsonschema() == snapshot( name="chat_prompt_output_schema" ) prompt = PromptTemplate.from_template("Hello, {name}!") - assert _schema(prompt.input_schema) == { + assert prompt.get_input_jsonschema() == { "title": "PromptInput", "type": "object", "properties": {"name": {"title": "Name", "type": "string"}}, @@ -369,8 +389,8 @@ async def typed_async_lambda_impl(x: str) -> int: prompt_mapper = PromptTemplate.from_template("Hello, {name}!").map() - assert _schema(prompt_mapper.input_schema) == { - "definitions": { + assert prompt_mapper.get_input_jsonschema() == { + "$defs": { "PromptInput": { "properties": {"name": {"title": "Name", "type": "string"}}, "required": ["name"], @@ -378,9 +398,10 @@ async def typed_async_lambda_impl(x: str) -> int: "type": "object", } }, - "items": {"$ref": "#/definitions/PromptInput"}, - "type": "array", + "default": None, + "items": {"$ref": "#/$defs/PromptInput"}, "title": "RunnableEachInput", + "type": "array", } assert _schema(prompt_mapper.output_schema) == snapshot( name="prompt_mapper_output_schema" @@ -399,13 +420,13 @@ async def typed_async_lambda_impl(x: str) -> int: seq = prompt | fake_llm | list_parser - assert _schema(seq.input_schema) == { + assert seq.get_input_jsonschema() == { "title": "PromptInput", "type": "object", "properties": {"name": {"title": "Name", "type": "string"}}, "required": ["name"], } - assert _schema(seq.output_schema) == { + assert seq.get_output_jsonschema() == { "type": "array", "items": {"type": "string"}, "title": "CommaSeparatedListOutputParserOutput", @@ -435,7 +456,7 @@ async def typed_async_lambda_impl(x: str) -> int: }, "title": "RouterRunnableInput", } - assert _schema(router.output_schema) == {"title": "RouterRunnableOutput"} + assert router.get_output_jsonschema() == {"title": "RouterRunnableOutput"} seq_w_map: Runnable = ( prompt @@ -447,13 +468,13 @@ async def typed_async_lambda_impl(x: str) -> int: } ) - assert _schema(seq_w_map.input_schema) == { + assert seq_w_map.get_input_jsonschema() == { "title": "PromptInput", "type": "object", "properties": {"name": {"title": "Name", "type": "string"}}, "required": ["name"], } - assert _schema(seq_w_map.output_schema) == { + assert seq_w_map.get_output_jsonschema() == { "title": "RunnableParallelOutput", "type": "object", "properties": { @@ -493,13 +514,13 @@ def test_passthrough_assign_schema() -> None: | fake_llm ) - assert _schema(seq_w_assign.input_schema) == { + assert seq_w_assign.get_input_jsonschema() == { "properties": {"question": {"title": "Question", "type": "string"}}, "title": "RunnableSequenceInput", "type": "object", "required": ["question"], } - assert _schema(seq_w_assign.output_schema) == { + assert seq_w_assign.get_output_jsonschema() == { "title": "FakeListLLMOutput", "type": "string", } @@ -511,7 +532,7 @@ def test_passthrough_assign_schema() -> None: # fallback to RunnableAssign.input_schema if next runnable doesn't have # expected dict input_schema - assert _schema(invalid_seq_w_assign.input_schema) == { + assert invalid_seq_w_assign.get_input_jsonschema() == { "properties": {"question": {"title": "Question"}}, "title": "RunnableParallelInput", "type": "object", @@ -524,7 +545,7 @@ def test_passthrough_assign_schema() -> None: ) def test_lambda_schemas() -> None: first_lambda = lambda x: x["hello"] # noqa: E731 - assert _schema(RunnableLambda(first_lambda).input_schema) == { + assert RunnableLambda(first_lambda).get_input_jsonschema() == { "title": "RunnableLambdaInput", "type": "object", "properties": {"hello": {"title": "Hello"}}, @@ -532,7 +553,7 @@ def test_lambda_schemas() -> None: } second_lambda = lambda x, y: (x["hello"], x["bye"], y["bah"]) # noqa: E731 - assert _schema(RunnableLambda(second_lambda).input_schema) == { # type: ignore[arg-type] + assert RunnableLambda(second_lambda).get_input_jsonschema() == { # type: ignore[arg-type] "title": "RunnableLambdaInput", "type": "object", "properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}}, @@ -542,7 +563,7 @@ def test_lambda_schemas() -> None: def get_value(input): # type: ignore[no-untyped-def] return input["variable_name"] - assert _schema(RunnableLambda(get_value).input_schema) == { + assert RunnableLambda(get_value).get_input_jsonschema() == { "title": "get_value_input", "type": "object", "properties": {"variable_name": {"title": "Variable Name"}}, @@ -552,7 +573,7 @@ def get_value(input): # type: ignore[no-untyped-def] async def aget_value(input): # type: ignore[no-untyped-def] return (input["variable_name"], input.get("another")) - assert _schema(RunnableLambda(aget_value).input_schema) == { + assert RunnableLambda(aget_value).get_input_jsonschema() == { "title": "aget_value_input", "type": "object", "properties": { @@ -569,7 +590,7 @@ async def aget_values(input): # type: ignore[no-untyped-def] "byebye": input["yo"], } - assert _schema(RunnableLambda(aget_values).input_schema) == { + assert RunnableLambda(aget_values).get_input_jsonschema() == { "title": "aget_values_input", "type": "object", "properties": { @@ -596,15 +617,11 @@ async def aget_values_typed(input: InputType) -> OutputType: } assert ( - _schema( - RunnableLambda( - aget_values_typed # type: ignore[arg-type] - ).input_schema - ) + RunnableLambda( + aget_values_typed # type: ignore[arg-type] + ).get_input_jsonschema() == { - "title": "aget_values_typed_input", - "$ref": "#/definitions/InputType", - "definitions": { + "$defs": { "InputType": { "properties": { "variable_name": { @@ -618,13 +635,13 @@ async def aget_values_typed(input: InputType) -> OutputType: "type": "object", } }, + "allOf": [{"$ref": "#/$defs/InputType"}], + "title": "aget_values_typed_input", } ) - assert _schema(RunnableLambda(aget_values_typed).output_schema) == { # type: ignore[arg-type] - "title": "aget_values_typed_output", - "$ref": "#/definitions/OutputType", - "definitions": { + assert RunnableLambda(aget_values_typed).get_output_jsonschema() == { # type: ignore[arg-type] + "$defs": { "OutputType": { "properties": { "bye": {"title": "Bye", "type": "string"}, @@ -636,6 +653,8 @@ async def aget_values_typed(input: InputType) -> OutputType: "type": "object", } }, + "allOf": [{"$ref": "#/$defs/OutputType"}], + "title": "aget_values_typed_output", } @@ -697,7 +716,7 @@ def test_schema_complex_seq() -> None: | StrOutputParser() ) - assert _schema(chain2.input_schema) == { + assert chain2.get_input_jsonschema() == { "title": "RunnableParallelInput", "type": "object", "properties": { @@ -707,17 +726,17 @@ def test_schema_complex_seq() -> None: "required": ["person", "language"], } - assert _schema(chain2.output_schema) == { + assert chain2.get_output_jsonschema() == { "title": "StrOutputParserOutput", "type": "string", } - assert _schema(chain2.with_types(input_type=str).input_schema) == { + assert chain2.with_types(input_type=str).get_input_jsonschema() == { "title": "RunnableSequenceInput", "type": "string", } - assert _schema(chain2.with_types(input_type=int).output_schema) == { + assert chain2.with_types(input_type=int).get_output_jsonschema() == { "title": "StrOutputParserOutput", "type": "string", } @@ -725,7 +744,7 @@ def test_schema_complex_seq() -> None: class InputType(BaseModel): person: str - assert _schema(chain2.with_types(input_type=InputType).input_schema) == { + assert chain2.with_types(input_type=InputType).get_input_jsonschema() == { "title": "InputType", "type": "object", "properties": {"person": {"title": "Person", "type": "string"}}, @@ -748,25 +767,37 @@ def test_configurable_fields() -> None: assert fake_llm_configurable.invoke("...") == "a" - assert _schema(fake_llm_configurable.config_schema()) == { - "title": "RunnableConfigurableFieldsConfig", - "type": "object", - "properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, - "definitions": { + assert fake_llm_configurable.get_config_jsonschema() == { + "$defs": { "Configurable": { - "title": "Configurable", - "type": "object", "properties": { "llm_responses": { - "title": "LLM Responses", - "description": "A list of fake responses for this LLM", "default": ["a"], - "type": "array", + "description": "A " + "list " + "of " + "fake " + "responses " + "for " + "this " + "LLM", "items": {"type": "string"}, + "title": "LLM " "Responses", + "type": "array", } }, + "title": "Configurable", + "type": "object", + } + }, + "properties": { + "configurable": { + "allOf": [{"$ref": "#/$defs/Configurable"}], + "default": None, } }, + "title": "RunnableConfigurableFieldsConfig", + "type": "object", } fake_llm_configured = fake_llm_configurable.with_config( @@ -791,24 +822,34 @@ def test_configurable_fields() -> None: text="Hello, John!" ) - assert _schema(prompt_configurable.config_schema()) == { - "title": "RunnableConfigurableFieldsConfig", - "type": "object", - "properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, - "definitions": { + assert prompt_configurable.get_config_jsonschema() == { + "$defs": { "Configurable": { - "title": "Configurable", - "type": "object", "properties": { "prompt_template": { - "title": "Prompt Template", - "description": "The prompt template for this chain", - "default": "Hello, {name}!", + "default": "Hello, " "{name}!", + "description": "The " + "prompt " + "template " + "for " + "this " + "chain", + "title": "Prompt " "Template", "type": "string", } }, + "title": "Configurable", + "type": "object", + } + }, + "properties": { + "configurable": { + "allOf": [{"$ref": "#/$defs/Configurable"}], + "default": None, } }, + "title": "RunnableConfigurableFieldsConfig", + "type": "object", } prompt_configured = prompt_configurable.with_config( @@ -819,11 +860,9 @@ def test_configurable_fields() -> None: text="Hello, John! John!" ) - assert _schema( - prompt_configurable.with_config( - configurable={"prompt_template": "Hello {name} in {lang}"} - ).input_schema - ) == { + assert prompt_configurable.with_config( + configurable={"prompt_template": "Hello {name} in {lang}"} + ).get_input_jsonschema() == { "title": "PromptInput", "type": "object", "properties": { @@ -837,31 +876,48 @@ def test_configurable_fields() -> None: assert chain_configurable.invoke({"name": "John"}) == "a" - assert _schema(chain_configurable.config_schema()) == { - "title": "RunnableSequenceConfig", - "type": "object", - "properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, - "definitions": { + assert chain_configurable.get_config_jsonschema() == { + "$defs": { "Configurable": { - "title": "Configurable", - "type": "object", "properties": { "llm_responses": { - "title": "LLM Responses", - "description": "A list of fake responses for this LLM", "default": ["a"], - "type": "array", + "description": "A " + "list " + "of " + "fake " + "responses " + "for " + "this " + "LLM", "items": {"type": "string"}, + "title": "LLM " "Responses", + "type": "array", }, "prompt_template": { - "title": "Prompt Template", - "description": "The prompt template for this chain", - "default": "Hello, {name}!", + "default": "Hello, " "{name}!", + "description": "The " + "prompt " + "template " + "for " + "this " + "chain", + "title": "Prompt " "Template", "type": "string", }, }, + "title": "Configurable", + "type": "object", } }, + "properties": { + "configurable": { + "allOf": [{"$ref": "#/$defs/Configurable"}], + "default": None, + } + }, + "title": "RunnableSequenceConfig", + "type": "object", } assert ( @@ -874,14 +930,12 @@ def test_configurable_fields() -> None: == "c" ) - assert _schema( - chain_configurable.with_config( - configurable={ - "prompt_template": "A very good morning to you, {name} {lang}!", - "llm_responses": ["c"], - } - ).input_schema - ) == { + assert chain_configurable.with_config( + configurable={ + "prompt_template": "A very good morning to you, {name} {lang}!", + "llm_responses": ["c"], + } + ).get_input_jsonschema() == { "title": "PromptInput", "type": "object", "properties": { @@ -906,37 +960,54 @@ def test_configurable_fields() -> None: "llm3": "a", } - assert _schema(chain_with_map_configurable.config_schema()) == { - "title": "RunnableSequenceConfig", - "type": "object", - "properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, - "definitions": { + assert chain_with_map_configurable.get_config_jsonschema() == { + "$defs": { "Configurable": { - "title": "Configurable", - "type": "object", "properties": { "llm_responses": { - "title": "LLM Responses", - "description": "A list of fake responses for this LLM", "default": ["a"], - "type": "array", + "description": "A " + "list " + "of " + "fake " + "responses " + "for " + "this " + "LLM", "items": {"type": "string"}, + "title": "LLM " "Responses", + "type": "array", }, "other_responses": { - "title": "Other Responses", "default": ["a"], - "type": "array", "items": {"type": "string"}, + "title": "Other " "Responses", + "type": "array", }, "prompt_template": { - "title": "Prompt Template", - "description": "The prompt template for this chain", - "default": "Hello, {name}!", + "default": "Hello, " "{name}!", + "description": "The " + "prompt " + "template " + "for " + "this " + "chain", + "title": "Prompt " "Template", "type": "string", }, }, + "title": "Configurable", + "type": "object", + } + }, + "properties": { + "configurable": { + "allOf": [{"$ref": "#/$defs/Configurable"}], + "default": None, } }, + "title": "RunnableSequenceConfig", + "type": "object", } assert chain_with_map_configurable.with_config( @@ -1008,6 +1079,9 @@ def test_configurable_fields_prefix_keys() -> None: chain = prompt | fake_llm assert _schema(chain.config_schema()) == { + "title": "RunnableSequenceConfig", + "type": "object", + "properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, "definitions": { "Chat_Responses": { "enum": ["hello", "bye", "helpful"], @@ -1068,9 +1142,6 @@ def test_configurable_fields_prefix_keys() -> None: "type": "string", }, }, - "properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, - "title": "RunnableSequenceConfig", - "type": "object", } @@ -1119,63 +1190,73 @@ def test_configurable_fields_example() -> None: chain_configurable = prompt | fake_llm | (lambda x: {"name": x}) | prompt | fake_llm assert chain_configurable.invoke({"name": "John"}) == "a" - expected = { - "title": "RunnableSequenceConfig", - "type": "object", - "properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, - "definitions": { - "LLM": { - "title": "LLM", - "enum": ["chat", "default"], - "type": "string", - }, + + assert chain_configurable.get_config_jsonschema() == { + "$defs": { "Chat_Responses": { "enum": ["hello", "bye", "helpful"], "title": "Chat Responses", "type": "string", }, - "Prompt_Template": { - "enum": ["hello", "good_morning"], - "title": "Prompt Template", - "type": "string", - }, "Configurable": { - "title": "Configurable", - "type": "object", "properties": { "chat_responses": { "default": ["hello", "bye"], - "items": {"$ref": "#/definitions/Chat_Responses"}, - "title": "Chat Responses", + "items": {"$ref": "#/$defs/Chat_Responses"}, + "title": "Chat " "Responses", "type": "array", }, "llm": { - "title": "LLM", + "allOf": [{"$ref": "#/$defs/LLM"}], "default": "default", - "allOf": [{"$ref": "#/definitions/LLM"}], + "title": "LLM", }, "llm_responses": { - "title": "LLM Responses", - "description": "A list of fake responses for this LLM", "default": ["a"], - "type": "array", + "description": "A " + "list " + "of " + "fake " + "responses " + "for " + "this " + "LLM", "items": {"type": "string"}, + "title": "LLM " "Responses", + "type": "array", }, "prompt_template": { - "title": "Prompt Template", - "description": "The prompt template for this chain", + "allOf": [{"$ref": "#/$defs/Prompt_Template"}], "default": "hello", - "allOf": [{"$ref": "#/definitions/Prompt_Template"}], + "description": "The " + "prompt " + "template " + "for " + "this " + "chain", + "title": "Prompt " "Template", }, }, + "title": "Configurable", + "type": "object", + }, + "LLM": {"enum": ["chat", "default"], "title": "LLM", "type": "string"}, + "Prompt_Template": { + "enum": ["hello", "good_morning"], + "title": "Prompt Template", + "type": "string", }, }, + "properties": { + "configurable": { + "allOf": [{"$ref": "#/$defs/Configurable"}], + "default": None, + } + }, + "title": "RunnableSequenceConfig", + "type": "object", } - replace_all_of_with_ref(expected) - - assert _schema(chain_configurable.config_schema()) == expected - assert ( chain_configurable.with_config(configurable={"llm": "chat"}).invoke( {"name": "John"} @@ -3204,7 +3285,7 @@ def test_map_stream() -> None: chain_pick_one = chain.pick("llm") - assert _schema(chain_pick_one.output_schema) == { + assert chain_pick_one.get_output_jsonschema() == { "title": "RunnableSequenceOutput", "type": "string", } @@ -3227,7 +3308,7 @@ def test_map_stream() -> None: ["llm", "hello"] ) - assert _schema(chain_pick_two.output_schema) == { + assert chain_pick_two.get_output_jsonschema() == { "title": "RunnableSequenceOutput", "type": "object", "properties": { @@ -3599,13 +3680,13 @@ def test_deep_stream_assign() -> None: chain_with_assign = chain.assign(hello=itemgetter("str") | llm) - assert _schema(chain_with_assign.input_schema) == { + assert chain_with_assign.get_input_jsonschema() == { "title": "PromptInput", "type": "object", "properties": {"question": {"title": "Question", "type": "string"}}, "required": ["question"], } - assert _schema(chain_with_assign.output_schema) == { + assert chain_with_assign.get_output_jsonschema() == { "title": "RunnableSequenceOutput", "type": "object", "properties": { @@ -3651,13 +3732,13 @@ def test_deep_stream_assign() -> None: hello=itemgetter("str") | llm, ) - assert _schema(chain_with_assign_shadow.input_schema) == { + assert chain_with_assign_shadow.get_input_jsonschema() == { "title": "PromptInput", "type": "object", "properties": {"question": {"title": "Question", "type": "string"}}, "required": ["question"], } - assert _schema(chain_with_assign_shadow.output_schema) == { + assert chain_with_assign_shadow.get_output_jsonschema() == { "title": "RunnableSequenceOutput", "type": "object", "properties": { @@ -3727,13 +3808,13 @@ async def test_deep_astream_assign() -> None: hello=itemgetter("str") | llm, ) - assert _schema(chain_with_assign.input_schema) == { + assert chain_with_assign.get_input_jsonschema() == { "title": "PromptInput", "type": "object", "properties": {"question": {"title": "Question", "type": "string"}}, "required": ["question"], } - assert _schema(chain_with_assign.output_schema) == { + assert chain_with_assign.get_output_jsonschema() == { "title": "RunnableSequenceOutput", "type": "object", "properties": { @@ -3779,13 +3860,13 @@ async def test_deep_astream_assign() -> None: hello=itemgetter("str") | llm, ) - assert _schema(chain_with_assign_shadow.input_schema) == { + assert chain_with_assign_shadow.get_input_jsonschema() == { "title": "PromptInput", "type": "object", "properties": {"question": {"title": "Question", "type": "string"}}, "required": ["question"], } - assert _schema(chain_with_assign_shadow.output_schema) == { + assert chain_with_assign_shadow.get_output_jsonschema() == { "title": "RunnableSequenceOutput", "type": "object", "properties": { @@ -4784,7 +4865,7 @@ async def test_tool_from_runnable() -> None: {"question": "What up"} ) assert chain_tool.description.endswith(repr(chain)) - assert _schema(chain_tool.args_schema) == _schema(chain.input_schema) + assert _schema(chain_tool.args_schema) == chain.get_input_jsonschema() assert _schema(chain_tool.args_schema) == { "properties": {"question": {"title": "Question", "type": "string"}}, "title": "PromptInput", @@ -4803,8 +4884,8 @@ def gen(input: Iterator[Any]) -> Iterator[int]: runnable = RunnableGenerator(gen) - assert _schema(runnable.input_schema) == {"title": "gen_input"} - assert _schema(runnable.output_schema) == { + assert runnable.get_input_jsonschema() == {"title": "gen_input"} + assert runnable.get_output_jsonschema() == { "title": "gen_output", "type": "integer", } @@ -4855,8 +4936,8 @@ def gen(input: Iterator[Any]) -> Iterator[int]: runnable = RunnableGenerator(gen) - assert _schema(runnable.input_schema) == {"title": "gen_input"} - assert _schema(runnable.output_schema) == { + assert runnable.get_input_jsonschema() == {"title": "gen_input"} + assert runnable.get_output_jsonschema() == { "title": "gen_output", "type": "integer", } @@ -4989,11 +5070,11 @@ def gen(input: str) -> Iterator[int]: yield fake.invoke(input * 2) yield fake.invoke(input * 3) - assert _schema(gen.input_schema) == { + assert gen.get_input_jsonschema() == { "title": "gen_input", "type": "string", } - assert _schema(gen.output_schema) == { + assert gen.get_output_jsonschema() == { "title": "gen_output", "type": "integer", } @@ -5040,11 +5121,11 @@ async def agen(input: str) -> AsyncIterator[int]: yield await fake.ainvoke(input * 2) yield await fake.ainvoke(input * 3) - assert _schema(agen.input_schema) == { + assert agen.get_input_jsonschema() == { "title": "agen_input", "type": "string", } - assert _schema(agen.output_schema) == { + assert agen.get_output_jsonschema() == { "title": "agen_output", "type": "integer", } @@ -5107,8 +5188,8 @@ def fun(input: str) -> int: output += fake.invoke(input * 3) return output - assert _schema(fun.input_schema) == {"title": "fun_input", "type": "string"} - assert _schema(fun.output_schema) == { + assert fun.get_input_jsonschema() == {"title": "fun_input", "type": "string"} + assert fun.get_output_jsonschema() == { "title": "fun_output", "type": "integer", } @@ -5156,8 +5237,8 @@ async def afun(input: str) -> int: output += await fake.ainvoke(input * 3) return output - assert _schema(afun.input_schema) == {"title": "afun_input", "type": "string"} - assert _schema(afun.output_schema) == { + assert afun.get_input_jsonschema() == {"title": "afun_input", "type": "string"} + assert afun.get_output_jsonschema() == { "title": "afun_output", "type": "integer", } @@ -5217,19 +5298,19 @@ async def aplus_one(input: AsyncIterator[int]) -> AsyncIterator[int]: chain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | plus_one achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one - assert _schema(chain.input_schema) == { + assert chain.get_input_jsonschema() == { "title": "gen_indexes_input", "type": "integer", } - assert _schema(chain.output_schema) == { + assert chain.get_output_jsonschema() == { "title": "plus_one_output", "type": "integer", } - assert _schema(achain.input_schema) == { + assert achain.get_input_jsonschema() == { "title": "gen_indexes_input", "type": "integer", } - assert _schema(achain.output_schema) == { + assert achain.get_output_jsonschema() == { "title": "aplus_one_output", "type": "integer", }