-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add MLP & QLoRA Fused Ops and Kernels, Mixtral (#29)
* refactor Signed-off-by: Yu Chin Fabian Lim <[email protected]> * fixes Signed-off-by: Yu Chin Fabian Lim <[email protected]> * refactor mistral Signed-off-by: Yu Chin Fabian Lim <[email protected]> * add mixtral Signed-off-by: Yu Chin Fabian Lim <[email protected]> * some refactoring after introducing mlp Signed-off-by: Yu Chin Fabian Lim <[email protected]> * remove extranous files Signed-off-by: Yu Chin Fabian Lim <[email protected]> * add bnb Signed-off-by: Yu Chin Fabian Lim <[email protected]> * lint + fmt and improvements to readme Signed-off-by: Yu Chin Fabian Lim <[email protected]> * bench fixes * need to handle lora adapters device due to #26 * allow replay of failed benches, addressing comment in #14 * update benches (remove l40) --------- Signed-off-by: Yu Chin Fabian Lim <[email protected]>
- Loading branch information
Showing
23 changed files
with
626 additions
and
326 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -394,3 +394,10 @@ def apply_lora_o(self, X): | |
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS) | ||
return O | ||
pass | ||
|
||
# added by [email protected] | ||
# this will be patchable on the actual module | ||
def apply_lora_o_v2(self, X): | ||
OW, OW_quant, OA, OB, OS = get_lora_parameters(self) | ||
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS) | ||
return O |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -735,3 +735,10 @@ def apply_lora_o(self, X): | |
Oqstate, OA, OB, OS = get_lora_parameters(self.o_proj) | ||
O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS) | ||
return O | ||
|
||
# added by [email protected] | ||
# this version can be directly patched on the output linear | ||
def apply_lora_o_v2(self, X): | ||
Oqstate, OA, OB, OS = get_lora_parameters(self) | ||
O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS) | ||
return O |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -130,8 +130,9 @@ def backward(ctx, dY): | |
pass | ||
pass | ||
|
||
|
||
def fast_rope_embedding(Q, K, cos, sin): | ||
# modified by [email protected] | ||
# NOTE: fast_rope embeddings currently does not account for position ids | ||
def fast_rope_embedding(Q, K, cos, sin, position_ids=None): | ||
Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2) | ||
K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2) | ||
return Q, K | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.