Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add liger kernel with fused cross entropy loss #93

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
5 changes: 4 additions & 1 deletion plugins/fused-ops-and-kernels/configs/fast_kernels.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,7 @@ training:
fast_rms_layernorm: True

# fast RoPE embedding triton kernels
fast_rope_embeddings: True
fast_rope_embeddings: True

# fused linear cross entropy loss
fabianlim marked this conversation as resolved.
Show resolved Hide resolved
fused_linear_loss: False
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
from .framework_plugin_fast_quantized_peft import lora_adapters_switch_ddp_from_fsdp


def validate_plugin_args(configurations):
# Consider making this a more graceful fallback?
assert (
configurations["fused_linear_loss"] != configurations["fast_loss"]
), "If using `fused_linear_loss`, `fast_loss` must be set to False"

# consider rewriting register_foak_model_patch_rules into something
# like this also
def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] = None):
Expand Down Expand Up @@ -68,6 +74,7 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] =
"fast_loss": "cross-ent",
"fast_rms_layernorm": "rms",
"fast_rope_embeddings": "rope",
"fused_linear_loss": "fused-lce",
}


Expand Down Expand Up @@ -115,6 +122,14 @@ def __init__(self, configurations: Dict[str, Dict]):
key="fast_rope_embeddings", values=[False, True], default=True
)
)
self.configurations["fused_linear_loss"] = (
self._check_config_and_maybe_check_values(
key="fused_linear_loss", values=[False, True], default=False
)
)

validate_plugin_args(self.configurations)


@property
def requires_agumentation(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Copyright 2024 Byron Hsu & Linkedin team. All rights reserved.
#
# BSD 2-CLAUSE LICENSE
# Copyright 2024 LinkedIn Corporation
# All Rights Reserved.
# Redistribution and use in source and binary forms, with or
# without modification, are permitted provided that the following
# conditions are met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


import torch
import triton
import triton.language as tl

@triton.jit
def liger_cross_entropy_kernel(
X_ptr,
X_stride,
Y_ptr,
Y_stride,
loss_ptr,
loss_stride,
n_cols,
n_non_ignore,
ignore_index,
label_smoothing: tl.constexpr,
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
BLOCK_SIZE: tl.constexpr,
):
"""
This kernel computes both cross entropy loss and the gradient of the input.
We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.

Parameters:
X_ptr: Pointer to input tensor.
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
loss_ptr: Pointer to tensor to store the loss.
loss_stride (int): The stride of the loss tensor.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduction (str): The string for the reduction to apply
BLOCK_SIZE (int): The block size for Triton operations.
"""

# https://github.com/triton-lang/triton/issues/1058
# If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64
program_id = tl.program_id(0).to(tl.int64)

# 1. Load Y_ptr first because if the target is ignore_index, we can return right away
Y_ptr += program_id * Y_stride
y = tl.load(Y_ptr)

# 2. locate the start index
X_ptr += program_id * X_stride

if y == ignore_index:
# set all X_ptr as 0
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
return

loss_ptr += program_id * loss_stride

# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867

# 3. [Online softmax] first pass: find max + sum
m = float("-inf") # m is the max value. use the notation from the paper
d = 0.0 # d is the sum. use the notation from the paper
ori_X_y = tl.load(
X_ptr + y
) # we need to store the original value of X_y for the loss calculation

# Label smoothing is a general case of normal cross entropy
# See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
scaled_x_sum = 0.0
eps = label_smoothing / n_cols

for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
block_max = tl.max(X_block)
if label_smoothing > 0:
# scale X beforehand to avoid overflow
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
m_new = tl.maximum(m, block_max)
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
m = m_new

# 4. [Online Softmax] Second pass: compute gradients
# For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
# dx_y = (softmax(x_y) - 1) / N
# dx_i = softmax(x_i) / N, i != y
# For label smoothing:
# dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
# = dx_i - (1 - label_smoothing) / N
#
# For 'sum' reduction, no normalization is applied:
# dx_y = softmax(x_y) - 1
# dx_i = softmax(x_i), for i ≠ y
# For label smoothing:
# dx_i = (softmax(x_y) - label_smoothing / V), V = n_cols, i != y
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing))
# = dx_i - (1 - label_smoothing)

for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
if reduction == "mean":
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
else:
X_block = tl.exp(X_block - m) / d - eps

tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)

# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
tl.debug_barrier()

# 5. Calculate the loss

# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
# So we can safely calculate log (softmax(X_y)) without overflow
loss = -(ori_X_y - m - tl.log(d))

# Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
# = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
if label_smoothing > 0:
smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d))
loss = loss * (1 - label_smoothing) + smooth_loss

# Normalize the loss by the number of non-ignored elements if reduction is "mean"
if reduction == "mean":
loss = loss / n_non_ignore

# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
X_y = tl.load(X_ptr + y)
if reduction == "mean":
X_y += -(1 - label_smoothing) / (n_non_ignore)
else:
X_y += -(1 - label_smoothing)

tl.store(loss_ptr, loss)
tl.store(X_ptr + y, X_y)


# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning


@triton.jit
def element_mul_kernel(
X_ptr,
X_stride,
grad_output_ptr,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""
This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
The multiplication is performed in-place on the tensor pointed by X_ptr.

Parameters:
X_ptr: Pointer to the input tensor.
X_stride (int): The stride of the input tensor.
grad_output_ptr: Pointer to the gradient output value.
n_cols (int): The number of columns in the input tensor.
BLOCK_SIZE (int): The block size for Triton operations.
"""

# Get the program ID and convert it to int64 to avoid overflow
program_id = tl.program_id(0).to(tl.int64)

# Locate the start index
X_ptr += program_id * X_stride

# Load the gradient output value
grad_output = tl.load(grad_output_ptr)

# Perform the element-wise multiplication
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)

Loading