Skip to content

Commit

Permalink
axpy BLAS implementation (#424)
Browse files Browse the repository at this point in the history
* axpy impl

* do adds as mults come in

* param over adder use

* fix tests

* param over adders
  • Loading branch information
gabizon103 authored Mar 14, 2024
1 parent 515ed39 commit a2ad5b7
Show file tree
Hide file tree
Showing 5 changed files with 275 additions and 1 deletion.
90 changes: 90 additions & 0 deletions apps/blas/axpy/axpy.fil
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import "apps/blas/scal/scal.fil";
import "apps/blas/util.fil";
import "primitives/reshape.fil";
import "primitives/core.fil";

// Performs a*x + y, where x,y are vectors and a is a scalar
// W: Width of nums
// N: Length of vectors
// M: Number of multipliers
// A: Number of adders
comp Axpy[W, N, M, A]<'G:II>(
go: interface['G],
a: ['G, 'G+1] W,
x[N]: ['G, 'G+1] W,
y[N]: ['G, 'G+1] W,
) -> (
out[N]: ['G+L, 'G+L+1] W
) with {
some L where L > 0;
some II where II > 0;
} where W > 0,
L > 0,
N > 0,
M > 0,
A > 0,
N % M == 0,
M % A == 0
{

let mult_uses = N / M;
let add_uses = M / A;

let add_latency = add_uses;
let add_ii = 1;

Mults := new Multipliers[W, M] in ['G, 'G + mult_uses*Mults::II];
Adds := new Adders[W, A] in ['G + Mults::L, 'G + mult_uses*Mults::II + Mults::L + add_uses*add_ii + mult_uses*(add_uses-1)];

let latency = (mult_uses-1)*Mults::II + Mults::L + add_uses-1 + (mult_uses-1)*(add_uses-1);
L := latency;

II := (mult_uses*Mults::II+Mults::L+add_uses*add_ii+mult_uses*(add_uses-1))-Mults::L; // compiler told me this

bundle mult_out[mult_uses][M]: for<k> ['G+k*Mults::II+Mults::L, 'G+k*Mults::II+Mults::L+1] W;
bundle a_bundle[mult_uses][M]: for<k> ['G+k*Mults::II, 'G+k*Mults::II+1] W;

for j in 0..mult_uses {
let mul_start = j*Mults::II;
let mul_end = j*Mults::II + Mults::L;

a_reg := new Shift[W, mul_start]<'G>(a);

// use inputs
if j == 0 {
// fill bundle with `a`
for i in 0..M {
a_bundle{j}{i} = a;
}
m := Mults<'G+mul_start>(x{j*M..(j+1)*M}, a_bundle{j}{0..M});
mult_out{j}{0..M} = m.out{0..M};
}
// register inputs and use
else {
// fill bundle with `a`
for i in 0..M {
a_bundle{j}{i} = a_reg.out;
}
x_reg := new Shift[W, mul_start, M]<'G>(x{j*M..(j+1)*M});
m := Mults<'G+mul_start>(x_reg.out{0..M}, a_bundle{j}{0..M});
mult_out{j}{0..M} = m.out{0..M};
}

// chunk multiply outputs based on how many adders we have
for k in 0..add_uses {
let offset = j*(add_uses-1);

bundle add_in[A]: ['G+mul_end + k*add_ii + offset, 'G+mul_end + k*add_ii + offset + 1] W;

y_reg := new Shift[W, mul_end + k*add_ii + offset, A]<'G>(y{(M*j)+(k*A)..(M*j)+(k+1)*A});

mult_out_reg := new Shift[W, k*add_ii + offset, A]<'G+mul_end>(mult_out{j}{(k*A)..(k+1)*A});
add_in{0..A} = mult_out_reg.out{0..A};

a := Adds<'G+mul_end+k*add_ii+offset>(add_in{0..A}, y_reg.out{0..A});
add_reg := new Shift[W, latency - mul_end - k*add_ii - offset, A]<'G+mul_end+k*add_ii+offset>(a.out{0..A});

out{(M*j)+(k*A)..(M*j)+(k+1)*A} = add_reg.out{0..A};
}
}
}
2 changes: 2 additions & 0 deletions apps/blas/axpy/test.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"out_0": {"0": [0], "1": [0]}, "out_1": {"0": [1], "1": [2]}, "out_2": {"0": [2], "1": [4]}, "out_3": {"0": [3], "1": [6]}, "out_4": {"0": [4], "1": [8]}, "out_5": {"0": [5], "1": [10]}, "out_6": {"0": [6], "1": [12]}, "out_7": {"0": [7], "1": [14]}, "out_8": {"0": [8], "1": [16]}, "out_9": {"0": [9], "1": [18]}, "out_10": {"0": [10], "1": [20]}, "out_11": {"0": [11], "1": [22]}, "out_12": {"0": [12], "1": [24]}, "out_13": {"0": [13], "1": [26]}, "out_14": {"0": [14], "1": [28]}, "out_15": {"0": [15], "1": [30]}, "cycles": 8}

