Skip to content

Commit

Permalink
feat: added a version number to HiFiGANConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelLarkin committed Oct 25, 2024
1 parent d9af916 commit 8dfea08
Showing 1 changed file with 30 additions and 1 deletion.
31 changes: 30 additions & 1 deletion hfgl/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
from enum import Enum
from pathlib import Path
from typing import Any, Optional
from typing import Annotated, Any, Optional

from everyvoice.config.preprocessing_config import PreprocessingConfig
from everyvoice.config.shared_types import (
Expand All @@ -27,6 +27,9 @@
model_validator,
)

# HiFiGANConfig's latest version number
LATEST_VERSION: str = "1.0"


# NOTE: We need to derive from both str and Enum if we want `HiFiGANResblock.one == "one"` to be True.
# Otherwise, `HiFiGANResblock.one == "one"` will be false.
Expand Down Expand Up @@ -155,6 +158,14 @@ def convert_to_HiFiGANTrainTypes(


class HiFiGANConfig(BaseModelWithContact):
VERSION: Annotated[
str,
Field(
default_factory=lambda: LATEST_VERSION,
init_var=False,
),
]

model: HiFiGANModelConfig = Field(
default_factory=HiFiGANModelConfig,
description="The model configuration settings.",
Expand Down Expand Up @@ -243,3 +254,21 @@ def check_upsample_rate_consistency(self) -> "HiFiGANConfig":
)

return self

@model_validator(mode="before")
@classmethod
def check_and_upgrade_checkpoint(cls, data: Any) -> Any:
"""
Check model's compatibility and possibly upgrade.
"""
ckpt_version = data.get("VERSION", "0.0")
if ckpt_version > LATEST_VERSION:
raise ValueError(
"Your config was created with a newer version of EveryVoice, please update your software."
)
# Successively convert model checkpoints to newer version.
if ckpt_version < "1.0":
# TODO: Write code to convert model to version 1.0.
data["VERSION"] = "1.0"

return data

0 comments on commit 8dfea08

Please sign in to comment.