Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream completion #31

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@ categories = ["api-bindings"]

[dependencies]
base64 = "0.22.0"
futures-util = "0.3.31"
image = "0.25.1"
itertools = "0.13.0"
reqwest = { version = "0.12.3", features = ["json"] }
reqwest = { version = "0.12.3", features = ["json", "stream"] }
serde = { version = "1.0.197", features = ["derive"] }
serde_json = "1.0.115"
thiserror = "1.0.58"
tokenizers = { version = "0.20.0", default-features = false, features = ["onig", "esaxx_fast"] }
tokio = { version = "1.37.0", features = ["rt", "macros"] }

[dev-dependencies]
dotenv = "0.15.0"
tokio = { version = "1.37.0", features = ["rt", "macros"] }
wiremock = "0.6.0"
20 changes: 17 additions & 3 deletions src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::borrow::Cow;

use serde::{Deserialize, Serialize};

use crate::Task;
use crate::{stream::TaskStreamChat, Task};

#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct Message<'a> {
Expand Down Expand Up @@ -95,6 +95,11 @@ impl<'a> TaskChat<'a> {
self.top_p = Some(top_p);
self
}

/// Creates a wrapper `TaskStreamChat` for this TaskChat.
pub fn with_streaming(self) -> TaskStreamChat<'a> {
TaskStreamChat { task: self }
}
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
Expand All @@ -109,7 +114,7 @@ pub struct ResponseChat {
}

#[derive(Serialize)]
struct ChatBody<'a> {
pub struct ChatBody<'a> {
/// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
pub model: &'a str,
/// The list of messages comprising the conversation so far.
Expand All @@ -126,6 +131,9 @@ struct ChatBody<'a> {
/// When no value is provided, the default value of 1 will be used.
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
/// Whether to stream the response or not.
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub stream: bool,
}

impl<'a> ChatBody<'a> {
Expand All @@ -136,8 +144,14 @@ impl<'a> ChatBody<'a> {
maximum_tokens: task.maximum_tokens,
temperature: task.temperature,
top_p: task.top_p,
stream: false,
}
}

pub fn with_streaming(mut self) -> Self {
self.stream = true;
self
}
}

impl<'a> Task for TaskChat<'a> {
Expand All @@ -155,7 +169,7 @@ impl<'a> Task for TaskChat<'a> {
client.post(format!("{base}/chat/completions")).json(&body)
}

fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
fn body_to_output(mut response: Self::ResponseBody) -> Self::Output {
response.choices.pop().unwrap()
}
}
17 changes: 14 additions & 3 deletions src/completion.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use serde::{Deserialize, Serialize};

use crate::{http::Task, Prompt};
use crate::{http::Task, Prompt, TaskStreamCompletion};

/// Completes a prompt. E.g. continues a text.
pub struct TaskCompletion<'a> {
Expand Down Expand Up @@ -32,6 +32,9 @@ impl<'a> TaskCompletion<'a> {
self.stopping.stop_sequences = stop_sequences;
self
}
pub fn with_streaming(self) -> TaskStreamCompletion<'a> {
TaskStreamCompletion { task: self }
}
}

/// Sampling controls how the tokens ("words") are selected for the completion.
Expand Down Expand Up @@ -118,7 +121,7 @@ impl Default for Stopping<'_> {

/// Body send to the Aleph Alpha API on the POST `/completion` Route
#[derive(Serialize, Debug)]
struct BodyCompletion<'a> {
pub struct BodyCompletion<'a> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move StreamCompletion to completion module?

/// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
pub model: &'a str,
/// Prompt to complete. The modalities supported depend on `model`.
Expand All @@ -142,6 +145,9 @@ struct BodyCompletion<'a> {
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "<[_]>::is_empty")]
pub completion_bias_inclusion: &'a [&'a str],
/// If true, the response will be streamed.
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub stream: bool,
}

