From f9c40e52dabda7b037805c0635b84ccb6d75f5a8 Mon Sep 17 00:00:00 2001 From: sigoden Date: Fri, 3 Nov 2023 06:52:57 +0800 Subject: [PATCH] refactor: improve code quanity (#203) - update field name of ModelInfo - rename ModelInfo to Model --- src/client/azure_openai.rs | 12 +++--- src/client/common.rs | 18 ++++----- src/client/localai.rs | 10 ++--- src/client/mod.rs | 4 +- src/client/{model_info.rs => model.rs} | 30 +++++++-------- src/client/openai.rs | 10 ++--- src/config/mod.rs | 51 ++++++++++++-------------- src/config/session.rs | 31 ++++++++-------- 8 files changed, 82 insertions(+), 84 deletions(-) rename src/client/{model_info.rs => model.rs} (82%) diff --git a/src/client/azure_openai.rs b/src/client/azure_openai.rs index f8a9daee..d1dc43b3 100644 --- a/src/client/azure_openai.rs +++ b/src/client/azure_openai.rs @@ -1,5 +1,5 @@ use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS}; -use super::{AzureOpenAIClient, ExtraConfig, PromptType, SendData, ModelInfo}; +use super::{AzureOpenAIClient, ExtraConfig, PromptType, SendData, Model}; use crate::utils::PromptKind; @@ -42,14 +42,14 @@ impl AzureOpenAIClient { ), ]; - pub fn list_models(local_config: &AzureOpenAIConfig, index: usize) -> Vec { - let client = Self::name(local_config); + pub fn list_models(local_config: &AzureOpenAIConfig, client_index: usize) -> Vec { + let client_name = Self::name(local_config); local_config .models .iter() .map(|v| { - ModelInfo::new(index, client, &v.name) + Model::new(client_index, client_name, &v.name) .set_max_tokens(v.max_tokens) .set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS) }) @@ -70,11 +70,11 @@ impl AzureOpenAIClient { let api_base = self.get_api_base()?; - let body = openai_build_body(data, self.model_info.name.clone()); + let body = openai_build_body(data, self.model.llm_name.clone()); let url = format!( "{}/openai/deployments/{}/chat/completions?api-version=2023-05-15", - &api_base, self.model_info.name + &api_base, self.model.llm_name ); let builder = client.post(url).header("api-key", api_key).json(&body); diff --git a/src/client/common.rs b/src/client/common.rs index 0dc637bd..464e46ad 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -46,16 +46,16 @@ macro_rules! register_client { pub struct $client { global_config: $crate::config::GlobalConfig, config: $config, - model_info: $crate::client::ModelInfo, + model: $crate::client::Model, } impl $client { pub const NAME: &str = $name; pub fn init(global_config: $crate::config::GlobalConfig) -> Option> { - let model_info = global_config.read().model_info.clone(); + let model = global_config.read().model.clone(); let config = { - if let ClientConfig::$config_key(c) = &global_config.read().clients[model_info.index] { + if let ClientConfig::$config_key(c) = &global_config.read().clients[model.client_index] { c.clone() } else { return None; @@ -64,7 +64,7 @@ macro_rules! register_client { Some(Box::new(Self { global_config, config, - model_info, + model, })) } @@ -79,11 +79,11 @@ macro_rules! register_client { None $(.or_else(|| $client::init(config.clone())))+ .ok_or_else(|| { - let model_info = config.read().model_info.clone(); + let model = config.read().model.clone(); anyhow::anyhow!( - "Unknown client {} at config.clients[{}]", - &model_info.client, - &model_info.index + "Unknown client '{}' at config.clients[{}]", + &model.client_name, + &model.client_index ) }) } @@ -101,7 +101,7 @@ macro_rules! register_client { anyhow::bail!("Unknown client {}", client) } - pub fn list_models(config: &$crate::config::Config) -> Vec<$crate::client::ModelInfo> { + pub fn list_models(config: &$crate::config::Config) -> Vec<$crate::client::Model> { config .clients .iter() diff --git a/src/client/localai.rs b/src/client/localai.rs index 5cc12cc5..eb4de654 100644 --- a/src/client/localai.rs +++ b/src/client/localai.rs @@ -1,5 +1,5 @@ use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS}; -use super::{ExtraConfig, LocalAIClient, PromptType, SendData, ModelInfo}; +use super::{ExtraConfig, LocalAIClient, PromptType, SendData, Model}; use crate::utils::PromptKind; @@ -41,14 +41,14 @@ impl LocalAIClient { ), ]; - pub fn list_models(local_config: &LocalAIConfig, index: usize) -> Vec { - let client = Self::name(local_config); + pub fn list_models(local_config: &LocalAIConfig, client_index: usize) -> Vec { + let client_name = Self::name(local_config); local_config .models .iter() .map(|v| { - ModelInfo::new(index, client, &v.name) + Model::new(client_index, client_name, &v.name) .set_max_tokens(v.max_tokens) .set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS) }) @@ -58,7 +58,7 @@ impl LocalAIClient { fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.get_api_key().ok(); - let body = openai_build_body(data, self.model_info.name.clone()); + let body = openai_build_body(data, self.model.llm_name.clone()); let chat_endpoint = self .config diff --git a/src/client/mod.rs b/src/client/mod.rs index 19a0875a..7ac9aa01 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,11 +1,11 @@ #[macro_use] mod common; mod message; -mod model_info; +mod model; pub use common::*; pub use message::*; -pub use model_info::*; +pub use model::*; register_client!( (openai, "openai", OpenAI, OpenAIConfig, OpenAIClient), diff --git a/src/client/model_info.rs b/src/client/model.rs similarity index 82% rename from src/client/model_info.rs rename to src/client/model.rs index 9e74951c..d00ad462 100644 --- a/src/client/model_info.rs +++ b/src/client/model.rs @@ -7,31 +7,35 @@ use anyhow::{bail, Result}; pub type TokensCountFactors = (usize, usize); // (per-messages, bias) #[derive(Debug, Clone)] -pub struct ModelInfo { - pub client: String, - pub name: String, - pub index: usize, +pub struct Model { + pub client_index: usize, + pub client_name: String, + pub llm_name: String, pub max_tokens: Option, pub tokens_count_factors: TokensCountFactors, } -impl Default for ModelInfo { +impl Default for Model { fn default() -> Self { - ModelInfo::new(0, "", "") + Model::new(0, "", "") } } -impl ModelInfo { - pub fn new(index: usize, client: &str, name: &str) -> Self { +impl Model { + pub fn new(client_index: usize, client_name: &str, name: &str) -> Self { Self { - index, - client: client.into(), - name: name.into(), + client_index, + client_name: client_name.into(), + llm_name: name.into(), max_tokens: None, tokens_count_factors: Default::default(), } } + pub fn id(&self) -> String { + format!("{}:{}", self.client_name, self.llm_name) + } + pub fn set_max_tokens(mut self, max_tokens: Option) -> Self { match max_tokens { None | Some(0) => self.max_tokens = None, @@ -45,10 +49,6 @@ impl ModelInfo { self } - pub fn id(&self) -> String { - format!("{}:{}", self.client, self.name) - } - pub fn messages_tokens(&self, messages: &[Message]) -> usize { messages.iter().map(|v| count_tokens(&v.content)).sum() } diff --git a/src/client/openai.rs b/src/client/openai.rs index 5589d2d8..f1243f04 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -1,6 +1,6 @@ use super::{ ExtraConfig, OpenAIClient, PromptType, SendData, - ModelInfo, TokensCountFactors, + Model, TokensCountFactors, }; use crate::{ @@ -44,12 +44,12 @@ impl OpenAIClient { pub const PROMPTS: [PromptType<'static>; 1] = [("api_key", "API Key:", true, PromptKind::String)]; - pub fn list_models(local_config: &OpenAIConfig, index: usize) -> Vec { - let client = Self::name(local_config); + pub fn list_models(local_config: &OpenAIConfig, client_index: usize) -> Vec { + let client_name = Self::name(local_config); MODELS .into_iter() .map(|(name, max_tokens)| { - ModelInfo::new(index, client, name) + Model::new(client_index, client_name, name) .set_max_tokens(Some(max_tokens)) .set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS) }) @@ -59,7 +59,7 @@ impl OpenAIClient { fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.get_api_key()?; - let body = openai_build_body(data, self.model_info.name.clone()); + let body = openai_build_body(data, self.model.llm_name.clone()); let env_prefix = Self::name(&self.config).to_uppercase(); let api_base = env::var(format!("{env_prefix}_API_BASE")) diff --git a/src/config/mod.rs b/src/config/mod.rs index 8589a8b9..00a0b679 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -6,7 +6,7 @@ use self::session::{Session, TEMP_SESSION_NAME}; use crate::client::{ create_client_config, list_client_types, list_models, ClientConfig, ExtraConfig, Message, - ModelInfo, OpenAIClient, SendData, + Model, OpenAIClient, SendData, }; use crate::render::{MarkdownRender, RenderOptions}; use crate::utils::{get_env_name, light_theme_from_colorfgbg, now, prompt_op_err}; @@ -41,7 +41,8 @@ const CLIENTS_FIELD: &str = "clients"; #[serde(default)] pub struct Config { /// LLM model - pub model: Option, + #[serde(rename(serialize = "model", deserialize = "model"))] + pub model_id: Option, /// GPT temperature, between 0 and 2 #[serde(rename(serialize = "temperature", deserialize = "temperature"))] pub default_temperature: Option, @@ -73,7 +74,7 @@ pub struct Config { #[serde(skip)] pub session: Option, #[serde(skip)] - pub model_info: ModelInfo, + pub model: Model, #[serde(skip)] pub last_message: Option<(String, String)>, #[serde(skip)] @@ -83,7 +84,7 @@ pub struct Config { impl Default for Config { fn default() -> Self { Self { - model: None, + model_id: None, default_temperature: None, save: true, highlight: true, @@ -97,7 +98,7 @@ impl Default for Config { roles: vec![], role: None, session: None, - model_info: Default::default(), + model: Default::default(), last_message: None, temperature: None, } @@ -135,7 +136,7 @@ impl Config { config.load_roles()?; - config.setup_model_info()?; + config.setup_model()?; config.setup_highlight(); config.setup_light_theme()?; @@ -304,22 +305,22 @@ impl Config { pub fn set_model(&mut self, value: &str) -> Result<()> { let models = list_models(self); - let mut model_info = None; + let mut model = None; let value = value.trim_end_matches(':'); if value.contains(':') { - if let Some(model) = models.iter().find(|v| v.id() == value) { - model_info = Some(model.clone()); + if let Some(found) = models.iter().find(|v| v.id() == value) { + model = Some(found.clone()); } - } else if let Some(model) = models.iter().find(|v| v.client == value) { - model_info = Some(model.clone()); + } else if let Some(found) = models.iter().find(|v| v.client_name == value) { + model = Some(found.clone()); } - match model_info { + match model { None => bail!("Unknown model '{}'", value), - Some(model_info) => { + Some(model) => { if let Some(session) = self.session.as_mut() { - session.set_model(model_info.clone())?; + session.set_model(model.clone())?; } - self.model_info = model_info; + self.model = model; Ok(()) } } @@ -338,7 +339,7 @@ impl Config { .clone() .map_or_else(|| String::from("no"), |v| v.to_string()); let items = vec![ - ("model", self.model_info.id()), + ("model", self.model.id()), ("temperature", temperature), ("dry_run", self.dry_run.to_string()), ("save", self.save.to_string()), @@ -471,18 +472,14 @@ impl Config { } self.session = Some(Session::new( TEMP_SESSION_NAME, - self.model_info.clone(), + self.model.clone(), self.role.clone(), )); } Some(name) => { let session_path = Self::session_file(name)?; if !session_path.exists() { - self.session = Some(Session::new( - name, - self.model_info.clone(), - self.role.clone(), - )); + self.session = Some(Session::new(name, self.model.clone(), self.role.clone())); } else { let session = Session::load(name, &session_path)?; let model = session.model().to_string(); @@ -608,7 +605,7 @@ impl Config { pub fn prepare_send_data(&self, content: &str, stream: bool) -> Result { let messages = self.build_messages(content)?; - self.model_info.max_tokens_limit(&messages)?; + self.model.max_tokens_limit(&messages)?; Ok(SendData { messages, temperature: self.get_temperature(), @@ -619,7 +616,7 @@ impl Config { pub fn maybe_print_send_tokens(&self, input: &str) { if self.dry_run { if let Ok(messages) = self.build_messages(input) { - let tokens = self.model_info.total_tokens(&messages); + let tokens = self.model.total_tokens(&messages); println!(">>> This message consumes {tokens} tokens. <<<"); } } @@ -666,8 +663,8 @@ impl Config { Ok(()) } - fn setup_model_info(&mut self) -> Result<()> { - let model = match &self.model { + fn setup_model(&mut self) -> Result<()> { + let model = match &self.model_id { Some(v) => v.clone(), None => { let models = list_models(self); @@ -716,7 +713,7 @@ impl Config { if let Some(model_name) = value.get("model").and_then(|v| v.as_str()) { if model_name.starts_with("gpt") { - self.model = Some(format!("{}:{}", OpenAIClient::NAME, model_name)); + self.model_id = Some(format!("{}:{}", OpenAIClient::NAME, model_name)); } } diff --git a/src/config/session.rs b/src/config/session.rs index 92e8c2aa..1aebd640 100644 --- a/src/config/session.rs +++ b/src/config/session.rs @@ -1,5 +1,5 @@ use super::role::Role; -use super::ModelInfo; +use super::Model; use crate::client::{Message, MessageRole}; use crate::render::MarkdownRender; @@ -14,7 +14,8 @@ pub const TEMP_SESSION_NAME: &str = "temp"; #[derive(Debug, Clone, Default, Deserialize, Serialize)] pub struct Session { - model: String, + #[serde(rename(serialize = "model", deserialize = "model"))] + model_id: String, temperature: Option, messages: Vec, #[serde(skip)] @@ -26,21 +27,21 @@ pub struct Session { #[serde(skip)] pub role: Option, #[serde(skip)] - pub model_info: ModelInfo, + pub model: Model, } impl Session { - pub fn new(name: &str, model_info: ModelInfo, role: Option) -> Self { + pub fn new(name: &str, model: Model, role: Option) -> Self { let temperature = role.as_ref().and_then(|v| v.temperature); Self { - model: model_info.id(), + model_id: model.id(), temperature, messages: vec![], name: name.to_string(), path: None, dirty: false, role, - model_info, + model, } } @@ -61,7 +62,7 @@ impl Session { } pub fn model(&self) -> &str { - &self.model + &self.model_id } pub fn temperature(&self) -> Option { @@ -69,7 +70,7 @@ impl Session { } pub fn tokens(&self) -> usize { - self.model_info.total_tokens(&self.messages) + self.model.total_tokens(&self.messages) } pub fn export(&self) -> Result { @@ -83,7 +84,7 @@ impl Session { data["temperature"] = temperature.into(); } data["total_tokens"] = tokens.into(); - if let Some(max_tokens) = self.model_info.max_tokens { + if let Some(max_tokens) = self.model.max_tokens { data["max_tokens"] = max_tokens.into(); } if percent != 0.0 { @@ -103,13 +104,13 @@ impl Session { items.push(("path", path.to_string())); } - items.push(("model", self.model_info.id())); + items.push(("model", self.model.id())); if let Some(temperature) = self.temperature() { items.push(("temperature", temperature.to_string())); } - if let Some(max_tokens) = self.model_info.max_tokens { + if let Some(max_tokens) = self.model.max_tokens { items.push(("max_tokens", max_tokens.to_string())); } @@ -143,7 +144,7 @@ impl Session { pub fn tokens_and_percent(&self) -> (usize, f32) { let tokens = self.tokens(); - let max_tokens = self.model_info.max_tokens.unwrap_or_default(); + let max_tokens = self.model.max_tokens.unwrap_or_default(); let percent = if max_tokens == 0 { 0.0 } else { @@ -164,9 +165,9 @@ impl Session { self.temperature = value; } - pub fn set_model(&mut self, model_info: ModelInfo) -> Result<()> { - self.model = model_info.id(); - self.model_info = model_info; + pub fn set_model(&mut self, model: Model) -> Result<()> { + self.model_id = model.id(); + self.model = model; Ok(()) }