Skip to content

Commit

Permalink
Migrate configs to toml
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
gnadathur committed Feb 23, 2024
1 parent 3ebb5d7 commit bc9ea0d
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 40 deletions.
8 changes: 3 additions & 5 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@ CHECKPOINT_FOLDER=${CHECKPOINT_FOLDER:-""}
# Please adjust this to a longer interval period. The unit of measurement is in steps.
CHECKPOINT_INTERVAL=${CHECKPOINT_INTERVAL:-5}

CONFIG_FILE=${CONFIG_FILE:-"./torchtrain/train_configs/train_config.toml"}

torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --steps 10 \
--model ${MODEL} --model_conf ${MODEL_CONF} \
--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP} \
--compile \
--checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL}
train.py --config_file ${CONFIG_FILE}
7 changes: 4 additions & 3 deletions torchtrain/lr_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# All rights reserved.

from torch.optim.lr_scheduler import LambdaLR
from torchtrain.config_manager import JobConfig

# global states for scheduling
# these are needed as LambdaLR does not support argument passing
Expand Down Expand Up @@ -29,11 +30,11 @@ def linear_warmup_linear_decay(current_step: int) -> float:
return curr_adjustment


def get_lr_scheduler(optimizer, args):
def get_lr_scheduler(optimizer, job_config: JobConfig):
"""Build a linear warmup and linear decay scheduler"""
global _warmup_steps, _decay_steps
_warmup_steps = max(int(args.steps * args.warmup_pct), 2)
_decay_steps = float(max(1, args.steps - _warmup_steps))
_warmup_steps = max(int(job_config.training.steps * job_config.training.warmup_pct), 2)
_decay_steps = float(max(1, job_config.training.steps - _warmup_steps))

warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay)
return warmup_scheduler
10 changes: 6 additions & 4 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)

from torchtrain.logging_utils import rank0_log
from torchtrain.config_manager import JobConfig

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -67,13 +68,14 @@ def partition_fn(name, module, device_mesh):


# Uses PTD FSDP AC wrapper
def checkpoint_wrapper(module, config):
# TODO: why is config needed here?
def checkpoint_wrapper(module, job_config: JobConfig):
return ptd_checkpoint_wrapper(
module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False
)


def parallelize_llama(model, world_mesh, parallel_dims, args):
def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply parallelisms to the model, including PTD parallelisms, and AC.
Expand All @@ -87,7 +89,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
if parallel_dims.sp_enabled:
# First we apply Sequence Parallelism if it's enabled
tp_mesh = world_mesh["sp"] if world_mesh.ndim > 1 else world_mesh
sp_degree = args.sp_degree
sp_degree = job_config.training.sequence_parallelism_degree
# First:
# 1. parallelize the first embedding and the last linear proj layer
# 2. shard the first layer of transformer block
Expand Down Expand Up @@ -175,7 +177,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
# apply AC to each layer
# before wrapping with FSDP, we need to make sure the layer is on GPU
transformer_block = transformer_block.cuda()
transformer_block = checkpoint_wrapper(transformer_block, args)
transformer_block = checkpoint_wrapper(transformer_block, job_config)

# Wraps each layer with FSDP
model.layers[layer_id] = wrap(transformer_block)
Expand Down
6 changes: 3 additions & 3 deletions torchtrain/train_configs/train_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ batch_size = 8
seq_len = 2048
warmup_pct = 0.20
max_norm = 1.0
steps = -1
steps = 10
data_parallel_degree = -1
sequence_parallel_degree = 1
pipeline_parallel_degree = 1
compile = true
compile = false
checkpoint_interval = 3600
checkpoint_interval_type = "steps"
checkpoint_folder = ""
data_set = "alpaca"
dataset = "alpaca"
63 changes: 38 additions & 25 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

from torchtrain.checkpoint import CheckpointManager, IntervalType
from torchtrain.config_manager import JobConfig

# torchtrain related
from torchtrain.datasets import create_tokenizer, dataloader_fn
Expand Down Expand Up @@ -49,14 +50,16 @@ def load_state_dict(self, state_dict) -> None:
self.losses = state_dict["losses"].tolist()


def build_optimizer(model, args):
def build_optimizer(model, job_config: JobConfig):
# build optimizer
if args.optimizer == "Adam":
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
elif args.optimizer == "AdamW":
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
name = job_config.optimizer.name
lr = job_config.optimizer.lr
if name == "Adam":
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
elif name == "AdamW":
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
else:
raise NotImplementedError(f"optimizer {args.optimizer} not added")
raise NotImplementedError(f"optimizer {name} not added")

return optimizer

Expand All @@ -75,65 +78,69 @@ def build_grad_scaler(model):

