diff --git a/tests/unit_tests/test_job_config.py b/tests/unit_tests/test_job_config.py index aed007ad..284decd3 100644 --- a/tests/unit_tests/test_job_config.py +++ b/tests/unit_tests/test_job_config.py @@ -48,10 +48,7 @@ def test_job_config_file_cmd_overrides(self): def test_parse_pp_split_points(self): toml_splits = ["layers.2", "layers.4", "layers.6"] - toml_split_str = ",".join(toml_splits) cmdline_splits = ["layers.1", "layers.3", "layers.5"] - cmdline_split_str = ",".join(cmdline_splits) - # no split points specified config = JobConfig() config.parse_args( [ @@ -68,7 +65,7 @@ def test_parse_pp_split_points(self): "--job.config_file", "./train_configs/debug_model.toml", "--experimental.pipeline_parallel_split_points", - f"{cmdline_split_str}", + *cmdline_splits, ] ) assert ( @@ -81,7 +78,7 @@ def test_parse_pp_split_points(self): tomli_w.dump( { "experimental": { - "pipeline_parallel_split_points": toml_split_str, + "pipeline_parallel_split_points": toml_splits, } }, f, @@ -98,7 +95,7 @@ def test_parse_pp_split_points(self): tomli_w.dump( { "experimental": { - "pipeline_parallel_split_points": toml_split_str, + "pipeline_parallel_split_points": toml_splits, } }, f, @@ -109,7 +106,7 @@ def test_parse_pp_split_points(self): "--job.config_file", fp.name, "--experimental.pipeline_parallel_split_points", - f"{cmdline_split_str}", + *cmdline_splits, ] ) assert ( @@ -117,6 +114,7 @@ def test_parse_pp_split_points(self): ), config.experimental.pipeline_parallel_split_points def test_print_help(self): - config = JobConfig() - parser = config.parser + from tyro.extras import get_parser + + parser = get_parser(JobConfig) parser.print_help() diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 400d86a4..28c9ddf3 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +import sys from dataclasses import asdict, dataclass, field, fields, is_dataclass from typing import Any, Callable, Dict, List, Optional, Union @@ -221,7 +223,7 @@ class Experimental: """ Specify comma-separated names of modules to use as the beginning of a split point. - e.g. "layers.0,layers.2" will cause the model to be split into 3 stages, + e.g. "layers.0" "layers.2" will cause the model to be split into 3 stages, the first containing all the layers up to layers.0, the second containing layers.0 and up to layers.2, the third containing layers.2 and all the remaining layers. @@ -444,25 +446,28 @@ class JobConfig: def to_dict(self) -> Dict[str, Any]: return asdict(self) - def _update(self, instance: "JobConfig") -> None: + def _update(self, instance) -> None: for f in fields(self): setattr(self, f.name, getattr(instance, f.name, getattr(self, f.name))) - def parse_args(self) -> None: + def find_config_file(self, argv): + config_flags = ("--job.config-file", "--job.config_file") + for i, arg in enumerate(argv[:-1]): + if arg in config_flags: + return argv[i + 1] + return None + + def parse_args(self, args=None) -> None: """ Parse CLI arguments, optionally load from a TOML file, merge with defaults, and return a JobConfig instance. """ - defaults = tyro.cli(self.__class__) - config_file = defaults.job.config_file + config = self.__class__ # initialize with defaults + config_file = self.find_config_file(args if args is not None else sys.argv[1:]) if config_file: toml_data = self._load_toml(config_file) - toml_config = self._dict_to_dataclass(self.__class__, toml_data) - merged_config = self._merge_with_defaults(toml_config, defaults) - # TODO: find a way to make this work without two calls - final_config = tyro.cli(self.__class__, default=merged_config) - else: - final_config = defaults + config = self._dict_to_dataclass(self.__class__, toml_data) + final_config = tyro.cli(self.__class__, args=args, default=config) self._update(final_config) self._validate_config() @@ -490,22 +495,6 @@ def _dict_to_dataclass(self, config_class: Callable, data: Dict[str, Any]) -> An kwargs[f.name] = value return config_class(**kwargs) - def _merge_with_defaults(self, target, defaults) -> Any: - """Recursively merge two dataclass instances (source overrides defaults).""" - merged_kwargs = {} - for f in fields(target): - target_val = getattr(target, f.name) - default_val = getattr(defaults, f.name) - if is_dataclass(target_val) and is_dataclass(default_val): - merged_kwargs[f.name] = self._merge_with_defaults( - target_val, default_val - ) - else: - merged_kwargs[f.name] = ( - target_val if target_val is not None else default_val - ) - return type(target)(**merged_kwargs) - def _validate_config(self) -> None: # TODO: Add more mandatory validations assert self.model.name, "Model name is required"