From 4166fe0792de447afb57b5f776987ff89781ccb9 Mon Sep 17 00:00:00 2001
From: ssbuild <462304@qq.cn>
Date: Fri, 3 Mar 2023 15:31:36 +0800
Subject: [PATCH] 0.0.16
Signed-off-by: ssbuild <462304@qq.cn>
---
README.md | 2 +
nlp/models/lora/__init__.py | 191 +++++-------------
.../lora/{config.py => configuration.py} | 34 ++--
nlp/models/transformer.py | 24 +--
setup.py | 2 +-
utils/trainer.py | 35 ++--
6 files changed, 104 insertions(+), 184 deletions(-)
rename nlp/models/lora/{config.py => configuration.py} (86%)
diff --git a/README.md b/README.md
index 47df97e8..8663487c 100644
--- a/README.md
+++ b/README.md
@@ -8,6 +8,8 @@
- [poetry_training](https://github.com/ssbuild/poetry_training)
## 更新
+- 2023年03月02
+ - 增加loRA 训练 , lion优化器 , 完整训练参考 [chatyuan_finetuning](https://github.com/ssbuild/chatyuan_finetuning)
- 2023年02月15
- 增加诗歌PaLM预训练模型
- 2023年02月13
diff --git a/nlp/models/lora/__init__.py b/nlp/models/lora/__init__.py
index 71151c0f..e9a9ee52 100644
--- a/nlp/models/lora/__init__.py
+++ b/nlp/models/lora/__init__.py
@@ -10,33 +10,33 @@
from typing import Optional, Union, List
import torch
-from torch.nn import Linear
from transformers import Conv1D
from transformers.utils import PushToHubMixin
-from .config import PeftConfig, LoraConfig, WEIGHTS_NAME
-from ...layers.lora.layers import MergedLinear, Linear8bitLt, is_bnb_available, LoraLayer
+from .configuration import LoraArguments, WEIGHTS_NAME
+from ...layers.lora.layers import MergedLinear, is_bnb_available, LoraLayer, Linear
from ...layers.lora.utils import mark_only_lora_as_trainable
+__all__ = [
+ 'LoraArguments',
+ 'LoraModel',
+ 'LoraLayer'
+]
if is_bnb_available():
import bitsandbytes as bnb
+ from ...layers.lora.layers import Linear8bitLt
-def _set_trainable(model):
- if model.modules_to_save is not None:
- for name, param in model.named_parameters():
- if any(module_name in name for module_name in model.modules_to_save):
- param.requires_grad = True
-def get_peft_model_state_dict(model, state_dict=None):
+def get_lora_model_state_dict(model, state_dict=None):
"""
- Get the state dict of the Peft model.
+ Get the state dict of the Lora model.
Args:
- model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP,
+ model ([`LoraModel`]): The Lora model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP,
the model should be the underlying model/unwrapped model (i.e. model.module).
state_dict (`dict`, *optional*, defaults to `None`):
The state dict of the model. If not provided, the state dict of the model
@@ -45,10 +45,10 @@ def get_peft_model_state_dict(model, state_dict=None):
if state_dict is None:
state_dict = model.state_dict()
- # to_return = lora_state_dict(model, bias=model.peft_config.bias)
+ # to_return = lora_state_dict(model, bias=model.lora_config.bias)
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
# to directly with the state dict which is necessary when using DeepSpeed or FSDP
- bias = model.peft_config.bias
+ bias = model.lora_config.bias
if bias == "none":
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k}
elif bias == "all":
@@ -71,29 +71,28 @@ def get_peft_model_state_dict(model, state_dict=None):
return to_return
-def set_peft_model_state_dict(model, peft_model_state_dict):
+def set_lora_model_state_dict(model, lora_model_state_dict):
"""
- Set the state dict of the Peft model.
+ Set the state dict of the Lora model.
Args:
- model ([`PeftModel`]): The Peft model.
- peft_model_state_dict (`dict`): The state dict of the Peft model.
+ model ([`LoraModel`]): The Lora model.
+ lora_model_state_dict (`dict`): The state dict of the Lora model.
"""
- model.load_state_dict(peft_model_state_dict, strict=False)
+ model.load_state_dict(lora_model_state_dict, strict=False)
return model
-class LoraModel(torch.nn.Module):
-
-
- def __init__(self, config, model):
- super().__init__()
- self.peft_config = config
+class LoraModel(torch.nn.Module,PushToHubMixin):
+ def __init__(self, model, config):
+ torch.nn.Module.__init__(self)
+ PushToHubMixin.__init__(self)
+ self.lora_config = config
self.model = model
self._find_and_replace()
- mark_only_lora_as_trainable(self.model, self.peft_config.bias)
+ mark_only_lora_as_trainable(self.model, self.lora_config.bias)
self.forward = self.model.forward
def _find_and_replace(self):
@@ -105,24 +104,24 @@ def _find_and_replace(self):
# )
is_target_modules_in_base_model = False
kwargs = {
- "r": self.peft_config.r,
- "lora_alpha": self.peft_config.lora_alpha,
- "lora_dropout": self.peft_config.lora_dropout,
- "fan_in_fan_out": self.peft_config.fan_in_fan_out,
- "merge_weights": self.peft_config.merge_weights or self.peft_config.inference_mode,
+ "r": self.lora_config.r,
+ "lora_alpha": self.lora_config.lora_alpha,
+ "lora_dropout": self.lora_config.lora_dropout,
+ "fan_in_fan_out": self.lora_config.fan_in_fan_out,
+ "merge_weights": self.lora_config.merge_weights or self.lora_config.inference_mode,
}
key_list = [key for key, _ in self.model.named_modules()]
for key in key_list:
- if isinstance(self.peft_config.target_modules, str):
- target_module_found = re.fullmatch(self.peft_config.target_modules, key)
+ if isinstance(self.lora_config.target_modules, str):
+ target_module_found = re.fullmatch(self.lora_config.target_modules, key)
else:
- target_module_found = any(key.endswith(target_key) for target_key in self.peft_config.target_modules)
+ target_module_found = any(key.endswith(target_key) for target_key in self.lora_config.target_modules)
if target_module_found:
if not is_target_modules_in_base_model:
is_target_modules_in_base_model = True
parent, target, target_name = self._get_submodules(key)
bias = target.bias is not None
- if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt) and self.peft_config.enable_lora is None:
+ if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt) and self.lora_config.enable_lora is None:
kwargs.update(
{
"has_fp16_weights": target.state.has_fp16_weights,
@@ -132,10 +131,10 @@ def _find_and_replace(self):
}
)
new_module = Linear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs)
- elif isinstance(target, torch.nn.Linear) and self.peft_config.enable_lora is None:
+ elif isinstance(target, torch.nn.Linear) and self.lora_config.enable_lora is None:
new_module = Linear(target.in_features, target.out_features, bias=bias, **kwargs)
- elif self.peft_config.enable_lora is not None:
- kwargs.update({"enable_lora": self.peft_config.enable_lora})
+ elif self.lora_config.enable_lora is not None:
+ kwargs.update({"enable_lora": self.lora_config.enable_lora})
if isinstance(target, Conv1D):
in_features, out_features = target.weight.shape
else:
@@ -150,7 +149,7 @@ def _find_and_replace(self):
self._replace_module(parent, target_name, new_module, target)
if not is_target_modules_in_base_model:
raise ValueError(
- f"Target modules {self.peft_config.target_modules} not found in the base model. "
+ f"Target modules {self.lora_config.target_modules} not found in the base model. "
f"Please check the target modules and try again."
)
@@ -180,8 +179,8 @@ def __getattr__(self, name: str):
def modules_to_save(self):
return None
- def get_peft_config_as_dict(self, inference: bool = False):
- config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(self.peft_config).items()}
+ def get_lora_config_as_dict(self, inference: bool = False):
+ config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(self.lora_config).items()}
if inference:
config["inference_mode"] = True
return config
@@ -197,44 +196,6 @@ def enable_adapter_layers(self):
def disable_adapter_layers(self):
self._set_adapter_layers(enabled=False)
-
-class PeftModel(PushToHubMixin, torch.nn.Module):
- """
- Parameter-Efficient Fine-Tuning Model. Base model encompassing various Peft methods.
-
- Args:
- model ([`PreTrainedModel`]): The base transformer model used for Peft.
- peft_config ([`PeftConfig`]): The configuration of the Peft model.
-
-
- **Attributes**:
- - **base_model** ([`PreTrainedModel`]) -- The base transformer model used for Peft.
- - **peft_config** ([`PeftConfig`]) -- The configuration of the Peft model.
- - **modules_to_save** (`list` of `str`) -- The list of sub-module names to save when
- saving the model.
- - **prompt_encoder** ([`PromptEncoder`]) -- The prompt encoder used for Peft if
- `isinstance(self.peft_config, PromptLearningConfig)`.
- - **prompt_tokens** (`torch.Tensor`) -- The virtual prompt tokens used for Peft if
- `isinstance(self.peft_config, PromptLearningConfig)`.
- - **transformer_backbone_name** (`str`) -- The name of the transformer
- backbone in the base model if `isinstance(self.peft_config, PromptLearningConfig)`.
- - **word_embeddings** (`torch.nn.Embedding`) -- The word embeddings of the transformer backbone
- in the base model if `isinstance(self.peft_config, PromptLearningConfig)`.
- """
-
- def __init__(self, model, peft_config: PeftConfig):
- super().__init__()
- self.peft_config = peft_config
- self.base_model = model
- self.config = self.base_model.config
- self.modules_to_save = None
-
- self.base_model = LoraModel(peft_config, model)
- if getattr(self.peft_config, "modules_to_save", None) is not None:
- self.modules_to_save = self.peft_config.modules_to_save
- _set_trainable(self)
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
def save_pretrained(self, save_directory, **kwargs):
r"""
Args:
@@ -252,21 +213,17 @@ def save_pretrained(self, save_directory, **kwargs):
os.makedirs(save_directory, exist_ok=True)
# save only the trainable weights
- output_state_dict = get_peft_model_state_dict(self, kwargs.get("state_dict", None))
+ output_state_dict = get_lora_model_state_dict(self, kwargs.get("state_dict", None))
torch.save(output_state_dict, os.path.join(save_directory, WEIGHTS_NAME))
# save the config and change the inference mode to `True`
- if self.peft_config.base_model_name_or_path is None:
- self.peft_config.base_model_name_or_path = (
- self.base_model.model.__dict__.get("name_or_path", None)
- )
- inference_mode = self.peft_config.inference_mode
- self.peft_config.inference_mode = True
- self.peft_config.save_pretrained(save_directory)
- self.peft_config.inference_mode = inference_mode
+ inference_mode = self.lora_config.inference_mode
+ self.lora_config.inference_mode = True
+ self.lora_config.save_pretrained(save_directory)
+ self.lora_config.inference_mode = inference_mode
@classmethod
- def from_pretrained(cls, model, model_id, **kwargs):
+ def from_pretrained(cls, model, pretrained_model_name_or_path, **kwargs):
r"""
Args:
Instantiate a `LoraModel` from a pretrained Lora configuration and weights.
@@ -280,40 +237,24 @@ def from_pretrained(cls, model, model_id, **kwargs):
- A path to a directory containing a Lora configuration file saved using the
`save_pretrained` method, e.g., ``./my_lora_config_directory/``.
"""
- from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING
-
- # load the config
- config = PEFT_TYPE_TO_CONFIG_MAPPING[PeftConfig.from_pretrained(model_id).peft_type].from_pretrained(model_id)
-
-
-
- if config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys():
- model = cls(model, config)
- else:
- model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config)
-
+ lora_config: LoraArguments = LoraArguments.from_pretrained(pretrained_model_name_or_path)
+ model = cls(model, lora_config)
# load weights if any
- if os.path.exists(os.path.join(model_id, WEIGHTS_NAME)):
- filename = os.path.join(model_id, WEIGHTS_NAME)
+ if os.path.exists(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
+ filename = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
else:
- try:
- raise ValueError()
- # filename = hf_hub_download(model_id, WEIGHTS_NAME)
- except: # noqa
- raise ValueError(
- f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. "
- f"Please check that the file {WEIGHTS_NAME} is present at {model_id}."
- )
+ raise ValueError(
+ f"Can't find weights for {pretrained_model_name_or_path} in {pretrained_model_name_or_path} or in the Hugging Face Hub. "
+ f"Please check that the file {WEIGHTS_NAME} is present at {pretrained_model_name_or_path}."
+ )
adapters_weights = torch.load(
filename, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# load the weights into the model
- model = set_peft_model_state_dict(model, adapters_weights)
+ model = set_lora_model_state_dict(model, adapters_weights)
return model
-
-
def print_trainable_parameters(self):
"""
Prints the number of trainable parameters in the model.
@@ -338,26 +279,4 @@ def __getattr__(self, name: str):
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
- return getattr(self.base_model, name)
-
- def forward(self, *args, **kwargs):
- """
- Forward pass of the model.
- """
- return self.get_base_model()(*args, **kwargs)
-
- @contextmanager
- def disable_adapter(self):
- """
- Disables the adapter module.
- """
-
- self.base_model.disable_adapter_layers()
- yield
- self.base_model.enable_adapter_layers()
-
- def get_base_model(self):
- """
- Returns the base model.
- """
- return self.base_model.model
\ No newline at end of file
+ return getattr(self.model, name)
diff --git a/nlp/models/lora/config.py b/nlp/models/lora/configuration.py
similarity index 86%
rename from nlp/models/lora/config.py
rename to nlp/models/lora/configuration.py
index 58ca860f..78596efd 100644
--- a/nlp/models/lora/config.py
+++ b/nlp/models/lora/configuration.py
@@ -12,7 +12,7 @@
CONFIG_NAME = "adapter_config.json"
@dataclass
-class PeftConfigMixin(PushToHubMixin):
+class LoraConfigMixin(PushToHubMixin):
r"""
This is the base configuration class for PEFT adapter models. It contains all the methods that are common to all
PEFT adapter models. This class inherits from `transformers.utils.PushToHubMixin` which contains the methods to
@@ -67,11 +67,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, CONFIG_NAME)):
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
else:
- try:
- ...
- # config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME)
- except:
- raise ValueError(f"Can't find config.json at '{pretrained_model_name_or_path}'")
+ raise ValueError(f"Can't find config.json at '{pretrained_model_name_or_path}'")
loaded_attributes = cls.from_json_file(config_file)
@@ -98,27 +94,16 @@ def from_json_file(cls, path_json_file, **kwargs):
return json_object
-@dataclass
-class PeftConfig(PeftConfigMixin):
- """
- This is the base configuration class to store the configuration of a :class:`~peft.PeftModel`.
- Args:
- peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use.
- task_type (Union[[`~peft.utils.config.TaskType`], `str`]): The type of task to perform.
- inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode.
- """
-
- base_model_name_or_path: str = field(default=None, metadata={"help": "The name of the base model to use."})
- inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"})
@dataclass
-class LoraConfig(PeftConfig):
+class LoraArguments(LoraConfigMixin):
"""
This is the configuration class to store the configuration of a [`~peft.Lora`].
Args:
+ inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode.
r (`int`): Lora attention dimension
target_modules (`Union[List[str],str]`): The names of the modules to apply Lora to.
lora_alpha (`float`): The alpha parameter for Lora scaling.
@@ -131,7 +116,10 @@ class LoraConfig(PeftConfig):
modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable
and saved in the final checkpoint.
"""
+ lora_model_name_or_path: str = field(default=None, metadata={"help": "The name of the base model to use."})
+ inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"})
+ with_lora: bool = field(default=False, metadata={"help": "whether use lora"})
r: int = field(default=8, metadata={"help": "Lora attention dimension"})
target_modules: Optional[Union[List[str], str]] = field(
default=None,
@@ -160,4 +148,10 @@ class LoraConfig(PeftConfig):
},
)
- def __post_init__(self): ...
\ No newline at end of file
+ def __post_init__(self):
+ if self.inference_mode:
+ self.merge_weights = True
+
+ if self.target_modules is not None and len(self.target_modules) == 1:
+ self.fan_in_fan_out = True
+ self.enable_lora = [True, False, True]
\ No newline at end of file
diff --git a/nlp/models/transformer.py b/nlp/models/transformer.py
index 49023f27..28b4d44a 100644
--- a/nlp/models/transformer.py
+++ b/nlp/models/transformer.py
@@ -233,18 +233,14 @@ def model(self):
def model(self, model):
self.set_model(model)
- def set_model(self, model):
- # keep_keys = [
- # 'config_class','load_tf_weights','base_model_prefix','supports_gradient_checkpointing','_init_weights','_set_gradient_checkpointing',
- # '_keys_to_ignore_on_load_missing','_keys_to_ignore_on_load_unexpected','_no_split_modules','is_parallelizable','_shift_right','main_input_name',
- # '_get_feat_extract_output_lengths','_get_feature_vector_attention_mask',#dummy_inputs
- # ]
- keep_keys = ['config_class','base_model_prefix']
- for k in keep_keys:
- o = getattr(model,k,None)
- if o is None:
- continue
- setattr(self,k,o)
+ def set_model(self, model , copy_attr=True):
+ if copy_attr:
+ keep_keys = ['config_class','base_model_prefix']
+ for k in keep_keys:
+ o = getattr(model,k,None)
+ if o is None:
+ continue
+ setattr(self,k,o)
assert self.base_model_prefix is not None, ValueError('base_model_prefix is not allow empty')
setattr(self, self.base_model_prefix, model)
@@ -350,9 +346,11 @@ def model(self):
def model(self, model):
self.set_model(model)
- def set_model(self, model):
+ def set_model(self, model , copy_attr=True):
assert model is not None
self.__backbone = model
+ if not copy_attr:
+ return
copy_attr = [
'log','log_dict'
diff --git a/setup.py b/setup.py
index 50c78eb9..e58adbfc 100644
--- a/setup.py
+++ b/setup.py
@@ -5,7 +5,7 @@
ignore = ['test','tests']
setup(
name='deep_training',
- version='0.0.15@post2',
+ version='0.0.16',
description='an easy training architecture',
long_description='torch_training: https://github.com/ssbuild/deep_training.git',
license='Apache License 2.0',
diff --git a/utils/trainer.py b/utils/trainer.py
index ace710c6..e6f66161 100644
--- a/utils/trainer.py
+++ b/utils/trainer.py
@@ -62,6 +62,22 @@ def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, Tensor]:
def on_get_metric( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
return {}
+ def update_best(self,val):
+ flag = False
+ if isinstance(val, torch.Tensor):
+ if self.monitor not in self.best:
+ flag = True
+ self.best[self.monitor] = val
+ else:
+ monitor_op = torch.le if self.mode.lower() == 'min' else torch.ge
+ if monitor_op(val, self.best[self.monitor]).bool().cpu().item():
+ flag = True
+ else:
+ warnings.warn('monitor {} is not tensor'.format(self.monitor))
+
+ if flag:
+ self.best[self.monitor] = val
+ return flag
def on_save_model(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
@@ -71,17 +87,8 @@ def on_save_model(
monitor_candidates.update(self.on_get_metric(trainer,pl_module))
val = monitor_candidates.get(self.monitor,None)
if val is not None:
- flag = False
- if isinstance(val,torch.Tensor):
- if self.monitor not in self.best:
- self.best[self.monitor] = val
- monitor_op = torch.le if self.mode.lower() == 'min' else torch.ge
- if monitor_op(val ,self.best[self.monitor]).bool().cpu().item():
- flag = True
- else:
- warnings.warn('monitor {} is not tensor'.format(self.monitor))
+ flag = self.update_best(val)
if flag:
- self.best[self.monitor] = val
logging.info('epoch {} ,step {} , save best {}, {}\n'.format(monitor_candidates['epoch'],
monitor_candidates['step'],
self.best[self.monitor],
@@ -90,16 +97,16 @@ def on_save_model(
if self.last_weight_file is not None:
logging.info('epoch {} ,step {} , save {}\n'.format(monitor_candidates['epoch'],
- monitor_candidates['step'],
- self.last_weight_file))
+ monitor_candidates['step'],
+ self.last_weight_file))
trainer.save_checkpoint(self.last_weight_file)
else:
warnings.warn('monitor {} is not in metirc , save lastest checkpoint!'.format(self.monitor))
logging.info('epoch {} ,step {} , save {}\n'.format(monitor_candidates['epoch'],
- monitor_candidates['step'],
- self.weight_file))
+ monitor_candidates['step'],
+ self.weight_file))
trainer.save_checkpoint(self.weight_file)