Skip to content

Commit

Permalink
feat: improve client (#189)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Oct 31, 2023
1 parent 9b61460 commit a137483
Show file tree
Hide file tree
Showing 7 changed files with 349 additions and 347 deletions.
35 changes: 22 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,8 @@ On first launch, aichat will guide you through the configuration.

```
> No config file, create a new one? Yes
> Select AI? openai
> API key: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
> Has Organization? No
> Use proxy? No
> AI Platform: openai
> API Key: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
> Save chat messages Yes
```

Expand All @@ -64,24 +62,35 @@ wrap_code: false # Whether wrap code block
auto_copy: false # Automatically copy the last output to the clipboard
keybindings: emacs # REPL keybindings, possible values: emacs (default), vi

clients: # Setup AIs
clients:
# All clients have the following configuration:
# ```
# - type: xxxx
# name: nova # Only use it to distinguish clients with the same client type. Optional
# extra:
# proxy: socks5://127.0.0.1:1080 # Specify https/socks5 proxy server. Note HTTPS_PROXY/ALL_PROXY also works.
# connect_timeout: 10 # Set a timeout in seconds for connect to server
# ```

# See https://platform.openai.com/docs/quickstart
- type: openai # OpenAI configuration
api_key: sk-xxx # OpenAI api key, alternative to OPENAI_API_KEY
- type: openai
api_key: sk-xxx
organization_id: org-xxx # Organization ID. Optional

# See https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart
- type: azure-openai # Azure openai configuration
api_base: https://RESOURCE.openai.azure.com # Azure openai base URL
api_key: xxx # Azure openai api key, alternative to AZURE_OPENAI_KEY
- type: azure-openai
api_base: https://RESOURCE.openai.azure.com
api_key: xxx
models: # Support models
- name: MyGPT4 # Model deployment name
max_tokens: 8192


# See https://github.com/go-skynet/LocalAI
- type: localai # LocalAI configuration
url: http://localhost:8080/v1/chat/completions # LocalAI api server
api_key: xxx # Api key. alternative to LOCALAI_API_KEY
- type: localai
api_base: http://localhost:8080/v1
api_key: xxx
chat_endpoint: /chat/completions # Optional
models: # Support models
- name: gpt4all-j
max_tokens: 8192
Expand Down
37 changes: 21 additions & 16 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,35 @@ wrap_code: false # Whether wrap code block
auto_copy: false # Automatically copy the last output to the clipboard
keybindings: emacs # REPL keybindings, possible values: emacs (default), vi

clients: # Setup AIs
clients:
# All clients have the following configuration:
# ```
# - type: xxxx
# name: nova # Only use it to distinguish clients with the same client type. Optional
# extra:
# proxy: socks5://127.0.0.1:1080 # Specify https/socks5 proxy server. Note HTTPS_PROXY/ALL_PROXY also works.
# connect_timeout: 10 # Set a timeout in seconds for connect to server
# ```

# See https://platform.openai.com/docs/quickstart
- type: openai # OpenAI configuration
api_key: sk-xxx # OpenAI api key, alternative to OPENAI_API_KEY
- type: openai
api_key: sk-xxx
organization_id: org-xxx # Organization ID. Optional
proxy: socks5://127.0.0.1:1080
connect_timeout: 10

# See https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart
- type: azure-openai # Azure openai configuration
api_base: https://RESOURCE.openai.azure.com # Azure openai base URL
api_key: xxx # Azure openai api key, alternative to AZURE_OPENAI_KEY
- type: azure-openai
api_base: https://RESOURCE.openai.azure.com
api_key: xxx
models: # Support models
- name: MyGPT4 # Model deployment name
max_tokens: 8192
proxy: socks5://127.0.0.1:1080 # Set proxy server. Optional
connect_timeout: 10 # Set a timeout in seconds for connect to gpt. Optional


# See https://github.com/go-skynet/LocalAI
- type: localai # LocalAI configuration
url: http://localhost:8080/v1/chat/completions # LocalAI api server
api_key: xxx # Api key. alternative to LOCALAI_API_KEY
- type: localai
api_base: http://localhost:8080/v1
api_key: xxx
chat_endpoint: /chat/completions # Optional
models: # Support models
- name: gpt4all-j
max_tokens: 8192
proxy: socks5://127.0.0.1:1080
connect_timeout: 10
max_tokens: 8192
145 changes: 51 additions & 94 deletions src/client/azure_openai.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,33 @@
use super::openai::{openai_send_message, openai_send_message_streaming};
use super::{set_proxy, Client, ClientConfig, ModelInfo};
use super::openai::{openai_build_body, openai_send_message, openai_send_message_streaming};
use super::{
prompt_input_api_base, prompt_input_api_key, prompt_input_max_token, prompt_input_model_name,
Client, ClientConfig, ExtraConfig, ModelInfo, SendData,
};

use crate::config::SharedConfig;
use crate::repl::ReplyStreamHandler;

use anyhow::{anyhow, Context, Result};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use inquire::{Confirm, Text};
use reqwest::{Client as ReqwestClient, RequestBuilder};
use serde::Deserialize;
use serde_json::json;

use std::env;
use std::time::Duration;

#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct AzureOpenAIClient {
global_config: SharedConfig,
local_config: AzureOpenAIConfig,
config: AzureOpenAIConfig,
model_info: ModelInfo,
}

