diff --git a/src/deep_training/trainer/ac/trainer.py b/src/deep_training/trainer/ac/trainer.py index e0deebc8..e93f02aa 100644 --- a/src/deep_training/trainer/ac/trainer.py +++ b/src/deep_training/trainer/ac/trainer.py @@ -806,6 +806,21 @@ def _train_loop(self,start_epoch=0,start_step=0, is_last_step_and_steps_less_than_grad_acc ): + # Gradient clipping + if args.max_grad_norm is not None and args.max_grad_norm > 0: + # deepspeed does its own clipping + + if hasattr(self.optimizer, "clip_grad_norm"): + # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping + self.optimizer.clip_grad_norm(args.max_grad_norm) + elif hasattr(model, "clip_grad_norm_"): + # Some models (like FullyShardedDDP) have a specific way to do gradient clipping + model.clip_grad_norm_(args.max_grad_norm) + else: + self.accelerator.clip_grad_norm_( + model.parameters(), + args.max_grad_norm, + ) optimizer.step() lr_scheduler.step() optimizer.zero_grad() diff --git a/src/deep_training/trainer/hf/trainer.py b/src/deep_training/trainer/hf/trainer.py index cf210a14..ab19ee5c 100644 --- a/src/deep_training/trainer/hf/trainer.py +++ b/src/deep_training/trainer/hf/trainer.py @@ -21,7 +21,7 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.trainer import IS_SAGEMAKER_MP_POST_1_10, TRAINING_ARGS_NAME -from transformers.trainer_pt_utils import remove_dummy_checkpoint, get_parameter_names +from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_utils import ShardedDDPOption, FSDPOption, IntervalStrategy from transformers.utils import is_peft_available, WEIGHTS_NAME, SAFE_WEIGHTS_NAME, is_sagemaker_mp_enabled, \ is_accelerate_available @@ -46,6 +46,17 @@ if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp + +try: + from transformers.trainer_pt_utils import remove_dummy_checkpoint +except: + def remove_dummy_checkpoint(is_main_process, output_dir, filenames): + if is_main_process: + for filename in filenames: + file = os.path.join(output_dir, filename) + if os.path.isfile(file): + os.remove(file) + class TrainerHF(Trainer): def __init__(self, model: Union[PreTrainedModel, nn.Module] = None, @@ -236,88 +247,88 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa if self.args.push_to_hub and not _internal_call: self.push_to_hub(commit_message="Model save") - # def get_decay_parameter_names(self, model) -> List[str]: - # """ - # Get all parameter names that weight decay will be applied to - # - # Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still - # apply to those modules since this function only filter out instance of nn.LayerNorm - # """ - # parameters = [] - # for m, lr in model.get_model_lr(): - # decay_parameters = get_parameter_names(m, ALL_LAYERNORM_LAYERS) - # decay_parameters = [name for name in decay_parameters if "bias" not in name] - # parameters.extend(decay_parameters) - # return parameters - # - # def get_parameter_names(self, model) -> List[str]: - # """ - # Get all parameter names that weight decay will be applied to - # - # Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still - # apply to those modules since this function only filter out instance of nn.LayerNorm - # """ - # parameters = [] - # for m, lr in model.get_model_lr(): - # parameter = [] - # for n,p in m.named_parameters(): - # parameter.append((n,p)) - # parameters += parameter - # return parameters - # - # def get_optimizer_grouped_parameters(self,opt_model): - # decay_parameters = self.get_decay_parameter_names(opt_model) - # parameters = self.get_parameter_names(opt_model) - # optimizer_grouped_parameters = [ - # { - # "params": [ - # p for n, p in parameters if (n in decay_parameters and p.requires_grad) - # ], - # "weight_decay": self.args.weight_decay, - # }, - # { - # "params": [ - # p for n, p in parameters if (n not in decay_parameters and p.requires_grad) - # ], - # "weight_decay": 0.0, - # }, - # ] - # return optimizer_grouped_parameters - # def create_optimizer(self): - # """ - # Setup the optimizer. - # - # We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the - # Trainer's init through `optimizers`, or subclass and override this method in a subclass. - # """ - # opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model - # if self.optimizer is None: - # optimizer_grouped_parameters = self.get_optimizer_grouped_parameters(opt_model) - # optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) - # - # if self.sharded_ddp == ShardedDDPOption.SIMPLE: - # self.optimizer = OSS( - # params=optimizer_grouped_parameters, - # optim=optimizer_cls, - # **optimizer_kwargs, - # ) - # else: - # self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) - # if optimizer_cls.__name__ == "Adam8bit": - # import bitsandbytes - # - # manager = bitsandbytes.optim.GlobalOptimManager.get_instance() - # - # skipped = 0 - # for module in opt_model.modules(): - # if isinstance(module, nn.Embedding): - # skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) - # logger.info(f"skipped {module}: {skipped / 2 ** 20}M params") - # manager.register_module_override(module, "weight", {"optim_bits": 32}) - # logger.debug(f"bitsandbytes: will optimize {module} in fp32") - # logger.info(f"skipped: {skipped / 2 ** 20}M params") - # - # if is_sagemaker_mp_enabled(): - # self.optimizer = smp.DistributedOptimizer(self.optimizer) - # - # return self.optimizer \ No newline at end of file + def get_decay_parameter_names(self, model) -> List[str]: + """ + Get all parameter names that weight decay will be applied to + + Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still + apply to those modules since this function only filter out instance of nn.LayerNorm + """ + parameters = [] + for m, lr in model.get_model_lr(): + decay_parameters = get_parameter_names(m, ALL_LAYERNORM_LAYERS) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + parameters.extend(decay_parameters) + return parameters + + def get_parameter_names(self, model) -> List[str]: + """ + Get all parameter names that weight decay will be applied to + + Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still + apply to those modules since this function only filter out instance of nn.LayerNorm + """ + parameters = [] + for m, lr in model.get_model_lr(): + parameter = [] + for n,p in m.named_parameters(): + parameter.append((n,p)) + parameters += parameter + return parameters + + def get_optimizer_grouped_parameters(self,opt_model): + decay_parameters = self.get_decay_parameter_names(opt_model) + parameters = self.get_parameter_names(opt_model) + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in parameters if (n in decay_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p for n, p in parameters if (n not in decay_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + def create_optimizer(self): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + if self.optimizer is None: + optimizer_grouped_parameters = self.get_optimizer_grouped_parameters(opt_model) + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + self.optimizer = OSS( + params=optimizer_grouped_parameters, + optim=optimizer_cls, + **optimizer_kwargs, + ) + else: + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) + logger.info(f"skipped {module}: {skipped / 2 ** 20}M params") + manager.register_module_override(module, "weight", {"optim_bits": 32}) + logger.debug(f"bitsandbytes: will optimize {module} in fp32") + logger.info(f"skipped: {skipped / 2 ** 20}M params") + + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer(self.optimizer) + + return self.optimizer \ No newline at end of file