Skip to content

Commit

Permalink
Support both llama v1 and llama v2. (#272)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Jul 28, 2023
1 parent 7513a5e commit 50d8273
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
6 changes: 5 additions & 1 deletion candle-examples/examples/llama/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@ fn main() -> Result<()> {
let args = Args::parse();

let device = candle_examples::device(args.cpu)?;
let config = Config::config_7b(args.use_flash_attn);
let config = if args.v1 {
Config::config_7b_v1(args.use_flash_attn)
} else {
Config::config_7b_v2(args.use_flash_attn)
};
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let (llama, tokenizer_filename) = match args.npy {
Expand Down
16 changes: 15 additions & 1 deletion candle-examples/examples/llama/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,21 @@ pub struct Config {
}

impl Config {
pub fn config_7b(use_flash_attn: bool) -> Self {
pub fn config_7b_v1(use_flash_attn: bool) -> Self {
Self {
hidden_size: 4096,
intermediate_size: 11008,
vocab_size: 32000,
n_layer: 32,
n_head: 32,
n_embd: 4096,
n_key_value_head: 32,
use_flash_attn,
rms_norm_eps: 1e-6,
}
}

pub fn config_7b_v2(use_flash_attn: bool) -> Self {
Self {
hidden_size: 4096,
intermediate_size: 11008,
Expand Down

0 comments on commit 50d8273

Please sign in to comment.