diff --git a/test/test_job_config.py b/test/test_job_config.py index ccdaf206..0e3d9c63 100644 --- a/test/test_job_config.py +++ b/test/test_job_config.py @@ -9,12 +9,12 @@ class TestJobConfig: def test_command_line_args(self): config = JobConfig() config.parse_args([]) - assert config.model.name == "llama" + assert config.training.steps == -1 def test_job_config_file(self): config = JobConfig() config.parse_args(["--job.config_file", "./train_configs/debug_model.toml"]) - assert config.model.name == "llama" + assert config.training.steps == 10 def test_job_file_does_not_exist(self): with pytest.raises(FileNotFoundError): diff --git a/torchtrain/config_manager.py b/torchtrain/config_manager.py index ff8afe8f..613f9411 100644 --- a/torchtrain/config_manager.py +++ b/torchtrain/config_manager.py @@ -1,3 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. import argparse @@ -17,16 +20,16 @@ class JobConfig: Semantics: - Default config is loaded from a toml file. If no toml file is provided, then the default config is loaded from argparse defaults. + - if toml file has missing keys, they are filled with argparse defaults. """ def parse_args(self, args_list: list = sys.argv[1:]): args = JobConfig.init_args_from_command_line(args_list) config_file = getattr(args, "job.config_file", None) - if config_file is None: - args_dict = self._args_to_two_level_dict(args) - else: + args_dict = self._args_to_two_level_dict(args) + if config_file is not None: with open(config_file, "rb") as f: - args_dict = tomllib.load(f) + args_dict |= tomllib.load(f) for k, v in args_dict.items(): class_type = type(k.title(), (), v) setattr(self, k, class_type())