-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rot kernel #438
Merged
Merged
rot kernel #438
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
92a5d21
start
gabizon103 c918288
Merge branch 'main' into rot
gabizon103 f410406
works but could be better
gabizon103 54282ec
better throughput alt
gabizon103 3ea9b11
comment
gabizon103 93c5715
Merge branch 'main' into rot
gabizon103 185ef41
clean
gabizon103 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import "apps/blas/util.fil"; | ||
import "primitives/signed.fil"; | ||
|
||
// Applies a rotation to vectors x, y | ||
// W: element width | ||
// N: vector length | ||
// M: multiplier amount | ||
// A: adder amount | ||
// Tiles scalar-vector products based on multiplier count, | ||
// then does additions as multiplies finish | ||
comp Rot[W, N, M, A]<'G:II>( | ||
go: interface['G], | ||
c: ['G, 'G+1] W, | ||
s: ['G, 'G+1] W, | ||
x[N]: ['G, 'G+1] W, | ||
y[N]: ['G, 'G+1] W, | ||
) -> ( | ||
out_1[N]: ['G+L, 'G+L+1] W, | ||
out_2[N]: ['G+L, 'G+L+1] W | ||
) with { | ||
some L where L > 0; | ||
some II where II > 0; | ||
} where W > 0, | ||
N > 0, | ||
M > 0, | ||
A > 0, | ||
M % 4 == 0, // need to do at least 4 multiplies at once | ||
A % 2 == 0, // need to do at least 2 adds at once | ||
N % (M/4) == 0, | ||
(M/4) % (A/2) == 0 | ||
{ | ||
// partition mults into 4 groups, for 4 mults that need to happen at each time step | ||
let m = M/4; | ||
|
||
// same for adds, but only 2 adds to do | ||
let a = A/2; | ||
|
||
// reuse for a single scalar-vector computation | ||
let mult_reuses = N / m; | ||
|
||
// reuse for each of the adder groups | ||
let add_reuses = m / a; | ||
|
||
// dummy so we can get its params | ||
M_ := new Multipliers[W, m]; | ||
let mult_latency = M_::L; | ||
let mult_ii = M_::II; | ||
|
||
A_ := new Adders[W, a]; | ||
let add_ii = A_::II; | ||
|
||
// -s | ||
negs := new Neg[W]<'G>(s); | ||
|
||
// instantiate multipliers | ||
let last_mult_invoke = (mult_reuses)*mult_ii; | ||
M_cy := new Multipliers[W, m] in ['G, 'G+last_mult_invoke]; | ||
M_sy := new Multipliers[W, m] in ['G, 'G+last_mult_invoke]; | ||
M_cx := new Multipliers[W, m] in ['G, 'G+last_mult_invoke]; | ||
M_nsx := new Multipliers[W, m] in ['G, 'G+last_mult_invoke]; | ||
|
||
// instantiate adders | ||
let add_end = last_mult_invoke + mult_latency + add_reuses*add_ii + (add_reuses-1)*(mult_reuses-1) + 1; | ||
A_1 := new Adders[W, a] in ['G+mult_latency, 'G+add_end]; | ||
A_2 := new Adders[W, a] in ['G+mult_latency, 'G+add_end]; | ||
|
||
// check which stage is limiting the pipeline | ||
let ii = if (add_end-mult_latency) < (last_mult_invoke) {(last_mult_invoke)} else {(add_end-mult_latency)}; | ||
|
||
bundle cy[mult_reuses][m]: for<k> ['G+k*mult_ii+mult_latency, 'G+k*mult_ii+mult_latency+1] W; | ||
bundle sy[mult_reuses][m]: for<k> ['G+k*mult_ii+mult_latency, 'G+k*mult_ii+mult_latency+1] W; | ||
bundle cx[mult_reuses][m]: for<k> ['G+k*mult_ii+mult_latency, 'G+k*mult_ii+mult_latency+1] W; | ||
bundle nsx[mult_reuses][m]: for<k> ['G+k*mult_ii+mult_latency, 'G+k*mult_ii+mult_latency+1] W; | ||
|
||
// scalar bundles for multiplications | ||
bundle c_bundle[m]: ['G, 'G+1] W; | ||
bundle s_bundle[m]: ['G, 'G+1] W; | ||
bundle negs_bundle[m]: ['G, 'G+1] W; | ||
|
||
// fill them | ||
for i in 0..m { | ||
c_bundle{i} = c; | ||
s_bundle{i} = s; | ||
negs_bundle{i} = negs.out; | ||
} | ||
|
||
// start multiplications | ||
for i in 0..mult_reuses { | ||
// some parameters | ||
let mult_start = i*mult_ii; | ||
let mult_end = i*mult_ii + mult_latency; | ||
|
||
// register inputs | ||
x_reg := new Shift[W, i*mult_ii, m]<'G>(x{i*m..(i+1)*m}); | ||
y_reg := new Shift[W, i*mult_ii, m]<'G>(y{i*m..(i+1)*m}); | ||
c_reg := new Shift[W, i*mult_ii, m]<'G>(c_bundle{0..m}); | ||
s_reg := new Shift[W, i*mult_ii, m]<'G>(s_bundle{0..m}); | ||
negs_reg := new Shift[W, i*mult_ii, m]<'G>(negs_bundle{0..m}); | ||
|
||
mult_cy := M_cy<'G+i*mult_ii>(c_reg.out{0..m}, y_reg.out{0..m}); | ||
mult_sy := M_sy<'G+i*mult_ii>(s_reg.out{0..m}, y_reg.out{0..m}); | ||
mult_cx := M_cx<'G+i*mult_ii>(c_reg.out{0..m}, x_reg.out{0..m}); | ||
mult_nsx := M_nsx<'G+i*mult_ii>(negs_reg.out{0..m}, x_reg.out{0..m}); | ||
|
||
cy{i}{0..m} = mult_cy.out{0..m}; | ||
sy{i}{0..m} = mult_sy.out{0..m}; | ||
cx{i}{0..m} = mult_cx.out{0..m}; | ||
nsx{i}{0..m} = mult_nsx.out{0..m}; | ||
|
||
for j in 0..add_reuses { | ||
let offset = i*(add_reuses-1); | ||
mult_cy_reg := new Shift[W, j*add_ii + offset, a]<'G+mult_end>(cy{i}{(j*a)..(j+1)*a}); | ||
mult_sy_reg := new Shift[W, j*add_ii + offset, a]<'G+mult_end>(sy{i}{(j*a)..(j+1)*a}); | ||
mult_cx_reg := new Shift[W, j*add_ii + offset, a]<'G+mult_end>(cx{i}{(j*a)..(j+1)*a}); | ||
mult_nsx_reg := new Shift[W, j*add_ii + offset, a]<'G+mult_end>(nsx{i}{(j*a)..(j+1)*a}); | ||
|
||
add_1 := A_1<'G + mult_end + j*add_ii + offset>(mult_cx_reg.out{0..a}, mult_sy_reg.out{0..a}); | ||
add_2 := A_2<'G + mult_end + j*add_ii + offset>(mult_nsx_reg.out{0..a}, mult_cy_reg.out{0..a}); | ||
|
||
add_1_reg := new Shift[W, latency - mult_end - j*add_ii - offset, a]<'G + mult_end + j*add_ii + offset>(add_1.out{0..a}); | ||
add_2_reg := new Shift[W, latency - mult_end - j*add_ii - offset, a]<'G + mult_end + j*add_ii + offset>(add_2.out{0..a}); | ||
|
||
out_1{(m*i)+(j*a)..(m*i)+(j+1)*a} = add_1_reg.out{0..a}; | ||
out_2{(m*i)+(j*a)..(m*i)+(j+1)*a} = add_2_reg.out{0..a}; | ||
} | ||
} | ||
|
||
let latency = (mult_reuses*mult_ii + mult_latency) + (add_reuses-1)*(mult_reuses-1) + (add_reuses-1)*add_ii; | ||
L := latency; | ||
II := ii; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import "primitives/core.fil"; | ||
import "apps/blas/scal/scal.fil"; | ||
import "apps/blas/util.fil"; | ||
import "primitives/signed.fil"; | ||
|
||
// Applies a rotation to vectors x, y | ||
// W: element width | ||
// N: vector length | ||
// M: multiplier amount | ||
// A: adder amount | ||
// Uses scal to compute the scalar-vector products cy, sy, cx, -sx | ||
comp Rot[W, N, M, A]<'G:II>( | ||
go: interface['G], | ||
c: ['G, 'G+1] W, | ||
s: ['G, 'G+1] W, | ||
x[N]: ['G, 'G+1] W, | ||
y[N]: ['G, 'G+1] W, | ||
) -> ( | ||
out_1[N]: ['G+L, 'G+L+1] W, | ||
out_2[N]: ['G+L, 'G+L+1] W | ||
) with { | ||
some L where L > 0; | ||
some II where II > 0; | ||
} where W > 0, | ||
L > 0, | ||
N > 0, | ||
M > 0, | ||
A > 0, | ||
M % 4 == 0, | ||
A % 2 == 0 | ||
{ | ||
|
||
scalex := new Scal[W, N, M/4]; | ||
let scale_latency = scalex::L; | ||
let scale_ii = scalex::II; | ||
|
||
zero := new Const[W, 0]<'G>(); | ||
neg_s := new Neg[W]<'G>(s); | ||
|
||
bundle cy[N]: ['G+scale_latency, 'G+scale_latency+1] W; | ||
bundle sy[N]: ['G+scale_latency, 'G+scale_latency+1] W; | ||
bundle cx[N]: ['G+scale_latency, 'G+scale_latency+1] W; | ||
bundle msx[N]: ['G+scale_latency, 'G+scale_latency+1] W; | ||
|
||
SCY := new Scal[W, N, M/4] in ['G, 'G+scale_ii]; | ||
scale_cy := SCY<'G>(y{0..N}, c); | ||
cy{0..N} = scale_cy.out{0..N}; | ||
|
||
SSY := new Scal[W, N, M/4] in ['G, 'G+scale_ii]; | ||
scale_sy := SSY<'G>(y{0..N}, s); | ||
sy{0..N} = scale_sy.out{0..N}; | ||
|
||
SCX := new Scal[W, N, M/4] in ['G, 'G+scale_ii]; | ||
scale_cx := SCX<'G>(x{0..N}, c); | ||
cx{0..N} = scale_cx.out{0..N}; | ||
|
||
SMSX := new Scal[W, N, M/4] in ['G, 'G+scale_ii]; | ||
scale_msx := SMSX<'G>(x{0..N}, neg_s.out); | ||
msx{0..N} = scale_msx.out{0..N}; | ||
|
||
// out_1{i} <- cx{i} + sy{i} | ||
// out_2{i} <- msx{i} + cy{i} | ||
|
||
let add_uses = N / (A/2); | ||
let add_ii = 1; | ||
|
||
// use half the adders for x, half for y | ||
A_x := new Adders[W, A/2] in ['G+scale_latency, 'G+scale_latency+(add_uses-1)*add_ii+1]; | ||
A_y := new Adders[W, A/2] in ['G+scale_latency, 'G+scale_latency+(add_uses-1)*add_ii+1]; | ||
|
||
let latency = scale_latency + (add_uses-1) * add_ii; | ||
for k in 0..add_uses { | ||
// save chunked arrays based on when we are ready to add them | ||
cx_reg := new Shift[W, k*add_ii, A/2]<'G+scale_latency>(cx{k*(A/2)..(k+1)*(A/2)}); | ||
sy_reg := new Shift[W, k*add_ii, A/2]<'G+scale_latency>(sy{k*(A/2)..(k+1)*(A/2)}); | ||
cy_reg := new Shift[W, k*add_ii, A/2]<'G+scale_latency>(cy{k*(A/2)..(k+1)*(A/2)}); | ||
msx_reg := new Shift[W, k*add_ii, A/2]<'G+scale_latency>(msx{k*(A/2)..(k+1)*(A/2)}); | ||
|
||
ax := A_x<'G + scale_latency + k*add_ii>(cx_reg.out{0..(A/2)}, sy_reg.out{0..(A/2)}); | ||
ay := A_y<'G + scale_latency + k*add_ii>(msx_reg.out{0..(A/2)}, cy_reg.out{0..(A/2)}); | ||
|
||
// save add result | ||
ax_reg := new Shift[W, latency - scale_latency - k*add_ii, A/2]<'G + scale_latency + k*add_ii>(ax.out{0..(A/2)}); | ||
ay_reg := new Shift[W, latency - scale_latency - k*add_ii, A/2]<'G + scale_latency + k*add_ii>(ay.out{0..(A/2)}); | ||
|
||
out_1{k*(A/2)..(k+1)*(A/2)} = ax_reg.out{0..(A/2)}; | ||
out_2{k*(A/2)..(k+1)*(A/2)} = ay_reg.out{0..(A/2)}; | ||
} | ||
|
||
L := latency; | ||
// this is a thing we can do now? | ||
let ii = if (add_uses*add_ii) > (scale_ii) {add_uses*add_ii} else {scale_ii}; | ||
II := ii; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# for determining the right answer | ||
|
||
def rot(u, v, c, s): | ||
assert len(u) == len(v) | ||
x = [0] * len(u) | ||
y = [0] * len(v) | ||
for i in range(len(u)): | ||
x[i] = c*u[i] + s*v[i] | ||
y[i] = (-s)*u[i] + c*v[i] | ||
return (x,y) | ||
|
||
if __name__ == '__main__': | ||
u0 = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15] | ||
v0 = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15] | ||
c0 = 1 | ||
s0 = 1 | ||
(x0, y0) = rot(u0, v0, c0, s0) | ||
print(f"x0: {x0}") | ||
print(f"y0: {y0}") | ||
|
||
print("\n======================\n") | ||
|
||
u1 = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] | ||
v1 = [16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1] | ||
c1 = 1 | ||
s1 = 1 | ||
(x1, y1) = rot(u1, v1, c1, s1) | ||
print(f"x1: {x1}") | ||
print(f"y1: {y1}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import "apps/blas/rot/rot-alt.fil"; | ||
|
||
comp main<'G:II>( | ||
go: interface['G], | ||
c: ['G, 'G+1] W, | ||
s: ['G, 'G+1] W, | ||
u_0: ['G, 'G+1] W, | ||
u_1: ['G, 'G+1] W, | ||
u_2: ['G, 'G+1] W, | ||
u_3: ['G, 'G+1] W, | ||
u_4: ['G, 'G+1] W, | ||
u_5: ['G, 'G+1] W, | ||
u_6: ['G, 'G+1] W, | ||
u_7: ['G, 'G+1] W, | ||
u_8: ['G, 'G+1] W, | ||
u_9: ['G, 'G+1] W, | ||
u_10: ['G, 'G+1] W, | ||
u_11: ['G, 'G+1] W, | ||
u_12: ['G, 'G+1] W, | ||
u_13: ['G, 'G+1] W, | ||
u_14: ['G, 'G+1] W, | ||
u_15: ['G, 'G+1] W, | ||
v_0: ['G, 'G+1] W, | ||
v_1: ['G, 'G+1] W, | ||
v_2: ['G, 'G+1] W, | ||
v_3: ['G, 'G+1] W, | ||
v_4: ['G, 'G+1] W, | ||
v_5: ['G, 'G+1] W, | ||
v_6: ['G, 'G+1] W, | ||
v_7: ['G, 'G+1] W, | ||
v_8: ['G, 'G+1] W, | ||
v_9: ['G, 'G+1] W, | ||
v_10: ['G, 'G+1] W, | ||
v_11: ['G, 'G+1] W, | ||
v_12: ['G, 'G+1] W, | ||
v_13: ['G, 'G+1] W, | ||
v_14: ['G, 'G+1] W, | ||
v_15: ['G, 'G+1] W, | ||
) -> ( | ||
x_0: ['G+L, 'G+L+1] W, | ||
x_1: ['G+L, 'G+L+1] W, | ||
x_2: ['G+L, 'G+L+1] W, | ||
x_3: ['G+L, 'G+L+1] W, | ||
x_4: ['G+L, 'G+L+1] W, | ||
x_5: ['G+L, 'G+L+1] W, | ||
x_6: ['G+L, 'G+L+1] W, | ||
x_7: ['G+L, 'G+L+1] W, | ||
x_8: ['G+L, 'G+L+1] W, | ||
x_9: ['G+L, 'G+L+1] W, | ||
x_10: ['G+L, 'G+L+1] W, | ||
x_11: ['G+L, 'G+L+1] W, | ||
x_12: ['G+L, 'G+L+1] W, | ||
x_13: ['G+L, 'G+L+1] W, | ||
x_14: ['G+L, 'G+L+1] W, | ||
x_15: ['G+L, 'G+L+1] W, | ||
y_0: ['G+L, 'G+L+1] W, | ||
y_1: ['G+L, 'G+L+1] W, | ||
y_2: ['G+L, 'G+L+1] W, | ||
y_3: ['G+L, 'G+L+1] W, | ||
y_4: ['G+L, 'G+L+1] W, | ||
y_5: ['G+L, 'G+L+1] W, | ||
y_6: ['G+L, 'G+L+1] W, | ||
y_7: ['G+L, 'G+L+1] W, | ||
y_8: ['G+L, 'G+L+1] W, | ||
y_9: ['G+L, 'G+L+1] W, | ||
y_10: ['G+L, 'G+L+1] W, | ||
y_11: ['G+L, 'G+L+1] W, | ||
y_12: ['G+L, 'G+L+1] W, | ||
y_13: ['G+L, 'G+L+1] W, | ||
y_14: ['G+L, 'G+L+1] W, | ||
y_15: ['G+L, 'G+L+1] W, | ||
) with { | ||
let M = 8; | ||
let N = 16; | ||
let W = 32; | ||
let A = 2; | ||
some L where L > 0; | ||
some II where II > 0; | ||
} { | ||
Rotx := new Rot[W, N, M, A]; | ||
|
||
bundle u[N]: ['G, 'G+1] W; | ||
u{0} = u_0; u{1} = u_1; u{2} = u_2; u{3} = u_3; | ||
u{4} = u_4; u{5} = u_5; u{6} = u_6; u{7} = u_7; | ||
u{8} = u_8; u{9} = u_9; u{10} = u_10; u{11} = u_11; | ||
u{12} = u_12; u{13} = u_13; u{14} = u_14; u{15} = u_15; | ||
|
||
bundle v[N]: ['G, 'G+1] W; | ||
v{0} = v_0; v{1} = v_1; v{2} = v_2; v{3} = v_3; | ||
v{4} = v_4; v{5} = v_5; v{6} = v_6; v{7} = v_7; | ||
v{8} = v_8; v{9} = v_9; v{10} = v_10; v{11} = v_11; | ||
v{12} = v_12; v{13} = v_13; v{14} = v_14; v{15} = v_15; | ||
|
||
r := Rotx<'G>(c, s, u{0..N}, v{0..N}); | ||
|
||
bundle x[N]: ['G+Rotx::L, 'G+Rotx::L+1] W; | ||
x{0..N} = r.out_1{0..N}; | ||
x_0 = x{0}; | ||
x_1 = x{1}; | ||
x_2 = x{2}; | ||
x_3 = x{3}; | ||
x_4 = x{4}; | ||
x_5 = x{5}; | ||
x_6 = x{6}; | ||
x_7 = x{7}; | ||
x_8 = x{8}; | ||
x_9 = x{9}; | ||
x_10 = x{10}; | ||
x_11 = x{11}; | ||
x_12 = x{12}; | ||
x_13 = x{13}; | ||
x_14 = x{14}; | ||
x_15 = x{15}; | ||
|
||
bundle y[N]: ['G+Rotx::L, 'G+Rotx::L+1] W; | ||
y{0..N} = r.out_2{0..N}; | ||
y_0 = y{0}; | ||
y_1 = y{1}; | ||
y_2 = y{2}; | ||
y_3 = y{3}; | ||
y_4 = y{4}; | ||
y_5 = y{5}; | ||
y_6 = y{6}; | ||
y_7 = y{7}; | ||
y_8 = y{8}; | ||
y_9 = y{9}; | ||
y_10 = y{10}; | ||
y_11 = y{11}; | ||
y_12 = y{12}; | ||
y_13 = y{13}; | ||
y_14 = y{14}; | ||
y_15 = y{15}; | ||
|
||
L := Rotx::L; | ||
II := Rotx::II; | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tangential fly-by comment: maybe we should add some syntax to allow asserts to add a reason for the assertion so we can generate better errors. For example when one of these asserts trigger, the comment could show up as the rationale for the restriction.