Skip to content

Commit

Permalink
Add chat completion support
Browse files Browse the repository at this point in the history
This PR adds support for the following features:

- client.chat().create() and client.chat().prompt() APIs for multi
  or single turn chat completions.
- api/llm support for publish, enqueue, and batch APIs
  • Loading branch information
mdumandag committed Jun 6, 2024
1 parent 0ba4f80 commit 72e12ac
Show file tree
Hide file tree
Showing 19 changed files with 1,185 additions and 96 deletions.
129 changes: 75 additions & 54 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ from upstash_qstash import Client

client = Client("<QSTASH_TOKEN>")
res = client.publish_json(
{
"url": "https://my-api...",
"body": {
"hello": "world"
},
"headers": {
"test-header": "test-value",
},
}
{
"url": "https://my-api...",
"body": {
"hello": "world"
},
"headers": {
"test-header": "test-value",
},
}
)

print(res["messageId"])
Expand All @@ -37,10 +37,10 @@ from upstash_qstash import Client
client = Client("<QSTASH_TOKEN>")
schedules = client.schedules()
res = schedules.create(
{
"destination": "https://my-api...",
"cron": "*/5 * * * *",
}
{
"destination": "https://my-api...",
"cron": "*/5 * * * *",
}
)

print(res["scheduleId"])
Expand All @@ -53,25 +53,46 @@ from upstash_qstash import Receiver

# Keys available from the QStash console
receiver = Receiver(
{
"current_signing_key": "CURRENT_SIGNING_KEY",
"next_signing_key": "NEXT_SIGNING_KEY",
}
{
"current_signing_key": "CURRENT_SIGNING_KEY",
"next_signing_key": "NEXT_SIGNING_KEY",
}
)

# ... in your request handler

signature, body = req.headers["Upstash-Signature"], req.body

is_valid = receiver.verify(
{
"body": body,
"signature": signature,
"url": "https://my-api...", # Optional
}
{
"body": body,
"signature": signature,
"url": "https://my-api...", # Optional
}
)
```

#### Create Chat Completions

```python
from upstash_qstash import Client

client = Client("<QSTASH_TOKEN>")
chat = client.chat()

res = chat.create({
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
"messages": [
{
"role": "user",
"content": "What is the capital of Turkey?"
}
]
})

print(res["choices"][0]["message"]["content"])
```

#### Additional configuration

```python
Expand All @@ -86,47 +107,47 @@ from upstash_qstash import Client
# "backoff": lambda retry_count: math.exp(retry_count) * 50,
# }
client = Client("<QSTASH_TOKEN>", {
"attempts": 2,
"backoff": lambda retry_count: (2 ** retry_count) * 20,
"attempts": 2,
"backoff": lambda retry_count: (2 ** retry_count) * 20,
})

# Create Topic
topics = client.topics()
topics.upsert_or_add_endpoints(
{
"name": "topic_name",
"endpoints": [
{"url": "https://my-endpoint-1"},
{"url": "https://my-endpoint-2"}
],
}
{
"name": "topic_name",
"endpoints": [
{"url": "https://my-endpoint-1"},
{"url": "https://my-endpoint-2"}
],
}
)

