Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ChatGLM3 llm usage #14370

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 88 additions & 4 deletions docs/docs/integrations/llms/chatglm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,91 @@
"\n",
"[ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) is the second-generation version of the open-source bilingual (Chinese-English) chat model ChatGLM-6B. It retains the smooth conversation flow and low deployment threshold of the first-generation model, while introducing the new features like better performance, longer context and more efficient inference.\n",
"\n",
"This example goes over how to use LangChain to interact with ChatGLM2-6B Inference for text completion.\n",
"[ChatGLM3](https://github.com/THUDM/ChatGLM3) is a new generation of pre-trained dialogue models jointly released by Zhipu AI and Tsinghua KEG. ChatGLM3-6B is the open-source model in the ChatGLM3 series"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ChatGLM3\n",
"\n",
"This examples goes over how to use LangChain to interact with ChatGLM3-6B Inference for text completion."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import LLMChain\n",
"from langchain.llms.chatglm3 import ChatGLM3\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.schema.messages import AIMessage"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"{question}\"\"\"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"endpoint_url = \"http://127.0.0.1:8000/v1/chat/completions\"\n",
"\n",
"messages = [\n",
" AIMessage(content=\"我将从美国到中国来旅游,出行前希望了解中国的城市\"),\n",
" AIMessage(content=\"欢迎问我任何问题。\"),\n",
"]\n",
"\n",
"llm = ChatGLM3(\n",
" endpoint_url=endpoint_url,\n",
" max_tokens=80000,\n",
" prefix_messages=messages,\n",
" top_p=0.9,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'北京和上海是中国两个不同的城市,它们在很多方面都有所不同。\\n\\n北京是中国的首都,也是历史悠久的城市之一。它有着丰富的历史文化遗产,如故宫、颐和园等,这些景点吸引着众多游客前来观光。北京也是一个政治、文化和教育中心,有很多政府机构和学术机构总部设在北京。\\n\\n上海则是一个现代化的城市,它是中国的经济中心之一。上海拥有许多高楼大厦和国际化的金融机构,是中国最国际化的城市之一。上海也是一个美食和购物天堂,有许多著名的餐厅和购物中心。\\n\\n北京和上海的气候也不同。北京属于温带大陆性气候,冬季寒冷干燥,夏季炎热多风;而上海属于亚热带季风气候,四季分明,春秋宜人。\\n\\n北京和上海有很多不同之处,但都是中国非常重要的城市,每个城市都有自己独特的魅力和特色。'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"question = \"北京和上海两座城市有什么不同?\"\n",
"\n",
"llm_chain.run(question)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ChatGLM and ChatGLM2\n",
"\n",
"The following example shows how to use LangChain to interact with the ChatGLM2-6B Inference to complete text.\n",
"ChatGLM-6B and ChatGLM2-6B has the same api specs, so this example should work with both."
]
},
Expand Down Expand Up @@ -106,7 +190,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "langchain-dev",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -120,9 +204,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
148 changes: 148 additions & 0 deletions libs/community/langchain_community/llms/chatglm3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import json
import logging
from typing import Any, List, Optional, Union

