Skip to content

Commit

Permalink
Use set_params!
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Oct 31, 2024
1 parent 018bda7 commit fa9cef3
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 50 deletions.
82 changes: 34 additions & 48 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::sync::RwLock;

pub mod utils;
pub use utils::BufferOffset;
use utils::{get_block_dims, linear_split, EncoderProvider};
use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider};

const AFFINE: &str = include_str!("affine.metal");
const BINARY: &str = include_str!("binary.metal");
Expand Down Expand Up @@ -1809,25 +1809,27 @@ pub fn call_sdpa_full(
};
let batch_strides = [b_stride_q, b_stride_k, b_stride_v, b_stride_o];

encoder.set_buffer(0, Some(&q_buffer), q_offset as NSUInteger);
encoder.set_buffer(1, Some(&k_buffer), k_offset as NSUInteger);
encoder.set_buffer(2, Some(&v_buffer), v_offset as NSUInteger);
encoder.set_buffer(3, Some(&output), 0);
impl EncoderParam for MLXFastAttentionParams {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
core::mem::size_of::<MLXFastAttentionParams>() as u64,
&data as *const MLXFastAttentionParams as *const c_void,
);
}
}

encoder.set_bytes(
4,
std::mem::size_of::<MLXFastAttentionParams>() as u64,
&params as *const MLXFastAttentionParams as *const c_void,
);
encoder.set_bytes(
6,
(std::mem::size_of::<i32>() * batch_shape.len()) as u64,
batch_shape.as_ptr() as *const i32 as *const c_void,
);
encoder.set_bytes(
7,
(std::mem::size_of::<usize>() * batch_strides.len()) as u64,
batch_strides.as_ptr() as *const c_void,
set_params!(
encoder,
(
(q_buffer, q_offset),
(k_buffer, k_offset),
(v_buffer, v_offset),
output,
params,
&batch_shape[..],
&batch_strides[..]
)
);

let grid_dims = MTLSize {
Expand Down Expand Up @@ -1917,35 +1919,19 @@ pub fn call_sdpa_vector(
// q = (bs, qhead, seq, hidden)
// k/v = (bs, kv_head, kv_seq, hidden)

encoder.set_buffer(0, Some(&q_buffer), q_offset as NSUInteger);
encoder.set_buffer(1, Some(&k_buffer), k_offset as NSUInteger);
encoder.set_buffer(2, Some(&v_buffer), v_offset as NSUInteger);
encoder.set_buffer(3, Some(&output), 0);

encoder.set_bytes(
4,
std::mem::size_of::<i32>() as u64,
&gqa_factor as *const i32 as *const c_void,
);
encoder.set_bytes(
5,
std::mem::size_of::<i32>() as u64,
&n as *const i32 as *const c_void,
);
encoder.set_bytes(
6,
std::mem::size_of::<usize>() as u64,
&stride as *const usize as *const c_void,
);
encoder.set_bytes(
7,
std::mem::size_of::<f32>() as u64,
&alpha as *const f32 as *const c_void,
);
encoder.set_bytes(
8,
std::mem::size_of::<f32>() as u64,
&softcapping as *const f32 as *const c_void,
set_params!(
encoder,
(
(q_buffer, q_offset),
(k_buffer, k_offset),
(v_buffer, v_offset),
output,
gqa_factor,
n,
stride,
alpha,
softcapping
)
);

let grid_dims = MTLSize {
Expand Down
4 changes: 2 additions & 2 deletions candle-metal-kernels/src/scaled_dot_product_attention.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1170,8 +1170,8 @@ template <
const device itype* V [[buffer(2)]], \
device otype* O [[buffer(3)]], \
const constant MLXFastAttentionParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
const constant int* batch_shape [[buffer(5)]], \
const constant size_t* batch_strides [[buffer(6)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
Expand Down

0 comments on commit fa9cef3

Please sign in to comment.