From ca6e1155767d466c8b35da3ba905951ec120c68b Mon Sep 17 00:00:00 2001 From: wulixuan Date: Fri, 26 Jan 2024 10:21:21 +0800 Subject: [PATCH] feat: add tests. --- .../chat_models/test_yuan2.py | 159 ++++++++++++++++++ .../unit_tests/chat_models/test_imports.py | 1 + .../unit_tests/chat_models/test_yuan2.py | 60 +++++++ 3 files changed, 220 insertions(+) create mode 100644 libs/community/tests/integration_tests/chat_models/test_yuan2.py create mode 100644 libs/community/tests/unit_tests/chat_models/test_yuan2.py diff --git a/libs/community/tests/integration_tests/chat_models/test_yuan2.py b/libs/community/tests/integration_tests/chat_models/test_yuan2.py new file mode 100644 index 0000000000000..b5d8d178c13f2 --- /dev/null +++ b/libs/community/tests/integration_tests/chat_models/test_yuan2.py @@ -0,0 +1,159 @@ +"""Test ChatYuan2 wrapper.""" +from typing import Any, Optional + +import pytest +from langchain_core.callbacks import CallbackManager +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from langchain_core.outputs import ( + ChatGeneration, + ChatResult, + LLMResult, +) +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.pydantic_v1 import BaseModel, Field + +from langchain_community.chat_models.yuan2 import ChatYuan2 +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler + + +@pytest.mark.scheduled +def test_chat_yuan2() -> None: + """Test ChatYuan2 wrapper.""" + chat = ChatYuan2( + yuan2_api_key="EMPTY", + yuan2_api_base="http://127.0.0.1:8001/v1", + temperature=1.0, + model_name="yuan2", + max_retries=3, + streaming=False, + ) + messages = [ + HumanMessage(content="Hello"), + ] + response = chat(messages) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + + +def test_chat_yuan2_system_message() -> None: + """Test ChatYuan2 wrapper with system message.""" + chat = ChatYuan2( + yuan2_api_key="EMPTY", + yuan2_api_base="http://127.0.0.1:8001/v1", + temperature=1.0, + model_name="yuan2", + max_retries=3, + streaming=False, + ) + messages = [ + SystemMessage(content="You are an AI assistant."), + HumanMessage(content="Hello"), + ] + response = chat(messages) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + + +@pytest.mark.scheduled +def test_chat_yuan2_generate() -> None: + """Test ChatYuan2 wrapper with generate.""" + chat = ChatYuan2( + yuan2_api_key="EMPTY", + yuan2_api_base="http://127.0.0.1:8001/v1", + temperature=1.0, + model_name="yuan2", + max_retries=3, + streaming=False, + ) + messages = [ + HumanMessage(content="Hello"), + ] + response = chat.generate([messages]) + assert isinstance(response, LLMResult) + assert len(response.generations) == 1 + assert response.llm_output + generation = response.generations[0] + for gen in generation: + assert isinstance(gen, ChatGeneration) + assert isinstance(gen.text, str) + assert gen.text == gen.message.content + +@pytest.mark.scheduled +def test_chat_yuan2_streaming() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + + chat = ChatYuan2( + yuan2_api_key="EMPTY", + yuan2_api_base="http://127.0.0.1:8001/v1", + temperature=1.0, + model_name="yuan2", + max_retries=3, + streaming=True, + callback_manager=callback_manager, + ) + messages = [ + HumanMessage( + content="Hello" + ), + ] + response = chat(messages) + assert callback_handler.llm_streams > 0 + assert isinstance(response, BaseMessage) + + +@pytest.mark.asyncio +async def test_async_chat_yuan2() -> None: + """Test async generation.""" + chat = ChatYuan2( + yuan2_api_key="EMPTY", + yuan2_api_base="http://127.0.0.1:8001/v1", + temperature=1.0, + model_name="yuan2", + max_retries=3, + streaming=False, + ) + messages = [ + HumanMessage(content="Hello"), + ] + response = await chat.agenerate([messages]) + assert isinstance(response, LLMResult) + assert len(response.generations) == 1 + generations = response.generations[0] + for generation in generations: + assert isinstance(generation, ChatGeneration) + assert isinstance(generation.text, str) + assert generation.text == generation.message.content + + +@pytest.mark.asyncio +async def test_async_chat_yuan2_streaming() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + + chat = ChatYuan2( + yuan2_api_key="EMPTY", + yuan2_api_base="http://127.0.0.1:8001/v1", + temperature=1.0, + model_name="yuan2", + max_retries=3, + streaming=True, + callback_manager=callback_manager, + ) + messages = [ + HumanMessage( + content="Hello" + ), + ] + response = await chat.agenerate([messages]) + assert callback_handler.llm_streams > 0 + assert isinstance(response, LLMResult) + assert len(response.generations) == 1 + generations = response.generations[0] + for generation in generations: + assert isinstance(generation, ChatGeneration) + assert isinstance(generation.text, str) + assert generation.text == generation.message.content + diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index 031fb96e8937f..17eed6dfa3ff2 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -36,6 +36,7 @@ "LlamaEdgeChatService", "GPTRouter", "ChatZhipuAI", + "ChatYuan2", ] diff --git a/libs/community/tests/unit_tests/chat_models/test_yuan2.py b/libs/community/tests/unit_tests/chat_models/test_yuan2.py new file mode 100644 index 0000000000000..508bad9f454e8 --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_yuan2.py @@ -0,0 +1,60 @@ +"""Test ChatYuan2 wrapper.""" + +import pytest +from langchain_core.messages import ( + AIMessage, + HumanMessage, + SystemMessage, +) + +from langchain_community.chat_models.yuan2 import ChatYuan2, _convert_message_to_dict, _convert_dict_to_message + + +@pytest.mark.requires("openai") +def test_yuan2_model_param() -> None: + chat = ChatYuan2(model="foo") + assert chat.model_name == "foo" + chat = ChatYuan2(model_name="foo") + assert chat.model_name == "foo" + + +def test__convert_message_to_dict_human() -> None: + message = HumanMessage(content="foo") + result = _convert_message_to_dict(message) + expected_output = {"role": "user", "content": "foo"} + assert result == expected_output + + +def test__convert_message_to_dict_ai() -> None: + message = AIMessage(content="foo") + result = _convert_message_to_dict(message) + expected_output = {"role": "assistant", "content": "foo"} + assert result == expected_output + + +def test__convert_message_to_dict_system() -> None: + message = SystemMessage(content="foo") + with pytest.raises(TypeError) as e: + _convert_message_to_dict(message) + assert "Got unknown type" in str(e) + + +def test__convert_dict_to_message_human() -> None: + message = {"role": "user", "content": "hello"} + result = _convert_dict_to_message(message) + expected_output = HumanMessage(content="hello") + assert result == expected_output + + +def test__convert_dict_to_message_ai() -> None: + message = {"role": "assistant", "content": "hello"} + result = _convert_dict_to_message(message) + expected_output = AIMessage(content="hello") + assert result == expected_output + + +def test__convert_dict_to_message_system() -> None: + message = {"role": "system", "content": "hello"} + result = _convert_dict_to_message(message) + expected_output = SystemMessage(content="hello") + assert result == expected_output