Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream completion #31

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@ categories = ["api-bindings"]

[dependencies]
base64 = "0.22.0"
futures-util = "0.3.31"
image = "0.25.1"
itertools = "0.13.0"
reqwest = { version = "0.12.3", features = ["json"] }
reqwest = { version = "0.12.3", features = ["json", "stream"] }
serde = { version = "1.0.197", features = ["derive"] }
serde_json = "1.0.115"
thiserror = "1.0.58"
tokenizers = { version = "0.20.0", default-features = false, features = ["onig", "esaxx_fast"] }
tokio = { version = "1.37.0", features = ["rt", "macros"] }

[dev-dependencies]
dotenv = "0.15.0"
tokio = { version = "1.37.0", features = ["rt", "macros"] }
wiremock = "0.6.0"
15 changes: 13 additions & 2 deletions src/completion.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use serde::{Deserialize, Serialize};

use crate::{http::Task, Prompt};
use crate::{http::Task, Prompt, TaskStreamCompletion};

/// Completes a prompt. E.g. continues a text.
pub struct TaskCompletion<'a> {
Expand Down Expand Up @@ -32,6 +32,9 @@ impl<'a> TaskCompletion<'a> {
self.stopping.stop_sequences = stop_sequences;
self
}
pub fn with_streaming(self) -> TaskStreamCompletion<'a> {
TaskStreamCompletion { task: self }
}
}

/// Sampling controls how the tokens ("words") are selected for the completion.
Expand Down Expand Up @@ -118,7 +121,7 @@ impl Default for Stopping<'_> {

/// Body send to the Aleph Alpha API on the POST `/completion` Route
#[derive(Serialize, Debug)]
struct BodyCompletion<'a> {
pub struct BodyCompletion<'a> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move StreamCompletion to completion module?

/// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
pub model: &'a str,
/// Prompt to complete. The modalities supported depend on `model`.
Expand All @@ -142,6 +145,9 @@ struct BodyCompletion<'a> {
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "<[_]>::is_empty")]
pub completion_bias_inclusion: &'a [&'a str],
/// If true, the response will be streamed.
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub stream: bool,
}

impl<'a> BodyCompletion<'a> {
Expand All @@ -155,8 +161,13 @@ impl<'a> BodyCompletion<'a> {
top_k: task.sampling.top_k,
top_p: task.sampling.top_p,
completion_bias_inclusion: task.sampling.start_with_one_of,
stream: false,
}
}
pub fn with_streaming(mut self) -> Self {
self.stream = true;
self
}
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
Expand Down
104 changes: 73 additions & 31 deletions src/http.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::{borrow::Cow, time::Duration};

use reqwest::{header, ClientBuilder, RequestBuilder, StatusCode};
use futures_util::stream::StreamExt;
use reqwest::{header, ClientBuilder, RequestBuilder, Response, StatusCode};
use serde::Deserialize;
use thiserror::Error as ThisError;
use tokenizers::Tokenizer;
use tokio::sync::mpsc;

use crate::How;
use crate::{stream::parse_stream_event, How};

/// A job send to the Aleph Alpha Api using the http client. A job wraps all the knowledge required
/// for the Aleph Alpha API to specify its result. Notably it includes the model(s) the job is
Expand All @@ -19,7 +21,7 @@ pub trait Job {
type Output;

/// Expected answer of the Aleph Alpha API
type ResponseBody: for<'de> Deserialize<'de>;
type ResponseBody: for<'de> Deserialize<'de> + Send;

/// Prepare the request for the Aleph Alpha API. Authentication headers can be assumed to be
/// already set.
Expand All @@ -36,7 +38,7 @@ pub trait Task {
type Output;

/// Expected answer of the Aleph Alpha API
type ResponseBody: for<'de> Deserialize<'de>;
type ResponseBody: for<'de> Deserialize<'de> + Send;

/// Prepare the request for the Aleph Alpha API. Authentication headers can be assumed to be
/// already set.
Expand Down Expand Up @@ -100,32 +102,8 @@ impl HttpClient {
})
}

/// Execute a task with the aleph alpha API and fetch its result.
///
/// ```no_run
/// use aleph_alpha_client::{Client, How, TaskCompletion, Task, Error};
///
/// async fn print_completion() -> Result<(), Error> {
/// // Authenticate against API. Fetches token.
/// let client = Client::with_authentication("AA_API_TOKEN")?;
///
/// // Name of the model we we want to use. Large models give usually better answer, but are
/// // also slower and more costly.
/// let model = "luminous-base";
///
/// // The task we want to perform. Here we want to continue the sentence: "An apple a day
/// // ..."
/// let task = TaskCompletion::from_text("An apple a day");
///
/// // Retrieve answer from API
/// let response = client.output_of(&task.with_model(model), &How::default()).await?;
///
/// // Print entire sentence with completion
/// println!("An apple a day{}", response.completion);
/// Ok(())
/// }
/// ```
pub async fn output_of<T: Job>(&self, task: &T, how: &How) -> Result<T::Output, Error> {
/// Execute a task with the aleph alpha API and return the response without awaiting the body.
pub async fn request<T: Job>(&self, task: &T, how: &How) -> Result<Response, Error> {
let query = if how.be_nice {
[("nice", "true")].as_slice()
} else {
Expand All @@ -152,12 +130,71 @@ impl HttpClient {
reqwest_error.into()
}
})?;
let response = translate_http_error(response).await?;
translate_http_error(response).await
}

