-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stream completion #31
Closed
Closed
Changes from 1 commit
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,13 @@ | ||
use std::{borrow::Cow, time::Duration}; | ||
|
||
use reqwest::{header, ClientBuilder, RequestBuilder, StatusCode}; | ||
use futures_util::stream::StreamExt; | ||
use reqwest::{header, ClientBuilder, RequestBuilder, Response, StatusCode}; | ||
use serde::Deserialize; | ||
use thiserror::Error as ThisError; | ||
use tokenizers::Tokenizer; | ||
use tokio::sync::mpsc; | ||
|
||
use crate::How; | ||
use crate::{stream::parse_stream_event, How}; | ||
|
||
/// A job send to the Aleph Alpha Api using the http client. A job wraps all the knowledge required | ||
/// for the Aleph Alpha API to specify its result. Notably it includes the model(s) the job is | ||
|
@@ -19,7 +21,7 @@ pub trait Job { | |
type Output; | ||
|
||
/// Expected answer of the Aleph Alpha API | ||
type ResponseBody: for<'de> Deserialize<'de>; | ||
type ResponseBody: for<'de> Deserialize<'de> + Send; | ||
|
||
/// Prepare the request for the Aleph Alpha API. Authentication headers can be assumed to be | ||
/// already set. | ||
|
@@ -36,7 +38,7 @@ pub trait Task { | |
type Output; | ||
|
||
/// Expected answer of the Aleph Alpha API | ||
type ResponseBody: for<'de> Deserialize<'de>; | ||
type ResponseBody: for<'de> Deserialize<'de> + Send; | ||
|
||
/// Prepare the request for the Aleph Alpha API. Authentication headers can be assumed to be | ||
/// already set. | ||
|
@@ -100,32 +102,8 @@ impl HttpClient { | |
}) | ||
} | ||
|
||
/// Execute a task with the aleph alpha API and fetch its result. | ||
/// | ||
/// ```no_run | ||
/// use aleph_alpha_client::{Client, How, TaskCompletion, Task, Error}; | ||
/// | ||
/// async fn print_completion() -> Result<(), Error> { | ||
/// // Authenticate against API. Fetches token. | ||
/// let client = Client::with_authentication("AA_API_TOKEN")?; | ||
/// | ||
/// // Name of the model we we want to use. Large models give usually better answer, but are | ||
/// // also slower and more costly. | ||
/// let model = "luminous-base"; | ||
/// | ||
/// // The task we want to perform. Here we want to continue the sentence: "An apple a day | ||
/// // ..." | ||
/// let task = TaskCompletion::from_text("An apple a day"); | ||
/// | ||
/// // Retrieve answer from API | ||
/// let response = client.output_of(&task.with_model(model), &How::default()).await?; | ||
/// | ||
/// // Print entire sentence with completion | ||
/// println!("An apple a day{}", response.completion); | ||
/// Ok(()) | ||
/// } | ||
/// ``` | ||
pub async fn output_of<T: Job>(&self, task: &T, how: &How) -> Result<T::Output, Error> { | ||
/// Execute a task with the aleph alpha API and return the response without awaiting the body. | ||
pub async fn request<T: Job>(&self, task: &T, how: &How) -> Result<Response, Error> { | ||
let query = if how.be_nice { | ||
[("nice", "true")].as_slice() | ||
} else { | ||
|
@@ -152,12 +130,71 @@ impl HttpClient { | |
reqwest_error.into() | ||
} | ||
})?; | ||
let response = translate_http_error(response).await?; | ||
translate_http_error(response).await | ||
} | ||
|
||
/// Execute a task with the aleph alpha API and fetch its result. | ||
/// | ||
/// ```no_run | ||
/// use aleph_alpha_client::{Client, How, TaskCompletion, Task, Error}; | ||
/// | ||
/// async fn print_completion() -> Result<(), Error> { | ||
/// // Authenticate against API. Fetches token. | ||
/// let client = Client::with_authentication("AA_API_TOKEN")?; | ||
/// | ||
/// // Name of the model we we want to use. Large models give usually better answer, but are | ||
/// // also slower and more costly. | ||
/// let model = "luminous-base"; | ||
/// | ||
/// // The task we want to perform. Here we want to continue the sentence: "An apple a day | ||
/// // ..." | ||
/// let task = TaskCompletion::from_text("An apple a day"); | ||
/// | ||
/// // Retrieve answer from API | ||
/// let response = client.output_of(&task.with_model(model), &How::default()).await?; | ||
/// | ||
/// // Print entire sentence with completion | ||
/// println!("An apple a day{}", response.completion); | ||
/// Ok(()) | ||
/// } | ||
/// ``` | ||
pub async fn output_of<T: Job>(&self, task: &T, how: &How) -> Result<T::Output, Error> { | ||
let response = self.request(task, how).await?; | ||
let response_body: T::ResponseBody = response.json().await?; | ||
let answer = task.body_to_output(response_body); | ||
Ok(answer) | ||
} | ||
|
||
pub async fn stream_output_of<T: Job>( | ||
&self, | ||
task: &T, | ||
how: &How, | ||
) -> Result<mpsc::Receiver<Result<T::ResponseBody, Error>>, Error> | ||
where | ||
T::ResponseBody: Send + 'static, | ||
{ | ||
let response = self.request(task, how).await?; | ||
let mut stream = response.bytes_stream(); | ||
|
||
let (tx, rx) = mpsc::channel::<Result<T::ResponseBody, Error>>(100); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 5 should be fine |
||
tokio::spawn(async move { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we only want to return an iterator, not spawn the thread. |
||
while let Some(item) = stream.next().await { | ||
match item { | ||
Ok(bytes) => { | ||
let events = parse_stream_event::<T::ResponseBody>(bytes.as_ref()); | ||
for event in events { | ||
tx.send(event).await.unwrap(); | ||
} | ||
} | ||
Err(e) => { | ||
tx.send(Err(e.into())).await.unwrap(); | ||
} | ||
} | ||
} | ||
}); | ||
Ok(rx) | ||
} | ||
|
||
fn header_from_token(api_token: &str) -> header::HeaderValue { | ||
let mut auth_value = header::HeaderValue::from_str(&format!("Bearer {api_token}")).unwrap(); | ||
// Consider marking security-sensitive headers with `set_sensitive`. | ||
|
@@ -264,6 +301,11 @@ pub enum Error { | |
deserialization_error | ||
)] | ||
InvalidTokenizer { deserialization_error: String }, | ||
/// Deserialization error of the stream event. | ||
#[error( | ||
"Stream event could not be correctly deserialized. Caused by:\n{cause}. Event:\n{event}" | ||
)] | ||
StreamDeserializationError { cause: String, event: String }, | ||
/// Most likely either TLS errors creating the Client, or IO errors. | ||
#[error(transparent)] | ||
Other(#[from] reqwest::Error), | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this need to be public?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move StreamCompletion to completion module?