88 changes: 88 additions & 0 deletions apps/blas/axpy/test.fil
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import "apps/blas/axpy/axpy.fil";

comp main<'G:II>(
go: interface['G],
a: ['G, 'G+1] W,
x_0: ['G, 'G+1] W,
x_1: ['G, 'G+1] W,
x_2: ['G, 'G+1] W,
x_3: ['G, 'G+1] W,
x_4: ['G, 'G+1] W,
x_5: ['G, 'G+1] W,
x_6: ['G, 'G+1] W,
x_7: ['G, 'G+1] W,
x_8: ['G, 'G+1] W,
x_9: ['G, 'G+1] W,
x_10: ['G, 'G+1] W,
x_11: ['G, 'G+1] W,
x_12: ['G, 'G+1] W,
x_13: ['G, 'G+1] W,
x_14: ['G, 'G+1] W,
x_15: ['G, 'G+1] W,
y_0: ['G, 'G+1] W,
y_1: ['G, 'G+1] W,
y_2: ['G, 'G+1] W,
y_3: ['G, 'G+1] W,
y_4: ['G, 'G+1] W,
y_5: ['G, 'G+1] W,
y_6: ['G, 'G+1] W,
y_7: ['G, 'G+1] W,
y_8: ['G, 'G+1] W,
y_9: ['G, 'G+1] W,
y_10: ['G, 'G+1] W,
y_11: ['G, 'G+1] W,
y_12: ['G, 'G+1] W,
y_13: ['G, 'G+1] W,
y_14: ['G, 'G+1] W,
y_15: ['G, 'G+1] W,
) -> (
out_0: ['G+L, 'G+L+1] W,
out_1: ['G+L, 'G+L+1] W,
out_2: ['G+L, 'G+L+1] W,
out_3: ['G+L, 'G+L+1] W,
out_4: ['G+L, 'G+L+1] W,
out_5: ['G+L, 'G+L+1] W,
out_6: ['G+L, 'G+L+1] W,
out_7: ['G+L, 'G+L+1] W,
out_8: ['G+L, 'G+L+1] W,
out_9: ['G+L, 'G+L+1] W,
out_10: ['G+L, 'G+L+1] W,
out_11: ['G+L, 'G+L+1] W,
out_12: ['G+L, 'G+L+1] W,
out_13: ['G+L, 'G+L+1] W,
out_14: ['G+L, 'G+L+1] W,
out_15: ['G+L, 'G+L+1] W,
) with {
let M = 8;
let N = 16;
let W = 32;
let A = 8;
some L where L > 0;
some II where II > 0;
} {

A := new Axpy[W, N, M, A];

bundle x[N]: ['G, 'G+1] W;
bundle y[N]: ['G, 'G+1] W;

x{0} = x_0; x{1} = x_1; x{2} = x_2; x{3} = x_3;
x{4} = x_4; x{5} = x_5; x{6} = x_6; x{7} = x_7;
x{8} = x_8; x{9} = x_9; x{10} = x_10; x{11} = x_11;
x{12} = x_12; x{13} = x_13; x{14} = x_14; x{15} = x_15;

y{0} = y_0; y{1} = y_1; y{2} = y_2; y{3} = y_3;
y{4} = y_4; y{5} = y_5; y{6} = y_6; y{7} = y_7;
y{8} = y_8; y{9} = y_9; y{10} = y_10; y{11} = y_11;
y{12} = y_12; y{13} = y_13; y{14} = y_14; y{15} = y_15;

L := A::L;
II := A::II;

a := A<'G>(a, x{0..N}, y{0..N});

out_0 = a.out{0}; out_1 = a.out{1}; out_2 = a.out{2}; out_3 = a.out{3};
out_4 = a.out{4}; out_5 = a.out{5}; out_6 = a.out{6}; out_7 = a.out{7};
out_8 = a.out{8}; out_9 = a.out{9}; out_10 = a.out{10}; out_11 = a.out{11};
out_12 = a.out{12}; out_13 = a.out{13}; out_14 = a.out{14}; out_15 = a.out{15};
}
35 changes: 35 additions & 0 deletions apps/blas/axpy/test.fil.data
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"a": [0, 1],
"x_0": [0, 0],
"x_1": [1, 1],
"x_2": [2, 2],
"x_3": [3, 3],
"x_4": [4, 4],
"x_5": [5, 5],
"x_6": [6, 6],
"x_7": [7, 7],
"x_8": [8, 8],
"x_9": [9, 9],
"x_10": [10, 10],
"x_11": [11, 11],
"x_12": [12, 12],
"x_13": [13, 13],
"x_14": [14, 14],
"x_15": [15, 15],
"y_0": [0, 0],
"y_1": [1, 1],
"y_2": [2, 2],
"y_3": [3, 3],
"y_4": [4, 4],
"y_5": [5, 5],
"y_6": [6, 6],
"y_7": [7, 7],
"y_8": [8, 8],
"y_9": [9, 9],
"y_10": [10, 10],
"y_11": [11, 11],
"y_12": [12, 12],
"y_13": [13, 13],
"y_14": [14, 14],
"y_15": [15, 15]
}
61 changes: 60 additions & 1 deletion apps/blas/util.fil
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import "primitives/math/math.fil";
import "primitives/core.fil";

