Skip to content

Commit

Permalink
refactor: improve code quanity (#196)
Browse files Browse the repository at this point in the history
- rewrite Repl, remove ReplHandler
- move ReplyStreamHandler to repl/ and rename it to ReplyHandler
- deprecate utils::print_now
- refactor session info
  • Loading branch information
sigoden authored Nov 2, 2023
1 parent 5c7bfd9 commit 444f4eb
Show file tree
Hide file tree
Showing 18 changed files with 443 additions and 477 deletions.
4 changes: 3 additions & 1 deletion src/client/azure_openai.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS};
use super::{AzureOpenAIClient, ExtraConfig, ModelInfo, PromptKind, PromptType, SendData};
use super::{AzureOpenAIClient, ExtraConfig, PromptType, SendData};

use crate::{config::ModelInfo, utils::PromptKind};

use anyhow::{anyhow, Result};
use async_trait::async_trait;
Expand Down
15 changes: 6 additions & 9 deletions src/client/common.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{
config::{Message, SharedConfig},
repl::{ReplyStreamHandler, SharedAbortSignal},
render::ReplyHandler,
repl::AbortSignal,
utils::{init_tokio_runtime, prompt_input_integer, prompt_input_string, tokenize, PromptKind},
};

Expand Down Expand Up @@ -139,7 +140,7 @@ macro_rules! openai_compatible_client {
async fn send_message_streaming_inner(
&self,
client: &reqwest::Client,
handler: &mut $crate::repl::ReplyStreamHandler,
handler: &mut $crate::render::ReplyHandler,
data: $crate::client::SendData,
) -> Result<()> {
let builder = self.request_builder(client, data)?;
Expand Down Expand Up @@ -201,12 +202,8 @@ pub trait Client {
})
}

