From 926a90a2566128f66b8d07f96460374d50fef838 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 12 Dec 2023 10:06:09 -0800 Subject: [PATCH 1/5] WIP Fix tool_calls message merge --- libs/core/langchain_core/messages/base.py | 16 +- libs/core/tests/unit_tests/test_messages.py | 224 ++++++++++++++++++++ 2 files changed, 238 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/messages/base.py b/libs/core/langchain_core/messages/base.py index daed44f50142b..9233f70095f75 100644 --- a/libs/core/langchain_core/messages/base.py +++ b/libs/core/langchain_core/messages/base.py @@ -98,8 +98,12 @@ def _merge_kwargs_dict( merged[k] = v elif merged[k] is None and v: merged[k] = v + elif v is None: + continue + elif merged[k] == v: + continue elif type(merged[k]) != type(v): - raise ValueError( + raise TypeError( f'additional_kwargs["{k}"] already exists in this message,' " but with a different type." ) @@ -107,8 +111,16 @@ def _merge_kwargs_dict( merged[k] += v elif isinstance(merged[k], dict): merged[k] = self._merge_kwargs_dict(merged[k], v) + elif isinstance(merged[k], list): + for i, e in enumerate(v): + if isinstance(e, dict) and isinstance(e.get("index"), int): + i = e["index"] + if i < len(merged[k]): + merged[k][i] = self._merge_kwargs_dict(merged[k][i], e) + else: + merged[k].append(e) else: - raise ValueError( + raise TypeError( f"Additional kwargs key {k} already exists in this message." ) return merged diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index bcb8bc88ce4c9..698a931bf4fcd 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -203,3 +203,227 @@ def test_message_chunk_to_message() -> None: assert message_chunk_to_message( FunctionMessageChunk(name="hello", content="I am") ) == FunctionMessage(name="hello", content="I am") + + +def test_tool_calls_merge() -> None: + chunks = [ + dict(content=""), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": "call_CwGAsESnXehQEjiAIWzinlva", + "function": {"arguments": "", "name": "person"}, + "type": "function", + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": '{"na', "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": 'me": ', "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": '"jane"', "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": ', "a', "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": 'ge": ', "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": "2}", "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 1, + "id": "call_zXSIylHvc5x3JUAPcHZR5GZI", + "function": {"arguments": "", "name": "person"}, + "type": "function", + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 1, + "id": None, + "function": {"arguments": '{"na', "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 1, + "id": None, + "function": {"arguments": 'me": ', "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 1, + "id": None, + "function": {"arguments": '"bob",', "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 1, + "id": None, + "function": {"arguments": ' "ag', "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 1, + "id": None, + "function": {"arguments": 'e": 3', "name": None}, + "type": None, + } + ] + }, + ), + dict( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 1, + "id": None, + "function": {"arguments": "}", "name": None}, + "type": None, + } + ] + }, + ), + dict(content=""), + ] + + final = None + + for chunk in chunks: + msg = AIMessageChunk(**chunk) + if final is None: + final = msg + else: + final = final + msg + + assert final == AIMessageChunk( + content="", + additional_kwargs={ + "tool_calls": [ + { + "index": 0, + "id": "call_CwGAsESnXehQEjiAIWzinlva", + "function": { + "arguments": '{"name": "jane", "age": 2}', + "name": "person", + }, + "type": "function", + }, + { + "index": 1, + "id": "call_zXSIylHvc5x3JUAPcHZR5GZI", + "function": { + "arguments": '{"name": "bob", "age": 3}', + "name": "person", + }, + "type": "function", + }, + ] + }, + ) From 5507a1035cb392146dee984ae43c6087a6cfb302 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 12 Dec 2023 15:38:12 -0800 Subject: [PATCH 2/5] Fix mutation --- libs/core/langchain_core/messages/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/core/langchain_core/messages/base.py b/libs/core/langchain_core/messages/base.py index 9233f70095f75..96c7665ec7bcc 100644 --- a/libs/core/langchain_core/messages/base.py +++ b/libs/core/langchain_core/messages/base.py @@ -112,13 +112,14 @@ def _merge_kwargs_dict( elif isinstance(merged[k], dict): merged[k] = self._merge_kwargs_dict(merged[k], v) elif isinstance(merged[k], list): + merged[k] = merged[k].copy() for i, e in enumerate(v): if isinstance(e, dict) and isinstance(e.get("index"), int): i = e["index"] if i < len(merged[k]): merged[k][i] = self._merge_kwargs_dict(merged[k][i], e) else: - merged[k].append(e) + merged[k] = merged[k] + [e] else: raise TypeError( f"Additional kwargs key {k} already exists in this message." From 26048844f79dce832e26c7a74e051c0eaf253b57 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 12 Dec 2023 15:38:54 -0800 Subject: [PATCH 3/5] Lint --- libs/core/tests/unit_tests/test_messages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index 698a931bf4fcd..31059109c8671 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -206,7 +206,7 @@ def test_message_chunk_to_message() -> None: def test_tool_calls_merge() -> None: - chunks = [ + chunks: list[dict] = [ dict(content=""), dict( content="", From 5cc15f67f1ffde9037d1670f478b0734c3bea368 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 12 Dec 2023 15:58:06 -0800 Subject: [PATCH 4/5] Lint --- libs/core/tests/unit_tests/test_messages.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index 31059109c8671..d9da9e0dcad48 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -1,3 +1,4 @@ +from typing import List import unittest import pytest @@ -206,7 +207,7 @@ def test_message_chunk_to_message() -> None: def test_tool_calls_merge() -> None: - chunks: list[dict] = [ + chunks: List[dict] = [ dict(content=""), dict( content="", From 45c16e61eac9df1ddd2f2ca8883df95cf2781a33 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 13 Dec 2023 08:55:03 -0800 Subject: [PATCH 5/5] Lint --- libs/core/tests/unit_tests/test_messages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index d9da9e0dcad48..95d60a52f2b68 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -1,5 +1,5 @@ -from typing import List import unittest +from typing import List import pytest