Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: ssbuild <[email protected]>
  • Loading branch information
ssbuild committed Oct 10, 2023
1 parent 9db23a9 commit 8ae4e65
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 86 deletions.
15 changes: 15 additions & 0 deletions src/deep_training/trainer/ac/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,21 @@ def _train_loop(self,start_epoch=0,start_step=0,
is_last_step_and_steps_less_than_grad_acc
):

# Gradient clipping
if args.max_grad_norm is not None and args.max_grad_norm > 0:
# deepspeed does its own clipping

if hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(args.max_grad_norm)
elif hasattr(model, "clip_grad_norm_"):
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping
model.clip_grad_norm_(args.max_grad_norm)
else:
self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm,
)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
Expand Down
183 changes: 97 additions & 86 deletions src/deep_training/trainer/hf/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer import IS_SAGEMAKER_MP_POST_1_10, TRAINING_ARGS_NAME
from transformers.trainer_pt_utils import remove_dummy_checkpoint, get_parameter_names
from transformers.trainer_pt_utils import get_parameter_names
from transformers.trainer_utils import ShardedDDPOption, FSDPOption, IntervalStrategy
from transformers.utils import is_peft_available, WEIGHTS_NAME, SAFE_WEIGHTS_NAME, is_sagemaker_mp_enabled, \
is_accelerate_available
Expand All @@ -46,6 +46,17 @@
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp


try:
from transformers.trainer_pt_utils import remove_dummy_checkpoint
except:
def remove_dummy_checkpoint(is_main_process, output_dir, filenames):
if is_main_process:
for filename in filenames:
file = os.path.join(output_dir, filename)
if os.path.isfile(file):
os.remove(file)

class TrainerHF(Trainer):
def __init__(self,
model: Union[PreTrainedModel, nn.Module] = None,
Expand Down Expand Up @@ -236,88 +247,88 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
if self.args.push_to_hub and not _internal_call:
self.push_to_hub(commit_message="Model save")

# def get_decay_parameter_names(self, model) -> List[str]:
# """
# Get all parameter names that weight decay will be applied to
#
# Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still
# apply to those modules since this function only filter out instance of nn.LayerNorm
# """
# parameters = []
# for m, lr in model.get_model_lr():
# decay_parameters = get_parameter_names(m, ALL_LAYERNORM_LAYERS)
# decay_parameters = [name for name in decay_parameters if "bias" not in name]
# parameters.extend(decay_parameters)
# return parameters
#
# def get_parameter_names(self, model) -> List[str]:
# """
# Get all parameter names that weight decay will be applied to
#
# Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still
# apply to those modules since this function only filter out instance of nn.LayerNorm
# """
# parameters = []
# for m, lr in model.get_model_lr():
# parameter = []
# for n,p in m.named_parameters():
# parameter.append((n,p))
# parameters += parameter
# return parameters
#
# def get_optimizer_grouped_parameters(self,opt_model):
# decay_parameters = self.get_decay_parameter_names(opt_model)
# parameters = self.get_parameter_names(opt_model)
# optimizer_grouped_parameters = [
# {
# "params": [
# p for n, p in parameters if (n in decay_parameters and p.requires_grad)
# ],
# "weight_decay": self.args.weight_decay,
# },
# {
# "params": [
# p for n, p in parameters if (n not in decay_parameters and p.requires_grad)
# ],
# "weight_decay": 0.0,
# },
# ]
# return optimizer_grouped_parameters
# def create_optimizer(self):
# """
# Setup the optimizer.
#
# We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
# Trainer's init through `optimizers`, or subclass and override this method in a subclass.
# """
# opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
# if self.optimizer is None:
# optimizer_grouped_parameters = self.get_optimizer_grouped_parameters(opt_model)
# optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
#
# if self.sharded_ddp == ShardedDDPOption.SIMPLE:
# self.optimizer = OSS(
# params=optimizer_grouped_parameters,
# optim=optimizer_cls,
# **optimizer_kwargs,
# )
# else:
# self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
# if optimizer_cls.__name__ == "Adam8bit":
# import bitsandbytes
#
# manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
#
# skipped = 0
# for module in opt_model.modules():
# if isinstance(module, nn.Embedding):
# skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
# logger.info(f"skipped {module}: {skipped / 2 ** 20}M params")
# manager.register_module_override(module, "weight", {"optim_bits": 32})
# logger.debug(f"bitsandbytes: will optimize {module} in fp32")
# logger.info(f"skipped: {skipped / 2 ** 20}M params")
#
# if is_sagemaker_mp_enabled():
# self.optimizer = smp.DistributedOptimizer(self.optimizer)
#
# return self.optimizer
def get_decay_parameter_names(self, model) -> List[str]:
"""
Get all parameter names that weight decay will be applied to
Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still
apply to those modules since this function only filter out instance of nn.LayerNorm
"""
parameters = []
for m, lr in model.get_model_lr():
decay_parameters = get_parameter_names(m, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
parameters.extend(decay_parameters)
return parameters

def get_parameter_names(self, model) -> List[str]:
"""
Get all parameter names that weight decay will be applied to
Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still
apply to those modules since this function only filter out instance of nn.LayerNorm
"""
parameters = []
for m, lr in model.get_model_lr():
parameter = []
for n,p in m.named_parameters():
parameter.append((n,p))
parameters += parameter
return parameters

def get_optimizer_grouped_parameters(self,opt_model):
decay_parameters = self.get_decay_parameter_names(opt_model)
parameters = self.get_parameter_names(opt_model)
optimizer_grouped_parameters = [
{
"params": [
p for n, p in parameters if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p for n, p in parameters if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
return optimizer_grouped_parameters
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None:
optimizer_grouped_parameters = self.get_optimizer_grouped_parameters(opt_model)
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes

manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped / 2 ** 20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped / 2 ** 20}M params")

if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer)

return self.optimizer

0 comments on commit 8ae4e65

Please sign in to comment.