Skip to content

Commit

Permalink
add/use macro config_get_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden committed Nov 1, 2023
1 parent 4263b59 commit d519245
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 13 deletions.
8 changes: 6 additions & 2 deletions src/client/azure_openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::env;
#[derive(Debug, Clone, Deserialize)]
pub struct AzureOpenAIConfig {
pub name: Option<String>,
pub api_base: String,
pub api_base: Option<String>,
pub api_key: Option<String>,
pub models: Vec<AzureOpenAIModel>,
pub extra: Option<ExtraConfig>,
Expand Down Expand Up @@ -49,6 +49,8 @@ impl Client for AzureOpenAIClient {
}

impl AzureOpenAIClient {
config_get_fn!(api_base, get_api_base);

pub const PROMPTS: [PromptType<'static>; 4] = [
("api_base", "API Base:", true, PromptKind::String),
("api_key", "API Key:", true, PromptKind::String),
Expand Down Expand Up @@ -83,11 +85,13 @@ impl AzureOpenAIClient {
})
.ok_or_else(|| anyhow!("Miss api_key"))?;

let api_base = self.get_api_base()?;

let body = openai_build_body(data, self.model_info.name.clone());

let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version=2023-05-15",
self.config.api_base, self.model_info.name
&api_base, self.model_info.name
);

let builder = client.post(url).header("api-key", api_key).json(&body);
Expand Down
16 changes: 16 additions & 0 deletions src/client/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,22 @@ macro_rules! register_role {
};
}

macro_rules! config_get_fn {
($field_name:ident, $fn_name:ident) => {
fn $fn_name(&self) -> Result<String> {
let api_key = self.config.$field_name.clone();
api_key
.or_else(|| {
let env_prefix = Self::name(&self.config);
let env_name =
format!("{}_{}", env_prefix, stringify!($field_name)).to_ascii_uppercase();
env::var(&env_name).ok()
})
.ok_or_else(|| anyhow::anyhow!("Miss {}", stringify!($field_name)))
}
};
}

#[async_trait]
pub trait Client {
fn config(&self) -> (&SharedConfig, &Option<ExtraConfig>);
Expand Down
8 changes: 3 additions & 5 deletions src/client/localai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ impl Client for LocalAIClient {
}

impl LocalAIClient {
config_get_fn!(api_key, get_api_key);

pub const PROMPTS: [PromptType<'static>; 4] = [
("api_base", "API Base:", true, PromptKind::String),
("api_key", "API Key:", false, PromptKind::String),
Expand All @@ -72,11 +74,7 @@ impl LocalAIClient {
}

fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let api_key = self.config.api_key.clone();
let api_key = api_key.or_else(|| {
let env_prefix = Self::name(&self.config).to_uppercase();
env::var(format!("{env_prefix}_API_KEY")).ok()
});
let api_key = self.get_api_key().ok();

let body = openai_build_body(data, self.model_info.name.clone());

Expand Down
10 changes: 4 additions & 6 deletions src/client/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ impl Client for OpenAIClient {
}

impl OpenAIClient {
config_get_fn!(api_key, get_api_key);

pub const PROMPTS: [PromptType<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];

Expand All @@ -64,15 +66,11 @@ impl OpenAIClient {
}

fn request_builder(&self, client: &ReqwestClient, data: SendData) -> Result<RequestBuilder> {
let env_prefix = Self::name(&self.config).to_uppercase();

let api_key = self.config.api_key.clone();
let api_key = api_key
.or_else(|| env::var(format!("{env_prefix}_API_KEY")).ok())
.ok_or_else(|| anyhow!("Miss api_key"))?;
let api_key = self.get_api_key()?;

let body = openai_build_body(data, self.model_info.name.clone());

let env_prefix = Self::name(&self.config).to_uppercase();
let api_base = env::var(format!("{env_prefix}_API_BASE"))
.ok()
.unwrap_or_else(|| API_BASE.to_string());
Expand Down

0 comments on commit d519245

Please sign in to comment.