From 9db23a969299ab0945b30e2675820c197ae4aed4 Mon Sep 17 00:00:00 2001 From: ssbuild <462304@qq.cn> Date: Tue, 10 Oct 2023 10:25:57 +0800 Subject: [PATCH 1/2] fix some new bug Signed-off-by: ssbuild <462304@qq.cn> --- README.md | 1 + setup.py | 2 +- .../nlp/models/transformer_base.py | 15 +- src/deep_training/trainer/ac/trainer.py | 31 ++- src/deep_training/trainer/cl/trainer.py | 10 +- src/deep_training/trainer/hf/trainer.py | 189 +++++++++--------- 6 files changed, 147 insertions(+), 101 deletions(-) diff --git a/README.md b/README.md index 98b961fd..18a0d82d 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ pip install -U git+https://github.com/ssbuild/deep_training.git --no-deps --forc - 0.2.5 support colossalai 训练 ,策略 ddp ,gemini,gemini_auto,zero2,zero2_cpu,3d - 0.2.5.post0 fix model deepcopy - 0.2.5.post2 support accelerator 训练 , fix some bug in accelerator and hf trainer + - 0.2.5.post3 fix trainer some bug - 2023-09-26 - 0.2.4 support transformers trainer and qwen-7b 新版 和 qwen-14b , 旧版不再支持,旧版可以安装 deep_training <= 0.2.3 diff --git a/setup.py b/setup.py index 0a54ab42..0fdb07ff 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ ] setup( name='deep_training', - version='0.2.5.post2', + version='0.2.5.post3', description='an easy training architecture', long_description='torch_training: https://github.com/ssbuild/deep_training.git', license='Apache License 2.0', diff --git a/src/deep_training/nlp/models/transformer_base.py b/src/deep_training/nlp/models/transformer_base.py index fcef0e55..185068ff 100644 --- a/src/deep_training/nlp/models/transformer_base.py +++ b/src/deep_training/nlp/models/transformer_base.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # @Time : 2023/4/11 14:35 - +import dataclasses import sys from functools import partial from typing import Any, IO, Union, Optional, Dict @@ -12,6 +12,7 @@ from transformers import ( PretrainedConfig, ) +from transformers.modeling_outputs import CausalLMOutputWithPast from ..utils import configure_optimizers, get_value_from_args_assert, get_value_from_args from ..utils.adversarial import AdversarialMethods @@ -569,7 +570,9 @@ def forward_fn(*args, **kwargs): return loss def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None: - if isinstance(outputs, dict): + if dataclasses.is_dataclass(outputs): + self.log('loss', outputs.loss, prog_bar=True) + elif isinstance(outputs, dict): self.log_dict(outputs, prog_bar=True) else: self.log('loss', outputs, prog_bar=True) @@ -578,7 +581,9 @@ def training_step(self, batch): if not isinstance(batch, dict): batch = dict(batch) outputs = self.compute_loss(**batch) - if isinstance(outputs,tuple): + if dataclasses.is_dataclass(outputs): + return outputs.loss + if isinstance(outputs,(tuple,list)): return outputs[0] return outputs @@ -587,7 +592,7 @@ def validation_step(self, batch, batch_idx, **kwargs): batch = dict(batch) outputs = self.compute_loss(**batch) outputs = apply_to_collection(outputs,dtype=torch.Tensor, function=lambda x: x.detach().numpy()) - if isinstance(outputs, tuple): + if isinstance(outputs, (tuple, list)): outputs = { "loss": outputs[0], "outputs": outputs[1:] @@ -599,7 +604,7 @@ def test_step(self, batch, batch_idx): batch = dict(batch) outputs = self.compute_loss(**batch) outputs = apply_to_collection(outputs,dtype=torch.Tensor, function=lambda x: x.detach().numpy()) - if isinstance(outputs,tuple): + if isinstance(outputs, (tuple, list)): outputs = { "outputs": outputs } diff --git a/src/deep_training/trainer/ac/trainer.py b/src/deep_training/trainer/ac/trainer.py index 9af0005a..e0deebc8 100644 --- a/src/deep_training/trainer/ac/trainer.py +++ b/src/deep_training/trainer/ac/trainer.py @@ -2,6 +2,8 @@ # @Time: 0:37 # @Author: tk # @File:trainer +import contextlib +import dataclasses import importlib import json import argparse @@ -143,6 +145,7 @@ def __init__(self, self.optimizer, self.lr_scheduler = optimizers self.current_flos = 0 + self.use_cpu_amp = False # Activate gradient checkpointing if needed if args.gradient_checkpointing: @@ -567,12 +570,33 @@ def train(self,start_epoch=0,start_step=0, trial: Union["optuna.Trial", Dict[str **kwargs,): self._train_loop(start_epoch=start_epoch,start_step=start_step,trial=trial,ignore_keys_for_eval=ignore_keys_for_eval,**kwargs) + def compute_loss_context_manager(self): + """ + A helper wrapper to group together context managers. + """ + return self.autocast_smart_context_manager() + + def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): + """ + A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired + arguments, depending on the situation. + """ + if self.use_cpu_amp: + ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) + else: + ctx_manager = contextlib.nullcontext() + + return ctx_manager + def training_step(self, model: nn.Module, inputs: Dict[ str, Union[ torch.Tensor, Any ] ]) -> torch.Tensor: device = torch.cuda.current_device() batch = {k: v.to(device) for k, v in inputs.items() if isinstance(v, torch.Tensor)} - loss_obj = model(**batch) + with self.compute_loss_context_manager(): + loss_obj = model(**batch) - if isinstance(loss_obj, (list, tuple)): + if dataclasses.is_dataclass(loss_obj): + loss_obj = loss_obj.loss + elif isinstance(loss_obj, (list, tuple)): loss_obj = loss_obj[ 0 ] if isinstance(loss_obj, dict): @@ -580,6 +604,9 @@ def training_step(self, model: nn.Module, inputs: Dict[ str, Union[ torch.Tensor else: tr_loss_step = loss_obj + if self.args.n_gpu > 1: + tr_loss_step = tr_loss_step.mean() # mean() to average on multi-gpu parallel training + self.accelerator.backward(loss=tr_loss_step) return tr_loss_step.detach() / self.args.gradient_accumulation_steps diff --git a/src/deep_training/trainer/cl/trainer.py b/src/deep_training/trainer/cl/trainer.py index 1b758bcc..399be07f 100644 --- a/src/deep_training/trainer/cl/trainer.py +++ b/src/deep_training/trainer/cl/trainer.py @@ -2,6 +2,7 @@ # @Time: 0:37 # @Author: tk # @File:trainer +import dataclasses import importlib import json import argparse @@ -18,7 +19,7 @@ from typing import Union, Optional, Callable, List, Tuple, Dict, Any import numpy as np -from lightning_utilities import apply_to_collection +from lightning_utilities.core.apply_func import apply_to_collection from packaging import version from datasets import Dataset from peft import PeftModel @@ -608,7 +609,7 @@ def train(self,start_epoch=0,start_step=0, trial: Union["optuna.Trial", Dict[str **kwargs,): self._train_loop(start_epoch=start_epoch,start_step=start_step,trial=trial,ignore_keys_for_eval=ignore_keys_for_eval,**kwargs) - def training_step(self, model: nn.Module, inputs: Dict[ str, Union[ torch.Tensor, Any ] ]) -> torch.Tensor: + def training_step(self, model: nn.Module, inputs: Dict[ str, Union[ torch.Tensor, Any ] ]) -> Union[torch.Tensor,Dict,Any]: device = get_current_device() batch = {k: v.to(device) for k, v in inputs.items() if isinstance(v, torch.Tensor)} loss = model(**batch) @@ -747,7 +748,10 @@ def _train_loop(self,start_epoch=0,start_step=0, self.control = self.callback_handler.on_step_begin(args, self.state, self.control) loss_obj = self.training_step(model, batch) - if isinstance(loss_obj, (list, tuple)): + + if dataclasses.is_dataclass(loss_obj): + loss_obj = loss_obj.loss + elif isinstance(loss_obj, (list, tuple)): loss_obj = loss_obj[0] if isinstance(loss_obj, dict): diff --git a/src/deep_training/trainer/hf/trainer.py b/src/deep_training/trainer/hf/trainer.py index 1289cb4e..cf210a14 100644 --- a/src/deep_training/trainer/hf/trainer.py +++ b/src/deep_training/trainer/hf/trainer.py @@ -1,5 +1,6 @@ # coding=utf-8 # Copyright 2020-present the HuggingFace Inc. team. +import dataclasses import os from dataclasses import dataclass, field from pathlib import Path @@ -8,7 +9,6 @@ import safetensors import torch from accelerate.utils import save_fsdp_model -from peft import PeftModel from torch import nn from torch.nn import functional as F from datasets import Dataset @@ -33,6 +33,8 @@ from transformers.trainer import logger +if is_peft_available: + from peft import PeftModel if is_accelerate_available(): from accelerate import __version__ as accelerate_version @@ -105,7 +107,12 @@ def compute_loss(self, model, inputs, return_outputs=False): self._past = outputs[self.args.past_index] if labels is not None: - if is_peft_available() and isinstance(model, (PeftModel,PetlModel,PromptModel)): + if not is_peft_available(): + supported_classes = (PetlModel, PromptModel) + if is_peft_available(): + supported_classes += (PeftModel,) + + if isinstance(model, supported_classes): model_name = unwrap_model(model.base_model)._get_name() else: model_name = unwrap_model(model)._get_name() @@ -114,10 +121,10 @@ def compute_loss(self, model, inputs, return_outputs=False): else: loss = self.label_smoother(outputs, labels) else: - if isinstance(outputs,tuple): + if dataclasses.is_dataclass(outputs): + loss = outputs.loss + elif isinstance(outputs,(tuple,list)): loss = outputs[0] - if isinstance(loss,dict): - loss = loss["loss"] elif isinstance(outputs, dict) and "loss" not in outputs: raise ValueError( "The model did not return a loss from the inputs, only the following keys: " @@ -127,6 +134,8 @@ def compute_loss(self, model, inputs, return_outputs=False): # We don't use .loss here since the model may return tuples instead of ModelOutput. loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + if isinstance(loss, dict): + loss = loss["loss"] return (loss, outputs) if return_outputs else loss def _save(self, output_dir: Optional[str] = None, state_dict=None): @@ -227,88 +236,88 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa if self.args.push_to_hub and not _internal_call: self.push_to_hub(commit_message="Model save") - def get_decay_parameter_names(self, model) -> List[str]: - """ - Get all parameter names that weight decay will be applied to - - Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still - apply to those modules since this function only filter out instance of nn.LayerNorm - """ - parameters = [] - for m, lr in model.get_model_lr(): - decay_parameters = get_parameter_names(m, ALL_LAYERNORM_LAYERS) - decay_parameters = [name for name in decay_parameters if "bias" not in name] - parameters.extend(decay_parameters) - return parameters - - def get_parameter_names(self, model) -> List[str]: - """ - Get all parameter names that weight decay will be applied to - - Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still - apply to those modules since this function only filter out instance of nn.LayerNorm - """ - parameters = [] - for m, lr in model.get_model_lr(): - parameter = [] - for n,p in m.named_parameters(): - parameter.append((n,p)) - parameters += parameter - return parameters - - def get_optimizer_grouped_parameters(self,opt_model): - decay_parameters = self.get_decay_parameter_names(opt_model) - parameters = self.get_parameter_names(opt_model) - optimizer_grouped_parameters = [ - { - "params": [ - p for n, p in parameters if (n in decay_parameters and p.requires_grad) - ], - "weight_decay": self.args.weight_decay, - }, - { - "params": [ - p for n, p in parameters if (n not in decay_parameters and p.requires_grad) - ], - "weight_decay": 0.0, - }, - ] - return optimizer_grouped_parameters - def create_optimizer(self): - """ - Setup the optimizer. - - We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the - Trainer's init through `optimizers`, or subclass and override this method in a subclass. - """ - opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model - if self.optimizer is None: - optimizer_grouped_parameters = self.get_optimizer_grouped_parameters(opt_model) - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) - - if self.sharded_ddp == ShardedDDPOption.SIMPLE: - self.optimizer = OSS( - params=optimizer_grouped_parameters, - optim=optimizer_cls, - **optimizer_kwargs, - ) - else: - self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) - if optimizer_cls.__name__ == "Adam8bit": - import bitsandbytes - - manager = bitsandbytes.optim.GlobalOptimManager.get_instance() - - skipped = 0 - for module in opt_model.modules(): - if isinstance(module, nn.Embedding): - skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) - logger.info(f"skipped {module}: {skipped / 2 ** 20}M params") - manager.register_module_override(module, "weight", {"optim_bits": 32}) - logger.debug(f"bitsandbytes: will optimize {module} in fp32") - logger.info(f"skipped: {skipped / 2 ** 20}M params") - - if is_sagemaker_mp_enabled(): - self.optimizer = smp.DistributedOptimizer(self.optimizer) - - return self.optimizer \ No newline at end of file + # def get_decay_parameter_names(self, model) -> List[str]: + # """ + # Get all parameter names that weight decay will be applied to + # + # Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still + # apply to those modules since this function only filter out instance of nn.LayerNorm + # """ + # parameters = [] + # for m, lr in model.get_model_lr(): + # decay_parameters = get_parameter_names(m, ALL_LAYERNORM_LAYERS) + # decay_parameters = [name for name in decay_parameters if "bias" not in name] + # parameters.extend(decay_parameters) + # return parameters + # + # def get_parameter_names(self, model) -> List[str]: + # """ + # Get all parameter names that weight decay will be applied to + # + # Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still + # apply to those modules since this function only filter out instance of nn.LayerNorm + # """ + # parameters = [] + # for m, lr in model.get_model_lr(): + # parameter = [] + # for n,p in m.named_parameters(): + # parameter.append((n,p)) + # parameters += parameter + # return parameters + # + # def get_optimizer_grouped_parameters(self,opt_model): + # decay_parameters = self.get_decay_parameter_names(opt_model) + # parameters = self.get_parameter_names(opt_model) + # optimizer_grouped_parameters = [ + # { + # "params": [ + # p for n, p in parameters if (n in decay_parameters and p.requires_grad) + # ], + # "weight_decay": self.args.weight_decay, + # }, + # { + # "params": [ + # p for n, p in parameters if (n not in decay_parameters and p.requires_grad) + # ], + # "weight_decay": 0.0, + # }, + # ] + # return optimizer_grouped_parameters + # def create_optimizer(self): + # """ + # Setup the optimizer. + # + # We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + # Trainer's init through `optimizers`, or subclass and override this method in a subclass. + # """ + # opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + # if self.optimizer is None: + # optimizer_grouped_parameters = self.get_optimizer_grouped_parameters(opt_model) + # optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + # + # if self.sharded_ddp == ShardedDDPOption.SIMPLE: + # self.optimizer = OSS( + # params=optimizer_grouped_parameters, + # optim=optimizer_cls, + # **optimizer_kwargs, + # ) + # else: + # self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + # if optimizer_cls.__name__ == "Adam8bit": + # import bitsandbytes + # + # manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + # + # skipped = 0 + # for module in opt_model.modules(): + # if isinstance(module, nn.Embedding): + # skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) + # logger.info(f"skipped {module}: {skipped / 2 ** 20}M params") + # manager.register_module_override(module, "weight", {"optim_bits": 32}) + # logger.debug(f"bitsandbytes: will optimize {module} in fp32") + # logger.info(f"skipped: {skipped / 2 ** 20}M params") + # + # if is_sagemaker_mp_enabled(): + # self.optimizer = smp.DistributedOptimizer(self.optimizer) + # + # return self.optimizer \ No newline at end of file From 8ae4e656a058b136db6a2cb8bff1b78efa073b3f Mon Sep 17 00:00:00 2001 From: ssbuild <462304@qq.cn> Date: Tue, 10 Oct 2023 11:18:31 +0800 Subject: [PATCH 2/2] update Signed-off-by: ssbuild <462304@qq.cn> --- src/deep_training/trainer/ac/trainer.py | 15 ++ src/deep_training/trainer/hf/trainer.py | 183 +++++++++++++----------- 2 files changed, 112 insertions(+), 86 deletions(-) diff --git a/src/deep_training/trainer/ac/trainer.py b/src/deep_training/trainer/ac/trainer.py index e0deebc8..e93f02aa 100644 --- a/src/deep_training/trainer/ac/trainer.py +++ b/src/deep_training/trainer/ac/trainer.py @@ -806,6 +806,21 @@ def _train_loop(self,start_epoch=0,start_step=0, is_last_step_and_steps_less_than_grad_acc ): + # Gradient clipping + if args.max_grad_norm is not None and args.max_grad_norm > 0: + # deepspeed does its own clipping + + if hasattr(self.optimizer, "clip_grad_norm"): + # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping + self.optimizer.clip_grad_norm(args.max_grad_norm) + elif hasattr(model, "clip_grad_norm_"): + # Some models (like FullyShardedDDP) have a specific way to do gradient clipping + model.clip_grad_norm_(args.max_grad_norm) + else: + self.accelerator.clip_grad_norm_( + model.parameters(), + args.max_grad_norm, + ) optimizer.step() lr_scheduler.step() optimizer.zero_grad() diff --git a/src/deep_training/trainer/hf/trainer.py b/src/deep_training/trainer/hf/trainer.py index cf210a14..ab19ee5c 100644 --- a/src/deep_training/trainer/hf/trainer.py +++ b/src/deep_training/trainer/hf/trainer.py @@ -21,7 +21,7 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.trainer import IS_SAGEMAKER_MP_POST_1_10, TRAINING_ARGS_NAME -from transformers.trainer_pt_utils import remove_dummy_checkpoint, get_parameter_names +from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_utils import ShardedDDPOption, FSDPOption, IntervalStrategy from transformers.utils import is_peft_available, WEIGHTS_NAME, SAFE_WEIGHTS_NAME, is_sagemaker_mp_enabled, \ is_accelerate_available @@ -46,6 +46,17 @@ if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp + +try: + from transformers.trainer_pt_utils import remove_dummy_checkpoint +except: + def remove_dummy_checkpoint(is_main_process, output_dir, filenames): + if is_main_process: + for filename in filenames: + file = os.path.join(output_dir, filename) + if os.path.isfile(file): + os.remove(file) + class TrainerHF(Trainer): def __init__(self, model: Union[PreTrainedModel, nn.Module] = None, @@ -236,88 +247,88 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa if self.args.push_to_hub and not _internal_call: self.push_to_hub(commit_message="Model save") - # def get_decay_parameter_names(self, model) -> List[str]: - # """ - # Get all parameter names that weight decay will be applied to - # - # Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still - # apply to those modules since this function only filter out instance of nn.LayerNorm - # """ - # parameters = [] - # for m, lr in model.get_model_lr(): - # decay_parameters = get_parameter_names(m, ALL_LAYERNORM_LAYERS) - # decay_parameters = [name for name in decay_parameters if "bias" not in name] - # parameters.extend(decay_parameters) - # return parameters - # - # def get_parameter_names(self, model) -> List[str]: - # """ - # Get all parameter names that weight decay will be applied to - # - # Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still - # apply to those modules since this function only filter out instance of nn.LayerNorm - # """ - # parameters = [] - # for m, lr in model.get_model_lr(): - # parameter = [] - # for n,p in m.named_parameters(): - # parameter.append((n,p)) - # parameters += parameter - # return parameters - # - # def get_optimizer_grouped_parameters(self,opt_model): - # decay_parameters = self.get_decay_parameter_names(opt_model) - # parameters = self.get_parameter_names(opt_model) - # optimizer_grouped_parameters = [ - # { - # "params": [ - # p for n, p in parameters if (n in decay_parameters and p.requires_grad) - # ], - # "weight_decay": self.args.weight_decay, - # }, - # { - # "params": [ - # p for n, p in parameters if (n not in decay_parameters and p.requires_grad) - # ], - # "weight_decay": 0.0, - # }, - # ] - # return optimizer_grouped_parameters - # def create_optimizer(self): - # """ - # Setup the optimizer. - # - # We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the - # Trainer's init through `optimizers`, or subclass and override this method in a subclass. - # """ - # opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model - # if self.optimizer is None: - # optimizer_grouped_parameters = self.get_optimizer_grouped_parameters(opt_model) - # optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) - # - # if self.sharded_ddp == ShardedDDPOption.SIMPLE: - # self.optimizer = OSS( - # params=optimizer_grouped_parameters, - # optim=optimizer_cls, - # **optimizer_kwargs, - # ) - # else: - # self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) - # if optimizer_cls.__name__ == "Adam8bit": - # import bitsandbytes - # - # manager = bitsandbytes.optim.GlobalOptimManager.get_instance() - # - # skipped = 0 - # for module in opt_model.modules(): - # if isinstance(module, nn.Embedding): - # skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) - # logger.info(f"skipped {module}: {skipped / 2 ** 20}M params") - # manager.register_module_override(module, "weight", {"optim_bits": 32}) - # logger.debug(f"bitsandbytes: will optimize {module} in fp32") - # logger.info(f"skipped: {skipped / 2 ** 20}M params") - # - # if is_sagemaker_mp_enabled(): - # self.optimizer = smp.DistributedOptimizer(self.optimizer) - # - # return self.optimizer \ No newline at end of file + def get_decay_parameter_names(self, model) -> List[str]: + """ + Get all parameter names that weight decay will be applied to + + Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still + apply to those modules since this function only filter out instance of nn.LayerNorm + """ + parameters = [] + for m, lr in model.get_model_lr(): + decay_parameters = get_parameter_names(m, ALL_LAYERNORM_LAYERS) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + parameters.extend(decay_parameters) + return parameters + + def get_parameter_names(self, model) -> List[str]: + """ + Get all parameter names that weight decay will be applied to + + Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still + apply to those modules since this function only filter out instance of nn.LayerNorm + """ + parameters = [] + for m, lr in model.get_model_lr(): + parameter = [] + for n,p in m.named_parameters(): + parameter.append((n,p)) + parameters += parameter + return parameters + + def get_optimizer_grouped_parameters(self,opt_model): + decay_parameters = self.get_decay_parameter_names(opt_model) + parameters = self.get_parameter_names(opt_model) + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in parameters if (n in decay_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p for n, p in parameters if (n not in decay_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + def create_optimizer(self): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + if self.optimizer is None: + optimizer_grouped_parameters = self.get_optimizer_grouped_parameters(opt_model) + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + self.optimizer = OSS( + params=optimizer_grouped_parameters, + optim=optimizer_cls, + **optimizer_kwargs, + ) + else: + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) + logger.info(f"skipped {module}: {skipped / 2 ** 20}M params") + manager.register_module_override(module, "weight", {"optim_bits": 32}) + logger.debug(f"bitsandbytes: will optimize {module} in fp32") + logger.info(f"skipped: {skipped / 2 ** 20}M params") + + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer(self.optimizer) + + return self.optimizer \ No newline at end of file