Skip to content

Commit

Permalink
Fix some deprecation warnings in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Sep 17, 2024
1 parent c4d4d61 commit 9644525
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 20 deletions.
18 changes: 18 additions & 0 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,15 @@ def get_input_schema(
},
)

def get_input_jsonschema(
self, config: Optional[RunnableConfig] = None
) -> Dict[All, Any]:
schema = self.get_input_schema(config)
if hasattr(schema, "model_json_schema"):
return schema.model_json_schema()
else:
return schema.schema()

@property
def OutputType(self) -> Any:
if isinstance(self.output_channels, str):
Expand All @@ -335,6 +344,15 @@ def get_output_schema(
},
)

def get_output_jsonschema(
self, config: Optional[RunnableConfig] = None
) -> Dict[All, Any]:
schema = self.get_output_schema(config)
if hasattr(schema, "model_json_schema"):
return schema.model_json_schema()
else:
return schema.schema()

@property
def stream_channels_list(self) -> Sequence[str]:
stream_channels = self.stream_channels_asis
Expand Down
9 changes: 4 additions & 5 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def node_c(state: StateForC) -> StateForC:
"messages": [_AnyIdHumanMessage(content="hello")],
}

builder = StateGraph(input=State, output=Output)
builder = StateGraph(State, output=Output)
builder.add_node("a", node_a)
builder.add_node("b", node_b)
builder.add_node("c", node_c)
Expand Down Expand Up @@ -7663,10 +7663,9 @@ def decider(data: State) -> str:

app = workflow.compile()

# because it's a v1 pydantic, we're using .schema() here instead of the new methods
assert app.get_graph().draw_mermaid(with_styles=False) == snapshot
assert app.get_input_schema().schema() == snapshot
assert app.get_output_schema().schema() == snapshot
assert app.get_input_jsonschema() == snapshot
assert app.get_output_jsonschema() == snapshot

with pytest.raises(ValidationError), assert_ctx_once():
app.invoke({"query": {}})
Expand Down Expand Up @@ -9868,7 +9867,7 @@ def edit(state: JokeState):
return {"subject": f"{subject} - hohoho"}

# subgraph
subgraph = StateGraph(input=JokeState, output=OverallState)
subgraph = StateGraph(JokeState, output=OverallState)
subgraph.add_node("edit", edit)
subgraph.add_node(
"generate", lambda state: {"jokes": [f"Joke about {state['subject']}"]}
Expand Down
9 changes: 3 additions & 6 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ async def node_c(state: StateForC):
"messages": [_AnyIdHumanMessage(content="hello")],
}

builder = StateGraph(input=State, output=Output)
builder = StateGraph(State, output=Output)
builder.add_node("a", node_a)
builder.add_node("b", node_b)
builder.add_node("c", node_c)
Expand Down Expand Up @@ -3179,10 +3179,7 @@ async def assert_ctx_once() -> AsyncIterator[None]:
setup.reset_mock()
teardown.reset_mock()

class MyPydanticContextModel(BaseModel):
class Config:
arbitrary_types_allowed = True

class MyPydanticContextModel(BaseModel, arbitrary_types_allowed=True):
session: httpx.AsyncClient
something_else: str

Expand Down Expand Up @@ -8603,7 +8600,7 @@ async def edit(state: JokeState):
return {"subject": f"{subject} - hohoho"}

# subgraph
subgraph = StateGraph(input=JokeState, output=OverallState)
subgraph = StateGraph(JokeState, output=OverallState)
subgraph.add_node("edit", edit)
subgraph.add_node(
"generate", lambda state: {"jokes": [f"Joke about {state['subject']}"]}
Expand Down
13 changes: 4 additions & 9 deletions libs/langgraph/tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def only_return_hint(state, config) -> OutputState:
def miss_all_hint(state, config):
return {"input_state": state}

graph = StateGraph(input=InputState, output=OutputState)
graph = StateGraph(InputState, output=OutputState)
actions = [complete_hint, miss_first_hint, only_return_hint, miss_all_hint]

for action in actions:
Expand Down Expand Up @@ -125,8 +125,7 @@ class State(InputState): # this would be ignored
builder.add_node("n", lambda x: x)
builder.add_edge("__start__", "n")
graph = builder.compile()
model = graph.get_input_schema()
json_schema = model.schema()
json_schema = graph.get_input_jsonschema()

if total_ is False:
expected_required = set()
Expand All @@ -146,7 +145,7 @@ class State(InputState): # this would be ignored
)

# Check output schema. Should be the same process
output_schema = graph.get_output_schema().schema()
output_schema = graph.get_output_jsonschema()
if total_ is False:
expected_required = set()
expected_optional = {"out_val2", "out_val1"}
Expand Down Expand Up @@ -192,9 +191,7 @@ class InputState:
builder.add_node("n", lambda x: x)
builder.add_edge("__start__", "n")
graph = builder.compile()
for model in [graph.get_input_schema(), graph.get_output_schema()]:
json_schema = model.schema()

for json_schema in [graph.get_input_jsonschema(), graph.get_output_jsonschema()]:
expected_required = {"val1", "val7"}
expected_optional = {
"val2",
Expand Down Expand Up @@ -256,8 +253,6 @@ class State(TypedDict):
StateGraph(_state, input=_inp, output=_outp)
bad_output_examples = [
(State, InputState, BadOutputState),
(None, InputState, BadOutputState),
(None, State, BadOutputState),
(State, None, BadOutputState),
]
for _state, _inp, _outp in bad_output_examples:
Expand Down

0 comments on commit 9644525

Please sign in to comment.