Skip to content

Commit

Permalink
Metal bgemm min changes (#2364)
Browse files Browse the repository at this point in the history
* Add updated mfa metallib

* Add bgemm and tests
  • Loading branch information
ivarflakstad authored Aug 1, 2024
1 parent 8696cf6 commit fea46cb
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 4 deletions.
2 changes: 2 additions & 0 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const CAST: &str = include_str!("cast.metal");
const CONV: &str = include_str!("conv.metal");
const REDUCE: &str = include_str!("reduce.metal");
const RANDOM: &str = include_str!("random.metal");
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const QUANTIZED: &str = include_str!("quantized.metal");
const SORT: &str = include_str!("sort.metal");
Expand Down Expand Up @@ -1564,6 +1565,7 @@ pub fn call_gemm(
let bytes = match name {
"sgemm" => 4,
"hgemm" => 2,
"bgemm" => 2,
other => {
return Err(MetalKernelError::LoadLibraryError(format!(
"{other} is not a valid kernel for gemm"
Expand Down
Binary file modified candle-metal-kernels/src/libMetalFlashAttention.metallib
Binary file not shown.
78 changes: 74 additions & 4 deletions candle-metal-kernels/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,7 @@ fn where_cond_u32_f32() {
}

fn run_gemm<T: Clone>(
name: &'static str,
(b, m, n, k): (usize, usize, usize, usize),
lhs: &[T],
lhs_stride: Vec<usize>,
Expand Down Expand Up @@ -1076,7 +1077,7 @@ fn run_gemm<T: Clone>(
&device,
command_buffer,
&kernels,
"sgemm",
name,
(b, m, n, k),
&lhs_stride,
lhs_offset,
Expand All @@ -1100,7 +1101,16 @@ fn gemm() {
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
let results = run_gemm(
"sgemm",
(b, m, n, k),
&lhs,
lhs_stride,
0,
&rhs,
rhs_stride,
0,
);
assert_eq!(
approx(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
Expand All @@ -1111,7 +1121,16 @@ fn gemm() {
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
let results = run_gemm(
"sgemm",
(b, m, n, k),
&lhs,
lhs_stride,
0,
&rhs,
rhs_stride,
0,
);
assert_eq!(
approx(results, 4),
vec![
Expand All @@ -1127,11 +1146,62 @@ fn gemm() {
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
// Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
let results = run_gemm((1, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 12 * 4);
let results = run_gemm(
"sgemm",
(1, m, n, k),
&lhs,
lhs_stride,
0,
&rhs,
rhs_stride,
12 * 4,
);
assert_eq!(
approx(results, 4),
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
);

// bgemm sanity test
let (b, m, n, k) = (1, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
let results = run_gemm(
"bgemm",
(b, m, n, k),
&lhs,
lhs_stride,
0,
&rhs,
rhs_stride,
0,
);
assert_eq!(
approx_bf16(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);

// hgemm sanity test
let (b, m, n, k) = (1, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect();
let results = run_gemm(
"hgemm",
(b, m, n, k),
&lhs,
lhs_stride,
0,
&rhs,
rhs_stride,
0,
);
assert_eq!(
approx_f16(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
}

fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
Expand Down

0 comments on commit fea46cb

Please sign in to comment.