Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extracted Subset of AutoGPTQ library into Accelerated-Peft Plugin #48

Merged
merged 20 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugins/accelerated-peft/.pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ ignore=CVS,protobufs
# ignore-list. The regex matches against paths and can be in Posix or Windows
# format. Because '\\' represents the directory delimiter on Windows systems,
# it can't be used as an escape character.
ignore-paths=
ignore-paths=.*gptqmodel/,tests/test_q4_triton.py,tests/test_triton.py

# Files or directories matching the regular expression patterns are skipped.
# The regex matches against base names, not paths. The default value ignores
Expand Down
1 change: 1 addition & 0 deletions plugins/accelerated-peft/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ classifiers=[

[project.optional-dependencies]
flash-attn = ["flash-attn"]
auto_gptq = ["auto_gptq @ git+https://github.com/AutoGPTQ/AutoGPTQ.git@ea829c7bbe83561c2b1de26795b6592992373ef7"] # known working commitid

[tool.hatch.metadata.hooks.requirements_txt]
files = ["requirements.txt"]
Expand Down
12 changes: 6 additions & 6 deletions plugins/accelerated-peft/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# decide not to have this as an requirement for now
# fms_acceleration @ git+https://github.com/foundation-model-stack/fms-acceleration.git#subdirectory=plugins/framework

# put this in here because there is a breaking FSDP api change that
# is fixed after peft > 0.10
accelerate < 0.29
# Needs a lower bound due to`accelerate.load_checkpoint_in_model` function used in gptqmodel
accelerate >= 0.29

# bitsandbytes for the BNB plugin
bitsandbytes

# Installing from repository because "auto_gptq > 0.7.1" it not yet available
# Specifying the commit id here as recent commits to the main branch have introduced additional dependencies
auto_gptq @ git+https://github.com/AutoGPTQ/AutoGPTQ.git@ea829c7bbe83561c2b1de26795b6592992373ef7
# Used to manage the thread limit in functions for converting old
# GPTQ models to new GPTQ model format that support symmetrical=False
# https://github.com/AutoGPTQ/AutoGPTQ/pull/640
threadpoolctl
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@
from peft.tuners.lora.model import LoraModel
from transformers import AutoModelForCausalLM, TrainingArguments
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.import_utils import _is_package_available
import torch
import torch.distributed


class AutoGPTQAccelerationPlugin(AccelerationPlugin):

require_packages = ["auto_gptq"]
require_packages = []

def __init__(self, configurations: Dict[str, Dict]):
def __init__(self, configurations: Dict[str, Dict], use_external_lib: bool = False):
super().__init__(configurations)

# just do checking, nothing must to configure at this point
Expand All @@ -47,18 +48,31 @@ def __init__(self, configurations: Dict[str, Dict]):
self._check_config_equal(
key="peft.quantization.auto_gptq.from_quantized", value=True
)
self.use_external_lib = use_external_lib

if self.use_external_lib:
assert (
_is_package_available("auto_gptq") is True
), "Unable to use external library, autogptq module not found."

def model_loader(self, model_name: str, **kwargs):
# guarded imports
# Third Party
from auto_gptq import ( # pylint: disable=import-outside-toplevel,import-error
AutoGPTQForCausalLM,
BaseQuantizeConfig,
)
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)

if self.use_external_lib:
# Third Party
from auto_gptq import ( # pylint: disable=import-outside-toplevel,import-error
AutoGPTQForCausalLM as GPTQModel,
)
from auto_gptq import BaseQuantizeConfig as QuantizeConfig # pylint: disable=import-outside-toplevel,import-error
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)
else:
from .gptqmodel import GPTQModel, QuantizeConfig # pylint: disable=import-outside-toplevel,import-error
from .gptqmodel.utils import Backend # pylint: disable=import-outside-toplevel,import-error
from .gptqmodel.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
PATCH_FOR_FSDP_TRITON_V2,
Expand All @@ -85,11 +99,11 @@ def model_loader(self, model_name: str, **kwargs):
# switching to cuda/cuda_old/triton backend."
# assume model_name points to a quantized checkpoint. Thus we load the quantization
# config directly from the checkpoint.
quantize_config = BaseQuantizeConfig.from_pretrained(model_name)
quantize_config = QuantizeConfig.from_pretrained(model_name)

# get additional parameters
torch_dtype = kwargs.get("torch_dtype", torch.float32)
low_cpu_mem_usage = kwargs.get("low_cpu_mem_usage")
low_cpu_mem_usage = kwargs.get("low_cpu_mem_usage", False)
attn_implementation = kwargs.get("attn_implementation")

