Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding checkpoint and config version #38

Merged
merged 3 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion hfgl/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
from enum import Enum
from pathlib import Path
from typing import Any, Optional
from typing import Annotated, Any, Optional

from everyvoice.config.preprocessing_config import PreprocessingConfig
from everyvoice.config.shared_types import (
Expand All @@ -27,6 +27,9 @@
model_validator,
)

# HiFiGANConfig's latest version number
LATEST_VERSION: str = "1.0"


# NOTE: We need to derive from both str and Enum if we want `HiFiGANResblock.one == "one"` to be True.
# Otherwise, `HiFiGANResblock.one == "one"` will be false.
Expand Down Expand Up @@ -155,6 +158,14 @@ def convert_to_HiFiGANTrainTypes(


class HiFiGANConfig(BaseModelWithContact):
VERSION: Annotated[
str,
Field(
default_factory=lambda: LATEST_VERSION,
init_var=False,
),
]

model: HiFiGANModelConfig = Field(
default_factory=HiFiGANModelConfig,
description="The model configuration settings.",
Expand Down Expand Up @@ -243,3 +254,23 @@ def check_upsample_rate_consistency(self) -> "HiFiGANConfig":
)

return self

@model_validator(mode="before")
@classmethod
def check_and_upgrade_checkpoint(cls, data: Any) -> Any:
"""
Check model's compatibility and possibly upgrade.
"""
from packaging.version import Version

ckpt_version = Version(data.get("VERSION", "0.0"))
if ckpt_version > Version(LATEST_VERSION):
raise ValueError(
"Your config was created with a newer version of EveryVoice, please update your software."
)
# Successively convert model checkpoints to newer version.
if ckpt_version < Version("1.0"):
# Converting to 1.0 just requires setting the VERSION field
data["VERSION"] = "1.0"

return data
40 changes: 40 additions & 0 deletions hfgl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,8 @@ class HiFiGANGenerator(pl.LightningModule):
for low-requirement model storage and inference.
"""

_VERSION: str = "1.0"

def __init__(self, config: dict | VocoderConfig):
super().__init__()

Expand All @@ -469,11 +471,45 @@ def __init__(self, config: dict | VocoderConfig):
self.config = config
self.generator = Generator(self.config)

def check_and_upgrade_checkpoint(self, checkpoint):
"""
Check model's compatibility and possibly upgrade.
"""
from packaging.version import Version

model_info = checkpoint.get(
"model_info",
{
"name": self.__class__.__name__,
"version": "1.0",
},
)

ckpt_model_type = model_info.get("name", "MISSING_TYPE")
if ckpt_model_type != self.__class__.__name__:
raise TypeError(
f"""Wrong model type ({ckpt_model_type}), we are expecting a '{ self.__class__.__name__ }' model"""
)

ckpt_version = Version(model_info.get("version", "0.0"))
if ckpt_version > Version(self._VERSION):
raise ValueError(
"Your model was created with a newer version of EveryVoice, please update your software."
)
# Successively convert model checkpoints to newer version.
if ckpt_version < Version("1.0"):
# Upgrading from 0.0 to 1.0 requires no changes; future versions might require changes
checkpoint["model_info"]["version"] = "1.0"

return checkpoint

def on_load_checkpoint(self, checkpoint):
"""Deserialize the checkpoint hyperparameters.
Note, this shouldn't fail on different versions of pydantic anymore,
but it will fail on breaking changes to the config. We should catch those exceptions
and handle them appropriately."""
checkpoint = self.check_and_upgrade_checkpoint(checkpoint)

try:
config = VocoderConfig(**checkpoint["hyper_parameters"]["config"])
except ValidationError as e:
Expand All @@ -486,6 +522,10 @@ def on_load_checkpoint(self, checkpoint):
def on_save_checkpoint(self, checkpoint):
"""Serialize the checkpoint hyperparameters"""
checkpoint["hyper_parameters"]["config"] = self.config.model_checkpoint_dump()
checkpoint["model_info"] = {
"name": self.__class__.__name__,
"version": self._VERSION,
}


class HiFiGAN(HiFiGANGenerator):
Expand Down
3 changes: 3 additions & 0 deletions hfgl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@ def load_hifigan_from_checkpoint(ckpt: dict, device) -> Tuple[HiFiGAN, HiFiGANCo
model = HiFiGAN(config).to(device)
else:
model = HiFiGANGenerator(config).to(device)

ckpt = model.check_and_upgrade_checkpoint(ckpt)
model.load_state_dict(ckpt["state_dict"])
model.generator.eval()
model.generator.remove_weight_norm()

return model, config


Expand Down