Skip to content

Commit

Permalink
0.0.16
Browse files Browse the repository at this point in the history
Signed-off-by: ssbuild <[email protected]>
  • Loading branch information
ssbuild committed Mar 3, 2023
1 parent be253d1 commit 4166fe0
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 184 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
- [poetry_training](https://github.com/ssbuild/poetry_training)

## 更新
- <strong>2023年03月02</strong>
- 增加loRA 训练 , lion优化器 , 完整训练参考 [chatyuan_finetuning](https://github.com/ssbuild/chatyuan_finetuning)
- <strong>2023年02月15</strong>
- 增加诗歌PaLM预训练模型
- <strong>2023年02月13</strong>
Expand Down
191 changes: 55 additions & 136 deletions nlp/models/lora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,33 @@
from typing import Optional, Union, List

import torch
from torch.nn import Linear
from transformers import Conv1D
from transformers.utils import PushToHubMixin

from .config import PeftConfig, LoraConfig, WEIGHTS_NAME
from ...layers.lora.layers import MergedLinear, Linear8bitLt, is_bnb_available, LoraLayer
from .configuration import LoraArguments, WEIGHTS_NAME
from ...layers.lora.layers import MergedLinear, is_bnb_available, LoraLayer, Linear
from ...layers.lora.utils import mark_only_lora_as_trainable

__all__ = [
'LoraArguments',
'LoraModel',
'LoraLayer'
]

if is_bnb_available():
import bitsandbytes as bnb
from ...layers.lora.layers import Linear8bitLt

def _set_trainable(model):
if model.modules_to_save is not None:
for name, param in model.named_parameters():
if any(module_name in name for module_name in model.modules_to_save):
param.requires_grad = True




def get_peft_model_state_dict(model, state_dict=None):
def get_lora_model_state_dict(model, state_dict=None):
"""
Get the state dict of the Peft model.
Get the state dict of the Lora model.
Args:
model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP,
model ([`LoraModel`]): The Lora model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP,
the model should be the underlying model/unwrapped model (i.e. model.module).
state_dict (`dict`, *optional*, defaults to `None`):
The state dict of the model. If not provided, the state dict of the model
Expand All @@ -45,10 +45,10 @@ def get_peft_model_state_dict(model, state_dict=None):
if state_dict is None:
state_dict = model.state_dict()

# to_return = lora_state_dict(model, bias=model.peft_config.bias)
# to_return = lora_state_dict(model, bias=model.lora_config.bias)
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
# to directly with the state dict which is necessary when using DeepSpeed or FSDP
bias = model.peft_config.bias
bias = model.lora_config.bias
if bias == "none":
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k}
elif bias == "all":
Expand All @@ -71,29 +71,28 @@ def get_peft_model_state_dict(model, state_dict=None):
return to_return


def set_peft_model_state_dict(model, peft_model_state_dict):
def set_lora_model_state_dict(model, lora_model_state_dict):
"""
Set the state dict of the Peft model.
Set the state dict of the Lora model.
Args:
model ([`PeftModel`]): The Peft model.
peft_model_state_dict (`dict`): The state dict of the Peft model.
model ([`LoraModel`]): The Lora model.
lora_model_state_dict (`dict`): The state dict of the Lora model.
"""

model.load_state_dict(peft_model_state_dict, strict=False)
model.load_state_dict(lora_model_state_dict, strict=False)
return model



class LoraModel(torch.nn.Module):


def __init__(self, config, model):
super().__init__()
self.peft_config = config
class LoraModel(torch.nn.Module,PushToHubMixin):
def __init__(self, model, config):
torch.nn.Module.__init__(self)
PushToHubMixin.__init__(self)
self.lora_config = config
self.model = model
self._find_and_replace()
mark_only_lora_as_trainable(self.model, self.peft_config.bias)
mark_only_lora_as_trainable(self.model, self.lora_config.bias)
self.forward = self.model.forward

def _find_and_replace(self):
Expand All @@ -105,24 +104,24 @@ def _find_and_replace(self):
# )
is_target_modules_in_base_model = False
kwargs = {
"r": self.peft_config.r,
"lora_alpha": self.peft_config.lora_alpha,
"lora_dropout": self.peft_config.lora_dropout,
"fan_in_fan_out": self.peft_config.fan_in_fan_out,
"merge_weights": self.peft_config.merge_weights or self.peft_config.inference_mode,
"r": self.lora_config.r,
"lora_alpha": self.lora_config.lora_alpha,
"lora_dropout": self.lora_config.lora_dropout,
"fan_in_fan_out": self.lora_config.fan_in_fan_out,
"merge_weights": self.lora_config.merge_weights or self.lora_config.inference_mode,
}
key_list = [key for key, _ in self.model.named_modules()]
for key in key_list:
if isinstance(self.peft_config.target_modules, str):
target_module_found = re.fullmatch(self.peft_config.target_modules, key)
if isinstance(self.lora_config.target_modules, str):
target_module_found = re.fullmatch(self.lora_config.target_modules, key)
else:
target_module_found = any(key.endswith(target_key) for target_key in self.peft_config.target_modules)
target_module_found = any(key.endswith(target_key) for target_key in self.lora_config.target_modules)
if target_module_found:
if not is_target_modules_in_base_model:
is_target_modules_in_base_model = True
parent, target, target_name = self._get_submodules(key)
bias = target.bias is not None
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt) and self.peft_config.enable_lora is None:
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt) and self.lora_config.enable_lora is None:
kwargs.update(
{
"has_fp16_weights": target.state.has_fp16_weights,
Expand All @@ -132,10 +131,10 @@ def _find_and_replace(self):
}
)
new_module = Linear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs)
elif isinstance(target, torch.nn.Linear) and self.peft_config.enable_lora is None:
elif isinstance(target, torch.nn.Linear) and self.lora_config.enable_lora is None:
new_module = Linear(target.in_features, target.out_features, bias=bias, **kwargs)
elif self.peft_config.enable_lora is not None:
kwargs.update({"enable_lora": self.peft_config.enable_lora})
elif self.lora_config.enable_lora is not None:
kwargs.update({"enable_lora": self.lora_config.enable_lora})
if isinstance(target, Conv1D):
in_features, out_features = target.weight.shape
else:
Expand All @@ -150,7 +149,7 @@ def _find_and_replace(self):
self._replace_module(parent, target_name, new_module, target)
if not is_target_modules_in_base_model:
raise ValueError(
f"Target modules {self.peft_config.target_modules} not found in the base model. "
f"Target modules {self.lora_config.target_modules} not found in the base model. "
f"Please check the target modules and try again."
)

Expand Down Expand Up @@ -180,8 +179,8 @@ def __getattr__(self, name: str):
def modules_to_save(self):
return None

def get_peft_config_as_dict(self, inference: bool = False):
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(self.peft_config).items()}
def get_lora_config_as_dict(self, inference: bool = False):
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(self.lora_config).items()}
if inference:
config["inference_mode"] = True
return config
Expand All @@ -197,44 +196,6 @@ def enable_adapter_layers(self):
def disable_adapter_layers(self):
self._set_adapter_layers(enabled=False)


class PeftModel(PushToHubMixin, torch.nn.Module):
"""
Parameter-Efficient Fine-Tuning Model. Base model encompassing various Peft methods.
Args:
model ([`PreTrainedModel`]): The base transformer model used for Peft.
peft_config ([`PeftConfig`]): The configuration of the Peft model.
**Attributes**:
- **base_model** ([`PreTrainedModel`]) -- The base transformer model used for Peft.
- **peft_config** ([`PeftConfig`]) -- The configuration of the Peft model.
- **modules_to_save** (`list` of `str`) -- The list of sub-module names to save when
saving the model.
- **prompt_encoder** ([`PromptEncoder`]) -- The prompt encoder used for Peft if
`isinstance(self.peft_config, PromptLearningConfig)`.
- **prompt_tokens** (`torch.Tensor`) -- The virtual prompt tokens used for Peft if
`isinstance(self.peft_config, PromptLearningConfig)`.
- **transformer_backbone_name** (`str`) -- The name of the transformer
backbone in the base model if `isinstance(self.peft_config, PromptLearningConfig)`.
- **word_embeddings** (`torch.nn.Embedding`) -- The word embeddings of the transformer backbone
in the base model if `isinstance(self.peft_config, PromptLearningConfig)`.
"""

def __init__(self, model, peft_config: PeftConfig):
super().__init__()
self.peft_config = peft_config
self.base_model = model
self.config = self.base_model.config
self.modules_to_save = None

self.base_model = LoraModel(peft_config, model)
if getattr(self.peft_config, "modules_to_save", None) is not None:
self.modules_to_save = self.peft_config.modules_to_save
_set_trainable(self)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def save_pretrained(self, save_directory, **kwargs):
r"""
Args:
Expand All @@ -252,21 +213,17 @@ def save_pretrained(self, save_directory, **kwargs):
os.makedirs(save_directory, exist_ok=True)

# save only the trainable weights
output_state_dict = get_peft_model_state_dict(self, kwargs.get("state_dict", None))
output_state_dict = get_lora_model_state_dict(self, kwargs.get("state_dict", None))
torch.save(output_state_dict, os.path.join(save_directory, WEIGHTS_NAME))

# save the config and change the inference mode to `True`
if self.peft_config.base_model_name_or_path is None:
self.peft_config.base_model_name_or_path = (
self.base_model.model.__dict__.get("name_or_path", None)
)
inference_mode = self.peft_config.inference_mode
self.peft_config.inference_mode = True
self.peft_config.save_pretrained(save_directory)
self.peft_config.inference_mode = inference_mode
inference_mode = self.lora_config.inference_mode
self.lora_config.inference_mode = True
self.lora_config.save_pretrained(save_directory)
self.lora_config.inference_mode = inference_mode

@classmethod
def from_pretrained(cls, model, model_id, **kwargs):
def from_pretrained(cls, model, pretrained_model_name_or_path, **kwargs):
r"""
Args:
Instantiate a `LoraModel` from a pretrained Lora configuration and weights.
Expand All @@ -280,40 +237,24 @@ def from_pretrained(cls, model, model_id, **kwargs):
- A path to a directory containing a Lora configuration file saved using the
`save_pretrained` method, e.g., ``./my_lora_config_directory/``.
"""
from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING

# load the config
config = PEFT_TYPE_TO_CONFIG_MAPPING[PeftConfig.from_pretrained(model_id).peft_type].from_pretrained(model_id)



if config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys():
model = cls(model, config)
else:
model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config)

lora_config: LoraArguments = LoraArguments.from_pretrained(pretrained_model_name_or_path)
model = cls(model, lora_config)
# load weights if any
if os.path.exists(os.path.join(model_id, WEIGHTS_NAME)):
filename = os.path.join(model_id, WEIGHTS_NAME)
if os.path.exists(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
filename = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
else:
try:
raise ValueError()
# filename = hf_hub_download(model_id, WEIGHTS_NAME)
except: # noqa
raise ValueError(
f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. "
f"Please check that the file {WEIGHTS_NAME} is present at {model_id}."
)
raise ValueError(
f"Can't find weights for {pretrained_model_name_or_path} in {pretrained_model_name_or_path} or in the Hugging Face Hub. "
f"Please check that the file {WEIGHTS_NAME} is present at {pretrained_model_name_or_path}."
)

adapters_weights = torch.load(
filename, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# load the weights into the model
model = set_peft_model_state_dict(model, adapters_weights)
model = set_lora_model_state_dict(model, adapters_weights)

return model



def print_trainable_parameters(self):
"""
Prints the number of trainable parameters in the model.
Expand All @@ -338,26 +279,4 @@ def __getattr__(self, name: str):
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.base_model, name)

def forward(self, *args, **kwargs):
"""
Forward pass of the model.
"""
return self.get_base_model()(*args, **kwargs)

@contextmanager
def disable_adapter(self):
"""
Disables the adapter module.
"""

self.base_model.disable_adapter_layers()
yield
self.base_model.enable_adapter_layers()

def get_base_model(self):
"""
Returns the base model.
"""
return self.base_model.model
return getattr(self.model, name)
Loading

0 comments on commit 4166fe0

Please sign in to comment.