From 4263b59af9ff8d36eb23edc3dcff92b1d7357d46 Mon Sep 17 00:00:00 2001 From: sigoden Date: Wed, 1 Nov 2023 09:54:26 +0800 Subject: [PATCH] split clients/mod.rs --- src/client/common.rs | 298 +++++++++++++++++++++++++++++++++++++++++ src/client/mod.rs | 307 ++----------------------------------------- src/utils/mod.rs | 8 ++ 3 files changed, 315 insertions(+), 298 deletions(-) create mode 100644 src/client/common.rs diff --git a/src/client/common.rs b/src/client/common.rs new file mode 100644 index 00000000..938bc596 --- /dev/null +++ b/src/client/common.rs @@ -0,0 +1,298 @@ +use crate::{ + config::{Message, SharedConfig}, + repl::{ReplyStreamHandler, SharedAbortSignal}, + utils::{init_tokio_runtime, prompt_input_integer, prompt_input_string, tokenize, PromptKind}, +}; + +use anyhow::{Context, Result}; +use async_trait::async_trait; +use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy}; +use serde::Deserialize; +use serde_json::{json, Value}; +use std::{env, time::Duration}; +use tokio::time::sleep; + +use super::{openai::OpenAIConfig, ClientConfig}; + +#[macro_export] +macro_rules! register_role { + ( + $(($name:literal, $config_key:ident, $config:ident, $client:ident),)+ + ) => { + + #[derive(Debug, Clone, Deserialize)] + #[serde(tag = "type")] + pub enum ClientConfig { + $( + #[serde(rename = $name)] + $config_key($config), + )+ + #[serde(other)] + Unknown, + } + + + $( + #[derive(Debug)] + pub struct $client { + global_config: SharedConfig, + config: $config, + model_info: ModelInfo, + } + + impl $client { + pub const NAME: &str = $name; + + pub fn init(global_config: SharedConfig) -> Option> { + let model_info = global_config.read().model_info.clone(); + let config = { + if let ClientConfig::$config_key(c) = &global_config.read().clients[model_info.index] { + c.clone() + } else { + return None; + } + }; + Some(Box::new(Self { + global_config, + config, + model_info, + })) + } + + pub fn name(local_config: &$config) -> &str { + local_config.name.as_deref().unwrap_or(Self::NAME) + } + } + + )+ + + pub fn init_client(config: SharedConfig) -> Result> { + None + $(.or_else(|| $client::init(config.clone())))+ + .ok_or_else(|| { + let model_info = config.read().model_info.clone(); + anyhow!( + "Unknown client {} at config.clients[{}]", + &model_info.client, + &model_info.index + ) + }) + } + + pub fn list_client_types() -> Vec<&'static str> { + vec![$($client::NAME,)+] + } + + pub fn create_client_config(client: &str) -> Result { + $( + if client == $client::NAME { + return create_config(&$client::PROMPTS, $client::NAME) + } + )+ + bail!("Unknown client {}", client) + } + + pub fn all_models(config: &Config) -> Vec { + config + .clients + .iter() + .enumerate() + .flat_map(|(i, v)| match v { + $(ClientConfig::$config_key(c) => $client::list_models(c, i),)+ + ClientConfig::Unknown => vec![], + }) + .collect() + } + + }; +} + +#[async_trait] +pub trait Client { + fn config(&self) -> (&SharedConfig, &Option); + + fn build_client(&self) -> Result { + let mut builder = ReqwestClient::builder(); + let options = self.config().1; + let timeout = options + .as_ref() + .and_then(|v| v.connect_timeout) + .unwrap_or(10); + let proxy = options.as_ref().and_then(|v| v.proxy.clone()); + builder = set_proxy(builder, &proxy)?; + let client = builder + .connect_timeout(Duration::from_secs(timeout)) + .build() + .with_context(|| "Failed to build client")?; + Ok(client) + } + + fn send_message(&self, content: &str) -> Result { + init_tokio_runtime()?.block_on(async { + let global_config = self.config().0; + if global_config.read().dry_run { + let content = global_config.read().echo_messages(content); + return Ok(content); + } + let client = self.build_client()?; + let data = global_config.read().prepare_send_data(content, false)?; + self.send_message_inner(&client, data) + .await + .with_context(|| "Failed to get awswer") + }) + } + + fn send_message_streaming( + &self, + content: &str, + handler: &mut ReplyStreamHandler, + ) -> Result<()> { + async fn watch_abort(abort: SharedAbortSignal) { + loop { + if abort.aborted() { + break; + } + sleep(Duration::from_millis(100)).await; + } + } + let abort = handler.get_abort(); + init_tokio_runtime()?.block_on(async { + tokio::select! { + ret = async { + let global_config = self.config().0; + if global_config.read().dry_run { + let content = global_config.read().echo_messages(content); + let tokens = tokenize(&content); + for token in tokens { + tokio::time::sleep(Duration::from_millis(25)).await; + handler.text(&token)?; + } + return Ok(()); + } + let client = self.build_client()?; + let data = global_config.read().prepare_send_data(content, true)?; + self.send_message_streaming_inner(&client, handler, data).await + } => { + handler.done()?; + ret.with_context(|| "Failed to get awswer") + } + _ = watch_abort(abort.clone()) => { + handler.done()?; + Ok(()) + }, + _ = tokio::signal::ctrl_c() => { + abort.set_ctrlc(); + Ok(()) + } + } + }) + } + + async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result; + + async fn send_message_streaming_inner( + &self, + client: &ReqwestClient, + handler: &mut ReplyStreamHandler, + data: SendData, + ) -> Result<()>; +} + +impl Default for ClientConfig { + fn default() -> Self { + Self::OpenAI(OpenAIConfig::default()) + } +} + +#[derive(Debug, Clone, Deserialize, Default)] +pub struct ExtraConfig { + pub proxy: Option, + pub connect_timeout: Option, +} + +#[derive(Debug)] +pub struct SendData { + pub messages: Vec, + pub temperature: Option, + pub stream: bool, +} + +pub type PromptType<'a> = (&'a str, &'a str, bool, PromptKind); + +pub fn create_config(list: &[PromptType], client: &str) -> Result { + let mut config = json!({ + "type": client, + }); + for (path, desc, required, kind) in list { + match kind { + PromptKind::String => { + let value = prompt_input_string(desc, *required)?; + set_config_value(&mut config, path, kind, &value); + } + PromptKind::Integer => { + let value = prompt_input_integer(desc, *required)?; + set_config_value(&mut config, path, kind, &value); + } + } + } + + let clients = json!(vec![config]); + Ok(clients) +} + +fn set_config_value(json: &mut Value, path: &str, kind: &PromptKind, value: &str) { + let segs: Vec<&str> = path.split('.').collect(); + match segs.as_slice() { + [name] => json[name] = to_json(kind, value), + [scope, name] => match scope.split_once('[') { + None => { + if json.get(scope).is_none() { + let mut obj = json!({}); + obj[name] = to_json(kind, value); + json[scope] = obj; + } else { + json[scope][name] = to_json(kind, value); + } + } + Some((scope, _)) => { + if json.get(scope).is_none() { + let mut obj = json!({}); + obj[name] = to_json(kind, value); + json[scope] = json!([obj]); + } else { + json[scope][0][name] = to_json(kind, value); + } + } + }, + _ => {} + } +} + +fn to_json(kind: &PromptKind, value: &str) -> Value { + if value.is_empty() { + return Value::Null; + } + match kind { + PromptKind::String => value.into(), + PromptKind::Integer => match value.parse::() { + Ok(value) => value.into(), + Err(_) => value.into(), + }, + } +} + +fn set_proxy(builder: ClientBuilder, proxy: &Option) -> Result { + let proxy = if let Some(proxy) = proxy { + if proxy.is_empty() || proxy == "false" || proxy == "-" { + return Ok(builder); + } + proxy.clone() + } else if let Ok(proxy) = env::var("HTTPS_PROXY").or_else(|_| env::var("ALL_PROXY")) { + proxy + } else { + return Ok(builder); + }; + let builder = + builder.proxy(Proxy::all(&proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?); + Ok(builder) +} diff --git a/src/client/mod.rs b/src/client/mod.rs index b0326650..5562eb64 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,220 +1,24 @@ +#[macro_use] +mod common; + pub mod azure_openai; pub mod localai; pub mod openai; +pub use common::*; + use self::azure_openai::AzureOpenAIConfig; use self::localai::LocalAIConfig; use self::openai::OpenAIConfig; use crate::{ - config::{Config, Message, ModelInfo, SharedConfig}, - repl::{ReplyStreamHandler, SharedAbortSignal}, - utils::{prompt_input_integer, prompt_input_string, tokenize, PromptKind}, + config::{Config, ModelInfo, SharedConfig}, + utils::PromptKind, }; -use anyhow::{anyhow, bail, Context, Result}; -use async_trait::async_trait; -use reqwest::{Client as ReqwestClient, ClientBuilder, Proxy}; +use anyhow::{anyhow, bail, Result}; use serde::Deserialize; -use serde_json::{json, Value}; -use std::{env, time::Duration}; -use tokio::time::sleep; - -#[derive(Debug, Clone, Deserialize, Default)] -pub struct ExtraConfig { - pub proxy: Option, - pub connect_timeout: Option, -} - -#[derive(Debug)] -pub struct SendData { - pub messages: Vec, - pub temperature: Option, - pub stream: bool, -} - -#[async_trait] -pub trait Client { - fn config(&self) -> (&SharedConfig, &Option); - - fn build_client(&self) -> Result { - let mut builder = ReqwestClient::builder(); - let options = self.config().1; - let timeout = options - .as_ref() - .and_then(|v| v.connect_timeout) - .unwrap_or(10); - let proxy = options.as_ref().and_then(|v| v.proxy.clone()); - builder = set_proxy(builder, &proxy)?; - let client = builder - .connect_timeout(Duration::from_secs(timeout)) - .build() - .with_context(|| "Failed to build client")?; - Ok(client) - } - - fn send_message(&self, content: &str) -> Result { - init_tokio_runtime()?.block_on(async { - let global_config = self.config().0; - if global_config.read().dry_run { - let content = global_config.read().echo_messages(content); - return Ok(content); - } - let client = self.build_client()?; - let data = global_config.read().prepare_send_data(content, false)?; - self.send_message_inner(&client, data) - .await - .with_context(|| "Failed to get awswer") - }) - } - - fn send_message_streaming( - &self, - content: &str, - handler: &mut ReplyStreamHandler, - ) -> Result<()> { - async fn watch_abort(abort: SharedAbortSignal) { - loop { - if abort.aborted() { - break; - } - sleep(Duration::from_millis(100)).await; - } - } - let abort = handler.get_abort(); - init_tokio_runtime()?.block_on(async { - tokio::select! { - ret = async { - let global_config = self.config().0; - if global_config.read().dry_run { - let content = global_config.read().echo_messages(content); - let tokens = tokenize(&content); - for token in tokens { - tokio::time::sleep(Duration::from_millis(25)).await; - handler.text(&token)?; - } - return Ok(()); - } - let client = self.build_client()?; - let data = global_config.read().prepare_send_data(content, true)?; - self.send_message_streaming_inner(&client, handler, data).await - } => { - handler.done()?; - ret.with_context(|| "Failed to get awswer") - } - _ = watch_abort(abort.clone()) => { - handler.done()?; - Ok(()) - }, - _ = tokio::signal::ctrl_c() => { - abort.set_ctrlc(); - Ok(()) - } - } - }) - } - - async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result; - - async fn send_message_streaming_inner( - &self, - client: &ReqwestClient, - handler: &mut ReplyStreamHandler, - data: SendData, - ) -> Result<()>; -} - -macro_rules! register_role { - ( - $(($name:literal, $config_key:ident, $config:ident, $client:ident),)+ - ) => { - - #[derive(Debug, Clone, Deserialize)] - #[serde(tag = "type")] - pub enum ClientConfig { - $( - #[serde(rename = $name)] - $config_key($config), - )+ - #[serde(other)] - Unknown, - } - - - $( - #[derive(Debug)] - pub struct $client { - global_config: SharedConfig, - config: $config, - model_info: ModelInfo, - } - - impl $client { - pub const NAME: &str = $name; - - pub fn init(global_config: SharedConfig) -> Option> { - let model_info = global_config.read().model_info.clone(); - let config = { - if let ClientConfig::$config_key(c) = &global_config.read().clients[model_info.index] { - c.clone() - } else { - return None; - } - }; - Some(Box::new(Self { - global_config, - config, - model_info, - })) - } - - pub fn name(local_config: &$config) -> &str { - local_config.name.as_deref().unwrap_or(Self::NAME) - } - } - - )+ - - pub fn init_client(config: SharedConfig) -> Result> { - None - $(.or_else(|| $client::init(config.clone())))+ - .ok_or_else(|| { - let model_info = config.read().model_info.clone(); - anyhow!( - "Unknown client {} at config.clients[{}]", - &model_info.client, - &model_info.index - ) - }) - } - - pub fn list_client_types() -> Vec<&'static str> { - vec![$($client::NAME,)+] - } - - pub fn create_client_config(client: &str) -> Result { - $( - if client == $client::NAME { - return create_config(&$client::PROMPTS, $client::NAME) - } - )+ - bail!("Unknown client {}", client) - } - - pub fn all_models(config: &Config) -> Vec { - config - .clients - .iter() - .enumerate() - .flat_map(|(i, v)| match v { - $(ClientConfig::$config_key(c) => $client::list_models(c, i),)+ - ClientConfig::Unknown => vec![], - }) - .collect() - } - - }; -} +use serde_json::Value; register_role!( ("openai", OpenAI, OpenAIConfig, OpenAIClient), @@ -226,96 +30,3 @@ register_role!( AzureOpenAIClient ), ); - -impl Default for ClientConfig { - fn default() -> Self { - Self::OpenAI(OpenAIConfig::default()) - } -} - -type PromptType<'a> = (&'a str, &'a str, bool, PromptKind); - -fn init_tokio_runtime() -> Result { - tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .with_context(|| "Failed to init tokio") -} - -fn create_config(list: &[PromptType], client: &str) -> Result { - let mut config = json!({ - "type": client, - }); - for (path, desc, required, kind) in list { - match kind { - PromptKind::String => { - let value = prompt_input_string(desc, *required)?; - set_config_value(&mut config, path, kind, &value); - } - PromptKind::Integer => { - let value = prompt_input_integer(desc, *required)?; - set_config_value(&mut config, path, kind, &value); - } - } - } - - let clients = json!(vec![config]); - Ok(clients) -} - -fn set_config_value(json: &mut Value, path: &str, kind: &PromptKind, value: &str) { - let segs: Vec<&str> = path.split('.').collect(); - match segs.as_slice() { - [name] => json[name] = to_json(kind, value), - [scope, name] => match scope.split_once('[') { - None => { - if json.get(scope).is_none() { - let mut obj = json!({}); - obj[name] = to_json(kind, value); - json[scope] = obj; - } else { - json[scope][name] = to_json(kind, value); - } - } - Some((scope, _)) => { - if json.get(scope).is_none() { - let mut obj = json!({}); - obj[name] = to_json(kind, value); - json[scope] = json!([obj]); - } else { - json[scope][0][name] = to_json(kind, value); - } - } - }, - _ => {} - } -} - -fn to_json(kind: &PromptKind, value: &str) -> Value { - if value.is_empty() { - return Value::Null; - } - match kind { - PromptKind::String => value.into(), - PromptKind::Integer => match value.parse::() { - Ok(value) => value.into(), - Err(_) => value.into(), - }, - } -} - -fn set_proxy(builder: ClientBuilder, proxy: &Option) -> Result { - let proxy = if let Some(proxy) = proxy { - if proxy.is_empty() || proxy == "false" || proxy == "-" { - return Ok(builder); - } - proxy.clone() - } else if let Ok(proxy) = env::var("HTTPS_PROXY").or_else(|_| env::var("ALL_PROXY")) { - proxy - } else { - return Ok(builder); - }; - let builder = - builder.proxy(Proxy::all(&proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?); - Ok(builder) -} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 6e27199f..ab922e33 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -65,3 +65,11 @@ pub fn light_theme_from_colorfgbg(colorfgbg: &str) -> Option { let light = v > 128.0; Some(light) } + +pub fn init_tokio_runtime() -> anyhow::Result { + use anyhow::Context; + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .with_context(|| "Failed to init tokio") +}