From ab8aeea1290a72e3d2e793b08dc2b558b9961e3a Mon Sep 17 00:00:00 2001 From: Vadym Barda Date: Mon, 16 Sep 2024 13:44:33 -0400 Subject: [PATCH] langgraph: use create_model_v2 and fallback on create_model (#1708) * langgraph: use create_model_v2 and fallback on create_model * fix * code review * fix --- libs/langgraph/langgraph/graph/state.py | 6 ++-- libs/langgraph/langgraph/pregel/__init__.py | 8 +++-- libs/langgraph/langgraph/utils/pydantic.py | 37 +++++++++++++++++++++ 3 files changed, 45 insertions(+), 6 deletions(-) create mode 100644 libs/langgraph/langgraph/utils/pydantic.py diff --git a/libs/langgraph/langgraph/graph/state.py b/libs/langgraph/langgraph/graph/state.py index 3a20a06ab..12e823ae4 100644 --- a/libs/langgraph/langgraph/graph/state.py +++ b/libs/langgraph/langgraph/graph/state.py @@ -19,7 +19,6 @@ from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.runnables.base import RunnableLike -from langchain_core.runnables.utils import create_model from pydantic import BaseModel from pydantic.v1 import BaseModel as BaseModelV1 @@ -47,6 +46,7 @@ from langgraph.pregel.write import SKIP_WRITE, ChannelWrite, ChannelWriteEntry from langgraph.store.base import BaseStore from langgraph.utils.fields import get_field_default +from langgraph.utils.pydantic import create_model from langgraph.utils.runnable import coerce_to_runnable logger = logging.getLogger(__name__) @@ -784,12 +784,12 @@ def _get_schema( if len(keys) == 1 and keys[0] == "__root__": return create_model( # type: ignore[call-overload] name, - __root__=(channels[keys[0]].UpdateType, None), + root=(channels[keys[0]].UpdateType, None), ) else: return create_model( # type: ignore[call-overload] name, - **{ + field_definitions={ k: ( channels[k].UpdateType, ( diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index b5cdb6b19..1d9b2c9df 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -34,7 +34,6 @@ ) from langchain_core.runnables.utils import ( ConfigurableFieldSpec, - create_model, get_function_nonlocals, get_unique_config_specs, ) @@ -91,6 +90,7 @@ patch_config, patch_configurable, ) +from langgraph.utils.pydantic import create_model from langgraph.utils.runnable import RunnableCallable WriteValue = Union[Callable[[Input], Output], Any] @@ -309,7 +309,7 @@ def get_input_schema( else: return create_model( # type: ignore[call-overload] self.get_name("Input"), - **{ + field_definitions={ k: (self.channels[k].UpdateType, None) for k in self.input_channels or self.channels.keys() }, @@ -329,7 +329,9 @@ def get_output_schema( else: return create_model( # type: ignore[call-overload] self.get_name("Output"), - **{k: (self.channels[k].ValueType, None) for k in self.output_channels}, + field_definitions={ + k: (self.channels[k].ValueType, None) for k in self.output_channels + }, ) @property diff --git a/libs/langgraph/langgraph/utils/pydantic.py b/libs/langgraph/langgraph/utils/pydantic.py new file mode 100644 index 000000000..9accc66e1 --- /dev/null +++ b/libs/langgraph/langgraph/utils/pydantic.py @@ -0,0 +1,37 @@ +from typing import Any, Dict, Optional, Union + +from pydantic import BaseModel +from pydantic.v1 import BaseModel as BaseModelV1 + + +def create_model( + model_name: str, + *, + field_definitions: Optional[Dict[str, Any]] = None, + root: Optional[Any] = None, +) -> Union[BaseModel, BaseModelV1]: + """Create a pydantic model with the given field definitions. + + Args: + model_name: The name of the model. + field_definitions: The field definitions for the model. + root: Type for a root model (RootModel) + """ + try: + # for langchain-core >= 0.3.0 + from langchain_core.runnables.pydantic import create_model_v2 + + return create_model_v2( + model_name, + field_definitions=field_definitions, + root=root, + ) + except ImportError: + # for langchain-core < 0.3.0 + from langchain_core.runnables.utils import create_model + + v1_kwargs = {} + if root is not None: + v1_kwargs["__root__"] = root + + return create_model(model_name, **v1_kwargs, **(field_definitions or {}))