From 444f4ebe9de8aa68afdf8d39b734a8543b9b5c4b Mon Sep 17 00:00:00 2001 From: sigoden Date: Thu, 2 Nov 2023 09:53:54 +0800 Subject: [PATCH] refactor: improve code quanity (#196) - rewrite Repl, remove ReplHandler - move ReplyStreamHandler to repl/ and rename it to ReplyHandler - deprecate utils::print_now - refactor session info --- src/client/azure_openai.rs | 4 +- src/client/common.rs | 15 +- src/client/localai.rs | 4 +- src/client/mod.rs | 6 - src/client/openai.rs | 11 +- src/config/mod.rs | 62 +++++++-- src/config/model_info.rs | 2 +- src/config/role.rs | 2 +- src/config/session.rs | 75 +++++----- src/main.rs | 22 ++- src/render/cmd.rs | 19 ++- src/render/mod.rs | 84 +++++++++-- src/render/repl.rs | 16 +-- src/repl/abort.rs | 12 +- src/repl/handler.rs | 205 --------------------------- src/repl/init.rs | 85 ------------ src/repl/mod.rs | 277 +++++++++++++++++++++++++++++-------- src/utils/mod.rs | 19 +-- 18 files changed, 443 insertions(+), 477 deletions(-) delete mode 100644 src/repl/handler.rs delete mode 100644 src/repl/init.rs diff --git a/src/client/azure_openai.rs b/src/client/azure_openai.rs index 56036662..fe3ec0f1 100644 --- a/src/client/azure_openai.rs +++ b/src/client/azure_openai.rs @@ -1,5 +1,7 @@ use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS}; -use super::{AzureOpenAIClient, ExtraConfig, ModelInfo, PromptKind, PromptType, SendData}; +use super::{AzureOpenAIClient, ExtraConfig, PromptType, SendData}; + +use crate::{config::ModelInfo, utils::PromptKind}; use anyhow::{anyhow, Result}; use async_trait::async_trait; diff --git a/src/client/common.rs b/src/client/common.rs index 7963c76c..0d0c0e2e 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -1,6 +1,7 @@ use crate::{ config::{Message, SharedConfig}, - repl::{ReplyStreamHandler, SharedAbortSignal}, + render::ReplyHandler, + repl::AbortSignal, utils::{init_tokio_runtime, prompt_input_integer, prompt_input_string, tokenize, PromptKind}, }; @@ -139,7 +140,7 @@ macro_rules! openai_compatible_client { async fn send_message_streaming_inner( &self, client: &reqwest::Client, - handler: &mut $crate::repl::ReplyStreamHandler, + handler: &mut $crate::render::ReplyHandler, data: $crate::client::SendData, ) -> Result<()> { let builder = self.request_builder(client, data)?; @@ -201,12 +202,8 @@ pub trait Client { }) } - fn send_message_streaming( - &self, - content: &str, - handler: &mut ReplyStreamHandler, - ) -> Result<()> { - async fn watch_abort(abort: SharedAbortSignal) { + fn send_message_streaming(&self, content: &str, handler: &mut ReplyHandler) -> Result<()> { + async fn watch_abort(abort: AbortSignal) { loop { if abort.aborted() { break; @@ -252,7 +249,7 @@ pub trait Client { async fn send_message_streaming_inner( &self, client: &ReqwestClient, - handler: &mut ReplyStreamHandler, + handler: &mut ReplyHandler, data: SendData, ) -> Result<()>; } diff --git a/src/client/localai.rs b/src/client/localai.rs index d438388b..796b574a 100644 --- a/src/client/localai.rs +++ b/src/client/localai.rs @@ -1,5 +1,7 @@ use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS}; -use super::{ExtraConfig, LocalAIClient, ModelInfo, PromptKind, PromptType, SendData}; +use super::{ExtraConfig, LocalAIClient, PromptType, SendData}; + +use crate::{config::ModelInfo, utils::PromptKind}; use anyhow::Result; use async_trait::async_trait; diff --git a/src/client/mod.rs b/src/client/mod.rs index 5fa61460..e55055d1 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -3,12 +3,6 @@ mod common; pub use common::*; -use crate::{ - config::{ModelInfo, TokensCountFactors}, - repl::ReplyStreamHandler, - utils::PromptKind, -}; - register_client!( (openai, "openai", OpenAI, OpenAIConfig, OpenAIClient), (localai, "localai", LocalAI, LocalAIConfig, LocalAIClient), diff --git a/src/client/openai.rs b/src/client/openai.rs index a82ee9a8..98de5a8f 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -1,6 +1,11 @@ use super::{ - ExtraConfig, ModelInfo, OpenAIClient, PromptKind, PromptType, ReplyStreamHandler, SendData, - TokensCountFactors, + ExtraConfig, OpenAIClient, PromptType, SendData, +}; + +use crate::{ + config::{ModelInfo, TokensCountFactors}, + render::ReplyHandler, + utils::PromptKind, }; use anyhow::{anyhow, bail, Result}; @@ -88,7 +93,7 @@ pub async fn openai_send_message(builder: RequestBuilder) -> Result { pub async fn openai_send_message_streaming( builder: RequestBuilder, - handler: &mut ReplyStreamHandler, + handler: &mut ReplyHandler, ) -> Result<()> { let mut es = builder.eventsource()?; while let Some(event) = es.next().await { diff --git a/src/config/mod.rs b/src/config/mod.rs index ea458973..731cc101 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -12,7 +12,7 @@ use crate::client::{ all_models, create_client_config, list_client_types, ClientConfig, ExtraConfig, OpenAIClient, SendData, }; -use crate::render::RenderOptions; +use crate::render::{MarkdownRender, RenderOptions}; use crate::utils::{get_env_name, light_theme_from_colorfgbg, now, prompt_op_err}; use anyhow::{anyhow, bail, Context, Result}; @@ -336,7 +336,7 @@ impl Config { } } - pub fn info(&self) -> Result { + pub fn sys_info(&self) -> Result { let path_info = |path: &Path| { let state = if path.exists() { "" } else { " ⚠️" }; format!("{}{state}", path.display()) @@ -349,27 +349,63 @@ impl Config { .clone() .map_or_else(|| String::from("no"), |v| v.to_string()); let items = vec![ - ("config_file", path_info(&Self::config_file()?)), - ("roles_file", path_info(&Self::roles_file()?)), - ("messages_file", path_info(&Self::messages_file()?)), - ("sessions_dir", path_info(&Self::sessions_dir()?)), ("model", self.model_info.full_name()), ("temperature", temperature), + ("dry_run", self.dry_run.to_string()), ("save", self.save.to_string()), ("highlight", self.highlight.to_string()), - ("light_theme", self.light_theme.to_string()), ("wrap", wrap), ("wrap_code", self.wrap_code.to_string()), - ("dry_run", self.dry_run.to_string()), + ("light_theme", self.light_theme.to_string()), ("keybindings", self.keybindings.stringify().into()), + ("config_file", path_info(&Self::config_file()?)), + ("roles_file", path_info(&Self::roles_file()?)), + ("messages_file", path_info(&Self::messages_file()?)), + ("sessions_dir", path_info(&Self::sessions_dir()?)), ]; - let mut output = String::new(); - for (name, value) in items { - output.push_str(&format!("{name:<20}{value}\n")); - } + let output = items + .iter() + .map(|(name, value)| format!("{name:<20}{value}")) + .collect::>() + .join("\n"); Ok(output) } + pub fn role_info(&self) -> Result { + if let Some(role) = &self.role { + role.info() + } else { + bail!("No role") + } + } + + pub fn session_info(&self) -> Result { + if let Some(session) = &self.session { + let render_options = self.get_render_options()?; + let mut markdown_render = MarkdownRender::init(render_options)?; + session.render(&mut markdown_render) + } else { + bail!("No session") + } + } + + pub fn info(&self) -> Result { + if let Some(session) = &self.session { + session.export() + } else if let Some(role) = &self.role { + role.info() + } else { + self.sys_info() + } + } + + pub fn last_reply(&self) -> &str { + self.last_message + .as_ref() + .map(|(_, reply)| reply.as_str()) + .unwrap_or_default() + } + pub fn repl_completions(&self) -> Vec { let mut completion: Vec = self .roles @@ -423,7 +459,7 @@ impl Config { Ok(()) } - pub fn start_session(&mut self, session: &Option) -> Result<()> { + pub fn start_session(&mut self, session: Option<&str>) -> Result<()> { if self.session.is_some() { bail!("Already in a session, please use '.clear session' to exit the session first?"); } diff --git a/src/config/model_info.rs b/src/config/model_info.rs index 793c014d..7a52e633 100644 --- a/src/config/model_info.rs +++ b/src/config/model_info.rs @@ -1,4 +1,4 @@ -use super::Message; +use super::message::Message; use crate::utils::count_tokens; diff --git a/src/config/role.rs b/src/config/role.rs index 16a64002..819cc127 100644 --- a/src/config/role.rs +++ b/src/config/role.rs @@ -19,7 +19,7 @@ impl Role { pub fn info(&self) -> Result { let output = serde_yaml::to_string(&self) .with_context(|| format!("Unable to show info about role {}", &self.name))?; - Ok(output) + Ok(output.trim_end().to_string()) } pub fn embedded(&self) -> bool { diff --git a/src/config/session.rs b/src/config/session.rs index 62c1b092..7ed57145 100644 --- a/src/config/session.rs +++ b/src/config/session.rs @@ -82,12 +82,12 @@ impl Session { if let Some(temperature) = self.temperature() { data["temperature"] = temperature.into(); } - data["total-tokens"] = tokens.into(); + data["total_tokens"] = tokens.into(); if let Some(max_tokens) = self.model_info.max_tokens { - data["max-tokens"] = max_tokens.into(); + data["max_tokens"] = max_tokens.into(); } if percent != 0.0 { - data["total/max-tokens"] = format!("{}%", percent).into(); + data["total/max"] = format!("{}%", percent).into(); } data["messages"] = json!(self.messages); @@ -97,43 +97,46 @@ impl Session { } pub fn render(&self, render: &mut MarkdownRender) -> Result { - let path = self.path.clone().unwrap_or_else(|| "-".to_string()); - - let temperature = self - .temperature() - .map_or_else(|| String::from("-"), |v| v.to_string()); - - let max_tokens = self - .model_info - .max_tokens - .map(|v| v.to_string()) - .unwrap_or_else(|| '-'.to_string()); - - let items = vec![ - ("path", path), - ("model", self.model().to_string()), - ("temperature", temperature), - ("max_tokens", max_tokens), - ]; - let mut lines = vec![]; - for (name, value) in items { - lines.push(format!("{name:<20}{value}")); + let mut items = vec![]; + + if let Some(path) = &self.path { + items.push(("path", path.to_string())); } - lines.push("".into()); - for message in &self.messages { - match message.role { - MessageRole::System => { - continue; - } - MessageRole::Assistant => { - lines.push(render.render(&message.content)); - lines.push("".into()); - } - MessageRole::User => { - lines.push(format!("{}){}", self.name, message.content)); + + items.push(("model", self.model_info.full_name())); + + if let Some(temperature) = self.temperature() { + items.push(("temperature", temperature.to_string())); + } + + if let Some(max_tokens) = self.model_info.max_tokens { + items.push(("max_tokens", max_tokens.to_string())); + } + + let mut lines: Vec = items + .iter() + .map(|(name, value)| format!("{name:<20}{value}")) + .collect(); + + if !self.is_empty() { + lines.push("".into()); + + for message in &self.messages { + match message.role { + MessageRole::System => { + continue; + } + MessageRole::Assistant => { + lines.push(render.render(&message.content)); + lines.push("".into()); + } + MessageRole::User => { + lines.push(format!("{}){}", self.name, message.content)); + } } } } + let output = lines.join("\n"); Ok(output) } diff --git a/src/main.rs b/src/main.rs index 6492b4a3..b801e477 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,7 +17,7 @@ use crossbeam::sync::WaitGroup; use is_terminal::IsTerminal; use parking_lot::RwLock; use render::{render_stream, MarkdownRender}; -use repl::{AbortSignal, Repl}; +use repl::{create_abort_signal, Repl}; use std::io::{stdin, Read}; use std::sync::Arc; use std::{io::stdout, process::exit}; @@ -59,7 +59,9 @@ fn main() -> Result<()> { config.write().set_role(name)?; } if let Some(session) = &cli.session { - config.write().start_session(session)?; + config + .write() + .start_session(session.as_ref().map(|v| v.as_str()))?; } if let Some(model) = &cli.model { config.write().set_model(model)?; @@ -68,14 +70,8 @@ fn main() -> Result<()> { config.write().highlight = false; } if cli.info { - let info = if let Some(session) = &config.read().session { - session.export()? - } else if let Some(role) = &config.read().role { - role.info()? - } else { - config.read().info()? - }; - println!("{info}"); + let info = config.read().info()?; + println!("{}", info); exit(0); } let no_stream = cli.no_stream; @@ -116,12 +112,12 @@ fn start_directive( output } else { let wg = WaitGroup::new(); - let abort = AbortSignal::new(); + let abort = create_abort_signal(); let abort_clone = abort.clone(); ctrlc::set_handler(move || { abort_clone.set_ctrlc(); }) - .expect("Error setting Ctrl-C handler"); + .expect("Failed to setting Ctrl-C handler"); let output = render_stream(input, client, config, false, abort, wg.clone())?; wg.wait(); output @@ -132,5 +128,5 @@ fn start_directive( fn start_interactive(config: SharedConfig) -> Result<()> { cl100k_base_singleton(); let mut repl: Repl = Repl::init(config.clone())?; - repl.run(config) + repl.run() } diff --git a/src/render/cmd.rs b/src/render/cmd.rs index 115170b7..e09bd71a 100644 --- a/src/render/cmd.rs +++ b/src/render/cmd.rs @@ -1,7 +1,6 @@ -use super::MarkdownRender; +use super::{MarkdownRender, ReplyEvent}; -use crate::print_now; -use crate::repl::{ReplyStreamEvent, SharedAbortSignal}; +use crate::repl::AbortSignal; use crate::utils::{spaces, split_line_sematic, split_line_tail}; use anyhow::Result; @@ -9,9 +8,9 @@ use crossbeam::channel::Receiver; use textwrap::core::display_width; pub fn cmd_render_stream( - rx: &Receiver, + rx: &Receiver, render: &mut MarkdownRender, - abort: &SharedAbortSignal, + abort: &AbortSignal, ) -> Result<()> { let mut buffer = String::new(); let mut col = 0; @@ -21,14 +20,14 @@ pub fn cmd_render_stream( } if let Ok(evt) = rx.try_recv() { match evt { - ReplyStreamEvent::Text(text) => { + ReplyEvent::Text(text) => { if text.contains('\n') { let text = format!("{buffer}{text}"); let (head, tail) = split_line_tail(&text); buffer = tail.to_string(); let input = format!("{}{head}", spaces(col)); let output = render.render(&input); - print_now!("{}\n", &output[col..]); + println!("{}", &output[col..]); col = 0; } else { buffer = format!("{buffer}{text}"); @@ -51,14 +50,14 @@ pub fn cmd_render_stream( col += display_width(output); } } - print_now!("{}", output); + print!("{}", output); } } } } - ReplyStreamEvent::Done => { + ReplyEvent::Done => { let input = format!("{}{buffer}", spaces(col)); - print_now!("{}\n", render.render(&input)); + println!("{}", render.render(&input)); break; } } diff --git a/src/render/mod.rs b/src/render/mod.rs index 03b67ea1..27f9e816 100644 --- a/src/render/mod.rs +++ b/src/render/mod.rs @@ -8,12 +8,12 @@ use self::repl::repl_render_stream; use crate::client::Client; use crate::config::SharedConfig; -use crate::print_now; -use crate::repl::{ReplyStreamHandler, SharedAbortSignal}; +use crate::repl::AbortSignal; -use anyhow::Result; -use crossbeam::channel::unbounded; +use anyhow::{Context, Result}; +use crossbeam::channel::{unbounded, Sender}; use crossbeam::sync::WaitGroup; +use nu_ansi_term::{Color, Style}; use std::thread::spawn; pub fn render_stream( @@ -21,13 +21,14 @@ pub fn render_stream( client: &dyn Client, config: &SharedConfig, repl: bool, - abort: SharedAbortSignal, + abort: AbortSignal, wg: WaitGroup, ) -> Result { let render_options = config.read().get_render_options()?; let mut stream_handler = { let (tx, rx) = unbounded(); let abort_clone = abort.clone(); + let highlight = config.read().highlight; spawn(move || { let run = move || { if repl { @@ -39,14 +40,81 @@ pub fn render_stream( } }; if let Err(err) = run() { - let err = format!("{err:?}"); - print_now!("\n{}\n\n", err.trim()); + render_error(err, highlight); } drop(wg); }); - ReplyStreamHandler::new(tx, abort_clone) + ReplyHandler::new(tx, abort_clone) }; client.send_message_streaming(input, &mut stream_handler)?; let buffer = stream_handler.get_buffer(); Ok(buffer.to_string()) } + +pub fn render_error(err: anyhow::Error, highlight: bool) { + let err = format!("{err:?}\n"); + if highlight { + let style = Style::new().fg(Color::Red); + println!("{}", style.paint(err.trim())); + } else { + println!("{}", err.trim()); + } +} + +pub struct ReplyHandler { + sender: Sender, + buffer: String, + abort: AbortSignal, +} + +impl ReplyHandler { + pub fn new(sender: Sender, abort: AbortSignal) -> Self { + Self { + sender, + abort, + buffer: String::new(), + } + } + + pub fn text(&mut self, text: &str) -> Result<()> { + if self.buffer.is_empty() && text == "\n\n" { + return Ok(()); + } + self.buffer.push_str(text); + let ret = self + .sender + .send(ReplyEvent::Text(text.to_string())) + .with_context(|| "Failed to send ReplyEvent:Text"); + self.safe_ret(ret)?; + Ok(()) + } + + pub fn done(&mut self) -> Result<()> { + let ret = self + .sender + .send(ReplyEvent::Done) + .with_context(|| "Failed to send ReplyEvent::Done"); + self.safe_ret(ret)?; + Ok(()) + } + + pub fn get_buffer(&self) -> &str { + &self.buffer + } + + pub fn get_abort(&self) -> AbortSignal { + self.abort.clone() + } + + fn safe_ret(&self, ret: Result<()>) -> Result<()> { + if ret.is_err() && self.abort.aborted() { + return Ok(()); + } + ret + } +} + +pub enum ReplyEvent { + Text(String), + Done, +} diff --git a/src/render/repl.rs b/src/render/repl.rs index 73639230..efad0e7c 100644 --- a/src/render/repl.rs +++ b/src/render/repl.rs @@ -1,6 +1,6 @@ -use super::MarkdownRender; +use super::{MarkdownRender, ReplyEvent}; -use crate::repl::{ReplyStreamEvent, SharedAbortSignal}; +use crate::repl::AbortSignal; use crate::utils::split_line_tail; use anyhow::Result; @@ -18,9 +18,9 @@ use std::{ use textwrap::core::display_width; pub fn repl_render_stream( - rx: &Receiver, + rx: &Receiver, render: &mut MarkdownRender, - abort: &SharedAbortSignal, + abort: &AbortSignal, ) -> Result<()> { enable_raw_mode()?; let mut stdout = io::stdout(); @@ -33,9 +33,9 @@ pub fn repl_render_stream( } fn repl_render_stream_inner( - rx: &Receiver, + rx: &Receiver, render: &mut MarkdownRender, - abort: &SharedAbortSignal, + abort: &AbortSignal, writer: &mut Stdout, ) -> Result<()> { let mut last_tick = Instant::now(); @@ -51,7 +51,7 @@ fn repl_render_stream_inner( if let Ok(evt) = rx.try_recv() { match evt { - ReplyStreamEvent::Text(text) => { + ReplyEvent::Text(text) => { let (col, mut row) = cursor::position()?; // fix unexpected duplicate lines on kitty, see https://github.com/sigoden/aichat/issues/105 @@ -95,7 +95,7 @@ fn repl_render_stream_inner( writer.flush()?; } - ReplyStreamEvent::Done => { + ReplyEvent::Done => { #[cfg(target_os = "windows")] let eol = "\n\n"; #[cfg(not(target_os = "windows"))] diff --git a/src/repl/abort.rs b/src/repl/abort.rs index f76abb67..af58b354 100644 --- a/src/repl/abort.rs +++ b/src/repl/abort.rs @@ -3,15 +3,19 @@ use std::sync::{ Arc, }; -pub type SharedAbortSignal = Arc; +pub type AbortSignal = Arc; -pub struct AbortSignal { +pub struct AbortSignalInner { ctrlc: AtomicBool, ctrld: AtomicBool, } -impl AbortSignal { - pub fn new() -> SharedAbortSignal { +pub fn create_abort_signal() -> AbortSignal { + AbortSignalInner::new() +} + +impl AbortSignalInner { + pub fn new() -> AbortSignal { Arc::new(Self { ctrlc: AtomicBool::new(false), ctrld: AtomicBool::new(false), diff --git a/src/repl/handler.rs b/src/repl/handler.rs deleted file mode 100644 index 94f91a1e..00000000 --- a/src/repl/handler.rs +++ /dev/null @@ -1,205 +0,0 @@ -use crate::client::init_client; -use crate::config::SharedConfig; -use crate::print_now; -use crate::render::{render_stream, MarkdownRender}; -use std::fs; -use std::io::Read; - -use super::abort::SharedAbortSignal; - -use anyhow::{bail, Context, Result}; -use arboard::Clipboard; -use crossbeam::channel::Sender; -use crossbeam::sync::WaitGroup; -use std::cell::RefCell; - -pub enum ReplCmd { - Submit(String), - Info, - RoleInfo, - SessionInfo, - SetModel(String), - SetRole(String), - ExitRole, - StartSession(Option), - ExitSession, - Set(String), - Copy, - ReadFile(String), -} - -pub struct ReplCmdHandler { - config: SharedConfig, - abort: SharedAbortSignal, - clipboard: std::result::Result, arboard::Error>, -} - -impl ReplCmdHandler { - pub fn init(config: SharedConfig, abort: SharedAbortSignal) -> Result { - let clipboard = Clipboard::new().map(RefCell::new); - Ok(Self { - config, - abort, - clipboard, - }) - } - - pub fn handle(&self, cmd: ReplCmd) -> Result<()> { - match cmd { - ReplCmd::Submit(input) => { - if input.is_empty() { - return Ok(()); - } - self.config.read().maybe_print_send_tokens(&input); - let wg = WaitGroup::new(); - let client = init_client(self.config.clone())?; - let ret = render_stream( - &input, - client.as_ref(), - &self.config, - true, - self.abort.clone(), - wg.clone(), - ); - wg.wait(); - let buffer = ret?; - self.config.write().save_message(&input, &buffer)?; - if self.config.read().auto_copy { - let _ = self.copy(&buffer); - } - } - ReplCmd::Info => { - let output = self.config.read().info()?; - print_now!("{}\n\n", output.trim_end()); - } - ReplCmd::SetModel(name) => { - self.config.write().set_model(&name)?; - print_now!("\n"); - } - ReplCmd::SetRole(name) => { - self.config.write().set_role(&name)?; - print_now!("\n"); - } - ReplCmd::RoleInfo => { - if let Some(role) = &self.config.read().role { - print_now!("{}\n\n", role.info()?); - } else { - bail!("No role") - } - } - ReplCmd::ExitRole => { - self.config.write().clear_role()?; - print_now!("\n"); - } - ReplCmd::StartSession(name) => { - self.config.write().start_session(&name)?; - print_now!("\n"); - } - ReplCmd::SessionInfo => { - if let Some(session) = &self.config.read().session { - let render_options = self.config.read().get_render_options()?; - let mut markdown_render = MarkdownRender::init(render_options)?; - print_now!("{}\n\n", session.render(&mut markdown_render)?); - } else { - bail!("No session") - } - } - ReplCmd::ExitSession => { - self.config.write().end_session()?; - print_now!("\n"); - } - ReplCmd::Set(input) => { - self.config.write().update(&input)?; - print_now!("\n"); - } - ReplCmd::Copy => { - let reply = self - .config - .read() - .last_message - .as_ref() - .map(|v| v.1.clone()) - .unwrap_or_default(); - self.copy(&reply) - .with_context(|| "Failed to copy the last output")?; - print_now!("\n"); - } - ReplCmd::ReadFile(file) => { - let mut contents = String::new(); - let mut file = fs::File::open(file).with_context(|| "Unable to open file")?; - file.read_to_string(&mut contents) - .with_context(|| "Unable to read file")?; - self.handle(ReplCmd::Submit(contents))?; - } - } - Ok(()) - } - - fn copy(&self, text: &str) -> Result<()> { - match self.clipboard.as_ref() { - Err(err) => bail!("{}", err), - Ok(clip) => { - clip.borrow_mut().set_text(text)?; - Ok(()) - } - } - } -} - -pub struct ReplyStreamHandler { - sender: Sender, - buffer: String, - abort: SharedAbortSignal, -} - -impl ReplyStreamHandler { - pub fn new(sender: Sender, abort: SharedAbortSignal) -> Self { - Self { - sender, - abort, - buffer: String::new(), - } - } - - pub fn text(&mut self, text: &str) -> Result<()> { - if self.buffer.is_empty() && text == "\n\n" { - return Ok(()); - } - self.buffer.push_str(text); - let ret = self - .sender - .send(ReplyStreamEvent::Text(text.to_string())) - .with_context(|| "Failed to send StreamEvent:Text"); - self.safe_ret(ret)?; - Ok(()) - } - - pub fn done(&mut self) -> Result<()> { - let ret = self - .sender - .send(ReplyStreamEvent::Done) - .with_context(|| "Failed to send StreamEvent:Done"); - self.safe_ret(ret)?; - Ok(()) - } - - pub fn get_buffer(&self) -> &str { - &self.buffer - } - - pub fn get_abort(&self) -> SharedAbortSignal { - self.abort.clone() - } - - fn safe_ret(&self, ret: Result<()>) -> Result<()> { - if ret.is_err() && self.abort.aborted() { - return Ok(()); - } - ret - } -} - -pub enum ReplyStreamEvent { - Text(String), - Done, -} diff --git a/src/repl/init.rs b/src/repl/init.rs deleted file mode 100644 index fbbcf222..00000000 --- a/src/repl/init.rs +++ /dev/null @@ -1,85 +0,0 @@ -use super::{ - highlighter::ReplHighlighter, prompt::ReplPrompt, validator::ReplValidator, REPL_COMMANDS, -}; - -use crate::config::SharedConfig; - -use anyhow::Result; -use reedline::{ - default_emacs_keybindings, default_vi_insert_keybindings, default_vi_normal_keybindings, - ColumnarMenu, DefaultCompleter, EditMode, Emacs, KeyCode, KeyModifiers, Keybindings, Reedline, - ReedlineEvent, ReedlineMenu, Vi, -}; - -const MENU_NAME: &str = "completion_menu"; - -pub struct Repl { - pub(crate) editor: Reedline, - pub(crate) prompt: ReplPrompt, -} - -impl Repl { - pub fn init(config: SharedConfig) -> Result { - let commands: Vec = REPL_COMMANDS - .into_iter() - .map(|(v, _)| v.to_string()) - .collect(); - - let completer = Self::create_completer(&config, &commands); - let highlighter = ReplHighlighter::new(commands, config.clone()); - let menu = Self::create_menu(); - let edit_mode: Box = if config.read().keybindings.is_vi() { - let mut normal_keybindings = default_vi_normal_keybindings(); - let mut insert_keybindings = default_vi_insert_keybindings(); - Self::extra_keybindings(&mut normal_keybindings); - Self::extra_keybindings(&mut insert_keybindings); - Box::new(Vi::new(insert_keybindings, normal_keybindings)) - } else { - let mut keybindings = default_emacs_keybindings(); - Self::extra_keybindings(&mut keybindings); - Box::new(Emacs::new(keybindings)) - }; - let mut editor = Reedline::create() - .with_completer(Box::new(completer)) - .with_highlighter(Box::new(highlighter)) - .with_menu(menu) - .with_edit_mode(edit_mode) - .with_quick_completions(true) - .with_partial_completions(true) - .with_validator(Box::new(ReplValidator)) - .with_ansi_colors(true); - editor.enable_bracketed_paste()?; - let prompt = ReplPrompt::new(config); - Ok(Self { editor, prompt }) - } - - fn create_completer(config: &SharedConfig, commands: &[String]) -> DefaultCompleter { - let mut completion = commands.to_vec(); - completion.extend(config.read().repl_completions()); - let mut completer = - DefaultCompleter::with_inclusions(&['.', '-', '_', ':']).set_min_word_len(2); - completer.insert(completion.clone()); - completer - } - - fn extra_keybindings(keybindings: &mut Keybindings) { - keybindings.add_binding( - KeyModifiers::NONE, - KeyCode::Tab, - ReedlineEvent::UntilFound(vec![ - ReedlineEvent::Menu(MENU_NAME.to_string()), - ReedlineEvent::MenuNext, - ]), - ); - keybindings.add_binding( - KeyModifiers::CONTROL, - KeyCode::Char('s'), - ReedlineEvent::Submit, - ); - } - - fn create_menu() -> ReedlineMenu { - let completion_menu = ColumnarMenu::default().with_name(MENU_NAME); - ReedlineMenu::EngineCompleter(Box::new(completion_menu)) - } -} diff --git a/src/repl/mod.rs b/src/repl/mod.rs index eb99fda1..c9c3c737 100644 --- a/src/repl/mod.rs +++ b/src/repl/mod.rs @@ -1,24 +1,35 @@ mod abort; -mod handler; mod highlighter; -mod init; mod prompt; mod validator; -pub use self::abort::*; -pub use self::handler::*; -pub use self::init::Repl; +pub use self::abort::{create_abort_signal, AbortSignal}; +use self::highlighter::ReplHighlighter; +use self::prompt::ReplPrompt; +use self::validator::ReplValidator; + +use crate::client::init_client; use crate::config::SharedConfig; -use crate::print_now; +use crate::render::{render_error, render_stream}; -use anyhow::Result; +use anyhow::{bail, Context, Result}; +use arboard::Clipboard; +use crossbeam::sync::WaitGroup; use fancy_regex::Regex; use lazy_static::lazy_static; use reedline::Signal; -use std::rc::Rc; +use reedline::{ + default_emacs_keybindings, default_vi_insert_keybindings, default_vi_normal_keybindings, + ColumnarMenu, DefaultCompleter, EditMode, Emacs, KeyCode, KeyModifiers, Keybindings, Reedline, + ReedlineEvent, ReedlineMenu, Vi, +}; +use std::cell::RefCell; +use std::io::Read; + +const MENU_NAME: &str = "completion_menu"; -pub const REPL_COMMANDS: [(&str, &str); 14] = [ +const REPL_COMMANDS: [(&str, &str); 14] = [ (".help", "Print this help message"), (".info", "Print system info"), (".edit", "Multi-line editing (CTRL+S to finish)"), @@ -40,128 +51,284 @@ lazy_static! { static ref EDIT_RE: Regex = Regex::new(r"^\s*\.edit\s*").unwrap(); } +pub struct Repl { + config: SharedConfig, + editor: Reedline, + prompt: ReplPrompt, + abort: AbortSignal, + clipboard: std::result::Result, arboard::Error>, +} + impl Repl { - pub fn run(&mut self, config: SharedConfig) -> Result<()> { - let abort = AbortSignal::new(); - let handler = ReplCmdHandler::init(config, abort.clone())?; - print_now!("Welcome to aichat {}\n", env!("CARGO_PKG_VERSION")); - print_now!("Type \".help\" for more information.\n"); + pub fn init(config: SharedConfig) -> Result { + let commands: Vec = REPL_COMMANDS + .into_iter() + .map(|(v, _)| v.to_string()) + .collect(); + + let completer = Self::create_completer(&config, &commands); + let highlighter = ReplHighlighter::new(commands, config.clone()); + let menu = Self::create_menu(); + let edit_mode: Box = if config.read().keybindings.is_vi() { + let mut normal_keybindings = default_vi_normal_keybindings(); + let mut insert_keybindings = default_vi_insert_keybindings(); + Self::extra_keybindings(&mut normal_keybindings); + Self::extra_keybindings(&mut insert_keybindings); + Box::new(Vi::new(insert_keybindings, normal_keybindings)) + } else { + let mut keybindings = default_emacs_keybindings(); + Self::extra_keybindings(&mut keybindings); + Box::new(Emacs::new(keybindings)) + }; + let mut editor = Reedline::create() + .with_completer(Box::new(completer)) + .with_highlighter(Box::new(highlighter)) + .with_menu(menu) + .with_edit_mode(edit_mode) + .with_quick_completions(true) + .with_partial_completions(true) + .with_validator(Box::new(ReplValidator)) + .with_ansi_colors(true); + + editor.enable_bracketed_paste()?; + + let prompt = ReplPrompt::new(config.clone()); + + let abort = create_abort_signal(); + + let clipboard = Clipboard::new().map(RefCell::new); + + Ok(Self { + config, + editor, + prompt, + clipboard, + abort, + }) + } + + pub fn run(&mut self) -> Result<()> { + self.banner(); + let mut already_ctrlc = false; - let handler = Rc::new(handler); + loop { - if abort.aborted_ctrld() { + if self.abort.aborted_ctrld() { break; } - if abort.aborted_ctrlc() && !already_ctrlc { + if self.abort.aborted_ctrlc() && !already_ctrlc { already_ctrlc = true; } let sig = self.editor.read_line(&self.prompt); match sig { Ok(Signal::Success(line)) => { already_ctrlc = false; - abort.reset(); - match self.handle_line(&handler, &line) { + self.abort.reset(); + match self.handle(&line) { Ok(quit) => { if quit { break; } } Err(err) => { - let err = format!("{err:?}"); - print_now!("Error: {}\n\n", err.trim()); + render_error(err, self.config.read().highlight); } } } Ok(Signal::CtrlC) => { - abort.set_ctrlc(); + self.abort.set_ctrlc(); if already_ctrlc { break; } already_ctrlc = true; - print_now!("(To exit, press Ctrl+C again or Ctrl+D or type .exit)\n\n"); + println!("(To exit, press Ctrl+C again or Ctrl+D or type .exit)\n"); } Ok(Signal::CtrlD) => { - abort.set_ctrld(); + self.abort.set_ctrld(); break; } _ => {} } } - handler.handle(ReplCmd::ExitSession)?; + self.handle(".exit session")?; Ok(()) } - fn handle_line(&mut self, handler: &Rc, line: &str) -> Result { + fn handle(&self, line: &str) -> Result { match parse_command(line) { Some((cmd, args)) => match cmd { ".help" => { dump_repl_help(); } ".info" => match args { - Some("role") => handler.handle(ReplCmd::RoleInfo)?, - Some("session") => handler.handle(ReplCmd::SessionInfo)?, - Some(_) => unknown_command(), + Some("role") => { + let info = self.config.read().role_info()?; + println!("{}", info); + } + Some("session") => { + let info = self.config.read().session_info()?; + println!("{}", info); + } + Some(_) => unknown_command()?, None => { - handler.handle(ReplCmd::Info)?; + let output = self.config.read().sys_info()?; + println!("{}", output); } }, ".edit" => { if let Some(text) = args { - handler.handle(ReplCmd::Submit(text.to_string()))?; + self.ask(text)?; } } ".model" => match args { - Some(name) => handler.handle(ReplCmd::SetModel(name.to_string()))?, - None => print_now!("Usage: .model \n\n"), + Some(name) => { + self.config.write().set_model(name)?; + } + None => println!("Usage: .model "), }, ".role" => match args { - Some(name) => handler.handle(ReplCmd::SetRole(name.to_string()))?, - None => print_now!("Usage: .role \n\n"), + Some(name) => { + self.config.write().set_role(name)?; + } + None => println!("Usage: .role "), }, ".session" => { - handler.handle(ReplCmd::StartSession(args.map(|v| v.to_string())))?; + self.config.write().start_session(args)?; } ".set" => { - handler.handle(ReplCmd::Set(args.unwrap_or_default().to_string()))?; + if let Some(args) = args { + self.config.write().update(args)?; + } } ".copy" => { - handler.handle(ReplCmd::Copy)?; + let config = self.config.read(); + self.copy(config.last_reply()) + .with_context(|| "Failed to copy the last output")?; } ".read" => match args { - Some(file) => handler.handle(ReplCmd::ReadFile(file.to_string()))?, - None => print_now!("Usage: .read \n\n"), + Some(file) => { + let mut content = String::new(); + let mut file = + std::fs::File::open(file).with_context(|| "Unable to open file")?; + file.read_to_string(&mut content) + .with_context(|| "Unable to read file")?; + self.ask(&content)?; + } + None => println!("Usage: .read "), }, ".exit" => match args { - Some("role") => handler.handle(ReplCmd::ExitRole)?, - Some("session") => handler.handle(ReplCmd::ExitSession)?, - Some(_) => unknown_command(), + Some("role") => { + self.config.write().clear_role()?; + } + Some("session") => { + self.config.write().end_session()?; + } + Some(_) => unknown_command()?, None => { return Ok(true); } }, - // deprecated + // deprecated this command ".clear" => match args { Some("role") => { - print_now!("Deprecated. Use '.exit role' instead.\n\n"); + println!(r#"Deprecated. Use ".exit role" instead."#); } - Some("session") => { - print_now!("Deprecated. Use '.exit session' instead.\n\n"); + Some("conversation") => { + println!(r#"Deprecated. Use ".exit session" instead."#); } - _ => unknown_command(), + _ => unknown_command()?, }, - _ => unknown_command(), + _ => unknown_command()?, }, None => { - handler.handle(ReplCmd::Submit(line.to_string()))?; + self.ask(line)?; } } + println!(); + Ok(false) } + + fn ask(&self, input: &str) -> Result<()> { + if input.is_empty() { + return Ok(()); + } + self.config.read().maybe_print_send_tokens(input); + let wg = WaitGroup::new(); + let client = init_client(self.config.clone())?; + let ret = render_stream( + input, + client.as_ref(), + &self.config, + true, + self.abort.clone(), + wg.clone(), + ); + wg.wait(); + let buffer = ret?; + self.config.write().save_message(input, &buffer)?; + if self.config.read().auto_copy { + let _ = self.copy(&buffer); + } + Ok(()) + } + + fn banner(&self) { + let version = env!("CARGO_PKG_VERSION"); + print!( + r#"Welcome to aichat {version} +Type ".help" for more information. +"# + ) + } + + fn create_completer(config: &SharedConfig, commands: &[String]) -> DefaultCompleter { + let mut completion = commands.to_vec(); + completion.extend(config.read().repl_completions()); + let mut completer = + DefaultCompleter::with_inclusions(&['.', '-', '_', ':']).set_min_word_len(2); + completer.insert(completion.clone()); + completer + } + + fn extra_keybindings(keybindings: &mut Keybindings) { + keybindings.add_binding( + KeyModifiers::NONE, + KeyCode::Tab, + ReedlineEvent::UntilFound(vec![ + ReedlineEvent::Menu(MENU_NAME.to_string()), + ReedlineEvent::MenuNext, + ]), + ); + keybindings.add_binding( + KeyModifiers::CONTROL, + KeyCode::Char('s'), + ReedlineEvent::Submit, + ); + } + + fn create_menu() -> ReedlineMenu { + let completion_menu = ColumnarMenu::default().with_name(MENU_NAME); + ReedlineMenu::EngineCompleter(Box::new(completion_menu)) + } + + fn copy(&self, text: &str) -> Result<()> { + if text.is_empty() { + bail!("No text") + } + match self.clipboard.as_ref() { + Err(err) => bail!("{}", err), + Ok(clip) => { + clip.borrow_mut().set_text(text)?; + Ok(()) + } + } + } } -fn unknown_command() { - print_now!("Unknown command. Try `.help`.\n\n"); +fn unknown_command() -> Result<()> { + bail!(r#"Unknown command. Type ".help" for more information."#); } fn dump_repl_help() { @@ -170,12 +337,10 @@ fn dump_repl_help() { .map(|(name, desc)| format!("{name:<24} {desc}")) .collect::>() .join("\n"); - print_now!( + println!( r###"{head} -Press Ctrl+C to abort readline, Ctrl+D to exit the REPL - -"###, +Press Ctrl+C to abort readline, Ctrl+D to exit the REPL"###, ); } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index ab922e33..938e07cb 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -6,24 +6,9 @@ pub use self::prompt_input::*; pub use self::split_line::*; pub use self::tiktoken::cl100k_base_singleton; -use chrono::prelude::*; -use std::io::{stdout, Write}; - -#[macro_export] -macro_rules! print_now { - ($($arg:tt)*) => { - $crate::utils::print_now(&format!($($arg)*)) - }; -} - -pub fn print_now(text: &T) { - print!("{}", text.to_string()); - let _ = stdout().flush(); -} - pub fn now() -> String { - let now = Local::now(); - now.to_rfc3339_opts(SecondsFormat::Secs, false) + let now = chrono::Local::now(); + now.to_rfc3339_opts(chrono::SecondsFormat::Secs, false) } pub fn get_env_name(key: &str) -> String {