# Publish to Topic
client.publish_json(
{
"topic": "my-topic",
"body": {
"key": "value"
},
# Retry sending message to API 3 times
# https://upstash.com/docs/qstash/features/retry
"retries": 3,
# Schedule message to be sent 4 seconds from now
"delay": 4,
# When message is sent, send a request to this URL
# https://upstash.com/docs/qstash/features/callbacks
"callback": "https://my-api.com/callback",
# When message fails to send, send a request to this URL
"failure_callback": "https://my-api.com/failure_callback",
# Headers to forward to the endpoint
"headers": {
"test-header": "test-value",
},
# Enable content-based deduplication
# https://upstash.com/docs/qstash/features/deduplication#content-based-deduplication
"content_based_deduplication": True,
}
{
"topic": "my-topic",
"body": {
"key": "value"
},
# Retry sending message to API 3 times
# https://upstash.com/docs/qstash/features/retry
"retries": 3,
# Schedule message to be sent 4 seconds from now
"delay": 4,
# When message is sent, send a request to this URL
# https://upstash.com/docs/qstash/features/callbacks
"callback": "https://my-api.com/callback",
# When message fails to send, send a request to this URL
"failure_callback": "https://my-api.com/failure_callback",
# Headers to forward to the endpoint
"headers": {
"test-header": "test-value",
},
# Enable content-based deduplication
# https://upstash.com/docs/qstash/features/deduplication#content-based-deduplication
"content_based_deduplication": True,
}
)
```

Expand Down
71 changes: 71 additions & 0 deletions tests/asyncio/test_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from typing import AsyncIterable

import pytest

from qstash_tokens import QSTASH_TOKEN
from upstash_qstash.asyncio import Client


@pytest.fixture
def client():
return Client(QSTASH_TOKEN)


@pytest.mark.asyncio
async def test_chat_async(client):
res = await client.chat().create(
{
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
"messages": [{"role": "user", "content": "hello"}],
}
)

assert res["id"] is not None

assert res["choices"][0]["message"]["content"] is not None
assert res["choices"][0]["message"]["role"] == "assistant"


@pytest.mark.asyncio
async def test_chat_streaming_async(client):
res = await client.chat().create(
{
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
"messages": [{"role": "user", "content": "hello"}],
"stream": True,
}
)

async for r in res:
assert r["id"] is not None
assert r["choices"][0]["delta"] is not None


@pytest.mark.asyncio
async def test_prompt_async(client):
res = await client.chat().prompt(
{
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
"user": "hello",
}
)

assert res["id"] is not None

assert res["choices"][0]["message"]["content"] is not None
assert res["choices"][0]["message"]["role"] == "assistant"


@pytest.mark.asyncio
async def test_prompt_streaming_async(client):
res = await client.chat().prompt(
{
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
"user": "hello",
"stream": True,
}
)

async for r in res:
assert r["id"] is not None
assert r["choices"][0]["delta"] is not None
49 changes: 48 additions & 1 deletion tests/asyncio/test_publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def test_publish_to_url_async(client):
found_event
), f"Event with messageId {res['messageId']} not found This may be because the latency of the event is too high"
assert (
event["state"] != "ERROR"
event["state"] != "ERROR"
), f"Event with messageId {res['messageId']} was not delivered"


Expand Down Expand Up @@ -87,3 +87,50 @@ async def test_batch_json_async(client):
assert len(res) == N
for i in range(N):
assert res[i]["messageId"] is not None


@pytest.mark.asyncio
async def test_publish_api_llm_async(client):
# not a proper test, because of a dummy callback.
res = await client.publish_json(
{
"api": "llm",
"body": {
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
"messages": [
{
"role": "user",
"content": "hello",
}
],
},
"callback": "https://example.com",
}
)

assert res["messageId"] is not None


@pytest.mark.asyncio
async def test_batch_api_llm_async(client):
# not a proper test, because of a dummy callback.
res = await client.batch_json(
[
{
"api": "llm",
"body": {
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
"messages": [
{
"role": "user",
"content": "hello",
}
],
},
"callback": "https://example.com",
}
]
)

assert len(res) == 1
assert res[0]["messageId"] is not None
27 changes: 27 additions & 0 deletions tests/asyncio/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,30 @@ async def test_enqueue(client):

print("Deleting queue")
await queue.delete()


@pytest.mark.asyncio
async def test_enqueue_api_llm_async(client):
# not a proper test, because of a dummy callback.
queue = client.queue({"queue_name": "test_queue"})

try:
res = await queue.enqueue_json(
{
"api": "llm",
"body": {
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
"messages": [
{
"role": "user",
"content": "hello",
}
],
},
"callback": "https://example.com/",
}
)

assert res["messageId"] is not None
finally:
await queue.delete()
Loading

0 comments on commit 72e12ac

Please sign in to comment.