Skip to content

Commit

Permalink
Make the dtype configurable for phi. (#2133)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Apr 27, 2024
1 parent 96a48e5 commit 3b429f3
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions candle-examples/examples/phi/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ struct Args {
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,

/// The dtype to be used for running the model, e.g. f32, bf16, or f16.
#[arg(long)]
dtype: Option<String>,
}

fn main() -> Result<()> {
Expand Down Expand Up @@ -345,10 +349,15 @@ fn main() -> Result<()> {
};
Model::Quantized(model)
} else {
let dtype = if args.model == WhichModel::V3 && device.is_cuda() {
DType::BF16
} else {
DType::F32
let dtype = match args.dtype {
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
None => {
if args.model == WhichModel::V3 && device.is_cuda() {
DType::BF16
} else {
DType::F32
}
}
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
match args.model {
Expand Down

0 comments on commit 3b429f3

Please sign in to comment.