Skip to content

Commit

Permalink
Improve error handling and support stream method
Browse files Browse the repository at this point in the history
1. The old code raised an ValidationError: pydantic_core._pydantic_core.ValidationError: 1 validation error for Xinference when import Xinference from xinference.py. This issue has been resolved by adjusting it's type and default value.
2. Rewrite the _stream method so that the chain.stream() can be used to return data streams.
  • Loading branch information
TheSongg authored Jan 14, 2025
1 parent c55af44 commit 5e1d263
Showing 1 changed file with 68 additions and 2 deletions.
70 changes: 68 additions & 2 deletions libs/community/langchain_community/llms/xinference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, Optional, Union, Iterator

Check failure on line 1 in libs/community/langchain_community/llms/xinference.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (E501)

langchain_community/llms/xinference.py:1:89: E501 Line too long (96 > 88)

Check failure on line 1 in libs/community/langchain_community/llms/xinference.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (E501)

langchain_community/llms/xinference.py:1:89: E501 Line too long (96 > 88)

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk, LLMResult

Check failure on line 5 in libs/community/langchain_community/llms/xinference.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (F401)

langchain_community/llms/xinference.py:5:53: F401 `langchain_core.outputs.LLMResult` imported but unused

Check failure on line 5 in libs/community/langchain_community/llms/xinference.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (F401)

langchain_community/llms/xinference.py:5:53: F401 `langchain_core.outputs.LLMResult` imported but unused

if TYPE_CHECKING:

Check failure on line 7 in libs/community/langchain_community/llms/xinference.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (I001)

langchain_community/llms/xinference.py:1:1: I001 Import block is un-sorted or un-formatted

Check failure on line 7 in libs/community/langchain_community/llms/xinference.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (I001)

langchain_community/llms/xinference.py:1:1: I001 Import block is un-sorted or un-formatted
from xinference.client import RESTfulChatModelHandle, RESTfulGenerateModelHandle
Expand Down Expand Up @@ -81,7 +82,7 @@ class Xinference(LLM):
""" # noqa: E501

client: Any
client: Optional[Any] = None
server_url: Optional[str]
"""URL of the xinference server"""
model_uid: Optional[str]
Expand Down Expand Up @@ -214,3 +215,68 @@ def _stream_generate(
token=token, verbose=self.verbose, log_probs=log_probs
)
yield token

def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
generate_config = kwargs.get("generate_config", {})
generate_config = {**self.model_kwargs, **generate_config}
if stop:
generate_config["stop"] = stop
for stream_resp in self._create_generate_stream(prompt, generate_config):
if stream_resp:
chunk = self._stream_response_to_generation_chunk(stream_resp)
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
yield chunk

def _create_generate_stream(
self,
prompt: str,
generate_config: Optional[Dict[str, List[str]]] = None
) -> Iterator[str]:
model = self.client.get_model(self.model_uid)
yield from self.create_stream(
model,
prompt,
generate_config,
)

@staticmethod
def _stream_response_to_generation_chunk(
stream_response: str,
) -> GenerationChunk:
"""Convert a stream response to a generation chunk."""
token = ''
if isinstance(stream_response, dict):
choices = stream_response.get("choices", [])
if choices:
choice = choices[0]
if isinstance(choice, dict):
token = choice.get("text", "")

if not stream_response["choices"]:
return GenerationChunk(text=token)

return GenerationChunk(
text=token,
generation_info=dict(
finish_reason=stream_response["choices"][0].get("finish_reason", None),
logprobs=stream_response["choices"][0].get("logprobs", None),
),
)

@staticmethod
def create_stream(
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle"],
prompt: str,
generate_config: Optional[Dict[str, List[str]]] = None
) -> Iterator[str]:
return model.generate(prompt=prompt, generate_config=generate_config)

0 comments on commit 5e1d263

Please sign in to comment.