Skip to content

Commit

Permalink
feat: add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
cauwulixuan committed Jan 26, 2024
1 parent 40accbd commit ca6e115
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 0 deletions.
159 changes: 159 additions & 0 deletions libs/community/tests/integration_tests/chat_models/test_yuan2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""Test ChatYuan2 wrapper."""
from typing import Any, Optional

Check failure on line 2 in libs/community/tests/integration_tests/chat_models/test_yuan2.py

View workflow job for this annotation

GitHub Actions / ci (libs/community) / lint / build (3.11)

Ruff (F401)

tests/integration_tests/chat_models/test_yuan2.py:2:20: F401 `typing.Any` imported but unused

Check failure on line 2 in libs/community/tests/integration_tests/chat_models/test_yuan2.py

View workflow job for this annotation

GitHub Actions / ci (libs/community) / lint / build (3.11)

Ruff (F401)

tests/integration_tests/chat_models/test_yuan2.py:2:25: F401 `typing.Optional` imported but unused

Check failure on line 2 in libs/community/tests/integration_tests/chat_models/test_yuan2.py

View workflow job for this annotation

GitHub Actions / ci (libs/community) / lint / build (3.8)

Ruff (F401)

tests/integration_tests/chat_models/test_yuan2.py:2:20: F401 `typing.Any` imported but unused

Check failure on line 2 in libs/community/tests/integration_tests/chat_models/test_yuan2.py

View workflow job for this annotation

GitHub Actions / ci (libs/community) / lint / build (3.8)

Ruff (F401)

tests/integration_tests/chat_models/test_yuan2.py:2:25: F401 `typing.Optional` imported but unused

import pytest
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage

Check failure on line 6 in libs/community/tests/integration_tests/chat_models/test_yuan2.py

View workflow job for this annotation

GitHub Actions / ci (libs/community) / lint / build (3.11)

Ruff (F401)

tests/integration_tests/chat_models/test_yuan2.py:6:37: F401 `langchain_core.messages.AIMessage` imported but unused

Check failure on line 6 in libs/community/tests/integration_tests/chat_models/test_yuan2.py

View workflow job for this annotation

GitHub Actions / ci (libs/community) / lint / build (3.8)

Ruff (F401)

tests/integration_tests/chat_models/test_yuan2.py:6:37: F401 `langchain_core.messages.AIMessage` imported but unused
from langchain_core.outputs import (
ChatGeneration,
ChatResult,

Check failure on line 9 in libs/community/tests/integration_tests/chat_models/test_yuan2.py

View workflow job for this annotation

GitHub Actions / ci (libs/community) / lint / build (3.11)

Ruff (F401)

tests/integration_tests/chat_models/test_yuan2.py:9:5: F401 `langchain_core.outputs.ChatResult` imported but unused

Check failure on line 9 in libs/community/tests/integration_tests/chat_models/test_yuan2.py

View workflow job for this annotation

GitHub Actions / ci (libs/community) / lint / build (3.8)

Ruff (F401)

tests/integration_tests/chat_models/test_yuan2.py:9:5: F401 `langchain_core.outputs.ChatResult` imported but unused
LLMResult,
)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field

from langchain_community.chat_models.yuan2 import ChatYuan2
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler


@pytest.mark.scheduled
def test_chat_yuan2() -> None:
"""Test ChatYuan2 wrapper."""
chat = ChatYuan2(
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=False,
)
messages = [
HumanMessage(content="Hello"),
]
response = chat(messages)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)


def test_chat_yuan2_system_message() -> None:
"""Test ChatYuan2 wrapper with system message."""
chat = ChatYuan2(
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=False,
)
messages = [
SystemMessage(content="You are an AI assistant."),
HumanMessage(content="Hello"),
]
response = chat(messages)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)


@pytest.mark.scheduled
def test_chat_yuan2_generate() -> None:
"""Test ChatYuan2 wrapper with generate."""
chat = ChatYuan2(
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=False,
)
messages = [
HumanMessage(content="Hello"),
]
response = chat.generate([messages])
assert isinstance(response, LLMResult)
assert len(response.generations) == 1
assert response.llm_output
generation = response.generations[0]
for gen in generation:
assert isinstance(gen, ChatGeneration)
assert isinstance(gen.text, str)
assert gen.text == gen.message.content

@pytest.mark.scheduled
def test_chat_yuan2_streaming() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])

chat = ChatYuan2(
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=True,
callback_manager=callback_manager,
)
messages = [
HumanMessage(
content="Hello"
),
]
response = chat(messages)
assert callback_handler.llm_streams > 0
assert isinstance(response, BaseMessage)


@pytest.mark.asyncio
async def test_async_chat_yuan2() -> None:
"""Test async generation."""
chat = ChatYuan2(
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=False,
)
messages = [
HumanMessage(content="Hello"),
]
response = await chat.agenerate([messages])
assert isinstance(response, LLMResult)
assert len(response.generations) == 1
generations = response.generations[0]
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content


@pytest.mark.asyncio
async def test_async_chat_yuan2_streaming() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])

chat = ChatYuan2(
yuan2_api_key="EMPTY",
yuan2_api_base="http://127.0.0.1:8001/v1",
temperature=1.0,
model_name="yuan2",
max_retries=3,
streaming=True,
callback_manager=callback_manager,
)
messages = [
HumanMessage(
content="Hello"
),
]
response = await chat.agenerate([messages])
assert callback_handler.llm_streams > 0
assert isinstance(response, LLMResult)
assert len(response.generations) == 1
generations = response.generations[0]
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content

Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"LlamaEdgeChatService",
"GPTRouter",
"ChatZhipuAI",
"ChatYuan2",
]


Expand Down
60 changes: 60 additions & 0 deletions libs/community/tests/unit_tests/chat_models/test_yuan2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Test ChatYuan2 wrapper."""

import pytest
from langchain_core.messages import (
AIMessage,
HumanMessage,
SystemMessage,
)

from langchain_community.chat_models.yuan2 import ChatYuan2, _convert_message_to_dict, _convert_dict_to_message


@pytest.mark.requires("openai")
def test_yuan2_model_param() -> None:
chat = ChatYuan2(model="foo")
assert chat.model_name == "foo"
chat = ChatYuan2(model_name="foo")
assert chat.model_name == "foo"


def test__convert_message_to_dict_human() -> None:
message = HumanMessage(content="foo")
result = _convert_message_to_dict(message)
expected_output = {"role": "user", "content": "foo"}
assert result == expected_output


def test__convert_message_to_dict_ai() -> None:
message = AIMessage(content="foo")
result = _convert_message_to_dict(message)
expected_output = {"role": "assistant", "content": "foo"}
assert result == expected_output


def test__convert_message_to_dict_system() -> None:
message = SystemMessage(content="foo")
with pytest.raises(TypeError) as e:
_convert_message_to_dict(message)
assert "Got unknown type" in str(e)


def test__convert_dict_to_message_human() -> None:
message = {"role": "user", "content": "hello"}
result = _convert_dict_to_message(message)
expected_output = HumanMessage(content="hello")
assert result == expected_output


def test__convert_dict_to_message_ai() -> None:
message = {"role": "assistant", "content": "hello"}
result = _convert_dict_to_message(message)
expected_output = AIMessage(content="hello")
assert result == expected_output


def test__convert_dict_to_message_system() -> None:
message = {"role": "system", "content": "hello"}
result = _convert_dict_to_message(message)
expected_output = SystemMessage(content="hello")
assert result == expected_output

0 comments on commit ca6e115

Please sign in to comment.