Skip to content

Commit

Permalink
Fix baichuan template (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
npuichigo authored Mar 14, 2024
1 parent 89f8a24 commit fc37fbb
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 73 deletions.
4 changes: 3 additions & 1 deletion example/history_template_baichuan.liquid
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
{% for item in items -%}
{%- capture identity -%}
{%- case item.identity -%}
{%- when "System", "Tool" -%}
System
{%- when "User" -%}
<reserved_106>
{%- when "Assistant" -%}
Expand All @@ -11,4 +13,4 @@

{{- identity }}{% if item.name %} {{ item.name }}{% endif %}: {{ item.content }}
{% endfor -%}
<reserved_107>:
<reserved_107>:
2 changes: 1 addition & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ pub struct Config {
/// File containing the history template string
#[arg(long)]
#[serde(skip_serializing_if = "Option::is_none")]
pub history_template_file: Option<String>
pub history_template_file: Option<String>,
}
204 changes: 143 additions & 61 deletions src/history/mod.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
use std::fs::File;
use std::io::Read;
use std::sync::Arc;
use crate::routes::chat::ChatCompletionMessageParams;
use anyhow::bail;
use liquid::{ParserBuilder, Template};
use serde::Serialize;
use crate::routes::chat::ChatCompletionMessageParams;
use std::fs::File;
use std::io::Read;
use std::sync::Arc;

const DEFAULT_TEMPLATE: &str =
"{% for item in items %}\
const DEFAULT_TEMPLATE: &str = "{% for item in items %}\
{{ item.identity }}{% if item.name %} {{ item.name }}{% endif %}: {{ item.content }}
{% endfor %}\
ASSISTANT:";

#[derive(Clone)]
pub struct HistoryBuilder {
history_template: Arc<Template>
history_template: Arc<Template>,
}

impl HistoryBuilder {
Expand All @@ -24,12 +23,11 @@ impl HistoryBuilder {
}
let mut _ref_holder = None;


let template = match template_file {
None => match template {
None => {DEFAULT_TEMPLATE}
Some(cfg) => {cfg.as_str()}
}
None => DEFAULT_TEMPLATE,
Some(cfg) => cfg.as_str(),
},
Some(filename) => {
_ref_holder = Some(load_template_file(filename)?);
_ref_holder.as_ref().unwrap().as_str()
Expand All @@ -38,10 +36,13 @@ impl HistoryBuilder {

let history_template = Arc::new(ParserBuilder::with_stdlib().build()?.parse(template)?);

Ok(HistoryBuilder {history_template})
Ok(HistoryBuilder { history_template })
}

pub fn build_history(&self, messages: &Vec<ChatCompletionMessageParams>) -> anyhow::Result<String> {
pub fn build_history(
&self,
messages: &Vec<ChatCompletionMessageParams>,
) -> anyhow::Result<String> {
let items: Vec<_> = messages.iter().map(|x| HistoryItem::new(x)).collect();
let context = liquid::object!({"items": items});
Ok(self.history_template.render(&context)?)
Expand All @@ -59,19 +60,31 @@ fn load_template_file(file: &String) -> anyhow::Result<String> {
struct HistoryItem {
identity: String,
content: String,
name: Option<String>
name: Option<String>,
}

impl HistoryItem {
pub fn new(message: &ChatCompletionMessageParams) -> Self {
let (identity, content, name) = match message {
ChatCompletionMessageParams::System { content, name } => { ("System".into(), content.clone(), name.clone()) }
ChatCompletionMessageParams::User { content, name } => { ("User".into(), content.clone(), name.clone()) }
ChatCompletionMessageParams::Assistant { content } => { ("Assistant".into(), content.clone(), None) }
ChatCompletionMessageParams::Tool { content, .. } => { ("Tool".into(), content.clone(), None) }
ChatCompletionMessageParams::System { content, name } => {
("System".into(), content.clone(), name.clone())
}
ChatCompletionMessageParams::User { content, name } => {
("User".into(), content.clone(), name.clone())
}
ChatCompletionMessageParams::Assistant { content } => {
("Assistant".into(), content.clone(), None)
}
ChatCompletionMessageParams::Tool { content, .. } => {
("Tool".into(), content.clone(), None)
}
};

HistoryItem { identity, content, name }
HistoryItem {
identity,
content,
name,
}
}
}

Expand All @@ -83,100 +96,169 @@ mod test {
pub fn test_default_template() {
let template = None;
let template_file = None;
let builder = HistoryBuilder::new(&template, &template_file).expect("default template should build correctly");
let builder = HistoryBuilder::new(&template, &template_file)
.expect("default template should build correctly");

let messages = vec![
ChatCompletionMessageParams::System {content: "test system 1".into(), name: Some("system 1".into())},
ChatCompletionMessageParams::System {content: "test system 2".into(), name: None},
ChatCompletionMessageParams::Assistant {content: "test assistant 1".into()},
ChatCompletionMessageParams::Tool {content: "test tool 1".into(), tool_call_id: "tool_1".into()},
ChatCompletionMessageParams::User {content: "test user 1".into(), name: Some("user 1".into())},
ChatCompletionMessageParams::User {content: "test user 2".into(), name: None}
ChatCompletionMessageParams::System {
content: "test system 1".into(),
name: Some("system 1".into()),
},
ChatCompletionMessageParams::System {
content: "test system 2".into(),
name: None,
},
ChatCompletionMessageParams::Assistant {
content: "test assistant 1".into(),
},
ChatCompletionMessageParams::Tool {
content: "test tool 1".into(),
tool_call_id: "tool_1".into(),
},
ChatCompletionMessageParams::User {
content: "test user 1".into(),
name: Some("user 1".into()),
},
ChatCompletionMessageParams::User {
content: "test user 2".into(),
name: None,
},
];

let result = builder.build_history(&messages).expect("history should build correctly");
let result = builder
.build_history(&messages)
.expect("history should build correctly");

let expected_result: String =
"System system 1: test system 1
let expected_result: String = "System system 1: test system 1
System: test system 2
Assistant: test assistant 1
Tool: test tool 1
User user 1: test user 1
User: test user 2
ASSISTANT:".into();
ASSISTANT:"
.into();

assert_eq!(expected_result, result)

}

#[test]
pub fn test_template_file() {
let template = None;
let template_file = Some(format!("{}/example/history_template.liquid", env!("CARGO_MANIFEST_DIR")));
let builder = HistoryBuilder::new(&template, &template_file).expect("default template should build correctly");
let template_file = Some(format!(
"{}/example/history_template.liquid",
env!("CARGO_MANIFEST_DIR")
));
let builder = HistoryBuilder::new(&template, &template_file)
.expect("default template should build correctly");

let messages = vec![
ChatCompletionMessageParams::System {content: "test system 1".into(), name: Some("system 1".into())},
ChatCompletionMessageParams::System {content: "test system 2".into(), name: None},
ChatCompletionMessageParams::Assistant {content: "test assistant 1".into()},
ChatCompletionMessageParams::Tool {content: "test tool 1".into(), tool_call_id: "tool_1".into()},
ChatCompletionMessageParams::User {content: "test user 1".into(), name: Some("user 1".into())},
ChatCompletionMessageParams::User {content: "test user 2".into(), name: None}
ChatCompletionMessageParams::System {
content: "test system 1".into(),
name: Some("system 1".into()),
},
ChatCompletionMessageParams::System {
content: "test system 2".into(),
name: None,
},
ChatCompletionMessageParams::Assistant {
content: "test assistant 1".into(),
},
ChatCompletionMessageParams::Tool {
content: "test tool 1".into(),
tool_call_id: "tool_1".into(),
},
ChatCompletionMessageParams::User {
content: "test user 1".into(),
name: Some("user 1".into()),
},
ChatCompletionMessageParams::User {
content: "test user 2".into(),
name: None,
},
];

let result = builder.build_history(&messages).expect("history should build correctly");
let result = builder
.build_history(&messages)
.expect("history should build correctly");

let expected_result: String =
"System system 1: test system 1
let expected_result: String = "System system 1: test system 1
System: test system 2
Assistant: test assistant 1
Tool: test tool 1
User user 1: test user 1
User: test user 2
ASSISTANT:".into();
ASSISTANT:"
.into();

assert_eq!(expected_result, result)

}

#[test]
pub fn test_template_file_custom_roles() {
let template = None;
let template_file = Some(format!("{}/example/history_template_custom_roles.liquid", env!("CARGO_MANIFEST_DIR")));
let builder = HistoryBuilder::new(&template, &template_file).expect("default template should build correctly");
let template_file = Some(format!(
"{}/example/history_template_custom_roles.liquid",
env!("CARGO_MANIFEST_DIR")
));
let builder = HistoryBuilder::new(&template, &template_file)
.expect("default template should build correctly");

let messages = vec![
ChatCompletionMessageParams::System {content: "test system 1".into(), name: Some("system 1".into())},
ChatCompletionMessageParams::System {content: "test system 2".into(), name: None},
ChatCompletionMessageParams::Assistant {content: "test assistant 1".into()},
ChatCompletionMessageParams::Tool {content: "test tool 1".into(), tool_call_id: "tool_1".into()},
ChatCompletionMessageParams::User {content: "test user 1".into(), name: Some("user 1".into())},
ChatCompletionMessageParams::User {content: "test user 2".into(), name: None}
ChatCompletionMessageParams::System {
content: "test system 1".into(),
name: Some("system 1".into()),
},
ChatCompletionMessageParams::System {
content: "test system 2".into(),
name: None,
},
ChatCompletionMessageParams::Assistant {
content: "test assistant 1".into(),
},
ChatCompletionMessageParams::Tool {
content: "test tool 1".into(),
tool_call_id: "tool_1".into(),
},
ChatCompletionMessageParams::User {
content: "test user 1".into(),
name: Some("user 1".into()),
},
ChatCompletionMessageParams::User {
content: "test user 2".into(),
name: None,
},
];

let result = builder.build_history(&messages).expect("history should build correctly");
let result = builder
.build_history(&messages)
.expect("history should build correctly");

let expected_result: String =
"Robot system 1: test system 1
let expected_result: String = "Robot system 1: test system 1
Robot: test system 2
Support: test assistant 1
Robot: test tool 1
Customer user 1: test user 1
Customer: test user 2
ASSISTANT:".into();
ASSISTANT:"
.into();

assert_eq!(expected_result, result)

}

#[test]
pub fn test_validations() {
let template = Some("abc".into());
let template_file = Some("abc".into());
match HistoryBuilder::new(&template, &template_file) {
Ok(_) => {assert!(false, "expected err")}
Err(e) => {assert_eq!("cannot set both history-template and history-template-file", e.to_string())}
Ok(_) => {
assert!(false, "expected err")
}
Err(e) => {
assert_eq!(
"cannot set both history-template and history-template-file",
e.to_string()
)
}
};

}
}
}
15 changes: 12 additions & 3 deletions src/routes/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ use crate::utils::deserialize_bytes_tensor;
#[instrument(name = "chat_completions", skip(grpc_client, history_builder, request))]
pub(crate) async fn compat_chat_completions(
headers: HeaderMap,
State(AppState{ grpc_client, history_builder }): State<AppState>,
State(AppState {
grpc_client,
history_builder,
}): State<AppState>,
request: Json<ChatCompletionCreateParams>,
) -> Response {
tracing::info!("request: {:?}", request);
Expand All @@ -46,7 +49,10 @@ pub(crate) async fn compat_chat_completions(
}
}

#[instrument(name = "streaming chat completions", skip(client, history_builder, request))]
#[instrument(
name = "streaming chat completions",
skip(client, history_builder, request)
)]
async fn chat_completions_stream(
headers: HeaderMap,
mut client: GrpcInferenceServiceClient<Channel>,
Expand Down Expand Up @@ -203,7 +209,10 @@ async fn chat_completions(
}))
}

fn build_triton_request(request: ChatCompletionCreateParams, history_builder: &HistoryBuilder) -> anyhow::Result<ModelInferRequest> {
fn build_triton_request(
request: ChatCompletionCreateParams,
history_builder: &HistoryBuilder,
) -> anyhow::Result<ModelInferRequest> {
let chat_history = history_builder.build_history(&request.messages)?;
tracing::debug!("chat history after formatting: {}", chat_history);

Expand Down
6 changes: 4 additions & 2 deletions src/routes/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::utils::{deserialize_bytes_tensor, string_or_seq_string};
#[instrument(name = "completions", skip(grpc_client, request))]
pub(crate) async fn compat_completions(
headers: HeaderMap,
State(AppState{ grpc_client, .. }): State<AppState>,
State(AppState { grpc_client, .. }): State<AppState>,
request: Json<CompletionCreateParams>,
) -> Response {
tracing::info!("request: {:?}", request);
Expand All @@ -39,7 +39,9 @@ pub(crate) async fn compat_completions(
.await
.into_response()
} else {
completions(headers, grpc_client, request).await.into_response()
completions(headers, grpc_client, request)
.await
.into_response()
}
}

Expand Down
Loading

0 comments on commit fc37fbb

Please sign in to comment.