Skip to content

Commit

Permalink
Support multiple args in @task's (#2923)
Browse files Browse the repository at this point in the history
Also add support for just `@task` without the ()'s
  • Loading branch information
hinthornw authored Jan 4, 2025
2 parents 0809868 + 451bc03 commit 59a11c6
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 31 deletions.
53 changes: 44 additions & 9 deletions libs/langgraph/langgraph/func/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio
import concurrent
import concurrent.futures
import functools
import inspect
import types
from functools import partial, update_wrapper
from typing import (
Any,
Awaitable,
Expand Down Expand Up @@ -33,17 +33,17 @@


def call(
func: Callable[[P1], T],
input: P1,
*,
func: Callable[P, T],
*args: Any,
retry: Optional[RetryPolicy] = None,
**kwargs: Any,
) -> concurrent.futures.Future[T]:
from langgraph.constants import CONFIG_KEY_CALL
from langgraph.utils.config import get_configurable

conf = get_configurable()
impl = conf[CONFIG_KEY_CALL]
fut = impl(func, input, retry=retry)
fut = impl(func, (args, kwargs), retry=retry)
return fut


Expand All @@ -59,16 +59,51 @@ def task( # type: ignore[overload-cannot-match]
) -> Callable[[Callable[P, T]], Callable[P, concurrent.futures.Future[T]]]: ...


@overload
def task(
*, retry: Optional[RetryPolicy] = None
__func_or_none__: Callable[P, T],
) -> Callable[P, concurrent.futures.Future[T]]: ...


@overload
def task(
__func_or_none__: Callable[P, Awaitable[T]],
) -> Callable[P, asyncio.Future[T]]: ...


def task(
__func_or_none__: Optional[Union[Callable[P, T], Callable[P, Awaitable[T]]]] = None,
*,
retry: Optional[RetryPolicy] = None,
) -> Union[
Callable[[Callable[P, Awaitable[T]]], Callable[P, asyncio.Future[T]]],
Callable[[Callable[P, T]], Callable[P, concurrent.futures.Future[T]]],
Callable[P, asyncio.Future[T]],
Callable[P, concurrent.futures.Future[T]],
]:
def _task(func: Callable[P, T]) -> Callable[P, concurrent.futures.Future[T]]:
return update_wrapper(partial(call, func, retry=retry), func)
def decorator(
func: Union[Callable[P, Awaitable[T]], Callable[P, T]],
) -> Callable[P, concurrent.futures.Future[T]]:
if asyncio.iscoroutinefunction(func):

@functools.wraps(func)
async def _tick(__allargs__: tuple) -> T:
return await func(*__allargs__[0], **__allargs__[1])

else:

@functools.wraps(func)
def _tick(__allargs__: tuple) -> T:
return func(*__allargs__[0], **__allargs__[1])

return functools.update_wrapper(
functools.partial(call, _tick, retry=retry), func
)

if __func_or_none__ is not None:
return decorator(__func_or_none__)

return _task
return decorator


def entrypoint(
Expand Down
31 changes: 19 additions & 12 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,27 +1515,32 @@ def test_imp_stream_order(
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

@task()
def foo(state: dict) -> dict:
return {"a": state["a"] + "foo", "b": "bar"}
def foo(state: dict) -> tuple:
return state["a"] + "foo", "bar"

@task()
def bar(state: dict) -> dict:
return {"a": state["a"] + state["b"], "c": "bark"}
@task
def bar(a: str, b: str, c: Optional[str] = None) -> dict:
return {"a": a + b, "c": (c or "") + "bark"}

@task()
@task
def baz(state: dict) -> dict:
return {"a": state["a"] + "baz", "c": "something else"}

@entrypoint(checkpointer=checkpointer)
def graph(state: dict) -> dict:
fut_foo = foo(state)
fut_bar = bar(fut_foo.result())
fut_bar = bar(*fut_foo.result())
fut_baz = baz(fut_bar.result())
return fut_baz.result()

thread1 = {"configurable": {"thread_id": "1"}}
assert [c for c in graph.stream({"a": "0"}, thread1)] == [
{"foo": {"a": "0foo", "b": "bar"}},
{
"foo": (
"0foo",
"bar",
)
},
{"bar": {"a": "0foobar", "c": "bark"}},
{"baz": {"a": "0foobarbaz", "c": "something else"}},
{"graph": {"a": "0foobarbaz", "c": "something else"}},
Expand Down Expand Up @@ -4168,10 +4173,12 @@ def __init__(self, i: Optional[int] = None):
def __call__(self, inputs: State, config: RunnableConfig, store: BaseStore):
assert isinstance(store, BaseStore)
store.put(
namespace
if self.i is not None
and config["configurable"]["thread_id"] in (thread_1, thread_2)
else (f"foo_{self.i}", "bar"),
(
namespace
if self.i is not None
and config["configurable"]["thread_id"] in (thread_1, thread_2)
else (f"foo_{self.i}", "bar")
),
doc_id,
{
**doc,
Expand Down
21 changes: 11 additions & 10 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2571,18 +2571,18 @@ async def test_imp_sync_from_async(checkpointer_name: str) -> None:
def foo(state: dict) -> dict:
return {"a": state["a"] + "foo", "b": "bar"}

@task()
def bar(state: dict) -> dict:
return {"a": state["a"] + state["b"], "c": "bark"}
@task
def bar(a: str, b: str, c: Optional[str] = None) -> dict:
return {"a": a + b, "c": (c or "") + "bark"}

@task()
def baz(state: dict) -> dict:
return {"a": state["a"] + "baz", "c": "something else"}

@entrypoint(checkpointer=checkpointer)
def graph(state: dict) -> dict:
fut_foo = foo(state)
fut_bar = bar(fut_foo.result())
foo_result = foo(state).result()
fut_bar = bar(foo_result["a"], foo_result["b"])
fut_baz = baz(fut_bar.result())
return fut_baz.result()

Expand All @@ -2607,18 +2607,19 @@ async def test_imp_stream_order(checkpointer_name: str) -> None:
async def foo(state: dict) -> dict:
return {"a": state["a"] + "foo", "b": "bar"}

@task()
async def bar(state: dict) -> dict:
return {"a": state["a"] + state["b"], "c": "bark"}
@task
async def bar(a: str, b: str, c: Optional[str] = None) -> dict:
return {"a": a + b, "c": (c or "") + "bark"}

@task()
async def baz(state: dict) -> dict:
return {"a": state["a"] + "baz", "c": "something else"}

@entrypoint(checkpointer=checkpointer)
async def graph(state: dict) -> dict:
fut_foo = foo(state)
fut_bar = bar(await fut_foo)
foo_res = await foo(state)

fut_bar = bar(foo_res["a"], foo_res["b"])
fut_baz = baz(await fut_bar)
return await fut_baz

Expand Down

0 comments on commit 59a11c6

Please sign in to comment.