Skip to content

Commit

Permalink
some more docs and testings
Browse files Browse the repository at this point in the history
  • Loading branch information
Bam4d committed Dec 9, 2023
1 parent b021a56 commit 85d0f5e
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 1 deletion.
52 changes: 51 additions & 1 deletion src/mistralai/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ async def _request(
stream: bool = False,
params: Optional[Dict[str, Any]] = None,
) -> Union[Dict[str, Any], aiohttp.ClientResponse]:

headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
Expand Down Expand Up @@ -224,6 +225,22 @@ async def chat(
random_seed: Optional[int] = None,
safe_mode: bool = True,
) -> ChatCompletionResponse:
""" A asynchronous chat endpoint that returns a single response.
Args:
model (str): model the name of the model to chat with, e.g. mistral-tiny
messages (List[ChatMessage]): messages an array of messages to chat with, e.g.
[{role: 'user', content: 'What is the best French cheese?'}]
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
Defaults to None.
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
safe_mode (bool, optional): whether to use safe mode, e.g. true. Defaults to False.
Returns:
ChatCompletionResponse: a response object containing the generated text.
"""
request = self._make_chat_request(
model,
messages,
Expand All @@ -247,8 +264,26 @@ async def chat_stream(
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
random_seed: Optional[int] = None,
safe_mode: bool = True,
safe_mode: bool = False,
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
""" An Asynchronous chat endpoint that streams responses.
Args:
model (str): model the name of the model to chat with, e.g. mistral-tiny
messages (List[ChatMessage]): messages an array of messages to chat with, e.g.
[{role: 'user', content: 'What is the best French cheese?'}]
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
Defaults to None.
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
safe_mode (bool, optional): whether to use safe mode, e.g. true. Defaults to False.
Returns:
AsyncGenerator[ChatCompletionStreamResponse, None]:
An async generator that yields ChatCompletionStreamResponse objects.
"""

request = self._make_chat_request(
model,
messages,
Expand Down Expand Up @@ -281,12 +316,27 @@ async def chat_stream(
async def embeddings(
self, model: str, input: Union[str, List[str]]
) -> EmbeddingResponse:
"""An asynchronous embeddings endpoint that returns embeddings for a single, or batch of inputs
Args:
model (str): The embedding model to use, e.g. mistral-embed
input (Union[str, List[str]]): The input to embed,
e.g. ['What is the best French cheese?']
Returns:
EmbeddingResponse: A response object containing the embeddings.
"""
request = {"model": model, "input": input}
response = await self._request("post", request, "v1/embeddings")
assert isinstance(response, dict), "Bad response from _request"
return EmbeddingResponse(**response)

async def list_models(self) -> ModelList:
"""Returns a list of the available models
Returns:
ModelList: A response object containing the list of models.
"""
response = await self._request("get", {}, "v1/models")
assert isinstance(response, dict), "Bad response from _request"
return ModelList(**response)
48 changes: 48 additions & 0 deletions src/mistralai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,22 @@ def chat(
random_seed: Optional[int] = None,
safe_mode: bool = True,
) -> ChatCompletionResponse:
""" A chat endpoint that returns a single response.
Args:
model (str): model the name of the model to chat with, e.g. mistral-tiny
messages (List[ChatMessage]): messages an array of messages to chat with, e.g.
[{role: 'user', content: 'What is the best French cheese?'}]
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
Defaults to None.
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
safe_mode (bool, optional): whether to use safe mode, e.g. true. Defaults to False.
Returns:
ChatCompletionResponse: a response object containing the generated text.
"""
request = self._make_chat_request(
model,
messages,
Expand Down Expand Up @@ -135,6 +151,23 @@ def chat_stream(
random_seed: Optional[int] = None,
safe_mode: bool = True,
) -> Iterable[ChatCompletionStreamResponse]:
""" A chat endpoint that streams responses.
Args:
model (str): model the name of the model to chat with, e.g. mistral-tiny
messages (List[ChatMessage]): messages an array of messages to chat with, e.g.
[{role: 'user', content: 'What is the best French cheese?'}]
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
Defaults to None.
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
safe_mode (bool, optional): whether to use safe mode, e.g. true. Defaults to False.
Returns:
Iterable[ChatCompletionStreamResponse]:
A generator that yields ChatCompletionStreamResponse objects.
"""
request = self._make_chat_request(
model,
messages,
Expand Down Expand Up @@ -162,12 +195,27 @@ def chat_stream(
yield ChatCompletionStreamResponse(**json_response)

def embeddings(self, model: str, input: Union[str, List[str]]) -> EmbeddingResponse:
"""An embeddings endpoint that returns embeddings for a single, or batch of inputs
Args:
model (str): The embedding model to use, e.g. mistral-embed
input (Union[str, List[str]]): The input to embed,
e.g. ['What is the best French cheese?']
Returns:
EmbeddingResponse: A response object containing the embeddings.
"""
request = {"model": model, "input": input}
response = self._request("post", request, "v1/embeddings")
assert isinstance(response, dict), "Bad response from _request"
return EmbeddingResponse(**response)

def list_models(self) -> ModelList:
"""Returns a list of the available models
Returns:
ModelList: A response object containing the list of models.
"""
response = self._request("get", {}, "v1/models")
assert isinstance(response, dict), "Bad response from _request"
return ModelList(**response)

0 comments on commit 85d0f5e

Please sign in to comment.