From 49eb5329a2e763448c6f6a2f438d6b088867cdb3 Mon Sep 17 00:00:00 2001 From: sigoden Date: Wed, 1 Nov 2023 09:27:34 +0800 Subject: [PATCH] refactor: clients/* and config.rs - add register_clients to make it easier to add a new client - no create_client_config, just add const PROMPTS - move ModelInfo from clients/ to config/ - model's max_tokens are optional - improve code quanity on config/mod.rs --- Cargo.lock | 1 + Cargo.toml | 2 +- config.example.yaml | 11 +- src/client/azure_openai.rs | 74 ++------- src/client/localai.rs | 74 ++------- src/client/mod.rs | 315 +++++++++++++++++++++---------------- src/client/openai.rs | 97 +++--------- src/config/message.rs | 9 +- src/config/mod.rs | 117 ++++++++------ src/config/model_info.rs | 27 ++++ src/main.rs | 4 +- src/repl/prompt.rs | 6 +- src/utils/mod.rs | 2 + src/utils/prompt_input.rs | 58 +++++++ 14 files changed, 410 insertions(+), 387 deletions(-) create mode 100644 src/config/model_info.rs create mode 100644 src/utils/prompt_input.rs diff --git a/Cargo.lock b/Cargo.lock index 45ab516f..bae2ce45 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1611,6 +1611,7 @@ version = "1.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" dependencies = [ + "indexmap 2.0.2", "itoa", "ryu", "serde", diff --git a/Cargo.toml b/Cargo.toml index 121d222c..49e21d12 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ inquire = "0.6.2" is-terminal = "0.4.9" reedline = "0.21.0" serde = { version = "1.0.152", features = ["derive"] } -serde_json = "1.0.93" +serde_json = { version = "1.0.93", features = ["preserve_order"] } serde_yaml = "0.9.17" tokio = { version = "1.26.0", features = ["rt", "time", "macros", "signal"] } crossbeam = "0.8.2" diff --git a/config.example.yaml b/config.example.yaml index 52c147c5..42c5f6e6 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -10,33 +10,30 @@ keybindings: emacs # REPL keybindings. values: emacs, vi clients: # All clients have the following configuration: - # ``` # - type: xxxx # name: nova # Only use it to distinguish clients with the same client type. Optional # extra: # proxy: socks5://127.0.0.1:1080 # Specify https/socks5 proxy server. Note HTTPS_PROXY/ALL_PROXY also works. # connect_timeout: 10 # Set a timeout in seconds for connect to server - # ``` # See https://platform.openai.com/docs/quickstart - type: openai api_key: sk-xxx - organization_id: org-xxx # Organization ID. Optional + organization_id: # See https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart - type: azure-openai api_base: https://RESOURCE.openai.azure.com api_key: xxx - models: # Support models + models: - name: MyGPT4 # Model deployment name max_tokens: 8192 - # See https://github.com/go-skynet/LocalAI - type: localai api_base: http://localhost:8080/v1 api_key: xxx - chat_endpoint: /chat/completions # Optional - models: # Support models + chat_endpoint: /chat/completions + models: - name: gpt4all-j max_tokens: 8192 \ No newline at end of file diff --git a/src/client/azure_openai.rs b/src/client/azure_openai.rs index 45a5eccb..9d82cd39 100644 --- a/src/client/azure_openai.rs +++ b/src/client/azure_openai.rs @@ -1,8 +1,5 @@ use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming}; -use super::{ - prompt_input_api_base, prompt_input_api_key, prompt_input_max_token, prompt_input_model_name, - Client, ClientConfig, ExtraConfig, ModelInfo, SendData, -}; +use super::{AzureOpenAIClient, Client, ExtraConfig, ModelInfo, PromptKind, PromptType, SendData}; use crate::config::SharedConfig; use crate::repl::ReplyStreamHandler; @@ -14,13 +11,6 @@ use serde::Deserialize; use std::env; -#[derive(Debug)] -pub struct AzureOpenAIClient { - global_config: SharedConfig, - config: AzureOpenAIConfig, - model_info: ModelInfo, -} - #[derive(Debug, Clone, Deserialize)] pub struct AzureOpenAIConfig { pub name: Option, @@ -33,17 +23,13 @@ pub struct AzureOpenAIConfig { #[derive(Debug, Clone, Deserialize)] pub struct AzureOpenAIModel { name: String, - max_tokens: usize, + max_tokens: Option, } #[async_trait] impl Client for AzureOpenAIClient { - fn config(&self) -> &SharedConfig { - &self.global_config - } - - fn extra_config(&self) -> &Option { - &self.config.extra + fn config(&self) -> (&SharedConfig, &Option) { + (&self.global_config, &self.config.extra) } async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { @@ -63,27 +49,17 @@ impl Client for AzureOpenAIClient { } impl AzureOpenAIClient { - pub const NAME: &str = "azure-openai"; - - pub fn init(global_config: SharedConfig) -> Option> { - let model_info = global_config.read().model_info.clone(); - let config = { - if let ClientConfig::AzureOpenAI(c) = &global_config.read().clients[model_info.index] { - c.clone() - } else { - return None; - } - }; - Some(Box::new(Self { - global_config, - config, - model_info, - })) - } - - pub fn name(local_config: &AzureOpenAIConfig) -> &str { - local_config.name.as_deref().unwrap_or(Self::NAME) - } + pub const PROMPTS: [PromptType<'static>; 4] = [ + ("api_base", "API Base:", true, PromptKind::String), + ("api_key", "API Key:", true, PromptKind::String), + ("models[].name", "Model Name:", true, PromptKind::String), + ( + "models[].max_tokens", + "Max Tokens:", + true, + PromptKind::Integer, + ), + ]; pub fn list_models(local_config: &AzureOpenAIConfig, index: usize) -> Vec { let client = Self::name(local_config); @@ -95,26 +71,6 @@ impl AzureOpenAIClient { .collect() } - pub fn create_config() -> Result { - let mut client_config = format!("clients:\n - type: {}\n", Self::NAME); - - let api_base = prompt_input_api_base()?; - client_config.push_str(&format!(" api_base: {api_base}\n")); - - let api_key = prompt_input_api_key()?; - client_config.push_str(&format!(" api_key: {api_key}\n")); - - let model_name = prompt_input_model_name()?; - - let max_tokens = prompt_input_max_token()?; - - client_config.push_str(&format!( - " models:\n - name: {model_name}\n max_tokens: {max_tokens}\n" - )); - - Ok(client_config) - } - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.config.api_key.clone(); let api_key = api_key diff --git a/src/client/localai.rs b/src/client/localai.rs index fb5a3534..1aa2337e 100644 --- a/src/client/localai.rs +++ b/src/client/localai.rs @@ -1,8 +1,5 @@ use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming}; -use super::{ - prompt_input_api_base, prompt_input_api_key_optional, prompt_input_max_token, - prompt_input_model_name, Client, ClientConfig, ExtraConfig, ModelInfo, SendData, -}; +use super::{Client, ExtraConfig, LocalAIClient, ModelInfo, PromptKind, PromptType, SendData}; use crate::config::SharedConfig; use crate::repl::ReplyStreamHandler; @@ -13,13 +10,6 @@ use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; use std::env; -#[derive(Debug)] -pub struct LocalAIClient { - global_config: SharedConfig, - config: LocalAIConfig, - model_info: ModelInfo, -} - #[derive(Debug, Clone, Deserialize)] pub struct LocalAIConfig { pub name: Option, @@ -33,17 +23,13 @@ pub struct LocalAIConfig { #[derive(Debug, Clone, Deserialize)] pub struct LocalAIModel { name: String, - max_tokens: usize, + max_tokens: Option, } #[async_trait] impl Client for LocalAIClient { - fn config(&self) -> &SharedConfig { - &self.global_config - } - - fn extra_config(&self) -> &Option { - &self.config.extra + fn config(&self) -> (&SharedConfig, &Option) { + (&self.global_config, &self.config.extra) } async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { @@ -63,27 +49,17 @@ impl Client for LocalAIClient { } impl LocalAIClient { - pub const NAME: &str = "localai"; - - pub fn init(global_config: SharedConfig) -> Option> { - let model_info = global_config.read().model_info.clone(); - let config = { - if let ClientConfig::LocalAI(c) = &global_config.read().clients[model_info.index] { - c.clone() - } else { - return None; - } - }; - Some(Box::new(Self { - global_config, - config, - model_info, - })) - } - - pub fn name(local_config: &LocalAIConfig) -> &str { - local_config.name.as_deref().unwrap_or(Self::NAME) - } + pub const PROMPTS: [PromptType<'static>; 4] = [ + ("api_base", "API Base:", true, PromptKind::String), + ("api_key", "API Key:", false, PromptKind::String), + ("models[].name", "Model Name:", true, PromptKind::String), + ( + "models[].max_tokens", + "Max Tokens:", + false, + PromptKind::Integer, + ), + ]; pub fn list_models(local_config: &LocalAIConfig, index: usize) -> Vec { let client = Self::name(local_config); @@ -95,26 +71,6 @@ impl LocalAIClient { .collect() } - pub fn create_config() -> Result { - let mut client_config = format!("clients:\n - type: {}\n", Self::NAME); - - let api_base = prompt_input_api_base()?; - client_config.push_str(&format!(" api_base: {api_base}\n")); - - let api_key = prompt_input_api_key_optional()?; - client_config.push_str(&format!(" api_key: {api_key}\n")); - - let model_name = prompt_input_model_name()?; - - let max_tokens = prompt_input_max_token()?; - - client_config.push_str(&format!( - " models:\n - name: {model_name}\n max_tokens: {max_tokens}\n" - )); - - Ok(client_config) - } - fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { let api_key = self.config.api_key.clone(); let api_key = api_key.or_else(|| { diff --git a/src/client/mod.rs b/src/client/mod.rs index bef6094b..b0326650 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -2,63 +2,28 @@ pub mod azure_openai; pub mod localai; pub mod openai; -use self::{ - azure_openai::{AzureOpenAIClient, AzureOpenAIConfig}, - localai::LocalAIConfig, - openai::{OpenAIClient, OpenAIConfig}, -}; +use self::azure_openai::AzureOpenAIConfig; +use self::localai::LocalAIConfig; +use self::openai::OpenAIConfig; use crate::{ - client::localai::LocalAIClient, - config::{Config, Message, SharedConfig}, + config::{Config, Message, ModelInfo, SharedConfig}, repl::{ReplyStreamHandler, SharedAbortSignal}, - utils::tokenize, + utils::{prompt_input_integer, prompt_input_string, tokenize, PromptKind}, }; use anyhow::{anyhow, bail, Context, Result}; use async_trait::async_trait; -use inquire::{required, Text}; use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy}; use serde::Deserialize; +use serde_json::{json, Value}; use std::{env, time::Duration}; use tokio::time::sleep; -#[derive(Debug, Clone, Deserialize)] -#[serde(tag = "type")] -pub enum ClientConfig { - #[serde(rename = "openai")] - OpenAI(OpenAIConfig), - #[serde(rename = "localai")] - LocalAI(LocalAIConfig), - #[serde(rename = "azure-openai")] - AzureOpenAI(AzureOpenAIConfig), -} -#[derive(Debug, Clone)] -pub struct ModelInfo { - pub client: String, - pub name: String, - pub max_tokens: usize, - pub index: usize, -} - -impl Default for ModelInfo { - fn default() -> Self { - OpenAIClient::list_models(&OpenAIConfig::default(), 0)[0].clone() - } -} - -impl ModelInfo { - pub fn new(client: &str, name: &str, max_tokens: usize, index: usize) -> Self { - Self { - client: client.into(), - name: name.into(), - max_tokens, - index, - } - } - pub fn stringify(&self) -> String { - format!("{}:{}", self.client, self.name) - } +#[derive(Debug, Clone, Deserialize, Default)] +pub struct ExtraConfig { + pub proxy: Option, + pub connect_timeout: Option, } #[derive(Debug)] @@ -67,15 +32,14 @@ pub struct SendData { pub temperature: Option, pub stream: bool, } + #[async_trait] pub trait Client { - fn config(&self) -> &SharedConfig; - - fn extra_config(&self) -> &Option; + fn config(&self) -> (&SharedConfig, &Option); fn build_client(&self) -> Result { let mut builder = ReqwestClient::builder(); - let options = self.extra_config(); + let options = self.config().1; let timeout = options .as_ref() .and_then(|v| v.connect_timeout) @@ -91,15 +55,16 @@ pub trait Client { fn send_message(&self, content: &str) -> Result { init_tokio_runtime()?.block_on(async { - if self.config().read().dry_run { - let content = self.config().read().echo_messages(content); + let global_config = self.config().0; + if global_config.read().dry_run { + let content = global_config.read().echo_messages(content); return Ok(content); } let client = self.build_client()?; - let data = self.config().read().prepare_send_data(content, false)?; + let data = global_config.read().prepare_send_data(content, false)?; self.send_message_inner(&client, data) .await - .with_context(|| "Failed to fetch") + .with_context(|| "Failed to get awswer") }) } @@ -120,8 +85,9 @@ pub trait Client { init_tokio_runtime()?.block_on(async { tokio::select! { ret = async { - if self.config().read().dry_run { - let content = self.config().read().echo_messages(content); + let global_config = self.config().0; + if global_config.read().dry_run { + let content = global_config.read().echo_messages(content); let tokens = tokenize(&content); for token in tokens { tokio::time::sleep(Duration::from_millis(25)).await; @@ -130,11 +96,11 @@ pub trait Client { return Ok(()); } let client = self.build_client()?; - let data = self.config().read().prepare_send_data(content, true)?; + let data = global_config.read().prepare_send_data(content, true)?; self.send_message_streaming_inner(&client, handler, data).await } => { handler.done()?; - ret.with_context(|| "Failed to fetch stream") + ret.with_context(|| "Failed to get awswer") } _ = watch_abort(abort.clone()) => { handler.done()?; @@ -158,101 +124,184 @@ pub trait Client { ) -> Result<()>; } -#[derive(Debug, Clone, Deserialize, Default)] -pub struct ExtraConfig { - pub proxy: Option, - pub connect_timeout: Option, -} +macro_rules! register_role { + ( + $(($name:literal, $config_key:ident, $config:ident, $client:ident),)+ + ) => { + + #[derive(Debug, Clone, Deserialize)] + #[serde(tag = "type")] + pub enum ClientConfig { + $( + #[serde(rename = $name)] + $config_key($config), + )+ + #[serde(other)] + Unknown, + } -pub fn init_client(config: SharedConfig) -> Result> { - OpenAIClient::init(config.clone()) - .or_else(|| LocalAIClient::init(config.clone())) - .or_else(|| AzureOpenAIClient::init(config.clone())) - .ok_or_else(|| { - let model_info = config.read().model_info.clone(); - anyhow!( - "Unknown client {} at config.clients[{}]", - &model_info.client, - &model_info.index - ) - }) -} -pub fn list_client_types() -> Vec<&'static str> { - vec![ - OpenAIClient::NAME, - LocalAIClient::NAME, - AzureOpenAIClient::NAME, - ] + $( + #[derive(Debug)] + pub struct $client { + global_config: SharedConfig, + config: $config, + model_info: ModelInfo, + } + + impl $client { + pub const NAME: &str = $name; + + pub fn init(global_config: SharedConfig) -> Option> { + let model_info = global_config.read().model_info.clone(); + let config = { + if let ClientConfig::$config_key(c) = &global_config.read().clients[model_info.index] { + c.clone() + } else { + return None; + } + }; + Some(Box::new(Self { + global_config, + config, + model_info, + })) + } + + pub fn name(local_config: &$config) -> &str { + local_config.name.as_deref().unwrap_or(Self::NAME) + } + } + + )+ + + pub fn init_client(config: SharedConfig) -> Result> { + None + $(.or_else(|| $client::init(config.clone())))+ + .ok_or_else(|| { + let model_info = config.read().model_info.clone(); + anyhow!( + "Unknown client {} at config.clients[{}]", + &model_info.client, + &model_info.index + ) + }) + } + + pub fn list_client_types() -> Vec<&'static str> { + vec![$($client::NAME,)+] + } + + pub fn create_client_config(client: &str) -> Result { + $( + if client == $client::NAME { + return create_config(&$client::PROMPTS, $client::NAME) + } + )+ + bail!("Unknown client {}", client) + } + + pub fn all_models(config: &Config) -> Vec { + config + .clients + .iter() + .enumerate() + .flat_map(|(i, v)| match v { + $(ClientConfig::$config_key(c) => $client::list_models(c, i),)+ + ClientConfig::Unknown => vec![], + }) + .collect() + } + + }; } -pub fn create_client_config(client: &str) -> Result { - if client == OpenAIClient::NAME { - OpenAIClient::create_config() - } else if client == LocalAIClient::NAME { - LocalAIClient::create_config() - } else if client == AzureOpenAIClient::NAME { - AzureOpenAIClient::create_config() - } else { - bail!("Unknown client {}", &client) +register_role!( + ("openai", OpenAI, OpenAIConfig, OpenAIClient), + ("localai", LocalAI, LocalAIConfig, LocalAIClient), + ( + "azure-openai", + AzureOpenAI, + AzureOpenAIConfig, + AzureOpenAIClient + ), +); + +impl Default for ClientConfig { + fn default() -> Self { + Self::OpenAI(OpenAIConfig::default()) } } -pub fn list_models(config: &Config) -> Vec { - config - .clients - .iter() - .enumerate() - .flat_map(|(i, v)| match v { - ClientConfig::OpenAI(c) => OpenAIClient::list_models(c, i), - ClientConfig::LocalAI(c) => LocalAIClient::list_models(c, i), - ClientConfig::AzureOpenAI(c) => AzureOpenAIClient::list_models(c, i), - }) - .collect() -} +type PromptType<'a> = (&'a str, &'a str, bool, PromptKind); -pub(crate) fn init_tokio_runtime() -> Result { +fn init_tokio_runtime() -> Result { tokio::runtime::Builder::new_current_thread() .enable_all() .build() .with_context(|| "Failed to init tokio") } -pub(crate) fn prompt_input_api_base() -> Result { - Text::new("API Base:") - .with_validator(required!("This field is required")) - .prompt() - .map_err(prompt_op_err) -} - -pub(crate) fn prompt_input_api_key() -> Result { - Text::new("API Key:") - .with_validator(required!("This field is required")) - .prompt() - .map_err(prompt_op_err) -} - -pub(crate) fn prompt_input_api_key_optional() -> Result { - Text::new("API Key:").prompt().map_err(prompt_op_err) -} +fn create_config(list: &[PromptType], client: &str) -> Result { + let mut config = json!({ + "type": client, + }); + for (path, desc, required, kind) in list { + match kind { + PromptKind::String => { + let value = prompt_input_string(desc, *required)?; + set_config_value(&mut config, path, kind, &value); + } + PromptKind::Integer => { + let value = prompt_input_integer(desc, *required)?; + set_config_value(&mut config, path, kind, &value); + } + } + } -pub(crate) fn prompt_input_model_name() -> Result { - Text::new("Model Name:") - .with_validator(required!("This field is required")) - .prompt() - .map_err(prompt_op_err) + let clients = json!(vec![config]); + Ok(clients) } -pub(crate) fn prompt_input_max_token() -> Result { - Text::new("Max tokens:") - .with_default("4096") - .with_validator(required!("This field is required")) - .prompt() - .map_err(prompt_op_err) +fn set_config_value(json: &mut Value, path: &str, kind: &PromptKind, value: &str) { + let segs: Vec<&str> = path.split('.').collect(); + match segs.as_slice() { + [name] => json[name] = to_json(kind, value), + [scope, name] => match scope.split_once('[') { + None => { + if json.get(scope).is_none() { + let mut obj = json!({}); + obj[name] = to_json(kind, value); + json[scope] = obj; + } else { + json[scope][name] = to_json(kind, value); + } + } + Some((scope, _)) => { + if json.get(scope).is_none() { + let mut obj = json!({}); + obj[name] = to_json(kind, value); + json[scope] = json!([obj]); + } else { + json[scope][0][name] = to_json(kind, value); + } + } + }, + _ => {} + } } -pub(crate) fn prompt_op_err(_: T) -> anyhow::Error { - anyhow!("An error happened, try again later.") +fn to_json(kind: &PromptKind, value: &str) -> Value { + if value.is_empty() { + return Value::Null; + } + match kind { + PromptKind::String => value.into(), + PromptKind::Integer => match value.parse::() { + Ok(value) => value.into(), + Err(_) => value.into(), + }, + } } fn set_proxy(builder: ClientBuilder, proxy: &Option) -> Result { diff --git a/src/client/openai.rs b/src/client/openai.rs index 57cd5fd4..39f73a4b 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -1,4 +1,4 @@ -use super::{prompt_input_api_key, Client, ClientConfig, ExtraConfig, ModelInfo, SendData}; +use super::{Client, ExtraConfig, ModelInfo, OpenAIClient, PromptKind, PromptType, SendData}; use crate::config::SharedConfig; use crate::repl::ReplyStreamHandler; @@ -14,12 +14,12 @@ use std::env; const API_BASE: &str = "https://api.openai.com/v1"; -#[derive(Debug)] -pub struct OpenAIClient { - global_config: SharedConfig, - config: OpenAIConfig, - model_info: ModelInfo, -} +const MODELS: [(&str, usize); 4] = [ + ("gpt-3.5-turbo", 4096), + ("gpt-3.5-turbo-16k", 16384), + ("gpt-4", 8192), + ("gpt-4-32k", 32768), +]; #[derive(Debug, Clone, Deserialize, Default)] pub struct OpenAIConfig { @@ -31,12 +31,8 @@ pub struct OpenAIConfig { #[async_trait] impl Client for OpenAIClient { - fn config(&self) -> &SharedConfig { - &self.global_config - } - - fn extra_config(&self) -> &Option { - &self.config.extra + fn config(&self) -> (&SharedConfig, &Option) { + (&self.global_config, &self.config.extra) } async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result { @@ -56,49 +52,15 @@ impl Client for OpenAIClient { } impl OpenAIClient { - pub const NAME: &str = "openai"; - - pub fn init(global_config: SharedConfig) -> Option> { - let model_info = global_config.read().model_info.clone(); - let config = { - if let ClientConfig::OpenAI(c) = &global_config.read().clients[model_info.index] { - c.clone() - } else { - return None; - } - }; - Some(Box::new(Self { - global_config, - config, - model_info, - })) - } - - pub fn name(local_config: &OpenAIConfig) -> &str { - local_config.name.as_deref().unwrap_or(Self::NAME) - } + pub const PROMPTS: [PromptType<'static>; 1] = + [("api_key", "API Key:", true, PromptKind::String)]; pub fn list_models(local_config: &OpenAIConfig, index: usize) -> Vec { let client = Self::name(local_config); - - [ - ("gpt-3.5-turbo", 4096), - ("gpt-3.5-turbo-16k", 16384), - ("gpt-4", 8192), - ("gpt-4-32k", 32768), - ] - .into_iter() - .map(|(name, max_tokens)| ModelInfo::new(client, name, max_tokens, index)) - .collect() - } - - pub fn create_config() -> Result { - let mut client_config = format!("clients:\n - type: {}\n", Self::NAME); - - let api_key = prompt_input_api_key()?; - client_config.push_str(&format!(" api_key: {api_key}\n")); - - Ok(client_config) + MODELS + .into_iter() + .map(|(name, max_tokens)| ModelInfo::new(client, name, Some(max_tokens), index)) + .collect() } fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result { @@ -127,20 +89,20 @@ impl OpenAIClient { } } -pub(crate) async fn openai_send_message(builder: RequestBuilder) -> Result { +pub async fn openai_send_message(builder: RequestBuilder) -> Result { let data: Value = builder.send().await?.json().await?; if let Some(err_msg) = data["error"]["message"].as_str() { - bail!("Request failed, {err_msg}"); + bail!("{err_msg}"); } let output = data["choices"][0]["message"]["content"] .as_str() - .ok_or_else(|| anyhow!("Unexpected response {data}"))?; + .ok_or_else(|| anyhow!("Invalid response data: {data}"))?; Ok(output.to_string()) } -pub(crate) async fn openai_send_message_streaming( +pub async fn openai_send_message_streaming( builder: RequestBuilder, handler: &mut ReplyStreamHandler, ) -> Result<()> { @@ -148,7 +110,7 @@ pub(crate) async fn openai_send_message_streaming( if !res.status().is_success() { let data: Value = res.json().await?; if let Some(err_msg) = data["error"]["message"].as_str() { - bail!("Request failed, {err_msg}"); + bail!("{err_msg}"); } bail!("Request failed"); } @@ -159,37 +121,30 @@ pub(crate) async fn openai_send_message_streaming( break; } let data: Value = serde_json::from_str(&chunk)?; - let text = data["choices"][0]["delta"]["content"] - .as_str() - .unwrap_or_default(); - if text.is_empty() { - continue; + if let Some(text) = data["choices"][0]["delta"]["content"].as_str() { + handler.text(text)?; } - handler.text(text)?; } Ok(()) } -pub(crate) fn openai_build_body(data: SendData, model: String) -> Value { +pub fn openai_build_body(data: SendData, model: String) -> Value { let SendData { messages, temperature, stream, } = data; + let mut body = json!({ "model": model, "messages": messages, }); - if let Some(v) = temperature { - body.as_object_mut() - .and_then(|m| m.insert("temperature".into(), json!(v))); + body["temperature"] = v.into(); } - if stream { - body.as_object_mut() - .and_then(|m| m.insert("stream".into(), json!(true))); + body["stream"] = true.into(); } body } diff --git a/src/config/message.rs b/src/config/message.rs index d5fcff1b..5882337a 100644 --- a/src/config/message.rs +++ b/src/config/message.rs @@ -17,7 +17,7 @@ impl Message { } } -#[derive(Debug, Clone, Deserialize, Serialize)] +#[derive(Debug, Clone, Copy, Deserialize, Serialize)] #[serde(rename_all = "snake_case")] pub enum MessageRole { System, @@ -25,6 +25,13 @@ pub enum MessageRole { User, } +impl MessageRole { + #[allow(dead_code)] + pub fn is_system(&self) -> bool { + matches!(self, MessageRole::System) + } +} + pub fn num_tokens_from_messages(messages: &[Message]) -> usize { let mut num_tokens = 0; for message in messages.iter() { diff --git a/src/config/mod.rs b/src/config/mod.rs index 956bea01..18f17317 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,19 +1,20 @@ mod message; +mod model_info; mod role; mod session; pub use self::message::Message; +pub use self::model_info::ModelInfo; use self::role::Role; use self::session::{Session, TEMP_SESSION_NAME}; -use crate::client::openai::{OpenAIClient, OpenAIConfig}; use crate::client::{ - create_client_config, list_client_types, list_models, prompt_op_err, ClientConfig, ExtraConfig, - ModelInfo, SendData, + all_models, create_client_config, list_client_types, ClientConfig, ExtraConfig, OpenAIClient, + SendData, }; use crate::config::message::num_tokens_from_messages; use crate::render::RenderOptions; -use crate::utils::{get_env_name, light_theme_from_colorfgbg, now}; +use crate::utils::{get_env_name, light_theme_from_colorfgbg, now, prompt_op_err}; use anyhow::{anyhow, bail, Context, Result}; use inquire::{Confirm, Select, Text}; @@ -49,6 +50,8 @@ const SET_COMPLETIONS: [&str; 7] = [ ".set dry_run false", ]; +const CLIENTS_FIELD: &str = "clients"; + #[derive(Debug, Clone, Deserialize)] #[serde(default)] pub struct Config { @@ -61,7 +64,7 @@ pub struct Config { pub save: bool, /// Whether to disable highlight pub highlight: bool, - /// Used only for debugging + /// Dry-run flag pub dry_run: bool, /// Whether to use a light theme pub light_theme: bool, @@ -105,7 +108,7 @@ impl Default for Config { wrap_code: false, auto_copy: false, keybindings: Default::default(), - clients: vec![ClientConfig::OpenAI(OpenAIConfig::default())], + clients: vec![ClientConfig::default()], roles: vec![], role: None, session: None, @@ -145,11 +148,11 @@ impl Config { config.temperature = config.default_temperature; - config.set_model_info()?; - config.merge_env_vars(); config.load_roles()?; - config.ensure_sessions_dir()?; - config.detect_theme()?; + + config.setup_model_info()?; + config.setup_highlight(); + config.setup_light_theme()?; Ok(config) } @@ -296,8 +299,10 @@ impl Config { vec![message] }; let tokens = num_tokens_from_messages(&messages); - if tokens >= self.model_info.max_tokens { - bail!("Exceed max tokens limit") + if let Some(max_tokens) = self.model_info.max_tokens { + if tokens >= max_tokens { + bail!("Exceed max tokens limit") + } } Ok(messages) @@ -318,7 +323,7 @@ impl Config { } pub fn set_model(&mut self, value: &str) -> Result<()> { - let models = list_models(self); + let models = all_models(self); let mut model_info = None; if value.contains(':') { if let Some(model) = models.iter().find(|v| v.stringify() == value) { @@ -339,14 +344,6 @@ impl Config { } } - pub const fn get_reamind_tokens(&self) -> usize { - let mut tokens = self.model_info.max_tokens; - if let Some(session) = self.session.as_ref() { - tokens = tokens.saturating_sub(session.tokens); - } - tokens - } - pub fn info(&self) -> Result { let path_info = |path: &Path| { let state = if path.exists() { "" } else { " ⚠️" }; @@ -390,12 +387,7 @@ impl Config { completion.extend(SET_COMPLETIONS.map(std::string::ToString::to_string)); completion.extend( - list_models(self) - .iter() - .map(|v| format!(".model {}", v.stringify())), - ); - completion.extend( - list_models(self) + all_models(self) .iter() .map(|v| format!(".model {}", v.stringify())), ); @@ -504,6 +496,14 @@ impl Config { name = Text::new("Session name:").with_default(&name).prompt()?; } let session_path = Self::session_file(&name)?; + let sessions_dir = session_path.parent().ok_or_else(|| { + anyhow!("Unable to save session file to {}", session_path.display()) + })?; + if !sessions_dir.exists() { + create_dir_all(sessions_dir).with_context(|| { + format!("Failed to create session_dir '{}'", sessions_dir.display()) + })?; + } session.save(&session_path)?; } } @@ -556,6 +556,24 @@ impl Config { Ok(RenderOptions::new(theme, wrap, self.wrap_code)) } + pub fn render_prompt_right(&self) -> String { + if let Some(session) = &self.session { + let tokens = session.tokens; + // 10000(%32) + match self.model_info.max_tokens { + Some(max_tokens) => { + let ratio = tokens as f32 / max_tokens as f32; + let percent = ratio * 100.0; + let percent = (percent * 100.0).round() / 100.0; + format!("{tokens}({percent}%)") + } + None => format!("{tokens}"), + } + } else { + String::new() + } + } + pub fn prepare_send_data(&self, content: &str, stream: bool) -> Result { let messages = self.build_messages(content)?; Ok(SendData { @@ -585,11 +603,20 @@ impl Config { } fn load_config(config_path: &Path) -> Result { - let content = read_to_string(config_path) - .with_context(|| format!("Failed to load config at {}", config_path.display()))?; + let ctx = || format!("Failed to load config at {}", config_path.display()); + let content = read_to_string(config_path).with_context(ctx)?; let config: Self = serde_yaml::from_str(&content) - .with_context(|| format!("Invalid config at {}", config_path.display()))?; + .map_err(|err| { + let err_msg = err.to_string(); + if err_msg.starts_with(&format!("{}: ", CLIENTS_FIELD)) { + anyhow!("clients: invalid value") + } else { + anyhow!("{err_msg}") + } + }) + .with_context(ctx)?; + Ok(config) } @@ -606,11 +633,11 @@ impl Config { Ok(()) } - fn set_model_info(&mut self) -> Result<()> { + fn setup_model_info(&mut self) -> Result<()> { let model = match &self.model { Some(v) => v.clone(), None => { - let models = self::list_models(self); + let models = all_models(self); if models.is_empty() { bail!("No available model"); } @@ -622,7 +649,7 @@ impl Config { Ok(()) } - fn merge_env_vars(&mut self) { + fn setup_highlight(&mut self) { if let Ok(value) = env::var("NO_COLOR") { let mut no_color = false; set_bool(&mut no_color, &value); @@ -632,17 +659,7 @@ impl Config { } } - fn ensure_sessions_dir(&self) -> Result<()> { - let sessions_dir = Self::sessions_dir()?; - if !sessions_dir.exists() { - create_dir_all(&sessions_dir).with_context(|| { - format!("Failed to create session_dir '{}'", sessions_dir.display()) - })?; - } - Ok(()) - } - - fn detect_theme(&mut self) -> Result<()> { + fn setup_light_theme(&mut self) -> Result<()> { if self.light_theme { return Ok(()); } @@ -660,7 +677,7 @@ impl Config { fn compat_old_config(&mut self, config_path: &PathBuf) -> Result<()> { let content = read_to_string(config_path)?; let value: serde_json::Value = serde_yaml::from_str(&content)?; - if value.get("clients").is_some() { + if value.get(CLIENTS_FIELD).is_some() { return Ok(()); } @@ -725,16 +742,18 @@ fn create_config_file(config_path: &Path) -> Result<()> { exit(0); } - let client = Select::new("AI Platform:", list_client_types()) + let client = Select::new("Platform:", list_client_types()) .prompt() .map_err(prompt_op_err)?; - let mut raw_config = create_client_config(client)?; + let mut config = serde_json::json!({}); + config["model"] = client.into(); + config[CLIENTS_FIELD] = create_client_config(client)?; - raw_config.push_str(&format!("model: {client}\n")); + let config_data = serde_yaml::to_string(&config).with_context(|| "Failed to create config")?; ensure_parent_exists(config_path)?; - std::fs::write(config_path, raw_config).with_context(|| "Failed to write to config file")?; + std::fs::write(config_path, config_data).with_context(|| "Failed to write to config file")?; #[cfg(unix)] { use std::os::unix::prelude::PermissionsExt; diff --git a/src/config/model_info.rs b/src/config/model_info.rs new file mode 100644 index 00000000..1f8d6f07 --- /dev/null +++ b/src/config/model_info.rs @@ -0,0 +1,27 @@ +#[derive(Debug, Clone)] +pub struct ModelInfo { + pub client: String, + pub name: String, + pub max_tokens: Option, + pub index: usize, +} + +impl Default for ModelInfo { + fn default() -> Self { + ModelInfo::new("", "", None, 0) + } +} + +impl ModelInfo { + pub fn new(client: &str, name: &str, max_tokens: Option, index: usize) -> Self { + Self { + client: client.into(), + name: name.into(), + max_tokens, + index, + } + } + pub fn stringify(&self) -> String { + format!("{}:{}", self.client, self.name) + } +} diff --git a/src/main.rs b/src/main.rs index 5cd82e41..b2985e2a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,7 +12,7 @@ use crate::config::{Config, SharedConfig}; use anyhow::Result; use clap::Parser; -use client::{init_client, list_models}; +use client::{all_models, init_client}; use crossbeam::sync::WaitGroup; use is_terminal::IsTerminal; use parking_lot::RwLock; @@ -36,7 +36,7 @@ fn main() -> Result<()> { exit(0); } if cli.list_models { - for model in list_models(&config.read()) { + for model in all_models(&config.read()) { println!("{}", model.stringify()); } exit(0); diff --git a/src/repl/prompt.rs b/src/repl/prompt.rs index c73e5863..7bed5e04 100644 --- a/src/repl/prompt.rs +++ b/src/repl/prompt.rs @@ -32,11 +32,7 @@ impl Prompt for ReplPrompt { } fn render_prompt_right(&self) -> Cow { - if self.config.read().session.is_none() { - Cow::Borrowed("") - } else { - self.config.read().get_reamind_tokens().to_string().into() - } + Cow::Owned(self.config.read().render_prompt_right()) } fn render_prompt_indicator(&self, _prompt_mode: reedline::PromptEditMode) -> Cow { diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 9e999d46..6e27199f 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,6 +1,8 @@ +mod prompt_input; mod split_line; mod tiktoken; +pub use self::prompt_input::*; pub use self::split_line::*; pub use self::tiktoken::cl100k_base_singleton; diff --git a/src/utils/prompt_input.rs b/src/utils/prompt_input.rs new file mode 100644 index 00000000..f63823a1 --- /dev/null +++ b/src/utils/prompt_input.rs @@ -0,0 +1,58 @@ +use inquire::{required, validator::Validation, Text}; + +const MSG_REQUIRED: &str = "This field is required"; +const MSG_OPTIONAL: &str = "Optional field - Press ↵ to skip"; + +pub fn prompt_input_string(desc: &str, required: bool) -> anyhow::Result { + let mut text = Text::new(desc); + if required { + text = text.with_validator(required!(MSG_REQUIRED)) + } else { + text = text.with_help_message(MSG_OPTIONAL) + } + text.prompt().map_err(prompt_op_err) +} + +pub fn prompt_input_integer(desc: &str, required: bool) -> anyhow::Result { + let mut text = Text::new(desc); + if required { + text = text.with_validator(|text: &str| { + let out = if text.is_empty() { + Validation::Invalid(MSG_REQUIRED.into()) + } else { + validate_integer(text) + }; + Ok(out) + }) + } else { + text = text + .with_validator(|text: &str| { + let out = if text.is_empty() { + Validation::Valid + } else { + validate_integer(text) + }; + Ok(out) + }) + .with_help_message(MSG_OPTIONAL) + } + text.prompt().map_err(prompt_op_err) +} + +pub fn prompt_op_err(_: T) -> anyhow::Error { + anyhow::anyhow!("Not finish questionnaire, try again later!") +} + +#[derive(Debug, Clone, Copy)] +pub enum PromptKind { + String, + Integer, +} + +fn validate_integer(text: &str) -> Validation { + if text.parse::().is_err() { + Validation::Invalid("Must be a integer".into()) + } else { + Validation::Valid + } +}