Skip to content

Commit

Permalink
feat: display a spinner while importing expensive things
Browse files Browse the repository at this point in the history
  • Loading branch information
joanise committed Oct 9, 2024
1 parent 817ace4 commit e564bda
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions hfgl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
preprocess_base_command_interface,
train_base_command_interface,
)
from everyvoice.utils import spinner
from loguru import logger
from merge_args import merge_args

Expand All @@ -33,9 +34,10 @@ def preprocess(
),
**kwargs,
):
from everyvoice.base_cli.helpers import preprocess_base_command
with spinner():
from everyvoice.base_cli.helpers import preprocess_base_command

from .config import HiFiGANConfig
from .config import HiFiGANConfig

preprocess_base_command(
model_config=HiFiGANConfig,
Expand All @@ -47,11 +49,12 @@ def preprocess(
@app.command()
@merge_args(train_base_command_interface)
def train(**kwargs):
from everyvoice.base_cli.helpers import train_base_command
with spinner():
from everyvoice.base_cli.helpers import train_base_command

from .config import HiFiGANConfig
from .dataset import HiFiGANDataModule
from .model import HiFiGAN
from .config import HiFiGANConfig
from .dataset import HiFiGANDataModule
from .model import HiFiGAN

train_base_command(
model_config=HiFiGANConfig,
Expand Down Expand Up @@ -86,9 +89,10 @@ def export(
):
import os

import torch
with spinner():
import torch

from .utils import sizeof_fmt
from .utils import sizeof_fmt

orig_size = sizeof_fmt(os.path.getsize(model_path))
vocoder_ckpt = torch.load(model_path, map_location=torch.device("cpu"))
Expand Down Expand Up @@ -132,11 +136,12 @@ def synthesize(
"""Given some Mel spectrograms and a trained model, generate some audio. i.e. perform *copy synthesis*"""
import sys

import torch
from pydantic import ValidationError
from scipy.io.wavfile import write
with spinner():
import torch
from pydantic import ValidationError
from scipy.io.wavfile import write

from .utils import load_hifigan_from_checkpoint, synthesize_data
from .utils import load_hifigan_from_checkpoint, synthesize_data

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(generator_path, map_location=device)
Expand Down

0 comments on commit e564bda

Please sign in to comment.