From 09eca2902ebe9114e768d0672a7bc1e9806f16fe Mon Sep 17 00:00:00 2001
From: ssbuild <462304@qq.cn>
Date: Tue, 4 Jul 2023 13:41:53 +0800
Subject: [PATCH] 0.1.11
Signed-off-by: ssbuild <462304@qq.cn>
---
README.md | 6 +
setup.py | 4 +-
src/data_helper/data_helper.py | 99 ++--------
src/nlp/layers/lora_v2/layers.py | 171 ++++++++++++++++
src/nlp/models/chatglm2/modeling_chatglm.py | 30 +--
src/nlp/models/lora/v2/configuration.py | 16 ++
src/nlp/models/lora/v2/lora_model.py | 207 ++++++++++++--------
7 files changed, 349 insertions(+), 184 deletions(-)
diff --git a/README.md b/README.md
index ae285091..31f1b827 100644
--- a/README.md
+++ b/README.md
@@ -35,6 +35,12 @@
## update
+- 2023-07-04
+ - 0.1.11 release
+ - fix baichuan and chatglm2 some bugs
+ - support conv2d for lora
+ - support arrow parquet dataset
+
- 2023-06-06
- 0.1.11 rc0 add baichuan model 完整训练 [baichuan_finetuning](https://github.com/ssbuild/baichuan_finetuning)
- 0.1.11 rc1 add chatglm2 model 完整训练 [chatglm2_finetuning](https://github.com/ssbuild/chatglm2_finetuning)
diff --git a/setup.py b/setup.py
index db739545..bfa16e28 100644
--- a/setup.py
+++ b/setup.py
@@ -10,7 +10,7 @@
ignore = ['test','tests']
setup(
name='deep_training',
- version='0.1.11rc1',
+ version='0.1.11',
description='an easy training architecture',
long_description='torch_training: https://github.com/ssbuild/deep_training.git',
license='Apache License 2.0',
@@ -18,7 +18,7 @@
author='ssbuild',
author_email='9727464@qq.com',
install_requires=['lightning>=2',
- 'numpy-io>=0.0.3 , < 0.1.0',
+ 'numpy-io>=0.0.5 , < 0.1.0',
'sentencepiece',
'numpy',
'transformers>=4.22',
diff --git a/src/data_helper/data_helper.py b/src/data_helper/data_helper.py
index 2d391668..d45d0151 100644
--- a/src/data_helper/data_helper.py
+++ b/src/data_helper/data_helper.py
@@ -5,16 +5,16 @@
import os
import typing
import torch
-from fastdatasets import memory as MEMORY
-from fastdatasets.common.iterable_dataset import IterableDatasetBase
-from fastdatasets.common.random_dataset import RandomDatasetBase
-from fastdatasets.torch_dataset import IterableDataset as torch_IterableDataset, Dataset as torch_Dataset
-from torch.utils.data import DataLoader, IterableDataset
+# from fastdatasets import memory as MEMORY
+# from fastdatasets.common.iterable_dataset import IterableDatasetBase
+# from fastdatasets.common.random_dataset import RandomDatasetBase
+# from fastdatasets.torch_dataset import IterableDataset as torch_IterableDataset, Dataset as torch_Dataset
+# from torch.utils.data import DataLoader, IterableDataset
from transformers import PreTrainedTokenizer, PretrainedConfig
from .training_args import ModelArguments, DataArguments, TrainingArguments
from ..utils.func import is_chinese_char
-from numpy_io.core.writer import DataWriteHelper
from numpy_io.pytorch_loader.data_helper import DataHelperBase,load_tokenizer, load_configure
+from numpy_io.core.writer import DataWriteHelper
__all__ = [
'DataHelper',
@@ -43,24 +43,23 @@ class DataHelper(DataHelperBase):
model_args: typing.Optional[ModelArguments] = None
training_args: typing.Optional[TrainingArguments] = None
data_args: typing.Optional[DataArguments] = None
+
def __init__(self,
model_args: ModelArguments,
training_args: typing.Optional[TrainingArguments] = None,
data_args: typing.Optional[DataArguments] = None,
**kwargs):
- super(DataHelper, self).__init__()
-
- self.train_files = []
- self.eval_files = []
- self.test_files = []
+ if data_args:
+ super(DataHelper, self).__init__(data_args.data_backend,data_args.convert_file,data_args.output_dir,data_args.intermediate_name)
+ else:
+ super(DataHelper, self).__init__(None, None, None, None)
self.label2id = None
self.id2label = None
self.max_seq_length_dict = {}
self._external_kwargs = kwargs
- self.backend = data_args.data_backend if data_args else 'record'
self.model_args = model_args
self.training_args = training_args
self.data_args = data_args
@@ -253,82 +252,6 @@ def load_tokenizer_and_config(self,
return tokenizer, config
- # 返回制作特征数据的中间文件
- def get_intermediate_file(self, intermediate_name, mode):
- data_args: DataArguments = self.data_args
- if data_args.data_backend.startswith('memory'):
- # 内存数据: list
- intermediate_output = []
- logging.info('make data {} {}...'.format(data_args.output_dir,
- intermediate_name + '-' + mode + '.' + self.backend))
- else:
- # 本地文件数据: 文件名
- intermediate_output = os.path.join(data_args.output_dir,
- intermediate_name + '-' + mode + '.' + self.backend)
- logging.info('make data {}...'.format(intermediate_output))
- return intermediate_output
-
-
- def make_dataset_with_args(self, input_files,
- mode,
- shuffle=False,
- num_process_worker: int=0,
- overwrite: bool=False,
- mixed_data=True,
- dupe_factor=1):
- '''
- mode: one of [ train , eval , test]
- shuffle: whether shuffle data
- num_process_worker: the number of mutiprocess
- overwrite: whether overwrite data
- mixed_data: Whether the mixed data
- '''
- logging.info('make_dataset {} {}...'.format(','.join(input_files),mode))
- if mode == 'train':
- contain_objs = self.train_files
- elif mode == 'eval' or mode == 'val':
- contain_objs = self.eval_files
- elif mode == 'test' or mode == 'predict':
- contain_objs = self.test_files
- else:
- raise ValueError('{} invalid '.format(mode))
- if not input_files:
- logging.info('input_files empty!')
- return
-
- data_args: DataArguments = self.data_args
- for i in range(dupe_factor):
-
- if data_args.convert_file:
- if mixed_data:
- intermediate_name = data_args.intermediate_name + '_dupe_factor_{}'.format(i)
- intermediate_output = self.get_intermediate_file(intermediate_name, mode)
-
- if isinstance(intermediate_output, list) or not os.path.exists(intermediate_output) or overwrite:
- data = self.on_get_corpus(input_files, mode)
- self.make_dataset(intermediate_output,
- data,
- mode,
- num_process_worker=num_process_worker,
- shuffle=shuffle)
- contain_objs.append(intermediate_output)
- else:
- for fid,input_item in enumerate(input_files):
- intermediate_name = data_args.intermediate_name + '_file_{}_dupe_factor_{}'.format(fid,i)
- intermediate_output = self.get_intermediate_file(intermediate_name, mode)
-
- if isinstance(intermediate_output, list) or not os.path.exists(intermediate_output) or overwrite:
- data = self.on_get_corpus([input_item], mode)
- self.make_dataset(intermediate_output,
- data,
- mode,
- num_process_worker=num_process_worker,
- shuffle=shuffle)
- contain_objs.append(intermediate_output)
-
- else:
- for input_item in input_files:
- contain_objs.append(input_item)
diff --git a/src/nlp/layers/lora_v2/layers.py b/src/nlp/layers/lora_v2/layers.py
index 611f7f5e..a92a7f87 100644
--- a/src/nlp/layers/lora_v2/layers.py
+++ b/src/nlp/layers/lora_v2/layers.py
@@ -6,6 +6,8 @@
import math
import sys
import warnings
+from typing import Union, Tuple
+
import torch
from torch import nn
from torch.nn import functional as F
@@ -60,6 +62,7 @@ def __init__(
self,
in_features: int,
out_features: int,
+ **kwargs
):
self.r = {}
self.lora_alpha = {}
@@ -75,6 +78,7 @@ def __init__(
self.disable_adapters = False
self.in_features = in_features
self.out_features = out_features
+ self.kwargs = kwargs
def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights,dtype=None):
self.r[adapter_name] = r
@@ -116,6 +120,31 @@ def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init
self.reset_lora_parameters(adapter_name)
self.to(self.weight.device)
+ def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights,dtype=None):
+ self.r[adapter_name] = r
+ self.lora_alpha[adapter_name] = lora_alpha
+ if lora_dropout > 0.0:
+ lora_dropout_layer = nn.Dropout(p=lora_dropout)
+ else:
+ lora_dropout_layer = nn.Identity()
+
+ self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
+ # Actual trainable parameters
+ if r > 0:
+ kernel_size = self.kwargs["kernel_size"]
+ stride = self.kwargs["stride"]
+ padding = self.kwargs["padding"]
+ self.lora_A.update(
+ nn.ModuleDict({adapter_name: nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False,dtype=dtype)})
+ )
+ self.lora_B.update(
+ nn.ModuleDict({adapter_name: nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False,dtype=dtype)})
+ )
+ self.scaling[adapter_name] = lora_alpha / r
+ if init_lora_weights:
+ self.reset_lora_parameters(adapter_name)
+ self.to(self.weight.device)
+
def reset_lora_parameters(self, adapter_name):
if adapter_name in self.lora_A.keys():
# initialize A the same way as the default for nn.Linear and B to zero
@@ -297,6 +326,148 @@ def forward(self, x: torch.Tensor):
return nn.Embedding.forward(self, x)
+
+class Conv2d(nn.Conv2d, LoraLayer):
+ # Lora implemented in a conv2d layer
+ def __init__(
+ self,
+ adapter_name: str,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int]],
+ stride: Union[int, Tuple[int]] = 1,
+ padding: Union[int, Tuple[int]] = 0,
+ r: int = 0,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.0,
+ **kwargs,
+ ):
+ init_lora_weights = kwargs.pop("init_lora_weights", True)
+
+ nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, padding)
+ LoraLayer.__init__(
+ self,
+ in_features=in_channels,
+ out_features=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ )
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+
+ nn.Conv2d.reset_parameters(self)
+ self.update_layer_conv2d(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights,dtype=kwargs.get('dtype',None))
+ self.active_adapter = adapter_name
+
+ def merge(self):
+ if self.active_adapter not in self.lora_A.keys():
+ return
+ if self.merged:
+ warnings.warn("Already merged. Nothing to do.")
+ return
+ if self.r[self.active_adapter] > 0:
+ # https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117
+ if self.weight.size()[2:4] == (1, 1):
+ # conv2d 1x1
+ self.weight.data += (
+ self.lora_B[self.active_adapter].weight.squeeze(3).squeeze(2)
+ @ self.lora_A[self.active_adapter].weight.squeeze(3).squeeze(2)
+ ).unsqueeze(2).unsqueeze(3) * self.scaling[self.active_adapter]
+ else:
+ # conv2d 3x3
+ self.weight.data += (
+ F.conv2d(
+ self.lora_A[self.active_adapter].weight.permute(1, 0, 2, 3),
+ self.lora_B[self.active_adapter].weight,
+ ).permute(1, 0, 2, 3)
+ * self.scaling[self.active_adapter]
+ )
+ self.merged = True
+
+ def unmerge(self):
+ if self.active_adapter not in self.lora_A.keys():
+ return
+ if not self.merged:
+ warnings.warn("Already unmerged. Nothing to do.")
+ return
+ if self.r[self.active_adapter] > 0:
+ if self.weight.size()[2:4] == (1, 1):
+ # conv2d 1x1
+ self.weight.data -= (
+ self.lora_B[self.active_adapter].weight.squeeze(3).squeeze(2)
+ @ self.lora_A[self.active_adapter].weight.squeeze(3).squeeze(2)
+ ).unsqueeze(2).unsqueeze(3) * self.scaling[self.active_adapter]
+ else:
+ # conv2d 3x3
+ self.weight.data += (
+ F.conv2d(
+ self.lora_A[self.active_adapter].weight.permute(1, 0, 2, 3),
+ self.lora_B[self.active_adapter].weight,
+ ).permute(1, 0, 2, 3)
+ * self.scaling[self.active_adapter]
+ )
+ self.merged = False
+
+ def forward(self, x: torch.Tensor):
+ previous_dtype = x.dtype
+
+ if self.active_adapter not in self.lora_A.keys():
+ return F.conv2d(
+ x,
+ self.weight,
+ bias=self.bias,
+ stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation,
+ groups=self.groups,
+ )
+ if self.disable_adapters:
+ if self.r[self.active_adapter] > 0 and self.merged:
+ self.unmerge()
+ result = F.conv2d(
+ x,
+ self.weight,
+ bias=self.bias,
+ stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation,
+ groups=self.groups,
+ )
+ elif self.r[self.active_adapter] > 0 and not self.merged:
+ result = F.conv2d(
+ x,
+ self.weight,
+ bias=self.bias,
+ stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation,
+ groups=self.groups,
+ )
+
+ x = x.to(self.lora_A[self.active_adapter].weight.dtype)
+
+ result += (
+ self.lora_B[self.active_adapter](
+ self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
+ )
+ * self.scaling[self.active_adapter]
+ )
+ else:
+ result = F.conv2d(
+ x,
+ self.weight,
+ bias=self.bias,
+ stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation,
+ groups=self.groups,
+ )
+
+ result = result.to(previous_dtype)
+
+ return result
+
if is_bnb_available():
class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer):
diff --git a/src/nlp/models/chatglm2/modeling_chatglm.py b/src/nlp/models/chatglm2/modeling_chatglm.py
index e1bc8695..f3dae78a 100644
--- a/src/nlp/models/chatglm2/modeling_chatglm.py
+++ b/src/nlp/models/chatglm2/modeling_chatglm.py
@@ -418,11 +418,11 @@ def forward(
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
# adjust key and value for inference
+ if kv_cache is not None:
+ cache_k, cache_v = kv_cache
+ key_layer = torch.cat((cache_k, key_layer), dim=0)
+ value_layer = torch.cat((cache_v, value_layer), dim=0)
if use_cache:
- if kv_cache is not None:
- cache_k, cache_v = kv_cache
- key_layer = torch.cat((cache_k, key_layer), dim=0)
- value_layer = torch.cat((cache_v, value_layer), dim=0)
kv_cache = (key_layer, value_layer)
else:
kv_cache = None
@@ -754,7 +754,7 @@ def __init__(self, config: ChatGLMConfig, device=None):
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
- self.kv_channels = config.kv_channels
+ self.kv_channels = config.kv_channels
# Rotary positional embeddings
self.seq_length = config.seq_length
@@ -825,6 +825,13 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
+ if self.pre_seq_len is not None:
+ if past_key_values is None:
+ past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
+ dtype=inputs_embeds.dtype)
+ if attention_mask is not None:
+ attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
+ attention_mask], dim=-1)
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
@@ -836,15 +843,6 @@ def forward(
else:
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
-
-
- if past_key_values is None:
- if self.pre_seq_len is not None:
- past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
- dtype=inputs_embeds.dtype)
- else:
- past_key_values = tuple([None] * self.num_layers)
-
# Run encoder.
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
@@ -1019,10 +1017,14 @@ def process_response(self, response):
return response
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
+
+ if history is None:
+ history = []
prompt = ""
for i, (old_query, response) in enumerate(history):
prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response)
prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
+
inputs = tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.device)
return inputs
diff --git a/src/nlp/models/lora/v2/configuration.py b/src/nlp/models/lora/v2/configuration.py
index 8a8fe478..dd01471c 100644
--- a/src/nlp/models/lora/v2/configuration.py
+++ b/src/nlp/models/lora/v2/configuration.py
@@ -178,6 +178,19 @@ class LoraConfig(LoraBaseArguments):
metadata={"help": "Whether to initialize the weights of the Lora layers."},
)
+ layers_to_transform: Optional[Union[List, int]] = field(
+ default=None,
+ metadata={
+ "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index."
+ },
+ )
+ layers_pattern: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern."
+ },
+ )
+
def __post_init__(self):
if self.lora_type is None:
self.lora_type = 'lora'
@@ -268,3 +281,6 @@ def __post_init__(self):
if self.adalora is not None and isinstance(self.adalora, dict):
self.adalora = AdaLoraConfig.from_memory(self.adalora)
self.with_lora = self.adalora.with_lora | self.with_lora
+
+
+COMMON_LAYERS_PATTERN = ["layers", "h", "block", "blocks"]
\ No newline at end of file
diff --git a/src/nlp/models/lora/v2/lora_model.py b/src/nlp/models/lora/v2/lora_model.py
index f1c3090c..cfb6d543 100644
--- a/src/nlp/models/lora/v2/lora_model.py
+++ b/src/nlp/models/lora/v2/lora_model.py
@@ -13,9 +13,10 @@
from torch import nn
from transformers import Conv1D
+from .configuration import COMMON_LAYERS_PATTERN
from ...transformer_base import TransformerBase
from ....layers.lora_v2.layers import mark_only_lora_as_trainable, is_bnb_available, LoraLayer, Linear, \
- is_bnb_4bit_available, Embedding
+ is_bnb_4bit_available, Embedding, Conv2d
from ....layers.lora_v2.utils import _freeze_adapter, _get_submodules, ModulesToSaveWrapper, \
prepare_model_for_kbit_training, TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
@@ -91,8 +92,7 @@ def add_adapter(self, adapter_name, config=None):
else:
mark_only_lora_as_trainable(self.model, self.lora_config[adapter_name].bias)
- def _find_and_replace(self, adapter_name):
- lora_config = self.lora_config[adapter_name]
+ def _check_quantization_dependency(self):
loaded_in_4bit = getattr(self.get_transformer_model(), "is_loaded_in_4bit", False)
loaded_in_8bit = getattr(self.get_transformer_model(), "is_loaded_in_8bit", False)
if (loaded_in_4bit or loaded_in_8bit) and not is_bnb_available():
@@ -100,7 +100,35 @@ def _find_and_replace(self, adapter_name):
"To use Lora with 8-bit or 4-bit quantization, please install the `bitsandbytes` package. "
"You can install it with `pip install bitsandbytes`."
)
- is_target_modules_in_base_model = False
+
+ def _check_target_module_exists(self, lora_config, key):
+ if isinstance(lora_config.target_modules, str):
+ target_module_found = re.fullmatch(lora_config.target_modules, key)
+ else:
+ target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules)
+ is_using_layer_indexes = getattr(lora_config, "layers_to_transform", None) is not None
+ layer_indexing_pattern = getattr(lora_config, "layers_pattern", None)
+
+ if is_using_layer_indexes and target_module_found:
+ layers_pattern = COMMON_LAYERS_PATTERN if layer_indexing_pattern is None else layer_indexing_pattern
+ layers_pattern = [layers_pattern] if isinstance(layers_pattern, str) else layers_pattern
+
+ for pattern in layers_pattern:
+ layer_index = re.match(f".*.{pattern}\.(\d+)\.*", key)
+ if layer_index is not None:
+ layer_index = int(layer_index.group(1))
+ if isinstance(lora_config.layers_to_transform, int):
+ target_module_found = layer_index == lora_config.layers_to_transform
+ else:
+ target_module_found = layer_index in lora_config.layers_to_transform
+
+ break
+ else:
+ target_module_found = False
+ return target_module_found
+
+ def _create_new_module(self, lora_config, adapter_name, target):
+ bias = hasattr(target, "bias") and target.bias is not None
kwargs = {
"r": lora_config.r,
"lora_alpha": lora_config.lora_alpha,
@@ -108,86 +136,103 @@ def _find_and_replace(self, adapter_name):
"fan_in_fan_out": lora_config.fan_in_fan_out,
"init_lora_weights": lora_config.init_lora_weights,
}
+ loaded_in_4bit = getattr(self.get_transformer_model(), "is_loaded_in_4bit", False)
+ loaded_in_8bit = getattr(self.get_transformer_model(), "is_loaded_in_8bit", False)
+
+ if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
+ eightbit_kwargs = kwargs.copy()
+ eightbit_kwargs.update(
+ {
+ "has_fp16_weights": target.state.has_fp16_weights,
+ "memory_efficient_backward": target.state.memory_efficient_backward,
+ "threshold": target.state.threshold,
+ "index": target.index,
+ }
+ )
+ new_module = Linear8bitLt(
+ adapter_name, target.in_features, target.out_features, bias=bias, **eightbit_kwargs
+ )
+ elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit):
+ fourbit_kwargs = kwargs.copy()
+ fourbit_kwargs.update(
+ {
+ "compute_dtype": target.compute_dtype,
+ "compress_statistics": target.weight.compress_statistics,
+ "quant_type": target.weight.quant_type,
+ }
+ )
+ new_module = Linear4bit(adapter_name, target.in_features, target.out_features, bias=bias, **fourbit_kwargs)
+ elif isinstance(target, torch.nn.Embedding):
+ embedding_kwargs = kwargs.copy()
+ embedding_kwargs.pop("fan_in_fan_out", None)
+ in_features, out_features = target.num_embeddings, target.embedding_dim
+ new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs)
+ elif isinstance(target, torch.nn.Conv2d):
+ out_channels, in_channels = target.weight.size()[:2]
+ kernel_size = target.weight.size()[2:]
+ stride = target.stride
+ padding = target.padding
+ new_module = Conv2d(adapter_name, in_channels, out_channels, kernel_size, stride, padding, **kwargs)
+ else:
+ if isinstance(target, torch.nn.Linear):
+ in_features, out_features = target.in_features, target.out_features
+ if kwargs["fan_in_fan_out"]:
+ warnings.warn(
+ "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
+ "Setting fan_in_fan_out to False."
+ )
+ kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
+ elif isinstance(target, Conv1D):
+ in_features, out_features = (
+ target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
+ )
+ if not kwargs["fan_in_fan_out"]:
+ warnings.warn(
+ "fan_in_fan_out is set to False but the target module is `Conv1D`. "
+ "Setting fan_in_fan_out to True."
+ )
+ kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
+ else:
+ raise ValueError(
+ f"Target module {target} is not supported. "
+ f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
+ )
+ new_module = Linear(adapter_name, in_features, out_features, bias=bias, **kwargs)
+
+ return new_module
+ def _find_and_replace(self, adapter_name):
+ lora_config = self.lora_config[adapter_name]
+ self._check_quantization_dependency()
+ is_target_modules_in_base_model = False
+
key_list = [key for key, _ in self.model.named_modules()]
for key in key_list:
- if isinstance(lora_config.target_modules, str):
- target_module_found = re.fullmatch(lora_config.target_modules, key)
- else:
- target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules)
- if target_module_found:
- if not is_target_modules_in_base_model:
- is_target_modules_in_base_model = True
- parent, target, target_name = _get_submodules(self.model, key)
- if hasattr(target, "bias"):
- bias = target.bias is not None
+ if not self._check_target_module_exists(lora_config, key):
+ continue
- if isinstance(target, LoraLayer):
- target.update_layer(
- adapter_name,
- lora_config.r,
- lora_config.lora_alpha,
- lora_config.lora_dropout,
- lora_config.init_lora_weights,
- dtype=kwargs.get('dtype',None)
- )
- else:
- if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
- eightbit_kwargs = kwargs.copy()
- eightbit_kwargs.update(
- {
- "has_fp16_weights": target.state.has_fp16_weights,
- "memory_efficient_backward": target.state.memory_efficient_backward,
- "threshold": target.state.threshold,
- "index": target.index,
- }
- )
- new_module = Linear8bitLt(
- adapter_name, target.in_features, target.out_features, bias=bias, **eightbit_kwargs
- )
- elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit):
- fourbit_kwargs = kwargs.copy()
- fourbit_kwargs.update(
- {
- "compute_dtype": target.compute_dtype,
- "compress_statistics": target.weight.compress_statistics,
- "quant_type": target.weight.quant_type,
- }
- )
- new_module = Linear4bit(
- adapter_name, target.in_features, target.out_features, bias=bias, **fourbit_kwargs
- )
- elif isinstance(target, torch.nn.Embedding):
- embedding_kwargs = kwargs.copy()
- embedding_kwargs.pop("fan_in_fan_out", None)
- in_features, out_features = target.num_embeddings, target.embedding_dim
- new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs)
- else:
- if isinstance(target, torch.nn.Linear):
- in_features, out_features = target.in_features, target.out_features
- if kwargs["fan_in_fan_out"]:
- warnings.warn(
- "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
- "Setting fan_in_fan_out to False."
- )
- kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
- elif isinstance(target, Conv1D):
- in_features, out_features = (
- target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
- )
- if not kwargs["fan_in_fan_out"]:
- warnings.warn(
- "fan_in_fan_out is set to False but the target module is `Conv1D`. "
- "Setting fan_in_fan_out to True."
- )
- kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
- else:
- raise ValueError(
- f"Target module {target} is not supported. "
- f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
- )
- new_module = Linear(adapter_name, in_features, out_features, bias=bias, **kwargs)
+ is_target_modules_in_base_model = True
+ parent, target, target_name = _get_submodules(self.model, key)
+
+ if isinstance(target, LoraLayer) and isinstance(target, torch.nn.Conv2d):
+ target.update_layer_conv2d(
+ adapter_name,
+ lora_config.r,
+ lora_config.lora_alpha,
+ lora_config.lora_dropout,
+ lora_config.init_lora_weights,
+ )
+ elif isinstance(target, LoraLayer):
+ target.update_layer(
+ adapter_name,
+ lora_config.r,
+ lora_config.lora_alpha,
+ lora_config.lora_dropout,
+ lora_config.init_lora_weights,
+ )
+ else:
+ new_module = self._create_new_module(lora_config, adapter_name, target)
+ self._replace_module(parent, target_name, new_module, target)
- self._replace_module(parent, target_name, new_module, target)
if not is_target_modules_in_base_model:
raise ValueError(
f"Target modules {lora_config.target_modules} not found in the base model. "
@@ -209,6 +254,8 @@ def _replace_module(self, parent_module, child_name, new_module, old_module):
for name, module in new_module.named_modules():
if "lora_" in name:
module.to(old_module.weight.device)
+ if "ranknum" in name:
+ module.to(old_module.weight.device)
def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""