From 9644525b5942ff379ab4bee66fd972bc3d8069cb Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 17 Sep 2024 09:14:52 -0700 Subject: [PATCH] Fix some deprecation warnings in tests --- libs/langgraph/langgraph/pregel/__init__.py | 18 ++++++++++++++++++ libs/langgraph/tests/test_pregel.py | 9 ++++----- libs/langgraph/tests/test_pregel_async.py | 9 +++------ libs/langgraph/tests/test_state.py | 13 ++++--------- 4 files changed, 29 insertions(+), 20 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index dc803c735..8fe8608cf 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -316,6 +316,15 @@ def get_input_schema( }, ) + def get_input_jsonschema( + self, config: Optional[RunnableConfig] = None + ) -> Dict[All, Any]: + schema = self.get_input_schema(config) + if hasattr(schema, "model_json_schema"): + return schema.model_json_schema() + else: + return schema.schema() + @property def OutputType(self) -> Any: if isinstance(self.output_channels, str): @@ -335,6 +344,15 @@ def get_output_schema( }, ) + def get_output_jsonschema( + self, config: Optional[RunnableConfig] = None + ) -> Dict[All, Any]: + schema = self.get_output_schema(config) + if hasattr(schema, "model_json_schema"): + return schema.model_json_schema() + else: + return schema.schema() + @property def stream_channels_list(self) -> Sequence[str]: stream_channels = self.stream_channels_asis diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 8c88f8ad0..3e3ebc72f 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -361,7 +361,7 @@ def node_c(state: StateForC) -> StateForC: "messages": [_AnyIdHumanMessage(content="hello")], } - builder = StateGraph(input=State, output=Output) + builder = StateGraph(State, output=Output) builder.add_node("a", node_a) builder.add_node("b", node_b) builder.add_node("c", node_c) @@ -7663,10 +7663,9 @@ def decider(data: State) -> str: app = workflow.compile() - # because it's a v1 pydantic, we're using .schema() here instead of the new methods assert app.get_graph().draw_mermaid(with_styles=False) == snapshot - assert app.get_input_schema().schema() == snapshot - assert app.get_output_schema().schema() == snapshot + assert app.get_input_jsonschema() == snapshot + assert app.get_output_jsonschema() == snapshot with pytest.raises(ValidationError), assert_ctx_once(): app.invoke({"query": {}}) @@ -9868,7 +9867,7 @@ def edit(state: JokeState): return {"subject": f"{subject} - hohoho"} # subgraph - subgraph = StateGraph(input=JokeState, output=OverallState) + subgraph = StateGraph(JokeState, output=OverallState) subgraph.add_node("edit", edit) subgraph.add_node( "generate", lambda state: {"jokes": [f"Joke about {state['subject']}"]} diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index a3a1510c2..ec53d5633 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -601,7 +601,7 @@ async def node_c(state: StateForC): "messages": [_AnyIdHumanMessage(content="hello")], } - builder = StateGraph(input=State, output=Output) + builder = StateGraph(State, output=Output) builder.add_node("a", node_a) builder.add_node("b", node_b) builder.add_node("c", node_c) @@ -3179,10 +3179,7 @@ async def assert_ctx_once() -> AsyncIterator[None]: setup.reset_mock() teardown.reset_mock() - class MyPydanticContextModel(BaseModel): - class Config: - arbitrary_types_allowed = True - + class MyPydanticContextModel(BaseModel, arbitrary_types_allowed=True): session: httpx.AsyncClient something_else: str @@ -8603,7 +8600,7 @@ async def edit(state: JokeState): return {"subject": f"{subject} - hohoho"} # subgraph - subgraph = StateGraph(input=JokeState, output=OverallState) + subgraph = StateGraph(JokeState, output=OverallState) subgraph.add_node("edit", edit) subgraph.add_node( "generate", lambda state: {"jokes": [f"Joke about {state['subject']}"]} diff --git a/libs/langgraph/tests/test_state.py b/libs/langgraph/tests/test_state.py index d2a31690f..73546b6a1 100644 --- a/libs/langgraph/tests/test_state.py +++ b/libs/langgraph/tests/test_state.py @@ -73,7 +73,7 @@ def only_return_hint(state, config) -> OutputState: def miss_all_hint(state, config): return {"input_state": state} - graph = StateGraph(input=InputState, output=OutputState) + graph = StateGraph(InputState, output=OutputState) actions = [complete_hint, miss_first_hint, only_return_hint, miss_all_hint] for action in actions: @@ -125,8 +125,7 @@ class State(InputState): # this would be ignored builder.add_node("n", lambda x: x) builder.add_edge("__start__", "n") graph = builder.compile() - model = graph.get_input_schema() - json_schema = model.schema() + json_schema = graph.get_input_jsonschema() if total_ is False: expected_required = set() @@ -146,7 +145,7 @@ class State(InputState): # this would be ignored ) # Check output schema. Should be the same process - output_schema = graph.get_output_schema().schema() + output_schema = graph.get_output_jsonschema() if total_ is False: expected_required = set() expected_optional = {"out_val2", "out_val1"} @@ -192,9 +191,7 @@ class InputState: builder.add_node("n", lambda x: x) builder.add_edge("__start__", "n") graph = builder.compile() - for model in [graph.get_input_schema(), graph.get_output_schema()]: - json_schema = model.schema() - + for json_schema in [graph.get_input_jsonschema(), graph.get_output_jsonschema()]: expected_required = {"val1", "val7"} expected_optional = { "val2", @@ -256,8 +253,6 @@ class State(TypedDict): StateGraph(_state, input=_inp, output=_outp) bad_output_examples = [ (State, InputState, BadOutputState), - (None, InputState, BadOutputState), - (None, State, BadOutputState), (State, None, BadOutputState), ] for _state, _inp, _outp in bad_output_examples: