diff --git a/bin/import_timit.py b/bin/import_timit.py new file mode 100755 index 00000000..0191eba6 --- /dev/null +++ b/bin/import_timit.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python + +''' + NAME : LDC TIMIT Dataset + URL : https://catalog.ldc.upenn.edu/ldc93s1 + HOURS : 5 + TYPE : Read - English + AUTHORS : Garofolo, John, et al. + TYPE : LDC Membership + LICENCE : LDC User Agreement +''' + +import errno +import os +from os import path +import sys +import tarfile +import fnmatch +import pandas as pd +import subprocess + +def clean(word): + # LC ALL & strip punctuation which are not required + new = word.lower().replace('.', '') + new = new.replace(',', '') + new = new.replace(';', '') + new = new.replace('"', '') + new = new.replace('!', '') + new = new.replace('?', '') + new = new.replace(':', '') + new = new.replace('-', '') + return new + +def _preprocess_data(args): + + # Assume data is downloaded from LDC - https://catalog.ldc.upenn.edu/ldc93s1 + + # SA sentences are repeated throughout by each speaker therefore can be removed for ASR as they will affect WER + ignoreSASentences = True + + if ignoreSASentences: + print("Using recommended ignore SA sentences") + print("Ignoring SA sentences (2 x sentences which are repeated by all speakers)") + else: + print("Using unrecommended setting to include SA sentences") + + datapath = args + target = path.join(datapath, "TIMIT") + print("Checking to see if data has already been extracted in given argument: %s", target) + + if not path.isdir(target): + print("Could not find extracted data, trying to find: TIMIT-LDC93S1.tgz in: ", datapath) + filepath = path.join(datapath, "TIMIT-LDC93S1.tgz") + if path.isfile(filepath): + print("File found, extracting") + tar = tarfile.open(filepath) + tar.extractall(target) + tar.close() + else: + print("File should be downloaded from LDC and placed at:", filepath) + strerror = "File not found" + raise IOError(errno, strerror, filepath) + + else: + # is path therefore continue + print("Found extracted data in: ", target) + + print("Preprocessing data") + # We convert the .WAV (NIST sphere format) into MSOFT .wav + # creates _rif.wav as the new .wav file + for root, dirnames, filenames in os.walk(target): + for filename in fnmatch.filter(filenames, "*.WAV"): + sph_file = os.path.join(root, filename) + wav_file = os.path.join(root, filename)[:-4] + "_rif.wav" + print("converting {} to {}".format(sph_file, wav_file)) + subprocess.check_call(["sox", sph_file, wav_file]) + + print("Preprocessing Complete") + print("Building CSVs") + + # Lists to build CSV files + train_list_wavs, train_list_trans, train_list_size = [], [], [] + test_list_wavs, test_list_trans, test_list_size = [], [], [] + + for root, dirnames, filenames in os.walk(target): + for filename in fnmatch.filter(filenames, "*_rif.wav"): + full_wav = os.path.join(root, filename) + wav_filesize = path.getsize(full_wav) + + # need to remove _rif.wav (8chars) then add .TXT + trans_file = full_wav[:-8] + ".TXT" + with open(trans_file, "r") as f: + for line in f: + split = line.split() + start = split[0] + end = split[1] + t_list = split[2:] + trans = "" + + for t in t_list: + trans = trans + " " + clean(t) + + # if ignoreSAsentences we only want those without SA in the name + # OR + # if not ignoreSAsentences we want all to be added + if (ignoreSASentences and not ('SA' in os.path.basename(full_wav))) or (not ignoreSASentences): + if 'train' in full_wav.lower(): + train_list_wavs.append(full_wav) + train_list_trans.append(trans) + train_list_size.append(wav_filesize) + elif 'test' in full_wav.lower(): + test_list_wavs.append(full_wav) + test_list_trans.append(trans) + test_list_size.append(wav_filesize) + else: + raise IOError + + a = {'wav_filename': train_list_wavs, + 'wav_filesize': train_list_size, + 'transcript': train_list_trans + } + + c = {'wav_filename': test_list_wavs, + 'wav_filesize': test_list_size, + 'transcript': test_list_trans + } + + all = {'wav_filename': train_list_wavs + test_list_wavs, + 'wav_filesize': train_list_size + test_list_size, + 'transcript': train_list_trans + test_list_trans + } + + df_all = pd.DataFrame(all, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int) + df_train = pd.DataFrame(a, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int) + df_test = pd.DataFrame(c, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int) + + df_all.to_csv(target+"/timit_all.csv", sep=',', header=True, index=False, encoding='ascii') + df_train.to_csv(target+"/timit_train.csv", sep=',', header=True, index=False, encoding='ascii') + df_test.to_csv(target+"/timit_test.csv", sep=',', header=True, index=False, encoding='ascii') + +if __name__ == "__main__": + _preprocess_data(sys.argv[1]) + print("Completed") diff --git a/bin/run-timit.sh b/bin/run-timit.sh new file mode 100755 index 00000000..4e3afcb6 --- /dev/null +++ b/bin/run-timit.sh @@ -0,0 +1,29 @@ +#!/bin/sh +set -xe +if [ ! -f DeepSpeech.py ]; then + echo "Please make sure you run this from DeepSpeech's top level directory." + exit 1 +fi; + +if [ ! -f "data/TIMIT/timit_train.csv" ]; then + echo "Trying to preprocess TIMIT data, saving in ./data/TIMIT" + python -u bin/import_timit.py ./data/ +fi; + +if [ -d "${COMPUTE_KEEP_DIR}" ]; then + checkpoint_dir=$COMPUTE_KEEP_DIR +else + checkpoint_dir=$(python -c 'from xdg import BaseDirectory as xdg; print(xdg.save_data_path("deepspeech/timit"))') +fi + +python -u DeepSpeech.py \ + --train_files data/TIMIT/timit_train.csv \ + --dev_files data/TIMIT/timit_test.csv \ + --test_files data/TIMIT/timit_test.csv \ + --train_batch_size 8 \ + --dev_batch_size 8 \ + --test_batch_size 8 \ + --n_hidden 494 \ + --epoch 50 \ + --checkpoint_dir "$checkpoint_dir" \ + "$@"