From 957d604a7888bbf0243dbbca83a438db5132b48f Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 1 Aug 2024 10:05:07 +0100 Subject: [PATCH] Enable BF16 on metal. (#2380) --- candle-core/src/metal_backend/mod.rs | 1 + candle-examples/examples/phi/main.rs | 6 ++---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 58be550227..09d5fd49cd 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1409,6 +1409,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "sgemm", DType::F16 => "hgemm", + DType::BF16 => "bgemm", dtype => { return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into()) } diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 1a0d9aca53..ceddc35ef4 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -361,10 +361,8 @@ fn main() -> Result<()> { let dtype = match args.dtype { Some(dtype) => std::str::FromStr::from_str(&dtype)?, None => { - if (args.model == WhichModel::V3 || args.model == WhichModel::V3Medium) - && device.is_cuda() - { - DType::BF16 + if args.model == WhichModel::V3 || args.model == WhichModel::V3Medium { + device.bf16_default_to_f32() } else { DType::F32 }