Skip to content

Commit

Permalink
Add the dim method to layout and shape.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Nov 5, 2024
1 parent a2471b1 commit fcd2cb3
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
6 changes: 6 additions & 0 deletions candle-core/src/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ impl Layout {
self.shape.dims()
}

/// The dimension size for a specified dimension index.
pub fn dim<D: crate::shape::Dim>(&self, dim: D) -> Result<usize> {
let dim = dim.to_index(&self.shape, "dim")?;
Ok(self.dims()[dim])
}

pub fn shape(&self) -> &Shape {
&self.shape
}
Expand Down
6 changes: 6 additions & 0 deletions candle-core/src/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ impl Shape {
&self.0
}

/// The dimension size for a specified dimension index.
pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
let dim = dim.to_index(self, "dim")?;
Ok(self.dims()[dim])
}

/// The total number of elements, this is the product of all dimension sizes.
pub fn elem_count(&self) -> usize {
self.0.iter().product()
Expand Down
16 changes: 8 additions & 8 deletions candle-nn/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -986,29 +986,29 @@ impl candle::CustomOp3 for Sdpa {

let device = q.device();

let out_dims = vec![q_l.dims()[0], q_l.dims()[1], q_l.dims()[2], v_l.dims()[3]];
let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?];
let elem_count: usize = out_dims.iter().product();

let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?;

// q,k must have matching emb dim
if q_l.dims()[q_l.dims().len() - 1] != k_l.dims()[k_l.dims().len() - 1] {
if q_l.dim(D::Minus1)? != k_l.dim(D::Minus1)? {
candle::bail!("`q` and `k` last dims must match");
}

// k,v must have matching n kv heads
if v_l.dims()[v_l.dims().len() - 3] != k_l.dims()[k_l.dims().len() - 3] {
if v_l.dim(D::Minus(3))? != k_l.dim(D::Minus(3))? {
candle::bail!("`k` and `v` head dims must match");
}

// n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1.
if q_l.dims()[q_l.dims().len() - 3] % k_l.dims()[k_l.dims().len() - 3] != 0 {
if q_l.dim(D::Minus(3))? % k_l.dim(D::Minus(3))? != 0 {
candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`");
}

let k_head = k_l.dims()[k_l.dims().len() - 1];
let q_head = q_l.dims()[q_l.dims().len() - 1];
let q_seq = q_l.dims()[2];
let k_head = k_l.dim(D::Minus1)?;
let q_head = q_l.dim(D::Minus1)?;
let q_seq = q_l.dim(2)?;

let mut implementation_supports_use_case = q_head == k_head;
let supported_head_dim =
Expand Down Expand Up @@ -1076,7 +1076,7 @@ impl candle::CustomOp3 for Sdpa {
)
.map_err(candle::Error::wrap)?;
} else if supports_sdpa_full {
if q_l.dims()[2] != k_l.dims()[2] {
if q_l.dim(2)? != k_l.dim(2)? {
candle::bail!(
"query and key sequence length must be equal if using full metal sdpa"
)
Expand Down

0 comments on commit fcd2cb3

Please sign in to comment.