Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReIntroduce Package for FMS Accel #223

Merged
merged 3 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pip install fms-hf-tuning[aim]

If you wish to use [fms-acceleration](https://github.com/foundation-model-stack/fms-acceleration), you need to install it.
```
pip install git+https://github.com/foundation-model-stack/fms-acceleration.git#subdirectory=plugins/framework
pip install fms-hf-tuning[fms-accel]
```
`fms-acceleration` is a collection of plugins that packages that accelerate fine-tuning / training of large models, as part of the `fms-hf-tuning` suite. For more details on see [this section below](#fms-acceleration).

Expand Down Expand Up @@ -389,7 +389,7 @@ Equally you can pass in a JSON configuration for running tuning. See [build doc]

To access `fms-acceleration` features the `[fms-accel]` dependency must first be installed:
```
$ pip install https://github.com/foundation-model-stack/fms-acceleration.git#subdirectory=plugins/framework
$ pip install fms-hf-tuning[fms-accel]
```

Furthermore, the required `fms-acceleration` plugin must be installed. This is done via the command line utility `fms_acceleration.cli`. To show available plugins:
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ dependencies = [
dev = ["wheel>=0.42.0,<1.0", "packaging>=23.2,<24", "ninja>=1.11.1.1,<2.0", "scikit-learn>=1.0, <2.0", "boto3>=1.34, <2.0"]
flash-attn = ["flash-attn>=2.5.3,<3.0"]
aim = ["aim>=3.19.0,<4.0"]
fms-accel = ["fms-acceleration>=0.1"]


[tool.setuptools.packages.find]
exclude = ["tests", "tests.*"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,16 +172,29 @@ def get_framework(self):
NamedTemporaryFile,
)

with NamedTemporaryFile("w") as f:
self.to_yaml(f.name)
return AccelerationFramework(f.name)
try:
with NamedTemporaryFile("w") as f:
self.to_yaml(f.name)
return AccelerationFramework(f.name)
except ValueError as e:
(msg,) = e.args

# AcceleratorFramework raises ValueError if it
# fails to configure any plugin
if self.is_empty() and msg.startswith("No plugins could be configured"):
# in the case when the error was thrown when
# the acceleration framework config was empty
# then this is expected.
return None

raise e
else:
if not self.is_empty():
raise ValueError(
"No acceleration framework package found. To use, first ensure that "
"'pip install git+https://github.com/foundation-model-stack/fms-acceleration.git#subdirectory=plugins/framework' " # pylint: disable=line-too-long
"is done first to obtain the acceleration framework dependency. Additional "
"acceleration plugins make be required depending on the requested "
"No acceleration framework package found. To use, first "
"ensure that 'pip install fms-hf-tuning[fms-accel]' is done first to "
"obtain the acceleration framework dependency. Additional "
"acceleration plugins make be required depending on the requsted "
"acceleration. See README.md for instructions."
)

Expand Down Expand Up @@ -244,7 +257,7 @@ def _descend_and_set(path: List[str], d: Dict):
"to be installed. Please do:\n"
+ "\n".join(
[
"- python -m fms_acceleration install "
"- python -m fms_acceleration.cli install "
f"{AccelerationFrameworkConfig.PACKAGE_PREFIX + x}"
for x in annotate.required_packages
]
Expand Down
18 changes: 1 addition & 17 deletions tuning/config/acceleration_configs/fused_ops_and_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,13 @@
from typing import List

# Local
from .utils import (
EnsureTypes,
ensure_nested_dataclasses_initialized,
parsable_dataclass,
)
from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass


@parsable_dataclass
@dataclass
class FusedLoraConfig(List):

# to help the HfArgumentParser arrive at correct types
__args__ = [EnsureTypes(str, bool)]

# load unsloth optimizations for these 4bit base layer weights.
# currently only support "auto_gptq" and "bitsandbytes"
base_layer: str = None
Expand All @@ -41,9 +34,6 @@ class FusedLoraConfig(List):

def __post_init__(self):

# reset for another parse
self.__args__[0].reset()

if self.base_layer is not None and self.base_layer not in {
"auto_gptq",
"bitsandbytes",
Expand All @@ -60,9 +50,6 @@ def __post_init__(self):
@dataclass
class FastKernelsConfig(List):

# to help the HfArgumentParser arrive at correct types
__args__ = [EnsureTypes(bool, bool, bool)]

# fast loss triton kernels
fast_loss: bool = False

Expand All @@ -74,9 +61,6 @@ class FastKernelsConfig(List):

def __post_init__(self):

# reset for another parse
self.__args__[0].reset()

if not self.fast_loss == self.fast_rsm_layernorm == self.fast_rope_embeddings:
raise ValueError(
"fast_loss, fast_rms_layernorm and fast_rope_embedding must be enabled "
Expand Down
12 changes: 1 addition & 11 deletions tuning/config/acceleration_configs/quantized_lora_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@
from typing import List

# Local
from .utils import (
EnsureTypes,
ensure_nested_dataclasses_initialized,
parsable_dataclass,
)
from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass


@parsable_dataclass
Expand All @@ -49,9 +45,6 @@ def __post_init__(self):
@dataclass
class BNBQLoraConfig(List):

# to help the HfArgumentParser arrive at correct types
__args__ = [EnsureTypes(str, bool)]

# type of quantization applied
quant_type: str = "nf4"

Expand All @@ -61,9 +54,6 @@ class BNBQLoraConfig(List):

def __post_init__(self):

# reset for another parse
self.__args__[0].reset()

if self.quant_type not in ["nf4", "fp4"]:
raise ValueError("quant_type can only be either 'nf4' or 'fp4.")

Expand Down
Loading