Skip to content

Commit

Permalink
refactor: clean up the text generation plugin to split up and prepare…
Browse files Browse the repository at this point in the history
… for multiple backends
  • Loading branch information
tomsanbear committed Jun 15, 2024
1 parent de38919 commit 5ca87d3
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 100 deletions.
6 changes: 3 additions & 3 deletions mai-sdk-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ libp2p = { version = "0.53", features = [
"yamux",
# "upnp",
] }
either = "1.12.0"
base64 = "0.22.1"
sqlite = "0.36.0"
either = "1.12"
base64 = "0.22"
sqlite = "0.36"

[dev-dependencies]
slog-async = "2.8.0"
Expand Down
114 changes: 17 additions & 97 deletions mai-sdk-plugins/src/text_generation/mod.rs
Original file line number Diff line number Diff line change
@@ -1,109 +1,29 @@
use mai_sdk_core::task_queue::{Runnable, TaskId};
use serde::{Deserialize, Serialize};
use slog::{info, Logger};
mod ollama;
mod state;
mod task;

#[derive(Debug, Clone)]
pub struct TextGenerationPluginState {
logger: Logger,
}

impl TextGenerationPluginState {
pub fn new(logger: &Logger) -> Self {
Self {
logger: logger.clone(),
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextGenerationPluginTask {
id: TaskId,
model: String,
messages: Vec<ChatRequestMessage>,
}
pub use ollama::{OllamaChatCompletionRequest, OllamaChatCompletionResponse};
pub use state::TextGenerationPluginState;
pub use task::{TextGenerationPluginTask, TextGenerationPluginTaskOutput};

impl TextGenerationPluginTask {
pub fn new(model: String, messages: Vec<ChatRequestMessage>) -> Self {
Self {
id: nanoid::nanoid!(),
model,
messages,
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
/// ChatRequestMessage
/// Is a dto that represents a chat request message agnostic of model or platform
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ChatRequestMessage {
content: String,
role: String,
}

impl ChatRequestMessage {
/// Create a new ChatRequestMessage
/// role should be one of the following:
/// - "user"
/// - "assistant"
/// - "system"
pub fn new(content: String, role: String) -> Self {
if role != "user" && role != "assistant" && role != "system" {
panic!("role should be one of the following: 'user', 'assistant', 'system'");
}
Self { content, role }
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextGenerationPluginTaskOutput {
pub role: String,
pub content: String,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct OllamaChatCompletionRequest {
model: String,
messages: Vec<ChatRequestMessage>,
stream: bool,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct OllamaChatCompletionResponse {
model: String,
created_at: String,
message: OllamaChatCompletionResponseMessage,
done: bool,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct OllamaChatCompletionResponseMessage {
content: String,
role: String,
}

impl Runnable<TextGenerationPluginTaskOutput, TextGenerationPluginState>
for TextGenerationPluginTask
{
fn id(&self) -> TaskId {
self.id.clone()
}

async fn run(
&self,
state: TextGenerationPluginState,
) -> anyhow::Result<TextGenerationPluginTaskOutput> {
// Send request
let body = OllamaChatCompletionRequest {
model: self.model.clone(),
messages: self.messages.clone(),
stream: false,
};
let client = reqwest::Client::new();
let resp = client
.post("http://localhost:11434/api/chat")
.json(&body)
.send()
.await?;

// Parse response
let body = resp.json::<OllamaChatCompletionResponse>().await?.message;
let output = TextGenerationPluginTaskOutput {
content: body.content,
role: body.role,
};
info!(state.logger, "OllamaPluginTask::ChatCompletion completed"; "output" => format!("{:?}", output));

// Return result
Ok(output)
}
}
59 changes: 59 additions & 0 deletions mai-sdk-plugins/src/text_generation/ollama.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use super::ChatRequestMessage;

/// See https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion for more information
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct OllamaChatCompletionRequest {
model: String,
messages: Vec<ChatRequestMessage>,
stream: bool,
}

impl OllamaChatCompletionRequest {
pub fn new(model: String, messages: Vec<ChatRequestMessage>) -> Self {
Self {
model,
messages,
stream: false,
}
}
}

/// See https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion for more information
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct OllamaChatCompletionResponse {
model: String,
created_at: String,
message: OllamaChatCompletionResponseMessage,
done: bool,
}

impl OllamaChatCompletionResponse {
pub fn new(
model: String,
created_at: String,
content: String,
role: String,
done: bool,
) -> Self {
Self {
model,
created_at,
message: OllamaChatCompletionResponseMessage { content, role },
done,
}
}

pub fn content(&self) -> String {
self.message.content.clone()
}

pub fn role(&self) -> String {
self.message.role.clone()
}
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct OllamaChatCompletionResponseMessage {
content: String,
role: String,
}
14 changes: 14 additions & 0 deletions mai-sdk-plugins/src/text_generation/state.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use slog::Logger;

#[derive(Debug, Clone)]
pub struct TextGenerationPluginState {
pub logger: Logger,
}

impl TextGenerationPluginState {
pub fn new(logger: &Logger) -> Self {
Self {
logger: logger.clone(),
}
}
}
77 changes: 77 additions & 0 deletions mai-sdk-plugins/src/text_generation/task.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use mai_sdk_core::task_queue::{Runnable, TaskId};
use serde::{Deserialize, Serialize};
use slog::info;

use crate::text_generation::ollama::{OllamaChatCompletionRequest, OllamaChatCompletionResponse};

use super::{state::TextGenerationPluginState, ChatRequestMessage};

/// TextGenerationPluginTask
/// This task implements the ability to call an LLM model to generate text
/// The only method of generating text for now is through chat completion as it is the most common use case
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextGenerationPluginTask {
pub(crate) id: TaskId,

/// The model to use for generating text
/// Models should use the following format:
/// - for hugging face hub: "username/model_name"
/// - for local models: "path/to/model"
/// - for ollama models: "ollama/model_name"
pub(crate) model: String,

/// The messages to use for generating text
/// The consumer is responsible for ensuring that the messages fit within the model's context window
pub(crate) messages: Vec<ChatRequestMessage>,
}

impl TextGenerationPluginTask {
pub fn new(model: String, messages: Vec<ChatRequestMessage>) -> Self {
Self {
id: nanoid::nanoid!(),
model,
messages,
}
}
}

/// TextGenerationPluginTaskOutput
/// The output of the TextGenerationPluginTask, this will contain only the role and the generated text of the model
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextGenerationPluginTaskOutput {
pub role: String,
pub content: String,
}

impl Runnable<TextGenerationPluginTaskOutput, TextGenerationPluginState>
for TextGenerationPluginTask
{
fn id(&self) -> TaskId {
self.id.clone()
}

async fn run(
&self,
state: TextGenerationPluginState,
) -> anyhow::Result<TextGenerationPluginTaskOutput> {
// Send request
let body = OllamaChatCompletionRequest::new(self.model.clone(), self.messages.clone());
let client = reqwest::Client::new();
let resp = client
.post("http://localhost:11434/api/chat")
.json(&body)
.send()
.await?;

// Parse response
let response_body = resp.json::<OllamaChatCompletionResponse>().await?;
let output = TextGenerationPluginTaskOutput {
content: response_body.content(),
role: response_body.role(),
};
info!(state.logger, "OllamaPluginTask::ChatCompletion completed"; "output" => format!("{:?}", output));

// Return result
Ok(output)
}
}

0 comments on commit 5ca87d3

Please sign in to comment.