-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: clean up the text generation plugin to split up and prepare…
… for multiple backends
- Loading branch information
1 parent
de38919
commit 5ca87d3
Showing
5 changed files
with
170 additions
and
100 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,109 +1,29 @@ | ||
use mai_sdk_core::task_queue::{Runnable, TaskId}; | ||
use serde::{Deserialize, Serialize}; | ||
use slog::{info, Logger}; | ||
mod ollama; | ||
mod state; | ||
mod task; | ||
|
||
#[derive(Debug, Clone)] | ||
pub struct TextGenerationPluginState { | ||
logger: Logger, | ||
} | ||
|
||
impl TextGenerationPluginState { | ||
pub fn new(logger: &Logger) -> Self { | ||
Self { | ||
logger: logger.clone(), | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
pub struct TextGenerationPluginTask { | ||
id: TaskId, | ||
model: String, | ||
messages: Vec<ChatRequestMessage>, | ||
} | ||
pub use ollama::{OllamaChatCompletionRequest, OllamaChatCompletionResponse}; | ||
pub use state::TextGenerationPluginState; | ||
pub use task::{TextGenerationPluginTask, TextGenerationPluginTaskOutput}; | ||
|
||
impl TextGenerationPluginTask { | ||
pub fn new(model: String, messages: Vec<ChatRequestMessage>) -> Self { | ||
Self { | ||
id: nanoid::nanoid!(), | ||
model, | ||
messages, | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
/// ChatRequestMessage | ||
/// Is a dto that represents a chat request message agnostic of model or platform | ||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] | ||
pub struct ChatRequestMessage { | ||
content: String, | ||
role: String, | ||
} | ||
|
||
impl ChatRequestMessage { | ||
/// Create a new ChatRequestMessage | ||
/// role should be one of the following: | ||
/// - "user" | ||
/// - "assistant" | ||
/// - "system" | ||
pub fn new(content: String, role: String) -> Self { | ||
if role != "user" && role != "assistant" && role != "system" { | ||
panic!("role should be one of the following: 'user', 'assistant', 'system'"); | ||
} | ||
Self { content, role } | ||
} | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
pub struct TextGenerationPluginTaskOutput { | ||
pub role: String, | ||
pub content: String, | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
struct OllamaChatCompletionRequest { | ||
model: String, | ||
messages: Vec<ChatRequestMessage>, | ||
stream: bool, | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
struct OllamaChatCompletionResponse { | ||
model: String, | ||
created_at: String, | ||
message: OllamaChatCompletionResponseMessage, | ||
done: bool, | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
struct OllamaChatCompletionResponseMessage { | ||
content: String, | ||
role: String, | ||
} | ||
|
||
impl Runnable<TextGenerationPluginTaskOutput, TextGenerationPluginState> | ||
for TextGenerationPluginTask | ||
{ | ||
fn id(&self) -> TaskId { | ||
self.id.clone() | ||
} | ||
|
||
async fn run( | ||
&self, | ||
state: TextGenerationPluginState, | ||
) -> anyhow::Result<TextGenerationPluginTaskOutput> { | ||
// Send request | ||
let body = OllamaChatCompletionRequest { | ||
model: self.model.clone(), | ||
messages: self.messages.clone(), | ||
stream: false, | ||
}; | ||
let client = reqwest::Client::new(); | ||
let resp = client | ||
.post("http://localhost:11434/api/chat") | ||
.json(&body) | ||
.send() | ||
.await?; | ||
|
||
// Parse response | ||
let body = resp.json::<OllamaChatCompletionResponse>().await?.message; | ||
let output = TextGenerationPluginTaskOutput { | ||
content: body.content, | ||
role: body.role, | ||
}; | ||
info!(state.logger, "OllamaPluginTask::ChatCompletion completed"; "output" => format!("{:?}", output)); | ||
|
||
// Return result | ||
Ok(output) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
use super::ChatRequestMessage; | ||
|
||
/// See https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion for more information | ||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] | ||
pub struct OllamaChatCompletionRequest { | ||
model: String, | ||
messages: Vec<ChatRequestMessage>, | ||
stream: bool, | ||
} | ||
|
||
impl OllamaChatCompletionRequest { | ||
pub fn new(model: String, messages: Vec<ChatRequestMessage>) -> Self { | ||
Self { | ||
model, | ||
messages, | ||
stream: false, | ||
} | ||
} | ||
} | ||
|
||
/// See https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion for more information | ||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] | ||
pub struct OllamaChatCompletionResponse { | ||
model: String, | ||
created_at: String, | ||
message: OllamaChatCompletionResponseMessage, | ||
done: bool, | ||
} | ||
|
||
impl OllamaChatCompletionResponse { | ||
pub fn new( | ||
model: String, | ||
created_at: String, | ||
content: String, | ||
role: String, | ||
done: bool, | ||
) -> Self { | ||
Self { | ||
model, | ||
created_at, | ||
message: OllamaChatCompletionResponseMessage { content, role }, | ||
done, | ||
} | ||
} | ||
|
||
pub fn content(&self) -> String { | ||
self.message.content.clone() | ||
} | ||
|
||
pub fn role(&self) -> String { | ||
self.message.role.clone() | ||
} | ||
} | ||
|
||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] | ||
struct OllamaChatCompletionResponseMessage { | ||
content: String, | ||
role: String, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
use slog::Logger; | ||
|
||
#[derive(Debug, Clone)] | ||
pub struct TextGenerationPluginState { | ||
pub logger: Logger, | ||
} | ||
|
||
impl TextGenerationPluginState { | ||
pub fn new(logger: &Logger) -> Self { | ||
Self { | ||
logger: logger.clone(), | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
use mai_sdk_core::task_queue::{Runnable, TaskId}; | ||
use serde::{Deserialize, Serialize}; | ||
use slog::info; | ||
|
||
use crate::text_generation::ollama::{OllamaChatCompletionRequest, OllamaChatCompletionResponse}; | ||
|
||
use super::{state::TextGenerationPluginState, ChatRequestMessage}; | ||
|
||
/// TextGenerationPluginTask | ||
/// This task implements the ability to call an LLM model to generate text | ||
/// The only method of generating text for now is through chat completion as it is the most common use case | ||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
pub struct TextGenerationPluginTask { | ||
pub(crate) id: TaskId, | ||
|
||
/// The model to use for generating text | ||
/// Models should use the following format: | ||
/// - for hugging face hub: "username/model_name" | ||
/// - for local models: "path/to/model" | ||
/// - for ollama models: "ollama/model_name" | ||
pub(crate) model: String, | ||
|
||
/// The messages to use for generating text | ||
/// The consumer is responsible for ensuring that the messages fit within the model's context window | ||
pub(crate) messages: Vec<ChatRequestMessage>, | ||
} | ||
|
||
impl TextGenerationPluginTask { | ||
pub fn new(model: String, messages: Vec<ChatRequestMessage>) -> Self { | ||
Self { | ||
id: nanoid::nanoid!(), | ||
model, | ||
messages, | ||
} | ||
} | ||
} | ||
|
||
/// TextGenerationPluginTaskOutput | ||
/// The output of the TextGenerationPluginTask, this will contain only the role and the generated text of the model | ||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
pub struct TextGenerationPluginTaskOutput { | ||
pub role: String, | ||
pub content: String, | ||
} | ||
|
||
impl Runnable<TextGenerationPluginTaskOutput, TextGenerationPluginState> | ||
for TextGenerationPluginTask | ||
{ | ||
fn id(&self) -> TaskId { | ||
self.id.clone() | ||
} | ||
|
||
async fn run( | ||
&self, | ||
state: TextGenerationPluginState, | ||
) -> anyhow::Result<TextGenerationPluginTaskOutput> { | ||
// Send request | ||
let body = OllamaChatCompletionRequest::new(self.model.clone(), self.messages.clone()); | ||
let client = reqwest::Client::new(); | ||
let resp = client | ||
.post("http://localhost:11434/api/chat") | ||
.json(&body) | ||
.send() | ||
.await?; | ||
|
||
// Parse response | ||
let response_body = resp.json::<OllamaChatCompletionResponse>().await?; | ||
let output = TextGenerationPluginTaskOutput { | ||
content: response_body.content(), | ||
role: response_body.role(), | ||
}; | ||
info!(state.logger, "OllamaPluginTask::ChatCompletion completed"; "output" => format!("{:?}", output)); | ||
|
||
// Return result | ||
Ok(output) | ||
} | ||
} |