// a multiplier with II, latency as output params
comp Mult[W]<'G:1>(
Expand All @@ -24,7 +25,7 @@ comp Multipliers[W, N]<'G:1>(
) -> (
out[N]: ['G+L, 'G+L+1] W
) with {
some L where L >= 0;
some L where L > 0;
some II where II > 0;
} where W > 0 {
Mx := new Mult[W];
Expand All @@ -35,4 +36,62 @@ comp Multipliers[W, N]<'G:1>(
m := new Mult[W]<'G>(x{i}, y{i});
out{i} = m.out;
}
}

// a component that does N additions at once
comp Adders[W, N]<'G:1>(
x[N]: ['G, 'G+1] W,
y[N]: ['G, 'G+1] W
) -> (
out[N]: ['G, 'G+1] W
) where W > 0,
N > 0
{
for i in 0..N {
a := new Add[W]<'G>(x{i}, y{i});
out{i} = a.out;
}
}

// vector addition
// parameterized over adder use if you care about that for some reason
comp VecAdd[W, N, A]<'G:1>(
go: interface['G],
x[N]: ['G, 'G+1] W,
y[N]: ['G, 'G+1] W
) -> (
out[N]: ['G+L, 'G+L+1] W
) with {
some L where L > 0;
} where W > 0,
N > 0,
A > 0
{
let uses = N / A;
Adds := new Adders[W, A];


let latency = uses;
L := latency;

bundle add_out[uses][A]: for<k> ['G+k, 'G+k+1] W;

for j in 0..uses {
// use inputs
if j == 0 {
a := Adds<'G+j>(x{j*A..(j+1)*A}, y{j*A..(j+1)*A});
add_out{j}{0..A} = a.out{0..A};
}
// register inputs and use
else {
x_reg := new Shift[W, j, A]<'G>(x{j*A..(j+1)*A});
y_reg := new Shift[W, j, A]<'G>(y{j*A..(j+1)*A});
a := Adds<'G+j>(x_reg.out{0..A}, y_reg.out{0..A});
add_out{j}{0..A} = a.out{0..A};
}

add_reg := new Shift[W, latency - j, A]<'G+j>(add_out{j}{0..A});
out{j*A..(j+1)*A} = add_reg.out{0..A};
}

}

0 comments on commit a2ad5b7

Please sign in to comment.