Skip to content

Commit

Permalink
Rename var
Browse files Browse the repository at this point in the history
  • Loading branch information
jaysonfrancis committed Jan 3, 2025
1 parent eb3b94c commit 0553659
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import asdict, dataclass, field, fields, is_dataclass
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import tyro
Expand All @@ -24,10 +24,6 @@
}


def string_list(raw_arg):
return raw_arg.split(",")


@dataclass
class Job:
config_file: Optional[str] = None
Expand Down Expand Up @@ -452,7 +448,7 @@ def _update(self, instance: "JobConfig") -> None:
for f in fields(self):
setattr(self, f.name, getattr(instance, f.name, getattr(self, f.name)))

def parse_args(self):
def parse_args(self) -> None:
"""
Parse CLI arguments, optionally load from a TOML file,
merge with defaults, and return a JobConfig instance.
Expand All @@ -479,7 +475,7 @@ def _load_toml(file_path: str) -> Dict[str, Any]:
logger.exception(f"Error while loading config file: {file_path}")
raise e

def _dict_to_dataclass(self, config_class, data: Dict[str, Any]) -> Any:
def _dict_to_dataclass(self, config_class: Callable, data: Dict[str, Any]) -> Any:
"""Recursively convert dictionaries to nested dataclasses."""
if not is_dataclass(config_class):
return data
Expand All @@ -494,23 +490,21 @@ def _dict_to_dataclass(self, config_class, data: Dict[str, Any]) -> Any:
kwargs[f.name] = value
return config_class(**kwargs)

def _merge_with_defaults(
self, source: "JobConfig", defaults: "JobConfig"
) -> "JobConfig":
def _merge_with_defaults(self, target, defaults) -> Any:
"""Recursively merge two dataclass instances (source overrides defaults)."""
merged_kwargs = {}
for f in fields(source):
source_val = getattr(source, f.name)
for f in fields(target):
target_val = getattr(target, f.name)
default_val = getattr(defaults, f.name)
if is_dataclass(source_val) and is_dataclass(default_val):
if is_dataclass(target_val) and is_dataclass(default_val):
merged_kwargs[f.name] = self._merge_with_defaults(
source_val, default_val
target_val, default_val
)
else:
merged_kwargs[f.name] = (
source_val if source_val is not None else default_val
target_val if target_val is not None else default_val
)
return type(source)(**merged_kwargs)
return type(target)(**merged_kwargs)

def _validate_config(self) -> None:
# TODO: Add more mandatory validations
Expand Down

0 comments on commit 0553659

Please sign in to comment.