diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..755684d --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "espnet"] + path = espnet + url = https://github.com/espnet/espnet diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..da96fec --- /dev/null +++ b/LICENSE @@ -0,0 +1,32 @@ +SOFTWARE LICENSE AGREEMENT FOR EVALUATION + +This SOFTWARE EVALUATION LICENSE AGREEMENT (this "Agreement") is a legal contract between a person who uses or otherwise accesses or installs the Software (“User(s)”), and Nippon Telegraph and Telephone corporation ("NTT"). +READ THE TERMS AND CONDITIONS OF THIS AGREEMENT CAREFULLY BEFORE INSTALLING OR OTHERWISE ACCESSING OR USING NTT'S PROPRIETARY SOFTWARE ACCOMPANIED BY THIS AGREEMENT (the "SOFTWARE"). THE SOFTWARE IS COPYRIGHTED AND IT IS LICENSED TO USER UNDER THIS AGREEMENT, NOT SOLD TO USER. BY INSTALLING OR OTHERWISE ACCESSING OR USING THE SOFTWARE, USER ACKNOWLEDGES THAT USER HAS READ THIS AGREEMENT, THAT USER UNDERSTANDS IT, AND THAT USER ACCEPTS AND AGREES TO BE BOUND BY ITS TERMS. IF AT ANY TIME USER IS NOT WILLING TO BE BOUND BY THE TERMS OF THIS AGREEMENT, USER SHOULD TERMINATE THE INSTALLATION PROCESS, IMMEDIATELY CEASE AND REFRAIN FROM ACCESSING OR USING THE SOFTWARE AND DELETE ANY COPIES USER MAY HAVE. THIS AGREEMENT REPRESENTS THE ENTIRE AGREEMENT BETWEEN USER AND NTT CONCERNING THE SOFTWARE. + + +BACKGROUND +A. NTT is the owner of all rights, including all patent rights, and copyrights in and to the Software and related documentation listed in Exhibit A to this Agreement. +B. User wishes to obtain a royalty free license to use the Software to enable User to evaluate, and NTT wishes to grant such a license to User, pursuant and subject to the terms and conditions of this Agreement. +C. As a condition to NTT's provision of the Software to User, NTT has required User to execute this Agreement. +In consideration of these premises, and the mutual promises and conditions in this Agreement, the parties hereby agree as follows: +1. Grant of Evaluation License. NTT hereby grants to User, and User hereby accepts, under the terms and conditions of this Agreement, a royalty free, nontransferable and nonexclusive license to use the Software internally for the purposes of testing, analyzing, and evaluating the methods or mechanisms as shown in "NTT Neural Machine Translation Systems at WAT 2017, Morishita et al., WAT 2017". User may make a reasonable number of backup copies of the Software solely for User's internal use pursuant to the license granted in this Section 1. +2. Shipment and Installation. NTT will ship or deliver the Software by any method that NTT deems appropriate. User shall be solely responsible for proper installation of the Software. +3. Term. This Agreement is effective whichever is earlier (i) upon User’s acceptance of the Agreement, or (ii) upon User’s installing, accessing, and using the Software, even if User has not expressly accepted this Agreement. Without prejudice to any other rights, NTT may terminate this Agreement without notice to User. User may terminate this Agreement at any time by User’s decision to terminate the Agreement to NTT and ceasing use of the Software. Upon any termination or expiration of this Agreement for any reason, User agrees to uninstall the Software and destroy all copies of the Software. +4. Proprietary Rights +(a) The Software is the valuable and proprietary property of NTT, and NTT shall retain exclusive title to this property both during the term and after the termination of this Agreement. Without limitation, User acknowledges that all patent rights and copyrights in the Software shall remain the exclusive property of NTT at all times. User shall use not less than reasonable care in safeguarding the confidentiality of the Software. +(b) USER SHALL NOT, IN WHOLE OR IN PART, AT ANY TIME DURING THE TERM OF OR AFTER THE TERMINATION OF THIS AGREEMENT: (i) SELL, ASSIGN, LEASE, DISTRIBUTE, OR OTHERWISE TRANSFER THE SOFTWARE TO ANY THIRD PARTY; (ii) EXCEPT AS OTHERWISE PROVIDED HEREIN, COPY OR REPRODUCE THE SOFTWARE IN ANY MANNER; OR (iii) ALLOW ANY PERSON OR ENTITY TO COMMIT ANY OF THE ACTIONS DESCRIBED IN (i) THROUGH (ii) ABOVE. +(c) User shall take appropriate action, by instruction, agreement, or otherwise, with respect to its employees permitted under this Agreement to have access to the Software to ensure that all of User's obligations under this Section 4 shall be satisfied. +5. Indemnity. User shall defend, indemnify and hold harmless NTT, its agents and employees, from any loss, damage, or liability arising in connection with User's improper or unauthorized use of the Software. NTT SHALL HAVE THE SOLE RIGHT TO CONDUCT DEFEND ANY ACTTION RELATING TO THE SOFTWARE. +6. Disclaimer. THE SOFTWARE IS LICENSED TO USER "AS IS," WITHOUT ANY TRAINING, MAINTENANCE, OR SERVICE OBLIGATIONS WHATSOEVER ON THE PART OF NTT. NTT MAKES NO EXPRESS OR IMPLIED WARRANTIES OF ANY TYPE WHATSOEVER, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF MERCHANTABILITY, OF FITNESS FOR A PARTICULAR PURPOSE AND OF NON-INFRINGEMENT ON COPYRIGHT OR ANY OTHER RIGHT OF THIRD PARTIES. USER ASSUMES ALL RISKS ASSOCIATED WITH ITS USE OF THE SOFTWARE, INCLUDING WITHOUT LIMITATION RISKS RELATING TO QUALITY, PERFORMANCE, DATA LOSS, AND UTILITY IN A PRODUCTION ENVIRONMENT. +7. Limitation of Liability. IN NO EVENT SHALL NTT BE LIABLE TO USER OR TO ANY THIRD PARTY FOR ANY INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING BUT NOT LIMITED TO DAMAGES FOR PERSONAL INJURY, PROPERTY DAMAGE, LOST PROFITS, OR OTHER ECONOMIC LOSS, ARISING IN CONNECTION WITH USER'S USE OF OR INABILITY TO USE THE SOFTWARE, IN CONNECTION WITH NTT'S PROVISION OF OR FAILURE TO PROVIDE SERVICES PERTAINING TO THE SOFTWARE, OR AS A RESULT OF ANY DEFECT IN THE SOFTWARE. THIS DISCLAIMER OF LIABILITY SHALL APPLY REGARD¬LESS OF THE FORM OF ACTION THAT MAY BE BROUGHT AGAINST NTT, WHETHER IN CONTRACT OR TORT, INCLUDING WITHOUT LIMITATION ANY ACTION FOR NEGLIGENCE. USER'S SOLE REMEDY IN THE EVENT OF ANY BREACH OF THIS AGREEMENT BY NTT SHALL BE TERMINATION PURSUANT TO SECTION 3. +8. No Assignment or Sublicense. Neither this Agreement nor any right or license under this Agreement, nor the Software, may be sublicensed, assigned, or otherwise transferred by User without NTT's prior written consent. +9. General +(a) If any provision, or part of a provision, of this Agreement is or becomes illegal, unenforceable, or invalidated, by operation of law or otherwise, that provision or part shall to that extent be deemed omitted, and the remainder of this Agreement shall remain in full force and effect. +(b) This Agreement is the complete and exclusive statement of the agreement between the parties with respect to the subject matter hereof, and supersedes all written and oral contracts, proposals, and other communications between the parties relating to that subject matter. +(c) Subject to Section 8, this Agreement shall be binding on, and shall inure to the benefit of, the respective successors and assigns of NTT and User. +(d) If either party to this Agreement initiates a legal action or proceeding to enforce or interpret any part of this Agreement, the prevailing party in such action shall be entitled to recover, as an element of the costs of such action and not as damages, its attorneys' fees and other costs associated with such action or proceeding. +(e) This Agreement shall be governed by and interpreted under the laws of Japan, without reference to conflicts of law principles. All disputes arising out of or in connection with this Agreement shall be finally settled by arbitration in Tokyo in accordance with the Commercial Arbitration Rules of the Japan Commercial Arbitration Association. The arbitration shall be conducted by three (3) arbitrators and in Japanese. The award rendered by the arbitrators shall be final and binding upon the parties. Judgment upon the award may be entered in any court having jurisdiction thereof. +(f) NTT shall not be liable to the User or to any third party for any delay or failure to perform NTT’s obligation set forth under this Agreement due to any cause beyond NTT’s reasonable control. +  +EXHIBIT A + diff --git a/README.org b/README.org new file mode 100644 index 0000000..0bcaadf --- /dev/null +++ b/README.org @@ -0,0 +1,74 @@ +* ESPnet extensions for semi-supervised end-to-end speech recognition + +This repository contains evaluation scripts used in our paper +#+begin_quote +Shigeki Karita, Shinji Watanabe, Tomoharu Iwata, Atsunori Ogawa, Marc Delcroix, "Semi-Supervised End-to-End Speech Recognition," INTERSPEECH, 2018 +#+end_quote +A PDF file will be available in [[https://www.isca-speech.org/iscaweb/index.php/archive/online-archive][ISCA Online Archive]]. + +** how to setup + +#+begin_src bash +$ git clone https://github.com/nttcslab-sp/espnet-semisupervised --recursive +$ cd espnet-semisupervised/espnet/tools; make PYTHON_VERSION=3 -f conda.mk +$ cd ../.. +$ ./run.sh --gpu 0 --wsj0 --wsj1 +#+end_src + +NOTE: you need to install pytorch 0.3.1. + +** scripts + +in root dir + +- run.sh : end-to-end recipe for this experiment (do not forget to set --gpu 0 if you have that) +- sbatch.sh : slurm job script for sevaral pair/unpair data ratio and hyper parameter search (requires finished run_retrain_wsj.sh expdir for pretrained model params) + +in ~shell/~ dir + +- show_results.sh : summarize CER/WER/SER from decoded results of dev93/test92 sets (usage: `show_results.sh exp/train_si84_xxx`) +- decode.sh : a script for decode and evaluate training model (usage: `decode.sh --expdir exp/train_si84_xxx`) +- debug.sh : we recommend to ~source debug.sh~ before using ipython to set path to everything you need + +in ~python/~ dir + +- asr_train_loop_th.py : is a python script for initial-training with the paired dataset (train_si84) +- retrain_loop_th.py : is a python script for re-training with the unpaired dataset (train_si284) +- unsupervised_recog_th.py : is a python script for decoding by the re-trained model +- unsupervised.py : implements pytorch model for paired/unpaired learning +- results.py : implements chainer like reporter without chainer iterator used in training loop + +** results + +| train_set | dev93 Acc | dev93 CER | eval92 CER | dev93 WER | eval92 WER | dev93 SER | eval92 SER | path | +|--------------------------------------------+-----------+-----------+------------+-----------+------------+-----------+------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------| +| train_si84 (7138, 15 hours) | 77.6 | 25.4 | 15.8 | 61.9 | 44.2 | 99.8 | 98.5 | exp/train_si84_blstmp_e6_subsample1_2_2_1_1_unit320_proj320_d1_unit300_location_aconvc10_aconvf100_mtlalpha0.5_adadelta_bs30_mli800_mlo150 | +| + train_si284 RNNLM | | 19.3 | 16.6 | 51.3 | 47.7 | 99.8 | 99.7 | exp/rnnlm_train_si84_blstmp_e6_subsample1_2_2_1_1_unit320_proj320_d1_unit300_location_aconvc10_aconvf100_mtlalpha0.5_adadelta_bs30_mli800_mlo150_epochs15 | +|--------------------------------------------+-----------+-----------+------------+-----------+------------+-----------+------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------| +| + unpaired train_si284 retrain | 83.8 | 28.2 | 15.6 | 61.2 | 40.5 | 99.6 | 97.6 | ./exp/train_si84_retrain_None_alpha0.5_adadelta_lr1.0_bs30_el6_dl1_att_location_batch30_data_loss0.9 | +| + RNNLM | | 22.1 | 17.2 | 51.6 | 44.2 | 99.0 | 99.4 | ./exp/train_si84_retrain_None_alpha0.5_adadelta_lr1.0_bs30_el6_dl1_att_location_batch30_data_loss0.9/rnnlm0.1 | +| + unpaired train_si284 retrain w/ GAN-si84 | 83.5 | 26.3 | 15.0 | 59.9 | 40.0 | 99.4 | 97.3 | exp/train_si84_paired_hidden_gan_alpha0.5_bnFalse_adadelta_lr1.0_bs30_el6_dl1_att_location_batch30_data_loss0.9_st0.5_train_si84_epochs15 | +| + unpaired train_si284 retrain w/ KL-si84 | 83.6 | 28.5 | 15.6 | 60.5 | 40.4 | 99.6 | 97.3 | exp/train_si84_paired_hidden_gausslogdet_alpha0.5_bnFalse_adadelta_lr1.0_bs30_el6_dl1_att_location_batch30_data_loss0.9_st0.9_train_si84_epochs15 | +| + unpaired train_si284 retrain w/ GAN | 84.2 | 22.1 | 17.9 | 50.9 | 44.2 | 99.2 | 99.4 | ./exp/train_si84_retrain84_gan_alpha0.5_adadelta_lr1.0_bs30_el6_dl1_att_location_batch30_data_loss0.9_st0.9_train_si84_iter5 | +| + RNNLM | | 22.1 | 17.9 | 50.9 | 44.2 | 99.2 | 99.4 | ./exp/train_si84_retrain84_gan_alpha0.5_adadelta_lr1.0_bs30_el6_dl1_att_location_batch30_data_loss0.9_st0.9_train_si84_iter5/rnnlm0.2 | +| + unpaired train_si284 retrain w/ KL | 84.0 | 24.8 | 14.4 | 58.1 | 39.5 | 99.6 | 96.4 | ./exp/train_si84_ret3_gausslogdet_alpha0.5_bnFalse_adadelta_lr1.0_bs30_el6_dl1_att_location_batch30_data_loss0.9_st0.5_train_si84_epochs30 | +| + RNNLM | | 20.0 | 16.9 | 48.9 | 42.7 | 99.0 | 99.1 | ./exp/train_si84_retrain84_gausslogdet_alpha0.5_adadelta_lr1.0_bs30_el6_dl1_att_location_batch30_data_loss0.99_st0.99_train_si84/rnnlm0.2 | +| + unpaired train_si284 retrain w/ MMD | 82.9 | 25.9 | 13.9 | 59.7 | 38.4 | 99.2 | 96.7 | ./exp/train_si84_ret3_mmd_alpha0.5_bnFalse_adadelta_lr1.0_bs30_el6_dl1_att_location_batch30_data_loss0.5_st0.99_train_si84_epochs30 | +|--------------------------------------------+-----------+-----------+------------+-----------+------------+-----------+------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------| +| train_si284 (37416 utt, 81 hours) | 93.9 | 8.1 | 6.3 | 23.8 | 18.9 | 92.4 | 87.4 | exp/train_si284_blstmp_e6_subsample1_2_2_1_1_unit320_proj320_d1_unit300_location_aconvc10_aconvf100_mtlalpha0.5_adadelta_bs30_mli800_mlo150 | +| + train_si284 RNNLM | | 7.9 | 6.1 | 22.7 | 18.3 | 89.7 | 84.1 | ./exp/rnnlm_train_si284_blstmp_e6_subsample1_2_2_1_1_unit320_proj320_d1_unit300_location_aconvc10_aconvf100_mtlalpha0.5_adadelta_bs30_mli800_mlo150_epochs15 | + + +- Acc: character accuracy during training with forced decoding +- CER: character error rate (edit distance based error) +- WER: word error rate (edit distance based error) +- SER: sentence error rate (exact match error) +- all the exp path starts with ~exp/...~ is placed to ~/nfs/kswork/kishin/karita/experiments/espnet-unspervised/egs/wsj/unsupervised~ on NTT ks-servers + +smaller paired train data results + +[[plot.png]] + +** contact + +email: karita.shigeki@lab.ntt.co.jp diff --git a/cmd.sh b/cmd.sh new file mode 120000 index 0000000..f3f2969 --- /dev/null +++ b/cmd.sh @@ -0,0 +1 @@ +./espnet/egs/wsj/asr1/cmd.sh \ No newline at end of file diff --git a/conf b/conf new file mode 120000 index 0000000..a7aa1e1 --- /dev/null +++ b/conf @@ -0,0 +1 @@ +./espnet/egs/wsj/asr1/conf \ No newline at end of file diff --git a/espnet b/espnet new file mode 160000 index 0000000..8bb00b4 --- /dev/null +++ b/espnet @@ -0,0 +1 @@ +Subproject commit 8bb00b4cb3869ebdb39aedebe4f241bc34cd4b2a diff --git a/local b/local new file mode 120000 index 0000000..dcd3f6a --- /dev/null +++ b/local @@ -0,0 +1 @@ +./espnet/egs/wsj/asr1/local \ No newline at end of file diff --git a/path.sh b/path.sh new file mode 100644 index 0000000..c4b6611 --- /dev/null +++ b/path.sh @@ -0,0 +1,16 @@ +MAIN_ROOT=$PWD/espnet +KALDI_ROOT=$MAIN_ROOT/tools/kaldi +SPNET_ROOT=$MAIN_ROOT/src + +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/tools/sctk/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C + +export PATH=$PWD/python:$PWD/shell:$SPNET_ROOT/utils/:$SPNET_ROOT/bin/:$PATH +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:$MAIN_ROOT/tools/chainer_ctc/ext/warp-ctc/build +source $MAIN_ROOT/tools/venv/bin/activate +export PYTHONPATH=$PWD/python:$SPNET_ROOT/lm/:$SPNET_ROOT/asr/:$SPNET_ROOT/nets/:$SPNET_ROOT/utils/:$SPNET_ROOT/bin/:$PYTHONPATH + +export OMP_NUM_THREADS=1 diff --git a/plot.png b/plot.png new file mode 100644 index 0000000..25f87aa Binary files /dev/null and b/plot.png differ diff --git a/python/asr_train_loop_th.py b/python/asr_train_loop_th.py new file mode 100755 index 0000000..3b061d3 --- /dev/null +++ b/python/asr_train_loop_th.py @@ -0,0 +1,324 @@ +#!/usr/bin/env python +import argparse +import collections +import contextlib +import copy +import json +import logging +import math +import os +import pickle +import random +import six + +# spnet related +from e2e_asr_attctc_th import E2E +from e2e_asr_attctc_th import Loss +from asr_train_th import make_batchset, converter_kaldi, delete_feat +from results import EpochResult, GlobalResult + +# third libaries +import lazy_io +import numpy as np +import torch + + +@contextlib.contextmanager +def open_kaldi_feat(batch, reader): + try: + yield converter_kaldi(batch, reader) + finally: + delete_feat(batch) + + +def get_parser(): + parser = argparse.ArgumentParser() + # general configuration + parser.add_argument('--gpu', '-g', default='-1', type=str, + help='GPU ID (negative value indicates CPU)') + parser.add_argument('--outdir', type=str, required=True, + help='Output directory') + parser.add_argument('--debugmode', default=1, type=int, + help='Debugmode') + parser.add_argument('--dict', required=True, + help='Dictionary') + parser.add_argument('--seed', default=1, type=int, + help='Random seed') + parser.add_argument('--debugdir', type=str, + help='Output directory for debugging') + # TODO(karita): implement resume + # parser.add_argument('--resume', '-r', default='', + # help='Resume the training from snapshot') + parser.add_argument('--minibatches', '-N', type=int, default='-1', + help='Process only N minibatches (for debug)') + parser.add_argument('--verbose', '-V', default=0, type=int, + help='Verbose option') + # task related + parser.add_argument('--train-feat', type=str, required=True, + help='Filename of train feature data (Kaldi scp)') + parser.add_argument('--valid-feat', type=str, required=True, + help='Filename of validation feature data (Kaldi scp)') + parser.add_argument('--train-label', type=str, required=True, + help='Filename of train label data (json)') + parser.add_argument('--valid-label', type=str, required=True, + help='Filename of validation label data (json)') + # network archtecture + # encoder + parser.add_argument('--etype', default='blstmp', type=str, + choices=['blstm', 'blstmp', 'vggblstmp', 'vggblstm'], + help='Type of encoder network architecture') + parser.add_argument('--elayers', default=4, type=int, + help='Number of encoder layers') + parser.add_argument('--eunits', '-u', default=300, type=int, + help='Number of encoder hidden units') + parser.add_argument('--eprojs', default=320, type=int, + help='Number of encoder projection units') + parser.add_argument('--subsample', default=1, type=str, + help='Subsample input frames x_y_z means subsample every x frame at 1st layer, ' + 'every y frame at 2nd layer etc.') + # attention + parser.add_argument('--atype', default='dot', type=str, + choices=['dot', 'location', 'noatt'], + help='Type of attention architecture') + parser.add_argument('--adim', default=320, type=int, + help='Number of attention transformation dimensions') + parser.add_argument('--aconv-chans', default=-1, type=int, + help='Number of attention convolution channels \ + (negative value indicates no location-aware attention)') + parser.add_argument('--aconv-filts', default=100, type=int, + help='Number of attention convolution filters \ + (negative value indicates no location-aware attention)') + # decoder + parser.add_argument('--dtype', default='lstm', type=str, + choices=['lstm'], + help='Type of decoder network architecture') + parser.add_argument('--dlayers', default=1, type=int, + help='Number of decoder layers') + parser.add_argument('--dunits', default=320, type=int, + help='Number of decoder hidden units') + parser.add_argument('--mtlalpha', default=0.5, type=float, + help='Multitask learning coefficient, alpha: alpha*ctc_loss + (1-alpha)*att_loss ') + parser.add_argument('--lsm-type', const='', default='', type=str, nargs='?', choices=['', 'unigram'], + help='Apply label smoothing with a specified distribution type') + parser.add_argument('--lsm-weight', default=0.0, type=float, + help='Label smoothing weight') + + # model (parameter) related + parser.add_argument('--dropout-rate', default=0.0, type=float, + help='Dropout rate') + # minibatch related + parser.add_argument('--batch-size', '-b', default=50, type=int, + help='Batch size') + parser.add_argument('--maxlen-in', default=800, type=int, metavar='ML', + help='Batch size is reduced if the input sequence length > ML') + parser.add_argument('--maxlen-out', default=150, type=int, metavar='ML', + help='Batch size is reduced if the output sequence length > ML') + # optimization related + parser.add_argument('--opt', default='adadelta', type=str, + choices=['adadelta', 'adam'], + help='Optimizer') + parser.add_argument('--lr', default=1.0, type=float, + help="Learning rate") + parser.add_argument('--eps', default=1e-8, type=float, + help='Epsilon constant for optimizer') + parser.add_argument('--eps-decay', default=0.01, type=float, + help='Decaying ratio of epsilon') + parser.add_argument('--criterion', default='acc', type=str, + choices=['loss', 'acc'], + help='Criterion to perform epsilon decay') + parser.add_argument('--threshold', default=1e-4, type=float, + help='Threshold to stop iteration') + parser.add_argument('--epochs', '-e', default=30, type=int, + help='Number of maximum epochs') + parser.add_argument('--grad-clip', default=5, type=float, + help='Gradient norm threshold to clip') + parser.add_argument('--supervised-data-ratio', default=1.0, type=float, + help='ratio of supervised training set') + return parser + + +def setup_torch(args): + # seed setting (chainer seed may not need it) + nseed = args.seed + random.seed(nseed) + np.random.seed(nseed) + torch.manual_seed(nseed) + + # debug mode setting + # 0 would be fastest, but 1 seems to be reasonable + # by considering reproducability + if args.debugmode < 1: + torch.backends.cudnn.deterministic = True + logging.info('pytorch cudnn deterministic is disabled') + else: + torch.backends.cudnn.deterministic = True + + # check cuda and cudnn availability + if not torch.cuda.is_available(): + logging.warning('cuda is not available') + if not torch.backends.cudnn.enabled: + logging.warning('cudnn is not available') + + +def setup(args): + # logging info + if args.verbose > 0: + logging.basicConfig( + level=logging.INFO, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s') + else: + logging.basicConfig( + level=logging.WARN, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s') + logging.warning('Skip DEBUG/INFO messages') + + # display PYTHONPATH + logging.info('python path = ' + os.environ['PYTHONPATH']) + setup_torch(args) + + # load dictionary for debug log + if args.dict is not None: + with open(args.dict, 'rb') as f: + dictionary = f.readlines() + char_list = [entry.decode('utf-8').split(' ')[0] + for entry in dictionary] + char_list.insert(0, '') + char_list.append('') + args.char_list = char_list + else: + args.char_list = None + + # get input and output dimension info + with open(args.valid_label, 'rb') as f: + valid_json = json.load(f)['utts'] + utts = list(valid_json.keys()) + idim = int(valid_json[utts[0]]['idim']) + odim = int(valid_json[utts[0]]['odim']) + logging.info('#input dims : ' + str(idim)) + logging.info('#output dims: ' + str(odim)) + + # write model config + if not os.path.exists(args.outdir): + os.makedirs(args.outdir) + model_conf = args.outdir + '/model.conf' + with open(model_conf, 'wb') as f: + logging.info('writing a model config file to' + model_conf) + # TODO(watanabe) use others than pickle, possibly json, and save as a text + pickle.dump((idim, odim, args), f) + for key in sorted(vars(args).keys()): + logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) + + # read json data + with open(args.train_label, 'rb') as f: + train_json = json.load(f)['utts'] + with open(args.valid_label, 'rb') as f: + valid_json = json.load(f)['utts'] + + # make minibatch list (variable length) + # that contains [ [{"utt-id": { "tokenid": ..., }} x batchsize], ... ] + train_batch = make_batchset(train_json, args.batch_size, + args.maxlen_in, args.maxlen_out, args.minibatches) + valid_batch = make_batchset(valid_json, args.batch_size, + args.maxlen_in, args.maxlen_out, args.minibatches) + return idim, odim, train_batch, valid_batch + + +if __name__ == "__main__": + args = get_parser().parse_args() + idim, odim, train_batch, valid_batch = setup(args) + if args.supervised_data_ratio != 1.0: + n_supervised = int(len(train_batch) * args.supervised_data_ratio) + train_batch = train_batch[:n_supervised] + + # specify model architecture + e2e = E2E(idim, odim, args) + model = Loss(e2e, args.mtlalpha) + + # Set gpu + gpu_id = int(args.gpu) + logging.info('gpu id: ' + str(gpu_id)) + if gpu_id >= 0: + # Make a specified GPU current + model.cuda(gpu_id) # Copy the model to the GPU + + # Setup an optimizer + if args.opt == 'adadelta': + optimizer = torch.optim.Adadelta( + model.parameters(), lr=args.lr, rho=0.95, eps=args.eps) + elif args.opt == 'adam': + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + + # prepare Kaldi reader + train_reader = lazy_io.read_dict_scp(args.train_feat) + valid_reader = lazy_io.read_dict_scp(args.valid_feat) + + best = dict(loss=float("inf"), acc=-float("inf")) + opt_key = "eps" if args.opt == "adadelta" else "lr" + def get_opt_param(): + return optimizer.param_groups[0][opt_key] + + # training loop + result = GlobalResult(args.epochs, args.outdir) + for epoch in range(args.epochs): + model.train() + with result.epoch("main", train=True) as train_result: + for batch in np.random.permutation(train_batch): + with open_kaldi_feat(batch, train_reader) as x: + # forward + loss_ctc, loss_att, acc = model.predictor(x) + loss = args.mtlalpha * loss_ctc + (1 - args.mtlalpha) * loss_att + # backward + optimizer.zero_grad() # Clear the parameter gradients + loss.backward() # Backprop + loss.detach() # Truncate the graph + # compute the gradient norm to check if it is normal or not + grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), args.grad_clip) + logging.info('grad norm={}'.format(grad_norm)) + if math.isnan(grad_norm): + logging.warning('grad norm is nan. Do not update model.') + else: + optimizer.step() + # print/plot stats to args.outdir/results + train_result.report({ + "loss": loss, + "acc": acc, + "loss_ctc": loss_ctc, + "loss_att": loss_att, + "grad_norm": grad_norm, + opt_key: get_opt_param() + }) + + with result.epoch("validation/main", train=False) as valid_result: + model.eval() + for batch in valid_batch: + with open_kaldi_feat(batch, valid_reader) as x: + # forward (without backward) + loss_ctc, loss_att, acc = model.predictor(x) + loss = args.mtlalpha * loss_ctc + (1 - args.mtlalpha) * loss_att + # print/plot stats to args.outdir/results + valid_result.report({ + "loss": loss, + "acc": acc, + "loss_ctc": loss_ctc, + "loss_att": loss_att, + opt_key: get_opt_param() + }) + + # save/load model + valid_avg = valid_result.average() + degrade = False + if best["loss"] > valid_avg["loss"]: + best["loss"] = valid_avg["loss"] + torch.save(model.state_dict(), args.outdir + "/model.loss.best") + elif args.criterion == "loss": + degrade = True + + if best["acc"] < valid_avg["acc"]: + best["acc"] = valid_avg["acc"] + torch.save(model.state_dict(), args.outdir + "/model.acc.best") + elif args.criterion == "acc": + degrade = True + + if degrade: + key = "eps" if args.opt == "adadelta" else "lr" + for p in optimizer.param_groups: + p[key] *= args.eps_decay + model.load_state_dict(torch.load(args.outdir + "/model." + args.criterion + ".best")) diff --git a/python/results.py b/python/results.py new file mode 100644 index 0000000..27822e9 --- /dev/null +++ b/python/results.py @@ -0,0 +1,137 @@ +from __future__ import print_function +import os +import contextlib +import json +import time + +import torch +import matplotlib + +matplotlib.use('Agg') + +from matplotlib import pyplot + + +def to_float(x, name): + if isinstance(x, float): + return x + if isinstance(x, int): + return float(x) + elif isinstance(x, torch.autograd.Variable): + return x.data[0] + else: + raise NotImplementedError("{} is unknown-type: {} of {}".format(name, x, type(x))) + + +def plot_seq(d, path): + fig, ax = pyplot.subplots() + for k, xs in d.items(): + ax.plot(range(len(xs)), xs, label=k, marker="x") + + l = ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) + ax.grid() + fig.savefig(path, bbox_extra_artists=(l,), bbox_inches='tight') + pyplot.close() + + +def default_print(*args, **kwargs): + print(*args, flush=True, **kwargs) + + +class GlobalResult(object): + def __init__(self, max_epoch, outdir=None, float_fmt="{}", + report_every=100, logfun=default_print, optional=dict()): + self.outdir = outdir + if outdir is not None: + self.log_path = outdir + "/log" + with open(self.log_path, "w") as f: + json.dump([], f) + else: + self.log_path = None + self.logfun = logfun + self.max_epoch = max_epoch + self.report_every = report_every + self.float_fmt = float_fmt + self.start_time = time.time() + self.current_epoch = 0 + self.plot_dict = dict() + + def elapsed_time(self): + return time.time() - self.start_time + + @contextlib.contextmanager + def epoch(self, prefix, train): + try: + if train: + self.current_epoch += 1 + e_result = EpochResult(self, prefix, train) + yield e_result + finally: + self.logfun("[{}] {}-epoch: {}\t{}".format( + prefix, "train" if train else "valid", + self.current_epoch, e_result.summary())) + e_result.dump() + + if self.outdir is not None: + avg = e_result.average() + for k, v in avg.items(): + if k not in self.plot_dict.keys(): + self.plot_dict[k] = dict() + if prefix not in self.plot_dict[k].keys(): + self.plot_dict[k][prefix] = [] + self.plot_dict[k][prefix].append(avg[k]) + plot_seq(self.plot_dict[k], self.outdir + "/" + k + ".png") + + +class EpochResult(object): + def __init__(self, global_result, prefix, train): + self.global_result = global_result + self.train = train + self.sum_dict = dict() + self.iteration = 0 + self.logfun = global_result.logfun + self.log_path = global_result.log_path + self.prefix = prefix + self.float_fmt = global_result.float_fmt + + def summary(self): + s = "" + fmt = "{}: " + self.float_fmt + "\t" + for k, v in self.average().items(): + s += fmt.format(k, v) + s += "elapsed: " + time.strftime("%X", time.gmtime(self.global_result.elapsed_time())) + return s + + def dump(self): + if self.log_path is None: + return + + with open(self.log_path, "r") as f: + d = json.load(f) + + elem = { + "epoch": self.global_result.current_epoch, + "iteration": self.iteration, + "elapsed_time": self.global_result.elapsed_time() + } + + for k, v in self.average().items(): + elem[self.prefix + "/" + k] = v + + d.append(elem) + with open(self.log_path, "w") as f: + json.dump(d, f, indent=4) + + def report(self, d): + for k, v in d.items(): + if k not in self.sum_dict.keys(): + self.sum_dict[k] = to_float(v, k) + else: + self.sum_dict[k] += to_float(v, k) + self.iteration += 1 + if self.train and self.iteration % self.global_result.report_every == 0: + self.logfun("train-iter: {}\t{}".format(self.iteration, self.summary())) + self.dump() + + def average(self): + return {k: v / self.iteration for k, v in self.sum_dict.items()} diff --git a/python/retrain_loop_th.py b/python/retrain_loop_th.py new file mode 100755 index 0000000..d9c3b1e --- /dev/null +++ b/python/retrain_loop_th.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python +import argparse +import collections +import contextlib +import copy +import json +import logging +import math +import os +import pickle +import random +import six +from distutils.util import strtobool + +# spnet related +from unsupervised import E2E, Discriminator +from e2e_asr_attctc_th import Loss +from asr_train_loop_th import get_parser, setup, open_kaldi_feat, make_batchset +from results import EpochResult, GlobalResult + +# third libaries +import lazy_io +import numpy as np +import torch + + +def shuffle_pair(batch_list): + keys = np.random.permutation([[b[0] for b in batch] for batch in batch_list]) + vals = np.random.permutation([[b[1] for b in batch] for batch in batch_list]) + ret = [] + for ks, vs in zip(keys, vals): + batch = [] + n = min(len(ks), len(vs)) + for k, v in zip(ks[:n], vs[:n]): + batch.append((k, v)) + ret.append(batch) + return ret + +def cpu_loader(storage, location): + return storage + +def load_pretrained(self, src_dict, idim, odim, args, train_batch, train_reader): + dst_dict = self.state_dict() + for k, v in src_dict.items(): + assert k in dst_dict, k + " not found" + dst_dict[k] = v + self.load_state_dict(dst_dict) + tgt_dict = self.state_dict() + for k, v in src_dict.items(): + assert (tgt_dict[k] == v).all() + + if args.verbose > 0: + import e2e_asr_attctc_th as base + init = base.Loss(base.E2E(idim, odim, args), args.mtlalpha) + init.load_state_dict(src_dict) + init.eval() + self.predictor.eval() + # test first batch prediction equality + with open_kaldi_feat(train_batch[0], train_reader) as data: + init_ctc, init_att, init_acc = init.predictor(data) + re_ctc, re_att, re_acc = self.predictor(data, supervised=True) + print("init: ", init_ctc, init_att, init_acc) + print("re: ", re_ctc, re_att, re_acc) + np.testing.assert_almost_equal(init_ctc.data[0], re_ctc.data[0]) + np.testing.assert_almost_equal(init_att.data[0], re_att.data[0]) + np.testing.assert_almost_equal(init_acc, re_acc) + return self + + +def parameters(model, exclude=None): + if exclude is None: + return model.parameters() + assert exclude in model.modules() + exclude_params = list(exclude.parameters()) + model_params = list(model.parameters()) + ret = [] + for p in model_params: + found = False + for e in exclude_params: + if p is e: + found = True + break + if not found: + ret.append(p) + assert len(ret) == (len(model_params) - len(exclude_params)) + return ret + + +def fully_unpaired(batch_list): + """ + insert values of next batch into prev batch (make key-value mismatch) + """ + ret = [] + for i, batch in enumerate(batch_list): + ret_batch = [] + if i % 2 == 0 and i != len(batch_list) - 1: + next_batch = batch_list[i+1] + ret.append([(k1, v2) for (k1, v1), (k2, v2) in zip(batch, next_batch)]) + return ret + + +if __name__ == "__main__": + parser = get_parser() + parser.add_argument("--init-model", type=str, default="None") + parser.add_argument("--weight-decay", type=float, default=0.0) + parser.add_argument("--discriminator-dim", type=int, default=320) + parser.add_argument("--unsupervised-feat", type=str) + parser.add_argument("--unsupervised-json", type=str) + parser.add_argument('--speech-text-ratio', default=0.5, type=float, + help='Multitask learning coefficient') + parser.add_argument('--supervised-loss-ratio', default=0.9, type=float, + help='Multitask learning coefficient') + parser.add_argument('--unsupervised-loss', choices=["None", "gan", "gauss", "gausslogdet", "variance", "mmd"], default="None", type=str, + help='loss for hidden space') + parser.add_argument('--use-batchnorm', default=False, type=strtobool, + help="use batchnorm in output of encoder") + parser.add_argument('--use-smaller-data-size', default=True, type=strtobool, + help="use smaller size of supervised/unsupervised dataset for iteration") + parser.add_argument('--lock-encoder', default=False, type=strtobool, + help="do not update encoder parameters") + + args = parser.parse_args() + idim, odim, supervised_train_batch, valid_batch = setup(args) + if args.supervised_data_ratio != 1.0: + n_supervised = int(len(supervised_train_batch) * args.supervised_data_ratio) + supervised_train_batch = supervised_train_batch[:n_supervised] + + with open(args.unsupervised_json, 'rb') as f: + unsupervised_json = json.load(f)['utts'] + unsupervised_train_batch = make_batchset(unsupervised_json, args.batch_size, + args.maxlen_in, args.maxlen_out, args.minibatches) + unsupervised_train_batch = fully_unpaired(unsupervised_train_batch) + + n_supervised = len(supervised_train_batch) + n_unsupervised = len(unsupervised_train_batch) + + # prepare Kaldi reader + train_reader = lazy_io.read_dict_scp(args.train_feat) + valid_reader = lazy_io.read_dict_scp(args.valid_feat) + unsupervised_reader = lazy_io.read_dict_scp(args.unsupervised_feat) + + # specify model architecture + e2e = E2E(idim, odim, args) + model = Loss(e2e, args.mtlalpha) + if args.init_model != "None": + src_dict = torch.load(args.init_model, map_location=cpu_loader) + model = load_pretrained(model, src_dict, idim, odim, args, supervised_train_batch, train_reader) + if args.unsupervised_loss == "gan": + discriminator = Discriminator(args.eprojs, args.discriminator_dim) + else: + discriminator = None + + # Set gpu + gpu_id = int(args.gpu) + logging.info('gpu id: ' + str(gpu_id)) + if gpu_id >= 0: + # Make a specified GPU current + model.cuda(gpu_id) # Copy the model to the GPU + if discriminator: + discriminator.cuda(gpu_id) + + # Setup an optimizer + if args.lock_encoder: + model_params = parameters(model, model.predictor.enc) + else: + model_params = model.parameters() + if args.opt == 'adadelta': + optimizer = torch.optim.Adadelta( + model_params, lr=args.lr, rho=0.95, eps=args.eps, weight_decay=args.weight_decay) + if discriminator: + d_optimizer = torch.optim.Adadelta( + discriminator.parameters(), lr=args.lr, rho=0.95, eps=args.eps, weight_decay=args.weight_decay) + + elif args.opt == 'adam': + optimizer = torch.optim.Adam(model_params, lr=args.lr, weight_decay=args.weight_decay) + if discriminator: + d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + best = dict(loss=float("inf"), acc=-float("inf")) + opt_key = "eps" if args.opt == "adadelta" else "lr" + def get_opt_param(): + return optimizer.param_groups[0][opt_key] + + # training loop + result = GlobalResult(args.epochs, args.outdir) + for epoch in range(args.epochs): + model.train() + with result.epoch("main", train=True) as train_result: + n_iter_fun = min if args.use_smaller_data_size else max + for i in range(n_iter_fun(n_supervised, n_unsupervised)): + # re-shuffle and repeat smaller one + if i % n_supervised == 0: + supervised_train_batch = np.random.permutation(supervised_train_batch) + if i % n_unsupervised == 0: + unsupervised_train_batch = shuffle_pair(unsupervised_train_batch) + sbatch = supervised_train_batch[i % n_supervised] + ubatch = unsupervised_train_batch[i % n_unsupervised] + + # supervised forward + with open_kaldi_feat(sbatch, train_reader) as sx: + loss_ctc, loss_att, acc_supervised = model.predictor(sx, supervised=True) + loss_supervised = args.mtlalpha * loss_ctc + (1.0 - args.mtlalpha) * loss_att + + # unsupervised forward + with open_kaldi_feat(ubatch, unsupervised_reader) as ux: + loss_text, loss_hidden, acc_text = model.predictor(ux, supervised=False, discriminator=discriminator) + + loss_unsupervised = args.speech_text_ratio * loss_hidden + (1.0 - args.speech_text_ratio) * loss_text + + if discriminator: + loss_discriminator = -loss_hidden + d_optimizer.zero_grad() + loss_discriminator.backward(retain_variables=True) + loss_discriminator.detach() + d_optimizer.step() + + optimizer.zero_grad() # Clear the parameter gradients + loss = args.supervised_loss_ratio * loss_supervised + (1.0 - args.supervised_loss_ratio) * loss_unsupervised + loss.backward() + loss.detach() + + # compute the gradient norm to check if it is normal or not + grad_norm = torch.nn.utils.clip_grad_norm(model_params, args.grad_clip) + logging.info('grad norm={}'.format(grad_norm)) + if math.isnan(grad_norm): + logging.warning('grad norm is nan. Do not update model.') + else: + optimizer.step() + optimizer.zero_grad() # Clear the parameter gradients + + # print/plot stats to args.outdir/results + results = { + "loss_att": loss_att.data[0], + "loss_ctc": loss_ctc, + "loss_supervised": loss_supervised.data[0], + "acc_supervised": acc_supervised, + + "loss": loss.data[0], + "loss_unsupervised": loss_unsupervised.data[0] / (1.0 - args.supervised_loss_ratio), + "acc_text": acc_text, + "loss_text": loss_text.data[0], + "loss_hidden": loss_hidden, + + "grad_norm": 0.0 if math.isnan(grad_norm) else grad_norm, + opt_key: get_opt_param() + } + # print(acc_supervised) + # if not math.isnan(grad_norm): + train_result.report(results) + + with result.epoch("validation/main", train=False) as valid_result: + model.eval() + for batch in valid_batch: + with open_kaldi_feat(batch, valid_reader) as x: + # forward (without backward) + loss_ctc, loss_att, acc = model.predictor(x) + loss = args.mtlalpha * loss_ctc + (1 - args.mtlalpha) * loss_att + # print/plot stats to args.outdir/results + valid_result.report({ + "loss": loss, + "acc": acc, + "loss_ctc": loss_ctc, + "loss_att": loss_att, + opt_key: get_opt_param() + }) + + # save/load model + valid_avg = valid_result.average() + degrade = False + if best["loss"] > valid_avg["loss"]: + best["loss"] = valid_avg["loss"] + torch.save(model.state_dict(), args.outdir + "/model.loss.best") + elif args.criterion == "loss": + degrade = True + + if best["acc"] < valid_avg["acc"]: + best["acc"] = valid_avg["acc"] + torch.save(model.state_dict(), args.outdir + "/model.acc.best") + elif args.criterion == "acc": + degrade = True + + if degrade: + key = "eps" if args.opt == "adadelta" else "lr" + for p in optimizer.param_groups: + p[key] *= args.eps_decay + model.load_state_dict(torch.load(args.outdir + "/model." + args.criterion + ".best")) + diff --git a/python/unsupervised.py b/python/unsupervised.py new file mode 100644 index 0000000..746fb4c --- /dev/null +++ b/python/unsupervised.py @@ -0,0 +1,493 @@ +import logging + +import numpy as np +import torch +from torch.autograd import Variable +from torch.nn.utils.rnn import pack_padded_sequence +from torch.nn.utils.rnn import pad_packed_sequence, PackedSequence +import six + +import e2e_asr_attctc_th as base + + +def mmd(xs,ys,beta=1.0): + Nx = xs.shape[0] + Ny = ys.shape[0] + Kxy = torch.matmul(xs,ys.t()) + dia1 = torch.sum(xs*xs,1) + dia2 = torch.sum(ys*ys,1) + Kxy = Kxy-0.5*dia1.unsqueeze(1).expand(Nx,Ny) + Kxy = Kxy-0.5*dia2.expand(Nx,Ny) + Kxy = torch.exp(beta*Kxy).sum()/Nx/Ny + + Kx = torch.matmul(xs,xs.t()) + Kx = Kx-0.5*dia1.unsqueeze(1).expand(Nx,Nx) + Kx = Kx-0.5*dia1.expand(Nx,Nx) + Kx = torch.exp(beta*Kx).sum()/Nx/Nx + + Ky = torch.matmul(ys,ys.t()) + Ky = Ky-0.5*dia2.unsqueeze(1).expand(Ny,Ny) + Ky = Ky-0.5*dia2.expand(Ny,Ny) + Ky = torch.exp(beta*Ky).sum()/Ny/Ny + + return Kx+Ky-2*Kxy + + +class _Det(torch.autograd.Function): + """ + Matrix determinant. Input should be a square matrix + """ + + @staticmethod + def forward(ctx, x): + output = x.potrf().diag().prod()**2 + output = x.new([output]) + ctx.save_for_backward(x, output) + # ctx.save_for_backward(u, output) + return output + + @staticmethod + def backward(ctx, grad_output): + x, output = ctx.saved_variables + # u, output = ctx.saved_variables + grad_input = None + + if ctx.needs_input_grad[0]: + # TODO TEST + grad_input = grad_output * output * x.inverse().t() + # grad_input = grad_output * output * torch.potrf(u).t() + + return grad_input + +def det(x): + # u = torch.potrf(x) + return _Det.apply(x) + + +class LogDet(torch.autograd.Function): + """ + Matrix log determinant. Input should be a square matrix + """ + + @staticmethod + def forward(ctx, x, eps=0.0): + output = torch.log(x.potrf().diag() + eps).sum() * 2 + output = x.new([output]) + ctx.save_for_backward(x, output) + # ctx.save_for_backward(u, output) + return output + + @staticmethod + def backward(ctx, grad_output): + x, output = ctx.saved_variables + # u, output = ctx.saved_variables + grad_input = None + + if ctx.needs_input_grad[0]: + # TODO TEST + grad_input = grad_output * x.inverse().t() + # grad_input = grad_output * torch.potrf(u).t() + return grad_input + +def logdet(x): + # u = torch.potrf(x) + return LogDet.apply(x) + + +def test_det(): + x = Variable(torch.rand(3, 3) / 10.0 + torch.eye(3).float(), requires_grad=True) + torch.autograd.gradcheck(det, (x,), eps=1e-4, atol=0.1, rtol=0.1) + +def test_logdet(): + x = Variable(torch.rand(3, 3) + torch.eye(3).float() * 3 , requires_grad=True) + d = det(x).log() + d.backward() + gd = x.grad.clone() + ld = logdet(x) + x.grad = None + ld.backward() + gld = x.grad + np.testing.assert_allclose(d.data.numpy(), ld.data.numpy()) + np.testing.assert_allclose(gd.data.numpy(), gld.data.numpy()) + +def cov(xs, m=None): + assert xs.dim() == 2 + if m is None: + m = xs.mean(0, keepdim=True) + assert m.size() == (1, xs.size(1)) + return (xs - m).t().mm(xs - m) / xs.size(0) + +threshold = torch.nn.functional.threshold + +def unclamp_(x, eps): + """ + >>> a = torch.FloatTensor([0.0, 1.0, -0.1, 0.1]) + >>> unclamp(a, 0.5) + [0.5, 1.0, -0.5, 0.5] + """ + ng = x.abs() < eps + sign = x.sign() + fill_value = sign.float() * eps + (sign == 0).float() * eps + return x.masked_fill_(ng, 0) + ng.float() * fill_value + +def gauss_kld(xs, ys, use_logdet=False, eps=float(np.finfo(np.float32).eps)): + n_batch, n_hidden = xs.size() + xm = xs.mean(0, keepdim=True) + ym = ys.mean(0, keepdim=True) + xcov = cov(xs, xm) + ycov = cov(ys, ym) + xcov += torch.diag(xcov.diag() + eps) + ycov += torch.diag(ycov.diag() + eps) + if use_logdet: + log_ratio = logdet(ycov) - logdet(xcov) + else: + log_ratio = torch.log(threshold(det(ycov), eps, eps)) - torch.log(threshold(det(xcov), eps, eps)) + ycovi = ycov.inverse() + xym = xm - ym # (1, n_hidden) + hess = xym.mm(ycovi).mm(xym.t()) + tr = torch.trace(ycovi.mm(xcov)) + return 0.5 * (log_ratio + tr + hess - n_hidden).squeeze() + + +class EmbedRNN(torch.nn.Module): + def __init__(self, n_in, n_out, n_layers=1): + super(EmbedRNN, self).__init__() + self.embed = torch.nn.Embedding(n_in, n_out) + self.rnn = torch.nn.LSTM(n_out, n_out, n_layers, + bidirectional=True, batch_first=True) + self.merge = torch.nn.Linear(n_out * 2, n_out) + + def forward(self, xpad, xlen): + """ + :param xpad: (batchsize x max(xlen)) LongTensor + :return hpad: (batchsize x max(xlen) x n_out) FloatTensor + :return hlen: length list of int. hlen == xlen + """ + h = self.embed(xpad) + hpack = pack_padded_sequence(h, xlen, batch_first=True) + hpack, states = self.rnn(hpack) + hpad, hlen = pad_packed_sequence(hpack, batch_first=True) + b, t, o = hpad.shape + hpad = self.merge(hpad.contiguous().view(b * t, o)).view(b, t, -1) + return hpad, hlen + + +class MMSEDecoder(torch.nn.Module): + """ + hidden-to-speech decoder with a MMSE criterion + + TODO(karita): use Tacotron-like structure + """ + def __init__(self, eprojs, odim, dlayers, dunits, att, verbose=0): + super(MMSEDecoder, self).__init__() + self.dunits = dunits + self.dlayers = dlayers + self.in_linear = torch.nn.Linear(odim, dunits) + self.decoder = torch.nn.ModuleList() + self.decoder += [torch.nn.LSTMCell(dunits + eprojs, dunits)] + for l in six.moves.range(1, self.dlayers): + self.decoder += [torch.nn.LSTMCell(dunits, dunits)] + self.output = torch.nn.Linear(dunits, odim) + + self.loss = None + self.att = att + self.dunits = dunits + self.verbose = verbose + + def zero_state(self, hpad): + return Variable(hpad.data.new(hpad.size(0), self.dunits).zero_()) + + def forward(self, hpad, hlen, ypad, ylen): + '''Decoder forward + + :param hs: + :param ys: + :return: + ''' + hpad = base.mask_by_length(hpad, hlen, 0) + self.loss = None + + # get dim, length info + batch = ypad.size(0) + olength = ypad.size(1) + + # initialization + c_list = [self.zero_state(hpad)] + z_list = [self.zero_state(hpad)] + for l in six.moves.range(1, self.dlayers): + c_list.append(self.zero_state(hpad)) + z_list.append(self.zero_state(hpad)) + att_w = None + z_all = [] + self.att.reset() # reset pre-computation of h + att_weight_all = [] # for debugging + + # pre-computation of embedding + eys = self.in_linear(ypad.view(batch * olength, -1)).view(batch, olength, -1) # utt x olen x zdim + + # loop for an output sequence + for i in six.moves.range(olength): + att_c, att_w = self.att(hpad, hlen, z_list[0], att_w) + ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim) + z_list[0], c_list[0] = self.decoder[0](ey, (z_list[0], c_list[0])) + for l in six.moves.range(1, self.dlayers): + z_list[l], c_list[l] = self.decoder[l]( + z_list[l - 1], (z_list[l], c_list[l])) + z_all.append(z_list[-1]) + att_weight_all.append(att_w.data) # for debugging + + z_all = torch.stack(z_all, dim=1).view(batch * olength, self.dunits) + # compute loss + y_all = self.output(z_all).view(batch, olength, -1) + ym = base.mask_by_length(y_all, ylen) + tm = base.mask_by_length(ypad, ylen) + self.loss = torch.sum((ym - tm) ** 2) + self.loss *= (np.mean(ylen)) + logging.info('att loss:' + str(self.loss.data)) + return self.loss, att_weight_all + + +class Discriminator(torch.nn.Module): + def __init__(self, idim, odim): + super(Discriminator, self).__init__() + self.seq = torch.nn.Sequential( + torch.nn.Linear(idim, odim), + torch.nn.ReLU(), + torch.nn.Linear(odim, odim), + torch.nn.ReLU(), + torch.nn.Linear(odim, 1) + ) + + def forward(self, spack, tpack): + ns = spack.size(0) + nt = tpack.size(0) + input = torch.cat((spack, tpack), dim=0) + predict = self.seq(input) + target = input.data.new(ns + nt, 1) + target[:ns] = 0 + target[ns:] = 1 + target = Variable(target) + return -torch.nn.functional.binary_cross_entropy_with_logits(predict, target) + + +class E2E(torch.nn.Module): + def __init__(self, idim, odim, args): + super(E2E, self).__init__() + self.etype = args.etype + self.verbose = args.verbose + self.char_list = args.char_list + self.outdir = args.outdir + + if hasattr(args, "unsupervised_loss"): + self.unsupervised_loss = args.unsupervised_loss + else: + self.unsupervised_loss = None + if hasattr(args, "use_batchnorm") and args.use_batchnorm: + self.batchnorm = torch.nn.BatchNorm1d(args.eprojs) + else: + self.batchnorm = None + + # below means the last number becomes eos/sos ID + # note that sos/eos IDs are identical + self.sos = odim - 1 + self.eos = odim - 1 + + # subsample info + # +1 means input (+1) and layers outputs (args.elayer) + subsample = np.ones(args.elayers + 1, dtype=np.int) + if args.etype == 'blstmp': + ss = args.subsample.split("_") + for j in range(min(args.elayers + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + 'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.') + logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) + self.subsample = subsample + + # encoder + self.enc_t = EmbedRNN(odim, args.eprojs) + self.enc = base.Encoder(args.etype, idim, args.elayers, args.eunits, args.eprojs, + self.subsample, args.dropout_rate) + self.enc_common_rnn = getattr(self.enc.enc1, "bilstm%d" % (args.elayers-1)) + self.enc_common_merge = getattr(self.enc.enc1, "bt%d" % (args.elayers-1)) + + # ctc + self.ctc = base.CTC(odim, args.eprojs, args.dropout_rate) + + # attention + if args.atype == 'dot': + self.att = base.AttDot(args.eprojs, args.dunits, args.adim) + elif args.atype == 'location': # + self.att = base.AttLoc(args.eprojs, args.dunits, + args.adim, args.aconv_chans, args.aconv_filts) + elif args.atype == 'noatt': + self.att = base.NoAtt() + else: + logging.error( + "Error: need to specify an appropriate attention archtecture") + sys.exit() + # if args.tied_attention: + # self.att_s = self.att + + # decoder + self.dec = base.Decoder(args.eprojs, odim, args.dlayers, args.dunits, + self.sos, self.eos, self.att, self.verbose, self.char_list) + # self.dec_s = MMSEDecoder(args.eprojs, idim, args.dlayers, args.dunits, + # self.att_s, self.verbose) + # if args.tied_decoder: + # self.dec_s.decoder = self.dec.decoder + + # weight initialization + self.init_like_chainer() + + def init_like_chainer(self): + """Initialize weight like chainer + + chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0 + pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5) + + however, there are two exceptions as far as I know. + - EmbedID.W ~ Normal(0, 1) + - LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM) + """ + base.lecun_normal_init_parameters(self) + + # exceptions + # embed weight ~ Normal(0, 1) + self.dec.embed.weight.data.normal_(0, 1) + self.enc_t.embed.weight.data.normal_(0, 1) + # forget-bias = 1.0 + # https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745 + for l in six.moves.range(len(self.dec.decoder)): + base.set_forget_bias_to_one(self.dec.decoder[l].bias_ih) + + def sort_variables(self, xs, sorted_index): + xs = [xs[i] for i in sorted_index] + xs = [base.to_cuda(self, Variable(torch.from_numpy(xx))) for xx in xs] + xlens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) + return xs, xlens + + def forward_common(self, xpad, xlen): + # hpad, hlen = self.enc_common_rnn(xpad, xlen) + xpack = pack_padded_sequence(xpad, xlen, batch_first=True) + hpack, states = self.enc_common_rnn(xpack) + hpad, hlen = pad_packed_sequence(hpack, batch_first=True) + b, t, o = hpad.shape + hpad = torch.tanh(self.enc_common_merge(hpad.contiguous().view(b * t, o)).view(b, t, -1)) + return hpad, hlen + + def forward(self, data, supervised=False, discriminator=None, only_encoder=False): + '''E2E forward (unsupervised) + + :param data: + :return: + ''' + # utt list of frame x dim + xs = [d[1]['feat'] for d in data] + tids = [d[1]['tokenid'].split() for d in data] + ys = [np.fromiter(map(int, t), dtype=np.int64) for t in tids] + + # sort by length + sorted_index = sorted(range(len(xs)), key=lambda i: -len(xs[i])) + xs, xlens = self.sort_variables(xs, sorted_index) + ys, ylens = self.sort_variables(ys, sorted_index) + + # ys = [base.to_cuda(self, Variable(torch.from_numpy(y))) for y in ys] + if supervised or not self.training: + # forward encoder for speech + xpad = base.pad_list(xs) + hxpad, hxlens = self.enc(xpad, xlens) + if self.batchnorm: + hxpack = pack_padded_sequence(hxpad, hxlens, batch_first=True) + hxpack = PackedSequence(self.batchnorm(hxpack.data), hxpack.batch_sizes) + hxpad, hxlens = pad_packed_sequence(hxpack, batch_first=True) + + # CTC loss + loss_ctc = self.ctc(hxpad, hxlens, ys) + + # forward decoders + loss_att, acc, att_t = self.dec(hxpad, hxlens, ys) + return loss_ctc, loss_att, acc + + # loss_speech, att_s = self.dec_s(hxpad, hxlens, xpad, xlens) + else: + # forward encoder for text + y_sorted_index = sorted(range(len(ys)), key=lambda i: -len(ys[i])) + ys = [ys[i] for i in y_sorted_index] + ylens = [ylens[i] for i in y_sorted_index] + ypad = base.pad_list(ys, 0) + hypad, hylens = self.enc_t(ypad, ylens) + + # forward common encoder + hypad, hylens = self.forward_common(hypad, hylens) + hypack = pack_padded_sequence(hypad, hylens, batch_first=True) + + if self.unsupervised_loss is not None and self.unsupervised_loss != "None": + xpad = base.pad_list(xs) + hxpad, hxlens = self.enc(xpad, xlens) + hxpack = pack_padded_sequence(hxpad, hxlens, batch_first=True) + if self.batchnorm: + hxpack = PackedSequence(self.batchnorm(hxpack.data), hxpack.batch_sizes) + hypack = PackedSequence(self.batchnorm(hypack.data), hypack.batch_sizes) + + if only_encoder: + return hxpack, hypack + + if self.unsupervised_loss == "variance": + loss_unsupervised = torch.cat((hxpack.data, hypack.data), dim=0).var(1).mean() + if self.unsupervised_loss == "gauss": + loss_unsupervised = gauss_kld(hxpack.data, hypack.data) + if self.unsupervised_loss == "gausslogdet": + loss_unsupervised = gauss_kld(hxpack.data, hypack.data, use_logdet=True) + if self.unsupervised_loss == "mmd": + loss_unsupervised = mmd(hxpack.data, hypack.data) + if self.unsupervised_loss == "gan": + loss_unsupervised = discriminator(hxpack.data, hypack.data) + else: + loss_unsupervised = 0.0 + if only_encoder: + xpad = base.pad_list(xs) + hxpad, hxlens = self.enc(xpad, xlens) + hxpack = pack_padded_sequence(hxpad, hxlens, batch_first=True) + return hxpack, hypack + + # 3. forward decoders + loss_text, acc, att_t = self.dec(hypad, hylens, ys) + # loss_speech, att_s = self.dec_s(hxpad, hxlens, xpad, xlens) + return loss_text, loss_unsupervised, acc + + + def recognize(self, x, recog_args, char_list): + '''E2E greedy/beam search + + :param x: + :param recog_args: + :param char_list: + :return: + ''' + prev = self.training + self.eval() + # subsample frame + x = x[::self.subsample[0], :] + xlen = [x.shape[0]] + xpad = base.to_cuda(self, Variable(torch.from_numpy( + np.array(x, dtype=np.float32)), volatile=True)) + + # 1. encoder + # make a utt list (1) to use the same interface for encoder + h, hlen = self.enc(xpad.unsqueeze(0), xlen) + # h, hlen = self.forward_common(h, hlen) + lpz = None + + # 2. decoder + # decode the first utterance + if recog_args.beam_size == 1: + y = self.dec.recognize(h[0], recog_args) + else: + y = self.dec.recognize_beam(h[0], lpz, recog_args, char_list) + + if prev: + self.train() + return y + diff --git a/python/unsupervised_recog_th.py b/python/unsupervised_recog_th.py new file mode 100755 index 0000000..f9b0e61 --- /dev/null +++ b/python/unsupervised_recog_th.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python +import argparse +import json +import logging +import os +import pickle +import random + +# chainer related +import chainer +import numpy as np +import torch + +# spnet related +from unsupervised import E2E +from e2e_asr_attctc_th import Loss + +# for kaldi io +import kaldi_io_py + + +def main(): + parser = argparse.ArgumentParser() + # general configuration + parser.add_argument('--gpu', '-g', default='-1', type=str, + help='GPU ID (negative value indicates CPU)') + parser.add_argument('--debugmode', default=1, type=int, + help='Debugmode') + parser.add_argument('--seed', default=1, type=int, + help='Random seed') + parser.add_argument('--verbose', '-V', default=1, type=int, + help='Verbose option') + # task related + parser.add_argument('--recog-feat', type=str, required=True, + help='Filename of recognition feature data (Kaldi scp)') + parser.add_argument('--recog-label', type=str, required=True, + help='Filename of recognition label data (json)') + parser.add_argument('--result-label', type=str, required=True, + help='Filename of result label data (json)') + # model (parameter) related + parser.add_argument('--model', type=str, required=True, + help='Model file parameters to read') + parser.add_argument('--model-conf', type=str, required=True, + help='Model config file') + # search related + parser.add_argument('--beam-size', type=int, default=1, + help='Beam size') + parser.add_argument('--penalty', default=0.0, type=float, + help='Incertion penalty') + parser.add_argument('--maxlenratio', default=0.0, type=float, + help='Input length ratio to obtain max output length.' + + 'If maxlenratio=0.0 (default), it uses a end-detect function' + + 'to automatically find maximum hypothesis lengths') + parser.add_argument('--minlenratio', default=0.0, type=float, + help='Input length ratio to obtain min output length') + parser.add_argument('--ctc-weight', default=0.0, type=float, + help='CTC weight in joint decoding') + args = parser.parse_args() + + # logging info + if args.verbose == 1: + logging.basicConfig( + level=logging.INFO, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") + elif args.verbose == 2: + logging.basicConfig(level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") + else: + logging.basicConfig( + level=logging.WARN, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") + logging.warning("Skip DEBUG/INFO messages") + + # display PYTHONPATH + logging.info('python path = ' + os.environ['PYTHONPATH']) + + # display chainer version + logging.info('chainer version = ' + chainer.__version__) + + # seed setting (chainer seed may not need it) + nseed = args.seed + random.seed(nseed) + np.random.seed(nseed) + os.environ["CHAINER_SEED"] = str(nseed) + logging.info('chainer seed = ' + os.environ['CHAINER_SEED']) + + # read training config + with open(args.model_conf, "rb") as f: + logging.info('reading a model config file from' + args.model_conf) + idim, odim, train_args = pickle.load(f) + + for key in sorted(vars(args).keys()): + logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) + + # specify model architecture + logging.info('reading model parameters from' + args.model) + e2e = E2E(idim, odim, train_args) + model = Loss(e2e, train_args.mtlalpha) + + def cpu_loader(storage, location): + return storage + try: + model.load_state_dict(torch.load(args.model, map_location=cpu_loader)) + except: + model = torch.load(args.model + ".pkl", map_location=cpu_loader) + + # prepare Kaldi reader + reader = kaldi_io_py.read_mat_ark(args.recog_feat) + + # read json data + with open(args.recog_label, 'rb') as f: + recog_json = json.load(f)['utts'] + + new_json = {} + for name, feat in reader: + y_hat = e2e.recognize(feat, args, train_args.char_list) + y_true = map(int, recog_json[name]['tokenid'].split()) + + # print out decoding result + seq_hat = [train_args.char_list[int(idx)] for idx in y_hat] + seq_true = [train_args.char_list[int(idx)] for idx in y_true] + seq_hat_text = "".join(seq_hat).replace('', ' ') + seq_true_text = "".join(seq_true).replace('', ' ') + logging.info("groundtruth[%s]: " + seq_true_text, name) + logging.info("prediction [%s]: " + seq_hat_text, name) + + # copy old json info + new_json[name] = recog_json[name] + + # added recognition results to json + logging.debug("dump token id") + new_json[name]['rec_tokenid'] = " ".join([str(idx) for idx in y_hat]) + logging.debug("dump token") + new_json[name]['rec_token'] = " ".join(seq_hat) + logging.debug("dump text") + new_json[name]['rec_text'] = seq_hat_text + + with open(args.result_label, 'wb') as f: + f.write(json.dumps({'utts': new_json}, indent=4).encode('utf_8')) + + +if __name__ == '__main__': + main() diff --git a/run.sh b/run.sh new file mode 100755 index 0000000..5314fc0 --- /dev/null +++ b/run.sh @@ -0,0 +1,354 @@ +#!/bin/bash + +. ./path.sh +. ./cmd.sh + +sup_data_ratio=1.0 + +# general configuration +init="" +backend=pytorch +stage=0 # start from 0 if you need to start from data preparation +gpu=-1 # use 0 when using GPU on slurm/grid engine, otherwise -1 +debugmode=1 +dumpdir=dump # directory to dump full features +N=0 # number of minibatches to be used (mainly for debugging). "0" uses all minibatches. +verbose=0 # verbose option + +# feature configuration +do_delta=false # true when using CNN + +# network archtecture +# encoder related +etype=blstmp # encoder architecture type +elayers=6 +eunits=320 +eprojs=320 +subsample=1_2_2_1_1 # skip every n frame from input to nth layers +# decoder related +dlayers=1 +dunits=300 +# attention related +atype=location +aconv_chans=10 +aconv_filts=100 + +use_batchnorm=False + +# loss +unsupervised_loss=mmd +sup_loss_ratio=0.5 +st_ratio=0.5 + +# speech/text unsupervised loss ratio +mtlalpha=0.5 +paired_hidden=False + +# minibatch related +batchsize=30 +maxlen_in=800 # if input length > maxlen_in, batchsize is automatically reduced +maxlen_out=150 # if output length > maxlen_out, batchsize is automatically reduced + +# optimization related +opt=adadelta +epochs=15 +lr=1.0 +weight_decay=0.0 + +# decoding parameter +beam_size=20 +penalty=0.1 +maxlenratio=0.0 +minlenratio=0.0 +ctc_weight=0.3 +recog_model=acc.best # set a model to be used for decoding: 'acc.best' or 'loss.best' + +# data +# wsj0=/export/corpora5/LDC/LDC93S6B +# wsj1=/export/corpora5/LDC/LDC94S13B +wsj0=/nfs/kswork/kishin/karita/datasets/LDC93S6A +wsj1=/nfs/kswork/kishin/karita/datasets/LDC94S13A + +# exp tag +tag="" # tag for managing experiments. + +train_set=train_si84 +unpaired_set=train_si284 + +. utils/parse_options.sh || exit 1; + +. ./path.sh +. ./cmd.sh + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +train_dev=test_dev93 +recog_set="test_dev93 test_eval92" + +if [ ${stage} -le 0 ]; then + ### Task dependent. You have to make data the following preparation part by yourself. + ### But you can utilize Kaldi recipes in most cases + echo "stage 0: Data preparation" + local/wsj_data_prep.sh ${wsj0}/??-{?,??}.? ${wsj1}/??-{?,??}.? + ./shell/wsj_format_data_with_si84.sh +fi + +feat_tr_dir=${dumpdir}/${train_set}/delta${do_delta}; mkdir -p ${feat_tr_dir} +feat_us_dir=${dumpdir}/${unpaired_set}/delta${do_delta}; mkdir -p ${feat_us_dir} +feat_dt_dir=${dumpdir}/${train_dev}/delta${do_delta}; mkdir -p ${feat_dt_dir} +if [ ${stage} -le 1 ]; then + ### Task dependent. You have to design training and dev sets by yourself. + ### But you can utilize Kaldi recipes in most cases + echo "stage 1: Feature Generation" + fbankdir=fbank + # Generate the fbank features; by default 80-dimensional fbanks with pitch on each frame + for x in ${train_set} ${unpaired_set} ${recog_set} ; do + steps/make_fbank_pitch.sh --cmd "$train_cmd" --nj 10 data/${x} exp/make_fbank/${x} ${fbankdir} + done + + # compute global CMVN + compute-cmvn-stats scp:data/${unpaired_set}/feats.scp data/${unpaired_set}/cmvn.ark + + # dump features for training + dump.sh --cmd "$train_cmd" --nj 32 --do_delta $do_delta \ + data/${train_set}/feats.scp data/${unpaired_set}/cmvn.ark exp/dump_feats/train ${feat_tr_dir} + dump.sh --cmd "$train_cmd" --nj 32 --do_delta $do_delta \ + data/${unpaired_set}/feats.scp data/${unpaired_set}/cmvn.ark exp/dump_feats/train ${feat_us_dir} + dump.sh --cmd "$train_cmd" --nj 4 --do_delta $do_delta \ + data/${train_dev}/feats.scp data/${unpaired_set}/cmvn.ark exp/dump_feats/dev ${feat_dt_dir} +fi + +dict=data/lang_1char/${unpaired_set}_units.txt +nlsyms=data/lang_1char/non_lang_syms.txt + +echo "dictionary: ${dict}" +if [ ${stage} -le 2 ]; then + ### Task dependent. You have to check non-linguistic symbols used in the corpus. + echo "stage 2: Dictionary and Json Data Preparation" + mkdir -p data/lang_1char/ + + echo "make a non-linguistic symbol list" + cut -f 2- data/${unpaired_set}/text | tr " " "\n" | sort | uniq | grep "<" > ${nlsyms} + cat ${nlsyms} + + echo "make a dictionary" + echo " 1" > ${dict} # must be 1, 0 will be used for "blank" in CTC + text2token.py -s 1 -n 1 -l ${nlsyms} data/${unpaired_set}/text | cut -f 2- -d" " | tr " " "\n" \ + | sort | uniq | grep -v -e '^\s*$' | awk '{print $0 " " NR+1}' >> ${dict} + wc -l ${dict} + + echo "make json files" + data2json.sh --feat ${feat_tr_dir}/feats.scp --nlsyms ${nlsyms} \ + data/${train_set} ${dict} > ${feat_tr_dir}/data.json + data2json.sh --feat ${feat_us_dir}/feats.scp --nlsyms ${nlsyms} \ + data/${unpaired_set} ${dict} > ${feat_us_dir}/data.json + data2json.sh --feat ${feat_dt_dir}/feats.scp --nlsyms ${nlsyms} \ + data/${train_dev} ${dict} > ${feat_dt_dir}/data.json +fi + +if [ -z ${tag} ]; then + expdir=exp/semi_data${sup_data_ratio}_${unsupervised_loss}_loss${sup_loss_ratio}_${train_set}_${etype}_e${elayers}_subsample${subsample}_unit${eunits}_proj${eprojs}_d${dlayers}_unit${dunits}_${atype}_aconvc${aconv_chans}_aconvf${aconv_filts}_mtlalpha${mtlalpha}_${opt}_lr${lr}_wd_${weight_decay}_bs${batchsize}_mli${maxlen_in}_mlo${maxlen_out}_epochs${epochs} + if ${do_delta}; then + expdir=${expdir}_delta + fi +else + expdir=exp/${train_set}_${tag} +fi +mkdir -p ${expdir} + +init_train_script=asr_train_loop_th.py +retrain_script=retrain_loop_th.py +decode_script=unsupervised_recog_th.py + +if [ ${stage} -le 3 ]; then + echo "stage 3: Network Init-Training" + + ${cuda_cmd} ${expdir}/init_train.log \ + ${init_train_script} \ + --gpu ${gpu} \ + --outdir ${expdir}/init_results \ + --debugmode ${debugmode} \ + --dict ${dict} \ + --debugdir ${expdir} \ + --minibatches ${N} \ + --verbose ${verbose} \ + --train-feat scp:${feat_tr_dir}/feats.scp \ + --valid-feat scp:${feat_dt_dir}/feats.scp \ + --train-label ${feat_tr_dir}/data.json \ + --valid-label ${feat_dt_dir}/data.json \ + --etype ${etype} \ + --elayers ${elayers} \ + --eunits ${eunits} \ + --eprojs ${eprojs} \ + --subsample ${subsample} \ + --dlayers ${dlayers} \ + --dunits ${dunits} \ + --atype ${atype} \ + --aconv-chans ${aconv_chans} \ + --aconv-filts ${aconv_filts} \ + --mtlalpha ${mtlalpha} \ + --batch-size ${batchsize} \ + --maxlen-in ${maxlen_in} \ + --maxlen-out ${maxlen_out} \ + --opt ${opt} \ + --supervised-data-ratio ${sup_data_ratio} \ + --epochs ${epochs} +fi + +if [ -z $init ]; then + init=${expdir}/init_results/model.${recog_model} +fi + + +if [ ${stage} -le 4 ]; then + echo "stage 4: Network Re-Training" + + ${cuda_cmd} ${expdir}/train.log \ + ${retrain_script} \ + --init-model ${init} \ + --gpu ${gpu} \ + --outdir ${expdir}/results \ + --debugmode ${debugmode} \ + --dict ${dict} \ + --debugdir ${expdir} \ + --minibatches ${N} \ + --verbose ${verbose} \ + --train-feat scp:${feat_tr_dir}/feats.scp \ + --valid-feat scp:${feat_dt_dir}/feats.scp \ + --unsupervised-feat scp:${feat_us_dir}/feats.scp \ + --train-label ${feat_tr_dir}/data.json \ + --valid-label ${feat_dt_dir}/data.json \ + --unsupervised-json ${feat_us_dir}/data.json \ + --etype ${etype} \ + --elayers ${elayers} \ + --eunits ${eunits} \ + --eprojs ${eprojs} \ + --subsample ${subsample} \ + --dlayers ${dlayers} \ + --dunits ${dunits} \ + --atype ${atype} \ + --aconv-chans ${aconv_chans} \ + --aconv-filts ${aconv_filts} \ + --mtlalpha ${mtlalpha} \ + --batch-size ${batchsize} \ + --maxlen-in ${maxlen_in} \ + --maxlen-out ${maxlen_out} \ + --opt ${opt} \ + --lr ${lr} \ + --weight-decay ${weight_decay} \ + --unsupervised-loss ${unsupervised_loss} \ + --supervised-loss-ratio ${sup_loss_ratio} \ + --supervised-data-ratio ${sup_data_ratio} \ + --speech-text-ratio ${st_ratio} \ + --use-batchnorm ${use_batchnorm} \ + --epochs ${epochs} +fi + + +if [ ${stage} -le 5 ]; then + echo "stage 5: Decoding retrained model" + nj=32 + + for rtask in ${recog_set}; do + ( + decode_dir=decode_${rtask}_beam${beam_size}_e${recog_model}_p${penalty}_len${minlenratio}-${maxlenratio}_ctcw${ctc_weight} + + # split data + data=data/${rtask} + sdata=${data}/split${nj}utt; + if [ ! -d $sdata ]; then + split_data.sh --per-utt ${data} ${nj}; + fi + + # feature extraction + feats="ark,s,cs:apply-cmvn --norm-vars=true data/train_si284/cmvn.ark scp:${sdata}/JOB/feats.scp ark:- |" + if ${do_delta}; then + feats="$feats add-deltas ark:- ark:- |" + fi + + if [ ! -e ${data}/data.json ]; then + # make json labels for recognition + data2json.sh --nlsyms ${nlsyms} ${data} ${dict} > ${data}/data.json + fi + + #### use CPU for decoding + gpu=-1 + + ${decode_cmd} JOB=1:${nj} ${expdir}/${decode_dir}/log/decode.JOB.log \ + ${decode_script} \ + --gpu ${gpu} \ + --recog-feat "$feats" \ + --recog-label ${data}/data.json \ + --result-label ${expdir}/${decode_dir}/data.JOB.json \ + --model ${expdir}/results/model.${recog_model} \ + --model-conf ${expdir}/results/model.conf \ + --beam-size ${beam_size} \ + --penalty ${penalty} \ + --maxlenratio ${maxlenratio} \ + --minlenratio ${minlenratio} \ + --ctc-weight ${ctc_weight} & + wait + + score_sclite.sh --wer true --nlsyms ${nlsyms} ${expdir}/${decode_dir} ${dict} + ) & + done + wait + echo "Finished" +fi + + +if [ ${stage} -le 6 ]; then + echo "stage 6: Decoding init model" + nj=32 + + for rtask in ${recog_set}; do + ( + decode_dir=init_decode_${rtask}_beam${beam_size}_e${recog_model}_p${penalty}_len${minlenratio}-${maxlenratio}_ctcw${ctc_weight} + + # split data + data=data/${rtask} + sdata=${data}/split${nj}utt; + if [ ! -d $sdata ]; then + split_data.sh --per-utt ${data} ${nj}; + fi + + # feature extraction + feats="ark,s,cs:apply-cmvn --norm-vars=true data/train_si284/cmvn.ark scp:${sdata}/JOB/feats.scp ark:- |" + if ${do_delta}; then + feats="$feats add-deltas ark:- ark:- |" + fi + + if [ ! -e ${data}/data.json ]; then + # make json labels for recognition + data2json.sh --nlsyms ${nlsyms} ${data} ${dict} > ${data}/data.json + fi + + #### use CPU for decoding + gpu=-1 + + ${decode_cmd} JOB=1:${nj} ${expdir}/${decode_dir}/log/decode.JOB.log \ + asr_recog_th.py \ + --gpu ${gpu} \ + --recog-feat "$feats" \ + --recog-label ${data}/data.json \ + --result-label ${expdir}/${decode_dir}/data.JOB.json \ + --model ${expdir}/init_results/model.${recog_model} \ + --model-conf ${expdir}/init_results/model.conf \ + --beam-size ${beam_size} \ + --penalty ${penalty} \ + --maxlenratio ${maxlenratio} \ + --minlenratio ${minlenratio} \ + --ctc-weight ${ctc_weight} & + wait + + score_sclite.sh --wer true --nlsyms ${nlsyms} ${expdir}/${decode_dir} ${dict} + ) & + done + wait + echo "Finished" +fi + diff --git a/sbatch.sh b/sbatch.sh new file mode 100755 index 0000000..4bb9e07 --- /dev/null +++ b/sbatch.sh @@ -0,0 +1,80 @@ +#!/usr/bin/env zsh + +base=$(dirname $0) +mkdir -p ./slurm/log +jobid_list=() + +n_parallel=1 + +submit() { + num_gpu=$1 + logname=$2 + command=${@:3} + script="./slurm/${logname}.sh" + echo "#!/usr/bin/env zsh" > $script + echo "${command}" '|| {echo + dead-jobid: $SLURM_JOB_ID; echo + command:' "$command" "; echo + logfile: \ + $logname; tail $logname } | mattersend -c karita-exp " >> $script + msg=$(sbatch -p gpu --gres gpu:$num_gpu -c 1 -N 1 -o $logname -e $logname $script) + jobid=$(echo $msg | awk '{print $NF}') + jobid_list+=($jobid) + + echo "${command}" + echo "${msg}" +} + +mkdir -p log model + + +opt=adadelta +alpha=0.5 +batch=30 +elayers=6 +batchnorm=False +dlayers=1 +atype=location +epochs=15 +for unsupervised_loss in mmd gausslogdet gan None; do + for sup_data_ratio in 1.00 0.75 0.50 0.25; do + for sup_loss_ratio in 0.9 0.5 0.1 ; do + for wd in 0.0 ; do + for st_ratio in 0.9 0.5 0.1 ; do + for lr in 1.0; do + exp_name=sbatch_${unsupervised_loss}_alpha${alpha}_bn${batchnorm}_${opt}_lr${lr}_bs${batch}_el${elayers}_dl${dlayers}_att_${atype}_batch${batch}_data${sup_data_ratio}_loss${sup_loss_ratio}_st${st_ratio}_epochs${epochs} + log_name=log/${train_set}_${exp_name}.log + ngpu=1 + + echo $exp_name + submit $ngpu $log_name OMP_NUM_THREADS=1 ./run.sh \ + --stage 3 --tag $exp_name \ + --gpu 0 \ + --unsupervised_loss ${unsupervised_loss} \ + --mtlalpha $alpha \ + --batchsize $batch \ + --elayers $elayers \ + --dlayers $dlayers \ + --opt $opt \ + --epochs $epochs \ + --backend pytorch \ + --etype blstmp \ + --atype $atype \ + --weight_decay $wd \ + --st_ratio $st_ratio \ + --sup_loss_ratio $sup_loss_ratio \ + --sup_data_ratio $sup_data_ratio \ + --use_batchnorm $batchnorm + sleep 3 + done + done + done + done + done +done + +echo "=============================" +echo "to delete batch jobs: scancel ${jobid_list[@]}" +echo "=============================" + +ch=./scancel.sh +echo "#!/bin/sh" > $ch +echo scancel ${jobid_list[@]} >> $ch +chmod +x $ch diff --git a/shell/debug.sh b/shell/debug.sh new file mode 100755 index 0000000..bc46c72 --- /dev/null +++ b/shell/debug.sh @@ -0,0 +1,12 @@ +unset LC_ALL + +MAIN_ROOT=$PWD/espnet +KALDI_ROOT=$MAIN_ROOT/tools/kaldi +SPNET_ROOT=$MAIN_ROOT/src + +export PATH=$PWD/python:$PWD/shell:$SPNET_ROOT/utils/:$SPNET_ROOT/bin/:$PATH +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:$MAIN_ROOT/tools/chainer_ctc/ext/warp-ctc/build +source $MAIN_ROOT/tools/venv/bin/activate +export PYTHONPATH=$PWD/python:$SPNET_ROOT/lm/:$SPNET_ROOT/asr/:$SPNET_ROOT/nets/:$SPNET_ROOT/utils/:$SPNET_ROOT/bin/:$PYTHONPATH + +export OMP_NUM_THREADS=1 diff --git a/shell/decode.sh b/shell/decode.sh new file mode 100755 index 0000000..12ba5c3 --- /dev/null +++ b/shell/decode.sh @@ -0,0 +1,148 @@ +#!/bin/bash + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +. ./path.sh +. ./cmd.sh + +expdir="" + +# general configuration +init="" +backend=pytorch +stage=0 # start from 0 if you need to start from data preparation +gpu=-1 # use 0 when using GPU on slurm/grid engine, otherwise -1 +debugmode=1 +dumpdir=dump # directory to dump full features +N=0 # number of minibatches to be used (mainly for debugging). "0" uses all minibatches. +verbose=1 # verbose option + +# feature configuration +do_delta=false # true when using CNN + +# network archtecture +# encoder related +etype=blstmp # encoder architecture type +elayers=6 +eunits=320 +eprojs=320 +subsample=1_2_2_1_1 # skip every n frame from input to nth layers +# decoder related +dlayers=1 +dunits=300 +# attention related +atype=location +aconv_chans=10 +aconv_filts=100 + + +# loss +unsupervised_loss=None +sup_loss_ratio=0.5 +# speech/text unsupervised loss ratio +mtlalpha=0.5 + +# minibatch related +batchsize=30 +maxlen_in=800 # if input length > maxlen_in, batchsize is automatically reduced +maxlen_out=150 # if output length > maxlen_out, batchsize is automatically reduced + +# optimization related +opt=adadelta +epochs=15 +lr=1.0 +weight_decay=0.0 + +# rnnlm related +lm_weight=0.1 + +# decoding parameter +beam_size=20 +penalty=0.1 +maxlenratio=0.0 +minlenratio=0.0 +ctc_weight=0.3 +recog_model=acc.best # set a model to be used for decoding: 'acc.best' or 'loss.best' + +# data +# wsj0=/export/corpora5/LDC/LDC93S6B +# wsj1=/export/corpora5/LDC/LDC94S13B +wsj0=/nfs/kswork/kishin/karita/datasets/LDC93S6A +wsj1=/nfs/kswork/kishin/karita/datasets/LDC94S13A + +# wsj0=/data/rigel1/corpora/LDC93S6A +# wsj1=/data/rigel1/corpora/LDC94S13A + +lmexpdir="None" +# exp tag +tag="" # tag for managing experiments. + +train_set=train_si84 +unpaired_set=train_si284 +decode_script=unsupervised_recog_th.py + +. utils/parse_options.sh || exit 1; + +. ./path.sh +. ./cmd.sh + +dict=data/lang_1char/train_si284_units.txt +nlsyms=data/lang_1char/non_lang_syms.txt +recog_set="test_dev93 test_eval92" + + +if [ $lmexpdir = "None" ]; then + rnnlmopt="" +else + rnnlmopt="--rnnlm ${lmexpdir}/rnnlm.model.best --lm-weight ${lm_weight} " +fi + +echo "stage 5: Decoding" +nj=32 + +rnnexpdir=${expdir}/rnnlm${lm_weight} +mkdir -p $rnnexpdir +for rtask in ${recog_set}; do + ( + decode_dir=decode_${rtask}_beam${beam_size}_e${recog_model}_p${penalty}_len${minlenratio}-${maxlenratio}_ctcw${ctc_weight}_rnnlm${lm_weight} + + # split data + data=data/${rtask} + # split_data.sh --per-utt ${data} ${nj}; + sdata=${data}/split${nj}utt; + + # feature extraction + feats="ark,s,cs:apply-cmvn --norm-vars=true data/train_si284/cmvn.ark scp:${sdata}/JOB/feats.scp ark:- |" + if ${do_delta}; then + feats="$feats add-deltas ark:- ark:- |" + fi + + # make json labels for recognition + # data2json.sh --nlsyms ${nlsyms} ${data} ${dict} > ${data}/data.json + + #### use CPU for decoding + gpu=-1 + + ${decode_cmd} JOB=1:${nj} ${rnnexpdir}/${decode_dir}/log/decode.JOB.log \ + ${decode_script} \ + ${rnnlmopt} --gpu ${gpu} \ + --recog-feat "$feats" \ + --recog-label ${data}/data.json \ + --result-label ${rnnexpdir}/${decode_dir}/data.JOB.json \ + --model ${expdir}/results/model.${recog_model} \ + --model-conf ${expdir}/results/model.conf \ + --beam-size ${beam_size} \ + --penalty ${penalty} \ + --maxlenratio ${maxlenratio} \ + --minlenratio ${minlenratio} \ + --verbose ${verbose} \ + --ctc-weight ${ctc_weight} & + wait + + score_sclite.sh --wer true --nlsyms ${nlsyms} ${rnnexpdir}/${decode_dir} ${dict} + + ) & +done +wait +echo "Finished" diff --git a/shell/show_result.sh b/shell/show_result.sh new file mode 100755 index 0000000..c4337e9 --- /dev/null +++ b/shell/show_result.sh @@ -0,0 +1,15 @@ +#!/bin/sh +dev_cer=$(grep -e Avg -e SPKR -m 2 $1/decode_test_dev*/result.txt | awk '{print $11}' | tail -n 1) +eval_cer=$(grep -e Avg -e SPKR -m 2 $1/decode_test_eval*/result.txt | awk '{print $11}' | tail -n 1) + +dev_wer=$(grep -e Avg -e SPKR -m 2 $1/decode_test_dev*/result.wrd.txt | awk '{print $11}' | tail -n 1) +eval_wer=$(grep -e Avg -e SPKR -m 2 $1/decode_test_eval*/result.wrd.txt | awk '{print $11}' | tail -n 1) +echo WER dev $dev_wer % eval $eval_wer % + +dev_ser=$(grep -e Avg -e SPKR -m 2 $1/decode_test_dev*/result.wrd.txt | awk '{print $12}' | tail -n 1) +eval_ser=$(grep -e Avg -e SPKR -m 2 $1/decode_test_eval*/result.wrd.txt | awk '{print $12}' | tail -n 1) +echo SER dev $dev_ser % eval $eval_ser % + + +echo "| dev-CER | eval-CER | dev-WER | eval-WER | dev-SER | eval-SER | path " +echo "| $dev_cer | $eval_cer | $dev_wer | $eval_wer | $dev_ser | $eval_ser | $1 " diff --git a/shell/wsj_format_data_with_si84.sh b/shell/wsj_format_data_with_si84.sh new file mode 100755 index 0000000..33f8d31 --- /dev/null +++ b/shell/wsj_format_data_with_si84.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +# Copyright 2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey) +# 2015 Guoguo Chen +# Apache 2.0 + +# This script takes data prepared in a corpus-dependent way +# in data/local/, and converts it into the "canonical" form, +# in various subdirectories of data/, e.g. data/lang, data/lang_test_ug, +# data/train_si284, data/train_si84, etc. + +# Don't bother doing train_si84 separately (although we have the file lists +# in data/local/) because it's just the first 7138 utterances in train_si284. +# We'll create train_si84 after doing the feature extraction. + +lang_suffix= + +echo "$0 $@" # Print the command line for logging +. utils/parse_options.sh || exit 1; + +. ./path.sh || exit 1; + +echo "Preparing train and test data" +srcdir=data/local/data + +for x in train_si284 train_si84 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + mkdir -p data/$x + cp $srcdir/${x}_wav.scp data/$x/wav.scp || exit 1; + cp $srcdir/$x.txt data/$x/text || exit 1; + cp $srcdir/$x.spk2utt data/$x/spk2utt || exit 1; + cp $srcdir/$x.utt2spk data/$x/utt2spk || exit 1; + utils/filter_scp.pl data/$x/spk2utt $srcdir/spk2gender > data/$x/spk2gender || exit 1; +done + +echo "Succeeded in formatting data." diff --git a/steps b/steps new file mode 120000 index 0000000..905a13a --- /dev/null +++ b/steps @@ -0,0 +1 @@ +./espnet/egs/wsj/asr1/steps \ No newline at end of file diff --git a/utils b/utils new file mode 120000 index 0000000..caed981 --- /dev/null +++ b/utils @@ -0,0 +1 @@ +./espnet/egs/wsj/asr1/utils \ No newline at end of file