Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: clients/* and config.rs #193

Merged
merged 3 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ inquire = "0.6.2"
is-terminal = "0.4.9"
reedline = "0.21.0"
serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.93"
serde_json = { version = "1.0.93", features = ["preserve_order"] }
serde_yaml = "0.9.17"
tokio = { version = "1.26.0", features = ["rt", "time", "macros", "signal"] }
crossbeam = "0.8.2"
Expand Down
11 changes: 4 additions & 7 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,30 @@ keybindings: emacs # REPL keybindings. values: emacs, vi

clients:
# All clients have the following configuration:
# ```
# - type: xxxx
# name: nova # Only use it to distinguish clients with the same client type. Optional
# extra:
# proxy: socks5://127.0.0.1:1080 # Specify https/socks5 proxy server. Note HTTPS_PROXY/ALL_PROXY also works.
# connect_timeout: 10 # Set a timeout in seconds for connect to server
# ```

# See https://platform.openai.com/docs/quickstart
- type: openai
api_key: sk-xxx
organization_id: org-xxx # Organization ID. Optional
organization_id:

# See https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart
- type: azure-openai
api_base: https://RESOURCE.openai.azure.com
api_key: xxx
models: # Support models
models:
- name: MyGPT4 # Model deployment name
max_tokens: 8192


# See https://github.com/go-skynet/LocalAI
- type: localai
api_base: http://localhost:8080/v1
api_key: xxx
chat_endpoint: /chat/completions # Optional
models: # Support models
chat_endpoint: /chat/completions
models:
- name: gpt4all-j
max_tokens: 8192
82 changes: 21 additions & 61 deletions src/client/azure_openai.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming};
use super::{
prompt_input_api_base, prompt_input_api_key, prompt_input_max_token, prompt_input_model_name,
Client, ClientConfig, ExtraConfig, ModelInfo, SendData,
};
use super::{AzureOpenAIClient, Client, ExtraConfig, ModelInfo, PromptKind, PromptType, SendData};

use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;
Expand All @@ -14,17 +11,10 @@ use serde::Deserialize;

use std::env;

#[derive(Debug)]
pub struct AzureOpenAIClient {
global_config: SharedConfig,
config: AzureOpenAIConfig,
model_info: ModelInfo,
}

#[derive(Debug, Clone, Deserialize)]
pub struct AzureOpenAIConfig {
pub name: Option<String>,
pub api_base: String,
pub api_base: Option<String>,
pub api_key: Option<String>,
pub models: Vec<AzureOpenAIModel>,
pub extra: Option<ExtraConfig>,
Expand All @@ -33,17 +23,13 @@ pub struct AzureOpenAIConfig {
#[derive(Debug, Clone, Deserialize)]
pub struct AzureOpenAIModel {
name: String,
max_tokens: usize,
max_tokens: Option<usize>,
}

#[async_trait]
impl Client for AzureOpenAIClient {
fn config(&self) -> &SharedConfig {
&self.global_config
}

fn extra_config(&self) -> &Option<ExtraConfig> {
&self.config.extra
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>) {
(&self.global_config, &self.config.extra)
}

async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
Expand All @@ -63,27 +49,19 @@ impl Client for AzureOpenAIClient {
}

impl AzureOpenAIClient {
pub const NAME: &str = "azure-openai";

pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
let model_info = global_config.read().model_info.clone();
let config = {
if let ClientConfig::AzureOpenAI(c) = &global_config.read().clients[model_info.index] {
c.clone()
} else {
return None;
}
};
Some(Box::new(Self {
global_config,
config,
model_info,
}))
}

pub fn name(local_config: &AzureOpenAIConfig) -> &str {
local_config.name.as_deref().unwrap_or(Self::NAME)
}
config_get_fn!(api_base, get_api_base);

pub const PROMPTS: [PromptType<'static>; 4] = [
("api_base", "API Base:", true, PromptKind::String),
("api_key", "API Key:", true, PromptKind::String),
("models[].name", "Model Name:", true, PromptKind::String),
(
"models[].max_tokens",
"Max Tokens:",
true,
PromptKind::Integer,
),
];

pub fn list_models(local_config: &AzureOpenAIConfig, index: usize) -> Vec<ModelInfo> {
let client = Self::name(local_config);
Expand All @@ -95,26 +73,6 @@ impl AzureOpenAIClient {
.collect()
}

pub fn create_config() -> Result<String> {
let mut client_config = format!("clients:\n - type: {}\n", Self::NAME);

let api_base = prompt_input_api_base()?;
client_config.push_str(&format!(" api_base: {api_base}\n"));

let api_key = prompt_input_api_key()?;
client_config.push_str(&format!(" api_key: {api_key}\n"));

let model_name = prompt_input_model_name()?;

let max_tokens = prompt_input_max_token()?;

client_config.push_str(&format!(
" models:\n - name: {model_name}\n max_tokens: {max_tokens}\n"
));

Ok(client_config)
}

fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.config.api_key.clone();
let api_key = api_key
Expand All @@ -127,11 +85,13 @@ impl AzureOpenAIClient {
})
.ok_or_else(|| anyhow!("Miss api_key"))?;

let api_base = self.get_api_base()?;

let body = openai_build_body(data, self.model_info.name.clone());

let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version=2023-05-15",
self.config.api_base, self.model_info.name
&api_base, self.model_info.name
);

let builder = client.post(url).header("api-key", api_key).json(&body);
Expand Down
Loading