Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rot kernel #438

Merged
merged 7 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions apps/blas/rot/rot-alt.fil
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import "apps/blas/util.fil";
import "primitives/signed.fil";

// Applies a rotation to vectors x, y
// W: element width
// N: vector length
// M: multiplier amount
// A: adder amount
// Tiles scalar-vector products based on multiplier count,
// then does additions as multiplies finish
comp Rot[W, N, M, A]<'G:II>(
go: interface['G],
c: ['G, 'G+1] W,
s: ['G, 'G+1] W,
x[N]: ['G, 'G+1] W,
y[N]: ['G, 'G+1] W,
) -> (
out_1[N]: ['G+L, 'G+L+1] W,
out_2[N]: ['G+L, 'G+L+1] W
) with {
some L where L > 0;
some II where II > 0;
} where W > 0,
N > 0,
M > 0,
A > 0,
M % 4 == 0, // need to do at least 4 multiplies at once
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tangential fly-by comment: maybe we should add some syntax to allow asserts to add a reason for the assertion so we can generate better errors. For example when one of these asserts trigger, the comment could show up as the rationale for the restriction.

A % 2 == 0, // need to do at least 2 adds at once
N % (M/4) == 0,
(M/4) % (A/2) == 0
{
// partition mults into 4 groups, for 4 mults that need to happen at each time step
let m = M/4;

// same for adds, but only 2 adds to do
let a = A/2;

// reuse for a single scalar-vector computation
let mult_reuses = N / m;

// reuse for each of the adder groups
let add_reuses = m / a;

// dummy so we can get its params
M_ := new Multipliers[W, m];
let mult_latency = M_::L;
let mult_ii = M_::II;

A_ := new Adders[W, a];
let add_ii = A_::II;

// -s
negs := new Neg[W]<'G>(s);

// instantiate multipliers
let last_mult_invoke = (mult_reuses)*mult_ii;
M_cy := new Multipliers[W, m] in ['G, 'G+last_mult_invoke];
M_sy := new Multipliers[W, m] in ['G, 'G+last_mult_invoke];
M_cx := new Multipliers[W, m] in ['G, 'G+last_mult_invoke];
M_nsx := new Multipliers[W, m] in ['G, 'G+last_mult_invoke];

// instantiate adders
let add_end = last_mult_invoke + mult_latency + add_reuses*add_ii + (add_reuses-1)*(mult_reuses-1) + 1;
A_1 := new Adders[W, a] in ['G+mult_latency, 'G+add_end];
A_2 := new Adders[W, a] in ['G+mult_latency, 'G+add_end];

// check which stage is limiting the pipeline
let ii = if (add_end-mult_latency) < (last_mult_invoke) {(last_mult_invoke)} else {(add_end-mult_latency)};

bundle cy[mult_reuses][m]: for<k> ['G+k*mult_ii+mult_latency, 'G+k*mult_ii+mult_latency+1] W;
bundle sy[mult_reuses][m]: for<k> ['G+k*mult_ii+mult_latency, 'G+k*mult_ii+mult_latency+1] W;
bundle cx[mult_reuses][m]: for<k> ['G+k*mult_ii+mult_latency, 'G+k*mult_ii+mult_latency+1] W;
bundle nsx[mult_reuses][m]: for<k> ['G+k*mult_ii+mult_latency, 'G+k*mult_ii+mult_latency+1] W;

// scalar bundles for multiplications
bundle c_bundle[m]: ['G, 'G+1] W;
bundle s_bundle[m]: ['G, 'G+1] W;
bundle negs_bundle[m]: ['G, 'G+1] W;

// fill them
for i in 0..m {
c_bundle{i} = c;
s_bundle{i} = s;
negs_bundle{i} = negs.out;
}

// start multiplications
for i in 0..mult_reuses {
// some parameters
let mult_start = i*mult_ii;
let mult_end = i*mult_ii + mult_latency;

// register inputs
x_reg := new Shift[W, i*mult_ii, m]<'G>(x{i*m..(i+1)*m});
y_reg := new Shift[W, i*mult_ii, m]<'G>(y{i*m..(i+1)*m});
c_reg := new Shift[W, i*mult_ii, m]<'G>(c_bundle{0..m});
s_reg := new Shift[W, i*mult_ii, m]<'G>(s_bundle{0..m});
negs_reg := new Shift[W, i*mult_ii, m]<'G>(negs_bundle{0..m});

mult_cy := M_cy<'G+i*mult_ii>(c_reg.out{0..m}, y_reg.out{0..m});
mult_sy := M_sy<'G+i*mult_ii>(s_reg.out{0..m}, y_reg.out{0..m});
mult_cx := M_cx<'G+i*mult_ii>(c_reg.out{0..m}, x_reg.out{0..m});
mult_nsx := M_nsx<'G+i*mult_ii>(negs_reg.out{0..m}, x_reg.out{0..m});

cy{i}{0..m} = mult_cy.out{0..m};
sy{i}{0..m} = mult_sy.out{0..m};
cx{i}{0..m} = mult_cx.out{0..m};
nsx{i}{0..m} = mult_nsx.out{0..m};

for j in 0..add_reuses {
let offset = i*(add_reuses-1);
mult_cy_reg := new Shift[W, j*add_ii + offset, a]<'G+mult_end>(cy{i}{(j*a)..(j+1)*a});
mult_sy_reg := new Shift[W, j*add_ii + offset, a]<'G+mult_end>(sy{i}{(j*a)..(j+1)*a});
mult_cx_reg := new Shift[W, j*add_ii + offset, a]<'G+mult_end>(cx{i}{(j*a)..(j+1)*a});
mult_nsx_reg := new Shift[W, j*add_ii + offset, a]<'G+mult_end>(nsx{i}{(j*a)..(j+1)*a});

add_1 := A_1<'G + mult_end + j*add_ii + offset>(mult_cx_reg.out{0..a}, mult_sy_reg.out{0..a});
add_2 := A_2<'G + mult_end + j*add_ii + offset>(mult_nsx_reg.out{0..a}, mult_cy_reg.out{0..a});

add_1_reg := new Shift[W, latency - mult_end - j*add_ii - offset, a]<'G + mult_end + j*add_ii + offset>(add_1.out{0..a});
add_2_reg := new Shift[W, latency - mult_end - j*add_ii - offset, a]<'G + mult_end + j*add_ii + offset>(add_2.out{0..a});

out_1{(m*i)+(j*a)..(m*i)+(j+1)*a} = add_1_reg.out{0..a};
out_2{(m*i)+(j*a)..(m*i)+(j+1)*a} = add_2_reg.out{0..a};
}
}

let latency = (mult_reuses*mult_ii + mult_latency) + (add_reuses-1)*(mult_reuses-1) + (add_reuses-1)*add_ii;
L := latency;
II := ii;
}
94 changes: 94 additions & 0 deletions apps/blas/rot/rot.fil
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import "primitives/core.fil";
import "apps/blas/scal/scal.fil";
import "apps/blas/util.fil";
import "primitives/signed.fil";

