Skip to content

Commit

Permalink
feat: make model hparams nested (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
tilman151 authored Nov 16, 2023
1 parent 69004bc commit 3683860
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 21 deletions.
25 changes: 8 additions & 17 deletions rul_adapt/approach/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,25 @@ def set_model(
self.log_model_hyperparameters("feature_extractor", "regressor")

def log_model_hyperparameters(self, *model_names: str) -> None:
hparams_initial = self.hparams_initial
if not hasattr(self, "_logged_models"):
self._logged_models = {}
hparams_initial = self.hparams_initial
hparams_initial["model"] = {}

for model_name in model_names:
model_hparams = self._get_model_hparams(model_name)
hparams_initial.update(model_hparams)
hparams_initial["model"].update(model_hparams)
self._logged_models[model_name] = set(model_hparams.keys())

self._hparams_initial = hparams_initial
self._set_hparams(self._hparams_initial)

def _get_model_hparams(self, model_name):
prefix = f"model_{model_name.lstrip('_')}"
prefix = model_name.lstrip("_")
model = getattr(self, model_name)
hparams = {f"{prefix}_type": type(model).__name__}
hparams = {prefix: {"type": type(model).__name__}}
init_args = _get_init_args(model, "logging model hyperparameters")
hparams.update({f"{prefix}_{k}": v for k, v in init_args.items()})
hparams[prefix].update(init_args)

return hparams

Expand All @@ -112,22 +113,12 @@ def regressor(self) -> nn.Module:
raise RuntimeError("Regressor used before 'set_model' was called.")

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self._make_model_hparams_storable(checkpoint)
del checkpoint["hyper_parameters"]["model"]
checkpoint["logged_models"] = list(self._logged_models)
to_checkpoint = ["_feature_extractor", "_regressor"] + self.CHECKPOINT_MODELS
configs = {m: _get_hydra_config(getattr(self, m)) for m in to_checkpoint}
checkpoint["model_configs"] = configs

def _make_model_hparams_storable(self, checkpoint: Dict[str, Any]) -> None:
excluded_keys = set()
for keys in self._logged_models.values():
excluded_keys.update(keys)
checkpoint["hyper_parameters"] = {
k: v
for k, v in checkpoint["hyper_parameters"].items()
if k not in excluded_keys
}
checkpoint["logged_models"] = list(self._logged_models)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
for name, config in checkpoint["model_configs"].items():
setattr(self, name, hydra.utils.instantiate(config))
Expand Down
9 changes: 5 additions & 4 deletions tests/test_approach/test_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def test_checkpointing(tmp_path, approach, models):
utils.checkpoint(approach, ckpt_path)
restored = type(approach).load_from_checkpoint(ckpt_path)

assert approach.hparams == restored.hparams
paired_params = zip(approach.parameters(), restored.parameters())
for org_weight, restored_weight in paired_params:
assert torch.dist(org_weight, restored_weight) == 0.0
Expand Down Expand Up @@ -212,10 +213,10 @@ def test_model_hparams(approach_func):
approach.set_model(fe, reg)

assert approach.hparams == approach.hparams_initial
assert "model_feature_extractor_type" in approach.hparams_initial
assert approach.hparams_initial["model_feature_extractor_type"] == "LstmExtractor"
assert "model_regressor_type" in approach.hparams_initial
assert approach.hparams_initial["model_regressor_type"] == "FullyConnectedHead"
assert "model" in approach.hparams_initial
model_hparams = approach.hparams_initial["model"]
assert model_hparams["feature_extractor"]["type"] == "LstmExtractor"
assert model_hparams["regressor"]["type"] == "FullyConnectedHead"


@pytest.mark.integration
Expand Down

0 comments on commit 3683860

Please sign in to comment.