Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Community: integrate chat models with Yuan2.0 #16575

Merged
merged 18 commits into from
Feb 13, 2024
Merged
Prev Previous commit
Next Next commit
fmt
baskaryan committed Feb 13, 2024
commit 320939f36f9703d9a18a1811c77a83d74b8ee742
225 changes: 113 additions & 112 deletions libs/community/langchain_community/chat_models/yuan2.py
Original file line number Diff line number Diff line change
@@ -59,102 +59,6 @@
logger = logging.getLogger(__name__)


def _create_retry_decorator(llm: ChatYuan2) -> Callable[[Any], Any]:
import openai

min_seconds = 1
max_seconds = 60
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(llm.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.APITimeoutError)
| retry_if_exception_type(openai.APIError)
| retry_if_exception_type(openai.APIConnectionError)
| retry_if_exception_type(openai.RateLimitError)
| retry_if_exception_type(openai.InternalServerError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)


async def acompletion_with_retry(llm: ChatYuan2, **kwargs: Any) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm)

@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
return await llm.async_client.create(**kwargs)

return await _completion_with_retry(**kwargs)


def _convert_delta_to_message_chunk(
_dict: ChatCompletionMessage, default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""

if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content)


def _convert_dict_to_message(_dict: ChatCompletionMessage) -> BaseMessage:
role = _dict.role
if role == "user":
return HumanMessage(content=_dict.content)
elif role == "assistant":
content = _dict.content or ""
return AIMessage(content=content)
elif role == "system":
return SystemMessage(content=_dict.content)
else:
return ChatMessage(content=_dict.content, role=role)


def _convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a LangChain message to a dictionary.

Args:
message: The LangChain message.

Returns:
The dictionary.
"""
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"name": message.name,
"content": message.content,
}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict


class ChatYuan2(BaseChatModel):
"""`Yuan2.0` Chat models API.

@@ -170,25 +74,10 @@ class ChatYuan2(BaseChatModel):
.. code-block:: python

from langchain_community.chat_models import ChatYuan2

chat = ChatYuan2()
"""

@property
def lc_secrets(self) -> Dict[str, str]:
return {"yuan2_api_key": "YUAN2_API_KEY"}

@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}

if self.yuan2_api_base:
attributes["yuan2_api_base"] = self.yuan2_api_base

if self.yuan2_api_key:
attributes["yuan2_api_key"] = self.yuan2_api_key

return attributes

client: Any #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:

@@ -238,6 +127,22 @@ class Config:

allow_population_by_field_name = True

@property
def lc_secrets(self) -> Dict[str, str]:
return {"yuan2_api_key": "YUAN2_API_KEY"}

@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}

if self.yuan2_api_base:
attributes["yuan2_api_base"] = self.yuan2_api_base

if self.yuan2_api_key:
attributes["yuan2_api_key"] = self.yuan2_api_key

return attributes

@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
@@ -483,3 +388,99 @@ def _invocation_params(self) -> Mapping[str, Any]:
def _llm_type(self) -> str:
"""Return type of chat model."""
return "chat-yuan2"


def _create_retry_decorator(llm: ChatYuan2) -> Callable[[Any], Any]:
import openai

min_seconds = 1
max_seconds = 60
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(llm.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.APITimeoutError)
| retry_if_exception_type(openai.APIError)
| retry_if_exception_type(openai.APIConnectionError)
| retry_if_exception_type(openai.RateLimitError)
| retry_if_exception_type(openai.InternalServerError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)


async def acompletion_with_retry(llm: ChatYuan2, **kwargs: Any) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm)

@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
return await llm.async_client.create(**kwargs)

return await _completion_with_retry(**kwargs)


def _convert_delta_to_message_chunk(
_dict: ChatCompletionMessage, default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""

if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content)


def _convert_dict_to_message(_dict: ChatCompletionMessage) -> BaseMessage:
role = _dict.role
if role == "user":
return HumanMessage(content=_dict.content)
elif role == "assistant":
content = _dict.content or ""
return AIMessage(content=content)
elif role == "system":
return SystemMessage(content=_dict.content)
else:
return ChatMessage(content=_dict.content, role=role)


def _convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a LangChain message to a dictionary.

Args:
message: The LangChain message.

Returns:
The dictionary.
"""
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"name": message.name,
"content": message.content,
}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict