diff --git a/hfgl/config/__init__.py b/hfgl/config/__init__.py index 7f30571..aebdbb2 100644 --- a/hfgl/config/__init__.py +++ b/hfgl/config/__init__.py @@ -1,7 +1,7 @@ import math from enum import Enum from pathlib import Path -from typing import Any, Optional +from typing import Annotated, Any, Optional from everyvoice.config.preprocessing_config import PreprocessingConfig from everyvoice.config.shared_types import ( @@ -27,6 +27,9 @@ model_validator, ) +# HiFiGANConfig's latest version number +LATEST_VERSION: str = "1.0" + # NOTE: We need to derive from both str and Enum if we want `HiFiGANResblock.one == "one"` to be True. # Otherwise, `HiFiGANResblock.one == "one"` will be false. @@ -155,6 +158,14 @@ def convert_to_HiFiGANTrainTypes( class HiFiGANConfig(BaseModelWithContact): + VERSION: Annotated[ + str, + Field( + default_factory=lambda: LATEST_VERSION, + init_var=False, + ), + ] + model: HiFiGANModelConfig = Field( default_factory=HiFiGANModelConfig, description="The model configuration settings.", @@ -243,3 +254,23 @@ def check_upsample_rate_consistency(self) -> "HiFiGANConfig": ) return self + + @model_validator(mode="before") + @classmethod + def check_and_upgrade_checkpoint(cls, data: Any) -> Any: + """ + Check model's compatibility and possibly upgrade. + """ + from packaging.version import Version + + ckpt_version = Version(data.get("VERSION", "0.0")) + if ckpt_version > Version(LATEST_VERSION): + raise ValueError( + "Your config was created with a newer version of EveryVoice, please update your software." + ) + # Successively convert model checkpoints to newer version. + if ckpt_version < Version("1.0"): + # Converting to 1.0 just requires setting the VERSION field + data["VERSION"] = "1.0" + + return data diff --git a/hfgl/model.py b/hfgl/model.py index 846bf9d..46ef6e5 100644 --- a/hfgl/model.py +++ b/hfgl/model.py @@ -454,6 +454,8 @@ class HiFiGANGenerator(pl.LightningModule): for low-requirement model storage and inference. """ + _VERSION: str = "1.0" + def __init__(self, config: dict | VocoderConfig): super().__init__() @@ -469,11 +471,45 @@ def __init__(self, config: dict | VocoderConfig): self.config = config self.generator = Generator(self.config) + def check_and_upgrade_checkpoint(self, checkpoint): + """ + Check model's compatibility and possibly upgrade. + """ + from packaging.version import Version + + model_info = checkpoint.get( + "model_info", + { + "name": self.__class__.__name__, + "version": "1.0", + }, + ) + + ckpt_model_type = model_info.get("name", "MISSING_TYPE") + if ckpt_model_type != self.__class__.__name__: + raise TypeError( + f"""Wrong model type ({ckpt_model_type}), we are expecting a '{ self.__class__.__name__ }' model""" + ) + + ckpt_version = Version(model_info.get("version", "0.0")) + if ckpt_version > Version(self._VERSION): + raise ValueError( + "Your model was created with a newer version of EveryVoice, please update your software." + ) + # Successively convert model checkpoints to newer version. + if ckpt_version < Version("1.0"): + # Upgrading from 0.0 to 1.0 requires no changes; future versions might require changes + checkpoint["model_info"]["version"] = "1.0" + + return checkpoint + def on_load_checkpoint(self, checkpoint): """Deserialize the checkpoint hyperparameters. Note, this shouldn't fail on different versions of pydantic anymore, but it will fail on breaking changes to the config. We should catch those exceptions and handle them appropriately.""" + checkpoint = self.check_and_upgrade_checkpoint(checkpoint) + try: config = VocoderConfig(**checkpoint["hyper_parameters"]["config"]) except ValidationError as e: @@ -486,6 +522,10 @@ def on_load_checkpoint(self, checkpoint): def on_save_checkpoint(self, checkpoint): """Serialize the checkpoint hyperparameters""" checkpoint["hyper_parameters"]["config"] = self.config.model_checkpoint_dump() + checkpoint["model_info"] = { + "name": self.__class__.__name__, + "version": self._VERSION, + } class HiFiGAN(HiFiGANGenerator): diff --git a/hfgl/utils.py b/hfgl/utils.py index 2e88274..b400478 100644 --- a/hfgl/utils.py +++ b/hfgl/utils.py @@ -42,9 +42,12 @@ def load_hifigan_from_checkpoint(ckpt: dict, device) -> Tuple[HiFiGAN, HiFiGANCo model = HiFiGAN(config).to(device) else: model = HiFiGANGenerator(config).to(device) + + ckpt = model.check_and_upgrade_checkpoint(ckpt) model.load_state_dict(ckpt["state_dict"]) model.generator.eval() model.generator.remove_weight_norm() + return model, config