diff --git a/mai-sdk-core/Cargo.toml b/mai-sdk-core/Cargo.toml index 590767a..0a5dc4f 100644 --- a/mai-sdk-core/Cargo.toml +++ b/mai-sdk-core/Cargo.toml @@ -58,9 +58,9 @@ libp2p = { version = "0.53", features = [ "yamux", # "upnp", ] } -either = "1.12.0" -base64 = "0.22.1" -sqlite = "0.36.0" +either = "1.12" +base64 = "0.22" +sqlite = "0.36" [dev-dependencies] slog-async = "2.8.0" diff --git a/mai-sdk-plugins/src/text_generation/mod.rs b/mai-sdk-plugins/src/text_generation/mod.rs index 615b1ec..0c7afc6 100644 --- a/mai-sdk-plugins/src/text_generation/mod.rs +++ b/mai-sdk-plugins/src/text_generation/mod.rs @@ -1,109 +1,29 @@ -use mai_sdk_core::task_queue::{Runnable, TaskId}; -use serde::{Deserialize, Serialize}; -use slog::{info, Logger}; +mod ollama; +mod state; +mod task; -#[derive(Debug, Clone)] -pub struct TextGenerationPluginState { - logger: Logger, -} - -impl TextGenerationPluginState { - pub fn new(logger: &Logger) -> Self { - Self { - logger: logger.clone(), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TextGenerationPluginTask { - id: TaskId, - model: String, - messages: Vec, -} +pub use ollama::{OllamaChatCompletionRequest, OllamaChatCompletionResponse}; +pub use state::TextGenerationPluginState; +pub use task::{TextGenerationPluginTask, TextGenerationPluginTaskOutput}; -impl TextGenerationPluginTask { - pub fn new(model: String, messages: Vec) -> Self { - Self { - id: nanoid::nanoid!(), - model, - messages, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] +/// ChatRequestMessage +/// Is a dto that represents a chat request message agnostic of model or platform +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct ChatRequestMessage { content: String, role: String, } impl ChatRequestMessage { + /// Create a new ChatRequestMessage + /// role should be one of the following: + /// - "user" + /// - "assistant" + /// - "system" pub fn new(content: String, role: String) -> Self { + if role != "user" && role != "assistant" && role != "system" { + panic!("role should be one of the following: 'user', 'assistant', 'system'"); + } Self { content, role } } } - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TextGenerationPluginTaskOutput { - pub role: String, - pub content: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct OllamaChatCompletionRequest { - model: String, - messages: Vec, - stream: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct OllamaChatCompletionResponse { - model: String, - created_at: String, - message: OllamaChatCompletionResponseMessage, - done: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct OllamaChatCompletionResponseMessage { - content: String, - role: String, -} - -impl Runnable - for TextGenerationPluginTask -{ - fn id(&self) -> TaskId { - self.id.clone() - } - - async fn run( - &self, - state: TextGenerationPluginState, - ) -> anyhow::Result { - // Send request - let body = OllamaChatCompletionRequest { - model: self.model.clone(), - messages: self.messages.clone(), - stream: false, - }; - let client = reqwest::Client::new(); - let resp = client - .post("http://localhost:11434/api/chat") - .json(&body) - .send() - .await?; - - // Parse response - let body = resp.json::().await?.message; - let output = TextGenerationPluginTaskOutput { - content: body.content, - role: body.role, - }; - info!(state.logger, "OllamaPluginTask::ChatCompletion completed"; "output" => format!("{:?}", output)); - - // Return result - Ok(output) - } -} diff --git a/mai-sdk-plugins/src/text_generation/ollama.rs b/mai-sdk-plugins/src/text_generation/ollama.rs new file mode 100644 index 0000000..6b26148 --- /dev/null +++ b/mai-sdk-plugins/src/text_generation/ollama.rs @@ -0,0 +1,59 @@ +use super::ChatRequestMessage; + +/// See https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion for more information +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct OllamaChatCompletionRequest { + model: String, + messages: Vec, + stream: bool, +} + +impl OllamaChatCompletionRequest { + pub fn new(model: String, messages: Vec) -> Self { + Self { + model, + messages, + stream: false, + } + } +} + +/// See https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion for more information +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct OllamaChatCompletionResponse { + model: String, + created_at: String, + message: OllamaChatCompletionResponseMessage, + done: bool, +} + +impl OllamaChatCompletionResponse { + pub fn new( + model: String, + created_at: String, + content: String, + role: String, + done: bool, + ) -> Self { + Self { + model, + created_at, + message: OllamaChatCompletionResponseMessage { content, role }, + done, + } + } + + pub fn content(&self) -> String { + self.message.content.clone() + } + + pub fn role(&self) -> String { + self.message.role.clone() + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct OllamaChatCompletionResponseMessage { + content: String, + role: String, +} diff --git a/mai-sdk-plugins/src/text_generation/state.rs b/mai-sdk-plugins/src/text_generation/state.rs new file mode 100644 index 0000000..b0c2080 --- /dev/null +++ b/mai-sdk-plugins/src/text_generation/state.rs @@ -0,0 +1,14 @@ +use slog::Logger; + +#[derive(Debug, Clone)] +pub struct TextGenerationPluginState { + pub logger: Logger, +} + +impl TextGenerationPluginState { + pub fn new(logger: &Logger) -> Self { + Self { + logger: logger.clone(), + } + } +} diff --git a/mai-sdk-plugins/src/text_generation/task.rs b/mai-sdk-plugins/src/text_generation/task.rs new file mode 100644 index 0000000..0c25dfe --- /dev/null +++ b/mai-sdk-plugins/src/text_generation/task.rs @@ -0,0 +1,77 @@ +use mai_sdk_core::task_queue::{Runnable, TaskId}; +use serde::{Deserialize, Serialize}; +use slog::info; + +use crate::text_generation::ollama::{OllamaChatCompletionRequest, OllamaChatCompletionResponse}; + +use super::{state::TextGenerationPluginState, ChatRequestMessage}; + +/// TextGenerationPluginTask +/// This task implements the ability to call an LLM model to generate text +/// The only method of generating text for now is through chat completion as it is the most common use case +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TextGenerationPluginTask { + pub(crate) id: TaskId, + + /// The model to use for generating text + /// Models should use the following format: + /// - for hugging face hub: "username/model_name" + /// - for local models: "path/to/model" + /// - for ollama models: "ollama/model_name" + pub(crate) model: String, + + /// The messages to use for generating text + /// The consumer is responsible for ensuring that the messages fit within the model's context window + pub(crate) messages: Vec, +} + +impl TextGenerationPluginTask { + pub fn new(model: String, messages: Vec) -> Self { + Self { + id: nanoid::nanoid!(), + model, + messages, + } + } +} + +/// TextGenerationPluginTaskOutput +/// The output of the TextGenerationPluginTask, this will contain only the role and the generated text of the model +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TextGenerationPluginTaskOutput { + pub role: String, + pub content: String, +} + +impl Runnable + for TextGenerationPluginTask +{ + fn id(&self) -> TaskId { + self.id.clone() + } + + async fn run( + &self, + state: TextGenerationPluginState, + ) -> anyhow::Result { + // Send request + let body = OllamaChatCompletionRequest::new(self.model.clone(), self.messages.clone()); + let client = reqwest::Client::new(); + let resp = client + .post("http://localhost:11434/api/chat") + .json(&body) + .send() + .await?; + + // Parse response + let response_body = resp.json::().await?; + let output = TextGenerationPluginTaskOutput { + content: response_body.content(), + role: response_body.role(), + }; + info!(state.logger, "OllamaPluginTask::ChatCompletion completed"; "output" => format!("{:?}", output)); + + // Return result + Ok(output) + } +}