Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add stream completion method #32

Merged
merged 1 commit into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ categories = ["api-bindings"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
async-stream = "0.3.6"
base64 = "0.22.0"
futures-util = "0.3.31"
image = "0.25.1"
itertools = "0.13.0"
reqwest = { version = "0.12.3", features = ["json"] }
reqwest = { version = "0.12.3", features = ["json", "stream"] }
serde = { version = "1.0.197", features = ["derive"] }
serde_json = "1.0.115"
thiserror = "1.0.58"
Expand Down
63 changes: 62 additions & 1 deletion src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::borrow::Cow;

use serde::{Deserialize, Serialize};

use crate::Task;
use crate::{StreamTask, Task};

#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct Message<'a> {
Expand Down Expand Up @@ -126,6 +126,9 @@ struct ChatBody<'a> {
/// When no value is provided, the default value of 1 will be used.
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
/// Whether to stream the response or not.
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub stream: bool,
}

impl<'a> ChatBody<'a> {
Expand All @@ -136,8 +139,14 @@ impl<'a> ChatBody<'a> {
maximum_tokens: task.maximum_tokens,
temperature: task.temperature,
top_p: task.top_p,
stream: false,
}
}

pub fn with_streaming(mut self) -> Self {
self.stream = true;
self
}
}

impl<'a> Task for TaskChat<'a> {
Expand All @@ -159,3 +168,55 @@ impl<'a> Task for TaskChat<'a> {
response.choices.pop().unwrap()
}
}

#[derive(Deserialize)]
pub struct StreamMessage {
/// The role of the current chat completion. Will be assistant for the first chunk of every
/// completion stream and missing for the remaining chunks.
pub role: Option<String>,
/// The content of the current chat completion. Will be empty for the first chunk of every
/// completion stream and non-empty for the remaining chunks.
pub content: String,
}

/// One chunk of a chat completion stream.
#[derive(Deserialize)]
pub struct ChatStreamChunk {
/// The reason the model stopped generating tokens.
/// The value is only set in the last chunk of a completion and null otherwise.
pub finish_reason: Option<String>,
/// Chat completion chunk generated by the model when streaming is enabled.
pub delta: StreamMessage,
}

/// Event received from a chat completion stream. As the crate does not support multiple
/// chat completions, there will always exactly one choice item.
#[derive(Deserialize)]
pub struct ChatEvent {
pub choices: Vec<ChatStreamChunk>,
}

impl<'a> StreamTask for TaskChat<'a> {
type Output = ChatStreamChunk;

type ResponseBody = ChatEvent;

fn build_request(
&self,
client: &reqwest::Client,
base: &str,
model: &str,
) -> reqwest::RequestBuilder {
let body = ChatBody::new(model, &self).with_streaming();
client.post(format!("{base}/chat/completions")).json(&body)
}

fn body_to_output(mut response: Self::ResponseBody) -> Self::Output {
// We always expect there to be exactly one choice, as the `n` parameter is not
// supported by this crate.
response
.choices
.pop()
.expect("There must always be at least one choice")
}
}
75 changes: 74 additions & 1 deletion src/completion.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use serde::{Deserialize, Serialize};

use crate::{http::Task, Prompt};
use crate::{http::Task, Prompt, StreamTask};

/// Completes a prompt. E.g. continues a text.
pub struct TaskCompletion<'a> {
Expand Down Expand Up @@ -142,6 +142,9 @@ struct BodyCompletion<'a> {
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "<[_]>::is_empty")]
pub completion_bias_inclusion: &'a [&'a str],
/// If true, the response will be streamed.
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub stream: bool,
}

impl<'a> BodyCompletion<'a> {
Expand All @@ -155,8 +158,13 @@ impl<'a> BodyCompletion<'a> {
top_k: task.sampling.top_k,
top_p: task.sampling.top_p,
completion_bias_inclusion: task.sampling.start_with_one_of,
stream: false,
}
}
pub fn with_streaming(mut self) -> Self {
self.stream = true;
self
}
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -205,3 +213,68 @@ impl Task for TaskCompletion<'_> {
response.completions.pop().unwrap()
}
}

/// Describes a chunk of a completion stream
#[derive(Deserialize, Debug)]
pub struct StreamChunk {
/// The index of the stream that this chunk belongs to.
/// This is relevant if multiple completion streams are requested (see parameter n).
pub index: u32,
/// The completion of the stream.
pub completion: String,
}

/// Denotes the end of a completion stream.
///
/// The index of the stream that is being terminated is not deserialized.
/// It is only relevant if multiple completion streams are requested, (see parameter n),
/// which is not supported by this crate yet.
#[derive(Deserialize)]
pub struct StreamSummary {
/// Model name and version (if any) of the used model for inference.
pub model_version: String,
/// The reason why the model stopped generating new tokens.
pub finish_reason: String,
}

/// Denotes the end of all completion streams.
#[derive(Deserialize)]
pub struct CompletionSummary {
/// Number of tokens combined across all completion tasks.
/// In particular, if you set best_of or n to a number larger than 1 then we report the
/// combined prompt token count for all best_of or n tasks.
pub num_tokens_prompt_total: u32,
/// Number of tokens combined across all completion tasks.
/// If multiple completions are returned or best_of is set to a value greater than 1 then
/// this value contains the combined generated token count.
pub num_tokens_generated: u32,
}

#[derive(Deserialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum CompletionEvent {
StreamChunk(StreamChunk),
StreamSummary(StreamSummary),
CompletionSummary(CompletionSummary),
}

impl StreamTask for TaskCompletion<'_> {
type Output = CompletionEvent;

type ResponseBody = CompletionEvent;

fn build_request(
&self,
client: &reqwest::Client,
base: &str,
model: &str,
) -> reqwest::RequestBuilder {
let body = BodyCompletion::new(model, &self).with_streaming();
client.post(format!("{base}/complete")).json(&body)
}

fn body_to_output(response: Self::ResponseBody) -> Self::Output {
response
}
}
Loading
Loading