impl<'a> BodyCompletion<'a> {
Expand All @@ -155,8 +161,13 @@ impl<'a> BodyCompletion<'a> {
top_k: task.sampling.top_k,
top_p: task.sampling.top_p,
completion_bias_inclusion: task.sampling.start_with_one_of,
stream: false,
}
}
pub fn with_streaming(mut self) -> Self {
self.stream = true;
self
}
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -201,7 +212,7 @@ impl Task for TaskCompletion<'_> {
client.post(format!("{base}/complete")).json(&body)
}

fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
fn body_to_output(mut response: Self::ResponseBody) -> Self::Output {
response.completions.pop().unwrap()
}
}
2 changes: 1 addition & 1 deletion src/detokenization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl<'a> Task for TaskDetokenization<'a> {
client.post(format!("{base}/detokenize")).json(&body)
}

fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
fn body_to_output(response: Self::ResponseBody) -> Self::Output {
DetokenizationOutput::from(response)
}
}
2 changes: 1 addition & 1 deletion src/explanation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ impl Task for TaskExplanation<'_> {
client.post(format!("{base}/explain")).json(&body)
}

fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
fn body_to_output(response: Self::ResponseBody) -> Self::Output {
ExplanationOutput::from(response)
}
}
119 changes: 81 additions & 38 deletions src/http.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::{borrow::Cow, time::Duration};

use reqwest::{header, ClientBuilder, RequestBuilder, StatusCode};
use futures_util::stream::StreamExt;
use reqwest::{header, ClientBuilder, RequestBuilder, Response, StatusCode};
use serde::Deserialize;
use thiserror::Error as ThisError;
use tokenizers::Tokenizer;
use tokio::sync::mpsc;

use crate::How;
use crate::{stream::parse_stream_event, How};

/// A job send to the Aleph Alpha Api using the http client. A job wraps all the knowledge required
/// for the Aleph Alpha API to specify its result. Notably it includes the model(s) the job is
Expand All @@ -16,34 +18,34 @@ use crate::How;
/// [`Task::with_model`].
pub trait Job {
/// Output returned by [`crate::Client::output_of`]
type Output;
type Output: Send;

/// Expected answer of the Aleph Alpha API
type ResponseBody: for<'de> Deserialize<'de>;
type ResponseBody: for<'de> Deserialize<'de> + Send;

/// Prepare the request for the Aleph Alpha API. Authentication headers can be assumed to be
/// already set.
fn build_request(&self, client: &reqwest::Client, base: &str) -> RequestBuilder;

/// Parses the response of the server into higher level structs for the user.
fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output;
fn body_to_output(response: Self::ResponseBody) -> Self::Output;
}

/// A task send to the Aleph Alpha Api using the http client. Requires to specify a model before it
/// can be executed.
pub trait Task {
/// Output returned by [`crate::Client::output_of`]
type Output;
type Output: Send;

/// Expected answer of the Aleph Alpha API
type ResponseBody: for<'de> Deserialize<'de>;
type ResponseBody: for<'de> Deserialize<'de> + Send;

/// Prepare the request for the Aleph Alpha API. Authentication headers can be assumed to be
/// already set.
fn build_request(&self, client: &reqwest::Client, base: &str, model: &str) -> RequestBuilder;

/// Parses the response of the server into higher level structs for the user.
fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output;
fn body_to_output(response: Self::ResponseBody) -> Self::Output;

/// Turn your task into [`Job`] by annotating it with a model name.
fn with_model<'a>(&'a self, model: &'a str) -> MethodJob<'a, Self>
Expand Down Expand Up @@ -75,8 +77,8 @@ where
self.task.build_request(client, base, self.model)
}

fn body_to_output(&self, response: T::ResponseBody) -> T::Output {
self.task.body_to_output(response)
fn body_to_output(response: T::ResponseBody) -> T::Output {
T::body_to_output(response)
}
}

Expand All @@ -100,32 +102,8 @@ impl HttpClient {
})
}

