diff --git a/run_llama_train.sh b/run_llama_train.sh index c9f79fea..d40945e8 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -27,4 +27,4 @@ CONFIG_FILE=${CONFIG_FILE:-"./torchtrain/train_configs/train_config.toml"} torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ -train.py --config_file ${CONFIG_FILE} +train.py --global_config_file ${CONFIG_FILE} diff --git a/test/test_job_config.py b/test/test_job_config.py index f9b1f3cd..2d03af69 100644 --- a/test/test_job_config.py +++ b/test/test_job_config.py @@ -2,10 +2,25 @@ from torchtrain.config_manager import JobConfig class TestJobConfig(): - def test_job_config(self): - config = JobConfig() + def test_command_line_args(self): + config = JobConfig([]) assert config.model.name == "llama" - def test_file_does_not_exist(self): + def test_command_line_args_with_override(self): + config = JobConfig(["--metrics_log_freq" , "2", "--metrics_enable_tensorboard"]) + assert config.metrics.log_freq == 2 + assert config.metrics.enable_tensorboard + + def test_job_config_file(self): + config = JobConfig(["--global_config_file", "./torchtrain/train_configs/train_config.toml"]) + assert config.model.name == "llama" + + def test_job_config_file_with_override(self): + config = JobConfig(["--global_config_file", + "./torchtrain/train_configs/train_config.toml", + "--metrics_log_freq" , "2"]) + assert config.metrics.log_freq == 2 + + def test_job_file_does_not_exist(self): with pytest.raises(FileNotFoundError): - JobConfig("ohno.toml") + JobConfig(["--global_config_file", "ohno.toml"]) diff --git a/torchtrain/config_manager.py b/torchtrain/config_manager.py index df2886c2..15e6956a 100644 --- a/torchtrain/config_manager.py +++ b/torchtrain/config_manager.py @@ -1,6 +1,9 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. - +import argparse +import sys +from collections import defaultdict +from typing import Tuple, Union try: import tomllib except ModuleNotFoundError: @@ -9,21 +12,210 @@ class JobConfig: """ A helper class to manage the train configuration. + Semantics: + - Default config is loaded from a toml file. If no toml file is provided, + then the default config is loaded from argparse defaults. + - Then, Override config is loaded from command line arguments. """ - def __init__(self, config_path: str = None): - self._config_path = "./torchtrain/train_configs/train_config.toml" if config_path is None else config_path + def __init__(self, args_list: list = sys.argv[1:]): + self._args_list = args_list self._import_config() def _import_config(self): - with open(self._config_path, "rb") as f: - self._config = tomllib.load(f) - for k, v in self._config.items(): - class_type = type(k.title(), (), v) - setattr(self, k, class_type()) - self._validate_config() + default_args, override_args = self._init_args() + config_path = default_args.global_config_file + args_dict = defaultdict(defaultdict) + override_dict = self._args_to_two_level_dict(override_args) + if config_path is None: + args_dict = self._args_to_two_level_dict(default_args) + else: + with open(config_path, "rb") as f: + args_dict = tomllib.load(f) + for first_level_key , d in override_dict.items(): + for second_level_key, v in d.items(): + args_dict[first_level_key][second_level_key] = v + for k, v in args_dict.items(): + class_type = type(k.title(), (), v) + setattr(self, k, class_type()) + self._validate_config() + + def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict: + args_dict = defaultdict(defaultdict) + for k, v in vars(args).items(): + first_level_key , second_level_key = k.split("_", 1) + args_dict[first_level_key][second_level_key] = v + return args_dict + def _validate_config(self): # TODO: Add more mandatory validations - assert self.model.name and self.model.model_conf and self.model.tokenizer_path + assert self.model.name and self.model.config and self.model.tokenizer_path return True + + def _init_args(self) -> Tuple: + """ + Each argument starts with _ which is the section name in the toml file + followed by name of the option in the toml file. For ex, + model_name translates to: + [model] + name + in the toml file + """ + parser = argparse.ArgumentParser(description="TorchTrain arg parser.") + parser.add_argument( + "--global_config_file", + type=str, + default=None, + help="job config file", + ) + + # global configs + parser.add_argument( + "--global_dump_folder", + type=str, + default="./torchtrain/outputs", + help="folder to dump job outputs", + ) + + # profiling configs + parser.add_argument( + "--profiling_run_profiler", + action="store_true", + help="enable pytorch profiler", + ) + parser.add_argument( + "--profiling_save_traces_folder", + type=str, + default="profiling/traces", + help="trace file location", + ) + parser.add_argument( + "--profiling_profile_every_x_iter", + type=int, + default=10, + help="collect profiler traces every x iterations", + ) + # metrics configs + parser.add_argument( + "--metrics_log_freq", + type=int, + default=10, + help="how often to log metrics to TensorBoard", + ) + parser.add_argument( + "--metrics_enable_tensorboard", + action="store_true", + help="how often to log metrics to TensorBoard", + ) + parser.add_argument( + "--metrics_save_tb_folder", + type=str, + default="tb", + help="folder to dump tensorboard state", + ) + + # model configs + parser.add_argument( + "--model_name", + type=str, + default="llama", + help="which model to train", + ) + parser.add_argument( + "--model_config", + type=str, + default="debugmodel", + help="which model config to train", + ) + parser.add_argument( + "--model_tokenizer_path", + type=str, + default="./torchtrain/datasets/tokenizer/tokenizer.model", + help="tokenizer path", + ) + + + # optimizer configs + parser.add_argument( + "--optimizer_name", type=str, default="AdamW", help="optimizer to use" + ) + parser.add_argument("--optimizer_lr", type=float, default=8e-4, help="learning rate to use") + + # training configs + parser.add_argument("--training_dataset", type=str, default="alpaca", help="dataset to use") + parser.add_argument("--training_batch_size", type=int, default=8, help="batch size") + parser.add_argument("--training_seq_len", type=int, default=2048, help="sequence length") + parser.add_argument( + "--training_warmup_pct", + type=float, + default=0.20, + help="percentage of total training steps to use for warmup", + ) + parser.add_argument( + "--training_max_norm", + type=Union[float, int], + default=1.0, + help="max norm for gradient clipping", + ) + parser.add_argument( + "--training_steps", type=int, default=-1, help="how many train steps to run" + ) + parser.add_argument( + "--training_data_parallel_degree", + type=int, + default=-1, + help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.", + ) + parser.add_argument( + "--training_sequence_parallel_degree", + type=int, + default=1, + help="Sequence Parallelism degree. 1 means disabled.", + ) + parser.add_argument( + "--training_pipeline_parallel_degree", + type=int, + default=1, + help="Pipeline Parallelism degree (default of 1 means disabled)", + ) + parser.add_argument( + "--training_compile", action="store_true", help="Whether to compile the model." + ) + parser.add_argument( + "--training_checkpoint_interval", + type=int, + default=3600, + help=( + "Checkpointing interval. The unit of measurement is in seconds or " + "steps depending on --training_checkpoint-internval-type." + ), + ) + parser.add_argument( + "--training_checkpoint_interval_type", + type=str, + default="steps", + help=( + "The checkpointing interval unit of measurement." + "The default value is step." + ), + ) + parser.add_argument( + "--training_checkpoint_folder", + type=str, + default="", + help=( + "The folder to store the checkpoints. If this is not specified or " + "is an empty string, checkpointing is disabled." + ), + ) + args = parser.parse_args(self._args_list) + aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS) + for arg , val in vars(args).items(): + if isinstance(val, bool): + aux_parser.add_argument('--'+arg, + action='store_true' if val else 'store_false') + else: + aux_parser.add_argument('--'+arg, type=type(val)) + override_args, _ = aux_parser.parse_known_args(self._args_list) + return args, override_args diff --git a/torchtrain/train_configs/train_config.toml b/torchtrain/train_configs/train_config.toml index 6458b5a4..e614e275 100644 --- a/torchtrain/train_configs/train_config.toml +++ b/torchtrain/train_configs/train_config.toml @@ -15,7 +15,7 @@ log_freq = 10 [model] name = "llama" -model_conf = "debugmodel" +config = "debugmodel" tokenizer_path = "./torchtrain/datasets/tokenizer/tokenizer.model" [optimizer] diff --git a/train.py b/train.py index fbe815ef..f5c52077 100644 --- a/train.py +++ b/train.py @@ -76,9 +76,8 @@ def build_grad_scaler(model): return ShardedGradScaler(enabled=enable_grad_scaling) -def main(args): +def main(job_config: JobConfig): init_logger() - job_config = JobConfig(args.config_file) # init world mesh world_size = int(os.environ["WORLD_SIZE"]) parallel_dims = ParallelDims( @@ -112,7 +111,7 @@ def main(args): # build model # TODO: add meta initialization model_cls = model_name_to_cls[model_name] - model_config = models_config[model_name][job_config.model.model_conf] + model_config = models_config[model_name][job_config.model.config] model_config.vocab_size = tokenizer.n_words model = model_cls.from_model_args(model_config) @@ -120,7 +119,7 @@ def main(args): # log model size model_param_count = get_num_params(model) rank0_log( - f"Model {model_name} {job_config.model.model_conf} size: {model_param_count:,} total parameters" + f"Model {model_name} {job_config.model.config} size: {model_param_count:,} total parameters" ) gpu_metrics = GPUMemoryMonitor("cuda") rank0_log(f"GPU memory usage: {gpu_metrics}") @@ -255,108 +254,6 @@ def main(args): metric_logger.close() rank0_log(f"{gpu_metrics.get_current_stats()}") - if __name__ == "__main__": - parser = argparse.ArgumentParser(description="TorchTrain arg parser.") - LOCAL_WORLD_SIZE = int(os.environ["LOCAL_WORLD_SIZE"]) - - parser.add_argument( - "--config_file", - type=str, - default="./torchtrain/train_configs/train_config.toml", - help="job config file", - ) - parser.add_argument( - "--model", type=str, default="llama", help="which model to train" - ) - parser.add_argument( - "--model_conf", - type=str, - default="debugmodel", - help="which model config to train", - ) - parser.add_argument("--dataset", type=str, default="alpaca", help="dataset to use") - parser.add_argument( - "--tokenizer_path", - type=str, - default="./torchtrain/datasets/tokenizer/tokenizer.model", - help="tokenizer path", - ) - parser.add_argument("--batch_size", type=int, default=8, help="batch size") - parser.add_argument("--seq_len", type=int, default=2048, help="sequence length") - parser.add_argument( - "--optimizer", type=str, default="AdamW", help="optimizer to use" - ) - parser.add_argument("--lr", type=float, default=8e-4, help="learning rate to use") - parser.add_argument( - "--warmup_pct", - type=float, - default=0.20, - help="percentage of total training steps to use for warmup", - ) - parser.add_argument( - "--max_norm", - type=Union[float, int], - default=1.0, - help="max norm for gradient clipping", - ) - parser.add_argument( - "--steps", type=int, default=-1, help="how many train steps to run" - ) - parser.add_argument( - "--dp_degree", - type=int, - default=-1, - help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.", - ) - parser.add_argument( - "--sp_degree", - type=int, - default=1, - help="Sequence Parallelism degree. 1 means disabled.", - ) - parser.add_argument( - "--pp_degree", - type=int, - default=1, - help="Pipeline Parallelism degree (default of 1 means disabled)", - ) - parser.add_argument( - "--compile", action="store_true", help="Whether to compile the model." - ) - parser.add_argument( - "--checkpoint-interval", - type=int, - default=3600, - help=( - "Checkpointing interval. The unit of measurement is in seconds or " - "steps depending on --checkpoint-internval-type." - ), - ) - parser.add_argument( - "--checkpoint-interval-type", - type=str, - default="steps", - help=( - "The checkpointing interval unit of measurement." - "The default value is step." - ), - ) - parser.add_argument( - "--checkpoint-folder", - type=str, - default="", - help=( - "The folder to store the checkpoints. If this is not specified or " - "is an empty string, checkpointing is disabled." - ), - ) - parser.add_argument( - "--log_freq", - type=int, - default=10, - help="how often to log metrics to TensorBoard", - ) - - args = parser.parse_args() - main(args) + config = JobConfig() + main(config)