Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Jul 6, 2024
1 parent 9496f2f commit c859361
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def from_dataclasses(*dataclasses: Type):
return config

def get_framework(self):
if self.is_empty():
return

if is_fms_accelerate_available():

Expand All @@ -176,14 +178,13 @@ def get_framework(self):
self.to_yaml(f.name)
return AccelerationFramework(f.name)
else:
if not self.is_empty():
raise ValueError(
"No acceleration framework package found. To use, first "
"ensure that 'pip install fms-hf-tuning[fms-accel]' is done first to "
"obtain the acceleration framework dependency. Additional "
"acceleration plugins make be required depending on the requsted "
"acceleration. See README.md for instructions."
)
raise ValueError(
"No acceleration framework package found. To use, first "
"ensure that 'pip install fms-hf-tuning[fms-accel]' is done first to "
"obtain the acceleration framework dependency. Additional "
"acceleration plugins make be required depending on the requsted "
"acceleration. See README.md for instructions."
)

def is_empty(self):
"check if the configuration is empty"
Expand Down Expand Up @@ -244,7 +245,7 @@ def _descend_and_set(path: List[str], d: Dict):
"to be installed. Please do:\n"
+ "\n".join(
[
"- python -m fms_acceleration install "
"- python -m fms_acceleration.cli install "
f"{AccelerationFrameworkConfig.PACKAGE_PREFIX + x}"
for x in annotate.required_packages
]
Expand Down
18 changes: 1 addition & 17 deletions tuning/config/acceleration_configs/fused_ops_and_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,13 @@
from typing import List

# Local
from .utils import (
EnsureTypes,
ensure_nested_dataclasses_initialized,
parsable_dataclass,
)
from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass


@parsable_dataclass
@dataclass
class FusedLoraConfig(List):

# to help the HfArgumentParser arrive at correct types
__args__ = [EnsureTypes(str, bool)]

# load unsloth optimizations for these 4bit base layer weights.
# currently only support "auto_gptq" and "bitsandbytes"
base_layer: str = None
Expand All @@ -41,9 +34,6 @@ class FusedLoraConfig(List):

def __post_init__(self):

# reset for another parse
self.__args__[0].reset()

if self.base_layer is not None and self.base_layer not in {
"auto_gptq",
"bitsandbytes",
Expand All @@ -60,9 +50,6 @@ def __post_init__(self):
@dataclass
class FastKernelsConfig(List):

# to help the HfArgumentParser arrive at correct types
__args__ = [EnsureTypes(bool, bool, bool)]

# fast loss triton kernels
fast_loss: bool = False

Expand All @@ -74,9 +61,6 @@ class FastKernelsConfig(List):

def __post_init__(self):

# reset for another parse
self.__args__[0].reset()

if not self.fast_loss == self.fast_rsm_layernorm == self.fast_rope_embeddings:
raise ValueError(
"fast_loss, fast_rms_layernorm and fast_rope_embedding must be enabled "
Expand Down
12 changes: 1 addition & 11 deletions tuning/config/acceleration_configs/quantized_lora_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@
from typing import List

# Local
from .utils import (
EnsureTypes,
ensure_nested_dataclasses_initialized,
parsable_dataclass,
)
from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass


@parsable_dataclass
Expand All @@ -49,9 +45,6 @@ def __post_init__(self):
@dataclass
class BNBQLoraConfig(List):

# to help the HfArgumentParser arrive at correct types
__args__ = [EnsureTypes(str, bool)]

# type of quantization applied
quant_type: str = "nf4"

Expand All @@ -61,9 +54,6 @@ class BNBQLoraConfig(List):

def __post_init__(self):

# reset for another parse
self.__args__[0].reset()

if self.quant_type not in ["nf4", "fp4"]:
raise ValueError("quant_type can only be either 'nf4' or 'fp4.")

Expand Down

0 comments on commit c859361

Please sign in to comment.