Skip to content

Commit

Permalink
fix(sgmv): deadlock in sgmv_shrink kernel caused by skewed segments (#35
Browse files Browse the repository at this point in the history
)
  • Loading branch information
tgaddair authored Jan 9, 2024
1 parent d2bc01c commit 591b598
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
20 changes: 17 additions & 3 deletions csrc/sgmv_flashinfer/sgmv_flashinfer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp,
constexpr auto fill_mode = cp_async::SharedMemFillMode::kFillZero;
const uint32_t problem_id = blockIdx.y;
const uint32_t bx = blockIdx.x;
const uint32_t s_start = s[problem_id], s_end = s[problem_id + 1];
constexpr uint32_t num_stages = 2;
constexpr uint32_t num_k_frags = 8;
constexpr uint32_t num_cells_k = (num_k_frags * 16) / cell_capacity<T>();
Expand All @@ -45,8 +44,9 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp,
uint32_t w_frag[num_k_frags][num_blocks_n][4];
float y_frag[num_blocks_n][8];

for (uint32_t i = 0;
i < (s_end - s_start + (num_warps * 16 - 1)) / (num_warps * 16); ++i) {
const uint32_t s_start = s[problem_id], s_end = s[problem_id + 1];
const uint32_t num_steps = (s_start < s_end) ? (s_end - s_start + (num_warps * 16 - 1)) / (num_warps * 16) : 0;
for (uint32_t i = 0; i < num_steps; ++i) {
// init y_frag
if (bx == 0) {
if constexpr (num_blocks_n == 1) {
Expand Down Expand Up @@ -335,6 +335,20 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp,
}
}
}

// handle the case where one of the segments needs more steps than this one
// to avoid deadlock
if constexpr (cooperative) {
uint32_t max_segment_size = 0;
for (uint32_t i = 0; i < num_problems; ++i) {
max_segment_size = max(max_segment_size, s[i + 1] - s[i]);
}

const uint32_t max_steps = (max_segment_size + (num_warps * 16 - 1)) / (num_warps * 16);
for (uint32_t i = 0; i < max_steps - num_steps; ++i) {
grid.sync();
}
}
}

} // namespace sgmv
Expand Down
8 changes: 7 additions & 1 deletion tests/test_sgmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def get_lora_lens(bs: int, popularity: str) -> list[int]:
a *= alpha
lens.append(bs - sum(lens))
return sorted(lens, reverse=True)
if popularity.startswith("skewed"):
if bs < 3:
return [bs]
# Create a highly imbalanced distribution by setting the first segment
# length to 1 and the remainder to the second segment.
return [1, bs - 1]
raise KeyError(popularity)


Expand Down Expand Up @@ -81,7 +87,7 @@ def lora_ref_impl(
pytest.param("expand", marks=pytest.mark.xfail(reason="TODO: sgmv expand")),
],
)
@pytest.mark.parametrize("popularity", ["distinct", "uniform", "zipf:1.5", "identical"])
@pytest.mark.parametrize("popularity", ["distinct", "uniform", "zipf:1.5", "identical", "skewed"])
@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 7, 10, 16, 32, 64, 133])
@torch.inference_mode()
def test_sgmv_correctness(dtype_str, h, r, direction, popularity, batch_size):
Expand Down

0 comments on commit 591b598

Please sign in to comment.