Skip to content

Commit

Permalink
Fill missing options in toml file wih argparse defaults (#91)
Browse files Browse the repository at this point in the history
Summary:

Summary:
Follow up on config unification, options not available in config file
are picked from command line defaults.

Test Plan:
============================= test session starts
============================== platform linux -- Python 3.10.13,
pytest-8.0.1, pluggy-1.4.0 --
/home/gnadathur/local/a/pytorch-env/bin/python cachedir: .pytest_cache
rootdir: /data/users/gnadathur/a/torchtrain
configfile: pyproject.toml
plugins: cov-4.1.0
collecting ... collected 3 items

test/test_job_config.py::TestJobConfig::test_command_line_args PASSED [
33%] test/test_job_config.py::TestJobConfig::test_job_config_file PASSED
[ 66%]
test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist
PASSED [100%]

---------- coverage: platform linux, python 3.10.13-final-0 ----------
Coverage XML written to file coverage.xml

============================= slowest 20 durations
============================= 0.00s call
test/test_job_config.py::TestJobConfig::test_job_config_file 0.00s call
test/test_job_config.py::TestJobConfig::test_command_line_args 0.00s
call
test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist
0.00s setup
test/test_job_config.py::TestJobConfig::test_command_line_args 0.00s
teardown test/test_job_config.py::TestJobConfig::test_command_line_args
0.00s setup test/test_job_config.py::TestJobConfig::test_job_config_file
0.00s setup
test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist
0.00s teardown
test/test_job_config.py::TestJobConfig::test_job_config_file 0.00s
teardown
test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist
============================== 3 passed in 0.06s
===============================

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

---------

Co-authored-by: gnadathur <[email protected]>
  • Loading branch information
gnadathur and gnadathur authored Feb 26, 2024
1 parent ae85e97 commit 97aa4bc
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
4 changes: 2 additions & 2 deletions test/test_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ class TestJobConfig:
def test_command_line_args(self):
config = JobConfig()
config.parse_args([])
assert config.model.name == "llama"
assert config.training.steps == -1

def test_job_config_file(self):
config = JobConfig()
config.parse_args(["--job.config_file", "./train_configs/debug_model.toml"])
assert config.model.name == "llama"
assert config.training.steps == 10

def test_job_file_does_not_exist(self):
with pytest.raises(FileNotFoundError):
Expand Down
11 changes: 7 additions & 4 deletions torchtrain/config_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
import argparse
Expand All @@ -17,16 +20,16 @@ class JobConfig:
Semantics:
- Default config is loaded from a toml file. If no toml file is provided,
then the default config is loaded from argparse defaults.
- if toml file has missing keys, they are filled with argparse defaults.
"""

def parse_args(self, args_list: list = sys.argv[1:]):
args = JobConfig.init_args_from_command_line(args_list)
config_file = getattr(args, "job.config_file", None)
if config_file is None:
args_dict = self._args_to_two_level_dict(args)
else:
args_dict = self._args_to_two_level_dict(args)
if config_file is not None:
with open(config_file, "rb") as f:
args_dict = tomllib.load(f)
args_dict |= tomllib.load(f)
for k, v in args_dict.items():
class_type = type(k.title(), (), v)
setattr(self, k, class_type())
Expand Down

0 comments on commit 97aa4bc

Please sign in to comment.