Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config manager supports command line overrides #69

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ CONFIG_FILE=${CONFIG_FILE:-"./torchtrain/train_configs/train_config.toml"}

torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --config_file ${CONFIG_FILE}
train.py --global_config_file ${CONFIG_FILE}
23 changes: 19 additions & 4 deletions test/test_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,25 @@
from torchtrain.config_manager import JobConfig

class TestJobConfig():
def test_job_config(self):
config = JobConfig()
def test_command_line_args(self):
config = JobConfig([])
assert config.model.name == "llama"

def test_file_does_not_exist(self):
def test_command_line_args_with_override(self):
config = JobConfig(["--metrics_log_freq" , "2", "--metrics_enable_tensorboard"])
assert config.metrics.log_freq == 2
assert config.metrics.enable_tensorboard

def test_job_config_file(self):
config = JobConfig(["--global_config_file", "./torchtrain/train_configs/train_config.toml"])
assert config.model.name == "llama"

def test_job_config_file_with_override(self):
config = JobConfig(["--global_config_file",
"./torchtrain/train_configs/train_config.toml",
"--metrics_log_freq" , "2"])
assert config.metrics.log_freq == 2

def test_job_file_does_not_exist(self):
with pytest.raises(FileNotFoundError):
JobConfig("ohno.toml")
JobConfig(["--global_config_file", "ohno.toml"])
212 changes: 202 additions & 10 deletions torchtrain/config_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

import argparse
import sys
from collections import defaultdict
from typing import Tuple, Union
try:
import tomllib
except ModuleNotFoundError:
Expand All @@ -9,21 +12,210 @@
class JobConfig:
"""
A helper class to manage the train configuration.
Semantics:
- Default config is loaded from a toml file. If no toml file is provided,
then the default config is loaded from argparse defaults.
- Then, Override config is loaded from command line arguments.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should either accept a config file, or a full set of arguments, wondering do you think override config would be something helpful here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we assume exclusive either or, that simplifies implementation. There are a couple of use cases to consider:

  1. Do we want the ability to use a base config file for a model say LLaMa-13b and submit a bunch of jobs with say different batch size overridden from command line ?

  2. For using it with internal environments, would we be able to pass the full set of arguments through command line or do we want to leverage command line defaults and pass overrides only ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So regarding the cases:

  1. I think often the time if I want to submit a job with a slight different config, I just modify the config file locally, and the torchx launcher is able to pick up the changes, we can also create separate configs for such a case
  2. for internal using, torchx launcher would also do the same thing, load the config options from a config file, and launch with those config options, and for the case where the config is not specified in the config file, the command line defaults is valuable there.

"""

def __init__(self, config_path: str = None):
self._config_path = "./torchtrain/train_configs/train_config.toml" if config_path is None else config_path
def __init__(self, args_list: list = sys.argv[1:]):
self._args_list = args_list
self._import_config()

def _import_config(self):
with open(self._config_path, "rb") as f:
self._config = tomllib.load(f)
for k, v in self._config.items():
class_type = type(k.title(), (), v)
setattr(self, k, class_type())
self._validate_config()
default_args, override_args = self._init_args()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function can happen inside parse_args?

config_path = default_args.global_config_file
args_dict = defaultdict(defaultdict)
override_dict = self._args_to_two_level_dict(override_args)
if config_path is None:
args_dict = self._args_to_two_level_dict(default_args)
else:
with open(config_path, "rb") as f:
args_dict = tomllib.load(f)
for first_level_key , d in override_dict.items():
for second_level_key, v in d.items():
args_dict[first_level_key][second_level_key] = v
for k, v in args_dict.items():
class_type = type(k.title(), (), v)
setattr(self, k, class_type())
self._validate_config()

def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
args_dict = defaultdict(defaultdict)
for k, v in vars(args).items():
first_level_key , second_level_key = k.split("_", 1)
args_dict[first_level_key][second_level_key] = v
return args_dict


def _validate_config(self):
# TODO: Add more mandatory validations
assert self.model.name and self.model.model_conf and self.model.tokenizer_path
assert self.model.name and self.model.config and self.model.tokenizer_path
return True

def _init_args(self) -> Tuple:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can put both: arg parser initialization, and arg parsing into this function, call it sth like parse_args? what I think would be nice in train.py:

config = JobConfig()
config.parse_args(args[1:])
main(config)

