From 72e12ac385dcfb1d9b2303414360dd8a95f78c44 Mon Sep 17 00:00:00 2001 From: Metin Dumandag <29387993+mdumandag@users.noreply.github.com> Date: Thu, 6 Jun 2024 13:41:24 +0300 Subject: [PATCH] Add chat completion support This PR adds support for the following features: - client.chat().create() and client.chat().prompt() APIs for multi or single turn chat completions. - api/llm support for publish, enqueue, and batch APIs --- README.md | 129 ++++---- tests/asyncio/test_chat.py | 71 +++++ tests/asyncio/test_publish.py | 49 ++- tests/asyncio/test_queue.py | 27 ++ tests/test_chat.py | 67 ++++ tests/test_publish.py | 47 ++- tests/test_queue.py | 26 ++ upstash_qstash/__init__.py | 5 +- upstash_qstash/asyncio/chat.py | 52 ++++ upstash_qstash/asyncio/client.py | 21 +- upstash_qstash/asyncio/publish.py | 12 +- upstash_qstash/asyncio/queue.py | 3 +- upstash_qstash/chat.py | 501 ++++++++++++++++++++++++++++++ upstash_qstash/client.py | 23 +- upstash_qstash/error.py | 23 +- upstash_qstash/publish.py | 40 ++- upstash_qstash/qstash_types.py | 1 - upstash_qstash/queue.py | 3 +- upstash_qstash/upstash_http.py | 181 ++++++++++- 19 files changed, 1185 insertions(+), 96 deletions(-) create mode 100644 tests/asyncio/test_chat.py create mode 100644 tests/test_chat.py create mode 100644 upstash_qstash/asyncio/chat.py create mode 100644 upstash_qstash/chat.py diff --git a/README.md b/README.md index 65461cb..a189287 100644 --- a/README.md +++ b/README.md @@ -15,15 +15,15 @@ from upstash_qstash import Client client = Client("") res = client.publish_json( - { - "url": "https://my-api...", - "body": { - "hello": "world" - }, - "headers": { - "test-header": "test-value", - }, - } + { + "url": "https://my-api...", + "body": { + "hello": "world" + }, + "headers": { + "test-header": "test-value", + }, + } ) print(res["messageId"]) @@ -37,10 +37,10 @@ from upstash_qstash import Client client = Client("") schedules = client.schedules() res = schedules.create( - { - "destination": "https://my-api...", - "cron": "*/5 * * * *", - } + { + "destination": "https://my-api...", + "cron": "*/5 * * * *", + } ) print(res["scheduleId"]) @@ -53,10 +53,10 @@ from upstash_qstash import Receiver # Keys available from the QStash console receiver = Receiver( - { - "current_signing_key": "CURRENT_SIGNING_KEY", - "next_signing_key": "NEXT_SIGNING_KEY", - } + { + "current_signing_key": "CURRENT_SIGNING_KEY", + "next_signing_key": "NEXT_SIGNING_KEY", + } ) # ... in your request handler @@ -64,14 +64,35 @@ receiver = Receiver( signature, body = req.headers["Upstash-Signature"], req.body is_valid = receiver.verify( - { - "body": body, - "signature": signature, - "url": "https://my-api...", # Optional - } + { + "body": body, + "signature": signature, + "url": "https://my-api...", # Optional + } ) ``` +#### Create Chat Completions + +```python +from upstash_qstash import Client + +client = Client("") +chat = client.chat() + +res = chat.create({ + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [ + { + "role": "user", + "content": "What is the capital of Turkey?" + } + ] +}) + +print(res["choices"][0]["message"]["content"]) +``` + #### Additional configuration ```python @@ -86,47 +107,47 @@ from upstash_qstash import Client # "backoff": lambda retry_count: math.exp(retry_count) * 50, # } client = Client("", { - "attempts": 2, - "backoff": lambda retry_count: (2 ** retry_count) * 20, + "attempts": 2, + "backoff": lambda retry_count: (2 ** retry_count) * 20, }) # Create Topic topics = client.topics() topics.upsert_or_add_endpoints( - { - "name": "topic_name", - "endpoints": [ - {"url": "https://my-endpoint-1"}, - {"url": "https://my-endpoint-2"} - ], - } + { + "name": "topic_name", + "endpoints": [ + {"url": "https://my-endpoint-1"}, + {"url": "https://my-endpoint-2"} + ], + } ) # Publish to Topic client.publish_json( - { - "topic": "my-topic", - "body": { - "key": "value" - }, - # Retry sending message to API 3 times - # https://upstash.com/docs/qstash/features/retry - "retries": 3, - # Schedule message to be sent 4 seconds from now - "delay": 4, - # When message is sent, send a request to this URL - # https://upstash.com/docs/qstash/features/callbacks - "callback": "https://my-api.com/callback", - # When message fails to send, send a request to this URL - "failure_callback": "https://my-api.com/failure_callback", - # Headers to forward to the endpoint - "headers": { - "test-header": "test-value", - }, - # Enable content-based deduplication - # https://upstash.com/docs/qstash/features/deduplication#content-based-deduplication - "content_based_deduplication": True, - } + { + "topic": "my-topic", + "body": { + "key": "value" + }, + # Retry sending message to API 3 times + # https://upstash.com/docs/qstash/features/retry + "retries": 3, + # Schedule message to be sent 4 seconds from now + "delay": 4, + # When message is sent, send a request to this URL + # https://upstash.com/docs/qstash/features/callbacks + "callback": "https://my-api.com/callback", + # When message fails to send, send a request to this URL + "failure_callback": "https://my-api.com/failure_callback", + # Headers to forward to the endpoint + "headers": { + "test-header": "test-value", + }, + # Enable content-based deduplication + # https://upstash.com/docs/qstash/features/deduplication#content-based-deduplication + "content_based_deduplication": True, + } ) ``` diff --git a/tests/asyncio/test_chat.py b/tests/asyncio/test_chat.py new file mode 100644 index 0000000..48301ee --- /dev/null +++ b/tests/asyncio/test_chat.py @@ -0,0 +1,71 @@ +from typing import AsyncIterable + +import pytest + +from qstash_tokens import QSTASH_TOKEN +from upstash_qstash.asyncio import Client + + +@pytest.fixture +def client(): + return Client(QSTASH_TOKEN) + + +@pytest.mark.asyncio +async def test_chat_async(client): + res = await client.chat().create( + { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [{"role": "user", "content": "hello"}], + } + ) + + assert res["id"] is not None + + assert res["choices"][0]["message"]["content"] is not None + assert res["choices"][0]["message"]["role"] == "assistant" + + +@pytest.mark.asyncio +async def test_chat_streaming_async(client): + res = await client.chat().create( + { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [{"role": "user", "content": "hello"}], + "stream": True, + } + ) + + async for r in res: + assert r["id"] is not None + assert r["choices"][0]["delta"] is not None + + +@pytest.mark.asyncio +async def test_prompt_async(client): + res = await client.chat().prompt( + { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "user": "hello", + } + ) + + assert res["id"] is not None + + assert res["choices"][0]["message"]["content"] is not None + assert res["choices"][0]["message"]["role"] == "assistant" + + +@pytest.mark.asyncio +async def test_prompt_streaming_async(client): + res = await client.chat().prompt( + { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "user": "hello", + "stream": True, + } + ) + + async for r in res: + assert r["id"] is not None + assert r["choices"][0]["delta"] is not None diff --git a/tests/asyncio/test_publish.py b/tests/asyncio/test_publish.py index 25a4634..1b9860e 100644 --- a/tests/asyncio/test_publish.py +++ b/tests/asyncio/test_publish.py @@ -39,7 +39,7 @@ async def test_publish_to_url_async(client): found_event ), f"Event with messageId {res['messageId']} not found This may be because the latency of the event is too high" assert ( - event["state"] != "ERROR" + event["state"] != "ERROR" ), f"Event with messageId {res['messageId']} was not delivered" @@ -87,3 +87,50 @@ async def test_batch_json_async(client): assert len(res) == N for i in range(N): assert res[i]["messageId"] is not None + + +@pytest.mark.asyncio +async def test_publish_api_llm_async(client): + # not a proper test, because of a dummy callback. + res = await client.publish_json( + { + "api": "llm", + "body": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [ + { + "role": "user", + "content": "hello", + } + ], + }, + "callback": "https://example.com", + } + ) + + assert res["messageId"] is not None + + +@pytest.mark.asyncio +async def test_batch_api_llm_async(client): + # not a proper test, because of a dummy callback. + res = await client.batch_json( + [ + { + "api": "llm", + "body": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [ + { + "role": "user", + "content": "hello", + } + ], + }, + "callback": "https://example.com", + } + ] + ) + + assert len(res) == 1 + assert res[0]["messageId"] is not None diff --git a/tests/asyncio/test_queue.py b/tests/asyncio/test_queue.py index 91173fc..cc8a537 100644 --- a/tests/asyncio/test_queue.py +++ b/tests/asyncio/test_queue.py @@ -94,3 +94,30 @@ async def test_enqueue(client): print("Deleting queue") await queue.delete() + + +@pytest.mark.asyncio +async def test_enqueue_api_llm_async(client): + # not a proper test, because of a dummy callback. + queue = client.queue({"queue_name": "test_queue"}) + + try: + res = await queue.enqueue_json( + { + "api": "llm", + "body": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [ + { + "role": "user", + "content": "hello", + } + ], + }, + "callback": "https://example.com/", + } + ) + + assert res["messageId"] is not None + finally: + await queue.delete() diff --git a/tests/test_chat.py b/tests/test_chat.py new file mode 100644 index 0000000..8bc3a00 --- /dev/null +++ b/tests/test_chat.py @@ -0,0 +1,67 @@ +from typing import Iterable + +import pytest + +from qstash_tokens import QSTASH_TOKEN +from upstash_qstash import Client + + +@pytest.fixture +def client(): + return Client(QSTASH_TOKEN) + + +def test_chat(client): + res = client.chat().create( + { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [{"role": "user", "content": "hello"}], + } + ) + + assert res["id"] is not None + + assert res["choices"][0]["message"]["content"] is not None + assert res["choices"][0]["message"]["role"] == "assistant" + + +def test_chat_streaming(client): + res = client.chat().create( + { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [{"role": "user", "content": "hello"}], + "stream": True, + } + ) + + for r in res: + assert r["id"] is not None + assert r["choices"][0]["delta"] is not None + + +def test_prompt(client): + res = client.chat().prompt( + { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "user": "hello", + } + ) + + assert res["id"] is not None + + assert res["choices"][0]["message"]["content"] is not None + assert res["choices"][0]["message"]["role"] == "assistant" + + +def test_prompt_streaming(client): + res = client.chat().prompt( + { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "user": "hello", + "stream": True, + } + ) + + for r in res: + assert r["id"] is not None + assert r["choices"][0]["delta"] is not None diff --git a/tests/test_publish.py b/tests/test_publish.py index 236fc98..77f81e8 100644 --- a/tests/test_publish.py +++ b/tests/test_publish.py @@ -39,7 +39,7 @@ def test_publish_to_url(client): found_event ), f"Event with messageId {res['messageId']} not found This may be because the latency of the event is too high" assert ( - event["state"] != "ERROR" + event["state"] != "ERROR" ), f"Event with messageId {res['messageId']} was not delivered" @@ -95,3 +95,48 @@ def test_batch_json(client): assert len(res) == N for i in range(N): assert res[i]["messageId"] is not None + + +def test_publish_api_llm(client): + # not a proper test, because of a dummy callback. + res = client.publish_json( + { + "api": "llm", + "body": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [ + { + "role": "user", + "content": "hello", + } + ], + }, + "callback": "https://example.com", + } + ) + + assert res["messageId"] is not None + + +def test_batch_api_llm(client): + # not a proper test, because of a dummy callback. + res = client.batch_json( + [ + { + "api": "llm", + "body": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [ + { + "role": "user", + "content": "hello", + } + ], + }, + "callback": "https://example.com", + } + ] + ) + + assert len(res) == 1 + assert res[0]["messageId"] is not None diff --git a/tests/test_queue.py b/tests/test_queue.py index e43ee57..0ec05d1 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -91,3 +91,29 @@ def test_enqueue(client): print("Deleting queue") queue.delete() + + +def test_enqueue_api_llm(client): + # not a proper test, because of a dummy callback. + queue = client.queue({"queue_name": "test_queue"}) + + try: + res = queue.enqueue_json( + { + "api": "llm", + "body": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [ + { + "role": "user", + "content": "hello", + } + ], + }, + "callback": "https://example.com/", + } + ) + + assert res["messageId"] is not None + finally: + queue.delete() diff --git a/upstash_qstash/__init__.py b/upstash_qstash/__init__.py index 1305195..5021b1c 100644 --- a/upstash_qstash/__init__.py +++ b/upstash_qstash/__init__.py @@ -1,5 +1,6 @@ +from upstash_qstash.asyncio.client import Client as AsyncClient from upstash_qstash.client import Client from upstash_qstash.receiver import Receiver -__version__ = "0.1.0" -__all__ = ["Client", "Receiver"] +__version__ = "1.1.0" +__all__ = ["Client", "AsyncClient", "Receiver"] diff --git a/upstash_qstash/asyncio/chat.py b/upstash_qstash/asyncio/chat.py new file mode 100644 index 0000000..84c96a6 --- /dev/null +++ b/upstash_qstash/asyncio/chat.py @@ -0,0 +1,52 @@ +import json +from typing import AsyncIterable, Union + +from upstash_qstash.chat import Chat as SyncChat +from upstash_qstash.chat import ( + ChatCompletion, + ChatCompletionChunk, + ChatRequest, + PromptRequest, +) +from upstash_qstash.upstash_http import HttpClient + + +class Chat: + def __init__(self, http: HttpClient): + self.http = http + + async def create( + self, req: ChatRequest + ) -> Union[ChatCompletion, AsyncIterable[ChatCompletionChunk]]: + SyncChat._validate_request(req) + body = json.dumps(req) + + if req.get("stream"): + return self.http.request_stream_async( + { + "path": ["llm", "v1", "chat", "completions"], + "method": "POST", + "headers": { + "Content-Type": "application/json", + "Connection": "keep-alive", + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + }, + "body": body, + } + ) + + return await self.http.request_async( + { + "path": ["llm", "v1", "chat", "completions"], + "method": "POST", + "headers": {"Content-Type": "application/json"}, + "body": body, + } + ) + + async def prompt( + self, req: PromptRequest + ) -> Union[ChatCompletion, AsyncIterable[ChatCompletionChunk]]: + chat_req = SyncChat._to_chat_request(req) + return await self.create(chat_req) diff --git a/upstash_qstash/asyncio/client.py b/upstash_qstash/asyncio/client.py index de34c15..c4debb9 100644 --- a/upstash_qstash/asyncio/client.py +++ b/upstash_qstash/asyncio/client.py @@ -1,5 +1,6 @@ -from typing import Optional, Union +from typing import AsyncIterable, Optional, Union +from upstash_qstash.asyncio.chat import Chat from upstash_qstash.asyncio.dlq import DLQ from upstash_qstash.asyncio.events import Events, EventsRequest, GetEventsResponse from upstash_qstash.asyncio.keys import Keys @@ -8,6 +9,13 @@ from upstash_qstash.asyncio.queue import Queue, QueueOpts from upstash_qstash.asyncio.schedules import Schedules from upstash_qstash.asyncio.topics import Topics +from upstash_qstash.chat import Chat as SyncChat +from upstash_qstash.chat import ( + ChatCompletion, + ChatCompletionChunk, + ChatRequest, + PromptRequest, +) from upstash_qstash.qstash_types import RetryConfig from upstash_qstash.upstash_http import HttpClient @@ -29,7 +37,8 @@ def __init__( async def publish(self, req: PublishRequest): """ - If publishing to a URL (req contains 'url'), this method returns a PublishToUrlResponse: + If publishing to a URL (req contains 'url') or an API (req contains 'api'), + this method returns a PublishToUrlResponse: - PublishToUrlResponse: Contains 'messageId' indicating the unique ID of the message and an optional 'deduplicated' boolean indicating if the message is a duplicate. @@ -41,7 +50,7 @@ async def publish(self, req: PublishRequest): :param req: An instance of PublishRequest containing the request details. :return: Response details including the message_id, url (if publishing to a topic), and possibly a deduplicated boolean. The exact return type depends on the publish target. - :raises ValueError: If neither 'url' nor 'topic' is provided, or both are provided. + :raises ValueError: If neither 'url', 'topic', nor 'api' is provided, or more than one of them are provided. """ return await Publish.publish_async(self.http, req) @@ -143,3 +152,9 @@ async def events(self, req: Optional[EventsRequest] = None) -> GetEventsResponse >>> break """ return await Events.get(self.http, req) + + def chat(self) -> Chat: + """ + Access chat completion APIs. + """ + return Chat(self.http) diff --git a/upstash_qstash/asyncio/publish.py b/upstash_qstash/asyncio/publish.py index 71a7814..12601c6 100644 --- a/upstash_qstash/asyncio/publish.py +++ b/upstash_qstash/asyncio/publish.py @@ -14,17 +14,18 @@ class Publish: @staticmethod async def publish_async( - http: HttpClient, req: PublishRequest + http: HttpClient, req: PublishRequest ) -> Union[PublishToUrlResponse, PublishToTopicResponse]: """ Asynchronously publish a message to QStash. """ SyncPublish._validate_request(req) headers = SyncPublish._prepare_headers(req) + destination = SyncPublish._get_destination(req) return await http.request_async( { - "path": ["v2", "publish", req.get("url") or req["topic"]], + "path": ["v2", "publish", destination], "body": req.get("body"), "headers": headers, "method": "POST", @@ -45,7 +46,7 @@ async def publish_json_async(http: HttpClient, req: PublishRequest): @staticmethod async def batch_async( - http: HttpClient, req: BatchRequest + http: HttpClient, req: BatchRequest ) -> List[Union[PublishToUrlResponse, PublishToTopicResponse]]: """ Publish a batch of messages to QStash. @@ -56,9 +57,10 @@ async def batch_async( messages = [] for message in req: + destination = SyncPublish._get_destination(message) messages.append( { - "destination": message.get("url") or message["topic"], + "destination": destination, "headers": message["headers"], "body": message.get("body"), } @@ -77,7 +79,7 @@ async def batch_async( @staticmethod async def batch_json_async( - http: HttpClient, req: BatchRequest + http: HttpClient, req: BatchRequest ) -> List[Union[PublishToUrlResponse, PublishToTopicResponse]]: """ Asynchronously publish a batch of messages to QStash, automatically serializing the body of each message into JSON. diff --git a/upstash_qstash/asyncio/queue.py b/upstash_qstash/asyncio/queue.py index d870a79..696fe0d 100644 --- a/upstash_qstash/asyncio/queue.py +++ b/upstash_qstash/asyncio/queue.py @@ -89,6 +89,7 @@ async def enqueue( Publish._validate_request(req) headers = Publish._prepare_headers(req) + destination = Publish._get_destination(req) return await self.http.request_async( { @@ -96,7 +97,7 @@ async def enqueue( "v2", "enqueue", self.queue_name, - req.get("url") or req["topic"], + destination, ], "body": req.get("body"), "headers": headers, diff --git a/upstash_qstash/chat.py b/upstash_qstash/chat.py new file mode 100644 index 0000000..4b81d69 --- /dev/null +++ b/upstash_qstash/chat.py @@ -0,0 +1,501 @@ +import json +from typing import Dict, Iterable, List, Literal, TypedDict, Union + +from upstash_qstash.error import QstashException +from upstash_qstash.upstash_http import HttpClient + + +class ChatCompletionMessage(TypedDict, total=False): + role: Literal["system", "assistant", "user"] + """The role of the message author.""" + + content: str + """The content of the message.""" + + +ChatModel = Literal[ + "meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2" +] + + +class ChatResponseFormat(TypedDict, total=False): + type: Literal["text", "json_object"] + """Must be one of `text` or `json_object`.""" + + +class TopLogprob(TypedDict, total=False): + token: str + """The token.""" + + bytes: List[int] + """A list of integers representing the UTF-8 bytes representation of the token. + + Useful in instances where characters are represented by multiple tokens and + their byte representations must be combined to generate the correct text + representation. Can be `null` if there is no bytes representation for the token. + """ + + logprob: float + """The log probability of this token, if it is within the top 20 most likely + tokens. + + Otherwise, the value `-9999.0` is used to signify that the token is very + unlikely. + """ + + +class ChatCompletionTokenLogprob(TypedDict, total=False): + token: str + """The token.""" + + bytes: List[int] + """A list of integers representing the UTF-8 bytes representation of the token. + + Useful in instances where characters are represented by multiple tokens and + their byte representations must be combined to generate the correct text + representation. Can be `null` if there is no bytes representation for the token. + """ + + logprob: float + """The log probability of this token, if it is within the top 20 most likely + tokens. + + Otherwise, the value `-9999.0` is used to signify that the token is very + unlikely. + """ + + top_logprobs: List[TopLogprob] + """List of the most likely tokens and their log probability, at this token + position. + + In rare cases, there may be fewer than the number of requested `top_logprobs` + returned. + """ + + +class ChoiceLogprobs(TypedDict, total=False): + content: List[ChatCompletionTokenLogprob] + """A list of message content tokens with log probability information.""" + + +class Choice(TypedDict, total=False): + finish_reason: Literal["stop", "length"] + """The reason the model stopped generating tokens.""" + + index: int + """The index of the choice in the list of choices.""" + + logprobs: ChoiceLogprobs + """Log probability information for the choice.""" + + message: ChatCompletionMessage + """A chat completion message generated by the model.""" + + +class CompletionUsage(TypedDict, total=False): + completion_tokens: int + """Number of tokens in the generated completion.""" + + prompt_tokens: int + """Number of tokens in the prompt.""" + + total_tokens: int + """Total number of tokens used in the request (prompt + completion).""" + + +class ChatCompletion(TypedDict, total=False): + id: str + """A unique identifier for the chat completion.""" + + choices: List[Choice] + """A list of chat completion choices. + + Can be more than one if `n` is greater than 1. + """ + + created: int + """The Unix timestamp (in seconds) of when the chat completion was created.""" + + model: str + """The model used for the chat completion.""" + + object: Literal["chat.completion"] + """The object type, which is always `chat.completion`.""" + + system_fingerprint: str + """This fingerprint represents the backend configuration that the model runs with. + + Can be used in conjunction with the `seed` request parameter to understand when + backend changes have been made that might impact determinism. + """ + + usage: CompletionUsage + """Usage statistics for the completion request.""" + + +class ChunkChoice(TypedDict, total=False): + delta: ChatCompletionMessage + """A chat completion delta generated by streamed model responses.""" + + finish_reason: Literal["stop", "length"] + """The reason the model stopped generating tokens.""" + + index: int + """The index of the choice in the list of choices.""" + + logprobs: ChoiceLogprobs + """Log probability information for the choice.""" + + +class ChatCompletionChunk(TypedDict, total=False): + id: str + """A unique identifier for the chat completion. Each chunk has the same ID.""" + + choices: List[ChunkChoice] + """A list of chat completion choices. + + Can contain more than one elements if `n` is greater than 1. Can also be empty + for the last chunk. + """ + + created: int + """The Unix timestamp (in seconds) of when the chat completion was created. + + Each chunk has the same timestamp. + """ + + model: str + """The model to generate the completion.""" + + object: Literal["chat.completion.chunk"] + """The object type, which is always `chat.completion.chunk`.""" + + system_fingerprint: str + """ + This fingerprint represents the backend configuration that the model runs with. + Can be used in conjunction with the `seed` request parameter to understand when + backend changes have been made that might impact determinism. + """ + + usage: CompletionUsage + """ + Contains a null value except for the last chunk which contains + the token usage statistics for the entire request. + """ + + +class ChatRequest(TypedDict, total=False): + messages: List[ChatCompletionMessage] + """A list of messages comprising the conversation so far.""" + + model: ChatModel + """ID of the model to use""" + + frequency_penalty: float + """ + Number between -2.0 and 2.0. + Positive values penalize new tokens based on their existing frequency + in the text so far, decreasing the model's likelihood to repeat + the same line verbatim. + """ + + logit_bias: Dict[str, int] + """ + Modify the likelihood of specified tokens appearing in the completion. + """ + + logprobs: bool + """ + Whether to return log probabilities of the output tokens or not. + If true, returns the log probabilities of each output token returned + in the content of message. + """ + + top_logprobs: int + """ + An integer between 0 and 20 specifying the number of most likely tokens + to return at each token position, each with an associated log probability. + logprobs must be set to true if this parameter is used. + """ + + max_tokens: int + """ + The maximum number of tokens that can be generated in the chat completion. + + The total length of input tokens and generated tokens is limited by the + model's context length. + """ + + n: int + """ + How many chat completion choices to generate for each input message. + Note that you will be charged based on the number of generated tokens + across all of the choices. Keep n as 1 to minimize costs. + """ + + presence_penalty: float + """ + Number between -2.0 and 2.0. Positive values penalize new tokens based on + whether they appear in the text so far, increasing the model's + likelihood to talk about new topics. + """ + + response_format: ChatResponseFormat + """ + An object specifying the format that the model must output. + + **Important**: when using JSON mode, you must also instruct the model + to produce JSON yourself via a system or user message. Without this, + the model may generate an unending stream of whitespace until the + generation reaches the token limit, resulting in a long-running and + seemingly "stuck" request. Also note that the message content may + be partially cut off if `finish_reason="length"`, which indicates the + generation exceeded max_tokens or the conversation exceeded the max + context length. + """ + + seed: int + """ + If specified, our system will make a best effort to sample deterministically, + such that repeated requests with the same seed and parameters should return + the same result. Determinism is not guaranteed, and you should refer to the + `system_fingerprint` response parameter to monitor changes in the backend. + """ + + stop: Union[str, List[str]] + """ + Up to 4 sequences where the API will stop generating further tokens. + """ + + stream: bool + """ + If set, partial message deltas will be sent. Tokens will be sent as + they become available. + """ + + temperature: float + """ + What sampling temperature to use, between 0 and 2. Higher values like 0.8 + will make the output more random, while lower values like 0.2 will make + it more focused and deterministic. + + We generally recommend altering this or top_p but not both. + """ + + top_p: float + """ + An alternative to sampling with temperature, called nucleus sampling, + where the model considers the results of the tokens with top_p probability + mass. So 0.1 means only the tokens comprising the top 10% probability + mass are considered. + + We generally recommend altering this or temperature but not both. + """ + + +class PromptRequest(TypedDict, total=False): + system: str + """The contents of the system message.""" + + user: str + """The contents of the user message.""" + + model: ChatModel + """ID of the model to use""" + + frequency_penalty: float + """ + Number between -2.0 and 2.0. + Positive values penalize new tokens based on their existing frequency + in the text so far, decreasing the model's likelihood to repeat + the same line verbatim. + """ + + logit_bias: Dict[str, int] + """ + Modify the likelihood of specified tokens appearing in the completion. + """ + + logprobs: bool + """ + Whether to return log probabilities of the output tokens or not. + If true, returns the log probabilities of each output token returned + in the content of message. + """ + + top_logprobs: int + """ + An integer between 0 and 20 specifying the number of most likely tokens + to return at each token position, each with an associated log probability. + logprobs must be set to true if this parameter is used. + """ + + max_tokens: int + """ + The maximum number of tokens that can be generated in the chat completion. + + The total length of input tokens and generated tokens is limited by the + model's context length. + """ + + n: int + """ + How many chat completion choices to generate for each input message. + Note that you will be charged based on the number of generated tokens + across all of the choices. Keep n as 1 to minimize costs. + """ + + presence_penalty: float + """ + Number between -2.0 and 2.0. Positive values penalize new tokens based on + whether they appear in the text so far, increasing the model's + likelihood to talk about new topics. + """ + + response_format: ChatResponseFormat + """ + An object specifying the format that the model must output. + + **Important**: when using JSON mode, you must also instruct the model + to produce JSON yourself via a system or user message. Without this, + the model may generate an unending stream of whitespace until the + generation reaches the token limit, resulting in a long-running and + seemingly "stuck" request. Also note that the message content may + be partially cut off if `finish_reason="length"`, which indicates the + generation exceeded max_tokens or the conversation exceeded the max + context length. + """ + + seed: int + """ + If specified, our system will make a best effort to sample deterministically, + such that repeated requests with the same seed and parameters should return + the same result. Determinism is not guaranteed, and you should refer to the + `system_fingerprint` response parameter to monitor changes in the backend. + """ + + stop: Union[str, List[str]] + """ + Up to 4 sequences where the API will stop generating further tokens. + """ + + stream: bool + """ + If set, partial message deltas will be sent. Tokens will be sent as + they become available. + """ + + temperature: float + """ + What sampling temperature to use, between 0 and 2. Higher values like 0.8 + will make the output more random, while lower values like 0.2 will make + it more focused and deterministic. + + We generally recommend altering this or top_p but not both. + """ + + top_p: float + """ + An alternative to sampling with temperature, called nucleus sampling, + where the model considers the results of the tokens with top_p probability + mass. So 0.1 means only the tokens comprising the top 10% probability + mass are considered. + + We generally recommend altering this or temperature but not both. + """ + + +class Chat: + def __init__(self, http: HttpClient): + self.http = http + + @staticmethod + def _validate_request(req: ChatRequest): + has_messages = "messages" in req + has_model = "model" in req + + if not has_messages or not has_model: + raise QstashException("'messages' and 'model' must be provided.") + + @staticmethod + def _to_chat_request(req: PromptRequest) -> ChatRequest: + system_msg = req.get("system") + user_msg = req.get("user") + + if not system_msg and not user_msg: + raise QstashException( + "At least one of 'system' or 'user' prompt is required" + ) + + messages: List[ChatCompletionMessage] = [] + if system_msg: + messages.append({"role": "system", "content": system_msg}) + + if user_msg: + messages.append({"role": "user", "content": user_msg}) + + chat_req: ChatRequest = {"messages": messages} + + for k, v in req.items(): + if k == "system" or k == "user": + continue + + chat_req[k] = v # type: ignore[literal-required] + + return chat_req + + def create( + self, req: ChatRequest + ) -> Union[ChatCompletion, Iterable[ChatCompletionChunk]]: + """ + Creates a model response for the given chat conversation. + + When `stream` is set to `True`, it returns an iterable + that can be used to receive chat completion delta chunks + one by one. + + Otherwise, response is returned in one go as a chat + completion object. + """ + self._validate_request(req) + body = json.dumps(req) + + if req.get("stream"): + return self.http.request_stream( + { + "path": ["llm", "v1", "chat", "completions"], + "method": "POST", + "headers": { + "Content-Type": "application/json", + "Connection": "keep-alive", + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + }, + "body": body, + } + ) + + return self.http.request( + { + "path": ["llm", "v1", "chat", "completions"], + "method": "POST", + "headers": {"Content-Type": "application/json"}, + "body": body, + } + ) + + def prompt( + self, req: PromptRequest + ) -> Union[ChatCompletion, Iterable[ChatCompletionChunk]]: + """ + Creates a model response for the given prompt. + + When `stream` is set to `True`, it returns an iterable + that can be used to receive chat completion delta chunks + one by one. + + Otherwise, response is returned in one go as a chat + completion object. + """ + chat_req = self._to_chat_request(req) + return self.create(chat_req) diff --git a/upstash_qstash/client.py b/upstash_qstash/client.py index db7ac58..89be68b 100644 --- a/upstash_qstash/client.py +++ b/upstash_qstash/client.py @@ -1,5 +1,12 @@ -from typing import Optional, Union - +from typing import Iterable, Optional, Union + +from upstash_qstash.chat import ( + Chat, + ChatCompletion, + ChatCompletionChunk, + ChatRequest, + PromptRequest, +) from upstash_qstash.dlq import DLQ from upstash_qstash.events import Events, EventsRequest, GetEventsResponse from upstash_qstash.keys import Keys @@ -23,7 +30,6 @@ def __init__( ): """ Synchronous QStash client. - To use the blocking version, use the upstash_qstash client instead. """ self.http = HttpClient(token, retry, base_url or DEFAULT_BASE_URL) @@ -31,7 +37,8 @@ def publish(self, req: PublishRequest): """ Publish a message to QStash. - If publishing to a URL (req contains 'url'), this method returns a PublishToUrlResponse: + If publishing to a URL (req contains 'url') or an API (req contains 'api'), + this method returns a PublishToUrlResponse: - PublishToUrlResponse: Contains 'messageId' indicating the unique ID of the message and an optional 'deduplicated' boolean indicating if the message is a duplicate. @@ -43,7 +50,7 @@ def publish(self, req: PublishRequest): :param req: An instance of PublishRequest containing the request details. :return: Response details including the message_id, url (if publishing to a topic), and possibly a deduplicated boolean. The exact return type depends on the publish target. - :raises ValueError: If neither 'url' nor 'topic' is provided, or both are provided. + :raises ValueError: If neither 'url', 'topic', nor 'api' is provided, or more than one of them are provided. """ return Publish.publish(self.http, req) @@ -146,3 +153,9 @@ def events(self, req: Optional[EventsRequest] = None) -> GetEventsResponse: >>> break """ return Events.get(self.http, req) + + def chat(self) -> Chat: + """ + Access chat completion APIs. + """ + return Chat(self.http) diff --git a/upstash_qstash/error.py b/upstash_qstash/error.py index f3cd138..353d25e 100644 --- a/upstash_qstash/error.py +++ b/upstash_qstash/error.py @@ -1,8 +1,8 @@ import json from typing import TypedDict -RateLimitConfig = TypedDict( - "RateLimitConfig", +RateLimit = TypedDict( + "RateLimit", { "limit": int, "remaining": int, @@ -10,6 +10,18 @@ }, ) +ChatRateLimit = TypedDict( + "ChatRateLimit", + { + "limit-requests": int, + "limit-tokens": int, + "remaining-requests": int, + "remaining-tokens": int, + "reset-requests": str, + "reset-tokens": str, + }, +) + class QstashException(Exception): def __init__(self, message: str): @@ -18,7 +30,12 @@ def __init__(self, message: str): class QstashRateLimitException(QstashException): - def __init__(self, args: RateLimitConfig): + def __init__(self, args: RateLimit): + super().__init__(f"You have been rate limited. {json.dumps(args)}") + + +class QstashChatRateLimitException(QstashException): + def __init__(self, args: ChatRateLimit): super().__init__(f"You have been rate limited. {json.dumps(args)}") diff --git a/upstash_qstash/publish.py b/upstash_qstash/publish.py index 26797eb..168689b 100644 --- a/upstash_qstash/publish.py +++ b/upstash_qstash/publish.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, List, Optional, TypedDict, Union +from typing import Any, Dict, List, Literal, Optional, TypedDict, Union from upstash_qstash.error import QstashException from upstash_qstash.qstash_types import Method, UpstashHeaders @@ -21,6 +21,7 @@ "failure_callback": str, "method": Method, "topic": str, + "api": Literal["llm"], }, total=False, ) @@ -41,16 +42,37 @@ class Publish: + @staticmethod + def _get_destination(req: PublishRequest) -> str: + url = req.get("url") + if url is not None: + return url + + topic = req.get("topic") + if topic is not None: + return topic + + api = req.get("api") + return f"api/{api}" + @staticmethod def _validate_request(req: PublishRequest): """ - Validate the publish request to ensure it has either url or topic. + Validate the publish request to ensure it has either url, topic or api. """ - if (req.get("url") is None and req.get("topic") is None) or ( - req.get("url") is not None and req.get("topic") is not None - ): + destination_count = 0 + if "url" in req: + destination_count += 1 + + if "topic" in req: + destination_count += 1 + + if "api" in req: + destination_count += 1 + + if destination_count != 1: raise QstashException( - "Either 'url' or 'topic' must be provided, but not both." + "Only and only one of 'url', 'topic', or 'api' must be provided." ) @staticmethod @@ -95,10 +117,11 @@ def publish( """ Publish._validate_request(req) headers = Publish._prepare_headers(req) + destination = Publish._get_destination(req) return http.request( { - "path": ["v2", "publish", req.get("url") or req["topic"]], + "path": ["v2", "publish", destination], "body": req.get("body"), "headers": headers, "method": "POST", @@ -130,9 +153,10 @@ def batch( messages = [] for message in req: + destination = Publish._get_destination(message) messages.append( { - "destination": message.get("url") or message["topic"], + "destination": destination, "headers": message["headers"], "body": message.get("body"), } diff --git a/upstash_qstash/qstash_types.py b/upstash_qstash/qstash_types.py index 7d36280..b2d7b9c 100644 --- a/upstash_qstash/qstash_types.py +++ b/upstash_qstash/qstash_types.py @@ -25,7 +25,6 @@ "path": List[str], "body": Any, "headers": UpstashHeaders, - "keepalive": bool, "method": Method, "query": Dict[str, str], "parse_response_as_json": bool, diff --git a/upstash_qstash/queue.py b/upstash_qstash/queue.py index 0459bed..6da7c23 100644 --- a/upstash_qstash/queue.py +++ b/upstash_qstash/queue.py @@ -115,6 +115,7 @@ def enqueue( Publish._validate_request(req) headers = Publish._prepare_headers(req) + destination = Publish._get_destination(req) return self.http.request( { @@ -122,7 +123,7 @@ def enqueue( "v2", "enqueue", self.queue_name, - req.get("url") or req["topic"], + destination, ], "body": req.get("body"), "headers": headers, diff --git a/upstash_qstash/upstash_http.py b/upstash_qstash/upstash_http.py index b3f0791..2ca35ff 100644 --- a/upstash_qstash/upstash_http.py +++ b/upstash_qstash/upstash_http.py @@ -1,4 +1,5 @@ import asyncio +import json import math import time from typing import Optional, Union @@ -7,7 +8,11 @@ import aiohttp import requests -from upstash_qstash.error import QstashException, QstashRateLimitException +from upstash_qstash.error import ( + QstashChatRateLimitException, + QstashException, + QstashRateLimitException, +) from upstash_qstash.qstash_types import RetryConfig, UpstashHeaders, UpstashRequest NO_RETRY: RetryConfig = {"attempts": 1, "backoff": lambda _: 0} @@ -73,7 +78,6 @@ def request(self, req: UpstashRequest): method=req["method"], url=url, headers=headers, - stream=req.get("keepalive", False), data=req.get("body"), ) return self._handle_response(res, req) @@ -84,13 +88,51 @@ def request(self, req: UpstashRequest): "Exhausted all retries without a successful response" ) - def _handle_response(self, res, req: UpstashRequest): + def request_stream(self, req: UpstashRequest): """ - Synchronously handle the response from a request. - Raises an exception if the response is not successful. - If the response is successful, returns the response body. + Synchronously make a request to QStash, returning a generator that decodes + SSE events until done message is received. + + :param req: The request to make. + :return: The response from the request. """ + url, headers = self._prepare_request_details(req) + error = None + for i in range(self.retry["attempts"]): + try: + res = requests.request( + method=req["method"], + url=url, + headers=headers, + stream=True, + data=req.get("body"), + ) + return self._handle_stream_response(res) + except Exception as e: + error = e + time.sleep(self.retry["backoff"](i) / 1000) + raise error or QstashException( + "Exhausted all retries without a successful response" + ) + + def _check_status(self, res): if res.status_code == 429: + if res.headers.get("x-ratelimit-limit-requests") is not None: + raise QstashChatRateLimitException( + { + "limit-requests": res.headers.get("x-ratelimit-limit-requests"), + "limit-tokens": res.headers.get("x-ratelimit-limit-tokens"), + "remaining-requests": res.headers.get( + "x-ratelimit-remaining-requests" + ), + "remaining-tokens": res.headers.get( + "x-ratelimit-remaining-tokens" + ), + "reset-requests": res.headers.get("x-ratelimit-reset-requests"), + "reset-tokens": res.headers.get("x-ratelimit-reset-tokens"), + } + ) + raise QstashRateLimitException( { "limit": res.headers.get("Burst-RateLimit-Limit"), @@ -98,12 +140,42 @@ def _handle_response(self, res, req: UpstashRequest): "reset": res.headers.get("Burst-RateLimit-Reset"), } ) + if res.status_code < 200 or res.status_code >= 300: raise QstashException( f"Qstash request failed with status {res.status_code}: {res.text}" ) + + def _handle_response(self, res: requests.Response, req: UpstashRequest): + """ + Synchronously handle the response from a request. + Raises an exception if the response is not successful. + If the response is successful, returns the response body. + """ + self._check_status(res) + return res.json() if req.get("parse_response_as_json", True) else res.text + def _handle_stream_response(self, res: requests.Response): + """ + Synchronously handle the response from a request in a streaming fashion + until the done message is received. + Raises an exception if the response is not successful. + If the response is successful, returns a generator that yields response body in chunks. + """ + try: + self._check_status(res) + + for chunk in res.iter_lines(delimiter=b"\n\n"): + if chunk.startswith(b"data: "): + chunk = chunk[6:] # skip data header + if chunk == b"[DONE]": + break + + yield json.loads(chunk) + finally: + res.close() + async def request_async(self, req: UpstashRequest): """ Asynchronously make a request to QStash. @@ -130,14 +202,54 @@ async def request_async(self, req: UpstashRequest): "Exhausted all retries without a successful response" ) - async def _handle_response_async(self, res, req: UpstashRequest): + async def request_stream_async(self, req: UpstashRequest): """ - Asynchronously handle the response from a request. - Raises an exception if the response is not successful. - If the response is successful, returns the response body. + Asynchronously make a request to QStash, returning a generator that decodes + SSE events until done message is received. + + :param req: The request to make. + :return: The response from the request. """ + url, headers = self._prepare_request_details(req) + error = None + for i in range(self.retry["attempts"]): + try: + async with aiohttp.ClientSession() as session: + async with session.request( + method=req["method"], + url=url, + headers=headers, + data=req.get("body"), + ) as res: + await self._check_status_async(res) + async for chunk in self._handle_stream_response_async(res): + yield chunk + + return + except Exception as e: + error = e + await asyncio.sleep(self.retry["backoff"](i) / 1000) + raise error or QstashException( + "Exhausted all retries without a successful response" + ) + + async def _check_status_async(self, res): if res.status == 429: headers = res.headers + if headers.get("x-ratelimit-limit-requests") is not None: + raise QstashChatRateLimitException( + { + "limit-requests": headers.get("x-ratelimit-limit-requests"), + "limit-tokens": headers.get("x-ratelimit-limit-tokens"), + "remaining-requests": headers.get( + "x-ratelimit-remaining-requests" + ), + "remaining-tokens": headers.get("x-ratelimit-remaining-tokens"), + "reset-requests": headers.get("x-ratelimit-reset-requests"), + "reset-tokens": headers.get("x-ratelimit-reset-tokens"), + } + ) + raise QstashRateLimitException( { "limit": headers.get("Burst-RateLimit-Limit"), @@ -145,13 +257,60 @@ async def _handle_response_async(self, res, req: UpstashRequest): "reset": headers.get("Burst-RateLimit-Reset"), } ) + if res.status < 200 or res.status >= 300: text = await res.text() raise QstashException( f"Qstash request failed with status {res.status}: {text}" ) + + async def _handle_response_async(self, res, req: UpstashRequest): + """ + Asynchronously handle the response from a request. + Raises an exception if the response is not successful. + If the response is successful, returns the response body. + """ + await self._check_status_async(res) + return ( - await res.json() + await res.json(content_type=None) if req.get("parse_response_as_json", True) else await res.text() ) + + async def _handle_stream_response_async(self, res: aiohttp.ClientResponse): + """ + Asynchronously handle the response from a request in a streaming fashion + until the done message is received. + Raises an exception if the response is not successful. + If the response is successful, returns a generator that yields response body in chunks. + """ + + # Adapted from requests#iterlines + pending = None + async for data in res.content.iter_any(): + if pending is not None: + data = pending + data + + chunks = data.split(b"\n\n") + + if chunks and chunks[-1] and data and chunks[-1][-1] == data[-1]: + pending = chunks.pop() + else: + pending = None + + for chunk in chunks: + if chunk.startswith(b"data: "): + chunk = chunk[6:] # skip data header + if chunk == b"[DONE]": + return + + yield json.loads(chunk) + + if pending is not None: + if pending.startswith(b"data: "): + pending = pending[6:] # skip data header + if pending == b"[DONE]": + return + + yield json.loads(pending)