Skip to content

Commit

Permalink
refactor: improve handling of no_stream/no_system_message (#936)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Oct 19, 2024
1 parent 7e29f64 commit 7b20ab0
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/client/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub trait Client: Sync + Send {
input: &Input,
handler: &mut SseHandler,
) -> Result<()> {
let abort_signal = handler.get_abort();
let abort_signal = handler.abort();
let input = input.clone();
tokio::select! {
ret = async {
Expand Down
6 changes: 1 addition & 5 deletions src/client/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,17 +193,13 @@ struct EmbeddingsResBodyEmbedding {

pub fn openai_build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Value {
let ChatCompletionsData {
mut messages,
messages,
temperature,
top_p,
functions,
stream,
} = data;

if model.no_system_message() {
patch_system_message(&mut messages);
}

let messages: Vec<Value> = messages
.into_iter()
.flat_map(|message| {
Expand Down
4 changes: 2 additions & 2 deletions src/client/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ impl SseHandler {
Ok(())
}

pub fn get_abort(&self) -> AbortSignal {
pub fn abort(&self) -> AbortSignal {
self.abort.clone()
}

pub fn get_tool_calls(&self) -> &[ToolCall] {
pub fn tool_calls(&self) -> &[ToolCall] {
&self.tool_calls
}

Expand Down
9 changes: 6 additions & 3 deletions src/config/input.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::*;

use crate::client::{
init_client, ChatCompletionsData, Client, ImageUrl, Message, MessageContent,
MessageContentPart, MessageRole, Model,
init_client, patch_system_message, ChatCompletionsData, Client, ImageUrl, Message,
MessageContent, MessageContentPart, MessageRole, Model,
};
use crate::function::{ToolResult, ToolResults};
use crate::utils::{base64_encode, sha256, AbortSignal};
Expand Down Expand Up @@ -206,7 +206,10 @@ impl Input {
if !self.medias.is_empty() && !model.supports_vision() {
bail!("The current model does not support vision. Is the model configured with `supports_vision: true`?");
}
let messages = self.build_messages()?;
let mut messages = self.build_messages()?;
if model.no_system_message() {
patch_system_message(&mut messages);
}
model.guard_max_input_tokens(&messages)?;
let temperature = self.role().temperature();
let top_p = self.role().top_p();
Expand Down
55 changes: 39 additions & 16 deletions src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ impl Server {
tools,
} = req_body;

let messages =
let mut messages =
parse_messages(messages).map_err(|err| anyhow!("Invalid request body, {err}"))?;

let functions = parse_tools(tools).map_err(|err| anyhow!("Invalid request body, {err}"))?;
Expand Down Expand Up @@ -313,7 +313,9 @@ impl Server {

let completion_id = generate_completion_id();
let created = Utc::now().timestamp();

if client.model().no_system_message() {
patch_system_message(&mut messages);
}
let data: ChatCompletionsData = ChatCompletionsData {
messages,
temperature,
Expand Down Expand Up @@ -357,20 +359,41 @@ impl Server {
tx: &UnboundedSender<ResEvent>,
is_first: Arc<AtomicBool>,
) {
let ret = client
.chat_completions_streaming_inner(http_client, handler, data)
.await;
let first = match ret {
Ok(()) => None,
Err(err) => Some(format!("{err:?}")),
};
if is_first.load(Ordering::SeqCst) {
let _ = tx.send(ResEvent::First(first));
is_first.store(false, Ordering::SeqCst)
}
let tool_calls = handler.get_tool_calls();
if !tool_calls.is_empty() {
let _ = tx.send(ResEvent::ToolCalls(tool_calls.to_vec()));
if client.model().no_stream() {
let ret = client.chat_completions_inner(http_client, data).await;
match ret {
Ok(output) => {
let ChatCompletionsOutput {
text, tool_calls, ..
} = output;
let _ = tx.send(ResEvent::First(None));
is_first.store(false, Ordering::SeqCst);
let _ = tx.send(ResEvent::Text(text));
if !tool_calls.is_empty() {
let _ = tx.send(ResEvent::ToolCalls(tool_calls));
}
}
Err(err) => {
let _ = tx.send(ResEvent::First(Some(format!("{err:?}"))));
is_first.store(false, Ordering::SeqCst)
}
};
} else {
let ret = client
.chat_completions_streaming_inner(http_client, handler, data)
.await;
let first = match ret {
Ok(()) => None,
Err(err) => Some(format!("{err:?}")),
};
if is_first.load(Ordering::SeqCst) {
let _ = tx.send(ResEvent::First(first));
is_first.store(false, Ordering::SeqCst)
}
let tool_calls = handler.tool_calls().to_vec();
if !tool_calls.is_empty() {
let _ = tx.send(ResEvent::ToolCalls(tool_calls));
}
}
handler.done();
}
Expand Down

0 comments on commit 7b20ab0

Please sign in to comment.