diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 58be550227..09d5fd49cd 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1409,6 +1409,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "sgemm", DType::F16 => "hgemm", + DType::BF16 => "bgemm", dtype => { return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into()) } diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 1a0d9aca53..ceddc35ef4 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -361,10 +361,8 @@ fn main() -> Result<()> { let dtype = match args.dtype { Some(dtype) => std::str::FromStr::from_str(&dtype)?, None => { - if (args.model == WhichModel::V3 || args.model == WhichModel::V3Medium) - && device.is_cuda() - { - DType::BF16 + if args.model == WhichModel::V3 || args.model == WhichModel::V3Medium { + device.bf16_default_to_f32() } else { DType::F32 }