"""
Each argument starts with <prefix>_ which is the section name in the toml file
followed by name of the option in the toml file. For ex,
model_name translates to:
[model]
name
in the toml file
"""
parser = argparse.ArgumentParser(description="TorchTrain arg parser.")
parser.add_argument(
"--global_config_file",
type=str,
default=None,
help="job config file",
)

# global configs
parser.add_argument(
"--global_dump_folder",
type=str,
default="./torchtrain/outputs",
help="folder to dump job outputs",
)

# profiling configs
parser.add_argument(
"--profiling_run_profiler",
action="store_true",
help="enable pytorch profiler",
)
parser.add_argument(
"--profiling_save_traces_folder",
type=str,
default="profiling/traces",
help="trace file location",
)
parser.add_argument(
"--profiling_profile_every_x_iter",
type=int,
default=10,
help="collect profiler traces every x iterations",
)
# metrics configs
parser.add_argument(
"--metrics_log_freq",
type=int,
default=10,
help="how often to log metrics to TensorBoard",
)
parser.add_argument(
"--metrics_enable_tensorboard",
action="store_true",
help="how often to log metrics to TensorBoard",
)
parser.add_argument(
"--metrics_save_tb_folder",
type=str,
default="tb",
help="folder to dump tensorboard state",
)

# model configs
parser.add_argument(
"--model_name",
type=str,
default="llama",
help="which model to train",
)
parser.add_argument(
"--model_config",
type=str,
default="debugmodel",
help="which model config to train",
)
parser.add_argument(
"--model_tokenizer_path",
type=str,
default="./torchtrain/datasets/tokenizer/tokenizer.model",
help="tokenizer path",
)


# optimizer configs
parser.add_argument(
"--optimizer_name", type=str, default="AdamW", help="optimizer to use"
)
parser.add_argument("--optimizer_lr", type=float, default=8e-4, help="learning rate to use")

# training configs
parser.add_argument("--training_dataset", type=str, default="alpaca", help="dataset to use")
parser.add_argument("--training_batch_size", type=int, default=8, help="batch size")
parser.add_argument("--training_seq_len", type=int, default=2048, help="sequence length")
parser.add_argument(
"--training_warmup_pct",
type=float,
default=0.20,
help="percentage of total training steps to use for warmup",
)
parser.add_argument(
"--training_max_norm",
type=Union[float, int],
default=1.0,
help="max norm for gradient clipping",
)
parser.add_argument(
"--training_steps", type=int, default=-1, help="how many train steps to run"
)
parser.add_argument(
"--training_data_parallel_degree",
type=int,
default=-1,
help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.",
)
parser.add_argument(
"--training_sequence_parallel_degree",
type=int,
default=1,
help="Sequence Parallelism degree. 1 means disabled.",
)
parser.add_argument(
"--training_pipeline_parallel_degree",
type=int,
default=1,
help="Pipeline Parallelism degree (default of 1 means disabled)",
)
parser.add_argument(
"--training_compile", action="store_true", help="Whether to compile the model."
)
parser.add_argument(
"--training_checkpoint_interval",
type=int,
default=3600,
help=(
"Checkpointing interval. The unit of measurement is in seconds or "
"steps depending on --training_checkpoint-internval-type."
),
)
parser.add_argument(
"--training_checkpoint_interval_type",
type=str,
default="steps",
help=(
"The checkpointing interval unit of measurement."
"The default value is step."
),
)
parser.add_argument(
"--training_checkpoint_folder",
type=str,
default="",
help=(
"The folder to store the checkpoints. If this is not specified or "
"is an empty string, checkpointing is disabled."
),
)
args = parser.parse_args(self._args_list)
aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
for arg , val in vars(args).items():
if isinstance(val, bool):
aux_parser.add_argument('--'+arg,
action='store_true' if val else 'store_false')
else:
aux_parser.add_argument('--'+arg, type=type(val))
override_args, _ = aux_parser.parse_known_args(self._args_list)
return args, override_args
2 changes: 1 addition & 1 deletion torchtrain/train_configs/train_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ log_freq = 10

[model]
name = "llama"
model_conf = "debugmodel"
config = "debugmodel"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe call it type to avoid confusion with JobConfig?

tokenizer_path = "./torchtrain/datasets/tokenizer/tokenizer.model"

[optimizer]
Expand Down
Loading
Loading