Skip to content

Commit

Permalink
final cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jaysonfrancis committed Jan 4, 2025
1 parent de1000e commit da0d8b8
Showing 1 changed file with 25 additions and 30 deletions.
55 changes: 25 additions & 30 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,56 +446,51 @@ class JobConfig:
def to_dict(self) -> Dict[str, Any]:
return asdict(self)

def _update(self, instance) -> None:
def update_instance(self, instance) -> None:
for f in fields(self):
setattr(self, f.name, getattr(instance, f.name, getattr(self, f.name)))

def find_config_file(self, argv):
config_flags = ("--job.config-file", "--job.config_file")
for i, arg in enumerate(argv[:-1]):
if arg in config_flags:
return argv[i + 1]
return None

def parse_args(self, args=None) -> None:
"""
Parse CLI arguments, optionally load from a TOML file,
merge with defaults, and return a JobConfig instance.
"""
config = self.__class__ # initialize with defaults
config_file = self.find_config_file(args if args is not None else sys.argv[1:])
if config_file:
toml_data = self._load_toml(config_file)
config = self._dict_to_dataclass(self.__class__, toml_data)
final_config = tyro.cli(self.__class__, args=args, default=config)
self._update(final_config)
self._validate_config()

@staticmethod
def _load_toml(file_path: str) -> Dict[str, Any]:
try:
with open(file_path, "rb") as f:
return tomllib.load(f)
except (FileNotFoundError, tomllib.TOMLDecodeError) as e:
logger.exception(f"Error while loading config file: {file_path}")
raise e

def _dict_to_dataclass(self, config_class: Callable, data: Dict[str, Any]) -> Any:
toml_data = self.maybe_load_toml(args or sys.argv[1:])
defaults = self.__class__
if toml_data:
defaults = self.dict_to_dataclass(defaults, toml_data)
final_config = tyro.cli(self.__class__, args=args, default=defaults)
self.update_instance(final_config)
self.validate_config()

def maybe_load_toml(self, args):
config_flags = {"--job.config-file", "--job.config_file"}
for i, arg in enumerate(args[:-1]):
if arg in config_flags:
file_path = args[i + 1]
try:
with open(file_path, "rb") as f:
return tomllib.load(f)
except (FileNotFoundError, tomllib.TOMLDecodeError) as e:
logger.exception(f"Error while loading config file: {file_path}")
raise e
return None

def dict_to_dataclass(self, config_class, data) -> Any:
"""Recursively convert dictionaries to nested dataclasses."""
if not is_dataclass(config_class):
return data
kwargs = {}
for f in fields(config_class):
if f.name in data:
value = data[f.name]
# If target field is also a dataclass and value is a dict, recurse
if is_dataclass(f.type) and isinstance(value, dict):
kwargs[f.name] = self._dict_to_dataclass(f.type, value)
kwargs[f.name] = self.dict_to_dataclass(f.type, value)
else:
kwargs[f.name] = value
return config_class(**kwargs)

def _validate_config(self) -> None:
def validate_config(self) -> None:
# TODO: Add more mandatory validations
assert self.model.name, "Model name is required"
assert self.model.flavor, "Model flavor is required"
Expand Down

0 comments on commit da0d8b8

Please sign in to comment.