Skip to content

Commit

Permalink
Merge pull request #465 from sourabh2k15/dev
Browse files Browse the repository at this point in the history
Fixing librispeech dataset setup scripts
  • Loading branch information
priyakasimbeg authored Aug 7, 2023
2 parents 0db8dbb + 1ceb96d commit 43a514d
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 53 deletions.
49 changes: 32 additions & 17 deletions datasets/dataset_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,11 @@
'The number of threads to use in parallel when decompressing.')

flags.DEFINE_string('framework', None, 'Can be either jax or pytorch.')
flags.DEFINE_boolean('train_tokenizer', True, 'Train Librispeech tokenizer.')
FLAGS = flags.FLAGS

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
tf.config.set_visible_devices([], 'GPU')


def _maybe_mkdir(d):
if not os.path.exists(d):
Expand Down Expand Up @@ -458,17 +460,26 @@ def download_imagenet_v2(data_dir):
data_dir=data_dir).download_and_prepare()


def download_librispeech(dataset_dir, tmp_dir, train_tokenizer):
def download_librispeech(dataset_dir, tmp_dir):
# After extraction the result is a folder named Librispeech containing audio
# files in .flac format along with transcripts containing name of audio file
# and corresponding transcription.
tmp_librispeech_dir = os.path.join(tmp_dir, 'LibriSpeech')
tmp_librispeech_dir = os.path.join(tmp_dir, 'librispeech')
extracted_data_dir = os.path.join(tmp_librispeech_dir, 'LibriSpeech')
final_data_dir = os.path.join(dataset_dir, 'librispeech_processed')

_maybe_mkdir(tmp_librispeech_dir)

for split in ['dev', 'test']:
for version in ['clean', 'other']:
wget_cmd = f'wget http://www.openslr.org/resources/12/{split}-{version}.tar.gz -O - | tar xz' # pylint: disable=line-too-long
subprocess.Popen(wget_cmd, shell=True, cwd=tmp_dir).communicate()
wget_cmd = (
f'wget --directory-prefix={tmp_librispeech_dir} '
f'http://www.openslr.org/resources/12/{split}-{version}.tar.gz')
subprocess.Popen(wget_cmd, shell=True).communicate()
tar_path = os.path.join(tmp_librispeech_dir, f'{split}-{version}.tar.gz')
subprocess.Popen(
f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}',
shell=True).communicate()

tars = [
'raw-metadata.tar.gz',
Expand All @@ -477,19 +488,23 @@ def download_librispeech(dataset_dir, tmp_dir, train_tokenizer):
'train-other-500.tar.gz',
]
for tar_filename in tars:
wget_cmd = f'wget http://www.openslr.org/resources/12/{tar_filename} -O - | tar xz ' # pylint: disable=line-too-long
subprocess.Popen(wget_cmd, shell=True, cwd=tmp_dir).communicate()
wget_cmd = (f'wget --directory-prefix={tmp_librispeech_dir} '
f'http://www.openslr.org/resources/12/{tar_filename}')
subprocess.Popen(wget_cmd, shell=True).communicate()
tar_path = os.path.join(tmp_librispeech_dir, tar_filename)
subprocess.Popen(
f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}',
shell=True).communicate()

tokenizer_vocab_path = os.path.join(extracted_data_dir, 'spm_model.vocab')

if train_tokenizer:
tokenizer_vocab_path = librispeech_tokenizer.run(
train=True, data_dir=tmp_librispeech_dir)
if not os.path.exists(tokenizer_vocab_path):
librispeech_tokenizer.run(train=True, data_dir=extracted_data_dir)

# Preprocess data.
librispeech_dir = os.path.join(dataset_dir, 'librispeech')
librispeech_preprocess.run(
input_dir=tmp_librispeech_dir,
output_dir=librispeech_dir,
tokenizer_vocab_path=tokenizer_vocab_path)
librispeech_preprocess.run(
input_dir=extracted_data_dir,
output_dir=final_data_dir,
tokenizer_vocab_path=tokenizer_vocab_path)


def download_mnist(data_dir):
Expand Down Expand Up @@ -577,7 +592,7 @@ def main(_):

if FLAGS.all or FLAGS.librispeech:
logging.info('Downloading Librispeech...')
download_librispeech(data_dir, tmp_dir, train_tokenizer=True)
download_librispeech(data_dir, tmp_dir)

if FLAGS.all or FLAGS.cifar:
logging.info('Downloading CIFAR...')
Expand Down
18 changes: 0 additions & 18 deletions datasets/librispeech_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import threading
import time

from absl import flags
from absl import logging
import numpy as np
import pandas as pd
Expand All @@ -23,15 +22,6 @@
exists = tf.io.gfile.exists
rename = tf.io.gfile.rename

flags.DEFINE_string('raw_input_dir',
'',
'Path to the raw training data directory.')
flags.DEFINE_string('output_dir', '', 'Dir to write the processed data to.')
flags.DEFINE_string('tokenizer_vocab_path',
'',
'Path to sentence piece tokenizer vocab file.')
FLAGS = flags.FLAGS

TRANSCRIPTION_MAX_LENGTH = 256
AUDIO_MAX_LENGTH = 320000

Expand Down Expand Up @@ -178,11 +168,3 @@ def run(input_dir, output_dir, tokenizer_vocab_path):
'expected count: {} vs expected {}'.format(
num_entries, librispeech_example_counts[subset]))
example_ids.to_csv(os.path.join(output_dir, f'{subset}.csv'))


def main():
run(FLAGS.input_dir, FLAGS.output_dir, FLAGS.tokenizer_vocab_path)


if __name__ == '__main__':
main()
22 changes: 4 additions & 18 deletions datasets/librispeech_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import tempfile
from typing import Dict

from absl import flags
from absl import logging
import sentencepiece as spm
import tensorflow as tf
Expand All @@ -21,13 +20,6 @@

Features = Dict[str, tf.Tensor]

flags.DEFINE_string('input_dir', '', 'Path to training data directory.')
flags.DEFINE_boolean(
'train',
False,
'Whether to train a new tokenizer or load existing one to test.')
FLAGS = flags.FLAGS


def dump_chars_for_training(data_folder, splits, maxchars: int = int(1e7)):
char_count = 0
Expand Down Expand Up @@ -118,13 +110,15 @@ def load_tokenizer(model_filepath):

def run(train, data_dir):
logging.info('Data dir: %s', data_dir)
vocab_path = os.path.join(data_dir, 'spm_model.vocab')
logging.info('vocab_path = ', vocab_path)

if train:
logging.info('Training...')
splits = ['train-clean-100']
return train_tokenizer(data_dir, splits)
train_tokenizer(data_dir, splits, model_path=vocab_path)
else:
tokenizer = load_tokenizer(os.path.join(data_dir, 'spm_model.vocab'))
tokenizer = load_tokenizer(vocab_path)
test_input = 'OPEN SOURCE ROCKS'
tokens = tokenizer.tokenize(test_input)
detokenized = tokenizer.detokenize(tokens).numpy().decode('utf-8')
Expand All @@ -135,11 +129,3 @@ def run(train, data_dir):

if detokenized == test_input:
logging.info('Tokenizer working correctly!')


def main():
run(FLAGS.train, FLAGS.data_dir)


if __name__ == '__main__':
main()

0 comments on commit 43a514d

Please sign in to comment.