Skip to content

Commit

Permalink
refactor: improve code quanity (#203)
Browse files Browse the repository at this point in the history
- update field name of ModelInfo
- rename ModelInfo to Model
  • Loading branch information
sigoden authored Nov 2, 2023
1 parent dce6877 commit f9c40e5
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 84 deletions.
12 changes: 6 additions & 6 deletions src/client/azure_openai.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS};
use super::{AzureOpenAIClient, ExtraConfig, PromptType, SendData, ModelInfo};
use super::{AzureOpenAIClient, ExtraConfig, PromptType, SendData, Model};

use crate::utils::PromptKind;

Expand Down Expand Up @@ -42,14 +42,14 @@ impl AzureOpenAIClient {
),
];

pub fn list_models(local_config: &AzureOpenAIConfig, index: usize) -> Vec<ModelInfo> {
let client = Self::name(local_config);
pub fn list_models(local_config: &AzureOpenAIConfig, client_index: usize) -> Vec<Model> {
let client_name = Self::name(local_config);

local_config
.models
.iter()
.map(|v| {
ModelInfo::new(index, client, &v.name)
Model::new(client_index, client_name, &v.name)
.set_max_tokens(v.max_tokens)
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
})
Expand All @@ -70,11 +70,11 @@ impl AzureOpenAIClient {

let api_base = self.get_api_base()?;

let body = openai_build_body(data, self.model_info.name.clone());
let body = openai_build_body(data, self.model.llm_name.clone());

let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version=2023-05-15",
&api_base, self.model_info.name
&api_base, self.model.llm_name
);

let builder = client.post(url).header("api-key", api_key).json(&body);
Expand Down
18 changes: 9 additions & 9 deletions src/client/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ macro_rules! register_client {
pub struct $client {
global_config: $crate::config::GlobalConfig,
config: $config,
model_info: $crate::client::ModelInfo,
model: $crate::client::Model,
}

impl $client {
pub const NAME: &str = $name;

pub fn init(global_config: $crate::config::GlobalConfig) -> Option<Box<dyn Client>> {
let model_info = global_config.read().model_info.clone();
let model = global_config.read().model.clone();
let config = {
if let ClientConfig::$config_key(c) = &global_config.read().clients[model_info.index] {
if let ClientConfig::$config_key(c) = &global_config.read().clients[model.client_index] {
c.clone()
} else {
return None;
Expand All @@ -64,7 +64,7 @@ macro_rules! register_client {
Some(Box::new(Self {
global_config,
config,
model_info,
model,
}))
}

Expand All @@ -79,11 +79,11 @@ macro_rules! register_client {
None
$(.or_else(|| $client::init(config.clone())))+
.ok_or_else(|| {
let model_info = config.read().model_info.clone();
let model = config.read().model.clone();
anyhow::anyhow!(
"Unknown client {} at config.clients[{}]",
&model_info.client,
&model_info.index
"Unknown client '{}' at config.clients[{}]",
&model.client_name,
&model.client_index
)
})
}
Expand All @@ -101,7 +101,7 @@ macro_rules! register_client {
anyhow::bail!("Unknown client {}", client)
}

pub fn list_models(config: &$crate::config::Config) -> Vec<$crate::client::ModelInfo> {
pub fn list_models(config: &$crate::config::Config) -> Vec<$crate::client::Model> {
config
.clients
.iter()
Expand Down
10 changes: 5 additions & 5 deletions src/client/localai.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::openai::{openai_build_body, OPENAI_TOKENS_COUNT_FACTORS};
use super::{ExtraConfig, LocalAIClient, PromptType, SendData, ModelInfo};
use super::{ExtraConfig, LocalAIClient, PromptType, SendData, Model};

use crate::utils::PromptKind;

Expand Down Expand Up @@ -41,14 +41,14 @@ impl LocalAIClient {
),
];

pub fn list_models(local_config: &LocalAIConfig, index: usize) -> Vec<ModelInfo> {
let client = Self::name(local_config);
pub fn list_models(local_config: &LocalAIConfig, client_index: usize) -> Vec<Model> {
let client_name = Self::name(local_config);

local_config
.models
.iter()
.map(|v| {
ModelInfo::new(index, client, &v.name)
Model::new(client_index, client_name, &v.name)
.set_max_tokens(v.max_tokens)
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
})
Expand All @@ -58,7 +58,7 @@ impl LocalAIClient {
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key().ok();

let body = openai_build_body(data, self.model_info.name.clone());
let body = openai_build_body(data, self.model.llm_name.clone());

let chat_endpoint = self
.config
Expand Down
4 changes: 2 additions & 2 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#[macro_use]
mod common;
mod message;
mod model_info;
mod model;

pub use common::*;
pub use message::*;
pub use model_info::*;
pub use model::*;

register_client!(
(openai, "openai", OpenAI, OpenAIConfig, OpenAIClient),
Expand Down
30 changes: 15 additions & 15 deletions src/client/model_info.rs → src/client/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,35 @@ use anyhow::{bail, Result};
pub type TokensCountFactors = (usize, usize); // (per-messages, bias)

#[derive(Debug, Clone)]
pub struct ModelInfo {
pub client: String,
pub name: String,
pub index: usize,
pub struct Model {
pub client_index: usize,
pub client_name: String,
pub llm_name: String,
pub max_tokens: Option<usize>,
pub tokens_count_factors: TokensCountFactors,
}

impl Default for ModelInfo {
impl Default for Model {
fn default() -> Self {
ModelInfo::new(0, "", "")
Model::new(0, "", "")
}
}

impl ModelInfo {
pub fn new(index: usize, client: &str, name: &str) -> Self {
impl Model {
pub fn new(client_index: usize, client_name: &str, name: &str) -> Self {
Self {
index,
client: client.into(),
name: name.into(),
client_index,
client_name: client_name.into(),
llm_name: name.into(),
max_tokens: None,
tokens_count_factors: Default::default(),
}
}

pub fn id(&self) -> String {
format!("{}:{}", self.client_name, self.llm_name)
}

pub fn set_max_tokens(mut self, max_tokens: Option<usize>) -> Self {
match max_tokens {
None | Some(0) => self.max_tokens = None,
Expand All @@ -45,10 +49,6 @@ impl ModelInfo {
self
}

pub fn id(&self) -> String {
format!("{}:{}", self.client, self.name)
}

pub fn messages_tokens(&self, messages: &[Message]) -> usize {
messages.iter().map(|v| count_tokens(&v.content)).sum()
}
Expand Down
10 changes: 5 additions & 5 deletions src/client/openai.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
ExtraConfig, OpenAIClient, PromptType, SendData,
ModelInfo, TokensCountFactors,
Model, TokensCountFactors,
};

use crate::{
Expand Down Expand Up @@ -44,12 +44,12 @@ impl OpenAIClient {
pub const PROMPTS: [PromptType<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];

pub fn list_models(local_config: &OpenAIConfig, index: usize) -> Vec<ModelInfo> {
let client = Self::name(local_config);
pub fn list_models(local_config: &OpenAIConfig, client_index: usize) -> Vec<Model> {
let client_name = Self::name(local_config);
MODELS
.into_iter()
.map(|(name, max_tokens)| {
ModelInfo::new(index, client, name)
Model::new(client_index, client_name, name)
.set_max_tokens(Some(max_tokens))
.set_tokens_count_factors(OPENAI_TOKENS_COUNT_FACTORS)
})
Expand All @@ -59,7 +59,7 @@ impl OpenAIClient {
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.get_api_key()?;

let body = openai_build_body(data, self.model_info.name.clone());
let body = openai_build_body(data, self.model.llm_name.clone());

let env_prefix = Self::name(&self.config).to_uppercase();
let api_base = env::var(format!("{env_prefix}_API_BASE"))
Expand Down
51 changes: 24 additions & 27 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use self::session::{Session, TEMP_SESSION_NAME};

use crate::client::{
create_client_config, list_client_types, list_models, ClientConfig, ExtraConfig, Message,
ModelInfo, OpenAIClient, SendData,
Model, OpenAIClient, SendData,
};
use crate::render::{MarkdownRender, RenderOptions};
use crate::utils::{get_env_name, light_theme_from_colorfgbg, now, prompt_op_err};
Expand Down Expand Up @@ -41,7 +41,8 @@ const CLIENTS_FIELD: &str = "clients";
#[serde(default)]
pub struct Config {
/// LLM model
pub model: Option<String>,
#[serde(rename(serialize = "model", deserialize = "model"))]
pub model_id: Option<String>,
/// GPT temperature, between 0 and 2
#[serde(rename(serialize = "temperature", deserialize = "temperature"))]
pub default_temperature: Option<f64>,
Expand Down Expand Up @@ -73,7 +74,7 @@ pub struct Config {
#[serde(skip)]
pub session: Option<Session>,
#[serde(skip)]
pub model_info: ModelInfo,
pub model: Model,
#[serde(skip)]
pub last_message: Option<(String, String)>,
#[serde(skip)]
Expand All @@ -83,7 +84,7 @@ pub struct Config {
impl Default for Config {
fn default() -> Self {
Self {
model: None,
model_id: None,
default_temperature: None,
save: true,
highlight: true,
Expand All @@ -97,7 +98,7 @@ impl Default for Config {
roles: vec![],
role: None,
session: None,
model_info: Default::default(),
model: Default::default(),
last_message: None,
temperature: None,
}
Expand Down Expand Up @@ -135,7 +136,7 @@ impl Config {

config.load_roles()?;

config.setup_model_info()?;
config.setup_model()?;
config.setup_highlight();
config.setup_light_theme()?;

Expand Down Expand Up @@ -304,22 +305,22 @@ impl Config {

pub fn set_model(&mut self, value: &str) -> Result<()> {
let models = list_models(self);
let mut model_info = None;
let mut model = None;
let value = value.trim_end_matches(':');
if value.contains(':') {
if let Some(model) = models.iter().find(|v| v.id() == value) {
model_info = Some(model.clone());
if let Some(found) = models.iter().find(|v| v.id() == value) {
model = Some(found.clone());
}
} else if let Some(model) = models.iter().find(|v| v.client == value) {
model_info = Some(model.clone());
} else if let Some(found) = models.iter().find(|v| v.client_name == value) {
model = Some(found.clone());
}
match model_info {
match model {
None => bail!("Unknown model '{}'", value),
Some(model_info) => {
Some(model) => {
if let Some(session) = self.session.as_mut() {
session.set_model(model_info.clone())?;
session.set_model(model.clone())?;
}
self.model_info = model_info;
self.model = model;
Ok(())
}
}
Expand All @@ -338,7 +339,7 @@ impl Config {
.clone()
.map_or_else(|| String::from("no"), |v| v.to_string());
let items = vec![
("model", self.model_info.id()),
("model", self.model.id()),
("temperature", temperature),
("dry_run", self.dry_run.to_string()),
("save", self.save.to_string()),
Expand Down Expand Up @@ -471,18 +472,14 @@ impl Config {
}
self.session = Some(Session::new(
TEMP_SESSION_NAME,
self.model_info.clone(),
self.model.clone(),
self.role.clone(),
));
}
Some(name) => {
let session_path = Self::session_file(name)?;
if !session_path.exists() {
self.session = Some(Session::new(
name,
self.model_info.clone(),
self.role.clone(),
));
self.session = Some(Session::new(name, self.model.clone(), self.role.clone()));
} else {
let session = Session::load(name, &session_path)?;
let model = session.model().to_string();
Expand Down Expand Up @@ -608,7 +605,7 @@ impl Config {

pub fn prepare_send_data(&self, content: &str, stream: bool) -> Result<SendData> {
let messages = self.build_messages(content)?;
self.model_info.max_tokens_limit(&messages)?;
self.model.max_tokens_limit(&messages)?;
Ok(SendData {
messages,
temperature: self.get_temperature(),
Expand All @@ -619,7 +616,7 @@ impl Config {
pub fn maybe_print_send_tokens(&self, input: &str) {
if self.dry_run {
if let Ok(messages) = self.build_messages(input) {
let tokens = self.model_info.total_tokens(&messages);
let tokens = self.model.total_tokens(&messages);
println!(">>> This message consumes {tokens} tokens. <<<");
}
}
Expand Down Expand Up @@ -666,8 +663,8 @@ impl Config {
Ok(())
}

fn setup_model_info(&mut self) -> Result<()> {
let model = match &self.model {
fn setup_model(&mut self) -> Result<()> {
let model = match &self.model_id {
Some(v) => v.clone(),
None => {
let models = list_models(self);
Expand Down Expand Up @@ -716,7 +713,7 @@ impl Config {

if let Some(model_name) = value.get("model").and_then(|v| v.as_str()) {
if model_name.starts_with("gpt") {
self.model = Some(format!("{}:{}", OpenAIClient::NAME, model_name));
self.model_id = Some(format!("{}:{}", OpenAIClient::NAME, model_name));
}
}

Expand Down
Loading

0 comments on commit f9c40e5

Please sign in to comment.