From 810323822cf38f1f1a1ecae5363555cd75636714 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Sun, 2 Jun 2024 14:46:07 +0800 Subject: [PATCH] Add MLP & QLoRA Fused Ops and Kernels, Mixtral (#29) * refactor Signed-off-by: Yu Chin Fabian Lim * fixes Signed-off-by: Yu Chin Fabian Lim * refactor mistral Signed-off-by: Yu Chin Fabian Lim * add mixtral Signed-off-by: Yu Chin Fabian Lim * some refactoring after introducing mlp Signed-off-by: Yu Chin Fabian Lim * remove extranous files Signed-off-by: Yu Chin Fabian Lim * add bnb Signed-off-by: Yu Chin Fabian Lim * lint + fmt and improvements to readme Signed-off-by: Yu Chin Fabian Lim * bench fixes * need to handle lora adapters device due to #26 * allow replay of failed benches, addressing comment in #14 * update benches (remove l40) --------- Signed-off-by: Yu Chin Fabian Lim --- README.md | 3 +- .../fms_acceleration_peft/autogptq_utils.py | 6 +- .../framework_plugin_autogptq.py | 25 +- plugins/fused-ops-and-kernels/README.md | 36 +-- .../framework_plugin_fast_quantized_peft.py | 21 +- .../fused_ops/unsloth_lora/bnb/fast_lora.py | 7 + .../fused_ops/unsloth_lora/gptq/fast_lora.py | 7 + .../kernels/unsloth/rope_embedding.py | 5 +- .../fms_acceleration_foak/models/__init__.py | 2 +- .../src/fms_acceleration_foak/models/llama.py | 62 ++++- .../fms_acceleration_foak/models/mistral.py | 72 ++++-- .../fms_acceleration_foak/models/mixtral.py | 104 ++++++++ .../models/model_patcher.py | 25 ++ .../src/fms_acceleration_foak/models/utils.py | 229 +++++++++--------- sample-configurations/CONTENTS.yaml | 8 +- ...eft-bnb-nf4-foak-sample-configuration.yaml | 44 ++++ scripts/benchmarks/benchmark.py | 26 +- scripts/benchmarks/display_bench_results.py | 40 ++- scripts/benchmarks/refs/a100_80gb.csv | 143 ++++++----- scripts/benchmarks/refs/l40_40gb.csv | 49 ---- scripts/benchmarks/scenarios.yaml | 3 +- scripts/generate_sample_configurations.py | 15 +- scripts/run_benchmarks.sh | 20 +- 23 files changed, 626 insertions(+), 326 deletions(-) create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py create mode 100644 sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml delete mode 100644 scripts/benchmarks/refs/l40_40gb.csv diff --git a/README.md b/README.md index a7534ed1..707c8662 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ Plugin | Description | Depends | License | Status --|--|--|--|-- [framework](./plugins/framework/README.md) | This acceleration framework for integration with huggingface trainers | | | Beta [accelerated-peft](./plugins/accelerated-peft/README.md) | For PEFT-training, e.g., 4bit QLoRA. | Huggingface
AutoGPTQ | Apache 2.0
MIT | Beta -[fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 with exclusions. | Coming Soon +[fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 [(contains extracted code)](./plugins/fused-ops-and-kernels/README.md#code-extracted-from-unsloth)| Beta MOE-training-acceleration | [MegaBlocks](https://github.com/databricks/megablocks) inspired triton Kernels and acclerations for Mixture-of-Expert models | | Apache 2.0 | Coming Soon ## Usage with FMS HF Tuning @@ -174,7 +174,6 @@ The benchmarks can be reproduced [with the provided scripts](./scripts/benchmark See below CSV files for various results: - [A100-80GB](./scripts/benchmarks/refs/a100_80gb.csv). -- [L40-40GB](./scripts/benchmarks/refs/l40_40gb.csv). ### Code Architecture diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py index b8a7558d..913a6b7e 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py @@ -28,6 +28,7 @@ # consider making a map if patching more kernels PATCH_FOR_FSDP_TRITON_V2 = ["qweight", "qzeros"] + # This function may be moved after merging # https://github.com/foundation-model-stack/fms-acceleration/pull/25 def _patch_target_module( @@ -123,6 +124,7 @@ def create_new_module_peft( # if module cannot be found, return None which results in a raise in the call-stack return new_module + # consider to move this somewhere more general def patch_forward_to_view_attributes_before_call( old_forward: Callable, @@ -133,9 +135,9 @@ def patch_forward_to_view_attributes_before_call( ): # patch old_forward to view attribtues to torch_dype # before call - + if submodule_names is None: - submodule_names = '' + submodule_names = "" if isinstance(submodule_names, str): submodule_names = [submodule_names] diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py index 30492a2b..7928d9a9 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py @@ -51,13 +51,18 @@ def __init__(self, configurations: Dict[str, Dict]): def model_loader(self, model_name: str, **kwargs): # guarded imports # Third Party - from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig #pylint: disable=import-outside-toplevel,import-error - from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error + from auto_gptq import ( # pylint: disable=import-outside-toplevel,import-error + AutoGPTQForCausalLM, + BaseQuantizeConfig, + ) + from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error + QuantLinear, + ) # Local - from .autogptq_utils import ( #pylint: disable=import-outside-toplevel - patch_forward_to_view_attributes_before_call, - PATCH_FOR_FSDP_TRITON_V2 + from .autogptq_utils import ( # pylint: disable=import-outside-toplevel + PATCH_FOR_FSDP_TRITON_V2, + patch_forward_to_view_attributes_before_call, ) # Currently we allow only a quantized checkpoint to be loaded, we do not @@ -214,8 +219,14 @@ def augmentation( ): # guarded imports # Third Party - from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error - from auto_gptq.utils.peft_utils import GPTQLoraModel, get_gptq_peft_model #pylint: disable=import-outside-toplevel,import-error + from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error + QuantLinear, + ) + from auto_gptq.utils.peft_utils import ( # pylint: disable=import-outside-toplevel,import-error + GPTQLoraModel, + get_gptq_peft_model, + ) + # Local from .autogptq_utils import ( # pylint: disable=import-outside-toplevel create_new_module_peft, diff --git a/plugins/fused-ops-and-kernels/README.md b/plugins/fused-ops-and-kernels/README.md index a1b01d94..a1777671 100644 --- a/plugins/fused-ops-and-kernels/README.md +++ b/plugins/fused-ops-and-kernels/README.md @@ -3,7 +3,7 @@ This library contains fused operations and custom kernels, to be expanded over time. Currently it contains the following: -1. Fused operations and kernels are extracted from [unsloth](#extracted-code-from-unsloth). +1. Fused operations and kernels extracted from [unsloth](#extracted-code-from-unsloth). - Low-Rank Adapter Fused Operations - Fast RoPE Triton Kernels - Fast RMS LayerNorm Triton Kernels @@ -13,42 +13,28 @@ This library contains fused operations and custom kernels, to be expanded over t Plugin | Description | Depends | Loading | Augmentation | Callbacks --|--|--|--|--|-- -[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | Loads fused lora, fast cross-entropy, fast rms, fast RoPE | | | ✅ +[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | LoRA fused ops, fast cross-entropy, fast rms, fast RoPE | Contains extracted code | | ✅ ### Code Extracted from Unsloth - Notes on the extraction of code from [unsloth](https://github.com/unslothai/unsloth): -- while unsloth is released under Apache 2.0, there are [exceptions to the permissive licenses scattered in the code base](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1140-L1143). +- While unsloth is [released under Apache 2.0](https://github.com/unslothai/unsloth/blob/main/LICENSE), there are comments indicating some exceptions strewn throughout the code base, see [an example here](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1140-L1143). ``` - it would require a commercial license if used to run on more than 4 GPUs, see - https://github.com/unslothai/unsloth/blob/d215fd902cf28feb8abcfde2d25281d0fbf9d28c/unsloth/models/llama.py#L1140-L1143 + it would require a commercial license if used to run on more than 4 GPUs ... ``` -- these exceptions appear around [Feb 2024 Release](https://github.com/unslothai/unsloth/commit/3e4c5a323c16bbda2c92212b790073c4e99c2a55), around the model files (namely `llama.py`, `mistral.py`, etc). - * These model files are **not extracted**. -- All code extracted here before the Feb 2024 Release, see table below. +- These exceptions appear to be located around the trainer improvements, see [another example here](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1177-L1183). +- These exceptions appear around [Feb 2024 Release](https://github.com/unslothai/unsloth/commit/3e4c5a323c16bbda2c92212b790073c4e99c2a55); any code that appears in any file where such exceptions occur **is not extracted**. +- Instead in its place, we have adopted a different approach; we adopt the approach of model patching, as opposed unsloths' approach to rewrite the model. Our approach is novel and **completely rewritten from scratch**. +- All extracted code appears before the Feb 2024 Release. +- In the table below we record what was extracted, and the exact commit from which it was taken. Path | Description | Extracted From | Modifications | Date --|--|--|--|-- [fused_ops/unsloth_lora](./src/fms_acceleration_foak/fused_ops/unsloth_lora) | QLoRA fast dequant, activation kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | | 28 Jan 2024 -[fused_ops/unsloth_lora/bnb](./src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb) | BNB fast lora | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | | 28 Jan 2024 +[fused_ops/unsloth_lora/bnb](./src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb) | BNB fast lora | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `fast_lora.py` | 28 Jan 2024 [fused_ops/unsloth_lora/gptq](./src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq) | GPTQ fast dequant (triton_v2) | `jeromeku/main` @ [2839d39](https://github.com/jeromeku/unsloth/commit/2839d390ef3bb318904289bfb9a7751a782c4e44) | `fast_lora.py`
`triton/layers.py` | 6 Feb 2024 -[kernels/unsloth](./src/fms_acceleration_foak/kernels/unsloth) | Fast RMS, RoPE, CrossEnt kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `cross_entropy_loss.py` | 28 Jan 2024 - - - +[kernels/unsloth](./src/fms_acceleration_foak/kernels/unsloth) | Fast RMS, RoPE, CrossEnt kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `cross_entropy_loss.py`
`rms_layernorm.py` | 28 Jan 2024 ## Known Issues diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py index ad0a399c..7eab87f0 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py @@ -16,6 +16,7 @@ from typing import Callable, Dict, Tuple # Third Party +from accelerate.utils import set_module_tensor_to_device from fms_acceleration import AccelerationPlugin from peft import LoraConfig from peft.tuners.lora.layer import LoraLayer @@ -63,9 +64,20 @@ def _all_reduce_hook(grad): return grad for mod in modules: + # NOTE: assuming lora has no bias + A = mod.lora_A.default + B = mod.lora_B.default + # install hooks on the adapters - mod.lora_A.default.weight.register_hook(_all_reduce_hook) - mod.lora_B.default.weight.register_hook(_all_reduce_hook) + A.weight.register_hook(_all_reduce_hook) + B.weight.register_hook(_all_reduce_hook) + + # because we will ignore these from FSDP, we need to manually + # move them to gpu if they are already not on them + if not A.weight.is_cuda: + set_module_tensor_to_device(A, "weight", "cuda") + if not B.weight.is_cuda: + set_module_tensor_to_device(B, "weight", "cuda") class FastQuantizedPeftAccelerationPlugin(AccelerationPlugin): @@ -82,10 +94,7 @@ def __init__(self, configurations: Dict[str, Dict]): self._base_layer = self._check_config_and_maybe_check_values( key="peft.quantization.fused_ops_and_kernels.base_layer", - values=[ - "auto_gptq", - # "bitsandbytes" # enable later when we have BNB implemented - ], + values=["auto_gptq", "bitsandbytes"], ) # only support these at the moment diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py index 82f78f74..71d7070c 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py @@ -394,3 +394,10 @@ def apply_lora_o(self, X): O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS) return O pass + +# added by flim@sg.ibm.com +# this will be patchable on the actual module +def apply_lora_o_v2(self, X): + OW, OW_quant, OA, OB, OS = get_lora_parameters(self) + O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS) + return O \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py index 3808fba7..ee5055ed 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py @@ -735,3 +735,10 @@ def apply_lora_o(self, X): Oqstate, OA, OB, OS = get_lora_parameters(self.o_proj) O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS) return O + +# added by flim@sg.ibm.com +# this version can be directly patched on the output linear +def apply_lora_o_v2(self, X): + Oqstate, OA, OB, OS = get_lora_parameters(self) + O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS) + return O diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py index 49b04fce..3577b586 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py @@ -130,8 +130,9 @@ def backward(ctx, dY): pass pass - -def fast_rope_embedding(Q, K, cos, sin): +# modified by flim@sg.ibm.com +# NOTE: fast_rope embeddings currently does not account for position ids +def fast_rope_embedding(Q, K, cos, sin, position_ids=None): Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2) K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2) return Q, K diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py index 7d6df3bc..ebd49924 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py @@ -15,7 +15,7 @@ # Local from .model_patcher import ModelPatcher -PATCHES = [".models.llama", ".models.mistral"] +PATCHES = [".models.llama", ".models.mistral", ".models.mixtral"] PLUGIN_PREFIX = "fms_acceleration_foak" # TODO: remove the need for the prefix diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py index 3d01311a..290d1217 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py @@ -16,14 +16,24 @@ from functools import partial # Third Party -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaMLP, + LlamaRMSNorm, +) # Local from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm from ..kernels.unsloth.rope_embedding import fast_rope_embedding -from .model_patcher import ModelPatcher, ModelPatcherRule, ModelPatcherTrigger -from .utils import build_lora_fused_ops, trigger_fused_ops +from .model_patcher import ( + ModelPatcher, + ModelPatcherRule, + ModelPatcherTrigger, + combine_functions, + combine_triggers, +) +from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops # TODO: have a generic version of this rule # - do regex on RMSNorm class name @@ -42,18 +52,54 @@ ModelPatcher.register( ModelPatcherRule( rule_id="llama-qkvo", + trigger=combine_triggers( + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=LlamaAttention, + submodule_names=["q_proj", "k_proj", "v_proj"], + ) + ), + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=LlamaAttention, + submodule_names=["o_proj"], + ) + ), + logic="OR", + ), + forward_builder=combine_functions( + partial( + build_lora_fused_ops, + submodule_names=["q_proj", "k_proj", "v_proj"], + fused_op=KEY_QKV, + ), + partial( + build_lora_fused_ops, + submodule_names=["o_proj"], + fused_op=KEY_O, + ), + logic="APPEND", + ), + forward_builder_args=["base_type"], + ) +) + +ModelPatcher.register( + ModelPatcherRule( + rule_id="llama-mlp", trigger=ModelPatcherTrigger( check=partial( trigger_fused_ops, - attn_cls=LlamaAttention, - qkv_module_names=["q_proj", "k_proj", "v_proj"], - o_module_name="o_proj", + attn_cls=LlamaMLP, + submodule_names=["up_proj", "down_proj", "gate_proj"], ) ), forward_builder=partial( build_lora_fused_ops, - qkv_module_names=["q_proj", "k_proj", "v_proj"], - o_module_name="o_proj", + submodule_names=["up_proj", "down_proj", "gate_proj"], + fused_op=KEY_MLP, ), forward_builder_args=["base_type"], ) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py index a8e6795f..37809fd1 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py @@ -18,22 +18,22 @@ # Third Party from transformers.models.mistral.modeling_mistral import ( MistralAttention, + MistralMLP, MistralRMSNorm, ) # Local from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm -from ..kernels.unsloth.rope_embedding import fast_rope_embedding as _fast_rope_embedding -from .model_patcher import ModelPatcher, ModelPatcherRule, ModelPatcherTrigger -from .utils import build_lora_fused_ops, trigger_fused_ops - - -# NOTE: fast_rope_embedding does not work with position_ids -# currently they are ignored -def fast_rope_embedding(Q, K, cos, sin, position_ids=None): - return _fast_rope_embedding(Q, K, cos, sin) - +from ..kernels.unsloth.rope_embedding import fast_rope_embedding +from .model_patcher import ( + ModelPatcher, + ModelPatcherRule, + ModelPatcherTrigger, + combine_functions, + combine_triggers, +) +from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops # - do regex on RMSNorm class name # - check on the tensors required for fast_rms_layernorm @@ -45,29 +45,62 @@ def fast_rope_embedding(Q, K, cos, sin, position_ids=None): ), ) -# - do regex on Attention class name -# - have a set of qkv / o module names and check on that ModelPatcher.register( ModelPatcherRule( rule_id="mistral-qkvo", + trigger=combine_triggers( + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=MistralAttention, + submodule_names=["q_proj", "k_proj", "v_proj"], + ) + ), + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=MistralAttention, + submodule_names=["o_proj"], + ) + ), + logic="OR", + ), + forward_builder=combine_functions( + partial( + build_lora_fused_ops, + submodule_names=["q_proj", "k_proj", "v_proj"], + fused_op=KEY_QKV, + ), + partial( + build_lora_fused_ops, + submodule_names=["o_proj"], + fused_op=KEY_O, + ), + logic="APPEND", + ), + forward_builder_args=["base_type"], + ) +) + +ModelPatcher.register( + ModelPatcherRule( + rule_id="mistral-mlp", trigger=ModelPatcherTrigger( check=partial( trigger_fused_ops, - attn_cls=MistralAttention, - qkv_module_names=["q_proj", "k_proj", "v_proj"], - o_module_name="o_proj", + attn_cls=MistralMLP, + submodule_names=["up_proj", "down_proj", "gate_proj"], ) ), forward_builder=partial( build_lora_fused_ops, - qkv_module_names=["q_proj", "k_proj", "v_proj"], - o_module_name="o_proj", + submodule_names=["up_proj", "down_proj", "gate_proj"], + fused_op=KEY_MLP, ), forward_builder_args=["base_type"], ) ) -# - get the module_name and reload on that ModelPatcher.register( ModelPatcherRule( rule_id="mistral-cross-ent", @@ -79,9 +112,6 @@ def fast_rope_embedding(Q, K, cos, sin, position_ids=None): ) ) -# - get the module name -# - check if "apply_rotary_pos_emb" exists -# - patch ModelPatcher.register( ModelPatcherRule( rule_id="mistral-rope", diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py new file mode 100644 index 00000000..1522ef8d --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py @@ -0,0 +1,104 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from functools import partial + +# Third Party +from transformers.models.mixtral.modeling_mixtral import ( + MixtralAttention, + MixtralRMSNorm, +) + +# Local +from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss +from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm +from ..kernels.unsloth.rope_embedding import fast_rope_embedding +from .model_patcher import ( + ModelPatcher, + ModelPatcherRule, + ModelPatcherTrigger, + combine_functions, + combine_triggers, +) +from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops + +# - do regex on RMSNorm class name +# - check on the tensors required for fast_rms_layernorm +ModelPatcher.register( + ModelPatcherRule( + rule_id="mixtral-rms", + trigger=ModelPatcherTrigger(check=MixtralRMSNorm), + forward=fast_rms_layernorm, + ), +) + +ModelPatcher.register( + ModelPatcherRule( + rule_id="mixtral-qkvo", + trigger=combine_triggers( + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=MixtralAttention, + submodule_names=["q_proj", "k_proj", "v_proj"], + ) + ), + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=MixtralAttention, + submodule_names=["o_proj"], + ) + ), + logic="OR", + ), + forward_builder=combine_functions( + partial( + build_lora_fused_ops, + submodule_names=["q_proj", "k_proj", "v_proj"], + fused_op=KEY_QKV, + ), + partial( + build_lora_fused_ops, + submodule_names=["o_proj"], + fused_op=KEY_O, + ), + logic="APPEND", + ), + forward_builder_args=["base_type"], + ) +) + +ModelPatcher.register( + ModelPatcherRule( + rule_id="mixtral-cross-ent", + import_and_maybe_reload=( + "torch.nn.CrossEntropyLoss", + FastCrossEntropyLoss, + "transformers.models.mixtral.modeling_mixtral", + ), + ) +) + +ModelPatcher.register( + ModelPatcherRule( + rule_id="mixtral-rope", + import_and_maybe_reload=( + "transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb", + fast_rope_embedding, + None, + ), + ) +) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py index 3355aa67..7f803330 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py @@ -468,3 +468,28 @@ def patch_model(model: torch.nn.Module, **kwargs): def patch_model_summary(): return ModelPatcher.summary() + + +def combine_triggers(*triggers: ModelPatcherTrigger, logic: str = "OR"): + assert logic == "OR", "only OR logic implemented for combining triggers" + + # NOTE: this can be probably simplified + def _or_logic(*args, **kwargs): + for trig in triggers: + if trig.check(*args, **kwargs): + return True + return False + + return ModelPatcherTrigger(check=_or_logic) + + +def combine_functions(*funcs: Callable, logic: str = "APPEND"): + assert logic == "APPEND", "only APPEND logic implemented for combining functions" + + def _append(*args, **kwargs): + results = [] + for f in funcs: + results += f(*args, **kwargs) + return results + + return _append diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py index b048b8e4..10819fc0 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py @@ -1,34 +1,40 @@ -# Copyright The FMS HF Tuning Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - # Standard +from functools import partial from typing import Callable, List, Type +import os # Third Party import torch -import os # Local -# GPTQ imports -from ..fused_ops.unsloth_lora.gptq.fast_lora import LoRA_W as LoRA_W_gptq -from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_qkv as fused_op_qkv_gptq -from ..fused_ops.unsloth_lora.gptq.fast_lora import ( - get_lora_parameters as get_lora_parameters_gptq, +# NOTE: the default activation is swiglu in both cases +from ..fused_ops.unsloth_lora.bnb.fast_lora import ( + apply_lora_mlp_swiglu as fused_op_mlp_bnb, ) -from ..fused_ops.unsloth_lora.gptq.fast_lora import unpack_gptqstate +from ..fused_ops.unsloth_lora.bnb.fast_lora import apply_lora_o_v2 as fused_op_o_bnb +from ..fused_ops.unsloth_lora.bnb.fast_lora import apply_lora_qkv as fused_op_qkv_bnb +from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_mlp as fused_op_mlp_gptq +from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_o_v2 as fused_op_o_gptq +from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_qkv as fused_op_qkv_gptq from .model_patcher import ModelPatcherTrigger +KEY_QKV = "qkv" +KEY_O = "o" +KEY_MLP = "mlp" + +FUSED_OPS = { + "auto_gptq": { + KEY_QKV: fused_op_qkv_gptq, + KEY_O: fused_op_o_gptq, + KEY_MLP: fused_op_mlp_gptq, + }, + "bitsandbytes": { + KEY_QKV: fused_op_qkv_bnb, + KEY_O: fused_op_o_bnb, + KEY_MLP: fused_op_mlp_bnb, + }, +} + # simple utility function to guess if its lora layer def _is_loralayer(module: torch.nn.Module, names: List[str] = None): @@ -45,15 +51,15 @@ def _is_loralayer(module: torch.nn.Module, names: List[str] = None): # modules are called q_proj, k_proj, and v_proj, respectively. # the fused operation can be changed, depending on what the base layer is # i.e. gptq or bnb -def _build_qkv_forwards( +def _build_fused_forwards( attn: torch.nn.Module, fused_operation: Callable = fused_op_qkv_gptq, - module_names: List[str] = None, + submodule_names: List[str] = None, ): - if module_names is None: - module_names = ["q_proj", "k_proj", "v_proj"] + # fused opts expected to produce singular or multiple results + # module names must be passed in order of what the fused - Q = K = V = None + outs = {} # the fused operation will be called on first one that passes in the # input X. @@ -61,62 +67,52 @@ def _build_qkv_forwards( # - subsequent calls will be a no-op until ALL Q, K, V get reset to None def _fused_op(X): - nonlocal Q, K, V - if Q is None and K is None and V is None: - Q, K, V = fused_operation(attn, X) + # if all of the outs are not yet populated + if all(x not in outs for x in submodule_names): + fused_outs = fused_operation(attn, X) + try: + fused_outs = list(fused_outs) # not sure if this is correct + except TypeError: + # if fused_outs is not iterable + fused_outs = [fused_outs] + for n, x in zip(submodule_names, fused_outs): + outs[n] = x # each of these functions # - calls the fused op # - - error_msg = ( - "QKV fused_op needs to be first reset with sequential calls to each of them" - ) - - def _forward_q(self, X): - nonlocal Q - _fused_op(X) - assert Q is not None, error_msg - out, Q = Q, None # unload - return out - - def _forward_k(self, X): - nonlocal K - _fused_op(X) - assert K is not None, error_msg - out, K = K, None # unload - return out - def _forward_v(self, X): - nonlocal V + def _forward(self, X, name: str): _fused_op(X) - assert V is not None, error_msg - out, V = V, None # unload - return out - - return zip(module_names, [_forward_q, _forward_k, _forward_v]) - + assert ( + name in outs + ), "Fused_op needs to be first reset with sequential calls to each of them" + V = outs[name] + del outs[name] + return V -# fused ops for outputs for GPTQ -def fused_op_o_gptq(self, X): - Oqstate, OA, OB, OS = get_lora_parameters_gptq(self) - O = LoRA_W_gptq.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS) - return O + return zip(submodule_names, [partial(_forward, name=n) for n in submodule_names]) -# TODO: add the MLP def build_lora_fused_ops( attn: torch.nn.Module, base_type: str = "auto_gptq", - qkv_module_names: List[str] = None, - o_module_name: str = "o_proj", + submodule_names: List[str] = None, + fused_op: str = KEY_QKV, ): - if qkv_module_names is None: - qkv_module_names = ["q_proj", "k_proj", "v_proj"] - # handle the QKVs + assert ( + len(submodule_names) > 0 + ), "When building lora fused ops requires more than one submodule." + + if submodule_names is None: + submodule_names = ["q_proj", "k_proj", "v_proj"] + + # get the fused op + fused_operation = FUSED_OPS[base_type][fused_op] + + # handle casting issues if base_type == "auto_gptq": - _qkv_fused_op = fused_op_qkv_gptq - _o_fused_op = fused_op_o_gptq # this is required due to this FSDP fix # https://github.com/foundation-model-stack/fms-acceleration/pull/15 @@ -131,55 +127,60 @@ def build_lora_fused_ops( ): # guarded import - from fms_acceleration_peft.autogptq_utils import ( #pylint: disable=import-outside-toplevel - patch_forward_to_view_attributes_before_call, - PATCH_FOR_FSDP_TRITON_V2 + # pylint: disable=import-outside-toplevel,import-error + # Third Party + from fms_acceleration_peft.autogptq_utils import ( + PATCH_FOR_FSDP_TRITON_V2, + patch_forward_to_view_attributes_before_call, ) # patch each of the fused ops to view the attributes # back into torch.int32 - # TODO: add the MLP fused op also - _qkv_fused_op = patch_forward_to_view_attributes_before_call( - _qkv_fused_op, - PATCH_FOR_FSDP_TRITON_V2, torch.int32, - submodule_names=[ - n + '.base_layer' for n in qkv_module_names - ], - is_method_forward=False, - ) - _o_fused_op = patch_forward_to_view_attributes_before_call( - _o_fused_op, - PATCH_FOR_FSDP_TRITON_V2, torch.int32, - submodule_names='base_layer', + # - if there are multiple submodules, then we assume that + # 'fused_operation' will be called on module that has + # submodules specified in 'submodule_names'. + # - otherwise if there is only a single 'submodule_name', then + # assume that 'fused_operation' called on the submodule specified + # by 'submodule_name' itself + if len(submodule_names) > 1: + patched_submodule_names = [n + ".base_layer" for n in submodule_names] + else: + # otherwise assume calling on the 'submodule_name' itself + # so its just the base layer. + patched_submodule_names = "base_layer" + + fused_operation = patch_forward_to_view_attributes_before_call( + fused_operation, + PATCH_FOR_FSDP_TRITON_V2, + torch.int32, + submodule_names=patched_submodule_names, is_method_forward=False, ) - else: - raise NotImplementedError( - f"Cannot build fused ops for base type '{base_type}'." - ) - - trigger_and_forwards = [ - (ModelPatcherTrigger(check=_is_loralayer, module_name=name), forward) - for name, forward in _build_qkv_forwards( - attn, - fused_operation=_qkv_fused_op, - module_names=qkv_module_names, - ) - ] - - # handle the self-attn output - _output_module = getattr(attn, o_module_name) - if _is_loralayer(_output_module): - trigger_and_forwards.append( + if fused_op == KEY_QKV: + return [ + (ModelPatcherTrigger(check=_is_loralayer, module_name=name), forward) + for name, forward in _build_fused_forwards( + attn, + fused_operation=fused_operation, + submodule_names=submodule_names, + ) + ] + if fused_op == KEY_O: + # otherwise its just a single op + submodule_names = submodule_names[0] + return [ ( - ModelPatcherTrigger(check=_is_loralayer, module_name=o_module_name), - _o_fused_op, + ModelPatcherTrigger(check=_is_loralayer, module_name=submodule_names), + fused_operation, ) - ) + ] + if fused_op == KEY_MLP: + # otherwise just return the fused_op that should be attached at the + # top MLP level + return fused_operation - # return - return trigger_and_forwards + raise NotImplementedError(f"Unknown fused op '{fused_op}'") # trigger if either of the conditions are met @@ -188,16 +189,10 @@ def build_lora_fused_ops( def trigger_fused_ops( module: torch.nn.Module, attn_cls: Type, - qkv_module_names: List[str] = None, - o_module_name: str = "o_proj", + submodule_names: List[str], ): - if qkv_module_names is None: - qkv_module_names = ["q_proj", "k_proj", "v_proj"] - - _o = getattr(module, o_module_name) - _qkv = [getattr(module, x) for x in qkv_module_names] - # trigger on the attention layer - return isinstance(module, attn_cls) and ( - all(_is_loralayer(x) for x in _qkv) or _is_loralayer(_o) - ) + # trigger if the module meets the attn class and the submodules + # are all loralayers + _mods = [getattr(module, x) for x in submodule_names] + return isinstance(module, attn_cls) and all(_is_loralayer(x) for x in _mods) diff --git a/sample-configurations/CONTENTS.yaml b/sample-configurations/CONTENTS.yaml index c43a5adf..75f7279b 100644 --- a/sample-configurations/CONTENTS.yaml +++ b/sample-configurations/CONTENTS.yaml @@ -25,4 +25,10 @@ framework_configs: plugins: - accelerated-peft - fused-ops-and-kernels - filename: accelerated-peft-autogptq-foak-sample-configuration.yaml \ No newline at end of file + filename: accelerated-peft-autogptq-foak-sample-configuration.yaml + + - shortname: accelerated-peft-bnb-foak + plugins: + - accelerated-peft + - fused-ops-and-kernels + filename: accelerated-peft-bnb-nf4-foak-sample-configuration.yaml \ No newline at end of file diff --git a/sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml b/sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml new file mode 100644 index 00000000..fcb9bb14 --- /dev/null +++ b/sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml @@ -0,0 +1,44 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + # PEFT-related acceleration + peft: + + # quantization-releated acceleration + # e.g., kernels for quantized base weights + quantization: + + # For loading BitsAndBytes quantized layers + # to serve as 4bit base-weights for LoRA PEFT-tuning. + # NOTE: currently AutoGPTQ is not properly integrated into huggingface / + # bitsandbytes, thus recommended quant_type to be either "nf4" + # or "fp4". + # bitsandbytes: + bitsandbytes: + quant_type: nf4 + + # If True, then no get_peft_model and prepare_model_for_kbit_training + # will be called. + no_peft_model: false + fused_ops_and_kernels: + + # load unsloth optimizations for these 4bit base layer weights. + # currently only support "auto_gptq" and "bitsandbytes" + base_layer: bitsandbytes + + # activate various unsloth optimizations + # NOTE: currently supports only all-or-nothing. + + # fused kernels for lora linear layers + fused_lora: true + + # fast loss triton kernels + fast_loss: true + + # fast rms norm triton kernels + fast_rsm_layernorm: true + + # fast RoPE embedding triton kernels + fast_rope_embeddings: true diff --git a/scripts/benchmarks/benchmark.py b/scripts/benchmarks/benchmark.py index 651b227a..f5ff4a54 100644 --- a/scripts/benchmarks/benchmark.py +++ b/scripts/benchmarks/benchmark.py @@ -1,5 +1,6 @@ # Standard from itertools import product +from time import sleep from typing import Any, Callable, Dict, List, Tuple, Union import argparse import json @@ -88,6 +89,7 @@ HF_ARG_SKIP_MEMORY_METRIC = "--skip_memory_metrics" RESULT_FIELD_ALLOCATED_GPU_MEM = "mem_torch_mem_alloc_in_bytes" RESULT_FIELD_PEAK_ALLOCATED_GPU_MEM = "mem_peak_torch_mem_alloc_in_bytes" +ERROR_MESSAGES = "error_messages" def extract_gpu_memory_metrics(output_metrics) -> Tuple[float]: @@ -357,6 +359,17 @@ def __init__( self.results_filename = os.path.join(self.save_dir, FILE_RESULTS) self.gpu_log_filename = os.path.join(self.save_dir, FILE_MEM) + @property + def is_completed(self): + if not os.path.exists(self.results_filename): + return False + # otherwise open it and check for errors + with open(self.results_filename) as f: + results = json.load(f) + + # return complete only if no errors + return not ERROR_MESSAGES in results + def run( self, run_cmd: str, @@ -552,7 +565,7 @@ def write_result(self): **self.get_experiment_final_metrics(), } else: - other_results = {"error_messages": maybe_error_messages} + other_results = {ERROR_MESSAGES: maybe_error_messages} # combine the final thing save_result = {**save_result, **other_results} @@ -781,6 +794,14 @@ def main(args): log_memory_in_trainer=args.log_memory_hf, ) ): + # store pointer to file for future result retrival + experiment_stats[experiment.tag] = experiment.results_filename + + if experiment.is_completed: + # if completed, dont proceed + sleep(0.1) # sleep a bit to allow the tqdm to update + continue + if experiment.num_gpus > 1: prefix = COMMAND_ACCELERATE.format( accelerate_config_path=args.accelerate_config, @@ -806,10 +827,9 @@ def main(args): log_nvidia_smi=args.log_nvidia_smi, ) - # write results and store pointers to files + # write results experiment.write_result() experiment.write_shell_command() - experiment_stats[experiment.tag] = experiment.results_filename # 4. Consolidates the experiment results into a summary for tag, path in experiment_stats.items(): diff --git a/scripts/benchmarks/display_bench_results.py b/scripts/benchmarks/display_bench_results.py index 1de9b2a5..51ba5642 100644 --- a/scripts/benchmarks/display_bench_results.py +++ b/scripts/benchmarks/display_bench_results.py @@ -1,18 +1,21 @@ # Standard +from typing import List import argparse # First Party # import this because of alot of internal contants -from scripts.benchmarks.benchmark import gather_report, DIR_SAMP_CONFIGS -from typing import List +from scripts.benchmarks.benchmark import DIR_SAMP_CONFIGS, gather_report -def main(*directories: str, output_filename: str = "results.csv", remove_columns: List[str] = None): + +def main( + *directories: str, + output_filename: str = "results.csv", + remove_columns: List[str] = None, + keep_columns: List[str] = None, +): "gather outputs from a list of directories and output to a csv" - df, constant = gather_report(*directories, raw=False) - # filter result columns to keep by the inverse of remove_columns - if remove_columns: - df = df[df.columns[~df.columns.isin(remove_columns)]] + df, constant = gather_report(directories, raw=False) errors = [] try: @@ -22,12 +25,25 @@ def main(*directories: str, output_filename: str = "results.csv", remove_columns df = df.loc[df.error_messages.isna()] except: pass + + # filter result columns to keep by the inverse of remove_columns + if remove_columns: + df = df[df.columns[~df.columns.isin(remove_columns)]] + + # assume keep and remove are disjoint + kept = 0 + if keep_columns: + for c in keep_columns: + if c in constant: + df[c] = constant[c] + kept += 1 + df = df.reset_index(drop=True).drop("output_dir", axis=1) df.reindex(sorted(df.columns), axis=1).to_csv(output_filename, index=False) print("***************** Report Created ******************") print(f"Total lines: '{len(df)}'") print(f"Number columns included: '{len(df.columns)}'") - print(f"Number columns excluded: '{len(constant)}'") + print(f"Number columns excluded: '{len(constant)-kept}'") print(f"Excluding number of exceptions caught: '{len(errors)}'") print(f"Written report to '{output_filename}'") @@ -53,10 +69,16 @@ def main(*directories: str, output_filename: str = "results.csv", remove_columns nargs="*", help="list of columns to ignore from results.csv", ) + parser.add_argument( + "--keep_columns", + nargs="*", + help="list of columns to always include into results.csv", + ) args = parser.parse_args() main( - args.bench_outputs, + *args.bench_outputs, output_filename=args.result_file, remove_columns=args.remove_columns, + keep_columns=args.keep_columns, ) diff --git a/scripts/benchmarks/refs/a100_80gb.csv b/scripts/benchmarks/refs/a100_80gb.csv index 4434d864..b83549a7 100644 --- a/scripts/benchmarks/refs/a100_80gb.csv +++ b/scripts/benchmarks/refs/a100_80gb.csv @@ -1,61 +1,82 @@ -epoch,fp16,framework_config,index,learning_rate,lora_alpha,lora_dropout,model_name_or_path,num_gpus,nvidia_mem_reserved,peak_torch_mem_alloc_in_bytes,peft_method,per_device_train_batch_size,r,target_modules,torch_mem_alloc_in_bytes,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second -0.04,,none,0,2e-5,,,mistralai/Mistral-7B-v0.1,1,77705.0,72971724288.0,,4,,,44004763136.0,0.9278398831685384,177.1092,0.678,0.169,2775.237 -0.04,,none,1,2e-5,,,mistralai/Mistral-7B-v0.1,2,44706.0,36762859520.0,,2,,,29521119232.0,0.8970902442932129,91.086,1.317,0.329,2698.11 -0.09,,none,2,2e-5,,,mistralai/Mistral-7B-v0.1,1,74383.0,72972117504.0,,8,,,44005156352.0,0.9879656155904134,322.458,0.744,0.093,3048.583 -0.09,,none,3,2e-5,,,mistralai/Mistral-7B-v0.1,2,53907.0,36763056128.0,,4,,,29521315840.0,0.9259945551554362,167.7727,1.431,0.179,2929.678 -,,none,4,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,81043.0,,,4,,,,,,,, -,,none,5,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,79353.0,,,2,,,,,,,, -,,none,6,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,81043.0,,,8,,,,,,,, -,,none,7,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,79827.0,,,4,,,,,,,, -,,none,8,2e-5,,,NousResearch/Llama-2-70b-hf,1,80837.0,,,4,,,,,,,, -,,none,9,2e-5,,,NousResearch/Llama-2-70b-hf,2,80830.0,,,2,,,,,,,, -,,none,10,2e-5,,,NousResearch/Llama-2-70b-hf,1,80837.0,,,8,,,,,,,, -,,none,11,2e-5,,,NousResearch/Llama-2-70b-hf,2,80834.5,,,4,,,,,,,, -0.04,,none,12,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,29731.0,26108963328.0,lora,4,16,q_proj k_proj v_proj o_proj,15119590912.0,0.9096682230631511,136.624,0.878,0.22,3597.611 -0.04,,none,13,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,18697.0,15123161088.0,lora,2,16,q_proj k_proj v_proj o_proj,7850391552.0,0.8918854713439941,82.0311,1.463,0.366,2995.936 -0.09,,none,14,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,43195.0,37098695168.0,lora,8,16,q_proj k_proj v_proj o_proj,15119984128.0,0.962119706471761,270.6301,0.887,0.111,3632.412 -0.09,,none,15,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,26235.0,21433753600.0,lora,4,16,q_proj k_proj v_proj o_proj,7850588160.0,0.9218235015869141,143.8184,1.669,0.209,3417.643 -,,none,16,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,80955.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,, -0.04,,none,17,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,62617.0,57540387840.0,lora,2,16,q_proj k_proj v_proj o_proj,47311452160.0,0.9361546834309896,179.3128,0.669,0.167,1370.566 -,,none,18,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,80955.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,, -0.09,,none,19,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,69848.0,64347637760.0,lora,4,16,q_proj k_proj v_proj o_proj,47311648768.0,0.9383139928181966,280.8919,0.854,0.107,1749.855 -,,none,20,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,80917.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,, -,,none,21,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,80894.0,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,, -,,none,22,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,80917.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,, -,,none,23,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,80979.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,, -0.04,True,baseline-peft-bnb,24,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,27023.0,22825932800.0,lora,4,16,q_proj k_proj v_proj o_proj,5368221184.0,0.9589527130126954,178.8061,0.671,0.168,2748.9 -0.04,True,baseline-peft-bnb,25,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,13530.0,9974622720.0,lora,2,16,q_proj k_proj v_proj o_proj,2727018496.0,0.9154380798339844,87.3652,1.374,0.343,2813.02 -0.09,True,baseline-peft-bnb,26,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,47145.0,40278956032.0,lora,8,16,q_proj k_proj v_proj o_proj,5368614400.0,0.9702634493509928,341.2286,0.703,0.088,2880.884 -0.09,True,baseline-peft-bnb,27,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,21502.0,16587205120.0,lora,4,16,q_proj k_proj v_proj o_proj,2727215104.0,0.914565912882487,149.9341,1.601,0.2,3278.241 -0.04,True,baseline-peft-bnb,28,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,48313.0,46419968512.0,lora,4,16,q_proj k_proj v_proj o_proj,25726225920.0,0.9744932492574055,351.8623,0.341,0.085,1396.91 -0.04,True,baseline-peft-bnb,29,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,25549.0,21922782720.0,lora,2,16,q_proj k_proj v_proj o_proj,13219233792.0,0.9303209940592448,171.4299,0.7,0.175,1433.589 -0.09,True,baseline-peft-bnb,30,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,69931.0,67089150464.0,lora,8,16,q_proj k_proj v_proj o_proj,25726619136.0,0.9745417594909668,629.837,0.381,0.048,1560.785 -0.09,True,baseline-peft-bnb,31,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,32957.0,29384115200.0,lora,4,16,q_proj k_proj v_proj o_proj,13219430400.0,0.9310146331787109,300.5119,0.799,0.1,1635.609 -,True,baseline-peft-bnb,32,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,80893.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,, -0.04,True,baseline-peft-bnb,33,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,52634.0,46524471808.0,lora,2,16,q_proj k_proj v_proj o_proj,19172741120.0,1.0399916648864747,584.3145,0.205,0.051,420.595 -,True,baseline-peft-bnb,34,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,79557.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,, -,True,baseline-peft-bnb,35,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,80749.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,, -0.04,True,accelerated-peft-bnb,36,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,19931.0,15860019712.0,lora,4,16,q_proj k_proj v_proj o_proj,4843384320.0,0.9652111371358235,143.3569,0.837,0.209,3428.645 -0.04,True,accelerated-peft-bnb,37,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,13497.0,9974622720.0,lora,2,16,q_proj k_proj v_proj o_proj,2727018496.0,0.9277165730794271,86.4307,1.388,0.347,2843.435 -0.09,True,accelerated-peft-bnb,38,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,34355.0,26849751552.0,lora,8,16,q_proj k_proj v_proj o_proj,4843777536.0,0.9493892669677735,279.7156,0.858,0.107,3514.427 -0.09,True,accelerated-peft-bnb,39,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,21479.0,16587205120.0,lora,4,16,q_proj k_proj v_proj o_proj,2727215104.0,0.9110882759094239,149.3914,1.607,0.201,3290.15 -0.04,True,accelerated-peft-bnb,40,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,38405.0,36218024448.0,lora,4,16,q_proj k_proj v_proj o_proj,25201389056.0,0.9741149584452311,278.5888,0.431,0.108,1764.32 -0.04,True,accelerated-peft-bnb,41,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,25592.0,21906697728.0,lora,2,16,q_proj k_proj v_proj o_proj,13219233792.0,0.9300654411315918,172.7359,0.695,0.174,1422.75 -0.09,True,accelerated-peft-bnb,42,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,50875.0,47207756288.0,lora,8,16,q_proj k_proj v_proj o_proj,25201782272.0,0.9748441060384114,512.2298,0.469,0.059,1919.139 -0.09,True,accelerated-peft-bnb,43,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,32957.0,29369087488.0,lora,4,16,q_proj k_proj v_proj o_proj,13219430400.0,0.9301350593566895,287.6381,0.834,0.104,1708.814 -0.04,True,accelerated-peft-bnb,44,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,72829.0,68159977472.0,lora,4,16,q_proj k_proj v_proj o_proj,37346815488.0,1.118430455525716,1075.2044,0.112,0.028,457.141 -0.04,True,accelerated-peft-bnb,45,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,52632.0,46524471808.0,lora,2,16,q_proj k_proj v_proj o_proj,19172741120.0,1.040946865081787,586.651,0.205,0.051,418.92 -,True,accelerated-peft-bnb,46,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,80405.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,, -,True,accelerated-peft-bnb,47,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,80954.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,, -0.04,True,accelerated-peft-autogptq,48,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,20453.0,15890329088.0,lora,4,16,q_proj k_proj v_proj o_proj,4873693696.0,1.3805528958638509,151.0359,0.795,0.199,3254.326 -0.04,True,accelerated-peft-autogptq,49,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,17198.0,9952175616.0,lora,2,16,q_proj k_proj v_proj o_proj,3005709312.0,1.1706618309020995,87.4109,1.373,0.343,2811.548 -0.09,True,accelerated-peft-autogptq,50,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,34247.0,26880060928.0,lora,8,16,q_proj k_proj v_proj o_proj,4874086912.0,1.2741642634073893,282.6391,0.849,0.106,3478.076 -0.09,True,accelerated-peft-autogptq,51,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,24783.0,16262768128.0,lora,4,16,q_proj k_proj v_proj o_proj,3005905920.0,1.043952751159668,152.5473,1.573,0.197,3222.083 -0.04,True,accelerated-peft-autogptq,52,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,37461.0,35528093184.0,lora,4,16,q_proj k_proj v_proj o_proj,24511457792.0,0.9936613400777181,263.6066,0.455,0.114,1864.597 -0.04,True,accelerated-peft-autogptq,53,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,46641.0,25708175360.0,lora,2,16,q_proj k_proj v_proj o_proj,12788874240.0,0.9420519828796386,167.065,0.718,0.18,1471.045 -0.09,True,accelerated-peft-autogptq,54,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,49925.0,46517825024.0,lora,8,16,q_proj k_proj v_proj o_proj,24511851008.0,0.9855653127034505,498.9022,0.481,0.06,1970.406 -0.09,True,accelerated-peft-autogptq,55,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,52358.0,27739090432.0,lora,4,16,q_proj k_proj v_proj o_proj,12789070848.0,0.9389812151590983,281.8034,0.852,0.106,1744.195 -0.04,True,accelerated-peft-autogptq,56,2e-4,16,0.0,TheBloke/Llama-2-70B-GPTQ,1,71565.0,65895347200.0,lora,4,16,q_proj k_proj v_proj o_proj,36290144768.0,1.0755928039550782,1060.8387,0.113,0.028,463.331 -0.04,True,accelerated-peft-autogptq,57,2e-4,16,0.0,TheBloke/Llama-2-70B-GPTQ,2,80387.0,45397678592.0,lora,2,16,q_proj k_proj v_proj o_proj,18649885696.0,1.0256956418355305,576.0422,0.208,0.052,426.635 -,True,accelerated-peft-autogptq,58,2e-4,16,0.0,TheBloke/Llama-2-70B-GPTQ,1,80293.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,, -0.08,True,accelerated-peft-autogptq,59,2e-4,16,0.0,TheBloke/Llama-2-70B-GPTQ,2,80363.0,70667573760.0,lora,4,16,q_proj k_proj v_proj o_proj,18650082304.0,1.0266701062520345,1089.3291,0.22,0.028,451.214 +epoch,fp16,framework_config,learning_rate,lora_alpha,lora_dropout,mem_nvidia_mem_reserved,mem_peak_torch_mem_alloc_in_bytes,mem_torch_mem_alloc_in_bytes,model_name_or_path,num_gpus,peft_method,per_device_train_batch_size,r,target_modules,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second +0.15,True,baseline-peft-bnb,2e-4,16,0.0,25995.0,22825932800,5368221184,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8676117706298828,584.6749,0.684,0.171,2802.241 +0.15,True,baseline-peft-bnb,2e-4,16,0.0,12512.0,9974622720,2727018496,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8593511199951172,279.9917,1.429,0.357,2925.801 +0.29,True,baseline-peft-bnb,2e-4,16,0.0,46117.0,40278956032,5368614400,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.86837890625,1149.6017,0.696,0.087,2850.378 +0.29,True,baseline-peft-bnb,2e-4,16,0.0,20435.0,16587205120,2727215104,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8526134586334229,496.2449,1.612,0.202,3301.596 +0.15,True,baseline-peft-bnb,2e-4,16,0.0,47079.0,46427906560,25726225920,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8966263771057129,1169.4078,0.342,0.086,1401.051 +0.15,True,baseline-peft-bnb,2e-4,16,0.0,24609.0,21937980416,13219233792,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8650046825408936,564.3075,0.709,0.177,1451.691 +0.29,True,baseline-peft-bnb,2e-4,16,0.0,68071.0,67121147392,25726619136,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8866284656524658,2118.0176,0.378,0.047,1547.107 +0.29,True,baseline-peft-bnb,2e-4,16,0.0,32054.0,29375012352,13219430400,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8636721038818359,959.452,0.834,0.104,1707.641 +,True,baseline-peft-bnb,2e-4,16,0.0,80631.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.14,True,baseline-peft-bnb,2e-4,16,0.0,51579.0,46524471808,19172741120,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9462522315979004,1951.2462,0.205,0.051,419.834 +,True,baseline-peft-bnb,2e-4,16,0.0,79555.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.28,True,baseline-peft-bnb,2e-4,16,0.0,80801.0,72398346752,19172937728,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.935322732925415,3737.7987,0.214,0.027,438.333 +0.15,True,accelerated-peft-bnb,2e-4,16,0.0,18903.0,15860019712,4843384320,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8679532146453858,480.1165,0.833,0.208,3412.505 +0.15,True,accelerated-peft-bnb,2e-4,16,0.0,12477.0,9974622720,2727018496,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8598325538635254,281.0553,1.423,0.356,2914.729 +0.29,True,accelerated-peft-bnb,2e-4,16,0.0,33327.0,26849751552,4843777536,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8708646774291993,944.515,0.847,0.106,3469.294 +0.29,True,accelerated-peft-bnb,2e-4,16,0.0,20417.0,16587205120,2727215104,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8568318557739257,498.8375,1.604,0.2,3284.436 +0.15,True,accelerated-peft-bnb,2e-4,16,0.0,37321.0,36218024448,25201389056,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8979199028015137,923.4329,0.433,0.108,1774.249 +0.15,True,accelerated-peft-bnb,2e-4,16,0.0,24783.0,21940224000,13219233792,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8649028778076172,564.1011,0.709,0.177,1452.222 +0.29,True,accelerated-peft-bnb,2e-4,16,0.0,49847.0,47207756288,25201782272,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8877867794036866,1717.1699,0.466,0.058,1908.256 +0.29,True,accelerated-peft-bnb,2e-4,16,0.0,31907.0,29336790016,13219430400,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8623861598968506,952.2959,0.84,0.105,1720.474 +0.14,True,accelerated-peft-bnb,2e-4,16,0.0,71801.0,68159977472,37346815488,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.999151840209961,3662.4376,0.109,0.027,447.352 +0.14,True,accelerated-peft-bnb,2e-4,16,0.0,51579.0,46524471808,19172741120,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9392572689056397,1950.7659,0.205,0.051,419.938 +,True,accelerated-peft-bnb,2e-4,16,0.0,79375.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.28,True,accelerated-peft-bnb,2e-4,16,0.0,80866.0,72398346752,19172937728,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9258937835693359,3744.4001,0.214,0.027,437.56 +0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,19425.0,15890329088,4873693696,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.0217428588867188,477.2159,0.838,0.21,3433.247 +0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,12056.0,9690031616,2743565312,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9701251029968262,278.7874,1.435,0.359,2938.44 +0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,33219.0,26880060928,4874086912,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9569056987762451,941.1761,0.85,0.106,3481.601 +0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,19530.0,16000624128,2743761920,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9303163433074951,494.3287,1.618,0.202,3314.394 +0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,19065.0,13631990784,4873693696,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9736110210418701,411.3906,0.972,0.243,3982.589 +0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,11506.0,9174099456,2405399552,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,1.0141907215118409,248.8178,1.608,0.402,3292.368 +0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,32721.0,22390647808,4874086912,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9668986797332764,809.2016,0.989,0.124,4049.424 +0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,18635.0,15282316800,2405596160,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.942121753692627,444.2322,1.801,0.225,3688.162 +0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,36435.0,35528093184,24511457792,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9004004192352295,879.8344,0.455,0.114,1862.169 +0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,22962.5,20697435648,12526730240,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8698519325256348,537.8597,0.744,0.186,1523.074 +0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,48941.0,46517825024,24511851008,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8974114608764648,1669.3163,0.479,0.06,1962.959 +0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,29756.0,27484941824,12526926848,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8667408466339112,924.2282,0.866,0.108,1772.722 +0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,36613.0,33671981056,24511457792,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9003233146667481,814.7613,0.491,0.123,2010.896 +0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,22421.0,20108989952,12191160320,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.867002067565918,506.3203,0.79,0.198,1617.948 +0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,49691.0,42742948864,24511851008,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.897435302734375,1534.4874,0.521,0.065,2135.436 +0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,28865.0,26629788672,12191300608,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.866525583267212,877.2087,0.912,0.114,1867.742 +0.14,True,accelerated-peft-autogptq,2e-4,16,0.0,71177.0,65895347200,36290144768,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.99012770652771,3600.8607,0.111,0.028,455.002 +0.14,True,accelerated-peft-autogptq,2e-4,16,0.0,49455.0,44873390592,18125597696,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9539268207550049,1890.9021,0.212,0.053,433.232 +,True,accelerated-peft-autogptq,2e-4,16,0.0,79265.0,0,0,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.28,True,accelerated-peft-autogptq,2e-4,16,0.0,79283.0,70143285760,18125794304,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9549467945098877,3679.8651,0.217,0.027,445.234 +0.14,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,71223.0,65086305280,36290144768,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9903428840637207,3295.1075,0.121,0.03,497.222 +0.14,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,46207.0,41579411968,15105330176,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9634347057342529,1740.6214,0.23,0.057,470.637 +,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,80949.0,0,0,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.28,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,74507.0,66445605376,15105526784,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9590920734405518,3441.8985,0.232,0.029,476.016 +0.15,,none,2e-5,,,76679.0,72971724288,44004763136,mistralai/Mistral-7B-v0.1,1,,4,,,float16,0.9002080440521241,558.4193,0.716,0.179,2933.996 +0.15,,none,2e-5,,,43695.0,36762859520,29521119232,mistralai/Mistral-7B-v0.1,2,,2,,,float16,0.8854282188415528,302.5551,1.322,0.331,2707.606 +0.29,,none,2e-5,,,73761.0,72972117504,44005156352,mistralai/Mistral-7B-v0.1,1,,8,,,float16,1.0202219200134277,1085.5804,0.737,0.092,3018.478 +0.29,,none,2e-5,,,52923.0,36763056128,29521315840,mistralai/Mistral-7B-v0.1,2,,4,,,float16,0.8920887660980225,561.8731,1.424,0.178,2915.961 +,,none,2e-5,,,79961.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,4,,,float16,,,,, +,,none,2e-5,,,80925.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,2,,,float16,,,,, +,,none,2e-5,,,80969.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,8,,,float16,,,,, +,,none,2e-5,,,80703.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,4,,,float16,,,,, +,,none,2e-5,,,80987.0,0,0,NousResearch/Llama-2-70b-hf,1,,4,,,float16,,,,, +,,none,2e-5,,,80922.0,0,0,NousResearch/Llama-2-70b-hf,2,,2,,,float16,,,,, +,,none,2e-5,,,80987.0,0,0,NousResearch/Llama-2-70b-hf,1,,8,,,float16,,,,, +,,none,2e-5,,,80782.0,0,0,NousResearch/Llama-2-70b-hf,2,,4,,,float16,,,,, +0.15,,none,2e-4,16,0.0,28703.0,26108963328,15119590912,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8848505210876465,456.0676,0.877,0.219,3592.45 +0.15,,none,2e-4,16,0.0,17655.0,15123161088,7850391552,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8546714687347412,267.0472,1.498,0.374,3067.623 +0.29,,none,2e-4,16,0.0,42167.0,37098695168,15119984128,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,1.0078722095489503,909.6399,0.879,0.11,3602.305 +0.29,,none,2e-4,16,0.0,25207.0,21433753600,7850588160,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8803257846832275,477.2486,1.676,0.21,3433.012 +,,none,2e-4,16,0.0,78871.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.15,,none,2e-4,16,0.0,61532.0,57531527168,47311452160,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8628986740112304,545.0419,0.734,0.183,1503.004 +,,none,2e-4,16,0.0,80991.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.29,,none,2e-4,16,0.0,68811.0,64348470272,47311648768,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8795901584625244,919.9512,0.87,0.109,1780.964 +,,none,2e-4,16,0.0,80617.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, +,,none,2e-4,16,0.0,80760.0,0,0,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,,,,, +,,none,2e-4,16,0.0,80617.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +,,none,2e-4,16,0.0,80987.0,0,0,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.0,19257.0,13636909056,4843384320,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8704845142364502,417.5391,0.958,0.239,3923.944 +,True,accelerated-peft-bnb-foak,2e-4,16,0.0,5527.0,0,0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.0,32209.0,22430791680,4843777536,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8942180156707764,818.5228,0.977,0.122,4003.309 +,True,accelerated-peft-bnb-foak,2e-4,16,0.0,5675.0,0,0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, +0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.0,37301.0,35622334464,25201389056,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.887912654876709,861.4969,0.464,0.116,1901.806 +0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.0,49955.0,46024318976,25201782272,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8887538051605225,1590.7501,0.503,0.063,2059.909 +0.14,True,accelerated-peft-bnb-foak,2e-4,16,0.0,71995.0,67350935552,37346815488,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.0002326488494873,3357.4377,0.119,0.03,487.991 +,True,accelerated-peft-bnb-foak,2e-4,16,0.0,80303.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +,True,accelerated-peft-bnb-foak,2e-4,16,0.0,21095.0,0,0,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, diff --git a/scripts/benchmarks/refs/l40_40gb.csv b/scripts/benchmarks/refs/l40_40gb.csv deleted file mode 100644 index 2158c782..00000000 --- a/scripts/benchmarks/refs/l40_40gb.csv +++ /dev/null @@ -1,49 +0,0 @@ -acceleration_framework_config_file,epoch,error_messages,fp16,framework_config,index,learning_rate,lora_alpha,lora_dropout,model_name_or_path,num_gpus,output_dir,peft_method,per_device_train_batch_size,r,target_modules,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second,training_data_path -,,,,none,0,2e-5,,,mistralai/Mistral-7B-v0.1,1,,,4,,,,,,,,benchmark_outputs/data/cache.json -,0.03,,,none,1,2e-5,,,mistralai/Mistral-7B-v0.1,2,,,2,,,0.9020393848419189,102.4493,0.781,0.195,1599.23,benchmark_outputs/data/cache.json -,,,,none,2,2e-5,,,mistralai/Mistral-7B-v0.1,1,,,8,,,,,,,,benchmark_outputs/data/cache.json -,0.06,,,none,3,2e-5,,,mistralai/Mistral-7B-v0.1,2,,,4,,,0.936076545715332,170.7722,0.937,0.117,1918.814,benchmark_outputs/data/cache.json -,,,,none,4,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,,4,,,,,,,,benchmark_outputs/data/cache.json -,,,,none,5,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,,2,,,,,,,,benchmark_outputs/data/cache.json -,,,,none,6,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,,8,,,,,,,,benchmark_outputs/data/cache.json -,,,,none,7,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,,4,,,,,,,,benchmark_outputs/data/cache.json -,,,,none,8,2e-5,,,NousResearch/Llama-2-70b-hf,1,,,4,,,,,,,,benchmark_outputs/data/cache.json -,,,,none,9,2e-5,,,NousResearch/Llama-2-70b-hf,2,,,2,,,,,,,,benchmark_outputs/data/cache.json -,,,,none,10,2e-5,,,NousResearch/Llama-2-70b-hf,1,,,8,,,,,,,,benchmark_outputs/data/cache.json -,,,,none,11,2e-5,,,NousResearch/Llama-2-70b-hf,2,,,4,,,,,,,,benchmark_outputs/data/cache.json -,0.03,,,none,12,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,,lora,4,16,q_proj k_proj v_proj o_proj,0.9326287746429444,120.2794,0.665,0.166,2724.324,benchmark_outputs/data/cache.json -,0.03,,,none,13,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,,lora,2,16,q_proj k_proj v_proj o_proj,0.9157441139221192,78.5825,1.018,0.255,2084.943,benchmark_outputs/data/cache.json -,0.06,,,none,14,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,,lora,8,16,q_proj k_proj v_proj o_proj,1.0113807678222657,241.3246,0.663,0.083,2715.679,benchmark_outputs/data/cache.json -,0.06,,,none,15,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,,lora,4,16,q_proj k_proj v_proj o_proj,0.9433841228485107,133.2158,1.201,0.15,2459.768,benchmark_outputs/data/cache.json -,,,,none,16,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -,,,,none,17,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -,,,,none,18,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -,,,,none,19,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -,,,,none,20,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -,,,,none,21,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -,,,,none,22,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -,,,,none,23,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.03,,True,accelerated-peft-autogptq,36,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,,lora,4,16,q_proj k_proj v_proj o_proj,1.6183419704437256,137.2634,0.583,0.146,2387.235,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.03,,True,accelerated-peft-autogptq,37,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,,lora,2,16,q_proj k_proj v_proj o_proj,1.7251328945159912,73.906,1.082,0.271,2216.871,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.06,,True,accelerated-peft-autogptq,38,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,,lora,8,16,q_proj k_proj v_proj o_proj,1.5904263019561768,272.1958,0.588,0.073,2407.679,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.06,,True,accelerated-peft-autogptq,39,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,,lora,4,16,q_proj k_proj v_proj o_proj,1.515465259552002,138.6152,1.154,0.144,2363.954,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.03,,True,accelerated-peft-autogptq,40,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,,lora,4,16,q_proj k_proj v_proj o_proj,1.012540912628174,227.0536,0.352,0.088,1443.183,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.03,,True,accelerated-peft-autogptq,41,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,,lora,2,16,q_proj k_proj v_proj o_proj,1.0235525131225587,121.7118,0.657,0.164,1346.13,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,,True,accelerated-peft-autogptq,42,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.06,,True,accelerated-peft-autogptq,43,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,,lora,4,16,q_proj k_proj v_proj o_proj,1.0152217864990234,229.6679,0.697,0.087,1426.756,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,,True,accelerated-peft-autogptq,44,2e-4,16,0.0,TheBloke/Nous-Hermes-Llama2-70B-GPTQ,1,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,,True,accelerated-peft-autogptq,45,2e-4,16,0.0,TheBloke/Nous-Hermes-Llama2-70B-GPTQ,2,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,,True,accelerated-peft-autogptq,46,2e-4,16,0.0,TheBloke/Nous-Hermes-Llama2-70B-GPTQ,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,,True,accelerated-peft-autogptq,47,2e-4,16,0.0,TheBloke/Nous-Hermes-Llama2-70B-GPTQ,2,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.03,,True,accelerated-peft-bnb,0,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,,lora,4,16,q_proj k_proj v_proj o_proj,0.9979345798492432,130.1845,0.615,0.154,2517.044,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.03,,True,accelerated-peft-bnb,1,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,,lora,2,16,q_proj k_proj v_proj o_proj,0.942676591873169,69.8209,1.146,0.286,2346.575,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.06,,True,accelerated-peft-bnb,2,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,,lora,8,16,q_proj k_proj v_proj o_proj,0.9919514656066895,259.8776,0.616,0.077,2521.802,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.06,,True,accelerated-peft-bnb,3,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,,lora,4,16,q_proj k_proj v_proj o_proj,0.933735466003418,133.6157,1.197,0.15,2452.406,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.03,,True,accelerated-peft-bnb,4,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,lora,4,16,q_proj k_proj v_proj o_proj,1.015654945373535,218.3215,0.366,0.092,1500.906,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.03,,True,accelerated-peft-bnb,5,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,lora,2,16,q_proj k_proj v_proj o_proj,0.9546889305114746,173.2373,0.462,0.115,945.755,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,,True,accelerated-peft-bnb,6,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.06,,True,accelerated-peft-bnb,7,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,lora,4,16,q_proj k_proj v_proj o_proj,0.9585415840148925,273.4507,0.585,0.073,1198.315,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,,True,accelerated-peft-bnb,8,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,,True,accelerated-peft-bnb,9,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,,True,accelerated-peft-bnb,10,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_bnb_outputs/data/cache.json -sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,,True,accelerated-peft-bnb,11,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_bnb_outputs/data/cache.json diff --git a/scripts/benchmarks/scenarios.yaml b/scripts/benchmarks/scenarios.yaml index c935ac31..42f7c753 100644 --- a/scripts/benchmarks/scenarios.yaml +++ b/scripts/benchmarks/scenarios.yaml @@ -52,6 +52,7 @@ scenarios: - name: accelerated-peft-bnb framework_config: - accelerated-peft-bnb + - accelerated-peft-bnb-foak arguments: fp16: True learning_rate: 2e-4 @@ -82,4 +83,4 @@ scenarios: model_name_or_path: - 'TheBloke/Mistral-7B-v0.1-GPTQ' - 'TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ' - - 'TheBloke/Llama-2-70B-GPTQ' \ No newline at end of file + - 'TheBloke/Llama-2-70B-GPTQ' diff --git a/scripts/generate_sample_configurations.py b/scripts/generate_sample_configurations.py index fd51d965..b3485e3c 100644 --- a/scripts/generate_sample_configurations.py +++ b/scripts/generate_sample_configurations.py @@ -143,6 +143,7 @@ def read_configuration(path: str) -> Dict: KEY_BNB_NF4 = "bnb-nf4" KEY_BNB_NF4_BASELINE = "baseline-bnb-nf4" KEY_AUTO_GPTQ_FOAK = "auto-gptq-foak" +KEY_BNB_NF4_FOAK = "bnb-nf4-foak" CONFIGURATIONS = { KEY_AUTO_GPTQ: "plugins/accelerated-peft/configs/autogptq.yaml", @@ -153,14 +154,18 @@ def read_configuration(path: str) -> Dict: KEY_BNB_NF4_BASELINE: ( "plugins/accelerated-peft/configs/bnb.yaml", [ - ("peft.quantization.bitsandbytes.quant_type", "nf4"), - ("peft.quantization.bitsandbytes.no_peft_model", True), + ("peft.quantization.bitsandbytes.quant_type", "nf4"), + ("peft.quantization.bitsandbytes.no_peft_model", True), ], ), KEY_AUTO_GPTQ_FOAK: ( "plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml", [("peft.quantization.fused_ops_and_kernels.base_layer", "auto_gptq")], ), + KEY_BNB_NF4_FOAK: ( + "plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml", + [("peft.quantization.fused_ops_and_kernels.base_layer", "bitsandbytes")], + ), } # list of (tag, combi) tuples @@ -173,8 +178,10 @@ def read_configuration(path: str) -> Dict: ("accelerated-peft-bnb-nf4", (KEY_BNB_NF4,)), ("baseline-peft-bnb-nf4", (KEY_BNB_NF4_BASELINE,)), ("accelerated-peft-autogptq-foak", (KEY_AUTO_GPTQ, KEY_AUTO_GPTQ_FOAK)), + ("accelerated-peft-bnb-nf4-foak", (KEY_BNB_NF4, KEY_BNB_NF4_FOAK)), ] + # TODO: throw error if merge conflicts def merge_configs(config_contents: List[Dict]): "helper function to merge configuration contents." @@ -183,10 +190,10 @@ def merge_configs(config_contents: List[Dict]): def _merge(result: Dict, new_contents: Dict): for k, v in new_contents.items(): if k not in result: - # if k is not in result, it means v does not + # if k is not in result, it means v does not # exist as a subtree under result, so we just do # an assingment - result[k] = v + result[k] = v else: # otherwise we call the merge _merge(result[k], v) diff --git a/scripts/run_benchmarks.sh b/scripts/run_benchmarks.sh index 798138bf..8f8a1f9b 100644 --- a/scripts/run_benchmarks.sh +++ b/scripts/run_benchmarks.sh @@ -58,10 +58,10 @@ if [ -n "$RESULT_DIR" ]; then echo "Results dir $RESULT_DIR is not empty, but NO_OVERWRITE=true" echo "If intending to overwrite please delete the folder manually" echo "or do not set NO_OVERWRITE" - exit 1 + else + echo "Deleting $RESULT_DIR" + rm -rf $RESULT_DIR fi - echo "Deleting $RESULT_DIR" - rm -rf $RESULT_DIR fi # tag on the directories @@ -98,9 +98,11 @@ elif [ "$MEMORY_LOGGING" = "all" ]; then fi # dump out the environment -echo "Creating $RESULT_DIR" -mkdir -p $RESULT_DIR -pip freeze > $PIP_REQUIREMENTS_FILE +if [ ! "$NO_OVERWRITE" = "true" ]; then + echo "Creating $RESULT_DIR" + mkdir -p $RESULT_DIR + pip freeze > $PIP_REQUIREMENTS_FILE +fi # run the bench python $WORKING_DIR/benchmark.py \ @@ -116,8 +118,10 @@ python $WORKING_DIR/benchmark.py \ # this will write to the BENCH_RESULT_FILE # Remove the columns with values already represented by other metrics in the summary report PYTHONPATH=. \ - python $WORKING_DIR/display_bench_results.py benchmark_outputs \ + python $WORKING_DIR/display_bench_results.py $RESULT_DIR \ --result_file $BENCH_RESULT_FILE \ + --keep_columns \ + 'torch_dtype' \ --remove_columns \ 'before_init_mem_cpu' \ 'before_init_mem_gpu' \ @@ -129,5 +133,7 @@ PYTHONPATH=. \ 'train_mem_cpu_peaked_delta' \ 'train_mem_gpu_alloc_delta' \ 'train_mem_gpu_peaked_delta' \ + 'training_data_path' \ + 'error_messages' \ 'acceleration_framework_config_file'