Skip to content

Commit

Permalink
core[patch]: Automatic upgrade to AddableDict in transform and atrans…
Browse files Browse the repository at this point in the history
…form (#18743)

Automatic upgrade to transform and atransform

Closes: 

#18741
langchain-ai/langgraph#136
langchain-ai/langserve#504
  • Loading branch information
eyurtsev authored and hinthornw committed Apr 26, 2024
1 parent 527cf9d commit 0fbc89c
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 8 deletions.
34 changes: 28 additions & 6 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,12 +1050,19 @@ def transform(

for chunk in input:
if not got_first_val:
final = chunk
final = _adapt_first_streaming_chunk(chunk) # type: ignore
got_first_val = True
else:
# Make a best effort to gather, for any type that supports `+`
# This method should throw an error if gathering fails.
final = final + chunk # type: ignore[operator]
try:
final = final + chunk # type: ignore[operator]
except TypeError:
raise TypeError(
f"Failed while trying to add together "
f"type {type(final)} and {type(chunk)}."
f"These types should be addable for transform to work."
)

if got_first_val:
yield from self.stream(final, config, **kwargs)
Expand All @@ -1076,12 +1083,19 @@ async def atransform(

async for chunk in input:
if not got_first_val:
final = chunk
final = _adapt_first_streaming_chunk(chunk) # type: ignore
got_first_val = True
else:
# Make a best effort to gather, for any type that supports `+`
# This method should throw an error if gathering fails.
final = final + chunk # type: ignore[operator]
try:
final = final + chunk # type: ignore[operator]
except TypeError:
raise TypeError(
f"Failed while trying to add together "
f"type {type(final)} and {type(chunk)}."
f"These types should be addable for atransform to work."
)

if got_first_val:
async for output in self.astream(final, config, **kwargs):
Expand Down Expand Up @@ -3560,7 +3574,7 @@ def _transform(
final: Optional[Input] = None
for ichunk in input:
if final is None:
final = ichunk
final = _adapt_first_streaming_chunk(ichunk) # type: ignore
else:
try:
final = final + ichunk # type: ignore[operator]
Expand Down Expand Up @@ -3644,7 +3658,7 @@ async def _atransform(
final: Optional[Input] = None
async for ichunk in input:
if final is None:
final = ichunk
final = _adapt_first_streaming_chunk(ichunk)
else:
try:
final = final + ichunk # type: ignore[operator]
Expand Down Expand Up @@ -4445,3 +4459,11 @@ def my_func(fields):
yield chunk
"""
return RunnableLambda(func)


def _adapt_first_streaming_chunk(chunk: Any) -> Any:
"""This might transform the first chunk of a stream into an AddableDict."""
if isinstance(chunk, dict) and not isinstance(chunk, AddableDict):
return AddableDict(chunk)
else:
return chunk
8 changes: 6 additions & 2 deletions libs/core/tests/unit_tests/fake/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class GenericFakeChatModel(BaseChatModel):
streaming.
"""

messages: Iterator[AIMessage]
messages: Iterator[Union[AIMessage, str]]
"""Get an iterator over messages.
This can be expanded to accept other types like Callables / dicts / strings
Expand All @@ -187,7 +187,11 @@ def _generate(
) -> ChatResult:
"""Top Level call"""
message = next(self.messages)
generation = ChatGeneration(message=message)
if isinstance(message, str):
message_ = AIMessage(content=message)
else:
message_ = message
generation = ChatGeneration(message=message_)
return ChatResult(generations=[generation])

def _stream(
Expand Down
68 changes: 68 additions & 0 deletions libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
chain,
)
from langchain_core.runnables.base import RunnableSerializable
from langchain_core.runnables.utils import Input, Output
from langchain_core.tools import BaseTool, tool
from langchain_core.tracers import (
BaseTracer,
Expand Down Expand Up @@ -5183,3 +5184,70 @@ def add_one(x: int) -> int:
"name": "add_one",
"type": "chain",
}


def test_transform_of_runnable_lambda_with_dicts() -> None:
"""Test transform of runnable lamdbda."""
runnable = RunnableLambda(lambda x: x)
chunks = iter(
[
{"foo": "a"},
{"foo": "n"},
]
)
assert list(runnable.transform(chunks)) == [{"foo": "an"}]


async def test_atransform_of_runnable_lambda_with_dicts() -> None:
async def identity(x: Dict[str, str]) -> Dict[str, str]:
"""Return x."""
return x

runnable = RunnableLambda[Dict[str, str], Dict[str, str]](identity)

async def chunk_iterator() -> AsyncIterator[Dict[str, str]]:
yield {"foo": "a"}
yield {"foo": "n"}

chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]
assert chunks == [{"foo": "an"}]


def test_default_transform_with_dicts() -> None:
"""Test that default transform works with dicts."""

class CustomRunnable(RunnableSerializable[Input, Output]):
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
return cast(Output, input) # type: ignore

runnable = CustomRunnable[Dict[str, str], Dict[str, str]]()
chunks = iter(
[
{"foo": "a"},
{"foo": "n"},
]
)

assert list(runnable.transform(chunks)) == [{"foo": "an"}]


async def test_defualt_atransform_with_dicts() -> None:
"""Test that default transform works with dicts."""

class CustomRunnable(RunnableSerializable[Input, Output]):
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
return cast(Output, input)

runnable = CustomRunnable[Dict[str, str], Dict[str, str]]()

async def chunk_iterator() -> AsyncIterator[Dict[str, str]]:
yield {"foo": "a"}
yield {"foo": "n"}

chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]

assert chunks == [{"foo": "an"}]

0 comments on commit 0fbc89c

Please sign in to comment.