From 3b429f30235f20d6c678e3167e20dcf915c56367 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 27 Apr 2024 21:32:49 +0200 Subject: [PATCH] Make the dtype configurable for phi. (#2133) --- candle-examples/examples/phi/main.rs | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index b65a803de..371b389f7 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -209,6 +209,10 @@ struct Args { /// The context size to consider for the repeat penalty. #[arg(long, default_value_t = 64)] repeat_last_n: usize, + + /// The dtype to be used for running the model, e.g. f32, bf16, or f16. + #[arg(long)] + dtype: Option, } fn main() -> Result<()> { @@ -345,10 +349,15 @@ fn main() -> Result<()> { }; Model::Quantized(model) } else { - let dtype = if args.model == WhichModel::V3 && device.is_cuda() { - DType::BF16 - } else { - DType::F32 + let dtype = match args.dtype { + Some(dtype) => std::str::FromStr::from_str(&dtype)?, + None => { + if args.model == WhichModel::V3 && device.is_cuda() { + DType::BF16 + } else { + DType::F32 + } + } }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; match args.model {