Is there a way to write a gather function for slices instead of indices? #542
Answered
by
tomsmeding
noughtmare
asked this question in
Q&A
-
QuestionI think I need a function like this: gatherSlices
:: (Shape sh, Elt e, Slice slix)
=> Acc (Array sh slix) -- ^ slices of the source that should be gathered
-> Acc (Array (FullShape slix) e) -- ^ source values
-> Acc (Array (Append sh (SliceShape slix)) e) Where type family Append sh sh' where
Append sh Z = sh
Append sh (sh' :. x) = (Append sh sh') :. x Or for my current use-case it would be enough to have a function for this particular case: gather'
:: (Shape sh, Elt e)
=> Acc (Array sh Int)
-> Acc (Array (Z :. Int :. x) e)
-> Acc (Array (sh :. x) e) Is it possible to implement this? If so, what is the best way to do it. My use caseThe particular example I'm working with is this function from llm.c: void encoder_forward(float* out,
int* inp, float* wte, float* wpe,
int B, int T, int C) {
// out is (B,T,C). At each position (b,t), a C-dimensional vector summarizing token & position
// inp is (B,T) of integers, holding the token ids at each (b,t) position
// wte is (V,C) of token embeddings, short for "weight token embeddings"
// wpe is (maxT,C) of position embeddings, short for "weight positional embedding"
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
// seek to the output position in out[b,t,:]
float* out_bt = out + b * T * C + t * C;
// get the index of the token at inp[b, t]
int ix = inp[b * T + t];
// seek to the position in wte corresponding to the token
float* wte_ix = wte + ix * C;
// seek to the position in wpe corresponding to the position
float* wpe_t = wpe + t * C;
// add the two vectors and store the result in out[b,t,:]
for (int i = 0; i < C; i++) {
out_bt[i] = wte_ix[i] + wpe_t[i];
}
}
}
} I can translate it to Haskell: encoderForward :: MutFloatArray -> IntArray -> FloatArray -> FloatArray -> Int -> Int -> Int -> IO ()
encoderForward !out {- (B,T,C) ? -} !inp {- (B,T) V -} !wte {- (V,C) ? -} !wpe {- (maxT, C) ? -} !bb !tt !cc =
for_ [0 .. bb - 1] $ \b ->
for_ [0 .. tt - 1] $ \t ->
let !ix = inp ! (b * tt + t) in
for_ [0 .. cc - 1] $ \i ->
writeFloatArray out (b * tt * cc + t * cc + i)
$ (wte ! (ix * cc + i)) + (wpe ! (t * cc + i)) I think that in accelerate it would look something like this if I had such a encoderForward = gather' inp wte + replicate (Z :. b :. All :. All) wpe where
Z :. b :. _ = shape inp |
Beta Was this translation helpful? Give feedback.
Answered by
tomsmeding
Apr 13, 2024
Replies: 1 comment 1 reply
-
Is gather' :: (Shape sh, Elt e, x ~ Int)
=> Acc (Array sh Int)
-> Acc (Array (Z :. Int :. x) e)
-> Acc (Array (sh :. x) e)
gather' ixs source =
let I2 _ len = shape source
in backpermute
(shape ixs ::. len)
(\(ix ::. i) -> I2 (ixs ! ix) i)
source Please verify that I correctly interpreted which index goes where :) |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
noughtmare
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Is
gather'
not "just" this?Please verify that I correctly interpreted which index goes where :)