Skip to content

Commit

Permalink
Merge pull request #47 from stanford-futuredata/fix-topology-kernel
Browse files Browse the repository at this point in the history
Fix bug in topology kernel for ffn_hidden_size>4096.
  • Loading branch information
tgale96 authored Dec 8, 2023
2 parents 15634e0 + e74bc20 commit 0460181
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
14 changes: 5 additions & 9 deletions csrc/indices.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
namespace megablocks {
namespace construct_indices {

// We expect the number of outputs per block to be
// small. With ffn_hidden_size=4096, we only need
// to write 32 elements per block per iteration.
// This is the largest we're every likely to use
// so we keep the blocks small.
// We expect the number of outputs per block to be small. For
// example, with ffn_hidden_size=4096, we only need to write
// 32 elements per block per iteration.
const int kThreadsPerBlock = 32;

__global__ void __launch_bounds__(kThreadsPerBlock)
Expand All @@ -39,13 +37,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock)

// Write the indices to the output.
int bin_offset = blockIdx.y;
int tid = threadIdx.x;
int num_rows = end - start;
for (; bin_offset < num_rows; num_rows -= gridDim.y) {
int elements = num_columns;
short *out = indices;
for (; tid < elements; elements -= kThreadsPerBlock) {
*out = threadIdx.x + (blockIdx.x * num_columns);
for (int bid = threadIdx.x; bid < num_columns; bid += kThreadsPerBlock) {
*out = bid + (blockIdx.x * num_columns);
out += kThreadsPerBlock;
}
indices += gridDim.y * num_columns;
Expand Down
4 changes: 3 additions & 1 deletion megablocks/ops/topology_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
(16384, 768, 128),
(16384, 768, 256),
(16384, 768, 512),
(16384, 768, 1024))
(16384, 768, 1024),
(8, 14336, 8),
)


class TopologyTest(parameterized.TestCase):
Expand Down

0 comments on commit 0460181

Please sign in to comment.