fn send_message_streaming(
&self,
content: &str,
handler: &mut ReplyStreamHandler,
) -> Result<()> {
async fn watch_abort(abort: SharedAbortSignal) {
fn send_message_streaming(&self, content: &str, handler: &mut ReplyHandler) -> Result<()> {
async fn watch_abort(abort: AbortSignal) {
loop {
if abort.aborted() {
break;
Expand Down Expand Up @@ -252,7 +249,7 @@ pub trait Client {
async fn send_message_streaming_inner(
&self,
client: &ReqwestClient,
handler: &mut ReplyStreamHandler,
handler: &mut ReplyHandler,
data: SendData,
) -> Result<()>;
}
Expand Down
4 changes: 3 additions & 1 deletion src/client/localai.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS};
use super::{ExtraConfig, LocalAIClient, ModelInfo, PromptKind, PromptType, SendData};
use super::{ExtraConfig, LocalAIClient, PromptType, SendData};

use crate::{config::ModelInfo, utils::PromptKind};

use anyhow::Result;
use async_trait::async_trait;
Expand Down
6 changes: 0 additions & 6 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@ mod common;

pub use common::*;

use crate::{
config::{ModelInfo, TokensCountFactors},
repl::ReplyStreamHandler,
utils::PromptKind,
};

register_client!(
(openai, "openai", OpenAI, OpenAIConfig, OpenAIClient),
(localai, "localai", LocalAI, LocalAIConfig, LocalAIClient),
Expand Down
11 changes: 8 additions & 3 deletions src/client/openai.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
use super::{
ExtraConfig, ModelInfo, OpenAIClient, PromptKind, PromptType, ReplyStreamHandler, SendData,
TokensCountFactors,
ExtraConfig, OpenAIClient, PromptType, SendData,
};

use crate::{
config::{ModelInfo, TokensCountFactors},
render::ReplyHandler,
utils::PromptKind,
};

use anyhow::{anyhow, bail, Result};
Expand Down Expand Up @@ -88,7 +93,7 @@ pub async fn openai_send_message(builder: RequestBuilder) -> Result<String> {

pub async fn openai_send_message_streaming(
builder: RequestBuilder,
handler: &mut ReplyStreamHandler,
handler: &mut ReplyHandler,
) -> Result<()> {
let mut es = builder.eventsource()?;
while let Some(event) = es.next().await {
Expand Down
62 changes: 49 additions & 13 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::client::{
all_models, create_client_config, list_client_types, ClientConfig, ExtraConfig, OpenAIClient,
SendData,
};
use crate::render::RenderOptions;
use crate::render::{MarkdownRender, RenderOptions};
use crate::utils::{get_env_name, light_theme_from_colorfgbg, now, prompt_op_err};

use anyhow::{anyhow, bail, Context, Result};
Expand Down Expand Up @@ -336,7 +336,7 @@ impl Config {
}
}

pub fn info(&self) -> Result<String> {
pub fn sys_info(&self) -> Result<String> {
let path_info = |path: &Path| {
let state = if path.exists() { "" } else { " ⚠️" };
format!("{}{state}", path.display())
Expand All @@ -349,27 +349,63 @@ impl Config {
.clone()
.map_or_else(|| String::from("no"), |v| v.to_string());
let items = vec![
("config_file", path_info(&Self::config_file()?)),
("roles_file", path_info(&Self::roles_file()?)),
("messages_file", path_info(&Self::messages_file()?)),
("sessions_dir", path_info(&Self::sessions_dir()?)),
("model", self.model_info.full_name()),
("temperature", temperature),
("dry_run", self.dry_run.to_string()),
("save", self.save.to_string()),
("highlight", self.highlight.to_string()),
("light_theme", self.light_theme.to_string()),
("wrap", wrap),
("wrap_code", self.wrap_code.to_string()),
("dry_run", self.dry_run.to_string()),
("light_theme", self.light_theme.to_string()),
("keybindings", self.keybindings.stringify().into()),
("config_file", path_info(&Self::config_file()?)),
("roles_file", path_info(&Self::roles_file()?)),
("messages_file", path_info(&Self::messages_file()?)),
("sessions_dir", path_info(&Self::sessions_dir()?)),
];
let mut output = String::new();
for (name, value) in items {
output.push_str(&format!("{name:<20}{value}\n"));
}
let output = items
.iter()
.map(|(name, value)| format!("{name:<20}{value}"))
.collect::<Vec<String>>()
.join("\n");
Ok(output)
}

pub fn role_info(&self) -> Result<String> {
if let Some(role) = &self.role {
role.info()
} else {
bail!("No role")
}
}

pub fn session_info(&self) -> Result<String> {
if let Some(session) = &self.session {
let render_options = self.get_render_options()?;
let mut markdown_render = MarkdownRender::init(render_options)?;
session.render(&mut markdown_render)
} else {
bail!("No session")
}
}

pub fn info(&self) -> Result<String> {
if let Some(session) = &self.session {
session.export()
} else if let Some(role) = &self.role {
role.info()
} else {
self.sys_info()
}
}

pub fn last_reply(&self) -> &str {
self.last_message
.as_ref()
.map(|(_, reply)| reply.as_str())
.unwrap_or_default()
}

pub fn repl_completions(&self) -> Vec<String> {
let mut completion: Vec<String> = self
.roles
Expand Down Expand Up @@ -423,7 +459,7 @@ impl Config {
Ok(())
}

pub fn start_session(&mut self, session: &Option<String>) -> Result<()> {
pub fn start_session(&mut self, session: Option<&str>) -> Result<()> {
if self.session.is_some() {
bail!("Already in a session, please use '.clear session' to exit the session first?");
}
Expand Down
2 changes: 1 addition & 1 deletion src/config/model_info.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::Message;
use super::message::Message;

use crate::utils::count_tokens;

Expand Down
2 changes: 1 addition & 1 deletion src/config/role.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ impl Role {
pub fn info(&self) -> Result<String> {
let output = serde_yaml::to_string(&self)
.with_context(|| format!("Unable to show info about role {}", &self.name))?;
Ok(output)
Ok(output.trim_end().to_string())
}

pub fn embedded(&self) -> bool {
Expand Down
75 changes: 39 additions & 36 deletions src/config/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ impl Session {
if let Some(temperature) = self.temperature() {
data["temperature"] = temperature.into();
}
data["total-tokens"] = tokens.into();
data["total_tokens"] = tokens.into();
if let Some(max_tokens) = self.model_info.max_tokens {
data["max-tokens"] = max_tokens.into();
data["max_tokens"] = max_tokens.into();
}
if percent != 0.0 {
data["total/max-tokens"] = format!("{}%", percent).into();
data["total/max"] = format!("{}%", percent).into();
}
data["messages"] = json!(self.messages);

Expand All @@ -97,43 +97,46 @@ impl Session {
}

pub fn render(&self, render: &mut MarkdownRender) -> Result<String> {
let path = self.path.clone().unwrap_or_else(|| "-".to_string());

let temperature = self
.temperature()
.map_or_else(|| String::from("-"), |v| v.to_string());

let max_tokens = self
.model_info
.max_tokens
.map(|v| v.to_string())
.unwrap_or_else(|| '-'.to_string());

let items = vec![
("path", path),
("model", self.model().to_string()),
("temperature", temperature),
("max_tokens", max_tokens),
];
let mut lines = vec![];
for (name, value) in items {
lines.push(format!("{name:<20}{value}"));
let mut items = vec![];

if let Some(path) = &self.path {
items.push(("path", path.to_string()));
}
lines.push("".into());
for message in &self.messages {
match message.role {
MessageRole::System => {
continue;
}
MessageRole::Assistant => {
lines.push(render.render(&message.content));
lines.push("".into());
}
MessageRole::User => {
lines.push(format!("{}){}", self.name, message.content));

items.push(("model", self.model_info.full_name()));

if let Some(temperature) = self.temperature() {
items.push(("temperature", temperature.to_string()));
}

if let Some(max_tokens) = self.model_info.max_tokens {
items.push(("max_tokens", max_tokens.to_string()));
}

let mut lines: Vec<String> = items
.iter()
.map(|(name, value)| format!("{name:<20}{value}"))
.collect();

if !self.is_empty() {
lines.push("".into());

for message in &self.messages {
match message.role {
MessageRole::System => {
continue;
}
MessageRole::Assistant => {
lines.push(render.render(&message.content));
lines.push("".into());
}
MessageRole::User => {
lines.push(format!("{}){}", self.name, message.content));
}
}
}
}

let output = lines.join("\n");
Ok(output)
}
Expand Down
22 changes: 9 additions & 13 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crossbeam::sync::WaitGroup;
use is_terminal::IsTerminal;
use parking_lot::RwLock;
use render::{render_stream, MarkdownRender};
use repl::{AbortSignal, Repl};
use repl::{create_abort_signal, Repl};
use std::io::{stdin, Read};
use std::sync::Arc;
use std::{io::stdout, process::exit};
Expand Down Expand Up @@ -59,7 +59,9 @@ fn main() -> Result<()> {
config.write().set_role(name)?;
}
if let Some(session) = &cli.session {
config.write().start_session(session)?;
config
.write()
.start_session(session.as_ref().map(|v| v.as_str()))?;
}
if let Some(model) = &cli.model {
config.write().set_model(model)?;
Expand All @@ -68,14 +70,8 @@ fn main() -> Result<()> {
config.write().highlight = false;
}
if cli.info {
let info = if let Some(session) = &config.read().session {
session.export()?
} else if let Some(role) = &config.read().role {
role.info()?
} else {
config.read().info()?
};
println!("{info}");
let info = config.read().info()?;
println!("{}", info);
exit(0);
}
let no_stream = cli.no_stream;
Expand Down Expand Up @@ -116,12 +112,12 @@ fn start_directive(
output
} else {
let wg = WaitGroup::new();
let abort = AbortSignal::new();
let abort = create_abort_signal();
let abort_clone = abort.clone();
ctrlc::set_handler(move || {
abort_clone.set_ctrlc();
})
.expect("Error setting Ctrl-C handler");
.expect("Failed to setting Ctrl-C handler");
let output = render_stream(input, client, config, false, abort, wg.clone())?;
wg.wait();
output
Expand All @@ -132,5 +128,5 @@ fn start_directive(
fn start_interactive(config: SharedConfig) -> Result<()> {
cl100k_base_singleton();
let mut repl: Repl = Repl::init(config.clone())?;
repl.run(config)
repl.run()
}
Loading

0 comments on commit 444f4eb

Please sign in to comment.