Skip to content

Commit

Permalink
core[minor]: Add get_input_jsonschema, get_output_jsonschema, get_con…
Browse files Browse the repository at this point in the history
…fig_jsonschema (#26034)

This PR adds methods to directly get the json schema for inputs,
outputs, and config.
Currently, it's delegating to the underlying pydantic implementation,
but this may be changed in the future to be independent of pydantic.
  • Loading branch information
eyurtsev authored Sep 5, 2024
1 parent e5aa0f9 commit 3c598d2
Show file tree
Hide file tree
Showing 7 changed files with 479 additions and 316 deletions.
71 changes: 71 additions & 0 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,34 @@ def get_input_schema(
__root__=root_type,
)

def get_input_jsonschema(
self, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
"""Get a JSON schema that represents the input to the Runnable.
Args:
config: A config to use when generating the schema.
Returns:
A JSON schema that represents the input to the Runnable.
Example:
.. code-block:: python
from langchain_core.runnables import RunnableLambda
def add_one(x: int) -> int:
return x + 1
runnable = RunnableLambda(add_one)
print(runnable.get_input_jsonschema())
.. versionadded:: 0.3.0
"""
return self.get_input_schema(config).model_json_schema()

@property
def output_schema(self) -> Type[BaseModel]:
"""The type of output this Runnable produces specified as a pydantic model."""
Expand Down Expand Up @@ -382,6 +410,34 @@ def get_output_schema(
__root__=root_type,
)

def get_output_jsonschema(
self, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
"""Get a JSON schema that represents the output of the Runnable.
Args:
config: A config to use when generating the schema.
Returns:
A JSON schema that represents the output of the Runnable.
Example:
.. code-block:: python
from langchain_core.runnables import RunnableLambda
def add_one(x: int) -> int:
return x + 1
runnable = RunnableLambda(add_one)
print(runnable.get_output_jsonschema())
.. versionadded:: 0.3.0
"""
return self.get_output_schema(config).model_json_schema()

@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
"""List configurable fields for this Runnable."""
Expand Down Expand Up @@ -435,6 +491,21 @@ def config_schema(
)
return model

def get_config_jsonschema(
self, *, include: Optional[Sequence[str]] = None
) -> Dict[str, Any]:
"""Get a JSON schema that represents the output of the Runnable.
Args:
include: A list of fields to include in the config schema.
Returns:
A JSON schema that represents the output of the Runnable.
.. versionadded:: 0.3.0
"""
return self.config_schema(include=include).model_json_schema()

def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
"""Return a graph representation of this Runnable."""
from langchain_core.runnables.graph import Graph
Expand Down
43 changes: 23 additions & 20 deletions libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# serializer version: 1
# name: test_chat_input_schema[partial]
dict({
'definitions': dict({
'$defs': dict({
'AIMessage': dict({
'additionalProperties': True,
'description': '''
Expand Down Expand Up @@ -60,7 +60,7 @@
'default': list([
]),
'items': dict({
'$ref': '#/definitions/InvalidToolCall',
'$ref': '#/$defs/InvalidToolCall',
}),
'title': 'Invalid Tool Calls',
'type': 'array',
Expand All @@ -85,7 +85,7 @@
'default': list([
]),
'items': dict({
'$ref': '#/definitions/ToolCall',
'$ref': '#/$defs/ToolCall',
}),
'title': 'Tool Calls',
'type': 'array',
Expand All @@ -102,7 +102,7 @@
'usage_metadata': dict({
'anyOf': list([
dict({
'$ref': '#/definitions/UsageMetadata',
'$ref': '#/$defs/UsageMetadata',
}),
dict({
'type': 'null',
Expand Down Expand Up @@ -1133,6 +1133,7 @@
'type': 'object',
}),
'artifact': dict({
'default': None,
'title': 'Artifact',
}),
'content': dict({
Expand Down Expand Up @@ -1345,25 +1346,26 @@
}),
'properties': dict({
'history': dict({
'default': None,
'items': dict({
'anyOf': list([
dict({
'$ref': '#/definitions/AIMessage',
'$ref': '#/$defs/AIMessage',
}),
dict({
'$ref': '#/definitions/HumanMessage',
'$ref': '#/$defs/HumanMessage',
}),
dict({
'$ref': '#/definitions/ChatMessage',
'$ref': '#/$defs/ChatMessage',
}),
dict({
'$ref': '#/definitions/SystemMessage',
'$ref': '#/$defs/SystemMessage',
}),
dict({
'$ref': '#/definitions/FunctionMessage',
'$ref': '#/$defs/FunctionMessage',
}),
dict({
'$ref': '#/definitions/ToolMessage',
'$ref': '#/$defs/ToolMessage',
}),
dict({
'$ref': '#/definitions/AIMessageChunk',
Expand Down Expand Up @@ -1402,7 +1404,7 @@
# ---
# name: test_chat_input_schema[required]
dict({
'definitions': dict({
'$defs': dict({
'AIMessage': dict({
'additionalProperties': True,
'description': '''
Expand Down Expand Up @@ -1461,7 +1463,7 @@
'default': list([
]),
'items': dict({
'$ref': '#/definitions/InvalidToolCall',
'$ref': '#/$defs/InvalidToolCall',
}),
'title': 'Invalid Tool Calls',
'type': 'array',
Expand All @@ -1486,7 +1488,7 @@
'default': list([
]),
'items': dict({
'$ref': '#/definitions/ToolCall',
'$ref': '#/$defs/ToolCall',
}),
'title': 'Tool Calls',
'type': 'array',
Expand All @@ -1503,7 +1505,7 @@
'usage_metadata': dict({
'anyOf': list([
dict({
'$ref': '#/definitions/UsageMetadata',
'$ref': '#/$defs/UsageMetadata',
}),
dict({
'type': 'null',
Expand Down Expand Up @@ -2534,6 +2536,7 @@
'type': 'object',
}),
'artifact': dict({
'default': None,
'title': 'Artifact',
}),
'content': dict({
Expand Down Expand Up @@ -2749,22 +2752,22 @@
'items': dict({
'anyOf': list([
dict({
'$ref': '#/definitions/AIMessage',
'$ref': '#/$defs/AIMessage',
}),
dict({
'$ref': '#/definitions/HumanMessage',
'$ref': '#/$defs/HumanMessage',
}),
dict({
'$ref': '#/definitions/ChatMessage',
'$ref': '#/$defs/ChatMessage',
}),
dict({
'$ref': '#/definitions/SystemMessage',
'$ref': '#/$defs/SystemMessage',
}),
dict({
'$ref': '#/definitions/FunctionMessage',
'$ref': '#/$defs/FunctionMessage',
}),
dict({
'$ref': '#/definitions/ToolMessage',
'$ref': '#/$defs/ToolMessage',
}),
dict({
'$ref': '#/definitions/AIMessageChunk',
Expand Down
5 changes: 2 additions & 3 deletions libs/core/tests/unit_tests/prompts/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
SystemMessagePromptTemplate,
_convert_to_message,
)
from tests.unit_tests.pydantic_utils import _schema


@pytest.fixture
Expand Down Expand Up @@ -796,14 +795,14 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None:
assert prompt_all_required.optional_variables == []
with pytest.raises(ValidationError):
prompt_all_required.input_schema(input="")
assert _schema(prompt_all_required.input_schema) == snapshot(name="required")
assert prompt_all_required.get_input_jsonschema() == snapshot(name="required")
prompt_optional = ChatPromptTemplate(
messages=[MessagesPlaceholder("history", optional=True), ("user", "${input}")]
)
# input variables only lists required variables
assert set(prompt_optional.input_variables) == {"input"}
prompt_optional.input_schema(input="") # won't raise error
assert _schema(prompt_optional.input_schema) == snapshot(name="partial")
assert prompt_optional.get_input_jsonschema() == snapshot(name="partial")


def test_chat_prompt_w_msgs_placeholder_ser_des(snapshot: SnapshotAssertion) -> None:
Expand Down
Loading

0 comments on commit 3c598d2

Please sign in to comment.