diff --git a/dfaligner/config/__init__.py b/dfaligner/config/__init__.py index cf82f12..f64b89e 100644 --- a/dfaligner/config/__init__.py +++ b/dfaligner/config/__init__.py @@ -1,6 +1,6 @@ from enum import Enum from pathlib import Path -from typing import Any, Optional +from typing import Annotated, Any, Optional from everyvoice.config.preprocessing_config import PreprocessingConfig from everyvoice.config.shared_types import ( @@ -17,6 +17,9 @@ from everyvoice.utils import load_config_from_json_or_yaml_path from pydantic import Field, FilePath, ValidationInfo, field_serializer, model_validator +# DFAlignerConfig's latest version number +LATEST_VERSION: str = "1.0" + class DFAlignerExtractionMethod(Enum): beam = "beam" @@ -61,6 +64,14 @@ def convert_extraction_method_enum( class DFAlignerConfig(BaseModelWithContact): + VERSION: Annotated[ + str, + Field( + default_factory=lambda: LATEST_VERSION, + init_var=False, + ), + ] + # TODO FastSpeech2Config and DFAlignerConfig are almost identical. model: DFAlignerModelConfig = Field( default_factory=DFAlignerModelConfig, @@ -108,4 +119,22 @@ def load_config_from_path(path: Path) -> "DFAlignerConfig": config = DFAlignerConfig(**config) return config + @model_validator(mode="before") + @classmethod + def check_and_upgrade_checkpoint(cls, data: Any) -> Any: + """ + Check model's compatibility and possibly upgrade. + """ + ckpt_version = data.get("VERSION", "0.0") + if ckpt_version > LATEST_VERSION: + raise ValueError( + "Your config was created with a newer version of EveryVoice, please update your software." + ) + # Successively convert model checkpoints to newer version. + if ckpt_version < "1.0": + # Converting to 1.0 just requires setting the VERSION field + data["VERSION"] = "1.0" + + return data + # INPUT_TODO: initialize text with union of symbols from dataset