Skip to content

Commit

Permalink
core: Move json parsing in base chat model / output parser to bg thre…
Browse files Browse the repository at this point in the history
…ad (#24031)

- add version of AIMessageChunk.__add__ that can add many chunks,
instead of only 2
- In agenerate_from_stream merge and parse chunks in bg thread
- In output parse base classes do more work in bg threads where
appropriate

---------

Co-authored-by: William FH <[email protected]>
  • Loading branch information
nfcampos and hinthornw authored Jul 9, 2024
1 parent 73966e6 commit 160fc7f
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 162 deletions.
29 changes: 7 additions & 22 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,11 @@ def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
ChatResult: Chat result.
"""

generation: Optional[ChatGenerationChunk] = None
for chunk in stream:
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
generation = next(stream, None)
if generation:
generation += list(stream)
if generation is None:
raise ValueError("No generations found in stream.")
return ChatResult(
generations=[
ChatGeneration(
Expand All @@ -123,21 +121,8 @@ async def agenerate_from_stream(
ChatResult: Chat result.
"""

generation: Optional[ChatGenerationChunk] = None
async for chunk in stream:
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(
generations=[
ChatGeneration(
message=message_chunk_to_message(generation.message),
generation_info=generation.generation_info,
)
]
)
chunks = [chunk async for chunk in stream]
return await run_in_executor(None, generate_from_stream, iter(chunks))


class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
Expand Down
121 changes: 63 additions & 58 deletions libs/core/langchain_core/messages/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,64 +267,69 @@ def init_tool_calls(cls, values: dict) -> dict:

def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk):
if self.example != other.example:
raise ValueError(
"Cannot concatenate AIMessageChunks with different example values."
)
return add_ai_message_chunks(self, other)
elif isinstance(other, (list, tuple)) and all(
isinstance(o, AIMessageChunk) for o in other
):
return add_ai_message_chunks(self, *other)
return super().__add__(other)

content = merge_content(self.content, other.content)
additional_kwargs = merge_dicts(
self.additional_kwargs, other.additional_kwargs
)
response_metadata = merge_dicts(
self.response_metadata, other.response_metadata
)

# Merge tool call chunks
if self.tool_call_chunks or other.tool_call_chunks:
raw_tool_calls = merge_lists(
self.tool_call_chunks,
other.tool_call_chunks,
)
if raw_tool_calls:
tool_call_chunks = [
ToolCallChunk(
name=rtc.get("name"),
args=rtc.get("args"),
index=rtc.get("index"),
id=rtc.get("id"),
)
for rtc in raw_tool_calls
]
else:
tool_call_chunks = []
else:
tool_call_chunks = []

# Token usage
if self.usage_metadata or other.usage_metadata:
left: UsageMetadata = self.usage_metadata or UsageMetadata(
input_tokens=0, output_tokens=0, total_tokens=0
)
right: UsageMetadata = other.usage_metadata or UsageMetadata(
input_tokens=0, output_tokens=0, total_tokens=0
)
usage_metadata: Optional[UsageMetadata] = {
"input_tokens": left["input_tokens"] + right["input_tokens"],
"output_tokens": left["output_tokens"] + right["output_tokens"],
"total_tokens": left["total_tokens"] + right["total_tokens"],
}
else:
usage_metadata = None

return self.__class__(
example=self.example,
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
response_metadata=response_metadata,
usage_metadata=usage_metadata,
id=self.id,
)
def add_ai_message_chunks(
left: AIMessageChunk, *others: AIMessageChunk
) -> AIMessageChunk:
"""Add multiple AIMessageChunks together."""
if any(left.example != o.example for o in others):
raise ValueError(
"Cannot concatenate AIMessageChunks with different example values."
)

return super().__add__(other)
content = merge_content(left.content, *(o.content for o in others))
additional_kwargs = merge_dicts(
left.additional_kwargs, *(o.additional_kwargs for o in others)
)
response_metadata = merge_dicts(
left.response_metadata, *(o.response_metadata for o in others)
)

# Merge tool call chunks
if raw_tool_calls := merge_lists(
left.tool_call_chunks, *(o.tool_call_chunks for o in others)
):
tool_call_chunks = [
ToolCallChunk(
name=rtc.get("name"),
args=rtc.get("args"),
index=rtc.get("index"),
id=rtc.get("id"),
)
for rtc in raw_tool_calls
]
else:
tool_call_chunks = []

# Token usage
if left.usage_metadata or any(o.usage_metadata is not None for o in others):
usage_metadata_: UsageMetadata = left.usage_metadata or UsageMetadata(
input_tokens=0, output_tokens=0, total_tokens=0
)
for other in others:
if other.usage_metadata is not None:
usage_metadata_["input_tokens"] += other.usage_metadata["input_tokens"]
usage_metadata_["output_tokens"] += other.usage_metadata[
"output_tokens"
]
usage_metadata_["total_tokens"] += other.usage_metadata["total_tokens"]
usage_metadata: Optional[UsageMetadata] = usage_metadata_
else:
usage_metadata = None