#[derive(Debug, Clone, Deserialize)]
pub struct AzureOpenAIConfig {
pub name: Option<String>,
pub api_base: String,
pub api_key: Option<String>,
pub models: Vec<AzureOpenAIModel>,
pub proxy: Option<String>,
/// Set a timeout in seconds for connect to server
pub connect_timeout: Option<u64>,
pub extra: Option<ExtraConfig>,
}

#[derive(Debug, Clone, Deserialize)]
Expand All @@ -39,32 +38,36 @@ pub struct AzureOpenAIModel {

#[async_trait]
impl Client for AzureOpenAIClient {
fn get_config(&self) -> &SharedConfig {
fn config(&self) -> &SharedConfig {
&self.global_config
}

async fn send_message_inner(&self, content: &str) -> Result<String> {
let builder = self.request_builder(content, false)?;
fn extra_config(&self) -> &Option<ExtraConfig> {
&self.config.extra
}

async fn send_message_inner(&self, client: &ReqwestClient, data: SendData) -> Result<String> {
let builder = self.request_builder(client, data)?;
openai_send_message(builder).await
}

async fn send_message_streaming_inner(
&self,
content: &str,
client: &ReqwestClient,
handler: &mut ReplyStreamHandler,
data: SendData,
) -> Result<()> {
let builder = self.request_builder(content, true)?;
let builder = self.request_builder(client, data)?;
openai_send_message_streaming(builder, handler).await
}
}

impl AzureOpenAIClient {
pub const NAME: &str = "azure-openai";

pub fn init(global_config: SharedConfig) -> Option<Box<dyn Client>> {
let model_info = global_config.read().model_info.clone();
if model_info.client != AzureOpenAIClient::name() {
return None;
}
let local_config = {
let config = {
if let ClientConfig::AzureOpenAI(c) = &global_config.read().clients[model_info.index] {
c.clone()
} else {
Expand All @@ -73,59 +76,37 @@ impl AzureOpenAIClient {
};
Some(Box::new(Self {
global_config,
local_config,
config,
model_info,
}))
}

pub fn name() -> &'static str {
"azure-openai"
pub fn name(local_config: &AzureOpenAIConfig) -> &str {
local_config.name.as_deref().unwrap_or(Self::NAME)
}

pub fn list_models(local_config: &AzureOpenAIConfig, index: usize) -> Vec<ModelInfo> {
let client = Self::name(local_config);

local_config
.models
.iter()
.map(|v| ModelInfo::new(Self::name(), &v.name, v.max_tokens, index))
.map(|v| ModelInfo::new(client, &v.name, v.max_tokens, index))
.collect()
}

pub fn create_config() -> Result<String> {
let mut client_config = format!("clients:\n - type: {}\n", Self::name());

let api_base = Text::new("api_base:")
.prompt()
.map_err(|_| anyhow!("An error happened when asking for api base, try again later."))?;
let mut client_config = format!("clients:\n - type: {}\n", Self::NAME);

let api_base = prompt_input_api_base()?;
client_config.push_str(&format!(" api_base: {api_base}\n"));

if env::var("AZURE_OPENAI_KEY").is_err() {
let api_key = Text::new("API key:").prompt().map_err(|_| {
anyhow!("An error happened when asking for api key, try again later.")
})?;
let api_key = prompt_input_api_key()?;
client_config.push_str(&format!(" api_key: {api_key}\n"));

client_config.push_str(&format!(" api_key: {api_key}\n"));
}
let model_name = prompt_input_model_name()?;

let model_name = Text::new("Model Name:").prompt().map_err(|_| {
anyhow!("An error happened when asking for model name, try again later.")
})?;

let max_tokens = Text::new("Max tokens:").prompt().map_err(|_| {
anyhow!("An error happened when asking for max tokens, try again later.")
})?;

let ans = Confirm::new("Use proxy?")
.with_default(false)
.prompt()
.map_err(|_| anyhow!("Not finish questionnaire, try again later."))?;

if ans {
let proxy = Text::new("Set proxy:").prompt().map_err(|_| {
anyhow!("An error happened when asking for proxy, try again later.")
})?;
client_config.push_str(&format!(" proxy: {proxy}\n"));
}
let max_tokens = prompt_input_max_token()?;

client_config.push_str(&format!(
" models:\n - name: {model_name}\n max_tokens: {max_tokens}\n"
Expand All @@ -134,50 +115,26 @@ impl AzureOpenAIClient {
Ok(client_config)
}

fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
let messages = self.global_config.read().build_messages(content)?;

let mut body = json!({
"messages": messages,
});

if let Some(v) = self.global_config.read().get_temperature() {
body.as_object_mut()
.and_then(|m| m.insert("temperature".into(), json!(v)));
}

if stream {
body.as_object_mut()
.and_then(|m| m.insert("stream".into(), json!(true)));
}

let client = {
let mut builder = ReqwestClient::builder();
builder = set_proxy(builder, &self.local_config.proxy)?;
let timeout = Duration::from_secs(self.local_config.connect_timeout.unwrap_or(10));
builder
.connect_timeout(timeout)
.build()
.with_context(|| "Failed to build client")?
};
let mut api_base = self.local_config.api_base.clone();
if !api_base.ends_with('/') {
api_base = format!("{api_base}/");
}
fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.config.api_key.clone();
let api_key = api_key
.or_else(|| {
let env_prefix = match &self.config.name {
None => "AZURE".into(),
Some(v) => v.to_uppercase(),
};
env::var(format!("{env_prefix}_OPENAI_KEY")).ok()
})
.ok_or_else(|| anyhow!("Miss api_key"))?;

let body = openai_build_body(data, self.model_info.name.clone());

let url = format!(
"{api_base}openai/deployments/{}/chat/completions?api-version=2023-05-15",
self.model_info.name
"{}/openai/deployments/{}/chat/completions?api-version=2023-05-15",
self.config.api_base, self.model_info.name
);

let mut builder = client.post(url);

if let Some(api_key) = &self.local_config.api_key {
builder = builder.header("api-key", api_key)
} else if let Ok(api_key) = env::var("AZURE_OPENAI_KEY") {
builder = builder.header("api-key", api_key)
}
builder = builder.json(&body);
let builder = client.post(url).header("api-key", api_key).json(&body);

Ok(builder)
}
Expand Down
Loading

0 comments on commit a137483

Please sign in to comment.