Skip to content

Commit

Permalink
0.1.4
Browse files Browse the repository at this point in the history
Signed-off-by: tk <[email protected]>
  • Loading branch information
ssbuild committed May 2, 2023
1 parent f307987 commit 388c9e4
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 28 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
- 基于 lightning.fabric 封装RL ppo 进行中...

## 更新
- <strong>2023年05月02</strong>
- 0.1.4 增加 prompt_tuning,p_tuning,prefix_tuning,adaption_prompt

- <strong>2023年04月28</strong>
- 0.1.3@post0 新版本基于lightning
- pytorch-lightning 更名 lightning 完成
Expand Down
2 changes: 1 addition & 1 deletion nlp/models/moss/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math
import triton
import triton.language as tl
from models.custom_autotune import *
from .custom_autotune import *


def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
Expand Down
2 changes: 1 addition & 1 deletion nlp/models/prompt/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class PromptBaseArguments(PromptConfigMixin):
"""
base_model_name_or_path: str = field(default=None, metadata={"help": "The name of the base model to use."})
inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"})
prompt_type: str = field(default='prefix_tuning', metadata={"help": "one of prompt_tuning,p_tuning,prefix_tuning"})
prompt_type: str = field(default='prefix_tuning', metadata={"help": "one of prompt_tuning,p_tuning,prefix_tuning,adaption_prompt"})
with_prompt: bool = field(default=False, metadata={"help": "whether use lora"})
task_type: Union[str, TaskType] = field(default=None, metadata={"help": "Task type, one of seq_cls,seq_2_seq_lm,causal_lm,token_cls"})
target_dtype: Optional[Union[int, str]] = field(
Expand Down
39 changes: 21 additions & 18 deletions nlp/models/prompt/prompt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_prompt_model(model, prompt_config):
Args:
model ([`transformers.PreTrainedModel`]): Model to be wrapped.
peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Prompt model.
prompt_config ([`PromptConfig`]): Configuration object containing the parameters of the Prompt model.
"""

model_config = model.config.to_dict()
Expand Down Expand Up @@ -66,16 +66,18 @@ def __init__(self, model, prompt_config: PromptLearningConfig, adapter_name="def
super().__init__()

self.base_model = model
self.transformer_model = model if model is PreTrainedModel else model.model
self.config = self.transformer_model.config
self.config = self.get_transformer_model().config
self.modules_to_save = None
self.prompt_config = {}
self.active_adapter = adapter_name
self.prompt_type = prompt_config.prompt_type
self.base_model_torch_dtype = getattr(self.transformer_model, "dtype", None)
self.base_model_torch_dtype = getattr(self.get_transformer_model(), "dtype", None)

self.add_adapter(adapter_name, prompt_config)

def get_transformer_model(self):
return self.base_model if isinstance(self.base_model,PreTrainedModel) else self.base_model.model

def save_pretrained(self, save_directory, **kwargs):
r"""
This function saves the adapter model and the adapter configuration files to a directory, so that it can be
Expand Down Expand Up @@ -291,7 +293,8 @@ def get_base_model(self):
"""
Returns the base model.
"""
return self.base_model if isinstance(self.active_peft_config, PromptLearningConfig) else self.base_model.model
return self.base_model.model
# return self.base_model if isinstance(self.active_peft_config, PromptLearningConfig) else self.base_model.model

def add_adapter(self, adapter_name, prompt_config):
if prompt_config.prompt_type != self.prompt_type:
Expand Down Expand Up @@ -590,7 +593,7 @@ class PromptModelForCausalLM(PromptModel):

def __init__(self, model, prompt_config: PromptLearningConfig, adapter_name="default"):
super().__init__(model, prompt_config, adapter_name)
self.base_model_prepare_inputs_for_generation = self.transformer_model.prepare_inputs_for_generation
self.base_model_prepare_inputs_for_generation = self.get_transformer_model().prepare_inputs_for_generation

def forward(
self,
Expand Down Expand Up @@ -655,7 +658,7 @@ def forward(

def generate(self, **kwargs):
prompt_config = self.active_peft_config
self.transformer_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
self.get_transformer_model().prepare_inputs_for_generation = self.prepare_inputs_for_generation
try:
if not isinstance(prompt_config, PromptLearningConfig):
outputs = self.base_model.generate(**kwargs)
Expand Down Expand Up @@ -688,10 +691,10 @@ def generate(self, **kwargs):

outputs = self.base_model.generate(**kwargs)
except:
self.transformer_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
self.get_transformer_model().prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
raise
else:
self.transformer_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
self.get_transformer_model().prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
return outputs

def prepare_inputs_for_generation(self, *args, **kwargs):
Expand Down Expand Up @@ -775,9 +778,9 @@ class PromptModelForSeq2SeqLM(PromptModel):

def __init__(self, model, prompt_config: PromptLearningConfig, adapter_name="default"):
super().__init__(model, prompt_config, adapter_name)
self.base_model_prepare_inputs_for_generation = self.transformer_model.prepare_inputs_for_generation
self.base_model_prepare_inputs_for_generation = self.get_transformer_model().prepare_inputs_for_generation
self.base_model_prepare_encoder_decoder_kwargs_for_generation = (
self.transformer_model._prepare_encoder_decoder_kwargs_for_generation
self.get_transformer_model()._prepare_encoder_decoder_kwargs_for_generation
)

def forward(
Expand Down Expand Up @@ -873,8 +876,8 @@ def forward(

def generate(self, **kwargs):
prompt_config = self.active_peft_config
self.transformer_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
self.transformer_model._prepare_encoder_decoder_kwargs_for_generation = (
self.get_transformer_model().prepare_inputs_for_generation = self.prepare_inputs_for_generation
self.get_transformer_model()._prepare_encoder_decoder_kwargs_for_generation = (
self._prepare_encoder_decoder_kwargs_for_generation
)
try:
Expand All @@ -899,14 +902,14 @@ def generate(self, **kwargs):
else:
raise NotImplementedError
except:
self.transformer_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
self.transformer_model._prepare_encoder_decoder_kwargs_for_generation = (
self.get_transformer_model().prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
self.get_transformer_model()._prepare_encoder_decoder_kwargs_for_generation = (
self.base_model_prepare_encoder_decoder_kwargs_for_generation
)
raise
else:
self.transformer_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
self.transformer_model._prepare_encoder_decoder_kwargs_for_generation = (
self.get_transformer_model().prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
self.get_transformer_model()._prepare_encoder_decoder_kwargs_for_generation = (
self.base_model_prepare_encoder_decoder_kwargs_for_generation
)
return outputs
Expand Down Expand Up @@ -1080,7 +1083,7 @@ def _prefix_tuning_forward(
if "past_key_values" in fwd_params:
return self.base_model(labels=labels, **kwargs)
else:
transformer_backbone_name = self.transformer_model.get_submodule(self.transformer_backbone_name)
transformer_backbone_name = self.get_transformer_model().get_submodule(self.transformer_backbone_name)
fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys())
if "past_key_values" not in fwd_params:
raise ValueError("Model does not support past key values which are required for prefix tuning.")
Expand Down
8 changes: 4 additions & 4 deletions nlp/models/prompt/save_and_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def get_prompt_model_state_dict(model, state_dict=None, adapter_name="default"):
The state dict of the model. If not provided, the state dict of the model
will be used.
"""
config = model.peft_config[adapter_name]
config = model.prompt_config[adapter_name]
if state_dict is None:
state_dict = model.state_dict()

if config.peft_type == PromptType.ADAPTION_PROMPT:
if config.prompt_type == PromptType.ADAPTION_PROMPT:
to_return = {k: state_dict[k] for k in state_dict if k.split(".")[-1].startswith("adaption_")}
elif isinstance(config, PromptLearningConfig):
to_return = {}
Expand Down Expand Up @@ -59,7 +59,7 @@ def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="defaul
model ([`PeftModel`]): The Peft model.
peft_model_state_dict (`dict`): The state dict of the Peft model.
"""
config = model.peft_config[adapter_name]
config = model.prompt_config[adapter_name]
state_dict = {}
if model.modules_to_save is not None:
for key, value in peft_model_state_dict.items():
Expand All @@ -73,7 +73,7 @@ def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="defaul
state_dict = peft_model_state_dict


if isinstance(config, PromptLearningConfig) or config.peft_type == PromptType.ADAPTION_PROMPT:
if isinstance(config, PromptLearningConfig) or config.prompt_type == PromptType.ADAPTION_PROMPT:
peft_model_state_dict = state_dict
else:
raise NotImplementedError
Expand Down
7 changes: 4 additions & 3 deletions nlp/models/transformer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def load_from_checkpoint(
return super(MyLightningModule, cls).load_from_checkpoint(checkpoint_path,map_location,hparams_file,strict,**kwargs)

@property
def backbone(self):
def backbone(self) -> nn.Module:
return self.__model

@property
Expand Down Expand Up @@ -227,7 +227,8 @@ def get_model_lr(self,model=None,lr=None):
lr = lr if lr is not None else self.config.task_specific_params['learning_rate']
if model is not None:
return [(model,lr)]
return [(self.model if self.base_model_prefix is not None else self , lr), ]
# return [(self.model if self.base_model_prefix is not None else self , lr), ]
return [(self, lr), ]



Expand Down Expand Up @@ -317,7 +318,7 @@ def get_embeddings_module(self):
return tmp_obj

@property
def backbone(self):
def backbone(self) -> nn.Module:
return self.__backbone

@property
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
ignore = ['test','tests']
setup(
name='deep_training',
version='0.1.4rc0',
version='0.1.4',
description='an easy training architecture',
long_description='torch_training: https://github.com/ssbuild/deep_training.git',
license='Apache License 2.0',
Expand Down

0 comments on commit 388c9e4

Please sign in to comment.