From 7cff5898ec2d6fca33a513ea02414c957f3ea0f1 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 17 Aug 2024 20:29:01 +0100 Subject: [PATCH] Support Minus(u) for arbitrary values of u, e.g. Minus(3). (#2428) * Support Minus(u) for arbitrary values of u, e.g. Minus(3). * Forces u to be strictly positive. --- candle-core/src/shape.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index 567a711b3c..90a37be663 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -304,6 +304,7 @@ impl Dim for usize { pub enum D { Minus1, Minus2, + Minus(usize), } impl D { @@ -311,6 +312,7 @@ impl D { let dim = match self { Self::Minus1 => -1, Self::Minus2 => -2, + Self::Minus(u) => -(*u as i32), }; Error::DimOutOfRange { shape: shape.clone(), @@ -327,6 +329,7 @@ impl Dim for D { match self { Self::Minus1 if rank >= 1 => Ok(rank - 1), Self::Minus2 if rank >= 2 => Ok(rank - 2), + Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u), _ => Err(self.out_of_range(shape, op)), } } @@ -336,6 +339,7 @@ impl Dim for D { match self { Self::Minus1 => Ok(rank), Self::Minus2 if rank >= 1 => Ok(rank - 1), + Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u), _ => Err(self.out_of_range(shape, op)), } }