From 09eca2902ebe9114e768d0672a7bc1e9806f16fe Mon Sep 17 00:00:00 2001 From: ssbuild <462304@qq.cn> Date: Tue, 4 Jul 2023 13:41:53 +0800 Subject: [PATCH] 0.1.11 Signed-off-by: ssbuild <462304@qq.cn> --- README.md | 6 + setup.py | 4 +- src/data_helper/data_helper.py | 99 ++-------- src/nlp/layers/lora_v2/layers.py | 171 ++++++++++++++++ src/nlp/models/chatglm2/modeling_chatglm.py | 30 +-- src/nlp/models/lora/v2/configuration.py | 16 ++ src/nlp/models/lora/v2/lora_model.py | 207 ++++++++++++-------- 7 files changed, 349 insertions(+), 184 deletions(-) diff --git a/README.md b/README.md index ae285091..31f1b827 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,12 @@ ## update +- 2023-07-04 + - 0.1.11 release + - fix baichuan and chatglm2 some bugs + - support conv2d for lora + - support arrow parquet dataset + - 2023-06-06 - 0.1.11 rc0 add baichuan model 完整训练 [baichuan_finetuning](https://github.com/ssbuild/baichuan_finetuning) - 0.1.11 rc1 add chatglm2 model 完整训练 [chatglm2_finetuning](https://github.com/ssbuild/chatglm2_finetuning) diff --git a/setup.py b/setup.py index db739545..bfa16e28 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ ignore = ['test','tests'] setup( name='deep_training', - version='0.1.11rc1', + version='0.1.11', description='an easy training architecture', long_description='torch_training: https://github.com/ssbuild/deep_training.git', license='Apache License 2.0', @@ -18,7 +18,7 @@ author='ssbuild', author_email='9727464@qq.com', install_requires=['lightning>=2', - 'numpy-io>=0.0.3 , < 0.1.0', + 'numpy-io>=0.0.5 , < 0.1.0', 'sentencepiece', 'numpy', 'transformers>=4.22', diff --git a/src/data_helper/data_helper.py b/src/data_helper/data_helper.py index 2d391668..d45d0151 100644 --- a/src/data_helper/data_helper.py +++ b/src/data_helper/data_helper.py @@ -5,16 +5,16 @@ import os import typing import torch -from fastdatasets import memory as MEMORY -from fastdatasets.common.iterable_dataset import IterableDatasetBase -from fastdatasets.common.random_dataset import RandomDatasetBase -from fastdatasets.torch_dataset import IterableDataset as torch_IterableDataset, Dataset as torch_Dataset -from torch.utils.data import DataLoader, IterableDataset +# from fastdatasets import memory as MEMORY +# from fastdatasets.common.iterable_dataset import IterableDatasetBase +# from fastdatasets.common.random_dataset import RandomDatasetBase +# from fastdatasets.torch_dataset import IterableDataset as torch_IterableDataset, Dataset as torch_Dataset +# from torch.utils.data import DataLoader, IterableDataset from transformers import PreTrainedTokenizer, PretrainedConfig from .training_args import ModelArguments, DataArguments, TrainingArguments from ..utils.func import is_chinese_char -from numpy_io.core.writer import DataWriteHelper from numpy_io.pytorch_loader.data_helper import DataHelperBase,load_tokenizer, load_configure +from numpy_io.core.writer import DataWriteHelper __all__ = [ 'DataHelper', @@ -43,24 +43,23 @@ class DataHelper(DataHelperBase): model_args: typing.Optional[ModelArguments] = None training_args: typing.Optional[TrainingArguments] = None data_args: typing.Optional[DataArguments] = None + def __init__(self, model_args: ModelArguments, training_args: typing.Optional[TrainingArguments] = None, data_args: typing.Optional[DataArguments] = None, **kwargs): - super(DataHelper, self).__init__() - - self.train_files = [] - self.eval_files = [] - self.test_files = [] + if data_args: + super(DataHelper, self).__init__(data_args.data_backend,data_args.convert_file,data_args.output_dir,data_args.intermediate_name) + else: + super(DataHelper, self).__init__(None, None, None, None) self.label2id = None self.id2label = None self.max_seq_length_dict = {} self._external_kwargs = kwargs - self.backend = data_args.data_backend if data_args else 'record' self.model_args = model_args self.training_args = training_args self.data_args = data_args @@ -253,82 +252,6 @@ def load_tokenizer_and_config(self, return tokenizer, config - # 返回制作特征数据的中间文件 - def get_intermediate_file(self, intermediate_name, mode): - data_args: DataArguments = self.data_args - if data_args.data_backend.startswith('memory'): - # 内存数据: list - intermediate_output = [] - logging.info('make data {} {}...'.format(data_args.output_dir, - intermediate_name + '-' + mode + '.' + self.backend)) - else: - # 本地文件数据: 文件名 - intermediate_output = os.path.join(data_args.output_dir, - intermediate_name + '-' + mode + '.' + self.backend) - logging.info('make data {}...'.format(intermediate_output)) - return intermediate_output - - - def make_dataset_with_args(self, input_files, - mode, - shuffle=False, - num_process_worker: int=0, - overwrite: bool=False, - mixed_data=True, - dupe_factor=1): - ''' - mode: one of [ train , eval , test] - shuffle: whether shuffle data - num_process_worker: the number of mutiprocess - overwrite: whether overwrite data - mixed_data: Whether the mixed data - ''' - logging.info('make_dataset {} {}...'.format(','.join(input_files),mode)) - if mode == 'train': - contain_objs = self.train_files - elif mode == 'eval' or mode == 'val': - contain_objs = self.eval_files - elif mode == 'test' or mode == 'predict': - contain_objs = self.test_files - else: - raise ValueError('{} invalid '.format(mode)) - if not input_files: - logging.info('input_files empty!') - return - - data_args: DataArguments = self.data_args - for i in range(dupe_factor): - - if data_args.convert_file: - if mixed_data: - intermediate_name = data_args.intermediate_name + '_dupe_factor_{}'.format(i) - intermediate_output = self.get_intermediate_file(intermediate_name, mode) - - if isinstance(intermediate_output, list) or not os.path.exists(intermediate_output) or overwrite: - data = self.on_get_corpus(input_files, mode) - self.make_dataset(intermediate_output, - data, - mode, - num_process_worker=num_process_worker, - shuffle=shuffle) - contain_objs.append(intermediate_output) - else: - for fid,input_item in enumerate(input_files): - intermediate_name = data_args.intermediate_name + '_file_{}_dupe_factor_{}'.format(fid,i) - intermediate_output = self.get_intermediate_file(intermediate_name, mode) - - if isinstance(intermediate_output, list) or not os.path.exists(intermediate_output) or overwrite: - data = self.on_get_corpus([input_item], mode) - self.make_dataset(intermediate_output, - data, - mode, - num_process_worker=num_process_worker, - shuffle=shuffle) - contain_objs.append(intermediate_output) - - else: - for input_item in input_files: - contain_objs.append(input_item) diff --git a/src/nlp/layers/lora_v2/layers.py b/src/nlp/layers/lora_v2/layers.py index 611f7f5e..a92a7f87 100644 --- a/src/nlp/layers/lora_v2/layers.py +++ b/src/nlp/layers/lora_v2/layers.py @@ -6,6 +6,8 @@ import math import sys import warnings +from typing import Union, Tuple + import torch from torch import nn from torch.nn import functional as F @@ -60,6 +62,7 @@ def __init__( self, in_features: int, out_features: int, + **kwargs ): self.r = {} self.lora_alpha = {} @@ -75,6 +78,7 @@ def __init__( self.disable_adapters = False self.in_features = in_features self.out_features = out_features + self.kwargs = kwargs def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights,dtype=None): self.r[adapter_name] = r @@ -116,6 +120,31 @@ def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init self.reset_lora_parameters(adapter_name) self.to(self.weight.device) + def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights,dtype=None): + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) + # Actual trainable parameters + if r > 0: + kernel_size = self.kwargs["kernel_size"] + stride = self.kwargs["stride"] + padding = self.kwargs["padding"] + self.lora_A.update( + nn.ModuleDict({adapter_name: nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False,dtype=dtype)}) + ) + self.lora_B.update( + nn.ModuleDict({adapter_name: nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False,dtype=dtype)}) + ) + self.scaling[adapter_name] = lora_alpha / r + if init_lora_weights: + self.reset_lora_parameters(adapter_name) + self.to(self.weight.device) + def reset_lora_parameters(self, adapter_name): if adapter_name in self.lora_A.keys(): # initialize A the same way as the default for nn.Linear and B to zero @@ -297,6 +326,148 @@ def forward(self, x: torch.Tensor): return nn.Embedding.forward(self, x) + +class Conv2d(nn.Conv2d, LoraLayer): + # Lora implemented in a conv2d layer + def __init__( + self, + adapter_name: str, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int]], + stride: Union[int, Tuple[int]] = 1, + padding: Union[int, Tuple[int]] = 0, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + **kwargs, + ): + init_lora_weights = kwargs.pop("init_lora_weights", True) + + nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, padding) + LoraLayer.__init__( + self, + in_features=in_channels, + out_features=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + + nn.Conv2d.reset_parameters(self) + self.update_layer_conv2d(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights,dtype=kwargs.get('dtype',None)) + self.active_adapter = adapter_name + + def merge(self): + if self.active_adapter not in self.lora_A.keys(): + return + if self.merged: + warnings.warn("Already merged. Nothing to do.") + return + if self.r[self.active_adapter] > 0: + # https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117 + if self.weight.size()[2:4] == (1, 1): + # conv2d 1x1 + self.weight.data += ( + self.lora_B[self.active_adapter].weight.squeeze(3).squeeze(2) + @ self.lora_A[self.active_adapter].weight.squeeze(3).squeeze(2) + ).unsqueeze(2).unsqueeze(3) * self.scaling[self.active_adapter] + else: + # conv2d 3x3 + self.weight.data += ( + F.conv2d( + self.lora_A[self.active_adapter].weight.permute(1, 0, 2, 3), + self.lora_B[self.active_adapter].weight, + ).permute(1, 0, 2, 3) + * self.scaling[self.active_adapter] + ) + self.merged = True + + def unmerge(self): + if self.active_adapter not in self.lora_A.keys(): + return + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + if self.r[self.active_adapter] > 0: + if self.weight.size()[2:4] == (1, 1): + # conv2d 1x1 + self.weight.data -= ( + self.lora_B[self.active_adapter].weight.squeeze(3).squeeze(2) + @ self.lora_A[self.active_adapter].weight.squeeze(3).squeeze(2) + ).unsqueeze(2).unsqueeze(3) * self.scaling[self.active_adapter] + else: + # conv2d 3x3 + self.weight.data += ( + F.conv2d( + self.lora_A[self.active_adapter].weight.permute(1, 0, 2, 3), + self.lora_B[self.active_adapter].weight, + ).permute(1, 0, 2, 3) + * self.scaling[self.active_adapter] + ) + self.merged = False + + def forward(self, x: torch.Tensor): + previous_dtype = x.dtype + + if self.active_adapter not in self.lora_A.keys(): + return F.conv2d( + x, + self.weight, + bias=self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + if self.disable_adapters: + if self.r[self.active_adapter] > 0 and self.merged: + self.unmerge() + result = F.conv2d( + x, + self.weight, + bias=self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + elif self.r[self.active_adapter] > 0 and not self.merged: + result = F.conv2d( + x, + self.weight, + bias=self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + + x = x.to(self.lora_A[self.active_adapter].weight.dtype) + + result += ( + self.lora_B[self.active_adapter]( + self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x)) + ) + * self.scaling[self.active_adapter] + ) + else: + result = F.conv2d( + x, + self.weight, + bias=self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + + result = result.to(previous_dtype) + + return result + if is_bnb_available(): class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer): diff --git a/src/nlp/models/chatglm2/modeling_chatglm.py b/src/nlp/models/chatglm2/modeling_chatglm.py index e1bc8695..f3dae78a 100644 --- a/src/nlp/models/chatglm2/modeling_chatglm.py +++ b/src/nlp/models/chatglm2/modeling_chatglm.py @@ -418,11 +418,11 @@ def forward( key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) # adjust key and value for inference + if kv_cache is not None: + cache_k, cache_v = kv_cache + key_layer = torch.cat((cache_k, key_layer), dim=0) + value_layer = torch.cat((cache_v, value_layer), dim=0) if use_cache: - if kv_cache is not None: - cache_k, cache_v = kv_cache - key_layer = torch.cat((cache_k, key_layer), dim=0) - value_layer = torch.cat((cache_v, value_layer), dim=0) kv_cache = (key_layer, value_layer) else: kv_cache = None @@ -754,7 +754,7 @@ def __init__(self, config: ChatGLMConfig, device=None): self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num - self.kv_channels = config.kv_channels + self.kv_channels = config.kv_channels # Rotary positional embeddings self.seq_length = config.seq_length @@ -825,6 +825,13 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device, + dtype=inputs_embeds.dtype) + if attention_mask is not None: + attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask], dim=-1) if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) @@ -836,15 +843,6 @@ def forward( else: rotary_pos_emb = rotary_pos_emb[None, :seq_length] rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() - - - if past_key_values is None: - if self.pre_seq_len is not None: - past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device, - dtype=inputs_embeds.dtype) - else: - past_key_values = tuple([None] * self.num_layers) - # Run encoder. hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( @@ -1019,10 +1017,14 @@ def process_response(self, response): return response def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): + + if history is None: + history = [] prompt = "" for i, (old_query, response) in enumerate(history): prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response) prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) + inputs = tokenizer([prompt], return_tensors="pt") inputs = inputs.to(self.device) return inputs diff --git a/src/nlp/models/lora/v2/configuration.py b/src/nlp/models/lora/v2/configuration.py index 8a8fe478..dd01471c 100644 --- a/src/nlp/models/lora/v2/configuration.py +++ b/src/nlp/models/lora/v2/configuration.py @@ -178,6 +178,19 @@ class LoraConfig(LoraBaseArguments): metadata={"help": "Whether to initialize the weights of the Lora layers."}, ) + layers_to_transform: Optional[Union[List, int]] = field( + default=None, + metadata={ + "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index." + }, + ) + layers_pattern: Optional[str] = field( + default=None, + metadata={ + "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern." + }, + ) + def __post_init__(self): if self.lora_type is None: self.lora_type = 'lora' @@ -268,3 +281,6 @@ def __post_init__(self): if self.adalora is not None and isinstance(self.adalora, dict): self.adalora = AdaLoraConfig.from_memory(self.adalora) self.with_lora = self.adalora.with_lora | self.with_lora + + +COMMON_LAYERS_PATTERN = ["layers", "h", "block", "blocks"] \ No newline at end of file diff --git a/src/nlp/models/lora/v2/lora_model.py b/src/nlp/models/lora/v2/lora_model.py index f1c3090c..cfb6d543 100644 --- a/src/nlp/models/lora/v2/lora_model.py +++ b/src/nlp/models/lora/v2/lora_model.py @@ -13,9 +13,10 @@ from torch import nn from transformers import Conv1D +from .configuration import COMMON_LAYERS_PATTERN from ...transformer_base import TransformerBase from ....layers.lora_v2.layers import mark_only_lora_as_trainable, is_bnb_available, LoraLayer, Linear, \ - is_bnb_4bit_available, Embedding + is_bnb_4bit_available, Embedding, Conv2d from ....layers.lora_v2.utils import _freeze_adapter, _get_submodules, ModulesToSaveWrapper, \ prepare_model_for_kbit_training, TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING @@ -91,8 +92,7 @@ def add_adapter(self, adapter_name, config=None): else: mark_only_lora_as_trainable(self.model, self.lora_config[adapter_name].bias) - def _find_and_replace(self, adapter_name): - lora_config = self.lora_config[adapter_name] + def _check_quantization_dependency(self): loaded_in_4bit = getattr(self.get_transformer_model(), "is_loaded_in_4bit", False) loaded_in_8bit = getattr(self.get_transformer_model(), "is_loaded_in_8bit", False) if (loaded_in_4bit or loaded_in_8bit) and not is_bnb_available(): @@ -100,7 +100,35 @@ def _find_and_replace(self, adapter_name): "To use Lora with 8-bit or 4-bit quantization, please install the `bitsandbytes` package. " "You can install it with `pip install bitsandbytes`." ) - is_target_modules_in_base_model = False + + def _check_target_module_exists(self, lora_config, key): + if isinstance(lora_config.target_modules, str): + target_module_found = re.fullmatch(lora_config.target_modules, key) + else: + target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules) + is_using_layer_indexes = getattr(lora_config, "layers_to_transform", None) is not None + layer_indexing_pattern = getattr(lora_config, "layers_pattern", None) + + if is_using_layer_indexes and target_module_found: + layers_pattern = COMMON_LAYERS_PATTERN if layer_indexing_pattern is None else layer_indexing_pattern + layers_pattern = [layers_pattern] if isinstance(layers_pattern, str) else layers_pattern + + for pattern in layers_pattern: + layer_index = re.match(f".*.{pattern}\.(\d+)\.*", key) + if layer_index is not None: + layer_index = int(layer_index.group(1)) + if isinstance(lora_config.layers_to_transform, int): + target_module_found = layer_index == lora_config.layers_to_transform + else: + target_module_found = layer_index in lora_config.layers_to_transform + + break + else: + target_module_found = False + return target_module_found + + def _create_new_module(self, lora_config, adapter_name, target): + bias = hasattr(target, "bias") and target.bias is not None kwargs = { "r": lora_config.r, "lora_alpha": lora_config.lora_alpha, @@ -108,86 +136,103 @@ def _find_and_replace(self, adapter_name): "fan_in_fan_out": lora_config.fan_in_fan_out, "init_lora_weights": lora_config.init_lora_weights, } + loaded_in_4bit = getattr(self.get_transformer_model(), "is_loaded_in_4bit", False) + loaded_in_8bit = getattr(self.get_transformer_model(), "is_loaded_in_8bit", False) + + if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): + eightbit_kwargs = kwargs.copy() + eightbit_kwargs.update( + { + "has_fp16_weights": target.state.has_fp16_weights, + "memory_efficient_backward": target.state.memory_efficient_backward, + "threshold": target.state.threshold, + "index": target.index, + } + ) + new_module = Linear8bitLt( + adapter_name, target.in_features, target.out_features, bias=bias, **eightbit_kwargs + ) + elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit): + fourbit_kwargs = kwargs.copy() + fourbit_kwargs.update( + { + "compute_dtype": target.compute_dtype, + "compress_statistics": target.weight.compress_statistics, + "quant_type": target.weight.quant_type, + } + ) + new_module = Linear4bit(adapter_name, target.in_features, target.out_features, bias=bias, **fourbit_kwargs) + elif isinstance(target, torch.nn.Embedding): + embedding_kwargs = kwargs.copy() + embedding_kwargs.pop("fan_in_fan_out", None) + in_features, out_features = target.num_embeddings, target.embedding_dim + new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs) + elif isinstance(target, torch.nn.Conv2d): + out_channels, in_channels = target.weight.size()[:2] + kernel_size = target.weight.size()[2:] + stride = target.stride + padding = target.padding + new_module = Conv2d(adapter_name, in_channels, out_channels, kernel_size, stride, padding, **kwargs) + else: + if isinstance(target, torch.nn.Linear): + in_features, out_features = target.in_features, target.out_features + if kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " + "Setting fan_in_fan_out to False." + ) + kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False + elif isinstance(target, Conv1D): + in_features, out_features = ( + target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape + ) + if not kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to False but the target module is `Conv1D`. " + "Setting fan_in_fan_out to True." + ) + kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True + else: + raise ValueError( + f"Target module {target} is not supported. " + f"Currently, only `torch.nn.Linear` and `Conv1D` are supported." + ) + new_module = Linear(adapter_name, in_features, out_features, bias=bias, **kwargs) + + return new_module + def _find_and_replace(self, adapter_name): + lora_config = self.lora_config[adapter_name] + self._check_quantization_dependency() + is_target_modules_in_base_model = False + key_list = [key for key, _ in self.model.named_modules()] for key in key_list: - if isinstance(lora_config.target_modules, str): - target_module_found = re.fullmatch(lora_config.target_modules, key) - else: - target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules) - if target_module_found: - if not is_target_modules_in_base_model: - is_target_modules_in_base_model = True - parent, target, target_name = _get_submodules(self.model, key) - if hasattr(target, "bias"): - bias = target.bias is not None + if not self._check_target_module_exists(lora_config, key): + continue - if isinstance(target, LoraLayer): - target.update_layer( - adapter_name, - lora_config.r, - lora_config.lora_alpha, - lora_config.lora_dropout, - lora_config.init_lora_weights, - dtype=kwargs.get('dtype',None) - ) - else: - if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): - eightbit_kwargs = kwargs.copy() - eightbit_kwargs.update( - { - "has_fp16_weights": target.state.has_fp16_weights, - "memory_efficient_backward": target.state.memory_efficient_backward, - "threshold": target.state.threshold, - "index": target.index, - } - ) - new_module = Linear8bitLt( - adapter_name, target.in_features, target.out_features, bias=bias, **eightbit_kwargs - ) - elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit): - fourbit_kwargs = kwargs.copy() - fourbit_kwargs.update( - { - "compute_dtype": target.compute_dtype, - "compress_statistics": target.weight.compress_statistics, - "quant_type": target.weight.quant_type, - } - ) - new_module = Linear4bit( - adapter_name, target.in_features, target.out_features, bias=bias, **fourbit_kwargs - ) - elif isinstance(target, torch.nn.Embedding): - embedding_kwargs = kwargs.copy() - embedding_kwargs.pop("fan_in_fan_out", None) - in_features, out_features = target.num_embeddings, target.embedding_dim - new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs) - else: - if isinstance(target, torch.nn.Linear): - in_features, out_features = target.in_features, target.out_features - if kwargs["fan_in_fan_out"]: - warnings.warn( - "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " - "Setting fan_in_fan_out to False." - ) - kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False - elif isinstance(target, Conv1D): - in_features, out_features = ( - target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape - ) - if not kwargs["fan_in_fan_out"]: - warnings.warn( - "fan_in_fan_out is set to False but the target module is `Conv1D`. " - "Setting fan_in_fan_out to True." - ) - kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True - else: - raise ValueError( - f"Target module {target} is not supported. " - f"Currently, only `torch.nn.Linear` and `Conv1D` are supported." - ) - new_module = Linear(adapter_name, in_features, out_features, bias=bias, **kwargs) + is_target_modules_in_base_model = True + parent, target, target_name = _get_submodules(self.model, key) + + if isinstance(target, LoraLayer) and isinstance(target, torch.nn.Conv2d): + target.update_layer_conv2d( + adapter_name, + lora_config.r, + lora_config.lora_alpha, + lora_config.lora_dropout, + lora_config.init_lora_weights, + ) + elif isinstance(target, LoraLayer): + target.update_layer( + adapter_name, + lora_config.r, + lora_config.lora_alpha, + lora_config.lora_dropout, + lora_config.init_lora_weights, + ) + else: + new_module = self._create_new_module(lora_config, adapter_name, target) + self._replace_module(parent, target_name, new_module, target) - self._replace_module(parent, target_name, new_module, target) if not is_target_modules_in_base_model: raise ValueError( f"Target modules {lora_config.target_modules} not found in the base model. " @@ -209,6 +254,8 @@ def _replace_module(self, parent_module, child_name, new_module, old_module): for name, module in new_module.named_modules(): if "lora_" in name: module.to(old_module.weight.device) + if "ranknum" in name: + module.to(old_module.weight.device) def __getattr__(self, name: str): """Forward missing attributes to the wrapped module."""