Skip to content

Commit

Permalink
Config manager supports command line overrides
Browse files Browse the repository at this point in the history
Summary:
PR implements following enhancements to config manager.
1. Command line args and toml file args are now unified.
2. Defaults can be loaded from either.
3. Defaults can be overridden through command line. Overrides will
be applied irrespective of where the defaults were loaded from.

Test Plan:
============================= test session starts ==============================
platform linux -- Python 3.10.13, pytest-8.0.1, pluggy-1.4.0 -- /home/gnadathur/local/a/pytorch-env/bin/python
cachedir: .pytest_cache
rootdir: /data/users/gnadathur/a/torchtrain
configfile: pyproject.toml
plugins: cov-4.1.0
collecting ... collected 5 items

test/test_job_config.py::TestJobConfig::test_command_line_args PASSED    [ 20%]
test/test_job_config.py::TestJobConfig::test_command_line_args_with_override PASSED [ 40%]
test/test_job_config.py::TestJobConfig::test_job_config_file PASSED      [ 60%]
test/test_job_config.py::TestJobConfig::test_job_config_file_with_override PASSED [ 80%]
test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist PASSED [100%]

---------- coverage: platform linux, python 3.10.13-final-0 ----------
Coverage XML written to file coverage.xml

============================= slowest 20 durations =============================
0.01s call     test/test_job_config.py::TestJobConfig::test_job_config_file_with_override
0.00s call     test/test_job_config.py::TestJobConfig::test_job_config_file
0.00s call     test/test_job_config.py::TestJobConfig::test_command_line_args
0.00s call     test/test_job_config.py::TestJobConfig::test_command_line_args_with_override
0.00s call     test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist
0.00s setup    test/test_job_config.py::TestJobConfig::test_command_line_args
0.00s teardown test/test_job_config.py::TestJobConfig::test_command_line_args
0.00s setup    test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist
0.00s setup    test/test_job_config.py::TestJobConfig::test_command_line_args_with_override
0.00s teardown test/test_job_config.py::TestJobConfig::test_command_line_args_with_override
0.00s setup    test/test_job_config.py::TestJobConfig::test_job_config_file_with_override
0.00s setup    test/test_job_config.py::TestJobConfig::test_job_config_file
0.00s teardown test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist
0.00s teardown test/test_job_config.py::TestJobConfig::test_job_config_file
0.00s teardown test/test_job_config.py::TestJobConfig::test_job_config_file_with_override
============================== 5 passed in 0.10s ===============================

Successful 10 iterations llama.sh run

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 9a0f588849f520652a431df438318be3ab0f9578
Pull Request resolved: #69
  • Loading branch information
gnadathur committed Feb 23, 2024
1 parent e2d7408 commit 555c2ba
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 124 deletions.
2 changes: 1 addition & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ CONFIG_FILE=${CONFIG_FILE:-"./torchtrain/train_configs/train_config.toml"}

torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --config_file ${CONFIG_FILE}
train.py --global_config_file ${CONFIG_FILE}
23 changes: 19 additions & 4 deletions test/test_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,25 @@
from torchtrain.config_manager import JobConfig

class TestJobConfig():
def test_job_config(self):
config = JobConfig()
def test_command_line_args(self):
config = JobConfig([])
assert config.model.name == "llama"

def test_file_does_not_exist(self):
def test_command_line_args_with_override(self):
config = JobConfig(["--metrics_log_freq" , "2", "--metrics_enable_tensorboard"])
assert config.metrics.log_freq == 2
assert config.metrics.enable_tensorboard

def test_job_config_file(self):
config = JobConfig(["--global_config_file", "./torchtrain/train_configs/train_config.toml"])
assert config.model.name == "llama"

def test_job_config_file_with_override(self):
config = JobConfig(["--global_config_file",
"./torchtrain/train_configs/train_config.toml",
"--metrics_log_freq" , "2"])
assert config.metrics.log_freq == 2

def test_job_file_does_not_exist(self):
with pytest.raises(FileNotFoundError):
JobConfig("ohno.toml")
JobConfig(["--global_config_file", "ohno.toml"])
212 changes: 202 additions & 10 deletions torchtrain/config_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

import argparse
import sys
from collections import defaultdict
from typing import Tuple, Union
try:
import tomllib
except ModuleNotFoundError:
Expand All @@ -9,21 +12,210 @@
class JobConfig:
"""
A helper class to manage the train configuration.
Semantics:
- Default config is loaded from a toml file. If no toml file is provided,
then the default config is loaded from argparse defaults.
- Then, Override config is loaded from command line arguments.
"""

def __init__(self, config_path: str = None):
self._config_path = "./torchtrain/train_configs/train_config.toml" if config_path is None else config_path
def __init__(self, args_list: list = sys.argv[1:]):
self._args_list = args_list
self._import_config()

