From 431aa09c8f5883eafbffcad5214d0d641c786b02 Mon Sep 17 00:00:00 2001 From: Bob Lin Date: Wed, 6 Dec 2023 21:32:24 +0800 Subject: [PATCH 1/7] Add ChatGLM3 --- libs/langchain/langchain/llms/chatglm3.py | 137 ++++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 libs/langchain/langchain/llms/chatglm3.py diff --git a/libs/langchain/langchain/llms/chatglm3.py b/libs/langchain/langchain/llms/chatglm3.py new file mode 100644 index 0000000000000..1b87d14ce2095 --- /dev/null +++ b/libs/langchain/langchain/llms/chatglm3.py @@ -0,0 +1,137 @@ +import logging +import json +from typing import Any, List, Optional, Dict, Union + +import httpx +from langchain_core.pydantic_v1 import Field +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.base import LLM +from langchain.schema.messages import HumanMessage, BaseMessage, AIMessage, SystemMessage, FunctionMessage +from langchain.llms.utils import enforce_stop_tokens + +logger = logging.getLogger(__name__) +HEADERS = {"Content-Type": "application/json"} +DEFAULT_TIMEOUT = 30 + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + if isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, FunctionMessage): + message_dict = {"role": "function", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + return message_dict + + +class ChatGLM3(LLM): + """ChatGLM3 LLM service. + """ + model_name: str = Field(default="chatglm3-6b", alias="model") + endpoint_url: str = "http://127.0.0.1:8000/v1/chat/completions" + """Endpoint URL to use.""" + model_kwargs: Optional[dict] = None + """Keyword arguments to pass to the model.""" + max_tokens: int = 20000 + """Max token allowed to pass to the model.""" + temperature: float = 0.1 + """LLM model temperature from 0 to 10.""" + top_p: float = 0.7 + """Top P for nucleus sampling from 0 to 1""" + prefix_messages: List[BaseMessage] = Field(default_factory=list) + """Series of messages for Chat input.""" + streaming: bool = False + """Whether to stream the results or not.""" + http_client: Union[Any, None] = None + timeout: int = DEFAULT_TIMEOUT + + @property + def _llm_type(self) -> str: + return "chat_glm_3" + + @property + def _invocation_params(self) -> Dict[str, Any]: + """Get the parameters used to invoke the model.""" + params = { + "model": self.model_name, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "stream": self.streaming, + } + return {**params, **(self.model_kwargs or {})} + + @property + def client(self): + return self.http_client or httpx.Client(timeout=self.timeout) + + def _get_payload(self, prompt: str): + params = self._invocation_params + messages = self.prefix_messages + [HumanMessage(content=prompt)] + params.update({ + "messages": [_convert_message_to_dict(m) for m in messages], + }) + return params + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call out to a ChatGLM3 LLM inference endpoint. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = chatglm_llm("Who are you?") + """ + payload = self._get_payload(prompt) + print(f"ChatGLM3 payload: {payload}") + + try: + response = self.client.post(self.endpoint_url, headers=HEADERS, json=payload) + except httpx.NetworkError as e: + raise ValueError(f"Error raised by inference endpoint: {e}") + + logger.debug(f"ChatGLM3 response: {response}") + + if response.status_code != 200: + raise ValueError(f"Failed with response: {response}") + + try: + parsed_response = response.json() + + if isinstance(parsed_response, dict): + content_keys = "choices" + if content_keys in parsed_response: + choices = parsed_response[content_keys] + if len(choices): + text = choices[0]['message']['content'] + else: + raise ValueError(f"No content in response : {parsed_response}") + else: + raise ValueError(f"Unexpected response type: {parsed_response}") + + except json.JSONDecodeError as e: + raise ValueError( + f"Error raised during decoding response from inference endpoint: {e}." + f"\nResponse: {response.text}" + ) + + if stop is not None: + text = enforce_stop_tokens(text, stop) + + return text From 0280a782487c08e7f7686c074417a3a74a2b86ff Mon Sep 17 00:00:00 2001 From: Bob Lin Date: Thu, 7 Dec 2023 09:28:46 +0800 Subject: [PATCH 2/7] Update chatglm.ipynb --- docs/docs/integrations/llms/chatglm.ipynb | 90 ++++++++++++++++++++++- libs/langchain/langchain/llms/chatglm3.py | 2 +- 2 files changed, 87 insertions(+), 5 deletions(-) diff --git a/docs/docs/integrations/llms/chatglm.ipynb b/docs/docs/integrations/llms/chatglm.ipynb index 82867b0f091cd..95f72daac40a8 100644 --- a/docs/docs/integrations/llms/chatglm.ipynb +++ b/docs/docs/integrations/llms/chatglm.ipynb @@ -11,7 +11,89 @@ "\n", "[ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) is the second-generation version of the open-source bilingual (Chinese-English) chat model ChatGLM-6B. It retains the smooth conversation flow and low deployment threshold of the first-generation model, while introducing the new features like better performance, longer context and more efficient inference.\n", "\n", - "This example goes over how to use LangChain to interact with ChatGLM2-6B Inference for text completion.\n", + "[ChatGLM3](https://github.com/THUDM/ChatGLM3) is a new generation of pre-trained dialogue models jointly released by Zhipu AI and Tsinghua KEG. ChatGLM3-6B is the open-source model in the ChatGLM3 series" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ChatGLM3\n", + "\n", + "This examples goes over how to use LangChain to interact with ChatGLM3-6B Inference for text completion." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains import LLMChain\n", + "from langchain.llms.chatglm3 import ChatGLM3\n", + "from langchain.prompts import PromptTemplate\n", + "\n", + "from langchain.schema.messages import AIMessage" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "template = \"\"\"{question}\"\"\"\n", + "prompt = PromptTemplate(template=template, input_variables=[\"question\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "endpoint_url = \"http://127.0.0.1:8000/v1/chat/completions\"\n", + "\n", + "messages = [AIMessage(content=\"我将从美国到中国来旅游,出行前希望了解中国的城市\"), AIMessage(content=\"欢迎问我任何问题。\")]\n", + "\n", + "llm = ChatGLM3(\n", + " endpoint_url=endpoint_url,\n", + " max_tokens=80000,\n", + " prefix_messages=messages,\n", + " top_p=0.9,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'北京和上海是中国两个不同的城市,它们在很多方面都有所不同。\\n\\n北京是中国的首都,也是历史悠久的城市之一。它有着丰富的历史文化遗产,如故宫、颐和园等,这些景点吸引着众多游客前来观光。北京也是一个政治、文化和教育中心,有很多政府机构和学术机构总部设在北京。\\n\\n上海则是一个现代化的城市,它是中国的经济中心之一。上海拥有许多高楼大厦和国际化的金融机构,是中国最国际化的城市之一。上海也是一个美食和购物天堂,有许多著名的餐厅和购物中心。\\n\\n北京和上海的气候也不同。北京属于温带大陆性气候,冬季寒冷干燥,夏季炎热多风;而上海属于亚热带季风气候,四季分明,春秋宜人。\\n\\n北京和上海有很多不同之处,但都是中国非常重要的城市,每个城市都有自己独特的魅力和特色。'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "llm_chain = LLMChain(prompt=prompt, llm=llm)\n", + "question = \"北京和上海两座城市有什么不同?\"\n", + "\n", + "llm_chain.run(question)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ChatGLM and ChatGLM2\n", + "\n", + "The following example shows how to use LangChain to interact with the ChatGLM2-6B Inference to complete text.\n", "ChatGLM-6B and ChatGLM2-6B has the same api specs, so this example should work with both." ] }, @@ -106,7 +188,7 @@ ], "metadata": { "kernelspec": { - "display_name": "langchain-dev", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -120,9 +202,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.11.5" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/libs/langchain/langchain/llms/chatglm3.py b/libs/langchain/langchain/llms/chatglm3.py index 1b87d14ce2095..1912e4cab56f8 100644 --- a/libs/langchain/langchain/llms/chatglm3.py +++ b/libs/langchain/langchain/llms/chatglm3.py @@ -99,7 +99,7 @@ def _call( response = chatglm_llm("Who are you?") """ payload = self._get_payload(prompt) - print(f"ChatGLM3 payload: {payload}") + logger.debug(f"ChatGLM3 payload: {payload}") try: response = self.client.post(self.endpoint_url, headers=HEADERS, json=payload) From 66cce34dd4d3bb82ffcd67da19070f1a28f7700d Mon Sep 17 00:00:00 2001 From: Bob Lin Date: Thu, 7 Dec 2023 09:55:03 +0800 Subject: [PATCH 3/7] format --- docs/docs/integrations/llms/chatglm.ipynb | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/docs/integrations/llms/chatglm.ipynb b/docs/docs/integrations/llms/chatglm.ipynb index 95f72daac40a8..2df605afaa094 100644 --- a/docs/docs/integrations/llms/chatglm.ipynb +++ b/docs/docs/integrations/llms/chatglm.ipynb @@ -32,7 +32,6 @@ "from langchain.chains import LLMChain\n", "from langchain.llms.chatglm3 import ChatGLM3\n", "from langchain.prompts import PromptTemplate\n", - "\n", "from langchain.schema.messages import AIMessage" ] }, @@ -54,7 +53,10 @@ "source": [ "endpoint_url = \"http://127.0.0.1:8000/v1/chat/completions\"\n", "\n", - "messages = [AIMessage(content=\"我将从美国到中国来旅游,出行前希望了解中国的城市\"), AIMessage(content=\"欢迎问我任何问题。\")]\n", + "messages = [\n", + " AIMessage(content=\"我将从美国到中国来旅游,出行前希望了解中国的城市\"),\n", + " AIMessage(content=\"欢迎问我任何问题。\"),\n", + "]\n", "\n", "llm = ChatGLM3(\n", " endpoint_url=endpoint_url,\n", From 2d9125b84a3f6c918d6e5a96e0650bce208ab337 Mon Sep 17 00:00:00 2001 From: Bob Lin Date: Thu, 7 Dec 2023 10:05:52 +0800 Subject: [PATCH 4/7] format --- libs/langchain/langchain/llms/chatglm3.py | 31 +++++++++++++++-------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/libs/langchain/langchain/llms/chatglm3.py b/libs/langchain/langchain/llms/chatglm3.py index 1912e4cab56f8..e6ac9223c0ece 100644 --- a/libs/langchain/langchain/llms/chatglm3.py +++ b/libs/langchain/langchain/llms/chatglm3.py @@ -1,13 +1,20 @@ -import logging import json -from typing import Any, List, Optional, Dict, Union +import logging +from typing import Any, Dict, List, Optional, Union import httpx from langchain_core.pydantic_v1 import Field + from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM -from langchain.schema.messages import HumanMessage, BaseMessage, AIMessage, SystemMessage, FunctionMessage from langchain.llms.utils import enforce_stop_tokens +from langchain.schema.messages import ( + AIMessage, + BaseMessage, + FunctionMessage, + HumanMessage, + SystemMessage, +) logger = logging.getLogger(__name__) HEADERS = {"Content-Type": "application/json"} @@ -29,8 +36,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: class ChatGLM3(LLM): - """ChatGLM3 LLM service. - """ + """ChatGLM3 LLM service.""" + model_name: str = Field(default="chatglm3-6b", alias="model") endpoint_url: str = "http://127.0.0.1:8000/v1/chat/completions" """Endpoint URL to use.""" @@ -72,9 +79,11 @@ def client(self): def _get_payload(self, prompt: str): params = self._invocation_params messages = self.prefix_messages + [HumanMessage(content=prompt)] - params.update({ - "messages": [_convert_message_to_dict(m) for m in messages], - }) + params.update( + { + "messages": [_convert_message_to_dict(m) for m in messages], + } + ) return params def _call( @@ -102,7 +111,9 @@ def _call( logger.debug(f"ChatGLM3 payload: {payload}") try: - response = self.client.post(self.endpoint_url, headers=HEADERS, json=payload) + response = self.client.post( + self.endpoint_url, headers=HEADERS, json=payload + ) except httpx.NetworkError as e: raise ValueError(f"Error raised by inference endpoint: {e}") @@ -119,7 +130,7 @@ def _call( if content_keys in parsed_response: choices = parsed_response[content_keys] if len(choices): - text = choices[0]['message']['content'] + text = choices[0]["message"]["content"] else: raise ValueError(f"No content in response : {parsed_response}") else: From 9a136f46beb13a399cb99d50d2ba3d178ae1e6f7 Mon Sep 17 00:00:00 2001 From: Bob Lin Date: Thu, 7 Dec 2023 10:28:50 +0800 Subject: [PATCH 5/7] Fix annotation --- libs/langchain/langchain/llms/chatglm3.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/llms/chatglm3.py b/libs/langchain/langchain/llms/chatglm3.py index e6ac9223c0ece..c5c7672dadf84 100644 --- a/libs/langchain/langchain/llms/chatglm3.py +++ b/libs/langchain/langchain/llms/chatglm3.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, List, Optional, Union import httpx from langchain_core.pydantic_v1 import Field @@ -61,7 +61,7 @@ def _llm_type(self) -> str: return "chat_glm_3" @property - def _invocation_params(self) -> Dict[str, Any]: + def _invocation_params(self) -> dict: """Get the parameters used to invoke the model.""" params = { "model": self.model_name, @@ -73,10 +73,10 @@ def _invocation_params(self) -> Dict[str, Any]: return {**params, **(self.model_kwargs or {})} @property - def client(self): + def client(self) -> Any: return self.http_client or httpx.Client(timeout=self.timeout) - def _get_payload(self, prompt: str): + def _get_payload(self, prompt: str) -> dict: params = self._invocation_params messages = self.prefix_messages + [HumanMessage(content=prompt)] params.update( From 646ce689bb610b4e6fd772d7cf1f21e861ed390c Mon Sep 17 00:00:00 2001 From: Bob Lin Date: Tue, 12 Dec 2023 10:39:24 +0800 Subject: [PATCH 6/7] Add `langchain_community/llms/chatglm3.py` --- .../langchain_community/llms/chatglm3.py | 147 +++++++++++++++++ libs/langchain/langchain/llms/chatglm3.py | 149 +----------------- 2 files changed, 149 insertions(+), 147 deletions(-) create mode 100644 libs/community/langchain_community/llms/chatglm3.py diff --git a/libs/community/langchain_community/llms/chatglm3.py b/libs/community/langchain_community/llms/chatglm3.py new file mode 100644 index 0000000000000..950f8bc2b0fa9 --- /dev/null +++ b/libs/community/langchain_community/llms/chatglm3.py @@ -0,0 +1,147 @@ +import json +import logging +from typing import Any, List, Optional, Union + +import httpx +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.base import LLM +from langchain.llms.utils import enforce_stop_tokens +from langchain.schema.messages import ( + AIMessage, + BaseMessage, + FunctionMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.pydantic_v1 import Field + +logger = logging.getLogger(__name__) +HEADERS = {"Content-Type": "application/json"} +DEFAULT_TIMEOUT = 30 + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + if isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, FunctionMessage): + message_dict = {"role": "function", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + return message_dict + + +class ChatGLM3(LLM): + """ChatGLM3 LLM service.""" + + model_name: str = Field(default="chatglm3-6b", alias="model") + endpoint_url: str = "http://127.0.0.1:8000/v1/chat/completions" + """Endpoint URL to use.""" + model_kwargs: Optional[dict] = None + """Keyword arguments to pass to the model.""" + max_tokens: int = 20000 + """Max token allowed to pass to the model.""" + temperature: float = 0.1 + """LLM model temperature from 0 to 10.""" + top_p: float = 0.7 + """Top P for nucleus sampling from 0 to 1""" + prefix_messages: List[BaseMessage] = Field(default_factory=list) + """Series of messages for Chat input.""" + streaming: bool = False + """Whether to stream the results or not.""" + http_client: Union[Any, None] = None + timeout: int = DEFAULT_TIMEOUT + + @property + def _llm_type(self) -> str: + return "chat_glm_3" + + @property + def _invocation_params(self) -> dict: + """Get the parameters used to invoke the model.""" + params = { + "model": self.model_name, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "stream": self.streaming, + } + return {**params, **(self.model_kwargs or {})} + + @property + def client(self) -> Any: + return self.http_client or httpx.Client(timeout=self.timeout) + + def _get_payload(self, prompt: str) -> dict: + params = self._invocation_params + messages = self.prefix_messages + [HumanMessage(content=prompt)] + params.update( + { + "messages": [_convert_message_to_dict(m) for m in messages], + } + ) + return params + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call out to a ChatGLM3 LLM inference endpoint. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = chatglm_llm("Who are you?") + """ + payload = self._get_payload(prompt) + logger.debug(f"ChatGLM3 payload: {payload}") + + try: + response = self.client.post( + self.endpoint_url, headers=HEADERS, json=payload + ) + except httpx.NetworkError as e: + raise ValueError(f"Error raised by inference endpoint: {e}") + + logger.debug(f"ChatGLM3 response: {response}") + + if response.status_code != 200: + raise ValueError(f"Failed with response: {response}") + + try: + parsed_response = response.json() + + if isinstance(parsed_response, dict): + content_keys = "choices" + if content_keys in parsed_response: + choices = parsed_response[content_keys] + if len(choices): + text = choices[0]["message"]["content"] + else: + raise ValueError(f"No content in response : {parsed_response}") + else: + raise ValueError(f"Unexpected response type: {parsed_response}") + + except json.JSONDecodeError as e: + raise ValueError( + f"Error raised during decoding response from inference endpoint: {e}." + f"\nResponse: {response.text}" + ) + + if stop is not None: + text = enforce_stop_tokens(text, stop) + + return text diff --git a/libs/langchain/langchain/llms/chatglm3.py b/libs/langchain/langchain/llms/chatglm3.py index c5c7672dadf84..9886c152e41ea 100644 --- a/libs/langchain/langchain/llms/chatglm3.py +++ b/libs/langchain/langchain/llms/chatglm3.py @@ -1,148 +1,3 @@ -import json -import logging -from typing import Any, List, Optional, Union +from langchain_community.llms.chatglm3 import ChatGLM3 -import httpx -from langchain_core.pydantic_v1 import Field - -from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.llms.base import LLM -from langchain.llms.utils import enforce_stop_tokens -from langchain.schema.messages import ( - AIMessage, - BaseMessage, - FunctionMessage, - HumanMessage, - SystemMessage, -) - -logger = logging.getLogger(__name__) -HEADERS = {"Content-Type": "application/json"} -DEFAULT_TIMEOUT = 30 - - -def _convert_message_to_dict(message: BaseMessage) -> dict: - if isinstance(message, HumanMessage): - message_dict = {"role": "user", "content": message.content} - elif isinstance(message, AIMessage): - message_dict = {"role": "assistant", "content": message.content} - elif isinstance(message, SystemMessage): - message_dict = {"role": "system", "content": message.content} - elif isinstance(message, FunctionMessage): - message_dict = {"role": "function", "content": message.content} - else: - raise ValueError(f"Got unknown type {message}") - return message_dict - - -class ChatGLM3(LLM): - """ChatGLM3 LLM service.""" - - model_name: str = Field(default="chatglm3-6b", alias="model") - endpoint_url: str = "http://127.0.0.1:8000/v1/chat/completions" - """Endpoint URL to use.""" - model_kwargs: Optional[dict] = None - """Keyword arguments to pass to the model.""" - max_tokens: int = 20000 - """Max token allowed to pass to the model.""" - temperature: float = 0.1 - """LLM model temperature from 0 to 10.""" - top_p: float = 0.7 - """Top P for nucleus sampling from 0 to 1""" - prefix_messages: List[BaseMessage] = Field(default_factory=list) - """Series of messages for Chat input.""" - streaming: bool = False - """Whether to stream the results or not.""" - http_client: Union[Any, None] = None - timeout: int = DEFAULT_TIMEOUT - - @property - def _llm_type(self) -> str: - return "chat_glm_3" - - @property - def _invocation_params(self) -> dict: - """Get the parameters used to invoke the model.""" - params = { - "model": self.model_name, - "temperature": self.temperature, - "max_tokens": self.max_tokens, - "top_p": self.top_p, - "stream": self.streaming, - } - return {**params, **(self.model_kwargs or {})} - - @property - def client(self) -> Any: - return self.http_client or httpx.Client(timeout=self.timeout) - - def _get_payload(self, prompt: str) -> dict: - params = self._invocation_params - messages = self.prefix_messages + [HumanMessage(content=prompt)] - params.update( - { - "messages": [_convert_message_to_dict(m) for m in messages], - } - ) - return params - - def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - """Call out to a ChatGLM3 LLM inference endpoint. - - Args: - prompt: The prompt to pass into the model. - stop: Optional list of stop words to use when generating. - - Returns: - The string generated by the model. - - Example: - .. code-block:: python - - response = chatglm_llm("Who are you?") - """ - payload = self._get_payload(prompt) - logger.debug(f"ChatGLM3 payload: {payload}") - - try: - response = self.client.post( - self.endpoint_url, headers=HEADERS, json=payload - ) - except httpx.NetworkError as e: - raise ValueError(f"Error raised by inference endpoint: {e}") - - logger.debug(f"ChatGLM3 response: {response}") - - if response.status_code != 200: - raise ValueError(f"Failed with response: {response}") - - try: - parsed_response = response.json() - - if isinstance(parsed_response, dict): - content_keys = "choices" - if content_keys in parsed_response: - choices = parsed_response[content_keys] - if len(choices): - text = choices[0]["message"]["content"] - else: - raise ValueError(f"No content in response : {parsed_response}") - else: - raise ValueError(f"Unexpected response type: {parsed_response}") - - except json.JSONDecodeError as e: - raise ValueError( - f"Error raised during decoding response from inference endpoint: {e}." - f"\nResponse: {response.text}" - ) - - if stop is not None: - text = enforce_stop_tokens(text, stop) - - return text +__all__ = ["ChatGLM3"] From 0c2005e8c79b605355cdf10d5e2dfe20bb025356 Mon Sep 17 00:00:00 2001 From: Bob Lin Date: Tue, 12 Dec 2023 11:00:08 +0800 Subject: [PATCH 7/7] Add httpx --- libs/community/langchain_community/llms/chatglm3.py | 9 +++++---- libs/community/poetry.lock | 2 +- libs/community/pyproject.toml | 1 + libs/community/tests/unit_tests/test_dependencies.py | 1 + libs/langchain/poetry.lock | 2 +- libs/langchain/pyproject.toml | 1 + libs/langchain/tests/unit_tests/test_dependencies.py | 1 + 7 files changed, 11 insertions(+), 6 deletions(-) diff --git a/libs/community/langchain_community/llms/chatglm3.py b/libs/community/langchain_community/llms/chatglm3.py index 950f8bc2b0fa9..03bd088af77e6 100644 --- a/libs/community/langchain_community/llms/chatglm3.py +++ b/libs/community/langchain_community/llms/chatglm3.py @@ -3,10 +3,9 @@ from typing import Any, List, Optional, Union import httpx -from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.llms.base import LLM -from langchain.llms.utils import enforce_stop_tokens -from langchain.schema.messages import ( +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.llms import LLM +from langchain_core.messages import ( AIMessage, BaseMessage, FunctionMessage, @@ -15,6 +14,8 @@ ) from langchain_core.pydantic_v1 import Field +from langchain_community.llms.utils import enforce_stop_tokens + logger = logging.getLogger(__name__) HEADERS = {"Content-Type": "application/json"} DEFAULT_TIMEOUT = 30 diff --git a/libs/community/poetry.lock b/libs/community/poetry.lock index 160e328cffa44..c1d0762f1dd73 100644 --- a/libs/community/poetry.lock +++ b/libs/community/poetry.lock @@ -8485,4 +8485,4 @@ extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "as [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "e3bacf389a13d283c4dd29e3a673e1863826b4e98785c666fefc10cf714c2f6f" +content-hash = "7e5c043edec4677d594f4f292263bc3157a493df20ff1bbf7049ee8a65b2168b" diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index f05d6580af262..90f411af3ad91 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -83,6 +83,7 @@ msal = {version = "^1.25.0", optional = true} databricks-vectorsearch = {version = "^0.21", optional = true} dgml-utils = {version = "^0.3.0", optional = true} datasets = {version = "^2.15.0", optional = true} +httpx = "^0.24.1" [tool.poetry.group.test] optional = true diff --git a/libs/community/tests/unit_tests/test_dependencies.py b/libs/community/tests/unit_tests/test_dependencies.py index 5f9c8bbd383cb..e24e9179a4f63 100644 --- a/libs/community/tests/unit_tests/test_dependencies.py +++ b/libs/community/tests/unit_tests/test_dependencies.py @@ -41,6 +41,7 @@ def test_required_dependencies(poetry_conf: Mapping[str, Any]) -> None: "SQLAlchemy", "aiohttp", "dataclasses-json", + "httpx", "langchain-core", "langsmith", "numpy", diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index 0da142833dd52..3f5bec6c97a0e 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -9103,4 +9103,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "0b232a037505cefcdf2203edc9d750e70e2e52a297475490022402994c3036a3" +content-hash = "f80cddaa31f721c8ddda48024217b35c0dbc89ffe7171c000ea0a4859720a361" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index b49d8ef5816eb..68d95431f3f3b 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -22,6 +22,7 @@ numpy = "^1" aiohttp = "^3.8.3" tenacity = "^8.1.0" jsonpatch = "^1.33" +httpx = "^0.24.1" azure-core = {version = "^1.26.4", optional=true} tqdm = {version = ">=4.48.0", optional = true} openapi-pydantic = {version = "^0.3.2", optional = true} diff --git a/libs/langchain/tests/unit_tests/test_dependencies.py b/libs/langchain/tests/unit_tests/test_dependencies.py index 872b01d6213b0..18ced0511ffad 100644 --- a/libs/langchain/tests/unit_tests/test_dependencies.py +++ b/libs/langchain/tests/unit_tests/test_dependencies.py @@ -42,6 +42,7 @@ def test_required_dependencies(poetry_conf: Mapping[str, Any]) -> None: "aiohttp", "async-timeout", "dataclasses-json", + "httpx", "jsonpatch", "langchain-core", "langsmith",