-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathrun_train.py
executable file
·106 lines (91 loc) · 4.87 KB
/
run_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = 'Scott H. Hawley'
__version__ = '0.0.2'
# imports
import numpy as np
import torch
import os
import sys
import glob
import argparse
import matplotlib
matplotlib.use('Agg')
import signaltrain as st
if __name__ == "__main__":
# Set up random number generators and decide which device we'll run on
np.random.seed(218)
torch.manual_seed(218)
if torch.cuda.is_available():
device = torch.device("cuda:0")
torch.cuda.manual_seed(218)
torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
device = torch.device("cpu")
torch.set_default_tensor_type('torch.FloatTensor')
# Parse command line arguments
parser = argparse.ArgumentParser(description="Trains neural network to reproduce input-output transformations.",\
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--apex', help="optimization setting to use with NVIDIA apex", default="O0")
parser.add_argument('-b', '--batch', type=int, help="batch size", default=200)
parser.add_argument('--checkpoint', help='Name of model checkpoint .tar file', default="modelcheckpoint.tar")
parser.add_argument('-c','--compand', help='Turn on to use companded/decompanded audio', action='store_true')
parser.add_argument('--effect', help='Name of effect to use. ("files" = search for "target_" and effect_info.ini files in path)', default="comp_4c")
parser.add_argument('--epochs', type=int, help='Number of epochs to run', default=1000)
parser.add_argument('--lrmax', type=float, help="max learning rate", default=1e-4) # Note: lrmax should be obtained by running lr_finder in learningrate.py
parser.add_argument('-n', '--num', type=int, help='Number of "data points" (audio clips) per epoch', default=200000)
parser.add_argument('--path', help='Directory to pull input (and maybe target) data from (default: None, means only synthesized-on-the-fly data)', default=None)
parser.add_argument('--sr', type=int, help='Sampling rate', default=44100)
parser.add_argument('--scale', type=float, help='Scale factor (of input size & whole model)', default=1.0)
parser.add_argument('--shrink', type=int, help='Shink output chunk relative to input by this divisor', default=4)
parser.add_argument('-t','--target', help="type of target: chunk or stream", default="stream")
args = parser.parse_args()
# print command line as it was invoked (for reading nohup.out later)
print("Command line: "," ".join(sys.argv[:]))
# Check arguments before beginning to train....
# establish which audio effect class is being used
e = args.effect
if e == 'files': # target outputs are given as files rather than 'live' 'plugins'
# TODO: check to make sure there are a suitable number of 'target' files in path
effect = st.audio.FileEffect(args.path)
elif e == 'comp_4c':
effect = st.audio.Compressor_4c()
elif e == 'comp':
effect = st.audio.Compressor()
elif e == 'comp_t':
effect = st.audio.Comp_Just_Thresh()
elif e == 'comp_large':
effect = st.audio.Compressor_4c_Large()
elif e == 'comp_one':
effect = st.audio.Compressor_4c_OneSetting()
elif e == 'denoise':
effect = st.audio.Denoise()
elif e == 'lowpass':
effect = st.audio.LowPass()
elif 'VST' in e:
print("VST plugins not integrated yet, but that would be great.")
print("Feel free to grab Igor Gadelha' VSTRender lib to help implement this.")
print("See https://github.com/igorgad/dpm")
sys.exit(1)
else:
print(f"Effect option '{e}' is not yet added")
sys.exit(1)
# this is just to avoid confusion: the datagenerator class will/should trap for this also.
if (args.path is None) or (not glob.glob(args.path+"/Train/input*")) \
or (not glob.glob(args.path+"/Val/input*")): # no input files = 100% probability of synth'ing input data
args.synthprob = 1.0 # this isn't used yet, but passed anyway for later use
if effect is st.audio.FileEffect:
args.synthprob = 0.0 # can't run pre-recorded effects post-facto
if args.target not in ["chunk","stream"]:
print(f"Error, invalid target type: {args.target}")
sys.exit(1)
# Finished parsing/checking arguments, ready to run
st.misc.print_choochoo(__version__) # ascii art is the hallmark of professionalism
print("Running with args =",args)
# call the trianing routine
model = st.train.train(epochs=args.epochs, n_data_points=args.num, batch_size=args.batch, device=device, sr=args.sr,\
effect=effect, datapath=args.path, scale_factor=args.scale, shrink_factor=args.shrink,
apex_opt=args.apex, target_type=args.target, lr_max=args.lrmax, in_checkpointname=args.checkpoint,
compand=args.compand)
print("run_train.py: Execution completed.")
# EOF