def main(args):
init_logger()
job_config = JobConfig(args.config_file)
# init world mesh
world_size = int(os.environ["WORLD_SIZE"])
parallel_dims = ParallelDims(
dp=args.dp_degree, sp=args.sp_degree, pp=args.pp_degree, world_size=world_size
dp=job_config.training.data_parallel_degree,
sp=job_config.training.sequence_parallel_degree,
pp=job_config.training.pipeline_parallel_degree,
world_size=world_size
)
world_mesh = parallel_dims.build_mesh(device_type="cuda")

model_name = args.model
model_name = job_config.model.name
rank0_log(f"Building {model_name}")
# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = create_tokenizer(tokenizer_type, args.tokenizer_path)
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)

# build dataloader
# need dp world size and rank
# TODO: dp might not always be 0 so we need to handle that more carefully
dp_degree = world_mesh.size(0)
dp_rank = world_mesh.get_local_rank(0)
build_dataloader_fn = dataloader_fn[args.dataset]
build_dataloader_fn = dataloader_fn[job_config.training.dataset]
data_loader = build_dataloader_fn(
tokenizer,
args.batch_size,
args.seq_len,
job_config.training.batch_size,
job_config.training.seq_len,
dp_degree,
dp_rank,
)

# build model
# TODO: add meta initialization
model_cls = model_name_to_cls[model_name]
model_config = models_config[model_name][args.model_conf]
model_config = models_config[model_name][job_config.model.model_conf]
model_config.vocab_size = tokenizer.n_words

model = model_cls.from_model_args(model_config)

# log model size
model_param_count = get_num_params(model)
rank0_log(
f"Model {model_name} {args.model_conf} size: {model_param_count:,} total parameters"
f"Model {model_name} {job_config.model.model_conf} size: {model_param_count:,} total parameters"
)
gpu_metrics = GPUMemoryMonitor("cuda")
rank0_log(f"GPU memory usage: {gpu_metrics}")

# apply PTD parallelisms + AC
model = models_parallelize_fns[model_name](model, world_mesh, parallel_dims, args)
model = models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)

# to use FSDP-customized gradient scaler and gradient clipping solutions
assert isinstance(model, FSDP)

# build optimizer after apply parallelisms to the model
optimizer = build_optimizer(model, args)
scheduler = get_lr_scheduler(optimizer, args)
optimizer = build_optimizer(model, job_config)
scheduler = get_lr_scheduler(optimizer, job_config)

scaler = build_grad_scaler(model)

metric_logger = build_metric_logger()

# torch.compile model for improved performance
if args.compile:
if job_config.training.compile:
rank0_log(f"Compiling model {model_name} with torch.compile...")
model = torch.compile(
model,
Expand All @@ -148,13 +155,13 @@ def main(args):
model=model,
optimizer=optimizer,
states={"train_state": train_state},
folder=args.checkpoint_folder,
folder=job_config.training.checkpoint_folder,
interval_type=(
IntervalType.SECONDS
if args.checkpoint_interval_type == "seconds"
if job_config.training.checkpoint_interval_type == "seconds"
else IntervalType.STEPS
),
interval=args.checkpoint_interval,
interval=job_config.training.checkpoint_interval,
)
checkpoint.load()

Expand All @@ -164,7 +171,7 @@ def main(args):
losses_since_last_log: List[float] = []
nwords_since_last_log = 0
time_last_log = timer()
while train_state.step < args.steps or args.steps == -1:
while train_state.step < job_config.training.steps or job_config.training.steps == -1:
train_state.step += 1
# get batch
batch = next(iter(data_loader))
Expand All @@ -187,7 +194,7 @@ def main(args):

# clip gradients (after unscaling gradients of the optimizer's params)
scaler.unscale_(optimizer)
model.clip_grad_norm_(args.max_norm)
model.clip_grad_norm_(job_config.training.max_norm)

# optimizer step
# If gradients don't contain infs/NaNs, optimizer.step() is then called;
Expand All @@ -206,7 +213,7 @@ def main(args):
losses_since_last_log.append(train_state.current_loss)

# log metrics
if (train_state.step - 1) % args.log_freq == 0:
if (train_state.step - 1) % job_config.metrics.log_freq == 0:
avg_loss, max_loss = np.mean(losses_since_last_log), np.max(
losses_since_last_log
)
Expand Down Expand Up @@ -243,7 +250,7 @@ def main(args):
)
scheduler.step()

checkpoint.save(train_state.step, force=(train_state.step == args.steps))
checkpoint.save(train_state.step, force=(train_state.step == job_config.training.steps))

metric_logger.close()
rank0_log(f"{gpu_metrics.get_current_stats()}")
Expand All @@ -253,6 +260,12 @@ def main(args):
parser = argparse.ArgumentParser(description="TorchTrain arg parser.")
LOCAL_WORLD_SIZE = int(os.environ["LOCAL_WORLD_SIZE"])

parser.add_argument(
"--config_file",
type=str,
default="./torchtrain/train_configs/train_config.toml",
help="job config file",
)
parser.add_argument(
"--model", type=str, default="llama", help="which model to train"
)
Expand Down

0 comments on commit bc9ea0d

Please sign in to comment.