// Applies a rotation to vectors x, y
// W: element width
// N: vector length
// M: multiplier amount
// A: adder amount
// Uses scal to compute the scalar-vector products cy, sy, cx, -sx
comp Rot[W, N, M, A]<'G:II>(
go: interface['G],
c: ['G, 'G+1] W,
s: ['G, 'G+1] W,
x[N]: ['G, 'G+1] W,
y[N]: ['G, 'G+1] W,
) -> (
out_1[N]: ['G+L, 'G+L+1] W,
out_2[N]: ['G+L, 'G+L+1] W
) with {
some L where L > 0;
some II where II > 0;
} where W > 0,
L > 0,
N > 0,
M > 0,
A > 0,
M % 4 == 0,
A % 2 == 0
{

scalex := new Scal[W, N, M/4];
let scale_latency = scalex::L;
let scale_ii = scalex::II;

zero := new Const[W, 0]<'G>();
neg_s := new Neg[W]<'G>(s);

bundle cy[N]: ['G+scale_latency, 'G+scale_latency+1] W;
bundle sy[N]: ['G+scale_latency, 'G+scale_latency+1] W;
bundle cx[N]: ['G+scale_latency, 'G+scale_latency+1] W;
bundle msx[N]: ['G+scale_latency, 'G+scale_latency+1] W;

SCY := new Scal[W, N, M/4] in ['G, 'G+scale_ii];
scale_cy := SCY<'G>(y{0..N}, c);
cy{0..N} = scale_cy.out{0..N};

SSY := new Scal[W, N, M/4] in ['G, 'G+scale_ii];
scale_sy := SSY<'G>(y{0..N}, s);
sy{0..N} = scale_sy.out{0..N};

SCX := new Scal[W, N, M/4] in ['G, 'G+scale_ii];
scale_cx := SCX<'G>(x{0..N}, c);
cx{0..N} = scale_cx.out{0..N};

SMSX := new Scal[W, N, M/4] in ['G, 'G+scale_ii];
scale_msx := SMSX<'G>(x{0..N}, neg_s.out);
msx{0..N} = scale_msx.out{0..N};

// out_1{i} <- cx{i} + sy{i}
// out_2{i} <- msx{i} + cy{i}

let add_uses = N / (A/2);
let add_ii = 1;

// use half the adders for x, half for y
A_x := new Adders[W, A/2] in ['G+scale_latency, 'G+scale_latency+(add_uses-1)*add_ii+1];
A_y := new Adders[W, A/2] in ['G+scale_latency, 'G+scale_latency+(add_uses-1)*add_ii+1];

let latency = scale_latency + (add_uses-1) * add_ii;
for k in 0..add_uses {
// save chunked arrays based on when we are ready to add them
cx_reg := new Shift[W, k*add_ii, A/2]<'G+scale_latency>(cx{k*(A/2)..(k+1)*(A/2)});
sy_reg := new Shift[W, k*add_ii, A/2]<'G+scale_latency>(sy{k*(A/2)..(k+1)*(A/2)});
cy_reg := new Shift[W, k*add_ii, A/2]<'G+scale_latency>(cy{k*(A/2)..(k+1)*(A/2)});
msx_reg := new Shift[W, k*add_ii, A/2]<'G+scale_latency>(msx{k*(A/2)..(k+1)*(A/2)});

ax := A_x<'G + scale_latency + k*add_ii>(cx_reg.out{0..(A/2)}, sy_reg.out{0..(A/2)});
ay := A_y<'G + scale_latency + k*add_ii>(msx_reg.out{0..(A/2)}, cy_reg.out{0..(A/2)});

// save add result
ax_reg := new Shift[W, latency - scale_latency - k*add_ii, A/2]<'G + scale_latency + k*add_ii>(ax.out{0..(A/2)});
ay_reg := new Shift[W, latency - scale_latency - k*add_ii, A/2]<'G + scale_latency + k*add_ii>(ay.out{0..(A/2)});

out_1{k*(A/2)..(k+1)*(A/2)} = ax_reg.out{0..(A/2)};
out_2{k*(A/2)..(k+1)*(A/2)} = ay_reg.out{0..(A/2)};
}

L := latency;
// this is a thing we can do now?
let ii = if (add_uses*add_ii) > (scale_ii) {add_uses*add_ii} else {scale_ii};
II := ii;
}
29 changes: 29 additions & 0 deletions apps/blas/rot/sim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# for determining the right answer

