-
Notifications
You must be signed in to change notification settings - Fork 16.1k
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
"""Test ChatYuan2 wrapper.""" | ||
from typing import Any, Optional | ||
Check failure on line 2 in libs/community/tests/integration_tests/chat_models/test_yuan2.py GitHub Actions / ci (libs/community) / lint / build (3.11)Ruff (F401)
Check failure on line 2 in libs/community/tests/integration_tests/chat_models/test_yuan2.py GitHub Actions / ci (libs/community) / lint / build (3.11)Ruff (F401)
Check failure on line 2 in libs/community/tests/integration_tests/chat_models/test_yuan2.py GitHub Actions / ci (libs/community) / lint / build (3.8)Ruff (F401)
|
||
|
||
import pytest | ||
from langchain_core.callbacks import CallbackManager | ||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage | ||
Check failure on line 6 in libs/community/tests/integration_tests/chat_models/test_yuan2.py GitHub Actions / ci (libs/community) / lint / build (3.11)Ruff (F401)
Check failure on line 6 in libs/community/tests/integration_tests/chat_models/test_yuan2.py GitHub Actions / ci (libs/community) / lint / build (3.8)Ruff (F401)
|
||
from langchain_core.outputs import ( | ||
ChatGeneration, | ||
ChatResult, | ||
Check failure on line 9 in libs/community/tests/integration_tests/chat_models/test_yuan2.py GitHub Actions / ci (libs/community) / lint / build (3.11)Ruff (F401)
Check failure on line 9 in libs/community/tests/integration_tests/chat_models/test_yuan2.py GitHub Actions / ci (libs/community) / lint / build (3.8)Ruff (F401)
|
||
LLMResult, | ||
) | ||
from langchain_core.prompts import ChatPromptTemplate | ||
from langchain_core.pydantic_v1 import BaseModel, Field | ||
|
||
from langchain_community.chat_models.yuan2 import ChatYuan2 | ||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler | ||
|
||
|
||
@pytest.mark.scheduled | ||
def test_chat_yuan2() -> None: | ||
"""Test ChatYuan2 wrapper.""" | ||
chat = ChatYuan2( | ||
yuan2_api_key="EMPTY", | ||
yuan2_api_base="http://127.0.0.1:8001/v1", | ||
temperature=1.0, | ||
model_name="yuan2", | ||
max_retries=3, | ||
streaming=False, | ||
) | ||
messages = [ | ||
HumanMessage(content="Hello"), | ||
] | ||
response = chat(messages) | ||
assert isinstance(response, BaseMessage) | ||
assert isinstance(response.content, str) | ||
|
||
|
||
def test_chat_yuan2_system_message() -> None: | ||
"""Test ChatYuan2 wrapper with system message.""" | ||
chat = ChatYuan2( | ||
yuan2_api_key="EMPTY", | ||
yuan2_api_base="http://127.0.0.1:8001/v1", | ||
temperature=1.0, | ||
model_name="yuan2", | ||
max_retries=3, | ||
streaming=False, | ||
) | ||
messages = [ | ||
SystemMessage(content="You are an AI assistant."), | ||
HumanMessage(content="Hello"), | ||
] | ||
response = chat(messages) | ||
assert isinstance(response, BaseMessage) | ||
assert isinstance(response.content, str) | ||
|
||
|
||
@pytest.mark.scheduled | ||
def test_chat_yuan2_generate() -> None: | ||
"""Test ChatYuan2 wrapper with generate.""" | ||
chat = ChatYuan2( | ||
yuan2_api_key="EMPTY", | ||
yuan2_api_base="http://127.0.0.1:8001/v1", | ||
temperature=1.0, | ||
model_name="yuan2", | ||
max_retries=3, | ||
streaming=False, | ||
) | ||
messages = [ | ||
HumanMessage(content="Hello"), | ||
] | ||
response = chat.generate([messages]) | ||
assert isinstance(response, LLMResult) | ||
assert len(response.generations) == 1 | ||
assert response.llm_output | ||
generation = response.generations[0] | ||
for gen in generation: | ||
assert isinstance(gen, ChatGeneration) | ||
assert isinstance(gen.text, str) | ||
assert gen.text == gen.message.content | ||
|
||
@pytest.mark.scheduled | ||
def test_chat_yuan2_streaming() -> None: | ||
"""Test that streaming correctly invokes on_llm_new_token callback.""" | ||
callback_handler = FakeCallbackHandler() | ||
callback_manager = CallbackManager([callback_handler]) | ||
|
||
chat = ChatYuan2( | ||
yuan2_api_key="EMPTY", | ||
yuan2_api_base="http://127.0.0.1:8001/v1", | ||
temperature=1.0, | ||
model_name="yuan2", | ||
max_retries=3, | ||
streaming=True, | ||
callback_manager=callback_manager, | ||
) | ||
messages = [ | ||
HumanMessage( | ||
content="Hello" | ||
), | ||
] | ||
response = chat(messages) | ||
assert callback_handler.llm_streams > 0 | ||
assert isinstance(response, BaseMessage) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_async_chat_yuan2() -> None: | ||
"""Test async generation.""" | ||
chat = ChatYuan2( | ||
yuan2_api_key="EMPTY", | ||
yuan2_api_base="http://127.0.0.1:8001/v1", | ||
temperature=1.0, | ||
model_name="yuan2", | ||
max_retries=3, | ||
streaming=False, | ||
) | ||
messages = [ | ||
HumanMessage(content="Hello"), | ||
] | ||
response = await chat.agenerate([messages]) | ||
assert isinstance(response, LLMResult) | ||
assert len(response.generations) == 1 | ||
generations = response.generations[0] | ||
for generation in generations: | ||
assert isinstance(generation, ChatGeneration) | ||
assert isinstance(generation.text, str) | ||
assert generation.text == generation.message.content | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_async_chat_yuan2_streaming() -> None: | ||
"""Test that streaming correctly invokes on_llm_new_token callback.""" | ||
callback_handler = FakeCallbackHandler() | ||
callback_manager = CallbackManager([callback_handler]) | ||
|
||
chat = ChatYuan2( | ||
yuan2_api_key="EMPTY", | ||
yuan2_api_base="http://127.0.0.1:8001/v1", | ||
temperature=1.0, | ||
model_name="yuan2", | ||
max_retries=3, | ||
streaming=True, | ||
callback_manager=callback_manager, | ||
) | ||
messages = [ | ||
HumanMessage( | ||
content="Hello" | ||
), | ||
] | ||
response = await chat.agenerate([messages]) | ||
assert callback_handler.llm_streams > 0 | ||
assert isinstance(response, LLMResult) | ||
assert len(response.generations) == 1 | ||
generations = response.generations[0] | ||
for generation in generations: | ||
assert isinstance(generation, ChatGeneration) | ||
assert isinstance(generation.text, str) | ||
assert generation.text == generation.message.content | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,7 @@ | |
"LlamaEdgeChatService", | ||
"GPTRouter", | ||
"ChatZhipuAI", | ||
"ChatYuan2", | ||
] | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
"""Test ChatYuan2 wrapper.""" | ||
|
||
import pytest | ||
from langchain_core.messages import ( | ||
AIMessage, | ||
HumanMessage, | ||
SystemMessage, | ||
) | ||
|
||
from langchain_community.chat_models.yuan2 import ChatYuan2, _convert_message_to_dict, _convert_dict_to_message | ||
|
||
|
||
@pytest.mark.requires("openai") | ||
def test_yuan2_model_param() -> None: | ||
chat = ChatYuan2(model="foo") | ||
assert chat.model_name == "foo" | ||
chat = ChatYuan2(model_name="foo") | ||
assert chat.model_name == "foo" | ||
|
||
|
||
def test__convert_message_to_dict_human() -> None: | ||
message = HumanMessage(content="foo") | ||
result = _convert_message_to_dict(message) | ||
expected_output = {"role": "user", "content": "foo"} | ||
assert result == expected_output | ||
|
||
|
||
def test__convert_message_to_dict_ai() -> None: | ||
message = AIMessage(content="foo") | ||
result = _convert_message_to_dict(message) | ||
expected_output = {"role": "assistant", "content": "foo"} | ||
assert result == expected_output | ||
|
||
|
||
def test__convert_message_to_dict_system() -> None: | ||
message = SystemMessage(content="foo") | ||
with pytest.raises(TypeError) as e: | ||
_convert_message_to_dict(message) | ||
assert "Got unknown type" in str(e) | ||
|
||
|
||
def test__convert_dict_to_message_human() -> None: | ||
message = {"role": "user", "content": "hello"} | ||
result = _convert_dict_to_message(message) | ||
expected_output = HumanMessage(content="hello") | ||
assert result == expected_output | ||
|
||
|
||
def test__convert_dict_to_message_ai() -> None: | ||
message = {"role": "assistant", "content": "hello"} | ||
result = _convert_dict_to_message(message) | ||
expected_output = AIMessage(content="hello") | ||
assert result == expected_output | ||
|
||
|
||
def test__convert_dict_to_message_system() -> None: | ||
message = {"role": "system", "content": "hello"} | ||
result = _convert_dict_to_message(message) | ||
expected_output = SystemMessage(content="hello") | ||
assert result == expected_output |