return left.__class__(
example=left.example,
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
response_metadata=response_metadata,
usage_metadata=usage_metadata,
id=left.id,
)
67 changes: 42 additions & 25 deletions libs/core/langchain_core/messages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def pretty_print(self) -> None:

def merge_content(
first_content: Union[str, List[Union[str, Dict]]],
second_content: Union[str, List[Union[str, Dict]]],
*contents: Union[str, List[Union[str, Dict]]],
) -> Union[str, List[Union[str, Dict]]]:
"""Merge two message contents.
Expand All @@ -122,31 +122,32 @@ def merge_content(
Returns:
The merged content.
"""
# If first chunk is a string
if isinstance(first_content, str):
# If the second chunk is also a string, then merge them naively
if isinstance(second_content, str):
return first_content + second_content
# If the second chunk is a list, add the first chunk to the start of the list
merged = first_content
for content in contents:
# If current is a string
if isinstance(merged, str):
# If the next chunk is also a string, then merge them naively
if isinstance(content, str):
merged = cast(str, merged) + content
# If the next chunk is a list, add the current to the start of the list
else:
merged = [merged] + content # type: ignore
elif isinstance(content, list):
# If both are lists
merged = merge_lists(cast(List, merged), content) # type: ignore
# If the first content is a list, and the second content is a string
else:
return_list: List[Union[str, Dict]] = [first_content]
return return_list + second_content
elif isinstance(second_content, List):
# If both are lists
merged_list = merge_lists(first_content, second_content)
return cast(list, merged_list)
# If the first content is a list, and the second content is a string
else:
# If the last element of the first content is a string
# Add the second content to the last element
if isinstance(first_content[-1], str):
return first_content[:-1] + [first_content[-1] + second_content]
# If second content is an empty string, treat as a no-op
elif second_content == "":
return first_content
else:
# Otherwise, add the second content as a new element of the list
return first_content + [second_content]
# If the last element of the first content is a string
# Add the second content to the last element
if isinstance(merged[-1], str):
merged[-1] += content
# If second content is an empty string, treat as a no-op
elif content == "":
pass
else:
# Otherwise, add the second content as a new element of the list
merged.append(content)
return merged


class BaseMessageChunk(BaseMessage):
Expand Down Expand Up @@ -195,6 +196,22 @@ def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
self.response_metadata, other.response_metadata
),
)
elif isinstance(other, list) and all(
isinstance(o, BaseMessageChunk) for o in other
):
content = merge_content(self.content, *(o.content for o in other))
additional_kwargs = merge_dicts(
self.additional_kwargs, *(o.additional_kwargs for o in other)
)
response_metadata = merge_dicts(
self.response_metadata, *(o.response_metadata for o in other)
)
return self.__class__( # type: ignore[call-arg]
id=self.id,
content=content,
additional_kwargs=additional_kwargs,
response_metadata=response_metadata,
)
else:
raise TypeError(
'unsupported operand type(s) for +: "'
Expand Down
11 changes: 8 additions & 3 deletions libs/core/langchain_core/output_parsers/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Generation,
GenerationChunk,
)
from langchain_core.runnables.config import run_in_executor

if TYPE_CHECKING:
from langchain_core.runnables import RunnableConfig
Expand All @@ -37,9 +38,13 @@ async def _atransform(
) -> AsyncIterator[T]:
async for chunk in input:
if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)])
yield await run_in_executor(
None, self.parse_result, [ChatGeneration(message=chunk)]
)
else:
yield self.parse_result([Generation(text=chunk)])
yield await run_in_executor(
None, self.parse_result, [Generation(text=chunk)]
)

def transform(
self,
Expand Down Expand Up @@ -153,7 +158,7 @@ async def _atransform(
parsed = await self.aparse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield self._diff(prev_parsed, parsed)
yield await run_in_executor(None, self._diff, prev_parsed, parsed)
else:
yield parsed
prev_parsed = parsed
17 changes: 15 additions & 2 deletions libs/core/langchain_core/outputs/chat_generation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Dict, List, Literal
from typing import Any, Dict, List, Literal, Union

from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.outputs.generation import Generation
Expand Down Expand Up @@ -88,7 +88,9 @@ def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "output"]

def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
def __add__(
self, other: Union[ChatGenerationChunk, List[ChatGenerationChunk]]
) -> ChatGenerationChunk:
if isinstance(other, ChatGenerationChunk):
generation_info = merge_dicts(
self.generation_info or {},
Expand All @@ -98,6 +100,17 @@ def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
message=self.message + other.message,
generation_info=generation_info or None,
)
elif isinstance(other, list) and all(
isinstance(x, ChatGenerationChunk) for x in other
):
generation_info = merge_dicts(
self.generation_info or {},
*[chunk.generation_info for chunk in other if chunk.generation_info],
)
return ChatGenerationChunk(
message=self.message + [chunk.message for chunk in other],
generation_info=generation_info or None,
)
else:
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
Expand Down
Loading

0 comments on commit 160fc7f

Please sign in to comment.