Skip to content

Commit

Permalink
Merge pull request #32 from anilaltuner/main
Browse files Browse the repository at this point in the history
vllm added
  • Loading branch information
andthattoo authored Jan 13, 2025
2 parents f622133 + a469721 commit 60a87f5
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/api_interface/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod gem_api;
pub mod open_router;
pub mod openai_api;
pub mod vllm;
210 changes: 210 additions & 0 deletions src/api_interface/vllm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
use crate::program::atomics::MessageInput;
use ollama_rs::{
error::OllamaError, generation::functions::tools::Tool,
generation::functions::OpenAIFunctionCall,
};
use openai_dive::v1::api::Client;
use openai_dive::v1::resources::chat::*;
use serde_json::{json, Value};
use std::sync::Arc;

pub struct VLLMExecutor {
model: String,
client: Client,
}

impl VLLMExecutor {
pub fn new(model: String, base_url: String) -> Self {
Self {
model,
client: Client::new_with_base(base_url.as_str(), "".to_string()),
}
}

pub async fn generate_text(
&self,
input: Vec<MessageInput>,
schema: &Option<String>,
) -> Result<String, OllamaError> {
let messages: Vec<ChatMessage> = input
.into_iter()
.map(|msg| match msg.role.as_str() {
"user" => ChatMessage::User {
content: ChatMessageContent::Text(msg.content),
name: None,
},
"assistant" => ChatMessage::Assistant {
content: Some(ChatMessageContent::Text(msg.content)),
tool_calls: None,
name: None,
refusal: None,
},
"system" => ChatMessage::System {
content: ChatMessageContent::Text(msg.content),
name: None,
},
_ => ChatMessage::User {
content: ChatMessageContent::Text(msg.content),
name: None,
},
})
.collect();

let parameters = if let Some(schema_str) = schema {
let mut schema_json: Value = serde_json::from_str(schema_str)
.map_err(|e| OllamaError::from(format!("Invalid schema JSON: {:?}", e)))?;

if let Value::Object(ref mut map) = schema_json {
map.insert("additionalProperties".to_string(), Value::Bool(false));
}

ChatCompletionParametersBuilder::default()
.model(self.model.clone())
.messages(messages)
.response_format(ChatCompletionResponseFormat::JsonSchema(
JsonSchemaBuilder::default()
.name("structured_output")
.schema(schema_json)
.strict(true)
.build()
.map_err(|e| {
OllamaError::from(format!("Could not build JSON schema: {:?}", e))
})?,
))
.build()
} else {
ChatCompletionParametersBuilder::default()
.model(self.model.clone())
.messages(messages)
.response_format(ChatCompletionResponseFormat::Text)
.build()
}
.map_err(|e| OllamaError::from(format!("Could not build message parameters: {:?}", e)))?;

let result = self.client.chat().create(parameters).await.map_err(|e| {
OllamaError::from(format!("Failed to parse VLLM API response: {:?}", e))
})?;

let message = match &result.choices[0].message {
ChatMessage::Assistant { content, .. } => {
if let Some(ChatMessageContent::Text(text)) = content {
text.clone()
} else {
return Err(OllamaError::from(
"Unexpected message content format".to_string(),
));
}
}
_ => return Err(OllamaError::from("Unexpected message type".to_string())),
};

Ok(message)
}

pub async fn function_call(
&self,
prompt: &str,
tools: Vec<Arc<dyn Tool>>,
raw_mode: bool,
oai_parser: Arc<OpenAIFunctionCall>,
) -> Result<String, OllamaError> {
let openai_tools: Vec<_> = tools
.iter()
.map(|tool| ChatCompletionTool {
r#type: ChatCompletionToolType::Function,
function: ChatCompletionFunction {
name: tool.name().to_lowercase().replace(' ', "_"),
description: Some(tool.description()),
parameters: tool.parameters(),
},
})
.collect();

let messages = vec![ChatMessage::User {
content: ChatMessageContent::Text(prompt.to_string()),
name: None,
}];

let parameters = ChatCompletionParametersBuilder::default()
.model(self.model.clone())
.messages(messages)
.tools(openai_tools)
.build()
.map_err(|e| {
OllamaError::from(format!("Could not build message parameters: {:?}", e))
})?;

let result = self.client.chat().create(parameters).await.map_err(|e| {
OllamaError::from(format!("Failed to parse VLLM API response: {:?}", e))
})?;
let message = result.choices[0].message.clone();

if raw_mode {
self.handle_raw_mode(message)
} else {
self.handle_normal_mode(message, tools, oai_parser).await
}
}

fn handle_raw_mode(&self, message: ChatMessage) -> Result<String, OllamaError> {
let mut raw_calls = Vec::new();

if let ChatMessage::Assistant {
tool_calls: Some(tool_calls),
..
} = message
{
for tool_call in tool_calls {
let call_json = json!({
"name": tool_call.function.name,
"arguments": serde_json::from_str::<serde_json::Value>(&tool_call.function.arguments)?
});
raw_calls.push(serde_json::to_string(&call_json)?);
}
}

Ok(raw_calls.join("\n\n"))
}

async fn handle_normal_mode(
&self,
message: ChatMessage,
tools: Vec<Arc<dyn Tool>>,
oai_parser: Arc<OpenAIFunctionCall>,
) -> Result<String, OllamaError> {
let mut results = Vec::<String>::new();

if let ChatMessage::Assistant {
tool_calls: Some(tool_calls),
..
} = message
{
for tool_call in tool_calls {
for tool in &tools {
if tool.name().to_lowercase().replace(' ', "_") == tool_call.function.name {
let tool_params: Value =
serde_json::from_str(&tool_call.function.arguments)?;
let res = oai_parser
.function_call_with_history(
tool_call.function.name.clone(),
tool_params,
tool.clone(),
)
.await;
match res {
Ok(result) => results.push(result.message.unwrap().content),
Err(e) => {
return Err(OllamaError::from(format!(
"Could not generate text: {:?}",
e
)))
}
}
}
}
}
}

Ok(results.join("\n"))
}
}
13 changes: 13 additions & 0 deletions src/program/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use super::workflow::Workflow;
use crate::api_interface::gem_api::GeminiExecutor;
use crate::api_interface::open_router::OpenRouterExecutor;
use crate::api_interface::openai_api::OpenAIExecutor;
use crate::api_interface::vllm::VLLMExecutor;
use crate::memory::types::Entry;
use crate::memory::{MemoryReturnType, ProgramMemory};
use crate::program::atomics::MessageInput;
Expand Down Expand Up @@ -585,6 +586,11 @@ impl Executor {
OpenRouterExecutor::new(self.model.to_string(), api_key.clone());
openai_executor.generate_text(input, schema).await?
}
ModelProvider::VLLM => {
let executor =
VLLMExecutor::new(self.model.to_string(), "http://localhost:8000".to_string());
executor.generate_text(input, schema).await?
}
};

