Skip to content

Commit

Permalink
refactor: improve tool calls (#995)
Browse files Browse the repository at this point in the history
- rename MessageContent:ToolResults to MessageContent:ToolCalls
- rename ToolResults to MessageContentToolCalls
- persist tool_calls to messages.md
  • Loading branch information
sigoden authored Nov 14, 2024
1 parent cfa9217 commit 80684ec
Show file tree
Hide file tree
Showing 16 changed files with 84 additions and 76 deletions.
7 changes: 3 additions & 4 deletions src/client/bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,10 +363,9 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu
"content": content,
})]
}
MessageContent::ToolResults(results) => {
let ToolResults {
tool_results, text, ..
} = results;
MessageContent::ToolCalls(MessageContentToolCalls {
tool_results, text, ..
}) => {
let mut assistant_parts = vec![];
let mut user_parts = vec![];
if !text.is_empty() {
Expand Down
7 changes: 3 additions & 4 deletions src/client/claude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,9 @@ pub fn claude_build_chat_completions_body(
"content": content,
})]
}
MessageContent::ToolResults(results) => {
let ToolResults {
tool_results, text, ..
} = results;
MessageContent::ToolCalls(MessageContentToolCalls {
tool_results, text, ..
}) => {
let mut assistant_parts = vec![];
let mut user_parts = vec![];
if !text.is_empty() {
Expand Down
4 changes: 2 additions & 2 deletions src/client/cohere.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu
.collect();
Some(json!({ "role": role, "message": list.join("\n\n") }))
}
MessageContent::ToolResults(results) => {
tool_results = Some(results.tool_results);
MessageContent::ToolCalls(tool_calls) => {
tool_results = Some(tool_calls.tool_results);
None
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/client/ernie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,11 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Valu
.flat_map(|message| {
let Message { role, content } = message;
match content {
MessageContent::ToolResults(results) => {
MessageContent::ToolCalls(MessageContentToolCalls {
tool_results, ..
}) => {
let mut list = vec![];
for tool_result in results.tool_results {
for tool_result in tool_results {
list.push(json!({
"role": "assistant",
"content": format!("Action: {}\nAction Input: {}", tool_result.call.name, tool_result.call.arguments)
Expand Down
40 changes: 30 additions & 10 deletions src/client/message.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use super::ToolResults;

use crate::utils::dimmed_text;
use crate::{function::ToolResult, utils::dimmed_text};

use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -75,7 +73,7 @@ pub enum MessageContent {
Text(String),
Array(Vec<MessageContentPart>),
// Note: This type is primarily for convenience and does not exist in OpenAI's API.
ToolResults(ToolResults),
ToolCalls(MessageContentToolCalls),
}

impl MessageContent {
Expand Down Expand Up @@ -103,10 +101,9 @@ impl MessageContent {
}
format!(".file {}{}", files.join(" "), concated_text)
}
MessageContent::ToolResults(results) => {
let ToolResults {
tool_results, text, ..
} = results;
MessageContent::ToolCalls(MessageContentToolCalls {
tool_results, text, ..
}) => {
let mut lines = vec![];
if !text.is_empty() {
lines.push(text.clone())
Expand Down Expand Up @@ -139,7 +136,7 @@ impl MessageContent {
*text = replace_fn(text)
}
}
MessageContent::ToolResults(_) => {}
MessageContent::ToolCalls(_) => {}
}
}

Expand All @@ -155,7 +152,7 @@ impl MessageContent {
}
parts.join("\n\n")
}
MessageContent::ToolResults(_) => String::new(),
MessageContent::ToolCalls(_) => String::new(),
}
}
}
Expand All @@ -172,6 +169,29 @@ pub struct ImageUrl {
pub url: String,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageContentToolCalls {
pub tool_results: Vec<ToolResult>,
pub text: String,
pub sequence: bool,
}

impl MessageContentToolCalls {
pub fn new(tool_results: Vec<ToolResult>, text: String) -> Self {
Self {
tool_results,
text,
sequence: false,
}
}

pub fn merge(&mut self, tool_results: Vec<ToolResult>, _text: String) {
self.tool_results.extend(tool_results);
self.text.clear();
self.sequence = true;
}
}

pub fn patch_system_message(messages: &mut Vec<Message>) {
if messages[0].role.is_system() {
let system_message = messages.remove(0);
Expand Down
2 changes: 1 addition & 1 deletion src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod macros;
mod model;
mod stream;

pub use crate::function::{ToolCall, ToolResults};
pub use crate::function::ToolCall;
pub use crate::utils::PromptKind;
pub use common::*;
pub use message::*;
Expand Down
9 changes: 4 additions & 5 deletions src/client/model.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{
list_chat_models, list_embedding_models, list_reranker_models,
message::{Message, MessageContent, MessageContentPart},
ToolResults,
MessageContentToolCalls,
};

use crate::config::Config;
Expand Down Expand Up @@ -237,10 +237,9 @@ impl Model {
MessageContentPart::ImageUrl { .. } => 0,
})
.sum(),
MessageContent::ToolResults(results) => {
let ToolResults {
tool_results, text, ..
} = results;
MessageContent::ToolCalls(MessageContentToolCalls {
tool_results, text, ..
}) => {
estimate_token_length(text)
+ tool_results
.iter()
Expand Down
5 changes: 2 additions & 3 deletions src/client/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,11 @@ pub fn openai_build_chat_completions_body(data: ChatCompletionsData, model: &Mod
.flat_map(|message| {
let Message { role, content } = message;
match content {
MessageContent::ToolResults(results) => {
let ToolResults {
MessageContent::ToolCalls(MessageContentToolCalls {
tool_results,
text,
sequence,
} = results;
}) => {
if !sequence {
let tool_calls: Vec<_> = tool_results.iter().map(|tool_result| {
json!({
Expand Down
3 changes: 1 addition & 2 deletions src/client/vertexai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,7 @@ pub fn gemini_build_chat_completions_body(
.collect();
vec![json!({ "role": role, "parts": parts })]
},
MessageContent::ToolResults(results) => {
let tool_results = results.tool_results;
MessageContent::ToolCalls(MessageContentToolCalls { tool_results, .. }) => {
let model_parts: Vec<Value> = tool_results.iter().map(|tool_result| {
json!({
"functionCall": {
Expand Down
26 changes: 13 additions & 13 deletions src/config/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use super::*;

use crate::client::{
init_client, patch_system_message, ChatCompletionsData, Client, ImageUrl, Message,
MessageContent, MessageContentPart, MessageRole, Model,
MessageContent, MessageContentPart, MessageContentToolCalls, MessageRole, Model,
};
use crate::function::{ToolResult, ToolResults};
use crate::function::ToolResult;
use crate::utils::{base64_encode, sha256, AbortSignal};

use anyhow::{bail, Context, Result};
Expand All @@ -29,7 +29,7 @@ pub struct Input {
regenerate: bool,
medias: Vec<String>,
data_urls: HashMap<String, String>,
tool_results: Option<ToolResults>,
tool_calls: Option<MessageContentToolCalls>,
rag_name: Option<String>,
role: Role,
with_session: bool,
Expand All @@ -48,7 +48,7 @@ impl Input {
regenerate: false,
medias: Default::default(),
data_urls: Default::default(),
tool_results: None,
tool_calls: None,
rag_name: None,
role,
with_session,
Expand Down Expand Up @@ -104,7 +104,7 @@ impl Input {
regenerate: false,
medias,
data_urls,
tool_results: Default::default(),
tool_calls: Default::default(),
rag_name: None,
role,
with_session,
Expand All @@ -120,8 +120,8 @@ impl Input {
self.data_urls.clone()
}

pub fn tool_results(&self) -> &Option<ToolResults> {
&self.tool_results
pub fn tool_calls(&self) -> &Option<MessageContentToolCalls> {
&self.tool_calls
}

pub fn text(&self) -> String {
Expand Down Expand Up @@ -187,12 +187,12 @@ impl Input {
self.rag_name.as_deref()
}

pub fn merge_tool_call(mut self, output: String, tool_results: Vec<ToolResult>) -> Self {
match self.tool_results.as_mut() {
pub fn merge_tool_results(mut self, output: String, tool_results: Vec<ToolResult>) -> Self {
match self.tool_calls.as_mut() {
Some(exist_tool_results) => {
exist_tool_results.extend(tool_results, output);
exist_tool_results.merge(tool_results, output);
}
None => self.tool_results = Some(ToolResults::new(tool_results, output)),
None => self.tool_calls = Some(MessageContentToolCalls::new(tool_results, output)),
}
self
}
Expand Down Expand Up @@ -232,10 +232,10 @@ impl Input {
} else {
self.role().build_messages(self)
};
if let Some(tool_results) = &self.tool_results {
if let Some(tool_calls) = &self.tool_calls {
messages.push(Message::new(
MessageRole::Assistant,
MessageContent::ToolResults(tool_results.clone()),
MessageContent::ToolCalls(tool_calls.clone()),
))
}
Ok(messages)
Expand Down
18 changes: 16 additions & 2 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use self::session::Session;

use crate::client::{
create_client_config, list_chat_models, list_client_types, list_reranker_models, ClientConfig,
Model, OPENAI_COMPATIBLE_PLATFORMS,
MessageContentToolCalls, Model, OPENAI_COMPATIBLE_PLATFORMS,
};
use crate::function::{FunctionDeclaration, Functions, ToolResult};
use crate::rag::Rag;
Expand Down Expand Up @@ -1863,8 +1863,22 @@ impl Config {
} else {
String::new()
};
let tool_calls = match input.tool_calls() {
Some(MessageContentToolCalls {
tool_results, text, ..
}) => {
let mut lines = vec!["<tool_calls>".to_string()];
if !text.is_empty() {
lines.push(text.clone());
}
lines.push(serde_json::to_string(&tool_results).unwrap_or_default());
lines.push("</tool_calls>\n".to_string());
lines.join("\n")
}
None => String::new(),
};
let output = format!(
"# CHAT: {summary} [{timestamp}]{scope}\n{raw_input}\n--------\n{output}\n--------\n\n",
"# CHAT: {summary} [{timestamp}]{scope}\n{raw_input}\n--------\n{tool_calls}{output}\n--------\n\n",
);
file.write_all(output.as_bytes())
.with_context(|| "Failed to save message")
Expand Down
4 changes: 2 additions & 2 deletions src/config/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,10 @@ impl Session {
.push(Message::new(MessageRole::User, input.message_content()));
}
self.data_urls.extend(input.data_urls());
if let Some(tool_results) = input.tool_results() {
if let Some(tool_calls) = input.tool_calls() {
self.messages.push(Message::new(
MessageRole::Tool,
MessageContent::ToolResults(tool_results.clone()),
MessageContent::ToolCalls(tool_calls.clone()),
))
}
self.messages.push(Message::new(
Expand Down
23 changes: 0 additions & 23 deletions src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,29 +266,6 @@ impl ToolCall {
}
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolResults {
pub tool_results: Vec<ToolResult>,
pub text: String,
pub sequence: bool,
}

impl ToolResults {
pub fn new(tool_results: Vec<ToolResult>, text: String) -> Self {
Self {
tool_results,
text,
sequence: false,
}
}

pub fn extend(&mut self, tool_results: Vec<ToolResult>, _text: String) {
self.tool_results.extend(tool_results);
self.text.clear();
self.sequence = true;
}
}

#[cfg(windows)]
fn polyfill_cmd_name<T: AsRef<Path>>(cmd_name: &str, bin_dir: &[T]) -> String {
let cmd_name = cmd_name.to_string();
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ async fn start_directive(
if need_send_tool_results(&tool_results) {
start_directive(
config,
input.merge_tool_call(output, tool_results),
input.merge_tool_results(output, tool_results),
code_mode,
abort_signal,
)
Expand Down
2 changes: 1 addition & 1 deletion src/repl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ async fn ask(
ask(
config,
abort_signal,
input.merge_tool_call(output, tool_results),
input.merge_tool_results(output, tool_results),
false,
)
.await
Expand Down
2 changes: 1 addition & 1 deletion src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@ fn parse_messages(message: Vec<Value>) -> Result<Vec<Message>> {
}
output.push(Message::new(
MessageRole::Assistant,
MessageContent::ToolResults(ToolResults::new(list, text)),
MessageContent::ToolCalls(MessageContentToolCalls::new(list, text)),
));
tool_results = None;
} else {
Expand Down

0 comments on commit 80684ec

Please sign in to comment.