/// Execute a task with the aleph alpha API and fetch its result.
///
/// ```no_run
/// use aleph_alpha_client::{Client, How, TaskCompletion, Task, Error};
///
/// async fn print_completion() -> Result<(), Error> {
/// // Authenticate against API. Fetches token.
/// let client = Client::with_authentication("AA_API_TOKEN")?;
///
/// // Name of the model we we want to use. Large models give usually better answer, but are
/// // also slower and more costly.
/// let model = "luminous-base";
///
/// // The task we want to perform. Here we want to continue the sentence: "An apple a day
/// // ..."
/// let task = TaskCompletion::from_text("An apple a day");
///
/// // Retrieve answer from API
/// let response = client.output_of(&task.with_model(model), &How::default()).await?;
///
/// // Print entire sentence with completion
/// println!("An apple a day{}", response.completion);
/// Ok(())
/// }
/// ```
pub async fn output_of<T: Job>(&self, task: &T, how: &How) -> Result<T::Output, Error> {
let response = self.request(task, how).await?;
let response_body: T::ResponseBody = response.json().await?;
let answer = task.body_to_output(response_body);
Ok(answer)
}

pub async fn stream_output_of<T: Job>(
&self,
task: &T,
how: &How,
) -> Result<mpsc::Receiver<Result<T::ResponseBody, Error>>, Error>
where
T::ResponseBody: Send + 'static,
{
let response = self.request(task, how).await?;
let mut stream = response.bytes_stream();

let (tx, rx) = mpsc::channel::<Result<T::ResponseBody, Error>>(100);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 should be fine

tokio::spawn(async move {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we only want to return an iterator, not spawn the thread.

while let Some(item) = stream.next().await {
match item {
Ok(bytes) => {
let events = parse_stream_event::<T::ResponseBody>(bytes.as_ref());
for event in events {
tx.send(event).await.unwrap();
}
}
Err(e) => {
tx.send(Err(e.into())).await.unwrap();
}
}
}
});
Ok(rx)
}

fn header_from_token(api_token: &str) -> header::HeaderValue {
let mut auth_value = header::HeaderValue::from_str(&format!("Bearer {api_token}")).unwrap();
// Consider marking security-sensitive headers with `set_sensitive`.
Expand Down Expand Up @@ -264,6 +301,11 @@ pub enum Error {
deserialization_error
)]
InvalidTokenizer { deserialization_error: String },
/// Deserialization error of the stream event.
#[error(
"Stream event could not be correctly deserialized. Caused by:\n{cause}. Event:\n{event}"
)]
StreamDeserializationError { cause: String, event: String },
/// Most likely either TLS errors creating the Client, or IO errors.
#[error(transparent)]
Other(#[from] reqwest::Error),
Expand Down
51 changes: 46 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ mod http;
mod image_preprocessing;
mod prompt;
mod semantic_embedding;
mod stream;
mod tokenization;
use std::time::Duration;

use http::HttpClient;
use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput};
use tokenizers::Tokenizer;
use tokio::sync::mpsc;

pub use self::{
chat::{ChatOutput, Message, TaskChat},
Expand All @@ -51,6 +53,7 @@ pub use self::{
semantic_embedding::{
SemanticRepresentation, TaskBatchSemanticEmbedding, TaskSemanticEmbedding,
},
stream::{CompletionSummary, Event, StreamChunk, StreamSummary, TaskStreamCompletion},
tokenization::{TaskTokenization, TokenizationOutput},
};

Expand Down Expand Up @@ -190,6 +193,44 @@ impl Client {
.await
}

/// Instruct a model served by the aleph alpha API to continue writing a piece of text.
/// Stream the response as a series of events.
///
/// ```no_run
/// use aleph_alpha_client::{Client, How, TaskCompletion, Error, Event};
/// async fn print_stream_completion() -> Result<(), Error> {
/// // Authenticate against API. Fetches token.
/// let client = Client::with_authentication("AA_API_TOKEN")?;
///
/// // Name of the model we we want to use. Large models give usually better answer, but are
/// // also slower and more costly.
/// let model = "luminous-base";
///
/// // The task we want to perform. Here we want to continue the sentence: "An apple a day
/// // ..."
/// let task = TaskCompletion::from_text("An apple a day").with_streaming();
///
/// // Retrieve stream from API
/// let mut response = client.stream_completion(&task, model, &How::default()).await?;
/// while let Some(Ok(event)) = response.recv().await {
/// if let Event::StreamChunk(chunk) = event {
/// println!("{}", chunk.completion);
/// }
/// }
/// Ok(())
/// }
/// ```
pub async fn stream_completion(
&self,
task: &TaskStreamCompletion<'_>,
model: &str,
how: &How,
) -> Result<mpsc::Receiver<Result<Event, Error>>, Error> {
self.http_client
.stream_output_of(&task.with_model(model), how)
.await
}

/// Send a chat message to a model.
/// ```no_run
/// use aleph_alpha_client::{Client, How, TaskChat, Error, Message};
Expand All @@ -213,11 +254,11 @@ impl Client {
/// Ok(())
/// }
/// ```
pub async fn chat<'a>(
&'a self,
task: &'a TaskChat<'a>,
model: &'a str,
how: &'a How,
pub async fn chat(
&self,
task: &TaskChat<'_>,
model: &str,
how: &How,
) -> Result<ChatOutput, Error> {
self.http_client
.output_of(&task.with_model(model), how)
Expand Down
Loading