From c859361995442b7f1e3cd6a7bb4aedd70412dae9 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Sun, 30 Jun 2024 15:51:24 +0000 Subject: [PATCH] minor fixes Signed-off-by: Yu Chin Fabian Lim --- .../acceleration_framework_config.py | 19 ++++++++++--------- .../fused_ops_and_kernels.py | 18 +----------------- .../quantized_lora_config.py | 12 +----------- 3 files changed, 12 insertions(+), 37 deletions(-) diff --git a/tuning/config/acceleration_configs/acceleration_framework_config.py b/tuning/config/acceleration_configs/acceleration_framework_config.py index c93ea9e50..7dc5737bb 100644 --- a/tuning/config/acceleration_configs/acceleration_framework_config.py +++ b/tuning/config/acceleration_configs/acceleration_framework_config.py @@ -162,6 +162,8 @@ def from_dataclasses(*dataclasses: Type): return config def get_framework(self): + if self.is_empty(): + return if is_fms_accelerate_available(): @@ -176,14 +178,13 @@ def get_framework(self): self.to_yaml(f.name) return AccelerationFramework(f.name) else: - if not self.is_empty(): - raise ValueError( - "No acceleration framework package found. To use, first " - "ensure that 'pip install fms-hf-tuning[fms-accel]' is done first to " - "obtain the acceleration framework dependency. Additional " - "acceleration plugins make be required depending on the requsted " - "acceleration. See README.md for instructions." - ) + raise ValueError( + "No acceleration framework package found. To use, first " + "ensure that 'pip install fms-hf-tuning[fms-accel]' is done first to " + "obtain the acceleration framework dependency. Additional " + "acceleration plugins make be required depending on the requsted " + "acceleration. See README.md for instructions." + ) def is_empty(self): "check if the configuration is empty" @@ -244,7 +245,7 @@ def _descend_and_set(path: List[str], d: Dict): "to be installed. Please do:\n" + "\n".join( [ - "- python -m fms_acceleration install " + "- python -m fms_acceleration.cli install " f"{AccelerationFrameworkConfig.PACKAGE_PREFIX + x}" for x in annotate.required_packages ] diff --git a/tuning/config/acceleration_configs/fused_ops_and_kernels.py b/tuning/config/acceleration_configs/fused_ops_and_kernels.py index 91df8c9dc..ded51415e 100644 --- a/tuning/config/acceleration_configs/fused_ops_and_kernels.py +++ b/tuning/config/acceleration_configs/fused_ops_and_kernels.py @@ -18,20 +18,13 @@ from typing import List # Local -from .utils import ( - EnsureTypes, - ensure_nested_dataclasses_initialized, - parsable_dataclass, -) +from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass @parsable_dataclass @dataclass class FusedLoraConfig(List): - # to help the HfArgumentParser arrive at correct types - __args__ = [EnsureTypes(str, bool)] - # load unsloth optimizations for these 4bit base layer weights. # currently only support "auto_gptq" and "bitsandbytes" base_layer: str = None @@ -41,9 +34,6 @@ class FusedLoraConfig(List): def __post_init__(self): - # reset for another parse - self.__args__[0].reset() - if self.base_layer is not None and self.base_layer not in { "auto_gptq", "bitsandbytes", @@ -60,9 +50,6 @@ def __post_init__(self): @dataclass class FastKernelsConfig(List): - # to help the HfArgumentParser arrive at correct types - __args__ = [EnsureTypes(bool, bool, bool)] - # fast loss triton kernels fast_loss: bool = False @@ -74,9 +61,6 @@ class FastKernelsConfig(List): def __post_init__(self): - # reset for another parse - self.__args__[0].reset() - if not self.fast_loss == self.fast_rsm_layernorm == self.fast_rope_embeddings: raise ValueError( "fast_loss, fast_rms_layernorm and fast_rope_embedding must be enabled " diff --git a/tuning/config/acceleration_configs/quantized_lora_config.py b/tuning/config/acceleration_configs/quantized_lora_config.py index d8174438c..a55ac55d6 100644 --- a/tuning/config/acceleration_configs/quantized_lora_config.py +++ b/tuning/config/acceleration_configs/quantized_lora_config.py @@ -18,11 +18,7 @@ from typing import List # Local -from .utils import ( - EnsureTypes, - ensure_nested_dataclasses_initialized, - parsable_dataclass, -) +from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass @parsable_dataclass @@ -49,9 +45,6 @@ def __post_init__(self): @dataclass class BNBQLoraConfig(List): - # to help the HfArgumentParser arrive at correct types - __args__ = [EnsureTypes(str, bool)] - # type of quantization applied quant_type: str = "nf4" @@ -61,9 +54,6 @@ class BNBQLoraConfig(List): def __post_init__(self): - # reset for another parse - self.__args__[0].reset() - if self.quant_type not in ["nf4", "fp4"]: raise ValueError("quant_type can only be either 'nf4' or 'fp4.")