Ok(response)
Expand Down Expand Up @@ -669,6 +675,13 @@ impl Executor {
.function_call(prompt, tools, raw_mode, oai_parser)
.await?
}
ModelProvider::VLLM => {
let executor =
VLLMExecutor::new(self.model.to_string(), "http://localhost:8000".to_string());
executor
.function_call(prompt, tools, raw_mode, oai_parser)
.await?
}
};

Ok(result)
Expand Down
7 changes: 7 additions & 0 deletions src/program/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ pub enum Model {

#[serde(rename = "openai/o1")]
OROpenAIO1,

#[serde(rename = "Qwen/Qwen2.5-1.5B-Instruct")]
Qwen25Vllm,
}

impl Model {
Expand Down Expand Up @@ -264,6 +267,8 @@ pub enum ModelProvider {
Gemini,
#[serde(rename = "openrouter")]
OpenRouter,
#[serde(rename = "VLLM")]
VLLM,
}

impl From<Model> for ModelProvider {
Expand Down Expand Up @@ -331,6 +336,8 @@ impl From<Model> for ModelProvider {
Model::ORNemotron70B => ModelProvider::OpenRouter,
Model::ORNousHermes405B => ModelProvider::OpenRouter,
Model::OROpenAIO1 => ModelProvider::OpenRouter,
//vllm
Model::Qwen25Vllm => ModelProvider::VLLM,
}
}
}
Expand Down

0 comments on commit 60a87f5

Please sign in to comment.