Skip to content

Commit

Permalink
Merge pull request #67 from ssbuild/dev
Browse files Browse the repository at this point in the history
patch new bug
  • Loading branch information
ssbuild authored Oct 10, 2023
2 parents 1073340 + 8ae4e65 commit 86b3369
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 17 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pip install -U git+https://github.com/ssbuild/deep_training.git --no-deps --forc
- 0.2.5 support colossalai 训练 ,策略 ddp ,gemini,gemini_auto,zero2,zero2_cpu,3d
- 0.2.5.post0 fix model deepcopy
- 0.2.5.post2 support accelerator 训练 , fix some bug in accelerator and hf trainer
- 0.2.5.post3 fix trainer some bug

- <strong>2023-09-26</strong>
- 0.2.4 support transformers trainer and qwen-7b 新版 和 qwen-14b , 旧版不再支持,旧版可以安装 deep_training <= 0.2.3
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
]
setup(
name='deep_training',
version='0.2.5.post2',
version='0.2.5.post3',
description='an easy training architecture',
long_description='torch_training: https://github.com/ssbuild/deep_training.git',
license='Apache License 2.0',
Expand Down
15 changes: 10 additions & 5 deletions src/deep_training/nlp/models/transformer_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# @Time : 2023/4/11 14:35

import dataclasses
import sys
from functools import partial
from typing import Any, IO, Union, Optional, Dict
Expand All @@ -12,6 +12,7 @@
from transformers import (
PretrainedConfig,
)
from transformers.modeling_outputs import CausalLMOutputWithPast

from ..utils import configure_optimizers, get_value_from_args_assert, get_value_from_args
from ..utils.adversarial import AdversarialMethods
Expand Down Expand Up @@ -569,7 +570,9 @@ def forward_fn(*args, **kwargs):
return loss

def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
if isinstance(outputs, dict):
if dataclasses.is_dataclass(outputs):
self.log('loss', outputs.loss, prog_bar=True)
elif isinstance(outputs, dict):
self.log_dict(outputs, prog_bar=True)
else:
self.log('loss', outputs, prog_bar=True)
Expand All @@ -578,7 +581,9 @@ def training_step(self, batch):
if not isinstance(batch, dict):
batch = dict(batch)
outputs = self.compute_loss(**batch)
if isinstance(outputs,tuple):
if dataclasses.is_dataclass(outputs):
return outputs.loss
if isinstance(outputs,(tuple,list)):
return outputs[0]
return outputs

Expand All @@ -587,7 +592,7 @@ def validation_step(self, batch, batch_idx, **kwargs):
batch = dict(batch)
outputs = self.compute_loss(**batch)
outputs = apply_to_collection(outputs,dtype=torch.Tensor, function=lambda x: x.detach().numpy())
if isinstance(outputs, tuple):
if isinstance(outputs, (tuple, list)):
outputs = {
"loss": outputs[0],
"outputs": outputs[1:]
Expand All @@ -599,7 +604,7 @@ def test_step(self, batch, batch_idx):
batch = dict(batch)
outputs = self.compute_loss(**batch)
outputs = apply_to_collection(outputs,dtype=torch.Tensor, function=lambda x: x.detach().numpy())
if isinstance(outputs,tuple):
if isinstance(outputs, (tuple, list)):
outputs = {
"outputs": outputs
}
Expand Down
46 changes: 44 additions & 2 deletions src/deep_training/trainer/ac/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# @Time: 0:37
# @Author: tk
# @File:trainer
import contextlib
import dataclasses
import importlib
import json
import argparse
Expand Down Expand Up @@ -143,6 +145,7 @@ def __init__(self,
self.optimizer, self.lr_scheduler = optimizers

self.current_flos = 0
self.use_cpu_amp = False

# Activate gradient checkpointing if needed
if args.gradient_checkpointing:
Expand Down Expand Up @@ -567,19 +570,43 @@ def train(self,start_epoch=0,start_step=0, trial: Union["optuna.Trial", Dict[str
**kwargs,):
self._train_loop(start_epoch=start_epoch,start_step=start_step,trial=trial,ignore_keys_for_eval=ignore_keys_for_eval,**kwargs)

def compute_loss_context_manager(self):
"""
A helper wrapper to group together context managers.
"""
return self.autocast_smart_context_manager()

def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
"""
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
arguments, depending on the situation.
"""
if self.use_cpu_amp:
ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
else:
ctx_manager = contextlib.nullcontext()

return ctx_manager

def training_step(self, model: nn.Module, inputs: Dict[ str, Union[ torch.Tensor, Any ] ]) -> torch.Tensor:
device = torch.cuda.current_device()
batch = {k: v.to(device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
loss_obj = model(**batch)
with self.compute_loss_context_manager():
loss_obj = model(**batch)

if isinstance(loss_obj, (list, tuple)):
if dataclasses.is_dataclass(loss_obj):
loss_obj = loss_obj.loss
elif isinstance(loss_obj, (list, tuple)):
loss_obj = loss_obj[ 0 ]

if isinstance(loss_obj, dict):
tr_loss_step = loss_obj[ "loss" ]
else:
tr_loss_step = loss_obj

if self.args.n_gpu > 1:
tr_loss_step = tr_loss_step.mean() # mean() to average on multi-gpu parallel training

self.accelerator.backward(loss=tr_loss_step)
return tr_loss_step.detach() / self.args.gradient_accumulation_steps

Expand Down Expand Up @@ -779,6 +806,21 @@ def _train_loop(self,start_epoch=0,start_step=0,
is_last_step_and_steps_less_than_grad_acc
):

# Gradient clipping
if args.max_grad_norm is not None and args.max_grad_norm > 0:
# deepspeed does its own clipping

if hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(args.max_grad_norm)
elif hasattr(model, "clip_grad_norm_"):
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping
model.clip_grad_norm_(args.max_grad_norm)
else:
self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm,
)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
Expand Down
10 changes: 7 additions & 3 deletions src/deep_training/trainer/cl/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# @Time: 0:37
# @Author: tk
# @File:trainer
import dataclasses
import importlib
import json
import argparse
Expand All @@ -18,7 +19,7 @@
from typing import Union, Optional, Callable, List, Tuple, Dict, Any

import numpy as np
from lightning_utilities import apply_to_collection
from lightning_utilities.core.apply_func import apply_to_collection
from packaging import version
from datasets import Dataset
from peft import PeftModel
Expand Down Expand Up @@ -608,7 +609,7 @@ def train(self,start_epoch=0,start_step=0, trial: Union["optuna.Trial", Dict[str
**kwargs,):
self._train_loop(start_epoch=start_epoch,start_step=start_step,trial=trial,ignore_keys_for_eval=ignore_keys_for_eval,**kwargs)

def training_step(self, model: nn.Module, inputs: Dict[ str, Union[ torch.Tensor, Any ] ]) -> torch.Tensor:
def training_step(self, model: nn.Module, inputs: Dict[ str, Union[ torch.Tensor, Any ] ]) -> Union[torch.Tensor,Dict,Any]:
device = get_current_device()
batch = {k: v.to(device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
loss = model(**batch)
Expand Down Expand Up @@ -747,7 +748,10 @@ def _train_loop(self,start_epoch=0,start_step=0,
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

loss_obj = self.training_step(model, batch)
if isinstance(loss_obj, (list, tuple)):

if dataclasses.is_dataclass(loss_obj):
loss_obj = loss_obj.loss
elif isinstance(loss_obj, (list, tuple)):
loss_obj = loss_obj[0]

if isinstance(loss_obj, dict):
Expand Down
32 changes: 26 additions & 6 deletions src/deep_training/trainer/hf/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
import dataclasses
import os
from dataclasses import dataclass, field
from pathlib import Path
Expand All @@ -8,7 +9,6 @@
import safetensors
import torch
from accelerate.utils import save_fsdp_model
from peft import PeftModel
from torch import nn
from torch.nn import functional as F
from datasets import Dataset
Expand All @@ -21,7 +21,7 @@
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer import IS_SAGEMAKER_MP_POST_1_10, TRAINING_ARGS_NAME
from transformers.trainer_pt_utils import remove_dummy_checkpoint, get_parameter_names
from transformers.trainer_pt_utils import get_parameter_names
from transformers.trainer_utils import ShardedDDPOption, FSDPOption, IntervalStrategy
from transformers.utils import is_peft_available, WEIGHTS_NAME, SAFE_WEIGHTS_NAME, is_sagemaker_mp_enabled, \
is_accelerate_available
Expand All @@ -33,6 +33,8 @@

from transformers.trainer import logger

if is_peft_available:
from peft import PeftModel

if is_accelerate_available():
from accelerate import __version__ as accelerate_version
Expand All @@ -44,6 +46,17 @@
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp


try:
from transformers.trainer_pt_utils import remove_dummy_checkpoint
except:
def remove_dummy_checkpoint(is_main_process, output_dir, filenames):
if is_main_process:
for filename in filenames:
file = os.path.join(output_dir, filename)
if os.path.isfile(file):
os.remove(file)

class TrainerHF(Trainer):
def __init__(self,
model: Union[PreTrainedModel, nn.Module] = None,
Expand Down Expand Up @@ -105,7 +118,12 @@ def compute_loss(self, model, inputs, return_outputs=False):
self._past = outputs[self.args.past_index]

if labels is not None:
if is_peft_available() and isinstance(model, (PeftModel,PetlModel,PromptModel)):
if not is_peft_available():
supported_classes = (PetlModel, PromptModel)
if is_peft_available():
supported_classes += (PeftModel,)

if isinstance(model, supported_classes):
model_name = unwrap_model(model.base_model)._get_name()
else:
model_name = unwrap_model(model)._get_name()
Expand All @@ -114,10 +132,10 @@ def compute_loss(self, model, inputs, return_outputs=False):
else:
loss = self.label_smoother(outputs, labels)
else:
if isinstance(outputs,tuple):
if dataclasses.is_dataclass(outputs):
loss = outputs.loss
elif isinstance(outputs,(tuple,list)):
loss = outputs[0]
if isinstance(loss,dict):
loss = loss["loss"]
elif isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
Expand All @@ -127,6 +145,8 @@ def compute_loss(self, model, inputs, return_outputs=False):
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

if isinstance(loss, dict):
loss = loss["loss"]
return (loss, outputs) if return_outputs else loss

def _save(self, output_dir: Optional[str] = None, state_dict=None):
Expand Down

0 comments on commit 86b3369

Please sign in to comment.