def _import_config(self):
with open(self._config_path, "rb") as f:
self._config = tomllib.load(f)
for k, v in self._config.items():
class_type = type(k.title(), (), v)
setattr(self, k, class_type())
self._validate_config()
default_args, override_args = self._init_args()
config_path = default_args.global_config_file
args_dict = defaultdict(defaultdict)
override_dict = self._args_to_two_level_dict(override_args)
if config_path is None:
args_dict = self._args_to_two_level_dict(default_args)
else:
with open(config_path, "rb") as f:
args_dict = tomllib.load(f)
for first_level_key , d in override_dict.items():
for second_level_key, v in d.items():
args_dict[first_level_key][second_level_key] = v
for k, v in args_dict.items():
class_type = type(k.title(), (), v)
setattr(self, k, class_type())
self._validate_config()

def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
args_dict = defaultdict(defaultdict)
for k, v in vars(args).items():
first_level_key , second_level_key = k.split("_", 1)
args_dict[first_level_key][second_level_key] = v
return args_dict


def _validate_config(self):
# TODO: Add more mandatory validations
assert self.model.name and self.model.model_conf and self.model.tokenizer_path
assert self.model.name and self.model.config and self.model.tokenizer_path
return True

def _init_args(self) -> Tuple:
"""
Each argument starts with <prefix>_ which is the section name in the toml file
followed by name of the option in the toml file. For ex,
model_name translates to:
[model]
name
in the toml file
"""
parser = argparse.ArgumentParser(description="TorchTrain arg parser.")
parser.add_argument(
"--global_config_file",
type=str,
default=None,
help="job config file",
)

# global configs
parser.add_argument(
"--global_dump_folder",
type=str,
default="./torchtrain/outputs",
help="folder to dump job outputs",
)

# profiling configs
parser.add_argument(
"--profiling_run_profiler",
action="store_true",
help="enable pytorch profiler",
)
parser.add_argument(
"--profiling_save_traces_folder",
type=str,
default="profiling/traces",
help="trace file location",
)
parser.add_argument(
"--profiling_profile_every_x_iter",
type=int,
default=10,
help="collect profiler traces every x iterations",
)
# metrics configs
parser.add_argument(
"--metrics_log_freq",
type=int,
default=10,
help="how often to log metrics to TensorBoard",
)
parser.add_argument(
"--metrics_enable_tensorboard",
action="store_true",
help="how often to log metrics to TensorBoard",
)
parser.add_argument(
"--metrics_save_tb_folder",
type=str,
default="tb",
help="folder to dump tensorboard state",
)

# model configs
parser.add_argument(
"--model_name",
type=str,
default="llama",
help="which model to train",
)
parser.add_argument(
"--model_config",
type=str,
default="debugmodel",
help="which model config to train",
)
parser.add_argument(
"--model_tokenizer_path",
type=str,
default="./torchtrain/datasets/tokenizer/tokenizer.model",
help="tokenizer path",
)


# optimizer configs
parser.add_argument(
"--optimizer_name", type=str, default="AdamW", help="optimizer to use"
)
parser.add_argument("--optimizer_lr", type=float, default=8e-4, help="learning rate to use")

# training configs
parser.add_argument("--training_dataset", type=str, default="alpaca", help="dataset to use")
parser.add_argument("--training_batch_size", type=int, default=8, help="batch size")
parser.add_argument("--training_seq_len", type=int, default=2048, help="sequence length")
parser.add_argument(
"--training_warmup_pct",
type=float,
default=0.20,
help="percentage of total training steps to use for warmup",
)
parser.add_argument(
"--training_max_norm",
type=Union[float, int],
default=1.0,
help="max norm for gradient clipping",
)
parser.add_argument(
"--training_steps", type=int, default=-1, help="how many train steps to run"
)
parser.add_argument(
"--training_data_parallel_degree",
type=int,
default=-1,
help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.",
)
parser.add_argument(
"--training_sequence_parallel_degree",
type=int,
default=1,
help="Sequence Parallelism degree. 1 means disabled.",
)
parser.add_argument(
"--training_pipeline_parallel_degree",
type=int,
default=1,
help="Pipeline Parallelism degree (default of 1 means disabled)",
)
parser.add_argument(
"--training_compile", action="store_true", help="Whether to compile the model."
)
parser.add_argument(
"--training_checkpoint_interval",
type=int,
default=3600,
help=(
"Checkpointing interval. The unit of measurement is in seconds or "
"steps depending on --training_checkpoint-internval-type."
),
)
parser.add_argument(
"--training_checkpoint_interval_type",
type=str,
default="steps",
help=(
"The checkpointing interval unit of measurement."
"The default value is step."
),
)
parser.add_argument(
"--training_checkpoint_folder",
type=str,
default="",
help=(
"The folder to store the checkpoints. If this is not specified or "
"is an empty string, checkpointing is disabled."
),
)
args = parser.parse_args(self._args_list)
aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
for arg , val in vars(args).items():
if isinstance(val, bool):
aux_parser.add_argument('--'+arg,
action='store_true' if val else 'store_false')
else:
aux_parser.add_argument('--'+arg, type=type(val))
override_args, _ = aux_parser.parse_known_args(self._args_list)
return args, override_args
2 changes: 1 addition & 1 deletion torchtrain/train_configs/train_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ log_freq = 10

[model]
name = "llama"
model_conf = "debugmodel"
config = "debugmodel"
tokenizer_path = "./torchtrain/datasets/tokenizer/tokenizer.model"

[optimizer]
Expand Down
Loading

0 comments on commit 555c2ba

Please sign in to comment.