Skip to content

Commit

Permalink
feat: handling wrong model type
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelLarkin committed Oct 8, 2024
1 parent 409e757 commit 7f396ed
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 5 deletions.
9 changes: 8 additions & 1 deletion hfgl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,22 @@ def synthesize(
),
):
"""Given some Mel spectrograms and a trained model, generate some audio. i.e. perform *copy synthesis*"""
import sys

import torch
from pydantic import ValidationError
from scipy.io.wavfile import write

from .utils import load_hifigan_from_checkpoint, synthesize_data

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(generator_path, map_location=device)
data = torch.load(data_path, map_location=device)
vocoder_model, vocoder_config = load_hifigan_from_checkpoint(checkpoint, device)
try:
vocoder_model, vocoder_config = load_hifigan_from_checkpoint(checkpoint, device)
except (TypeError, ValidationError) as e:
logger.error(f"Unable to load {generator_path}: {e}")
sys.exit(1)
wav, sr = synthesize_data(data, vocoder_model, vocoder_config)
logger.info(f"Writing file {data_path}.wav")
write(f"{data_path}.wav", sr, wav)
Expand Down
29 changes: 26 additions & 3 deletions hfgl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
dynamic_range_compression_torch,
get_spectral_transform,
)
from loguru import logger
from pydantic import ValidationError
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
from torch.nn.utils import spectral_norm
from torch.nn.utils.parametrizations import weight_norm
Expand Down Expand Up @@ -454,8 +456,16 @@ class HiFiGANGenerator(pl.LightningModule):

def __init__(self, config: dict | VocoderConfig):
super().__init__()

if not isinstance(config, VocoderConfig):
config = VocoderConfig(**config)
try:
config = VocoderConfig(**config)
except ValidationError as e:
logger.error(f"{e}")
raise TypeError(
"Unable to load config. Possible causes: is it really a VocoderConfig? or the correct version?"
)

self.config = config
self.generator = Generator(self.config)

Expand All @@ -464,7 +474,14 @@ def on_load_checkpoint(self, checkpoint):
Note, this shouldn't fail on different versions of pydantic anymore,
but it will fail on breaking changes to the config. We should catch those exceptions
and handle them appropriately."""
self.config = VocoderConfig(**checkpoint["hyper_parameters"]["config"])
try:
config = VocoderConfig(**checkpoint["hyper_parameters"]["config"])
except ValidationError as e:
logger.error(f"{e}")
raise TypeError(
"Unable to load config. Possible causes: is it really a VocoderConfig? or the correct version?"
)
self.config = config

def on_save_checkpoint(self, checkpoint):
"""Serialize the checkpoint hyperparameters"""
Expand All @@ -479,7 +496,13 @@ def __init__(self, config: dict | VocoderConfig):
# Because we serialize the configurations when saving checkpoints,
# sometimes what is passed is actually just a dict.
if not isinstance(config, VocoderConfig):
config = VocoderConfig(**config)
try:
config = VocoderConfig(**config)
except ValidationError as e:
logger.error(f"{e}")
raise TypeError(
"Unable to load config. Possible causes: is it really a VocoderConfig? or the correct version?"
)
self.config = config
self.mpd = MultiPeriodDiscriminator(config)
self.msd = MultiScaleDiscriminator(config)
Expand Down
11 changes: 10 additions & 1 deletion hfgl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import torch
from everyvoice.utils import pydantic_validation_error_shortener
from everyvoice.utils.heavy import get_spectral_transform
from loguru import logger

Expand All @@ -23,7 +24,15 @@ def sizeof_fmt(num, suffix="B"):
def load_hifigan_from_checkpoint(ckpt: dict, device) -> Tuple[HiFiGAN, HiFiGANConfig]:
config: dict | HiFiGANConfig = ckpt["hyper_parameters"]["config"]
if isinstance(config, dict):
config = HiFiGANConfig(**config)
from pydantic import ValidationError

try:
config = HiFiGANConfig(**config)
except ValidationError as e:
logger.error(f"{pydantic_validation_error_shortener(e)}")
raise TypeError(
"Unable to load config. Possible causes: is it really a VocoderConfig? or the correct version?"
)
if any(
(
key.startswith("mpd") or key.startswith("msd")
Expand Down

0 comments on commit 7f396ed

Please sign in to comment.