From 85d0f5e644796605a0aac731417d6165ad214e80 Mon Sep 17 00:00:00 2001 From: Bam4d Date: Sat, 9 Dec 2023 11:41:05 +0000 Subject: [PATCH] some more docs and testings --- src/mistralai/async_client.py | 52 ++++++++++++++++++++++++++++++++++- src/mistralai/client.py | 48 ++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/src/mistralai/async_client.py b/src/mistralai/async_client.py index 5cd4a56..56cbfc4 100644 --- a/src/mistralai/async_client.py +++ b/src/mistralai/async_client.py @@ -185,6 +185,7 @@ async def _request( stream: bool = False, params: Optional[Dict[str, Any]] = None, ) -> Union[Dict[str, Any], aiohttp.ClientResponse]: + headers = { "Authorization": f"Bearer {self._api_key}", "Content-Type": "application/json", @@ -224,6 +225,22 @@ async def chat( random_seed: Optional[int] = None, safe_mode: bool = True, ) -> ChatCompletionResponse: + """ A asynchronous chat endpoint that returns a single response. + + Args: + model (str): model the name of the model to chat with, e.g. mistral-tiny + messages (List[ChatMessage]): messages an array of messages to chat with, e.g. + [{role: 'user', content: 'What is the best French cheese?'}] + temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5. + max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None. + top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9. + Defaults to None. + random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None. + safe_mode (bool, optional): whether to use safe mode, e.g. true. Defaults to False. + + Returns: + ChatCompletionResponse: a response object containing the generated text. + """ request = self._make_chat_request( model, messages, @@ -247,8 +264,26 @@ async def chat_stream( max_tokens: Optional[int] = None, top_p: Optional[float] = None, random_seed: Optional[int] = None, - safe_mode: bool = True, + safe_mode: bool = False, ) -> AsyncGenerator[ChatCompletionStreamResponse, None]: + """ An Asynchronous chat endpoint that streams responses. + + Args: + model (str): model the name of the model to chat with, e.g. mistral-tiny + messages (List[ChatMessage]): messages an array of messages to chat with, e.g. + [{role: 'user', content: 'What is the best French cheese?'}] + temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5. + max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None. + top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9. + Defaults to None. + random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None. + safe_mode (bool, optional): whether to use safe mode, e.g. true. Defaults to False. + + Returns: + AsyncGenerator[ChatCompletionStreamResponse, None]: + An async generator that yields ChatCompletionStreamResponse objects. + """ + request = self._make_chat_request( model, messages, @@ -281,12 +316,27 @@ async def chat_stream( async def embeddings( self, model: str, input: Union[str, List[str]] ) -> EmbeddingResponse: + """An asynchronous embeddings endpoint that returns embeddings for a single, or batch of inputs + + Args: + model (str): The embedding model to use, e.g. mistral-embed + input (Union[str, List[str]]): The input to embed, + e.g. ['What is the best French cheese?'] + + Returns: + EmbeddingResponse: A response object containing the embeddings. + """ request = {"model": model, "input": input} response = await self._request("post", request, "v1/embeddings") assert isinstance(response, dict), "Bad response from _request" return EmbeddingResponse(**response) async def list_models(self) -> ModelList: + """Returns a list of the available models + + Returns: + ModelList: A response object containing the list of models. + """ response = await self._request("get", {}, "v1/models") assert isinstance(response, dict), "Bad response from _request" return ModelList(**response) diff --git a/src/mistralai/client.py b/src/mistralai/client.py index 74a2e5c..3499c89 100644 --- a/src/mistralai/client.py +++ b/src/mistralai/client.py @@ -108,6 +108,22 @@ def chat( random_seed: Optional[int] = None, safe_mode: bool = True, ) -> ChatCompletionResponse: + """ A chat endpoint that returns a single response. + + Args: + model (str): model the name of the model to chat with, e.g. mistral-tiny + messages (List[ChatMessage]): messages an array of messages to chat with, e.g. + [{role: 'user', content: 'What is the best French cheese?'}] + temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5. + max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None. + top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9. + Defaults to None. + random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None. + safe_mode (bool, optional): whether to use safe mode, e.g. true. Defaults to False. + + Returns: + ChatCompletionResponse: a response object containing the generated text. + """ request = self._make_chat_request( model, messages, @@ -135,6 +151,23 @@ def chat_stream( random_seed: Optional[int] = None, safe_mode: bool = True, ) -> Iterable[ChatCompletionStreamResponse]: + """ A chat endpoint that streams responses. + + Args: + model (str): model the name of the model to chat with, e.g. mistral-tiny + messages (List[ChatMessage]): messages an array of messages to chat with, e.g. + [{role: 'user', content: 'What is the best French cheese?'}] + temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5. + max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None. + top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9. + Defaults to None. + random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None. + safe_mode (bool, optional): whether to use safe mode, e.g. true. Defaults to False. + + Returns: + Iterable[ChatCompletionStreamResponse]: + A generator that yields ChatCompletionStreamResponse objects. + """ request = self._make_chat_request( model, messages, @@ -162,12 +195,27 @@ def chat_stream( yield ChatCompletionStreamResponse(**json_response) def embeddings(self, model: str, input: Union[str, List[str]]) -> EmbeddingResponse: + """An embeddings endpoint that returns embeddings for a single, or batch of inputs + + Args: + model (str): The embedding model to use, e.g. mistral-embed + input (Union[str, List[str]]): The input to embed, + e.g. ['What is the best French cheese?'] + + Returns: + EmbeddingResponse: A response object containing the embeddings. + """ request = {"model": model, "input": input} response = self._request("post", request, "v1/embeddings") assert isinstance(response, dict), "Bad response from _request" return EmbeddingResponse(**response) def list_models(self) -> ModelList: + """Returns a list of the available models + + Returns: + ModelList: A response object containing the list of models. + """ response = self._request("get", {}, "v1/models") assert isinstance(response, dict), "Bad response from _request" return ModelList(**response)