import httpx
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.messages import (
AIMessage,
BaseMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.pydantic_v1 import Field

from langchain_community.llms.utils import enforce_stop_tokens

logger = logging.getLogger(__name__)
HEADERS = {"Content-Type": "application/json"}
DEFAULT_TIMEOUT = 30


def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {"role": "function", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict


class ChatGLM3(LLM):
"""ChatGLM3 LLM service."""

model_name: str = Field(default="chatglm3-6b", alias="model")
endpoint_url: str = "http://127.0.0.1:8000/v1/chat/completions"
"""Endpoint URL to use."""
model_kwargs: Optional[dict] = None
"""Keyword arguments to pass to the model."""
max_tokens: int = 20000
"""Max token allowed to pass to the model."""
temperature: float = 0.1
"""LLM model temperature from 0 to 10."""
top_p: float = 0.7
"""Top P for nucleus sampling from 0 to 1"""
prefix_messages: List[BaseMessage] = Field(default_factory=list)
"""Series of messages for Chat input."""
streaming: bool = False
"""Whether to stream the results or not."""
http_client: Union[Any, None] = None
timeout: int = DEFAULT_TIMEOUT

@property
def _llm_type(self) -> str:
return "chat_glm_3"

@property
def _invocation_params(self) -> dict:
"""Get the parameters used to invoke the model."""
params = {
"model": self.model_name,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"stream": self.streaming,
}
return {**params, **(self.model_kwargs or {})}

@property
def client(self) -> Any:
return self.http_client or httpx.Client(timeout=self.timeout)

def _get_payload(self, prompt: str) -> dict:
params = self._invocation_params
messages = self.prefix_messages + [HumanMessage(content=prompt)]
params.update(
{
"messages": [_convert_message_to_dict(m) for m in messages],
}
)
return params

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to a ChatGLM3 LLM inference endpoint.

Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.

Returns:
The string generated by the model.

Example:
.. code-block:: python

response = chatglm_llm("Who are you?")
"""
payload = self._get_payload(prompt)
logger.debug(f"ChatGLM3 payload: {payload}")

try:
response = self.client.post(
self.endpoint_url, headers=HEADERS, json=payload
)
except httpx.NetworkError as e:
raise ValueError(f"Error raised by inference endpoint: {e}")

logger.debug(f"ChatGLM3 response: {response}")

if response.status_code != 200:
raise ValueError(f"Failed with response: {response}")

try:
parsed_response = response.json()

if isinstance(parsed_response, dict):
content_keys = "choices"
if content_keys in parsed_response:
choices = parsed_response[content_keys]
if len(choices):
text = choices[0]["message"]["content"]
else:
raise ValueError(f"No content in response : {parsed_response}")
else:
raise ValueError(f"Unexpected response type: {parsed_response}")

except json.JSONDecodeError as e:
raise ValueError(
f"Error raised during decoding response from inference endpoint: {e}."
f"\nResponse: {response.text}"
)

if stop is not None:
text = enforce_stop_tokens(text, stop)

return text
2 changes: 1 addition & 1 deletion libs/community/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/community/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ msal = {version = "^1.25.0", optional = true}
databricks-vectorsearch = {version = "^0.21", optional = true}
dgml-utils = {version = "^0.3.0", optional = true}
datasets = {version = "^2.15.0", optional = true}
httpx = "^0.24.1"

[tool.poetry.group.test]
optional = true
Expand Down
1 change: 1 addition & 0 deletions libs/community/tests/unit_tests/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_required_dependencies(poetry_conf: Mapping[str, Any]) -> None:
"SQLAlchemy",
"aiohttp",
"dataclasses-json",
"httpx",
"langchain-core",
"langsmith",
"numpy",
Expand Down
3 changes: 3 additions & 0 deletions libs/langchain/langchain/llms/chatglm3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from langchain_community.llms.chatglm3 import ChatGLM3

__all__ = ["ChatGLM3"]
2 changes: 1 addition & 1 deletion libs/langchain/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/langchain/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ numpy = "^1"
aiohttp = "^3.8.3"
tenacity = "^8.1.0"
jsonpatch = "^1.33"
httpx = "^0.24.1"
azure-core = {version = "^1.26.4", optional=true}
tqdm = {version = ">=4.48.0", optional = true}
openapi-pydantic = {version = "^0.3.2", optional = true}
Expand Down
1 change: 1 addition & 0 deletions libs/langchain/tests/unit_tests/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_required_dependencies(poetry_conf: Mapping[str, Any]) -> None:
"aiohttp",
"async-timeout",
"dataclasses-json",
"httpx",
"jsonpatch",
"langchain-core",
"langsmith",
Expand Down
Loading