Skip to content

Commit

Permalink
refactor: clients/* and config.rs
Browse files Browse the repository at this point in the history
- add register_clients to make it easier to add a new client
- no create_client_config, just add const PROMPTS
- move ModelInfo from clients/ to config/
- model's max_tokens are optional
- improve code quanity on config/mod.rs
  • Loading branch information
sigoden committed Nov 1, 2023
1 parent 6420275 commit 49eb532
Show file tree
Hide file tree
Showing 14 changed files with 410 additions and 387 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ inquire = "0.6.2"
is-terminal = "0.4.9"
reedline = "0.21.0"
serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.93"
serde_json = { version = "1.0.93", features = ["preserve_order"] }
serde_yaml = "0.9.17"
tokio = { version = "1.26.0", features = ["rt", "time", "macros", "signal"] }
crossbeam = "0.8.2"
Expand Down
11 changes: 4 additions & 7 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,30 @@ keybindings: emacs # REPL keybindings. values: emacs, vi

clients:
# All clients have the following configuration:
# ```
# - type: xxxx
# name: nova # Only use it to distinguish clients with the same client type. Optional
# extra:
# proxy: socks5://127.0.0.1:1080 # Specify https/socks5 proxy server. Note HTTPS_PROXY/ALL_PROXY also works.
# connect_timeout: 10 # Set a timeout in seconds for connect to server
# ```

# See https://platform.openai.com/docs/quickstart
- type: openai
api_key: sk-xxx
organization_id: org-xxx # Organization ID. Optional
organization_id:

# See https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart
- type: azure-openai
api_base: https://RESOURCE.openai.azure.com
api_key: xxx
models: # Support models
models:
- name: MyGPT4 # Model deployment name
max_tokens: 8192


# See https://github.com/go-skynet/LocalAI
- type: localai
api_base: http://localhost:8080/v1
api_key: xxx
chat_endpoint: /chat/completions # Optional
models: # Support models
chat_endpoint: /chat/completions
models:
- name: gpt4all-j
max_tokens: 8192
74 changes: 15 additions & 59 deletions src/client/azure_openai.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming};
use super::{
prompt_input_api_base, prompt_input_api_key, prompt_input_max_token, prompt_input_model_name,
Client, ClientConfig, ExtraConfig, ModelInfo, SendData,
};
use super::{AzureOpenAIClient, Client, ExtraConfig, ModelInfo, PromptKind, PromptType, SendData};

use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;
Expand All @@ -14,13 +11,6 @@ use serde::Deserialize;

use std::env;

#[derive(Debug)]
pub struct AzureOpenAIClient {
global_config: SharedConfig,
config: AzureOpenAIConfig,
model_info: ModelInfo,
}

#[derive(Debug, Clone, Deserialize)]
pub struct AzureOpenAIConfig {
pub name: Option<String>,
Expand All @@ -33,17 +23,13 @@ pub struct AzureOpenAIConfig {
#[derive(Debug, Clone, Deserialize)]
pub struct AzureOpenAIModel {
name: String,
max_tokens: usize,
max_tokens: Option<usize>,
}

#[async_trait]
impl Client for AzureOpenAIClient {
fn config(&self) -> &SharedConfig {
&self.global_config
}

fn extra_config(&self) -> &Option<ExtraConfig> {
&self.config.extra
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>) {
(&self.global_config, &self.config.extra)
}

async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
Expand All @@ -63,27 +49,17 @@ impl Client for AzureOpenAIClient {
}

impl AzureOpenAIClient {
pub const NAME: &str = "azure-openai";

pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
let model_info = global_config.read().model_info.clone();
let config = {
if let ClientConfig::AzureOpenAI(c) = &global_config.read().clients[model_info.index] {
c.clone()
} else {
return None;
}
};
Some(Box::new(Self {
global_config,
config,
model_info,
}))
}

pub fn name(local_config: &AzureOpenAIConfig) -> &str {
local_config.name.as_deref().unwrap_or(Self::NAME)
}
pub const PROMPTS: [PromptType<'static>; 4] = [
("api_base", "API Base:", true, PromptKind::String),
("api_key", "API Key:", true, PromptKind::String),
("models[].name", "Model Name:", true, PromptKind::String),
(
"models[].max_tokens",
"Max Tokens:",
true,
PromptKind::Integer,
),
];

pub fn list_models(local_config: &AzureOpenAIConfig, index: usize) -> Vec<ModelInfo> {
let client = Self::name(local_config);
Expand All @@ -95,26 +71,6 @@ impl AzureOpenAIClient {
.collect()
}

pub fn create_config() -> Result<String> {
let mut client_config = format!("clients:\n - type: {}\n", Self::NAME);

let api_base = prompt_input_api_base()?;
client_config.push_str(&format!(" api_base: {api_base}\n"));

let api_key = prompt_input_api_key()?;
client_config.push_str(&format!(" api_key: {api_key}\n"));

let model_name = prompt_input_model_name()?;

let max_tokens = prompt_input_max_token()?;

client_config.push_str(&format!(
" models:\n - name: {model_name}\n max_tokens: {max_tokens}\n"
));

Ok(client_config)
}

fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.config.api_key.clone();
let api_key = api_key
Expand Down
74 changes: 15 additions & 59 deletions src/client/localai.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming};
use super::{
prompt_input_api_base, prompt_input_api_key_optional, prompt_input_max_token,
prompt_input_model_name, Client, ClientConfig, ExtraConfig, ModelInfo, SendData,
};
use super::{Client, ExtraConfig, LocalAIClient, ModelInfo, PromptKind, PromptType, SendData};

use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;
Expand All @@ -13,13 +10,6 @@ use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use std::env;

#[derive(Debug)]
pub struct LocalAIClient {
global_config: SharedConfig,
config: LocalAIConfig,
model_info: ModelInfo,
}

#[derive(Debug, Clone, Deserialize)]
pub struct LocalAIConfig {
pub name: Option<String>,
Expand All @@ -33,17 +23,13 @@ pub struct LocalAIConfig {
#[derive(Debug, Clone, Deserialize)]
pub struct LocalAIModel {
name: String,
max_tokens: usize,
max_tokens: Option<usize>,
}

#[async_trait]
impl Client for LocalAIClient {
fn config(&self) -> &SharedConfig {
&self.global_config
}

fn extra_config(&self) -> &Option<ExtraConfig> {
&self.config.extra
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>) {
(&self.global_config, &self.config.extra)
}

async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
Expand All @@ -63,27 +49,17 @@ impl Client for LocalAIClient {
}

impl LocalAIClient {
pub const NAME: &str = "localai";

pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
let model_info = global_config.read().model_info.clone();
let config = {
if let ClientConfig::LocalAI(c) = &global_config.read().clients[model_info.index] {
c.clone()
} else {
return None;
}
};
Some(Box::new(Self {
global_config,
config,
model_info,
}))
}

pub fn name(local_config: &LocalAIConfig) -> &str {
local_config.name.as_deref().unwrap_or(Self::NAME)
}
pub const PROMPTS: [PromptType<'static>; 4] = [
("api_base", "API Base:", true, PromptKind::String),
("api_key", "API Key:", false, PromptKind::String),
("models[].name", "Model Name:", true, PromptKind::String),
(
"models[].max_tokens",
"Max Tokens:",
false,
PromptKind::Integer,
),
];

pub fn list_models(local_config: &LocalAIConfig, index: usize) -> Vec<ModelInfo> {
let client = Self::name(local_config);
Expand All @@ -95,26 +71,6 @@ impl LocalAIClient {
.collect()
}

pub fn create_config() -> Result<String> {
let mut client_config = format!("clients:\n - type: {}\n", Self::NAME);

let api_base = prompt_input_api_base()?;
client_config.push_str(&format!(" api_base: {api_base}\n"));

let api_key = prompt_input_api_key_optional()?;
client_config.push_str(&format!(" api_key: {api_key}\n"));

let model_name = prompt_input_model_name()?;

let max_tokens = prompt_input_max_token()?;

client_config.push_str(&format!(
" models:\n - name: {model_name}\n max_tokens: {max_tokens}\n"
));

Ok(client_config)
}

fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.config.api_key.clone();
let api_key = api_key.or_else(|| {
Expand Down
Loading

0 comments on commit 49eb532

Please sign in to comment.