Skip to content

Commit

Permalink
feat: added a version number to Aligner
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelLarkin committed Nov 5, 2024
1 parent f2cec61 commit 6cd45f3
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions dfaligner/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def forward(self, x):


class Aligner(pl.LightningModule):
_VERSION: str = "1.0"

def __init__(
self,
config: dict | DFAlignerConfig,
Expand Down Expand Up @@ -95,16 +97,52 @@ def forward(self, x):
x = self.lin(x)
return x

def check_and_upgrade_checkpoint(self, checkpoint):
"""
Check model's compatibility and possibly upgrade.
"""
model_info = checkpoint.get(
"model_info",
{
"name": self.__class__.__name__,
"version": "1.0",
},
)

ckpt_model_type = model_info.get("name", "MISSING_TYPE")
if ckpt_model_type != self.__class__.__name__:
raise TypeError(
f"""Wrong model type ({ckpt_model_type}), we are expecting a '{ self.__class__.__name__ }' model"""
)

ckpt_version = model_info.get("version", "0.0")
if ckpt_version > self._VERSION:
raise ValueError(
"Your model was created with a newer version of EveryVoice, please update your software."
)
# Successively convert model checkpoints to newer version.
if ckpt_version < "1.0":
# Upgrading from 0.0 to 1.0 requires no changes; future versions might require changes
checkpoint["model_info"]["version"] = "1.0"

return checkpoint

def on_load_checkpoint(self, checkpoint):
"""Deserialize the checkpoint hyperparameters.
Note, this shouldn't fail on different versions of pydantic anymore,
but it will fail on breaking changes to the config. We should catch those exceptions
and handle them appropriately."""
checkpoint = self.check_and_upgrade_checkpoint(checkpoint)

self.config = AlignerConfig(**checkpoint["hyper_parameters"]["config"])

def on_save_checkpoint(self, checkpoint):
"""Serialize the checkpoint hyperparameters"""
checkpoint["hyper_parameters"]["config"] = self.config.model_checkpoint_dump()
checkpoint["model_info"] = {
"name": self.__class__.__name__,
"version": self._VERSION,
}

def configure_optimizers(self):
optim = torch.optim.AdamW(
Expand Down

0 comments on commit 6cd45f3

Please sign in to comment.