/// Execute a task with the aleph alpha API and fetch its result.
///
/// ```no_run
/// use aleph_alpha_client::{Client, How, TaskCompletion, Task, Error};
///
/// async fn print_completion() -> Result<(), Error> {
/// // Authenticate against API. Fetches token.
/// let client = Client::with_authentication("AA_API_TOKEN")?;
///
/// // Name of the model we we want to use. Large models give usually better answer, but are
/// // also slower and more costly.
/// let model = "luminous-base";
///
/// // The task we want to perform. Here we want to continue the sentence: "An apple a day
/// // ..."
/// let task = TaskCompletion::from_text("An apple a day");
///
/// // Retrieve answer from API
/// let response = client.output_of(&task.with_model(model), &How::default()).await?;
///
/// // Print entire sentence with completion
/// println!("An apple a day{}", response.completion);
/// Ok(())
/// }
/// ```
pub async fn output_of<T: Job>(&self, task: &T, how: &How) -> Result<T::Output, Error> {
/// Execute a task with the aleph alpha API and return the response without awaiting the body.
pub async fn request<T: Job>(&self, task: &T, how: &How) -> Result<Response, Error> {
let query = if how.be_nice {
[("nice", "true")].as_slice()
} else {
Expand All @@ -152,12 +130,72 @@ impl HttpClient {
reqwest_error.into()
}
})?;
let response = translate_http_error(response).await?;
translate_http_error(response).await
}

/// Execute a task with the aleph alpha API and fetch its result.
///
/// ```no_run
/// use aleph_alpha_client::{Client, How, TaskCompletion, Task, Error};
///
/// async fn print_completion() -> Result<(), Error> {
/// // Authenticate against API. Fetches token.
/// let client = Client::with_authentication("AA_API_TOKEN")?;
///
/// // Name of the model we we want to use. Large models give usually better answer, but are
/// // also slower and more costly.
/// let model = "luminous-base";
///
/// // The task we want to perform. Here we want to continue the sentence: "An apple a day
/// // ..."
/// let task = TaskCompletion::from_text("An apple a day");
///
/// // Retrieve answer from API
/// let response = client.output_of(&task.with_model(model), &How::default()).await?;
///
/// // Print entire sentence with completion
/// println!("An apple a day{}", response.completion);
/// Ok(())
/// }
/// ```
pub async fn output_of<T: Job>(&self, task: &T, how: &How) -> Result<T::Output, Error> {
let response = self.request(task, how).await?;
let response_body: T::ResponseBody = response.json().await?;
let answer = task.body_to_output(response_body);
let answer = T::body_to_output(response_body);
Ok(answer)
}

pub async fn stream_output_of<T: Job>(
&self,
task: &T,
how: &How,
) -> Result<mpsc::Receiver<Result<T::Output, Error>>, Error>
where
T::Output: 'static,
{
let response = self.request(task, how).await?;
let mut stream = response.bytes_stream();

let (tx, rx) = mpsc::channel::<Result<T::Output, Error>>(100);
tokio::spawn(async move {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we only want to return an iterator, not spawn the thread.

while let Some(item) = stream.next().await {
match item {
Ok(bytes) => {
let events = parse_stream_event::<T::ResponseBody>(bytes.as_ref());
for event in events {
let output = event.map(|b| T::body_to_output(b));
tx.send(output).await.unwrap();
}
}
Err(e) => {
tx.send(Err(e.into())).await.unwrap();
}
}
}
});
Ok(rx)
}

fn header_from_token(api_token: &str) -> header::HeaderValue {
let mut auth_value = header::HeaderValue::from_str(&format!("Bearer {api_token}")).unwrap();
// Consider marking security-sensitive headers with `set_sensitive`.
Expand Down Expand Up @@ -264,6 +302,11 @@ pub enum Error {
deserialization_error
)]
InvalidTokenizer { deserialization_error: String },
/// Deserialization error of the stream event.
#[error(
"Stream event could not be correctly deserialized. Caused by:\n{cause}. Event:\n{event}"
)]
StreamDeserializationError { cause: String, event: String },
/// Most likely either TLS errors creating the Client, or IO errors.
#[error(transparent)]
Other(#[from] reqwest::Error),
Expand Down
Loading
Loading