def rot(u, v, c, s):
assert len(u) == len(v)
x = [0] * len(u)
y = [0] * len(v)
for i in range(len(u)):
x[i] = c*u[i] + s*v[i]
y[i] = (-s)*u[i] + c*v[i]
return (x,y)

if __name__ == '__main__':
u0 = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
v0 = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
c0 = 1
s0 = 1
(x0, y0) = rot(u0, v0, c0, s0)
print(f"x0: {x0}")
print(f"y0: {y0}")

print("\n======================\n")

u1 = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]
v1 = [16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1]
c1 = 1
s1 = 1
(x1, y1) = rot(u1, v1, c1, s1)
print(f"x1: {x1}")
print(f"y1: {y1}")
136 changes: 136 additions & 0 deletions apps/blas/rot/test.fil
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import "apps/blas/rot/rot-alt.fil";

comp main<'G:II>(
go: interface['G],
c: ['G, 'G+1] W,
s: ['G, 'G+1] W,
u_0: ['G, 'G+1] W,
u_1: ['G, 'G+1] W,
u_2: ['G, 'G+1] W,
u_3: ['G, 'G+1] W,
u_4: ['G, 'G+1] W,
u_5: ['G, 'G+1] W,
u_6: ['G, 'G+1] W,
u_7: ['G, 'G+1] W,
u_8: ['G, 'G+1] W,
u_9: ['G, 'G+1] W,
u_10: ['G, 'G+1] W,
u_11: ['G, 'G+1] W,
u_12: ['G, 'G+1] W,
u_13: ['G, 'G+1] W,
u_14: ['G, 'G+1] W,
u_15: ['G, 'G+1] W,
v_0: ['G, 'G+1] W,
v_1: ['G, 'G+1] W,
v_2: ['G, 'G+1] W,
v_3: ['G, 'G+1] W,
v_4: ['G, 'G+1] W,
v_5: ['G, 'G+1] W,
v_6: ['G, 'G+1] W,
v_7: ['G, 'G+1] W,
v_8: ['G, 'G+1] W,
v_9: ['G, 'G+1] W,
v_10: ['G, 'G+1] W,
v_11: ['G, 'G+1] W,
v_12: ['G, 'G+1] W,
v_13: ['G, 'G+1] W,
v_14: ['G, 'G+1] W,
v_15: ['G, 'G+1] W,
) -> (
x_0: ['G+L, 'G+L+1] W,
x_1: ['G+L, 'G+L+1] W,
x_2: ['G+L, 'G+L+1] W,
x_3: ['G+L, 'G+L+1] W,
x_4: ['G+L, 'G+L+1] W,
x_5: ['G+L, 'G+L+1] W,
x_6: ['G+L, 'G+L+1] W,
x_7: ['G+L, 'G+L+1] W,
x_8: ['G+L, 'G+L+1] W,
x_9: ['G+L, 'G+L+1] W,
x_10: ['G+L, 'G+L+1] W,
x_11: ['G+L, 'G+L+1] W,
x_12: ['G+L, 'G+L+1] W,
x_13: ['G+L, 'G+L+1] W,
x_14: ['G+L, 'G+L+1] W,
x_15: ['G+L, 'G+L+1] W,
y_0: ['G+L, 'G+L+1] W,
y_1: ['G+L, 'G+L+1] W,
y_2: ['G+L, 'G+L+1] W,
y_3: ['G+L, 'G+L+1] W,
y_4: ['G+L, 'G+L+1] W,
y_5: ['G+L, 'G+L+1] W,
y_6: ['G+L, 'G+L+1] W,
y_7: ['G+L, 'G+L+1] W,
y_8: ['G+L, 'G+L+1] W,
y_9: ['G+L, 'G+L+1] W,
y_10: ['G+L, 'G+L+1] W,
y_11: ['G+L, 'G+L+1] W,
y_12: ['G+L, 'G+L+1] W,
y_13: ['G+L, 'G+L+1] W,
y_14: ['G+L, 'G+L+1] W,
y_15: ['G+L, 'G+L+1] W,
) with {
let M = 8;
let N = 16;
let W = 32;
let A = 2;
some L where L > 0;
some II where II > 0;
} {
Rotx := new Rot[W, N, M, A];

bundle u[N]: ['G, 'G+1] W;
u{0} = u_0; u{1} = u_1; u{2} = u_2; u{3} = u_3;
u{4} = u_4; u{5} = u_5; u{6} = u_6; u{7} = u_7;
u{8} = u_8; u{9} = u_9; u{10} = u_10; u{11} = u_11;
u{12} = u_12; u{13} = u_13; u{14} = u_14; u{15} = u_15;

bundle v[N]: ['G, 'G+1] W;
v{0} = v_0; v{1} = v_1; v{2} = v_2; v{3} = v_3;
v{4} = v_4; v{5} = v_5; v{6} = v_6; v{7} = v_7;
v{8} = v_8; v{9} = v_9; v{10} = v_10; v{11} = v_11;
v{12} = v_12; v{13} = v_13; v{14} = v_14; v{15} = v_15;

r := Rotx<'G>(c, s, u{0..N}, v{0..N});

bundle x[N]: ['G+Rotx::L, 'G+Rotx::L+1] W;
x{0..N} = r.out_1{0..N};
x_0 = x{0};
x_1 = x{1};
x_2 = x{2};
x_3 = x{3};
x_4 = x{4};
x_5 = x{5};
x_6 = x{6};
x_7 = x{7};
x_8 = x{8};
x_9 = x{9};
x_10 = x{10};
x_11 = x{11};
x_12 = x{12};
x_13 = x{13};
x_14 = x{14};
x_15 = x{15};

bundle y[N]: ['G+Rotx::L, 'G+Rotx::L+1] W;
y{0..N} = r.out_2{0..N};
y_0 = y{0};
y_1 = y{1};
y_2 = y{2};
y_3 = y{3};
y_4 = y{4};
y_5 = y{5};
y_6 = y{6};
y_7 = y{7};
y_8 = y{8};
y_9 = y{9};
y_10 = y{10};
y_11 = y{11};
y_12 = y{12};
y_13 = y{13};
y_14 = y{14};
y_15 = y{15};

L := Rotx::L;
II := Rotx::II;
}
Loading
Loading