# there are some kwargs that we wont be passed to AutoModel, so we need
Expand All @@ -101,54 +115,68 @@ def model_loader(self, model_name: str, **kwargs):
)
AutoModelForCausalLM.from_config = _from_config # patch

if self.use_external_lib:
kwargs = {
"low_cpu_mem_usage": low_cpu_mem_usage,
"use_marlin": False, # disable, cannot be used for training (no forward+backward)
"disable_exllama": True, # disable, cannot be used for training (no backward)
"use_tritonv2": True,
"trainable": True, # only support trainable mode
}
else:
kwargs = {
"low_cpu_mem_usage": low_cpu_mem_usage, # this is only used for device map
"backend": Backend.TRITON,
}

# this is a HF method that checks if the low_cpu_mem mode is enabled
# via HF accelerate
if is_fsdp_enabled():
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
_patch_target_module,
make_sure_no_tensor_in_meta_device,
)

# We patch `make_sure_no_tensor_in_meta_device`
# from autogptq to avoid errors on models without bias
_patch_target_module(
to_patch="auto_gptq.modeling._utils.make_sure_no_tensor_in_meta_device",
replace_with=make_sure_no_tensor_in_meta_device,
target_module="auto_gptq.modeling._base",
)
low_cpu_mem_usage = True

# NOTE: need to set the device map as below as we want to use AutoGPTQ for training.
# device_map is for inference only
# https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference
# For low_cpu_mem_usage = True, we have to set the device map to load checkpoints to "cpu"
# to avoid gpu consumption before train
# This approach will divert consumption to cpu memory,
# a better approach would be to load the checkpoints to meta device
# QLoRA is currently implemented by the former approach and will encounter the same issue.
# see https://github.com/huggingface/transformers/pull/25107#issuecomment-2134833262
device_map = {
"": (
(torch.cuda.current_device() if not low_cpu_mem_usage else "cpu")
if torch.cuda.is_available()
else None
)
}
kwargs["low_cpu_mem_usage"] = True
if self.use_external_lib:
fabianlim marked this conversation as resolved.
Show resolved Hide resolved
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
_patch_target_module,
make_sure_no_tensor_in_meta_device,
)

# We patch `make_sure_no_tensor_in_meta_device`
# from autogptq to avoid errors on models without bias
_patch_target_module(
to_patch="auto_gptq.modeling._utils.make_sure_no_tensor_in_meta_device",
replace_with=make_sure_no_tensor_in_meta_device,
target_module="auto_gptq.modeling._base",
)

# NOTE: need to set the device map as below as we want to use AutoGPTQ for training.
# For low_cpu_mem_usage = True, we have to set the device map to load checkpoints
# to "cpu" to avoid gpu consumption before train
# This approach will divert consumption to cpu memory,
# a better approach would be to load the checkpoints to meta device
# QLoRA is currently implemented by the former approach and
# will encounter the same issue.
# see https://github.com/huggingface/transformers/pull/25107#issuecomment-2134833262

kwargs["device_map"] = {
"": (
(
torch.cuda.current_device()
if not kwargs["low_cpu_mem_usage"]
else "cpu"
)
if torch.cuda.is_available()
else None
)
}

# currently only enable triton_v2, because the triton kernels are the only ones
# that have backwards
model = AutoGPTQForCausalLM.from_quantized(
model = GPTQModel.from_quantized(
model_name,
quantize_config=quantize_config,
torch_dtype=torch_dtype,
low_cpu_mem_usage=low_cpu_mem_usage,
use_marlin=False, # disable, cannot be used for training (no forward+backward)
disable_exllama=True, # disable, cannot be used for training (no backward)
warmup_triton=False, # disable for now as it will try to run the warmup while on CPU
use_tritonv2=True,
trainable=True, # only support trainable mode
device_map=device_map,
**kwargs,
)

# https://github.com/foundation-model-stack/fms-acceleration/pull/15
Expand Down Expand Up @@ -219,19 +247,24 @@ def augmentation(
):
# guarded imports
# Third Party
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)
from auto_gptq.utils.peft_utils import ( # pylint: disable=import-outside-toplevel,import-error
GPTQLoraModel,
get_gptq_peft_model,
)
if self.use_external_lib:
# Third Party
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)
from auto_gptq.utils.peft_utils import ( # pylint: disable=import-outside-toplevel,import-error
GPTQLoraModel,
get_gptq_peft_model,
)

# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
create_new_module_peft,
replace_module_peft,
)
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
create_new_module_peft,
replace_module_peft,
)
else:
# Local
from .gptqmodel.utils.peft import get_gptq_peft_model # pylint: disable=import-outside-toplevel,import-error

(peft_config,) = modifiable_args # unpack modifiable args

Expand All @@ -249,31 +282,35 @@ def augmentation(
gradient_checkpointing_kwargs=train_args.gradient_checkpointing_kwargs,
)

# These functions need to replaced due to some incompatibliites
# with newer PEFT packages.
# - on augmentation we call auto_gptq.utils.peft_utils.get_gptq_peft_model
# - this internally calls peft.utils.other.get_peft_model
# - however the problem is that peft API moves very fast, and there are incompatiblities
#
# During peft wrapping there are two key operations
# 1. LoraModel._create_new_module is called to create a LoraLinear layer that is
# compatible with the base layer. For quantized base layers, the LoraLinear
# may be different.
# 2. GPTQLoraModel._replace_module to replace the existing Linear with the LoraLinear.
# Also move to device (which may depend on how base layer is implemented)

# NOTE: GPTQLoraModel inherits from LoraModel, and the _create_new_module method is called
# on the parent. Hence _create_new_module is patched on the parent

# FIXME:
# 1. investigate using BaseGPTQForCausalLM.make_sure_compatible_with_peft
# to see if we can get around the patching

_old_create_new_module = LoraModel._create_new_module
_old_replace_module = GPTQLoraModel._replace_module
_create_new_module = partial(create_new_module_peft, target_cls=QuantLinear)
LoraModel._create_new_module = staticmethod(_create_new_module)
GPTQLoraModel._replace_module = MethodType(replace_module_peft, GPTQLoraModel)
if self.use_external_lib:
# These functions need to replaced due to some incompatibliites
# with newer PEFT packages.
# - on augmentation we call auto_gptq.utils.peft_utils.get_gptq_peft_model
# - this internally calls peft.utils.other.get_peft_model
# - however the problem is that peft API moves very fast, and there are incompatiblities
#
# During peft wrapping there are two key operations
# 1. LoraModel._create_new_module is called to create a LoraLinear layer that is
# compatible with the base layer. For quantized base layers, the LoraLinear
# may be different.
# 2. GPTQLoraModel._replace_module to replace the existing Linear with the LoraLinear.
# Also move to device (which may depend on how base layer is implemented)

# NOTE: GPTQLoraModel inherits from LoraModel,
# and the _create_new_module method is called
# on the parent. Hence _create_new_module is patched on the parent

# FIXME:
# 1. investigate using BaseGPTQForCausalLM.make_sure_compatible_with_peft
# to see if we can get around the patching

_old_create_new_module = LoraModel._create_new_module
_old_replace_module = GPTQLoraModel._replace_module
_create_new_module = partial(create_new_module_peft, target_cls=QuantLinear)
LoraModel._create_new_module = staticmethod(_create_new_module)
GPTQLoraModel._replace_module = MethodType(
replace_module_peft, GPTQLoraModel
)

# Install GPTQ adapters using the AutoGPTQ package (with the above patches)
model = get_gptq_peft_model(
Expand All @@ -284,9 +321,12 @@ def augmentation(
)
modifiable_args = (None,) # return a None for peft_config

# undo the patching for hygine
LoraModel._create_new_module = staticmethod(_old_create_new_module)
GPTQLoraModel._replace_module = MethodType(_old_replace_module, GPTQLoraModel)
if self.use_external_lib:
# undo the patching for hygine
LoraModel._create_new_module = staticmethod(_old_create_new_module)
GPTQLoraModel._replace_module = MethodType(
_old_replace_module, GPTQLoraModel
)

return model, modifiable_args

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
###############################################################################
# Adapted from https://github.com/ModelCloud/GPTQModel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# Local
from .models import GPTQModel
from .quantization import BaseQuantizeConfig, QuantizeConfig
from .utils import Backend, get_backend
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
###############################################################################
# Adapted from https://github.com/ModelCloud/GPTQModel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# Local
from .auto import MODEL_MAP, GPTQModel
from .base import BaseGPTQModel
from .dbrx import DbrxGPTQ
from .dbrx_converted import DbrxConvertedGPTQ
from .gemma import GemmaGPTQ
from .gpt_bigcode import GPTBigCodeGPTQ
from .gpt_neox import GPTNeoXGPTQ
from .llama import LlamaGPTQ
from .mistral import MistralGPTQ
from .mixtral import MixtralGPTQ
Loading
Loading