Reformat importers with black

This commit is contained in:
Reuben Morais 2020-03-31 13:43:30 +02:00
parent b7e6b8c3e6
commit 6f0bf3b3a8
26 changed files with 1545 additions and 784 deletions

View File

@ -1,8 +1,8 @@
#!/usr/bin/env python #!/usr/bin/env python
''' """
Tool for building Sample Databases (SDB files) from DeepSpeech CSV files and other SDB files Tool for building Sample Databases (SDB files) from DeepSpeech CSV files and other SDB files
Use "python3 build_sdb.py -h" for help Use "python3 build_sdb.py -h" for help
''' """
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import argparse import argparse
@ -12,44 +12,60 @@ import progressbar
from deepspeech_training.util.audio import ( from deepspeech_training.util.audio import (
AUDIO_TYPE_OPUS, AUDIO_TYPE_OPUS,
AUDIO_TYPE_WAV, AUDIO_TYPE_WAV,
change_audio_types change_audio_types,
) )
from deepspeech_training.util.downloader import SIMPLE_BAR from deepspeech_training.util.downloader import SIMPLE_BAR
from deepspeech_training.util.sample_collections import ( from deepspeech_training.util.sample_collections import (
DirectSDBWriter, DirectSDBWriter,
samples_from_files samples_from_files,
) )
AUDIO_TYPE_LOOKUP = { AUDIO_TYPE_LOOKUP = {"wav": AUDIO_TYPE_WAV, "opus": AUDIO_TYPE_OPUS}
'wav': AUDIO_TYPE_WAV,
'opus': AUDIO_TYPE_OPUS
}
def build_sdb(): def build_sdb():
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type] audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled) as sdb_writer: with DirectSDBWriter(
CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled
) as sdb_writer:
samples = samples_from_files(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled) samples = samples_from_files(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR)
for sample in bar(change_audio_types(samples, audio_type=audio_type, processes=CLI_ARGS.workers)): for sample in bar(
change_audio_types(
samples, audio_type=audio_type, processes=CLI_ARGS.workers
)
):
sdb_writer.add(sample) sdb_writer.add(sample)
def handle_args(): def handle_args():
parser = argparse.ArgumentParser(description='Tool for building Sample Databases (SDB files) ' parser = argparse.ArgumentParser(
'from DeepSpeech CSV files and other SDB files') description="Tool for building Sample Databases (SDB files) "
parser.add_argument('sources', nargs='+', "from DeepSpeech CSV files and other SDB files"
help='Source CSV and/or SDB files - ' )
'Note: For getting a correctly ordered target SDB, source SDBs have to have their samples ' parser.add_argument(
'already ordered from shortest to longest.') "sources",
parser.add_argument('target', help='SDB file to create') nargs="+",
parser.add_argument('--audio-type', default='opus', choices=AUDIO_TYPE_LOOKUP.keys(), help="Source CSV and/or SDB files - "
help='Audio representation inside target SDB') "Note: For getting a correctly ordered target SDB, source SDBs have to have their samples "
parser.add_argument('--workers', type=int, default=None, "already ordered from shortest to longest.",
help='Number of encoding SDB workers') )
parser.add_argument('--unlabeled', action='store_true', parser.add_argument("target", help="SDB file to create")
help='If to build an SDB with unlabeled (audio only) samples - ' parser.add_argument(
'typically used for building noise augmentation corpora') "--audio-type",
default="opus",
choices=AUDIO_TYPE_LOOKUP.keys(),
help="Audio representation inside target SDB",
)
parser.add_argument(
"--workers", type=int, default=None, help="Number of encoding SDB workers"
)
parser.add_argument(
"--unlabeled",
action="store_true",
help="If to build an SDB with unlabeled (audio only) samples - "
"typically used for building noise augmentation corpora",
)
return parser.parse_args() return parser.parse_args()

View File

@ -9,12 +9,13 @@ from google.protobuf import text_format
def main(): def main():
# Load and export as string # Load and export as string
with tfv1.gfile.FastGFile(sys.argv[1], 'rb') as fin: with tfv1.gfile.FastGFile(sys.argv[1], "rb") as fin:
graph_def = tfv1.GraphDef() graph_def = tfv1.GraphDef()
graph_def.ParseFromString(fin.read()) graph_def.ParseFromString(fin.read())
with tfv1.gfile.FastGFile(sys.argv[1] + 'txt', 'w') as fout: with tfv1.gfile.FastGFile(sys.argv[1] + "txt", "w") as fout:
fout.write(text_format.MessageToString(graph_def)) fout.write(text_format.MessageToString(graph_def))
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View File

@ -9,11 +9,11 @@ import pandas
from deepspeech_training.util.importers import get_importers_parser from deepspeech_training.util.importers import get_importers_parser
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript'] COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
def extract(archive_path, target_dir): def extract(archive_path, target_dir):
print('Extracting {} into {}...'.format(archive_path, target_dir)) print("Extracting {} into {}...".format(archive_path, target_dir))
with tarfile.open(archive_path) as tar: with tarfile.open(archive_path) as tar:
tar.extractall(target_dir) tar.extractall(target_dir)
@ -21,9 +21,9 @@ def extract(archive_path, target_dir):
def preprocess_data(tgz_file, target_dir): def preprocess_data(tgz_file, target_dir):
# First extract main archive and sub-archives # First extract main archive and sub-archives
extract(tgz_file, target_dir) extract(tgz_file, target_dir)
main_folder = os.path.join(target_dir, 'aidatatang_200zh') main_folder = os.path.join(target_dir, "aidatatang_200zh")
for targz in glob.glob(os.path.join(main_folder, 'corpus', '*', '*.tar.gz')): for targz in glob.glob(os.path.join(main_folder, "corpus", "*", "*.tar.gz")):
extract(targz, os.path.dirname(targz)) extract(targz, os.path.dirname(targz))
# Folder structure is now: # Folder structure is now:
@ -42,9 +42,11 @@ def preprocess_data(tgz_file, target_dir):
# Since the transcripts themselves can contain spaces, we split on space but # Since the transcripts themselves can contain spaces, we split on space but
# only once, then build a mapping from file name to transcript # only once, then build a mapping from file name to transcript
transcripts_path = os.path.join(main_folder, 'transcript', 'aidatatang_200_zh_transcript.txt') transcripts_path = os.path.join(
main_folder, "transcript", "aidatatang_200_zh_transcript.txt"
)
with open(transcripts_path) as fin: with open(transcripts_path) as fin:
transcripts = dict((line.split(' ', maxsplit=1) for line in fin)) transcripts = dict((line.split(" ", maxsplit=1) for line in fin))
def load_set(glob_path): def load_set(glob_path):
set_files = [] set_files = []
@ -53,33 +55,39 @@ def preprocess_data(tgz_file, target_dir):
wav_filename = wav wav_filename = wav
wav_filesize = os.path.getsize(wav) wav_filesize = os.path.getsize(wav)
transcript_key = os.path.splitext(os.path.basename(wav))[0] transcript_key = os.path.splitext(os.path.basename(wav))[0]
transcript = transcripts[transcript_key].strip('\n') transcript = transcripts[transcript_key].strip("\n")
set_files.append((wav_filename, wav_filesize, transcript)) set_files.append((wav_filename, wav_filesize, transcript))
except KeyError: except KeyError:
print('Warning: Missing transcript for WAV file {}.'.format(wav)) print("Warning: Missing transcript for WAV file {}.".format(wav))
return set_files return set_files
for subset in ('train', 'dev', 'test'): for subset in ("train", "dev", "test"):
print('Loading {} set samples...'.format(subset)) print("Loading {} set samples...".format(subset))
subset_files = load_set(os.path.join(main_folder, 'corpus', subset, '*', '*.wav')) subset_files = load_set(
os.path.join(main_folder, "corpus", subset, "*", "*.wav")
)
df = pandas.DataFrame(data=subset_files, columns=COLUMN_NAMES) df = pandas.DataFrame(data=subset_files, columns=COLUMN_NAMES)
# Trim train set to under 10s by removing the last couple hundred samples # Trim train set to under 10s by removing the last couple hundred samples
if subset == 'train': if subset == "train":
durations = (df['wav_filesize'] - 44) / 16000 / 2 durations = (df["wav_filesize"] - 44) / 16000 / 2
df = df[durations <= 10.0] df = df[durations <= 10.0]
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum())) print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
dest_csv = os.path.join(target_dir, 'aidatatang_{}.csv'.format(subset)) dest_csv = os.path.join(target_dir, "aidatatang_{}.csv".format(subset))
print('Saving {} set into {}...'.format(subset, dest_csv)) print("Saving {} set into {}...".format(subset, dest_csv))
df.to_csv(dest_csv, index=False) df.to_csv(dest_csv, index=False)
def main(): def main():
# https://www.openslr.org/62/ # https://www.openslr.org/62/
parser = get_importers_parser(description='Import aidatatang_200zh corpus') parser = get_importers_parser(description="Import aidatatang_200zh corpus")
parser.add_argument('tgz_file', help='Path to aidatatang_200zh.tgz') parser.add_argument("tgz_file", help="Path to aidatatang_200zh.tgz")
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.') parser.add_argument(
"--target_dir",
default="",
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
)
params = parser.parse_args() params = parser.parse_args()
if not params.target_dir: if not params.target_dir:

View File

@ -9,11 +9,11 @@ import pandas
from deepspeech_training.util.importers import get_importers_parser from deepspeech_training.util.importers import get_importers_parser
COLUMNNAMES = ['wav_filename', 'wav_filesize', 'transcript'] COLUMNNAMES = ["wav_filename", "wav_filesize", "transcript"]
def extract(archive_path, target_dir): def extract(archive_path, target_dir):
print('Extracting {} into {}...'.format(archive_path, target_dir)) print("Extracting {} into {}...".format(archive_path, target_dir))
with tarfile.open(archive_path) as tar: with tarfile.open(archive_path) as tar:
tar.extractall(target_dir) tar.extractall(target_dir)
@ -21,10 +21,10 @@ def extract(archive_path, target_dir):
def preprocess_data(tgz_file, target_dir): def preprocess_data(tgz_file, target_dir):
# First extract main archive and sub-archives # First extract main archive and sub-archives
extract(tgz_file, target_dir) extract(tgz_file, target_dir)
main_folder = os.path.join(target_dir, 'data_aishell') main_folder = os.path.join(target_dir, "data_aishell")
wav_archives_folder = os.path.join(main_folder, 'wav') wav_archives_folder = os.path.join(main_folder, "wav")
for targz in glob.glob(os.path.join(wav_archives_folder, '*.tar.gz')): for targz in glob.glob(os.path.join(wav_archives_folder, "*.tar.gz")):
extract(targz, main_folder) extract(targz, main_folder)
# Folder structure is now: # Folder structure is now:
@ -41,9 +41,11 @@ def preprocess_data(tgz_file, target_dir):
# Since the transcripts themselves can contain spaces, we split on space but # Since the transcripts themselves can contain spaces, we split on space but
# only once, then build a mapping from file name to transcript # only once, then build a mapping from file name to transcript
transcripts_path = os.path.join(main_folder, 'transcript', 'aishell_transcript_v0.8.txt') transcripts_path = os.path.join(
main_folder, "transcript", "aishell_transcript_v0.8.txt"
)
with open(transcripts_path) as fin: with open(transcripts_path) as fin:
transcripts = dict((line.split(' ', maxsplit=1) for line in fin)) transcripts = dict((line.split(" ", maxsplit=1) for line in fin))
def load_set(glob_path): def load_set(glob_path):
set_files = [] set_files = []
@ -52,33 +54,37 @@ def preprocess_data(tgz_file, target_dir):
wav_filename = wav wav_filename = wav
wav_filesize = os.path.getsize(wav) wav_filesize = os.path.getsize(wav)
transcript_key = os.path.splitext(os.path.basename(wav))[0] transcript_key = os.path.splitext(os.path.basename(wav))[0]
transcript = transcripts[transcript_key].strip('\n') transcript = transcripts[transcript_key].strip("\n")
set_files.append((wav_filename, wav_filesize, transcript)) set_files.append((wav_filename, wav_filesize, transcript))
except KeyError: except KeyError:
print('Warning: Missing transcript for WAV file {}.'.format(wav)) print("Warning: Missing transcript for WAV file {}.".format(wav))
return set_files return set_files
for subset in ('train', 'dev', 'test'): for subset in ("train", "dev", "test"):
print('Loading {} set samples...'.format(subset)) print("Loading {} set samples...".format(subset))
subset_files = load_set(os.path.join(main_folder, subset, 'S*', '*.wav')) subset_files = load_set(os.path.join(main_folder, subset, "S*", "*.wav"))
df = pandas.DataFrame(data=subset_files, columns=COLUMNNAMES) df = pandas.DataFrame(data=subset_files, columns=COLUMNNAMES)
# Trim train set to under 10s by removing the last couple hundred samples # Trim train set to under 10s by removing the last couple hundred samples
if subset == 'train': if subset == "train":
durations = (df['wav_filesize'] - 44) / 16000 / 2 durations = (df["wav_filesize"] - 44) / 16000 / 2
df = df[durations <= 10.0] df = df[durations <= 10.0]
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum())) print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
dest_csv = os.path.join(target_dir, 'aishell_{}.csv'.format(subset)) dest_csv = os.path.join(target_dir, "aishell_{}.csv".format(subset))
print('Saving {} set into {}...'.format(subset, dest_csv)) print("Saving {} set into {}...".format(subset, dest_csv))
df.to_csv(dest_csv, index=False) df.to_csv(dest_csv, index=False)
def main(): def main():
# http://www.openslr.org/33/ # http://www.openslr.org/33/
parser = get_importers_parser(description='Import AISHELL corpus') parser = get_importers_parser(description="Import AISHELL corpus")
parser.add_argument('aishell_tgz_file', help='Path to data_aishell.tgz') parser.add_argument("aishell_tgz_file", help="Path to data_aishell.tgz")
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.') parser.add_argument(
"--target_dir",
default="",
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
)
params = parser.parse_args() params = parser.parse_args()
if not params.target_dir: if not params.target_dir:

View File

@ -15,17 +15,19 @@ from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import ( from deepspeech_training.util.importers import (
get_counter, get_counter,
get_imported_samples, get_imported_samples,
print_import_report print_import_report,
) )
from deepspeech_training.util.importers import \ from deepspeech_training.util.importers import validate_label_eng as validate_label
validate_label_eng as validate_label
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
MAX_SECS = 10 MAX_SECS = 10
ARCHIVE_DIR_NAME = 'cv_corpus_v1' ARCHIVE_DIR_NAME = "cv_corpus_v1"
ARCHIVE_NAME = ARCHIVE_DIR_NAME + '.tar.gz' ARCHIVE_NAME = ARCHIVE_DIR_NAME + ".tar.gz"
ARCHIVE_URL = 'https://s3.us-east-2.amazonaws.com/common-voice-data-download/' + ARCHIVE_NAME ARCHIVE_URL = (
"https://s3.us-east-2.amazonaws.com/common-voice-data-download/" + ARCHIVE_NAME
)
def _download_and_preprocess_data(target_dir): def _download_and_preprocess_data(target_dir):
# Making path absolute # Making path absolute
@ -37,6 +39,7 @@ def _download_and_preprocess_data(target_dir):
# Conditionally convert common voice CSV files and mp3 data to DeepSpeech CSVs and wav # Conditionally convert common voice CSV files and mp3 data to DeepSpeech CSVs and wav
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME) _maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
def _maybe_extract(target_dir, extracted_data, archive_path): def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir # If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = os.join(target_dir, extracted_data) extracted_path = os.join(target_dir, extracted_data)
@ -47,43 +50,56 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
else: else:
print('Found directory "%s" - not extracting it from archive.' % extracted_path) print('Found directory "%s" - not extracting it from archive.' % extracted_path)
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = os.path.join(target_dir, extracted_data) extracted_dir = os.path.join(target_dir, extracted_data)
for source_csv in glob(os.path.join(extracted_dir, '*.csv')): for source_csv in glob(os.path.join(extracted_dir, "*.csv")):
_maybe_convert_set(extracted_dir, source_csv, os.path.join(target_dir, os.path.split(source_csv)[-1])) _maybe_convert_set(
extracted_dir,
source_csv,
os.path.join(target_dir, os.path.split(source_csv)[-1]),
)
def one_sample(sample): def one_sample(sample):
mp3_filename = sample[0] mp3_filename = sample[0]
# Storing wav files next to the mp3 ones - just with a different suffix # Storing wav files next to the mp3 ones - just with a different suffix
wav_filename = path.splitext(mp3_filename)[0] + ".wav" wav_filename = path.splitext(mp3_filename)[0] + ".wav"
_maybe_convert_wav(mp3_filename, wav_filename) _maybe_convert_wav(mp3_filename, wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(
subprocess.check_output(["soxi", "-s", wav_filename], stderr=subprocess.STDOUT)
)
file_size = -1 file_size = -1
if os.path.exists(wav_filename): if os.path.exists(wav_filename):
file_size = path.getsize(wav_filename) file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = validate_label(sample[1]) label = validate_label(sample[1])
rows = [] rows = []
counter = get_counter() counter = get_counter()
if file_size == -1: if file_size == -1:
# Excluding samples that failed upon conversion # Excluding samples that failed upon conversion
counter['failed'] += 1 counter["failed"] += 1
elif label is None: elif label is None:
# Excluding samples that failed on label validation # Excluding samples that failed on label validation
counter['invalid_label'] += 1 counter["invalid_label"] += 1
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)): elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript # Excluding samples that are too short to fit the transcript
counter['too_short'] += 1 counter["too_short"] += 1
elif frames / SAMPLE_RATE > MAX_SECS: elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size # Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1 counter["too_long"] += 1
else: else:
# This one is good - keep it for the target CSV # This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label)) rows.append((wav_filename, file_size, label))
counter['all'] += 1 counter["all"] += 1
counter['total_time'] += frames counter["total_time"] += frames
return (counter, rows) return (counter, rows)
def _maybe_convert_set(extracted_dir, source_csv, target_csv): def _maybe_convert_set(extracted_dir, source_csv, target_csv):
print() print()
if os.path.exists(target_csv): if os.path.exists(target_csv):
@ -94,14 +110,14 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
with open(source_csv) as source_csv_file: with open(source_csv) as source_csv_file:
reader = csv.DictReader(source_csv_file) reader = csv.DictReader(source_csv_file)
for row in reader: for row in reader:
samples.append((os.path.join(extracted_dir, row['filename']), row['text'])) samples.append((os.path.join(extracted_dir, row["filename"]), row["text"]))
# Mutable counters for the concurrent embedded routine # Mutable counters for the concurrent embedded routine
counter = get_counter() counter = get_counter()
num_samples = len(samples) num_samples = len(samples)
rows = [] rows = []
print('Importing mp3 files...') print("Importing mp3 files...")
pool = Pool() pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1): for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
@ -113,19 +129,26 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
pool.join() pool.join()
print('Writing "%s"...' % target_csv) print('Writing "%s"...' % target_csv)
with open(target_csv, 'w') as target_csv_file: with open(target_csv, "w") as target_csv_file:
writer = csv.DictWriter(target_csv_file, fieldnames=FIELDNAMES) writer = csv.DictWriter(target_csv_file, fieldnames=FIELDNAMES)
writer.writeheader() writer.writeheader()
bar = progressbar.ProgressBar(max_value=len(rows), widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=len(rows), widgets=SIMPLE_BAR)
for filename, file_size, transcript in bar(rows): for filename, file_size, transcript in bar(rows):
writer.writerow({ 'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript }) writer.writerow(
{
"wav_filename": filename,
"wav_filesize": file_size,
"transcript": transcript,
}
)
imported_samples = get_imported_samples(counter) imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples assert counter["all"] == num_samples
assert len(rows) == imported_samples assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(mp3_filename, wav_filename): def _maybe_convert_wav(mp3_filename, wav_filename):
if not os.path.exists(wav_filename): if not os.path.exists(wav_filename):
transformer = sox.Transformer() transformer = sox.Transformer()
@ -135,5 +158,6 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
except sox.core.SoxError: except sox.core.SoxError:
pass pass
if __name__ == "__main__": if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1]) _download_and_preprocess_data(sys.argv[1])

View File

@ -1,11 +1,11 @@
#!/usr/bin/env python #!/usr/bin/env python
''' """
Broadly speaking, this script takes the audio downloaded from Common Voice Broadly speaking, this script takes the audio downloaded from Common Voice
for a certain language, in addition to the *.tsv files output by CorporaCreator, for a certain language, in addition to the *.tsv files output by CorporaCreator,
and the script formats the data and transcripts to be in a state usable by and the script formats the data and transcripts to be in a state usable by
DeepSpeech.py DeepSpeech.py
Use "python3 import_cv2.py -h" for help Use "python3 import_cv2.py -h" for help
''' """
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import csv import csv
@ -23,26 +23,27 @@ from deepspeech_training.util.importers import (
get_imported_samples, get_imported_samples,
get_importers_parser, get_importers_parser,
get_validate_label, get_validate_label,
print_import_report print_import_report,
) )
from deepspeech_training.util.text import Alphabet from deepspeech_training.util.text import Alphabet
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
MAX_SECS = 10 MAX_SECS = 10
def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False): def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
for dataset in ['train', 'test', 'dev', 'validated', 'other']: for dataset in ["train", "test", "dev", "validated", "other"]:
input_tsv = os.path.join(os.path.abspath(tsv_dir), dataset + ".tsv") input_tsv = os.path.join(os.path.abspath(tsv_dir), dataset + ".tsv")
if os.path.isfile(input_tsv): if os.path.isfile(input_tsv):
print("Loading TSV file: ", input_tsv) print("Loading TSV file: ", input_tsv)
_maybe_convert_set(input_tsv, audio_dir, space_after_every_character) _maybe_convert_set(input_tsv, audio_dir, space_after_every_character)
def one_sample(sample): def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """ """ Take a audio file, and optionally convert it to 16kHz WAV """
mp3_filename = sample[0] mp3_filename = sample[0]
if not os.path.splitext(mp3_filename.lower())[1] == '.mp3': if not os.path.splitext(mp3_filename.lower())[1] == ".mp3":
mp3_filename += ".mp3" mp3_filename += ".mp3"
# Storing wav files next to the mp3 ones - just with a different suffix # Storing wav files next to the mp3 ones - just with a different suffix
wav_filename = os.path.splitext(mp3_filename)[0] + ".wav" wav_filename = os.path.splitext(mp3_filename)[0] + ".wav"
@ -51,40 +52,47 @@ def one_sample(sample):
frames = 0 frames = 0
if os.path.exists(wav_filename): if os.path.exists(wav_filename):
file_size = os.path.getsize(wav_filename) file_size = os.path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = label_filter_fun(sample[1]) label = label_filter_fun(sample[1])
rows = [] rows = []
counter = get_counter() counter = get_counter()
if file_size == -1: if file_size == -1:
# Excluding samples that failed upon conversion # Excluding samples that failed upon conversion
counter['failed'] += 1 counter["failed"] += 1
elif label is None: elif label is None:
# Excluding samples that failed on label validation # Excluding samples that failed on label validation
counter['invalid_label'] += 1 counter["invalid_label"] += 1
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)): elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript # Excluding samples that are too short to fit the transcript
counter['too_short'] += 1 counter["too_short"] += 1
elif frames / SAMPLE_RATE > MAX_SECS: elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size # Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1 counter["too_long"] += 1
else: else:
# This one is good - keep it for the target CSV # This one is good - keep it for the target CSV
rows.append((os.path.split(wav_filename)[-1], file_size, label)) rows.append((os.path.split(wav_filename)[-1], file_size, label))
counter['all'] += 1 counter["all"] += 1
counter['total_time'] += frames counter["total_time"] += frames
return (counter, rows) return (counter, rows)
def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None): def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
output_csv = os.path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv')) output_csv = os.path.join(
audio_dir, os.path.split(input_tsv)[-1].replace("tsv", "csv")
)
print("Saving new DeepSpeech-formatted CSV file to: ", output_csv) print("Saving new DeepSpeech-formatted CSV file to: ", output_csv)
# Get audiofile path and transcript for each sentence in tsv # Get audiofile path and transcript for each sentence in tsv
samples = [] samples = []
with open(input_tsv, encoding='utf-8') as input_tsv_file: with open(input_tsv, encoding="utf-8") as input_tsv_file:
reader = csv.DictReader(input_tsv_file, delimiter='\t') reader = csv.DictReader(input_tsv_file, delimiter="\t")
for row in reader: for row in reader:
samples.append((os.path.join(audio_dir, row['path']), row['sentence'])) samples.append((os.path.join(audio_dir, row["path"]), row["sentence"]))
counter = get_counter() counter = get_counter()
num_samples = len(samples) num_samples = len(samples)
@ -101,19 +109,31 @@ def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
pool.close() pool.close()
pool.join() pool.join()
with open(output_csv, 'w', encoding='utf-8') as output_csv_file: with open(output_csv, "w", encoding="utf-8") as output_csv_file:
print('Writing CSV file for DeepSpeech.py as: ', output_csv) print("Writing CSV file for DeepSpeech.py as: ", output_csv)
writer = csv.DictWriter(output_csv_file, fieldnames=FIELDNAMES) writer = csv.DictWriter(output_csv_file, fieldnames=FIELDNAMES)
writer.writeheader() writer.writeheader()
bar = progressbar.ProgressBar(max_value=len(rows), widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=len(rows), widgets=SIMPLE_BAR)
for filename, file_size, transcript in bar(rows): for filename, file_size, transcript in bar(rows):
if space_after_every_character: if space_after_every_character:
writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': ' '.join(transcript)}) writer.writerow(
{
"wav_filename": filename,
"wav_filesize": file_size,
"transcript": " ".join(transcript),
}
)
else: else:
writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript}) writer.writerow(
{
"wav_filename": filename,
"wav_filesize": file_size,
"transcript": transcript,
}
)
imported_samples = get_imported_samples(counter) imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples assert counter["all"] == num_samples
assert len(rows) == imported_samples assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
@ -130,24 +150,42 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
if __name__ == "__main__": if __name__ == "__main__":
PARSER = get_importers_parser(description='Import CommonVoice v2.0 corpora') PARSER = get_importers_parser(description="Import CommonVoice v2.0 corpora")
PARSER.add_argument('tsv_dir', help='Directory containing tsv files') PARSER.add_argument("tsv_dir", help="Directory containing tsv files")
PARSER.add_argument('--audio_dir', help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"') PARSER.add_argument(
PARSER.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet') "--audio_dir",
PARSER.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones') help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"',
PARSER.add_argument('--space_after_every_character', action='store_true', help='To help transcript join by white space') )
PARSER.add_argument(
"--filter_alphabet",
help="Exclude samples with characters not in provided alphabet",
)
PARSER.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
PARSER.add_argument(
"--space_after_every_character",
action="store_true",
help="To help transcript join by white space",
)
PARAMS = PARSER.parse_args() PARAMS = PARSER.parse_args()
validate_label = get_validate_label(PARAMS) validate_label = get_validate_label(PARAMS)
AUDIO_DIR = PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, 'clips') AUDIO_DIR = (
PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, "clips")
)
ALPHABET = Alphabet(PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None ALPHABET = Alphabet(PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None
def label_filter_fun(label): def label_filter_fun(label):
if PARAMS.normalize: if PARAMS.normalize:
label = unicodedata.normalize("NFKD", label.strip()) \ label = (
.encode("ascii", "ignore") \ unicodedata.normalize("NFKD", label.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore") .decode("ascii", "ignore")
)
label = validate_label(label) label = validate_label(label)
if ALPHABET and label: if ALPHABET and label:
try: try:

View File

@ -12,14 +12,12 @@ import librosa
import pandas import pandas
import soundfile # <= Has an external dependency on libsndfile import soundfile # <= Has an external dependency on libsndfile
from deepspeech_training.util.importers import \ from deepspeech_training.util.importers import validate_label_eng as validate_label
validate_label_eng as validate_label
# Prerequisite: Having the sph2pipe tool in your PATH: # Prerequisite: Having the sph2pipe tool in your PATH:
# https://www.ldc.upenn.edu/language-resources/tools/sphere-conversion-tools # https://www.ldc.upenn.edu/language-resources/tools/sphere-conversion-tools
def _download_and_preprocess_data(data_dir): def _download_and_preprocess_data(data_dir):
# Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19 # Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19
@ -28,33 +26,55 @@ def _download_and_preprocess_data(data_dir):
_maybe_convert_wav(data_dir, "LDC2005S13", "fisher-2005-wav") _maybe_convert_wav(data_dir, "LDC2005S13", "fisher-2005-wav")
# Conditionally split Fisher wav data # Conditionally split Fisher wav data
all_2004 = _split_wav_and_sentences(data_dir, all_2004 = _split_wav_and_sentences(
data_dir,
original_data="fisher-2004-wav", original_data="fisher-2004-wav",
converted_data="fisher-2004-split-wav", converted_data="fisher-2004-split-wav",
trans_data=os.path.join("LDC2004T19", "fe_03_p1_tran", "data", "trans")) trans_data=os.path.join("LDC2004T19", "fe_03_p1_tran", "data", "trans"),
all_2005 = _split_wav_and_sentences(data_dir, )
all_2005 = _split_wav_and_sentences(
data_dir,
original_data="fisher-2005-wav", original_data="fisher-2005-wav",
converted_data="fisher-2005-split-wav", converted_data="fisher-2005-split-wav",
trans_data=os.path.join("LDC2005T19", "fe_03_p2_tran", "data", "trans")) trans_data=os.path.join("LDC2005T19", "fe_03_p2_tran", "data", "trans"),
)
# The following files have incorrect transcripts that are much longer than # The following files have incorrect transcripts that are much longer than
# their audio source. The result is that we end up with more labels than time # their audio source. The result is that we end up with more labels than time
# slices, which breaks CTC. # slices, which breaks CTC.
all_2004.loc[all_2004["wav_filename"].str.endswith("fe_03_00265-33.53-33.81.wav"), "transcript"] = "correct" all_2004.loc[
all_2004.loc[all_2004["wav_filename"].str.endswith("fe_03_00991-527.39-528.3.wav"), "transcript"] = "that's one of those" all_2004["wav_filename"].str.endswith("fe_03_00265-33.53-33.81.wav"),
all_2005.loc[all_2005["wav_filename"].str.endswith("fe_03_10282-344.42-344.84.wav"), "transcript"] = "they don't want" "transcript",
all_2005.loc[all_2005["wav_filename"].str.endswith("fe_03_10677-101.04-106.41.wav"), "transcript"] = "uh my mine yeah the german shepherd pitbull mix he snores almost as loud as i do" ] = "correct"
all_2004.loc[
all_2004["wav_filename"].str.endswith("fe_03_00991-527.39-528.3.wav"),
"transcript",
] = "that's one of those"
all_2005.loc[
all_2005["wav_filename"].str.endswith("fe_03_10282-344.42-344.84.wav"),
"transcript",
] = "they don't want"
all_2005.loc[
all_2005["wav_filename"].str.endswith("fe_03_10677-101.04-106.41.wav"),
"transcript",
] = "uh my mine yeah the german shepherd pitbull mix he snores almost as loud as i do"
# The following file is just a short sound and not at all transcribed like provided. # The following file is just a short sound and not at all transcribed like provided.
# So we just exclude it. # So we just exclude it.
all_2004 = all_2004[~all_2004["wav_filename"].str.endswith("fe_03_00027-393.8-394.05.wav")] all_2004 = all_2004[
~all_2004["wav_filename"].str.endswith("fe_03_00027-393.8-394.05.wav")
]
# The following file is far too long and would ruin our training batch size. # The following file is far too long and would ruin our training batch size.
# So we just exclude it. # So we just exclude it.
all_2005 = all_2005[~all_2005["wav_filename"].str.endswith("fe_03_11487-31.09-234.06.wav")] all_2005 = all_2005[
~all_2005["wav_filename"].str.endswith("fe_03_11487-31.09-234.06.wav")
]
# The following file is too large for its transcript, so we just exclude it. # The following file is too large for its transcript, so we just exclude it.
all_2004 = all_2004[~all_2004["wav_filename"].str.endswith("fe_03_01326-307.42-307.93.wav")] all_2004 = all_2004[
~all_2004["wav_filename"].str.endswith("fe_03_01326-307.42-307.93.wav")
]
# Conditionally split Fisher data into train/validation/test sets # Conditionally split Fisher data into train/validation/test sets
train_2004, dev_2004, test_2004 = _split_sets(all_2004) train_2004, dev_2004, test_2004 = _split_sets(all_2004)
@ -70,6 +90,7 @@ def _download_and_preprocess_data(data_dir):
dev_files.to_csv(os.path.join(data_dir, "fisher-dev.csv"), index=False) dev_files.to_csv(os.path.join(data_dir, "fisher-dev.csv"), index=False)
test_files.to_csv(os.path.join(data_dir, "fisher-test.csv"), index=False) test_files.to_csv(os.path.join(data_dir, "fisher-test.csv"), index=False)
def _maybe_convert_wav(data_dir, original_data, converted_data): def _maybe_convert_wav(data_dir, original_data, converted_data):
source_dir = os.path.join(data_dir, original_data) source_dir = os.path.join(data_dir, original_data)
target_dir = os.path.join(data_dir, converted_data) target_dir = os.path.join(data_dir, converted_data)
@ -87,10 +108,18 @@ def _maybe_convert_wav(data_dir, original_data, converted_data):
for filename in fnmatch.filter(filenames, "*.sph"): for filename in fnmatch.filter(filenames, "*.sph"):
sph_file = os.path.join(root, filename) sph_file = os.path.join(root, filename)
for channel in ["1", "2"]: for channel in ["1", "2"]:
wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "_c" + channel + ".wav" wav_filename = (
os.path.splitext(os.path.basename(sph_file))[0]
+ "_c"
+ channel
+ ".wav"
)
wav_file = os.path.join(target_dir, wav_filename) wav_file = os.path.join(target_dir, wav_filename)
print("converting {} to {}".format(sph_file, wav_file)) print("converting {} to {}".format(sph_file, wav_file))
subprocess.check_call(["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, wav_file]) subprocess.check_call(
["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, wav_file]
)
def _parse_transcriptions(trans_file): def _parse_transcriptions(trans_file):
segments = [] segments = []
@ -108,18 +137,23 @@ def _parse_transcriptions(trans_file):
# We need to do the encode-decode dance here because encode # We need to do the encode-decode dance here because encode
# returns a bytes() object on Python 3, and text_to_char_array # returns a bytes() object on Python 3, and text_to_char_array
# expects a string. # expects a string.
transcript = unicodedata.normalize("NFKD", transcript) \ transcript = (
.encode("ascii", "ignore") \ unicodedata.normalize("NFKD", transcript)
.encode("ascii", "ignore")
.decode("ascii", "ignore") .decode("ascii", "ignore")
)
segments.append({ segments.append(
{
"start_time": start_time, "start_time": start_time,
"stop_time": stop_time, "stop_time": stop_time,
"speaker": speaker, "speaker": speaker,
"transcript": transcript, "transcript": transcript,
}) }
)
return segments return segments
def _split_wav_and_sentences(data_dir, trans_data, original_data, converted_data): def _split_wav_and_sentences(data_dir, trans_data, original_data, converted_data):
trans_dir = os.path.join(data_dir, trans_data) trans_dir = os.path.join(data_dir, trans_data)
source_dir = os.path.join(data_dir, original_data) source_dir = os.path.join(data_dir, original_data)
@ -136,43 +170,73 @@ def _split_wav_and_sentences(data_dir, trans_data, original_data, converted_data
segments = _parse_transcriptions(trans_file) segments = _parse_transcriptions(trans_file)
# Open wav corresponding to transcription file # Open wav corresponding to transcription file
wav_filenames = [os.path.splitext(os.path.basename(trans_file))[0] + "_c" + channel + ".wav" for channel in ["1", "2"]] wav_filenames = [
wav_files = [os.path.join(source_dir, wav_filename) for wav_filename in wav_filenames] os.path.splitext(os.path.basename(trans_file))[0]
+ "_c"
+ channel
+ ".wav"
for channel in ["1", "2"]
]
wav_files = [
os.path.join(source_dir, wav_filename) for wav_filename in wav_filenames
]
print("splitting {} according to {}".format(wav_files, trans_file)) print("splitting {} according to {}".format(wav_files, trans_file))
origAudios = [librosa.load(wav_file, sr=16000, mono=False) for wav_file in wav_files] origAudios = [
librosa.load(wav_file, sr=16000, mono=False) for wav_file in wav_files
]
# Loop over segments and split wav_file for each segment # Loop over segments and split wav_file for each segment
for segment in segments: for segment in segments:
# Create wav segment filename # Create wav segment filename
start_time = segment["start_time"] start_time = segment["start_time"]
stop_time = segment["stop_time"] stop_time = segment["stop_time"]
new_wav_filename = os.path.splitext(os.path.basename(trans_file))[0] + "-" + str(start_time) + "-" + str(stop_time) + ".wav" new_wav_filename = (
os.path.splitext(os.path.basename(trans_file))[0]
+ "-"
+ str(start_time)
+ "-"
+ str(stop_time)
+ ".wav"
)
new_wav_file = os.path.join(target_dir, new_wav_filename) new_wav_file = os.path.join(target_dir, new_wav_filename)
channel = 0 if segment["speaker"] == "A:" else 1 channel = 0 if segment["speaker"] == "A:" else 1
_split_and_resample_wav(origAudios[channel], start_time, stop_time, new_wav_file) _split_and_resample_wav(
origAudios[channel], start_time, stop_time, new_wav_file
)
new_wav_filesize = os.path.getsize(new_wav_file) new_wav_filesize = os.path.getsize(new_wav_file)
transcript = validate_label(segment["transcript"]) transcript = validate_label(segment["transcript"])
if transcript != None: if transcript != None:
files.append((os.path.abspath(new_wav_file), new_wav_filesize, transcript)) files.append(
(os.path.abspath(new_wav_file), new_wav_filesize, transcript)
)
return pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
)
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
def _split_audio(origAudio, start_time, stop_time): def _split_audio(origAudio, start_time, stop_time):
audioData, frameRate = origAudio audioData, frameRate = origAudio
nChannels = len(audioData.shape) nChannels = len(audioData.shape)
startIndex = int(start_time * frameRate) startIndex = int(start_time * frameRate)
stopIndex = int(stop_time * frameRate) stopIndex = int(stop_time * frameRate)
return audioData[startIndex: stopIndex] if 1 == nChannels else audioData[:, startIndex: stopIndex] return (
audioData[startIndex:stopIndex]
if 1 == nChannels
else audioData[:, startIndex:stopIndex]
)
def _split_and_resample_wav(origAudio, start_time, stop_time, new_wav_file): def _split_and_resample_wav(origAudio, start_time, stop_time, new_wav_file):
frameRate = origAudio[1] frameRate = origAudio[1]
chunkData = _split_audio(origAudio, start_time, stop_time) chunkData = _split_audio(origAudio, start_time, stop_time)
soundfile.write(new_wav_file, chunkData, frameRate, "PCM_16") soundfile.write(new_wav_file, chunkData, frameRate, "PCM_16")
def _split_sets(filelist): def _split_sets(filelist):
# We initially split the entire set into 80% train and 20% test, then # We initially split the entire set into 80% train and 20% test, then
# split the train set into 80% train and 20% validation. # split the train set into 80% train and 20% validation.
@ -186,9 +250,12 @@ def _split_sets(filelist):
test_beg = dev_end test_beg = dev_end
test_end = len(filelist) test_end = len(filelist)
return (filelist[train_beg:train_end], return (
filelist[train_beg:train_end],
filelist[dev_beg:dev_end], filelist[dev_beg:dev_end],
filelist[test_beg:test_end]) filelist[test_beg:test_end],
)
if __name__ == "__main__": if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1]) _download_and_preprocess_data(sys.argv[1])

View File

@ -10,11 +10,11 @@ import pandas
from deepspeech_training.util.importers import get_importers_parser from deepspeech_training.util.importers import get_importers_parser
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript'] COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
def extract(archive_path, target_dir): def extract(archive_path, target_dir):
print('Extracting {} into {}...'.format(archive_path, target_dir)) print("Extracting {} into {}...".format(archive_path, target_dir))
with tarfile.open(archive_path) as tar: with tarfile.open(archive_path) as tar:
tar.extractall(target_dir) tar.extractall(target_dir)
@ -22,7 +22,7 @@ def extract(archive_path, target_dir):
def preprocess_data(tgz_file, target_dir): def preprocess_data(tgz_file, target_dir):
# First extract main archive and sub-archives # First extract main archive and sub-archives
extract(tgz_file, target_dir) extract(tgz_file, target_dir)
main_folder = os.path.join(target_dir, 'ST-CMDS-20170001_1-OS') main_folder = os.path.join(target_dir, "ST-CMDS-20170001_1-OS")
# Folder structure is now: # Folder structure is now:
# - ST-CMDS-20170001_1-OS/ # - ST-CMDS-20170001_1-OS/
@ -35,16 +35,16 @@ def preprocess_data(tgz_file, target_dir):
for wav in glob.glob(glob_path): for wav in glob.glob(glob_path):
wav_filename = wav wav_filename = wav
wav_filesize = os.path.getsize(wav) wav_filesize = os.path.getsize(wav)
txt_filename = os.path.splitext(wav_filename)[0] + '.txt' txt_filename = os.path.splitext(wav_filename)[0] + ".txt"
with open(txt_filename, 'r') as fin: with open(txt_filename, "r") as fin:
transcript = fin.read() transcript = fin.read()
set_files.append((wav_filename, wav_filesize, transcript)) set_files.append((wav_filename, wav_filesize, transcript))
return set_files return set_files
# Load all files, then deterministically split into train/dev/test sets # Load all files, then deterministically split into train/dev/test sets
all_files = load_set(os.path.join(main_folder, '*.wav')) all_files = load_set(os.path.join(main_folder, "*.wav"))
df = pandas.DataFrame(data=all_files, columns=COLUMN_NAMES) df = pandas.DataFrame(data=all_files, columns=COLUMN_NAMES)
df.sort_values(by='wav_filename', inplace=True) df.sort_values(by="wav_filename", inplace=True)
indices = np.arange(0, len(df)) indices = np.arange(0, len(df))
np.random.seed(12345) np.random.seed(12345)
@ -57,29 +57,33 @@ def preprocess_data(tgz_file, target_dir):
train_indices = indices[:-10000] train_indices = indices[:-10000]
train_files = df.iloc[train_indices] train_files = df.iloc[train_indices]
durations = (train_files['wav_filesize'] - 44) / 16000 / 2 durations = (train_files["wav_filesize"] - 44) / 16000 / 2
train_files = train_files[durations <= 10.0] train_files = train_files[durations <= 10.0]
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum())) print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
dest_csv = os.path.join(target_dir, 'freestmandarin_train.csv') dest_csv = os.path.join(target_dir, "freestmandarin_train.csv")
print('Saving train set into {}...'.format(dest_csv)) print("Saving train set into {}...".format(dest_csv))
train_files.to_csv(dest_csv, index=False) train_files.to_csv(dest_csv, index=False)
dev_files = df.iloc[dev_indices] dev_files = df.iloc[dev_indices]
dest_csv = os.path.join(target_dir, 'freestmandarin_dev.csv') dest_csv = os.path.join(target_dir, "freestmandarin_dev.csv")
print('Saving dev set into {}...'.format(dest_csv)) print("Saving dev set into {}...".format(dest_csv))
dev_files.to_csv(dest_csv, index=False) dev_files.to_csv(dest_csv, index=False)
test_files = df.iloc[test_indices] test_files = df.iloc[test_indices]
dest_csv = os.path.join(target_dir, 'freestmandarin_test.csv') dest_csv = os.path.join(target_dir, "freestmandarin_test.csv")
print('Saving test set into {}...'.format(dest_csv)) print("Saving test set into {}...".format(dest_csv))
test_files.to_csv(dest_csv, index=False) test_files.to_csv(dest_csv, index=False)
def main(): def main():
# https://www.openslr.org/38/ # https://www.openslr.org/38/
parser = get_importers_parser(description='Import Free ST Chinese Mandarin corpus') parser = get_importers_parser(description="Import Free ST Chinese Mandarin corpus")
parser.add_argument('tgz_file', help='Path to ST-CMDS-20170001_1-OS.tar.gz') parser.add_argument("tgz_file", help="Path to ST-CMDS-20170001_1-OS.tar.gz")
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.') parser.add_argument(
"--target_dir",
default="",
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
)
params = parser.parse_args() params = parser.parse_args()
if not params.target_dir: if not params.target_dir:

View File

@ -12,10 +12,7 @@ import pandas as pd
from sox import Transformer from sox import Transformer
import swifter import swifter
from deepspeech_training.util.importers import ( from deepspeech_training.util.importers import get_importers_parser, get_validate_label
get_importers_parser,
get_validate_label
)
__version__ = "0.1.0" __version__ = "0.1.0"
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -37,9 +34,7 @@ def parse_args(args):
Returns: Returns:
:obj:`argparse.Namespace`: command line parameters namespace :obj:`argparse.Namespace`: command line parameters namespace
""" """
parser = get_importers_parser( parser = get_importers_parser(description="Imports GramVaani data for Deep Speech")
description="Imports GramVaani data for Deep Speech"
)
parser.add_argument( parser.add_argument(
"--version", "--version",
action="version", action="version",
@ -79,6 +74,7 @@ def parse_args(args):
) )
return parser.parse_args(args) return parser.parse_args(args)
def setup_logging(level): def setup_logging(level):
"""Setup basic logging """Setup basic logging
Args: Args:
@ -89,6 +85,7 @@ def setup_logging(level):
level=level, stream=sys.stdout, format=format, datefmt="%Y-%m-%d %H:%M:%S" level=level, stream=sys.stdout, format=format, datefmt="%Y-%m-%d %H:%M:%S"
) )
class GramVaaniCSV: class GramVaaniCSV:
"""GramVaaniCSV representing a GramVaani dataset. """GramVaaniCSV representing a GramVaani dataset.
Args: Args:
@ -104,7 +101,16 @@ class GramVaaniCSV:
_logger.info("Parsing csv file...%s", os.path.abspath(csv_filename)) _logger.info("Parsing csv file...%s", os.path.abspath(csv_filename))
data = pd.read_csv( data = pd.read_csv(
os.path.abspath(csv_filename), os.path.abspath(csv_filename),
names=["piece_id","audio_url","transcript_labelled","transcript","labels","content_filename","audio_length","user_id"], names=[
"piece_id",
"audio_url",
"transcript_labelled",
"transcript",
"labels",
"content_filename",
"audio_length",
"user_id",
],
usecols=["audio_url", "transcript", "audio_length"], usecols=["audio_url", "transcript", "audio_length"],
skiprows=[0], skiprows=[0],
engine="python", engine="python",
@ -116,6 +122,7 @@ class GramVaaniCSV:
_logger.info("Parsed %d lines csv file." % len(data)) _logger.info("Parsed %d lines csv file." % len(data))
return data return data
class GramVaaniDownloader: class GramVaaniDownloader:
"""GramVaaniDownloader downloads a GramVaani dataset. """GramVaaniDownloader downloads a GramVaani dataset.
Args: Args:
@ -135,7 +142,9 @@ class GramVaaniDownloader:
mp3_directory (os.path): The directory into which the associated mp3's were downloaded mp3_directory (os.path): The directory into which the associated mp3's were downloaded
""" """
mp3_directory = self._pre_download() mp3_directory = self._pre_download()
self.data.swifter.apply(func=lambda arg: self._download(*arg, mp3_directory), axis=1, raw=True) self.data.swifter.apply(
func=lambda arg: self._download(*arg, mp3_directory), axis=1, raw=True
)
return mp3_directory return mp3_directory
def _pre_download(self): def _pre_download(self):
@ -158,6 +167,7 @@ class GramVaaniDownloader:
else: else:
_logger.debug("Already downloaded mp3 file...%s", audio_url) _logger.debug("Already downloaded mp3 file...%s", audio_url)
class GramVaaniConverter: class GramVaaniConverter:
"""GramVaaniConverter converts the mp3's to wav's for a GramVaani dataset. """GramVaaniConverter converts the mp3's to wav's for a GramVaani dataset.
Args: Args:
@ -178,15 +188,26 @@ class GramVaaniConverter:
wav_directory (os.path): The directory into which the associated wav's were downloaded wav_directory (os.path): The directory into which the associated wav's were downloaded
""" """
wav_directory = self._pre_convert() wav_directory = self._pre_convert()
for mp3_filename in self.mp3_directory.glob('**/*.mp3'): for mp3_filename in self.mp3_directory.glob("**/*.mp3"):
wav_filename = os.path.join(wav_directory, os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav") wav_filename = os.path.join(
wav_directory,
os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav",
)
if not os.path.exists(wav_filename): if not os.path.exists(wav_filename):
_logger.debug("Converting mp3 file %s to wav file %s" % (mp3_filename, wav_filename)) _logger.debug(
"Converting mp3 file %s to wav file %s"
% (mp3_filename, wav_filename)
)
transformer = Transformer() transformer = Transformer()
transformer.convert(samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH) transformer.convert(
samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH
)
transformer.build(str(mp3_filename), str(wav_filename)) transformer.build(str(mp3_filename), str(wav_filename))
else: else:
_logger.debug("Already converted mp3 file %s to wav file %s" % (mp3_filename, wav_filename)) _logger.debug(
"Already converted mp3 file %s to wav file %s"
% (mp3_filename, wav_filename)
)
return wav_directory return wav_directory
def _pre_convert(self): def _pre_convert(self):
@ -199,14 +220,19 @@ class GramVaaniConverter:
os.mkdir(wav_directory) os.mkdir(wav_directory)
return wav_directory return wav_directory
class GramVaaniDataSets: class GramVaaniDataSets:
def __init__(self, target_dir, wav_directory, gram_vaani_csv): def __init__(self, target_dir, wav_directory, gram_vaani_csv):
self.target_dir = target_dir self.target_dir = target_dir
self.wav_directory = wav_directory self.wav_directory = wav_directory
self.csv_data = gram_vaani_csv.data self.csv_data = gram_vaani_csv.data
self.raw = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"]) self.raw = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
self.valid = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"]) self.valid = pd.DataFrame(
self.train = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"]) columns=["wav_filename", "wav_filesize", "transcript"]
)
self.train = pd.DataFrame(
columns=["wav_filename", "wav_filesize", "transcript"]
)
self.dev = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"]) self.dev = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
self.test = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"]) self.test = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
@ -218,20 +244,30 @@ class GramVaaniDataSets:
train_size, dev_size, test_size = self._calculate_data_set_sizes() train_size, dev_size, test_size = self._calculate_data_set_sizes()
self.train = self.valid.loc[0:train_size] self.train = self.valid.loc[0:train_size]
self.dev = self.valid.loc[train_size : train_size + dev_size] self.dev = self.valid.loc[train_size : train_size + dev_size]
self.test = self.valid.loc[train_size+dev_size:train_size+dev_size+test_size] self.test = self.valid.loc[
train_size + dev_size : train_size + dev_size + test_size
]
def _convert_csv_data_to_raw_data(self): def _convert_csv_data_to_raw_data(self):
self.raw[["wav_filename", "wav_filesize", "transcript"]] = self.csv_data[ self.raw[["wav_filename", "wav_filesize", "transcript"]] = self.csv_data[
["audio_url", "transcript", "audio_length"] ["audio_url", "transcript", "audio_length"]
].swifter.apply(func=lambda arg: self._convert_csv_data_to_raw_data_impl(*arg), axis=1, raw=True) ].swifter.apply(
func=lambda arg: self._convert_csv_data_to_raw_data_impl(*arg),
axis=1,
raw=True,
)
self.raw.reset_index() self.raw.reset_index()
def _convert_csv_data_to_raw_data_impl(self, audio_url, transcript, audio_length): def _convert_csv_data_to_raw_data_impl(self, audio_url, transcript, audio_length):
if audio_url == "audio_url": if audio_url == "audio_url":
return pd.Series(["wav_filename", "wav_filesize", "transcript"]) return pd.Series(["wav_filename", "wav_filesize", "transcript"])
mp3_filename = os.path.basename(audio_url) mp3_filename = os.path.basename(audio_url)
wav_relative_filename = os.path.join("wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav") wav_relative_filename = os.path.join(
wav_filesize = os.path.getsize(os.path.join(self.target_dir, wav_relative_filename)) "wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav"
)
wav_filesize = os.path.getsize(
os.path.join(self.target_dir, wav_relative_filename)
)
transcript = validate_label(transcript) transcript = validate_label(transcript)
if None == transcript: if None == transcript:
transcript = "" transcript = ""
@ -240,7 +276,12 @@ class GramVaaniDataSets:
def _is_valid_raw_rows(self): def _is_valid_raw_rows(self):
is_valid_raw_transcripts = self._is_valid_raw_transcripts() is_valid_raw_transcripts = self._is_valid_raw_transcripts()
is_valid_raw_wav_frames = self._is_valid_raw_wav_frames() is_valid_raw_wav_frames = self._is_valid_raw_wav_frames()
is_valid_raw_row = [(is_valid_raw_transcript & is_valid_raw_wav_frame) for is_valid_raw_transcript, is_valid_raw_wav_frame in zip(is_valid_raw_transcripts, is_valid_raw_wav_frames)] is_valid_raw_row = [
(is_valid_raw_transcript & is_valid_raw_wav_frame)
for is_valid_raw_transcript, is_valid_raw_wav_frame in zip(
is_valid_raw_transcripts, is_valid_raw_wav_frames
)
]
series = pd.Series(is_valid_raw_row) series = pd.Series(is_valid_raw_row)
return series return series
@ -249,9 +290,22 @@ class GramVaaniDataSets:
def _is_valid_raw_wav_frames(self): def _is_valid_raw_wav_frames(self):
transcripts = [str(transcript) for transcript in self.raw.transcript] transcripts = [str(transcript) for transcript in self.raw.transcript]
wav_filepaths = [os.path.join(self.target_dir, str(wav_filename)) for wav_filename in self.raw.wav_filename] wav_filepaths = [
wav_frames = [int(subprocess.check_output(['soxi', '-s', wav_filepath], stderr=subprocess.STDOUT)) for wav_filepath in wav_filepaths] os.path.join(self.target_dir, str(wav_filename))
is_valid_raw_wav_frames = [self._is_wav_frame_valid(wav_frame, transcript) for wav_frame, transcript in zip(wav_frames, transcripts)] for wav_filename in self.raw.wav_filename
]
wav_frames = [
int(
subprocess.check_output(
["soxi", "-s", wav_filepath], stderr=subprocess.STDOUT
)
)
for wav_filepath in wav_filepaths
]
is_valid_raw_wav_frames = [
self._is_wav_frame_valid(wav_frame, transcript)
for wav_frame, transcript in zip(wav_frames, transcripts)
]
return pd.Series(is_valid_raw_wav_frames) return pd.Series(is_valid_raw_wav_frames)
def _is_wav_frame_valid(self, wav_frame, transcript): def _is_wav_frame_valid(self, wav_frame, transcript):
@ -277,7 +331,14 @@ class GramVaaniDataSets:
def _save(self, dataset): def _save(self, dataset):
dataset_path = os.path.join(self.target_dir, dataset + ".csv") dataset_path = os.path.join(self.target_dir, dataset + ".csv")
dataframe = getattr(self, dataset) dataframe = getattr(self, dataset)
dataframe.to_csv(dataset_path, index=False, encoding="utf-8", escapechar='\\', quoting=csv.QUOTE_MINIMAL) dataframe.to_csv(
dataset_path,
index=False,
encoding="utf-8",
escapechar="\\",
quoting=csv.QUOTE_MINIMAL,
)
def main(args): def main(args):
"""Main entry point allowing external calls """Main entry point allowing external calls
@ -301,4 +362,5 @@ def main(args):
datasets.save() datasets.save()
_logger.info("Finished GramVaani importer...") _logger.info("Finished GramVaani importer...")
main(sys.argv[1:]) main(sys.argv[1:])

View File

@ -13,14 +13,23 @@ def _download_and_preprocess_data(data_dir):
# Conditionally download data # Conditionally download data
LDC93S1_BASE = "LDC93S1" LDC93S1_BASE = "LDC93S1"
LDC93S1_BASE_URL = "https://catalog.ldc.upenn.edu/desc/addenda/" LDC93S1_BASE_URL = "https://catalog.ldc.upenn.edu/desc/addenda/"
local_file = maybe_download(LDC93S1_BASE + ".wav", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".wav") local_file = maybe_download(
trans_file = maybe_download(LDC93S1_BASE + ".txt", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".txt") LDC93S1_BASE + ".wav", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".wav"
)
trans_file = maybe_download(
LDC93S1_BASE + ".txt", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".txt"
)
with open(trans_file, "r") as fin: with open(trans_file, "r") as fin:
transcript = ' '.join(fin.read().strip().lower().split(' ')[2:]).replace('.', '') transcript = " ".join(fin.read().strip().lower().split(" ")[2:]).replace(
".", ""
)
df = pandas.DataFrame(data=[(os.path.abspath(local_file), os.path.getsize(local_file), transcript)], df = pandas.DataFrame(
columns=["wav_filename", "wav_filesize", "transcript"]) data=[(os.path.abspath(local_file), os.path.getsize(local_file), transcript)],
columns=["wav_filename", "wav_filesize", "transcript"],
)
df.to_csv(os.path.join(data_dir, "ldc93s1.csv"), index=False) df.to_csv(os.path.join(data_dir, "ldc93s1.csv"), index=False)
if __name__ == "__main__": if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1]) _download_and_preprocess_data(sys.argv[1])

View File

@ -18,13 +18,24 @@ from deepspeech_training.util.downloader import maybe_download
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
def _download_and_preprocess_data(data_dir): def _download_and_preprocess_data(data_dir):
# Conditionally download data to data_dir # Conditionally download data to data_dir
print("Downloading Librivox data set (55GB) into {} if not already present...".format(data_dir)) print(
"Downloading Librivox data set (55GB) into {} if not already present...".format(
data_dir
)
)
with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar: with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar:
TRAIN_CLEAN_100_URL = "http://www.openslr.org/resources/12/train-clean-100.tar.gz" TRAIN_CLEAN_100_URL = (
TRAIN_CLEAN_360_URL = "http://www.openslr.org/resources/12/train-clean-360.tar.gz" "http://www.openslr.org/resources/12/train-clean-100.tar.gz"
TRAIN_OTHER_500_URL = "http://www.openslr.org/resources/12/train-other-500.tar.gz" )
TRAIN_CLEAN_360_URL = (
"http://www.openslr.org/resources/12/train-clean-360.tar.gz"
)
TRAIN_OTHER_500_URL = (
"http://www.openslr.org/resources/12/train-other-500.tar.gz"
)
DEV_CLEAN_URL = "http://www.openslr.org/resources/12/dev-clean.tar.gz" DEV_CLEAN_URL = "http://www.openslr.org/resources/12/dev-clean.tar.gz"
DEV_OTHER_URL = "http://www.openslr.org/resources/12/dev-other.tar.gz" DEV_OTHER_URL = "http://www.openslr.org/resources/12/dev-other.tar.gz"
@ -32,12 +43,20 @@ def _download_and_preprocess_data(data_dir):
TEST_CLEAN_URL = "http://www.openslr.org/resources/12/test-clean.tar.gz" TEST_CLEAN_URL = "http://www.openslr.org/resources/12/test-clean.tar.gz"
TEST_OTHER_URL = "http://www.openslr.org/resources/12/test-other.tar.gz" TEST_OTHER_URL = "http://www.openslr.org/resources/12/test-other.tar.gz"
def filename_of(x): return os.path.split(x)[1] def filename_of(x):
train_clean_100 = maybe_download(filename_of(TRAIN_CLEAN_100_URL), data_dir, TRAIN_CLEAN_100_URL) return os.path.split(x)[1]
train_clean_100 = maybe_download(
filename_of(TRAIN_CLEAN_100_URL), data_dir, TRAIN_CLEAN_100_URL
)
bar.update(0) bar.update(0)
train_clean_360 = maybe_download(filename_of(TRAIN_CLEAN_360_URL), data_dir, TRAIN_CLEAN_360_URL) train_clean_360 = maybe_download(
filename_of(TRAIN_CLEAN_360_URL), data_dir, TRAIN_CLEAN_360_URL
)
bar.update(1) bar.update(1)
train_other_500 = maybe_download(filename_of(TRAIN_OTHER_500_URL), data_dir, TRAIN_OTHER_500_URL) train_other_500 = maybe_download(
filename_of(TRAIN_OTHER_500_URL), data_dir, TRAIN_OTHER_500_URL
)
bar.update(2) bar.update(2)
dev_clean = maybe_download(filename_of(DEV_CLEAN_URL), data_dir, DEV_CLEAN_URL) dev_clean = maybe_download(filename_of(DEV_CLEAN_URL), data_dir, DEV_CLEAN_URL)
@ -45,9 +64,13 @@ def _download_and_preprocess_data(data_dir):
dev_other = maybe_download(filename_of(DEV_OTHER_URL), data_dir, DEV_OTHER_URL) dev_other = maybe_download(filename_of(DEV_OTHER_URL), data_dir, DEV_OTHER_URL)
bar.update(4) bar.update(4)
test_clean = maybe_download(filename_of(TEST_CLEAN_URL), data_dir, TEST_CLEAN_URL) test_clean = maybe_download(
filename_of(TEST_CLEAN_URL), data_dir, TEST_CLEAN_URL
)
bar.update(5) bar.update(5)
test_other = maybe_download(filename_of(TEST_OTHER_URL), data_dir, TEST_OTHER_URL) test_other = maybe_download(
filename_of(TEST_OTHER_URL), data_dir, TEST_OTHER_URL
)
bar.update(6) bar.update(6)
# Conditionally extract LibriSpeech data # Conditionally extract LibriSpeech data
@ -58,11 +81,17 @@ def _download_and_preprocess_data(data_dir):
LIBRIVOX_DIR = "LibriSpeech" LIBRIVOX_DIR = "LibriSpeech"
work_dir = os.path.join(data_dir, LIBRIVOX_DIR) work_dir = os.path.join(data_dir, LIBRIVOX_DIR)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-100"), train_clean_100) _maybe_extract(
data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-100"), train_clean_100
)
bar.update(0) bar.update(0)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-360"), train_clean_360) _maybe_extract(
data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-360"), train_clean_360
)
bar.update(1) bar.update(1)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-other-500"), train_other_500) _maybe_extract(
data_dir, os.path.join(LIBRIVOX_DIR, "train-other-500"), train_other_500
)
bar.update(2) bar.update(2)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "dev-clean"), dev_clean) _maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "dev-clean"), dev_clean)
@ -89,27 +118,47 @@ def _download_and_preprocess_data(data_dir):
# ... # ...
print("Converting FLAC to WAV and splitting transcriptions...") print("Converting FLAC to WAV and splitting transcriptions...")
with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar: with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar:
train_100 = _convert_audio_and_split_sentences(work_dir, "train-clean-100", "train-clean-100-wav") train_100 = _convert_audio_and_split_sentences(
work_dir, "train-clean-100", "train-clean-100-wav"
)
bar.update(0) bar.update(0)
train_360 = _convert_audio_and_split_sentences(work_dir, "train-clean-360", "train-clean-360-wav") train_360 = _convert_audio_and_split_sentences(
work_dir, "train-clean-360", "train-clean-360-wav"
)
bar.update(1) bar.update(1)
train_500 = _convert_audio_and_split_sentences(work_dir, "train-other-500", "train-other-500-wav") train_500 = _convert_audio_and_split_sentences(
work_dir, "train-other-500", "train-other-500-wav"
)
bar.update(2) bar.update(2)
dev_clean = _convert_audio_and_split_sentences(work_dir, "dev-clean", "dev-clean-wav") dev_clean = _convert_audio_and_split_sentences(
work_dir, "dev-clean", "dev-clean-wav"
)
bar.update(3) bar.update(3)
dev_other = _convert_audio_and_split_sentences(work_dir, "dev-other", "dev-other-wav") dev_other = _convert_audio_and_split_sentences(
work_dir, "dev-other", "dev-other-wav"
)
bar.update(4) bar.update(4)
test_clean = _convert_audio_and_split_sentences(work_dir, "test-clean", "test-clean-wav") test_clean = _convert_audio_and_split_sentences(
work_dir, "test-clean", "test-clean-wav"
)
bar.update(5) bar.update(5)
test_other = _convert_audio_and_split_sentences(work_dir, "test-other", "test-other-wav") test_other = _convert_audio_and_split_sentences(
work_dir, "test-other", "test-other-wav"
)
bar.update(6) bar.update(6)
# Write sets to disk as CSV files # Write sets to disk as CSV files
train_100.to_csv(os.path.join(data_dir, "librivox-train-clean-100.csv"), index=False) train_100.to_csv(
train_360.to_csv(os.path.join(data_dir, "librivox-train-clean-360.csv"), index=False) os.path.join(data_dir, "librivox-train-clean-100.csv"), index=False
train_500.to_csv(os.path.join(data_dir, "librivox-train-other-500.csv"), index=False) )
train_360.to_csv(
os.path.join(data_dir, "librivox-train-clean-360.csv"), index=False
)
train_500.to_csv(
os.path.join(data_dir, "librivox-train-other-500.csv"), index=False
)
dev_clean.to_csv(os.path.join(data_dir, "librivox-dev-clean.csv"), index=False) dev_clean.to_csv(os.path.join(data_dir, "librivox-dev-clean.csv"), index=False)
dev_other.to_csv(os.path.join(data_dir, "librivox-dev-other.csv"), index=False) dev_other.to_csv(os.path.join(data_dir, "librivox-dev-other.csv"), index=False)
@ -117,6 +166,7 @@ def _download_and_preprocess_data(data_dir):
test_clean.to_csv(os.path.join(data_dir, "librivox-test-clean.csv"), index=False) test_clean.to_csv(os.path.join(data_dir, "librivox-test-clean.csv"), index=False)
test_other.to_csv(os.path.join(data_dir, "librivox-test-other.csv"), index=False) test_other.to_csv(os.path.join(data_dir, "librivox-test-other.csv"), index=False)
def _maybe_extract(data_dir, extracted_data, archive): def _maybe_extract(data_dir, extracted_data, archive):
# If data_dir/extracted_data does not exist, extract archive in data_dir # If data_dir/extracted_data does not exist, extract archive in data_dir
if not gfile.Exists(os.path.join(data_dir, extracted_data)): if not gfile.Exists(os.path.join(data_dir, extracted_data)):
@ -124,6 +174,7 @@ def _maybe_extract(data_dir, extracted_data, archive):
tar.extractall(data_dir) tar.extractall(data_dir)
tar.close() tar.close()
def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir): def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
source_dir = os.path.join(extracted_dir, data_set) source_dir = os.path.join(extracted_dir, data_set)
target_dir = os.path.join(extracted_dir, dest_dir) target_dir = os.path.join(extracted_dir, dest_dir)
@ -146,7 +197,7 @@ def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
# We also convert the corresponding FLACs to WAV in the same pass # We also convert the corresponding FLACs to WAV in the same pass
files = [] files = []
for root, dirnames, filenames in os.walk(source_dir): for root, dirnames, filenames in os.walk(source_dir):
for filename in fnmatch.filter(filenames, '*.trans.txt'): for filename in fnmatch.filter(filenames, "*.trans.txt"):
trans_filename = os.path.join(root, filename) trans_filename = os.path.join(root, filename)
with codecs.open(trans_filename, "r", "utf-8") as fin: with codecs.open(trans_filename, "r", "utf-8") as fin:
for line in fin: for line in fin:
@ -157,9 +208,11 @@ def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
# We need to do the encode-decode dance here because encode # We need to do the encode-decode dance here because encode
# returns a bytes() object on Python 3, and text_to_char_array # returns a bytes() object on Python 3, and text_to_char_array
# expects a string. # expects a string.
transcript = unicodedata.normalize("NFKD", transcript) \ transcript = (
.encode("ascii", "ignore") \ unicodedata.normalize("NFKD", transcript)
.encode("ascii", "ignore")
.decode("ascii", "ignore") .decode("ascii", "ignore")
)
transcript = transcript.lower().strip() transcript = transcript.lower().strip()
@ -174,7 +227,10 @@ def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
files.append((os.path.abspath(wav_file), wav_filesize, transcript)) files.append((os.path.abspath(wav_file), wav_filesize, transcript))
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"]) return pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
)
if __name__ == "__main__": if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1]) _download_and_preprocess_data(sys.argv[1])

View File

@ -20,17 +20,17 @@ from deepspeech_training.util.importers import (
get_imported_samples, get_imported_samples,
get_importers_parser, get_importers_parser,
get_validate_label, get_validate_label,
print_import_report print_import_report,
) )
from deepspeech_training.util.text import Alphabet from deepspeech_training.util.text import Alphabet
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
MAX_SECS = 10 MAX_SECS = 10
ARCHIVE_DIR_NAME = 'lingua_libre' ARCHIVE_DIR_NAME = "lingua_libre"
ARCHIVE_NAME = 'Q{qId}-{iso639_3}-{language_English_name}.zip' ARCHIVE_NAME = "Q{qId}-{iso639_3}-{language_English_name}.zip"
ARCHIVE_URL = 'https://lingualibre.fr/datasets/' + ARCHIVE_NAME ARCHIVE_URL = "https://lingualibre.fr/datasets/" + ARCHIVE_NAME
def _download_and_preprocess_data(target_dir): def _download_and_preprocess_data(target_dir):
@ -43,6 +43,7 @@ def _download_and_preprocess_data(target_dir):
# Produce CSV files and convert ogg data to wav # Produce CSV files and convert ogg data to wav
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME) _maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
def _maybe_extract(target_dir, extracted_data, archive_path): def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir # If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = os.path.join(target_dir, extracted_data) extracted_path = os.path.join(target_dir, extracted_data)
@ -55,6 +56,7 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
else: else:
print('Found directory "%s" - not extracting it from archive.' % archive_path) print('Found directory "%s" - not extracting it from archive.' % archive_path)
def one_sample(sample): def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """ """ Take a audio file, and optionally convert it to 16kHz WAV """
ogg_filename = sample[0] ogg_filename = sample[0]
@ -65,47 +67,59 @@ def one_sample(sample):
frames = 0 frames = 0
if os.path.exists(wav_filename): if os.path.exists(wav_filename):
file_size = os.path.getsize(wav_filename) file_size = os.path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = label_filter(sample[1]) label = label_filter(sample[1])
rows = [] rows = []
counter = get_counter() counter = get_counter()
if file_size == -1: if file_size == -1:
# Excluding samples that failed upon conversion # Excluding samples that failed upon conversion
counter['failed'] += 1 counter["failed"] += 1
elif label is None: elif label is None:
# Excluding samples that failed on label validation # Excluding samples that failed on label validation
counter['invalid_label'] += 1 counter["invalid_label"] += 1
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)): elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript # Excluding samples that are too short to fit the transcript
counter['too_short'] += 1 counter["too_short"] += 1
elif frames / SAMPLE_RATE > MAX_SECS: elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size # Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1 counter["too_long"] += 1
else: else:
# This one is good - keep it for the target CSV # This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label)) rows.append((wav_filename, file_size, label))
counter['all'] += 1 counter["all"] += 1
counter['total_time'] += frames counter["total_time"] += frames
return (counter, rows) return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = os.path.join(target_dir, extracted_data) extracted_dir = os.path.join(target_dir, extracted_data)
# override existing CSV with normalized one # override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME + '_' + ARCHIVE_NAME.replace('.zip', '_{}.csv')) target_csv_template = os.path.join(
target_dir, ARCHIVE_DIR_NAME + "_" + ARCHIVE_NAME.replace(".zip", "_{}.csv")
)
if os.path.isfile(target_csv_template): if os.path.isfile(target_csv_template):
return return
ogg_root_dir = os.path.join(extracted_dir, ARCHIVE_NAME.replace('.zip', '')) ogg_root_dir = os.path.join(extracted_dir, ARCHIVE_NAME.replace(".zip", ""))
# Get audiofile path and transcript for each sentence in tsv # Get audiofile path and transcript for each sentence in tsv
samples = [] samples = []
glob_dir = os.path.join(ogg_root_dir, '**/*.ogg') glob_dir = os.path.join(ogg_root_dir, "**/*.ogg")
for record in glob(glob_dir, recursive=True): for record in glob(glob_dir, recursive=True):
record_file = record.replace(ogg_root_dir + os.path.sep, '') record_file = record.replace(ogg_root_dir + os.path.sep, "")
if record_filter(record_file): if record_filter(record_file):
samples.append((os.path.join(ogg_root_dir, record_file), os.path.splitext(os.path.basename(record_file))[0])) samples.append(
(
os.path.join(ogg_root_dir, record_file),
os.path.splitext(os.path.basename(record_file))[0],
)
)
counter = get_counter() counter = get_counter()
num_samples = len(samples) num_samples = len(samples)
@ -122,9 +136,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
pool.close() pool.close()
pool.join() pool.join()
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80% with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10% with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10% with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES) train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
train_writer.writeheader() train_writer.writeheader()
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES) dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
@ -136,7 +150,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
transcript = validate_label(item[2]) transcript = validate_label(item[2])
if not transcript: if not transcript:
continue continue
wav_filename = os.path.join(ogg_root_dir, item[0].replace('.ogg', '.wav')) wav_filename = os.path.join(
ogg_root_dir, item[0].replace(".ogg", ".wav")
)
i_mod = i % 10 i_mod = i % 10
if i_mod == 0: if i_mod == 0:
writer = test_writer writer = test_writer
@ -144,18 +160,21 @@ def _maybe_convert_sets(target_dir, extracted_data):
writer = dev_writer writer = dev_writer
else: else:
writer = train_writer writer = train_writer
writer.writerow(dict( writer.writerow(
dict(
wav_filename=wav_filename, wav_filename=wav_filename,
wav_filesize=os.path.getsize(wav_filename), wav_filesize=os.path.getsize(wav_filename),
transcript=transcript, transcript=transcript,
)) )
)
imported_samples = get_imported_samples(counter) imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples assert counter["all"] == num_samples
assert len(rows) == imported_samples assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(ogg_filename, wav_filename): def _maybe_convert_wav(ogg_filename, wav_filename):
if not os.path.exists(wav_filename): if not os.path.exists(wav_filename):
transformer = sox.Transformer() transformer = sox.Transformer()
@ -163,19 +182,41 @@ def _maybe_convert_wav(ogg_filename, wav_filename):
try: try:
transformer.build(ogg_filename, wav_filename) transformer.build(ogg_filename, wav_filename)
except sox.core.SoxError as ex: except sox.core.SoxError as ex:
print('SoX processing error', ex, ogg_filename, wav_filename) print("SoX processing error", ex, ogg_filename, wav_filename)
def handle_args(): def handle_args():
parser = get_importers_parser(description='Importer for LinguaLibre dataset. Check https://lingualibre.fr/wiki/Help:Download_from_LinguaLibre for details.') parser = get_importers_parser(
parser.add_argument(dest='target_dir') description="Importer for LinguaLibre dataset. Check https://lingualibre.fr/wiki/Help:Download_from_LinguaLibre for details."
parser.add_argument('--qId', type=int, required=True, help='LinguaLibre language qId') )
parser.add_argument('--iso639-3', type=str, required=True, help='ISO639-3 language code') parser.add_argument(dest="target_dir")
parser.add_argument('--english-name', type=str, required=True, help='Enligh name of the language') parser.add_argument(
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet') "--qId", type=int, required=True, help="LinguaLibre language qId"
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones') )
parser.add_argument('--bogus-records', type=argparse.FileType('r'), required=False, help='Text file listing well-known bogus record to skip from importing, from https://lingualibre.fr/wiki/LinguaLibre:Misleading_items') parser.add_argument(
"--iso639-3", type=str, required=True, help="ISO639-3 language code"
)
parser.add_argument(
"--english-name", type=str, required=True, help="Enligh name of the language"
)
parser.add_argument(
"--filter_alphabet",
help="Exclude samples with characters not in provided alphabet",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
parser.add_argument(
"--bogus-records",
type=argparse.FileType("r"),
required=False,
help="Text file listing well-known bogus record to skip from importing, from https://lingualibre.fr/wiki/LinguaLibre:Misleading_items",
)
return parser.parse_args() return parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
CLI_ARGS = handle_args() CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
@ -188,15 +229,17 @@ if __name__ == "__main__":
def record_filter(path): def record_filter(path):
if any(regex.match(path) for regex in bogus_regexes): if any(regex.match(path) for regex in bogus_regexes):
print('Reject', path) print("Reject", path)
return False return False
return True return True
def label_filter(label): def label_filter(label):
if CLI_ARGS.normalize: if CLI_ARGS.normalize:
label = unicodedata.normalize("NFKD", label.strip()) \ label = (
.encode("ascii", "ignore") \ unicodedata.normalize("NFKD", label.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore") .decode("ascii", "ignore")
)
label = validate_label(label) label = validate_label(label)
if ALPHABET and label: if ALPHABET and label:
try: try:
@ -205,6 +248,14 @@ if __name__ == "__main__":
label = None label = None
return label return label
ARCHIVE_NAME = ARCHIVE_NAME.format(qId=CLI_ARGS.qId, iso639_3=CLI_ARGS.iso639_3, language_English_name=CLI_ARGS.english_name) ARCHIVE_NAME = ARCHIVE_NAME.format(
ARCHIVE_URL = ARCHIVE_URL.format(qId=CLI_ARGS.qId, iso639_3=CLI_ARGS.iso639_3, language_English_name=CLI_ARGS.english_name) qId=CLI_ARGS.qId,
iso639_3=CLI_ARGS.iso639_3,
language_English_name=CLI_ARGS.english_name,
)
ARCHIVE_URL = ARCHIVE_URL.format(
qId=CLI_ARGS.qId,
iso639_3=CLI_ARGS.iso639_3,
language_English_name=CLI_ARGS.english_name,
)
_download_and_preprocess_data(target_dir=CLI_ARGS.target_dir) _download_and_preprocess_data(target_dir=CLI_ARGS.target_dir)

View File

@ -18,17 +18,17 @@ from deepspeech_training.util.importers import (
get_imported_samples, get_imported_samples,
get_importers_parser, get_importers_parser,
get_validate_label, get_validate_label,
print_import_report print_import_report,
) )
from deepspeech_training.util.text import Alphabet from deepspeech_training.util.text import Alphabet
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
MAX_SECS = 15 MAX_SECS = 15
ARCHIVE_DIR_NAME = '{language}' ARCHIVE_DIR_NAME = "{language}"
ARCHIVE_NAME = '{language}.tgz' ARCHIVE_NAME = "{language}.tgz"
ARCHIVE_URL = 'http://www.caito.de/data/Training/stt_tts/' + ARCHIVE_NAME ARCHIVE_URL = "http://www.caito.de/data/Training/stt_tts/" + ARCHIVE_NAME
def _download_and_preprocess_data(target_dir): def _download_and_preprocess_data(target_dir):
@ -63,7 +63,11 @@ def one_sample(sample):
frames = 0 frames = 0
if os.path.exists(wav_filename): if os.path.exists(wav_filename):
file_size = os.path.getsize(wav_filename) file_size = os.path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = label_filter(sample[1]) label = label_filter(sample[1])
counter = get_counter() counter = get_counter()
rows = [] rows = []
@ -71,27 +75,30 @@ def one_sample(sample):
if file_size == -1: if file_size == -1:
# Excluding samples that failed upon conversion # Excluding samples that failed upon conversion
print("conversion failure", wav_filename) print("conversion failure", wav_filename)
counter['failed'] += 1 counter["failed"] += 1
elif label is None: elif label is None:
# Excluding samples that failed on label validation # Excluding samples that failed on label validation
counter['invalid_label'] += 1 counter["invalid_label"] += 1
elif int(frames / SAMPLE_RATE * 1000 / 15 / 2) < len(str(label)): elif int(frames / SAMPLE_RATE * 1000 / 15 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript # Excluding samples that are too short to fit the transcript
counter['too_short'] += 1 counter["too_short"] += 1
elif frames / SAMPLE_RATE > MAX_SECS: elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size # Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1 counter["too_long"] += 1
else: else:
# This one is good - keep it for the target CSV # This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label)) rows.append((wav_filename, file_size, label))
counter['all'] += 1 counter["all"] += 1
counter['total_time'] += frames counter["total_time"] += frames
return (counter, rows) return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = os.path.join(target_dir, extracted_data) extracted_dir = os.path.join(target_dir, extracted_data)
# override existing CSV with normalized one # override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tgz', '_{}.csv')) target_csv_template = os.path.join(
target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace(".tgz", "_{}.csv")
)
if os.path.isfile(target_csv_template): if os.path.isfile(target_csv_template):
return return
@ -99,14 +106,16 @@ def _maybe_convert_sets(target_dir, extracted_data):
# Get audiofile path and transcript for each sentence in tsv # Get audiofile path and transcript for each sentence in tsv
samples = [] samples = []
glob_dir = os.path.join(wav_root_dir, '**/metadata.csv') glob_dir = os.path.join(wav_root_dir, "**/metadata.csv")
for record in glob(glob_dir, recursive=True): for record in glob(glob_dir, recursive=True):
if any(map(lambda sk: sk in record, SKIP_LIST)): # pylint: disable=cell-var-from-loop if any(
map(lambda sk: sk in record, SKIP_LIST)
): # pylint: disable=cell-var-from-loop
continue continue
with open(record, 'r') as rec: with open(record, "r") as rec:
for re in rec.readlines(): for re in rec.readlines():
re = re.strip().split('|') re = re.strip().split("|")
audio = os.path.join(os.path.dirname(record), 'wavs', re[0] + '.wav') audio = os.path.join(os.path.dirname(record), "wavs", re[0] + ".wav")
transcript = re[2] transcript = re[2]
samples.append((audio, transcript)) samples.append((audio, transcript))
@ -125,9 +134,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
pool.close() pool.close()
pool.join() pool.join()
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80% with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10% with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10% with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES) train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
train_writer.writeheader() train_writer.writeheader()
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES) dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
@ -147,39 +156,60 @@ def _maybe_convert_sets(target_dir, extracted_data):
writer = dev_writer writer = dev_writer
else: else:
writer = train_writer writer = train_writer
writer.writerow(dict( writer.writerow(
dict(
wav_filename=os.path.relpath(wav_filename, extracted_dir), wav_filename=os.path.relpath(wav_filename, extracted_dir),
wav_filesize=os.path.getsize(wav_filename), wav_filesize=os.path.getsize(wav_filename),
transcript=transcript, transcript=transcript,
)) )
)
imported_samples = get_imported_samples(counter) imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples assert counter["all"] == num_samples
assert len(rows) == imported_samples assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def handle_args(): def handle_args():
parser = get_importers_parser(description='Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/.') parser = get_importers_parser(
parser.add_argument(dest='target_dir') description="Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/."
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet') )
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones') parser.add_argument(dest="target_dir")
parser.add_argument('--skiplist', type=str, default='', help='Directories / books to skip, comma separated') parser.add_argument(
parser.add_argument('--language', required=True, type=str, help='Dataset language to use') "--filter_alphabet",
help="Exclude samples with characters not in provided alphabet",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
parser.add_argument(
"--skiplist",
type=str,
default="",
help="Directories / books to skip, comma separated",
)
parser.add_argument(
"--language", required=True, type=str, help="Dataset language to use"
)
return parser.parse_args() return parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
CLI_ARGS = handle_args() CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(',')) SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(","))
validate_label = get_validate_label(CLI_ARGS) validate_label = get_validate_label(CLI_ARGS)
def label_filter(label): def label_filter(label):
if CLI_ARGS.normalize: if CLI_ARGS.normalize:
label = unicodedata.normalize("NFKD", label.strip()) \ label = (
.encode("ascii", "ignore") \ unicodedata.normalize("NFKD", label.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore") .decode("ascii", "ignore")
)
label = validate_label(label) label = validate_label(label)
if ALPHABET and label: if ALPHABET and label:
try: try:

View File

@ -10,17 +10,17 @@ import pandas
from deepspeech_training.util.importers import get_importers_parser from deepspeech_training.util.importers import get_importers_parser
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript'] COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
def extract(archive_path, target_dir): def extract(archive_path, target_dir):
print('Extracting {} into {}...'.format(archive_path, target_dir)) print("Extracting {} into {}...".format(archive_path, target_dir))
with tarfile.open(archive_path) as tar: with tarfile.open(archive_path) as tar:
tar.extractall(target_dir) tar.extractall(target_dir)
def is_file_truncated(wav_filename, wav_filesize): def is_file_truncated(wav_filename, wav_filesize):
with wave.open(wav_filename, mode='rb') as fin: with wave.open(wav_filename, mode="rb") as fin:
assert fin.getframerate() == 16000 assert fin.getframerate() == 16000
assert fin.getsampwidth() == 2 assert fin.getsampwidth() == 2
assert fin.getnchannels() == 1 assert fin.getnchannels() == 1
@ -33,8 +33,13 @@ def is_file_truncated(wav_filename, wav_filesize):
def preprocess_data(folder_with_archives, target_dir): def preprocess_data(folder_with_archives, target_dir):
# First extract subset archives # First extract subset archives
for subset in ('train', 'dev', 'test'): for subset in ("train", "dev", "test"):
extract(os.path.join(folder_with_archives, 'magicdata_{}_set.tar.gz'.format(subset)), target_dir) extract(
os.path.join(
folder_with_archives, "magicdata_{}_set.tar.gz".format(subset)
),
target_dir,
)
# Folder structure is now: # Folder structure is now:
# - magicdata_{train,dev,test}.tar.gz # - magicdata_{train,dev,test}.tar.gz
@ -50,58 +55,73 @@ def preprocess_data(folder_with_archives, target_dir):
# name, one containing the speaker ID, and one containing the transcription # name, one containing the speaker ID, and one containing the transcription
def load_set(set_path): def load_set(set_path):
transcripts = pandas.read_csv(os.path.join(set_path, 'TRANS.txt'), sep='\t', index_col=0) transcripts = pandas.read_csv(
glob_path = os.path.join(set_path, '*', '*.wav') os.path.join(set_path, "TRANS.txt"), sep="\t", index_col=0
)
glob_path = os.path.join(set_path, "*", "*.wav")
set_files = [] set_files = []
for wav in glob.glob(glob_path): for wav in glob.glob(glob_path):
try: try:
wav_filename = wav wav_filename = wav
wav_filesize = os.path.getsize(wav) wav_filesize = os.path.getsize(wav)
transcript_key = os.path.basename(wav) transcript_key = os.path.basename(wav)
transcript = transcripts.loc[transcript_key, 'Transcription'] transcript = transcripts.loc[transcript_key, "Transcription"]
# Some files in this dataset are truncated, the header duration # Some files in this dataset are truncated, the header duration
# doesn't match the file size. This causes errors at training # doesn't match the file size. This causes errors at training
# time, so check here if things are fine before including a file # time, so check here if things are fine before including a file
if is_file_truncated(wav_filename, wav_filesize): if is_file_truncated(wav_filename, wav_filesize):
print('Warning: File {} is corrupted, header duration does ' print(
'not match file size. Ignoring.'.format(wav_filename)) "Warning: File {} is corrupted, header duration does "
"not match file size. Ignoring.".format(wav_filename)
)
continue continue
set_files.append((wav_filename, wav_filesize, transcript)) set_files.append((wav_filename, wav_filesize, transcript))
except KeyError: except KeyError:
print('Warning: Missing transcript for WAV file {}.'.format(wav)) print("Warning: Missing transcript for WAV file {}.".format(wav))
return set_files return set_files
for subset in ('train', 'dev', 'test'): for subset in ("train", "dev", "test"):
print('Loading {} set samples...'.format(subset)) print("Loading {} set samples...".format(subset))
subset_files = load_set(os.path.join(target_dir, subset)) subset_files = load_set(os.path.join(target_dir, subset))
df = pandas.DataFrame(data=subset_files, columns=COLUMN_NAMES) df = pandas.DataFrame(data=subset_files, columns=COLUMN_NAMES)
# Trim train set to under 10s # Trim train set to under 10s
if subset == 'train': if subset == "train":
durations = (df['wav_filesize'] - 44) / 16000 / 2 durations = (df["wav_filesize"] - 44) / 16000 / 2
df = df[durations <= 10.0] df = df[durations <= 10.0]
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum())) print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
with_noise = df['transcript'].str.contains(r'\[(FIL|SPK)\]') with_noise = df["transcript"].str.contains(r"\[(FIL|SPK)\]")
df = df[~with_noise] df = df[~with_noise]
print('Trimming {} samples with noise ([FIL] or [SPK])'.format(sum(with_noise))) print(
"Trimming {} samples with noise ([FIL] or [SPK])".format(
sum(with_noise)
)
)
dest_csv = os.path.join(target_dir, 'magicdata_{}.csv'.format(subset)) dest_csv = os.path.join(target_dir, "magicdata_{}.csv".format(subset))
print('Saving {} set into {}...'.format(subset, dest_csv)) print("Saving {} set into {}...".format(subset, dest_csv))
df.to_csv(dest_csv, index=False) df.to_csv(dest_csv, index=False)
def main(): def main():
# https://openslr.org/68/ # https://openslr.org/68/
parser = get_importers_parser(description='Import MAGICDATA corpus') parser = get_importers_parser(description="Import MAGICDATA corpus")
parser.add_argument('folder_with_archives', help='Path to folder containing magicdata_{train,dev,test}.tar.gz') parser.add_argument(
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to a folder called magicdata next to the archives') "folder_with_archives",
help="Path to folder containing magicdata_{train,dev,test}.tar.gz",
)
parser.add_argument(
"--target_dir",
default="",
help="Target folder to extract files into and put the resulting CSVs. Defaults to a folder called magicdata next to the archives",
)
params = parser.parse_args() params = parser.parse_args()
if not params.target_dir: if not params.target_dir:
params.target_dir = os.path.join(params.folder_with_archives, 'magicdata') params.target_dir = os.path.join(params.folder_with_archives, "magicdata")
preprocess_data(params.folder_with_archives, params.target_dir) preprocess_data(params.folder_with_archives, params.target_dir)

View File

@ -11,11 +11,11 @@ import pandas
from deepspeech_training.util.importers import get_importers_parser from deepspeech_training.util.importers import get_importers_parser
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript'] COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
def extract(archive_path, target_dir): def extract(archive_path, target_dir):
print('Extracting {} into {}...'.format(archive_path, target_dir)) print("Extracting {} into {}...".format(archive_path, target_dir))
with tarfile.open(archive_path) as tar: with tarfile.open(archive_path) as tar:
tar.extractall(target_dir) tar.extractall(target_dir)
@ -23,7 +23,7 @@ def extract(archive_path, target_dir):
def preprocess_data(tgz_file, target_dir): def preprocess_data(tgz_file, target_dir):
# First extract main archive and sub-archives # First extract main archive and sub-archives
extract(tgz_file, target_dir) extract(tgz_file, target_dir)
main_folder = os.path.join(target_dir, 'primewords_md_2018_set1') main_folder = os.path.join(target_dir, "primewords_md_2018_set1")
# Folder structure is now: # Folder structure is now:
# - primewords_md_2018_set1/ # - primewords_md_2018_set1/
@ -31,14 +31,11 @@ def preprocess_data(tgz_file, target_dir):
# - [0-f]/[00-0f]/*.wav # - [0-f]/[00-0f]/*.wav
# - set1_transcript.json # - set1_transcript.json
transcripts_path = os.path.join(main_folder, 'set1_transcript.json') transcripts_path = os.path.join(main_folder, "set1_transcript.json")
with open(transcripts_path) as fin: with open(transcripts_path) as fin:
transcripts = json.load(fin) transcripts = json.load(fin)
transcripts = { transcripts = {entry["file"]: entry["text"] for entry in transcripts}
entry['file']: entry['text']
for entry in transcripts
}
def load_set(glob_path): def load_set(glob_path):
set_files = [] set_files = []
@ -50,13 +47,13 @@ def preprocess_data(tgz_file, target_dir):
transcript = transcripts[transcript_key] transcript = transcripts[transcript_key]
set_files.append((wav_filename, wav_filesize, transcript)) set_files.append((wav_filename, wav_filesize, transcript))
except KeyError: except KeyError:
print('Warning: Missing transcript for WAV file {}.'.format(wav)) print("Warning: Missing transcript for WAV file {}.".format(wav))
return set_files return set_files
# Load all files, then deterministically split into train/dev/test sets # Load all files, then deterministically split into train/dev/test sets
all_files = load_set(os.path.join(main_folder, 'audio_files', '*', '*', '*.wav')) all_files = load_set(os.path.join(main_folder, "audio_files", "*", "*", "*.wav"))
df = pandas.DataFrame(data=all_files, columns=COLUMN_NAMES) df = pandas.DataFrame(data=all_files, columns=COLUMN_NAMES)
df.sort_values(by='wav_filename', inplace=True) df.sort_values(by="wav_filename", inplace=True)
indices = np.arange(0, len(df)) indices = np.arange(0, len(df))
np.random.seed(12345) np.random.seed(12345)
@ -69,29 +66,33 @@ def preprocess_data(tgz_file, target_dir):
train_indices = indices[:-10000] train_indices = indices[:-10000]
train_files = df.iloc[train_indices] train_files = df.iloc[train_indices]
durations = (train_files['wav_filesize'] - 44) / 16000 / 2 durations = (train_files["wav_filesize"] - 44) / 16000 / 2
train_files = train_files[durations <= 15.0] train_files = train_files[durations <= 15.0]
print('Trimming {} samples > 15 seconds'.format((durations > 15.0).sum())) print("Trimming {} samples > 15 seconds".format((durations > 15.0).sum()))
dest_csv = os.path.join(target_dir, 'primewords_train.csv') dest_csv = os.path.join(target_dir, "primewords_train.csv")
print('Saving train set into {}...'.format(dest_csv)) print("Saving train set into {}...".format(dest_csv))
train_files.to_csv(dest_csv, index=False) train_files.to_csv(dest_csv, index=False)
dev_files = df.iloc[dev_indices] dev_files = df.iloc[dev_indices]
dest_csv = os.path.join(target_dir, 'primewords_dev.csv') dest_csv = os.path.join(target_dir, "primewords_dev.csv")
print('Saving dev set into {}...'.format(dest_csv)) print("Saving dev set into {}...".format(dest_csv))
dev_files.to_csv(dest_csv, index=False) dev_files.to_csv(dest_csv, index=False)
test_files = df.iloc[test_indices] test_files = df.iloc[test_indices]
dest_csv = os.path.join(target_dir, 'primewords_test.csv') dest_csv = os.path.join(target_dir, "primewords_test.csv")
print('Saving test set into {}...'.format(dest_csv)) print("Saving test set into {}...".format(dest_csv))
test_files.to_csv(dest_csv, index=False) test_files.to_csv(dest_csv, index=False)
def main(): def main():
# https://www.openslr.org/47/ # https://www.openslr.org/47/
parser = get_importers_parser(description='Import Primewords Chinese corpus set 1') parser = get_importers_parser(description="Import Primewords Chinese corpus set 1")
parser.add_argument('tgz_file', help='Path to primewords_md_2018_set1.tar.gz') parser.add_argument("tgz_file", help="Path to primewords_md_2018_set1.tar.gz")
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.') parser.add_argument(
"--target_dir",
default="",
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
)
params = parser.parse_args() params = parser.parse_args()
if not params.target_dir: if not params.target_dir:

View File

@ -20,17 +20,17 @@ from deepspeech_training.util.importers import (
get_imported_samples, get_imported_samples,
get_importers_parser, get_importers_parser,
get_validate_label, get_validate_label,
print_import_report print_import_report,
) )
from deepspeech_training.util.text import Alphabet from deepspeech_training.util.text import Alphabet
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
MAX_SECS = 15 MAX_SECS = 15
ARCHIVE_DIR_NAME = 'African_Accented_French' ARCHIVE_DIR_NAME = "African_Accented_French"
ARCHIVE_NAME = 'African_Accented_French.tar.gz' ARCHIVE_NAME = "African_Accented_French.tar.gz"
ARCHIVE_URL = 'http://www.openslr.org/resources/57/' + ARCHIVE_NAME ARCHIVE_URL = "http://www.openslr.org/resources/57/" + ARCHIVE_NAME
def _download_and_preprocess_data(target_dir): def _download_and_preprocess_data(target_dir):
@ -43,6 +43,7 @@ def _download_and_preprocess_data(target_dir):
# Produce CSV files # Produce CSV files
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME) _maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
def _maybe_extract(target_dir, extracted_data, archive_path): def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir # If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = os.path.join(target_dir, extracted_data) extracted_path = os.path.join(target_dir, extracted_data)
@ -56,6 +57,7 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
else: else:
print('Found directory "%s" - not extracting it from archive.' % archive_path) print('Found directory "%s" - not extracting it from archive.' % archive_path)
def one_sample(sample): def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """ """ Take a audio file, and optionally convert it to 16kHz WAV """
wav_filename = sample[0] wav_filename = sample[0]
@ -63,74 +65,81 @@ def one_sample(sample):
frames = 0 frames = 0
if os.path.exists(wav_filename): if os.path.exists(wav_filename):
file_size = os.path.getsize(wav_filename) file_size = os.path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = label_filter(sample[1]) label = label_filter(sample[1])
counter = get_counter() counter = get_counter()
rows = [] rows = []
if file_size == -1: if file_size == -1:
# Excluding samples that failed upon conversion # Excluding samples that failed upon conversion
counter['failed'] += 1 counter["failed"] += 1
elif label is None: elif label is None:
# Excluding samples that failed on label validation # Excluding samples that failed on label validation
counter['invalid_label'] += 1 counter["invalid_label"] += 1
elif int(frames / SAMPLE_RATE * 1000 / 15 / 2) < len(str(label)): elif int(frames / SAMPLE_RATE * 1000 / 15 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript # Excluding samples that are too short to fit the transcript
counter['too_short'] += 1 counter["too_short"] += 1
elif frames / SAMPLE_RATE > MAX_SECS: elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size # Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1 counter["too_long"] += 1
else: else:
# This one is good - keep it for the target CSV # This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label)) rows.append((wav_filename, file_size, label))
counter['all'] += 1 counter["all"] += 1
counter['total_time'] += frames counter["total_time"] += frames
return (counter, rows) return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = os.path.join(target_dir, extracted_data) extracted_dir = os.path.join(target_dir, extracted_data)
# override existing CSV with normalized one # override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tar.gz', '_{}.csv')) target_csv_template = os.path.join(
target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace(".tar.gz", "_{}.csv")
)
if os.path.isfile(target_csv_template): if os.path.isfile(target_csv_template):
return return
wav_root_dir = os.path.join(extracted_dir) wav_root_dir = os.path.join(extracted_dir)
all_files = [ all_files = [
'transcripts/train/yaounde/fn_text.txt', "transcripts/train/yaounde/fn_text.txt",
'transcripts/train/ca16_conv/transcripts.txt', "transcripts/train/ca16_conv/transcripts.txt",
'transcripts/train/ca16_read/conditioned.txt', "transcripts/train/ca16_read/conditioned.txt",
'transcripts/dev/niger_west_african_fr/transcripts.txt', "transcripts/dev/niger_west_african_fr/transcripts.txt",
'speech/dev/niger_west_african_fr/niger_wav_file_name_transcript.tsv', "speech/dev/niger_west_african_fr/niger_wav_file_name_transcript.tsv",
'transcripts/devtest/ca16_read/conditioned.txt', "transcripts/devtest/ca16_read/conditioned.txt",
'transcripts/test/ca16/prompts.txt', "transcripts/test/ca16/prompts.txt",
] ]
transcripts = {} transcripts = {}
for tr in all_files: for tr in all_files:
with open(os.path.join(target_dir, ARCHIVE_DIR_NAME, tr), 'r') as tr_source: with open(os.path.join(target_dir, ARCHIVE_DIR_NAME, tr), "r") as tr_source:
for line in tr_source.readlines(): for line in tr_source.readlines():
line = line.strip() line = line.strip()
if '.tsv' in tr: if ".tsv" in tr:
sep = ' ' sep = " "
else: else:
sep = ' ' sep = " "
audio = os.path.basename(line.split(sep)[0]) audio = os.path.basename(line.split(sep)[0])
if not ('.wav' in audio): if not (".wav" in audio):
if '.tdf' in audio: if ".tdf" in audio:
audio = audio.replace('.tdf', '.wav') audio = audio.replace(".tdf", ".wav")
else: else:
audio += '.wav' audio += ".wav"
transcript = ' '.join(line.split(sep)[1:]) transcript = " ".join(line.split(sep)[1:])
transcripts[audio] = transcript transcripts[audio] = transcript
# Get audiofile path and transcript for each sentence in tsv # Get audiofile path and transcript for each sentence in tsv
samples = [] samples = []
glob_dir = os.path.join(wav_root_dir, '**/*.wav') glob_dir = os.path.join(wav_root_dir, "**/*.wav")
for record in glob(glob_dir, recursive=True): for record in glob(glob_dir, recursive=True):
record_file = os.path.basename(record) record_file = os.path.basename(record)
if record_file in transcripts: if record_file in transcripts:
@ -152,9 +161,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
pool.close() pool.close()
pool.join() pool.join()
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80% with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10% with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10% with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES) train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
train_writer.writeheader() train_writer.writeheader()
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES) dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
@ -174,25 +183,38 @@ def _maybe_convert_sets(target_dir, extracted_data):
writer = dev_writer writer = dev_writer
else: else:
writer = train_writer writer = train_writer
writer.writerow(dict( writer.writerow(
dict(
wav_filename=wav_filename, wav_filename=wav_filename,
wav_filesize=os.path.getsize(wav_filename), wav_filesize=os.path.getsize(wav_filename),
transcript=transcript, transcript=transcript,
)) )
)
imported_samples = get_imported_samples(counter) imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples assert counter["all"] == num_samples
assert len(rows) == imported_samples assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def handle_args(): def handle_args():
parser = get_importers_parser(description='Importer for African Accented French dataset. More information on http://www.openslr.org/57/.') parser = get_importers_parser(
parser.add_argument(dest='target_dir') description="Importer for African Accented French dataset. More information on http://www.openslr.org/57/."
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet') )
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones') parser.add_argument(dest="target_dir")
parser.add_argument(
"--filter_alphabet",
help="Exclude samples with characters not in provided alphabet",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
return parser.parse_args() return parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
CLI_ARGS = handle_args() CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
@ -200,9 +222,11 @@ if __name__ == "__main__":
def label_filter(label): def label_filter(label):
if CLI_ARGS.normalize: if CLI_ARGS.normalize:
label = unicodedata.normalize("NFKD", label.strip()) \ label = (
.encode("ascii", "ignore") \ unicodedata.normalize("NFKD", label.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore") .decode("ascii", "ignore")
)
label = validate_label(label) label = validate_label(label)
if ALPHABET and label: if ALPHABET and label:
try: try:

View File

@ -18,22 +18,21 @@ import pandas
import requests import requests
import soundfile # <= Has an external dependency on libsndfile import soundfile # <= Has an external dependency on libsndfile
from deepspeech_training.util.importers import \ from deepspeech_training.util.importers import validate_label_eng as validate_label
validate_label_eng as validate_label
# ARCHIVE_NAME refers to ISIP alignments from 01/29/03 # ARCHIVE_NAME refers to ISIP alignments from 01/29/03
ARCHIVE_NAME = 'switchboard_word_alignments.tar.gz' ARCHIVE_NAME = "switchboard_word_alignments.tar.gz"
ARCHIVE_URL = 'http://www.openslr.org/resources/5/' ARCHIVE_URL = "http://www.openslr.org/resources/5/"
ARCHIVE_DIR_NAME = 'LDC97S62' ARCHIVE_DIR_NAME = "LDC97S62"
LDC_DATASET = 'swb1_LDC97S62.tgz' LDC_DATASET = "swb1_LDC97S62.tgz"
def download_file(folder, url): def download_file(folder, url):
# https://stackoverflow.com/a/16696317/738515 # https://stackoverflow.com/a/16696317/738515
local_filename = url.split('/')[-1] local_filename = url.split("/")[-1]
full_filename = os.path.join(folder, local_filename) full_filename = os.path.join(folder, local_filename)
r = requests.get(url, stream=True) r = requests.get(url, stream=True)
with open(full_filename, 'wb') as f: with open(full_filename, "wb") as f:
for chunk in r.iter_content(chunk_size=1024): for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks if chunk: # filter out keep-alive new chunks
f.write(chunk) f.write(chunk)
@ -62,7 +61,7 @@ def _download_and_preprocess_data(data_dir):
archive_path = os.path.abspath(os.path.join(data_dir, LDC_DATASET)) archive_path = os.path.abspath(os.path.join(data_dir, LDC_DATASET))
# Check swb1_LDC97S62.tgz then extract # Check swb1_LDC97S62.tgz then extract
assert(os.path.isfile(archive_path)) assert os.path.isfile(archive_path)
_extract(target_dir, archive_path) _extract(target_dir, archive_path)
# Transcripts # Transcripts
@ -70,8 +69,14 @@ def _download_and_preprocess_data(data_dir):
_extract(target_dir, transcripts_path) _extract(target_dir, transcripts_path)
# Check swb1_d1/2/3/4/swb_ms98_transcriptions # Check swb1_d1/2/3/4/swb_ms98_transcriptions
expected_folders = ["swb1_d1","swb1_d2","swb1_d3","swb1_d4","swb_ms98_transcriptions"] expected_folders = [
assert(all([os.path.isdir(os.path.join(target_dir,e)) for e in expected_folders])) "swb1_d1",
"swb1_d2",
"swb1_d3",
"swb1_d4",
"swb_ms98_transcriptions",
]
assert all([os.path.isdir(os.path.join(target_dir, e)) for e in expected_folders])
# Conditionally convert swb sph data to wav # Conditionally convert swb sph data to wav
_maybe_convert_wav(target_dir, "swb1_d1", "swb1_d1-wav") _maybe_convert_wav(target_dir, "swb1_d1", "swb1_d1-wav")
@ -80,10 +85,18 @@ def _download_and_preprocess_data(data_dir):
_maybe_convert_wav(target_dir, "swb1_d4", "swb1_d4-wav") _maybe_convert_wav(target_dir, "swb1_d4", "swb1_d4-wav")
# Conditionally split wav data # Conditionally split wav data
d1 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d1-wav", "swb1_d1-split-wav") d1 = _maybe_split_wav_and_sentences(
d2 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d2-wav", "swb1_d2-split-wav") target_dir, "swb_ms98_transcriptions", "swb1_d1-wav", "swb1_d1-split-wav"
d3 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d3-wav", "swb1_d3-split-wav") )
d4 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d4-wav", "swb1_d4-split-wav") d2 = _maybe_split_wav_and_sentences(
target_dir, "swb_ms98_transcriptions", "swb1_d2-wav", "swb1_d2-split-wav"
)
d3 = _maybe_split_wav_and_sentences(
target_dir, "swb_ms98_transcriptions", "swb1_d3-wav", "swb1_d3-split-wav"
)
d4 = _maybe_split_wav_and_sentences(
target_dir, "swb_ms98_transcriptions", "swb1_d4-wav", "swb1_d4-split-wav"
)
swb_files = d1.append(d2).append(d3).append(d4) swb_files = d1.append(d2).append(d3).append(d4)
@ -115,14 +128,35 @@ def _maybe_convert_wav(data_dir, original_data, converted_data):
# Loop over sph files in source_dir and convert each to 16-bit PCM wav # Loop over sph files in source_dir and convert each to 16-bit PCM wav
for root, dirnames, filenames in os.walk(source_dir): for root, dirnames, filenames in os.walk(source_dir):
for filename in fnmatch.filter(filenames, "*.sph"): for filename in fnmatch.filter(filenames, "*.sph"):
for channel in ['1', '2']: for channel in ["1", "2"]:
sph_file = os.path.join(root, filename) sph_file = os.path.join(root, filename)
wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "-" + channel + ".wav" wav_filename = (
os.path.splitext(os.path.basename(sph_file))[0]
+ "-"
+ channel
+ ".wav"
)
wav_file = os.path.join(target_dir, wav_filename) wav_file = os.path.join(target_dir, wav_filename)
temp_wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "-" + channel + "-temp.wav" temp_wav_filename = (
os.path.splitext(os.path.basename(sph_file))[0]
+ "-"
+ channel
+ "-temp.wav"
)
temp_wav_file = os.path.join(target_dir, temp_wav_filename) temp_wav_file = os.path.join(target_dir, temp_wav_filename)
print("converting {} to {}".format(sph_file, temp_wav_file)) print("converting {} to {}".format(sph_file, temp_wav_file))
subprocess.check_call(["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, temp_wav_file]) subprocess.check_call(
[
"sph2pipe",
"-c",
channel,
"-p",
"-f",
"rif",
sph_file,
temp_wav_file,
]
)
print("upsampling {} to {}".format(temp_wav_file, wav_file)) print("upsampling {} to {}".format(temp_wav_file, wav_file))
audioData, frameRate = librosa.load(temp_wav_file, sr=16000, mono=True) audioData, frameRate = librosa.load(temp_wav_file, sr=16000, mono=True)
soundfile.write(wav_file, audioData, frameRate, "PCM_16") soundfile.write(wav_file, audioData, frameRate, "PCM_16")
@ -147,15 +181,19 @@ def _parse_transcriptions(trans_file):
# We need to do the encode-decode dance here because encode # We need to do the encode-decode dance here because encode
# returns a bytes() object on Python 3, and text_to_char_array # returns a bytes() object on Python 3, and text_to_char_array
# expects a string. # expects a string.
transcript = unicodedata.normalize("NFKD", transcript) \ transcript = (
.encode("ascii", "ignore") \ unicodedata.normalize("NFKD", transcript)
.encode("ascii", "ignore")
.decode("ascii", "ignore") .decode("ascii", "ignore")
)
segments.append({ segments.append(
{
"start_time": start_time, "start_time": start_time,
"stop_time": stop_time, "stop_time": stop_time,
"transcript": transcript, "transcript": transcript,
}) }
)
return segments return segments
@ -180,8 +218,16 @@ def _maybe_split_wav_and_sentences(data_dir, trans_data, original_data, converte
segments = _parse_transcriptions(trans_file) segments = _parse_transcriptions(trans_file)
# Open wav corresponding to transcription file # Open wav corresponding to transcription file
channel = ("2","1")[(os.path.splitext(os.path.basename(trans_file))[0])[6] == 'A'] channel = ("2", "1")[
wav_filename = "sw0" + (os.path.splitext(os.path.basename(trans_file))[0])[2:6] + "-" + channel + ".wav" (os.path.splitext(os.path.basename(trans_file))[0])[6] == "A"
]
wav_filename = (
"sw0"
+ (os.path.splitext(os.path.basename(trans_file))[0])[2:6]
+ "-"
+ channel
+ ".wav"
)
wav_file = os.path.join(source_dir, wav_filename) wav_file = os.path.join(source_dir, wav_filename)
print("splitting {} according to {}".format(wav_file, trans_file)) print("splitting {} according to {}".format(wav_file, trans_file))
@ -197,8 +243,14 @@ def _maybe_split_wav_and_sentences(data_dir, trans_data, original_data, converte
# Create wav segment filename # Create wav segment filename
start_time = segment["start_time"] start_time = segment["start_time"]
stop_time = segment["stop_time"] stop_time = segment["stop_time"]
new_wav_filename = os.path.splitext(os.path.basename(trans_file))[0] + "-" + str( new_wav_filename = (
start_time) + "-" + str(stop_time) + ".wav" os.path.splitext(os.path.basename(trans_file))[0]
+ "-"
+ str(start_time)
+ "-"
+ str(stop_time)
+ ".wav"
)
if _is_wav_too_short(new_wav_filename): if _is_wav_too_short(new_wav_filename):
continue continue
new_wav_file = os.path.join(target_dir, new_wav_filename) new_wav_file = os.path.join(target_dir, new_wav_filename)
@ -207,16 +259,23 @@ def _maybe_split_wav_and_sentences(data_dir, trans_data, original_data, converte
new_wav_filesize = os.path.getsize(new_wav_file) new_wav_filesize = os.path.getsize(new_wav_file)
transcript = segment["transcript"] transcript = segment["transcript"]
files.append((os.path.abspath(new_wav_file), new_wav_filesize, transcript)) files.append(
(os.path.abspath(new_wav_file), new_wav_filesize, transcript)
)
# Close origAudio # Close origAudio
origAudio.close() origAudio.close()
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"]) return pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
)
def _is_wav_too_short(wav_filename): def _is_wav_too_short(wav_filename):
short_wav_filenames = ['sw2986A-ms98-a-trans-80.6385-83.358875.wav', 'sw2663A-ms98-a-trans-161.12025-164.213375.wav'] short_wav_filenames = [
"sw2986A-ms98-a-trans-80.6385-83.358875.wav",
"sw2663A-ms98-a-trans-161.12025-164.213375.wav",
]
return wav_filename in short_wav_filenames return wav_filename in short_wav_filenames
@ -245,10 +304,24 @@ def _split_sets(filelist):
test_beg = dev_end test_beg = dev_end
test_end = len(filelist) test_end = len(filelist)
return (filelist[train_beg:train_end], filelist[dev_beg:dev_end], filelist[test_beg:test_end]) return (
filelist[train_beg:train_end],
filelist[dev_beg:dev_end],
filelist[test_beg:test_end],
)
def _read_data_set(filelist, thread_count, batch_size, numcep, numcontext, stride=1, offset=0, next_index=lambda i: i + 1, limit=0): def _read_data_set(
filelist,
thread_count,
batch_size,
numcep,
numcontext,
stride=1,
offset=0,
next_index=lambda i: i + 1,
limit=0,
):
# Optionally apply dataset size limit # Optionally apply dataset size limit
if limit > 0: if limit > 0:
filelist = filelist.iloc[:limit] filelist = filelist.iloc[:limit]
@ -256,7 +329,9 @@ def _read_data_set(filelist, thread_count, batch_size, numcep, numcontext, strid
filelist = filelist[offset::stride] filelist = filelist[offset::stride]
# Return DataSet # Return DataSet
return DataSet(txt_files, thread_count, batch_size, numcep, numcontext, next_index=next_index) return DataSet(
txt_files, thread_count, batch_size, numcep, numcontext, next_index=next_index
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,8 +1,8 @@
#!/usr/bin/env python #!/usr/bin/env python
''' """
Downloads and prepares (parts of) the "Spoken Wikipedia Corpora" for DeepSpeech.py Downloads and prepares (parts of) the "Spoken Wikipedia Corpora" for DeepSpeech.py
Use "python3 import_swc.py -h" for help Use "python3 import_swc.py -h" for help
''' """
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import argparse import argparse
@ -24,44 +24,54 @@ import progressbar
import sox import sox
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import \ from deepspeech_training.util.importers import validate_label_eng as validate_label
validate_label_eng as validate_label
from deepspeech_training.util.text import Alphabet from deepspeech_training.util.text import Alphabet
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar" SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"
SWC_ARCHIVE = "SWC_{language}.tar" SWC_ARCHIVE = "SWC_{language}.tar"
LANGUAGES = ['dutch', 'english', 'german'] LANGUAGES = ["dutch", "english", "german"]
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
FIELDNAMES_EXT = FIELDNAMES + ['article', 'speaker'] FIELDNAMES_EXT = FIELDNAMES + ["article", "speaker"]
CHANNELS = 1 CHANNELS = 1
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
UNKNOWN = '<unknown>' UNKNOWN = "<unknown>"
AUDIO_PATTERN = 'audio*.ogg' AUDIO_PATTERN = "audio*.ogg"
WAV_NAME = 'audio.wav' WAV_NAME = "audio.wav"
ALIGNED_NAME = 'aligned.swc' ALIGNED_NAME = "aligned.swc"
SUBSTITUTIONS = { SUBSTITUTIONS = {
'german': [ "german": [
(re.compile(r'\$'), 'dollar'), (re.compile(r"\$"), "dollar"),
(re.compile(r''), 'euro'), (re.compile(r""), "euro"),
(re.compile(r'£'), 'pfund'), (re.compile(r"£"), "pfund"),
(re.compile(r'ein tausend ([^\s]+) hundert ([^\s]+) er( |$)'), r'\1zehnhundert \2er '), (
(re.compile(r'ein tausend (acht|neun) hundert'), r'\1zehnhundert'), re.compile(r"ein tausend ([^\s]+) hundert ([^\s]+) er( |$)"),
(re.compile(r'eins punkt null null null punkt null null null punkt null null null'), 'eine milliarde'), r"\1zehnhundert \2er ",
(re.compile(r'punkt null null null punkt null null null punkt null null null'), 'milliarden'), ),
(re.compile(r'eins punkt null null null punkt null null null'), 'eine million'), (re.compile(r"ein tausend (acht|neun) hundert"), r"\1zehnhundert"),
(re.compile(r'punkt null null null punkt null null null'), 'millionen'), (
(re.compile(r'eins punkt null null null'), 'ein tausend'), re.compile(
(re.compile(r'punkt null null null'), 'tausend'), r"eins punkt null null null punkt null null null punkt null null null"
(re.compile(r'punkt null'), None) ),
"eine milliarde",
),
(
re.compile(
r"punkt null null null punkt null null null punkt null null null"
),
"milliarden",
),
(re.compile(r"eins punkt null null null punkt null null null"), "eine million"),
(re.compile(r"punkt null null null punkt null null null"), "millionen"),
(re.compile(r"eins punkt null null null"), "ein tausend"),
(re.compile(r"punkt null null null"), "tausend"),
(re.compile(r"punkt null"), None),
] ]
} }
DONT_NORMALIZE = { DONT_NORMALIZE = {"german": "ÄÖÜäöüß"}
'german': 'ÄÖÜäöüß'
}
PRE_FILTER = str.maketrans(dict.fromkeys('/()[]{}<>:')) PRE_FILTER = str.maketrans(dict.fromkeys("/()[]{}<>:"))
class Sample: class Sample:
@ -95,11 +105,14 @@ def get_sample_size(population_size):
margin_of_error = 0.01 margin_of_error = 0.01
fraction_picking = 0.50 fraction_picking = 0.50
z_score = 2.58 # Corresponds to confidence level 99% z_score = 2.58 # Corresponds to confidence level 99%
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (margin_of_error ** 2) numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2
)
sample_size = 0 sample_size = 0
for train_size in range(population_size, 0, -1): for train_size in range(population_size, 0, -1):
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / \ denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
(margin_of_error ** 2 * train_size) margin_of_error ** 2 * train_size
)
sample_size = int(numerator / denominator) sample_size = int(numerator / denominator)
if 2 * sample_size + train_size <= population_size: if 2 * sample_size + train_size <= population_size:
break break
@ -108,9 +121,11 @@ def get_sample_size(population_size):
def maybe_download_language(language): def maybe_download_language(language):
lang_upper = language[0].upper() + language[1:] lang_upper = language[0].upper() + language[1:]
return maybe_download(SWC_ARCHIVE.format(language=lang_upper), return maybe_download(
SWC_ARCHIVE.format(language=lang_upper),
CLI_ARGS.base_dir, CLI_ARGS.base_dir,
SWC_URL.format(language=lang_upper)) SWC_URL.format(language=lang_upper),
)
def maybe_extract(data_dir, extracted_data, archive): def maybe_extract(data_dir, extracted_data, archive):
@ -130,29 +145,29 @@ def maybe_extract(data_dir, extracted_data, archive):
def ignored(node): def ignored(node):
if node is None: if node is None:
return False return False
if node.tag == 'ignored': if node.tag == "ignored":
return True return True
return ignored(node.find('..')) return ignored(node.find(".."))
def read_token(token): def read_token(token):
texts, start, end = [], None, None texts, start, end = [], None, None
notes = token.findall('n') notes = token.findall("n")
if len(notes) > 0: if len(notes) > 0:
for note in notes: for note in notes:
attributes = note.attrib attributes = note.attrib
if start is None and 'start' in attributes: if start is None and "start" in attributes:
start = int(attributes['start']) start = int(attributes["start"])
if 'end' in attributes: if "end" in attributes:
token_end = int(attributes['end']) token_end = int(attributes["end"])
if end is None or token_end > end: if end is None or token_end > end:
end = token_end end = token_end
if 'pronunciation' in attributes: if "pronunciation" in attributes:
t = attributes['pronunciation'] t = attributes["pronunciation"]
texts.append(t) texts.append(t)
elif 'text' in token.attrib: elif "text" in token.attrib:
texts.append(token.attrib['text']) texts.append(token.attrib["text"])
return start, end, ' '.join(texts) return start, end, " ".join(texts)
def in_alphabet(alphabet, c): def in_alphabet(alphabet, c):
@ -160,10 +175,12 @@ def in_alphabet(alphabet, c):
ALPHABETS = {} ALPHABETS = {}
def get_alphabet(language): def get_alphabet(language):
if language in ALPHABETS: if language in ALPHABETS:
return ALPHABETS[language] return ALPHABETS[language]
alphabet_path = getattr(CLI_ARGS, language + '_alphabet') alphabet_path = getattr(CLI_ARGS, language + "_alphabet")
alphabet = Alphabet(alphabet_path) if alphabet_path else None alphabet = Alphabet(alphabet_path) if alphabet_path else None
ALPHABETS[language] = alphabet ALPHABETS[language] = alphabet
return alphabet return alphabet
@ -173,27 +190,35 @@ def label_filter(label, language):
label = label.translate(PRE_FILTER) label = label.translate(PRE_FILTER)
label = validate_label(label) label = validate_label(label)
if label is None: if label is None:
return None, 'validation' return None, "validation"
substitutions = SUBSTITUTIONS[language] if language in SUBSTITUTIONS else [] substitutions = SUBSTITUTIONS[language] if language in SUBSTITUTIONS else []
for pattern, replacement in substitutions: for pattern, replacement in substitutions:
if replacement is None: if replacement is None:
if pattern.match(label): if pattern.match(label):
return None, 'substitution rule' return None, "substitution rule"
else: else:
label = pattern.sub(replacement, label) label = pattern.sub(replacement, label)
chars = [] chars = []
dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else '' dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else ""
alphabet = get_alphabet(language) alphabet = get_alphabet(language)
for c in label: for c in label:
if CLI_ARGS.normalize and c not in dont_normalize and not in_alphabet(alphabet, c): if (
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore") CLI_ARGS.normalize
and c not in dont_normalize
and not in_alphabet(alphabet, c)
):
c = (
unicodedata.normalize("NFKD", c)
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
for sc in c: for sc in c:
if not in_alphabet(alphabet, sc): if not in_alphabet(alphabet, sc):
return None, 'illegal character' return None, "illegal character"
chars.append(sc) chars.append(sc)
label = ''.join(chars) label = "".join(chars)
label = validate_label(label) label = validate_label(label)
return label, 'validation' if label is None else None return label, "validation" if label is None else None
def collect_samples(base_dir, language): def collect_samples(base_dir, language):
@ -204,7 +229,9 @@ def collect_samples(base_dir, language):
samples = [] samples = []
reasons = Counter() reasons = Counter()
def add_sample(p_wav_path, p_article, p_speaker, p_start, p_end, p_text, p_reason='complete'): def add_sample(
p_wav_path, p_article, p_speaker, p_start, p_end, p_text, p_reason="complete"
):
if p_start is not None and p_end is not None and p_text is not None: if p_start is not None and p_end is not None and p_text is not None:
duration = p_end - p_start duration = p_end - p_start
text, filter_reason = label_filter(p_text, language) text, filter_reason = label_filter(p_text, language)
@ -214,53 +241,67 @@ def collect_samples(base_dir, language):
p_reason = filter_reason p_reason = filter_reason
elif CLI_ARGS.exclude_unknown_speakers and p_speaker == UNKNOWN: elif CLI_ARGS.exclude_unknown_speakers and p_speaker == UNKNOWN:
skip = True skip = True
p_reason = 'unknown speaker' p_reason = "unknown speaker"
elif CLI_ARGS.exclude_unknown_articles and p_article == UNKNOWN: elif CLI_ARGS.exclude_unknown_articles and p_article == UNKNOWN:
skip = True skip = True
p_reason = 'unknown article' p_reason = "unknown article"
elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long: elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long:
skip = True skip = True
p_reason = 'exceeded duration' p_reason = "exceeded duration"
elif int(duration / 30) < len(text): elif int(duration / 30) < len(text):
skip = True skip = True
p_reason = 'too short to decode' p_reason = "too short to decode"
elif duration / len(text) < 10: elif duration / len(text) < 10:
skip = True skip = True
p_reason = 'length duration ratio' p_reason = "length duration ratio"
if skip: if skip:
reasons[p_reason] += 1 reasons[p_reason] += 1
else: else:
samples.append(Sample(p_wav_path, p_start, p_end, text, p_article, p_speaker)) samples.append(
Sample(p_wav_path, p_start, p_end, text, p_article, p_speaker)
)
elif p_start is None or p_end is None: elif p_start is None or p_end is None:
reasons['missing timestamps'] += 1 reasons["missing timestamps"] += 1
else: else:
reasons['missing text'] += 1 reasons["missing text"] += 1
print('Collecting samples...') print("Collecting samples...")
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
for root in bar(roots): for root in bar(roots):
wav_path = os.path.join(root, WAV_NAME) wav_path = os.path.join(root, WAV_NAME)
aligned = ET.parse(path.join(root, ALIGNED_NAME)) aligned = ET.parse(path.join(root, ALIGNED_NAME))
article = UNKNOWN article = UNKNOWN
speaker = UNKNOWN speaker = UNKNOWN
for prop in aligned.iter('prop'): for prop in aligned.iter("prop"):
attributes = prop.attrib attributes = prop.attrib
if 'key' in attributes and 'value' in attributes: if "key" in attributes and "value" in attributes:
if attributes['key'] == 'DC.identifier': if attributes["key"] == "DC.identifier":
article = attributes['value'] article = attributes["value"]
elif attributes['key'] == 'reader.name': elif attributes["key"] == "reader.name":
speaker = attributes['value'] speaker = attributes["value"]
for sentence in aligned.iter('s'): for sentence in aligned.iter("s"):
if ignored(sentence): if ignored(sentence):
continue continue
split = False split = False
tokens = list(map(read_token, sentence.findall('t'))) tokens = list(map(read_token, sentence.findall("t")))
sample_start, sample_end, token_texts, sample_texts = None, None, [], [] sample_start, sample_end, token_texts, sample_texts = None, None, [], []
for token_start, token_end, token_text in tokens: for token_start, token_end, token_text in tokens:
if CLI_ARGS.exclude_numbers and any(c.isdigit() for c in token_text): if CLI_ARGS.exclude_numbers and any(c.isdigit() for c in token_text):
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts), add_sample(
p_reason='has numbers') wav_path,
sample_start, sample_end, token_texts, sample_texts = None, None, [], [] article,
speaker,
sample_start,
sample_end,
" ".join(sample_texts),
p_reason="has numbers",
)
sample_start, sample_end, token_texts, sample_texts = (
None,
None,
[],
[],
)
continue continue
if sample_start is None: if sample_start is None:
sample_start = token_start sample_start = token_start
@ -268,20 +309,37 @@ def collect_samples(base_dir, language):
continue continue
token_texts.append(token_text) token_texts.append(token_text)
if token_end is not None: if token_end is not None:
if token_start != sample_start and token_end - sample_start > CLI_ARGS.max_duration > 0: if (
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts), token_start != sample_start
p_reason='split') and token_end - sample_start > CLI_ARGS.max_duration > 0
):
add_sample(
wav_path,
article,
speaker,
sample_start,
sample_end,
" ".join(sample_texts),
p_reason="split",
)
sample_start = sample_end sample_start = sample_end
sample_texts = [] sample_texts = []
split = True split = True
sample_end = token_end sample_end = token_end
sample_texts.extend(token_texts) sample_texts.extend(token_texts)
token_texts = [] token_texts = []
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts), add_sample(
p_reason='split' if split else 'complete') wav_path,
print('Skipped samples:') article,
speaker,
sample_start,
sample_end,
" ".join(sample_texts),
p_reason="split" if split else "complete",
)
print("Skipped samples:")
for reason, n in reasons.most_common(): for reason, n in reasons.most_common():
print(' - {}: {}'.format(reason, n)) print(" - {}: {}".format(reason, n))
return samples return samples
@ -301,18 +359,18 @@ def maybe_convert_one_to_wav(entry):
elif len(files) > 1: elif len(files) > 1:
wav_files = [] wav_files = []
for i, file in enumerate(files): for i, file in enumerate(files):
wav_path = os.path.join(root, 'audio{}.wav'.format(i)) wav_path = os.path.join(root, "audio{}.wav".format(i))
transformer.build(file, wav_path) transformer.build(file, wav_path)
wav_files.append(wav_path) wav_files.append(wav_path)
combiner.set_input_format(file_type=['wav'] * len(wav_files)) combiner.set_input_format(file_type=["wav"] * len(wav_files))
combiner.build(wav_files, output_wav, 'concatenate') combiner.build(wav_files, output_wav, "concatenate")
except sox.core.SoxError: except sox.core.SoxError:
return return
def maybe_convert_to_wav(base_dir): def maybe_convert_to_wav(base_dir):
roots = list(os.walk(base_dir)) roots = list(os.walk(base_dir))
print('Converting and joining source audio files...') print("Converting and joining source audio files...")
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
tp = ThreadPool() tp = ThreadPool()
for _ in bar(tp.imap_unordered(maybe_convert_one_to_wav, roots)): for _ in bar(tp.imap_unordered(maybe_convert_one_to_wav, roots)):
@ -332,53 +390,66 @@ def assign_sub_sets(samples):
sample_set.extend(speakers.pop(0)) sample_set.extend(speakers.pop(0))
train_set = sum(speakers, []) train_set = sum(speakers, [])
if len(train_set) == 0: if len(train_set) == 0:
print('WARNING: Unable to build dev and test sets without speaker bias as there is no speaker meta data') print(
"WARNING: Unable to build dev and test sets without speaker bias as there is no speaker meta data"
)
random.seed(42) # same source data == same output random.seed(42) # same source data == same output
random.shuffle(samples) random.shuffle(samples)
for index, sample in enumerate(samples): for index, sample in enumerate(samples):
if index < sample_size: if index < sample_size:
sample.sub_set = 'dev' sample.sub_set = "dev"
elif index < 2 * sample_size: elif index < 2 * sample_size:
sample.sub_set = 'test' sample.sub_set = "test"
else: else:
sample.sub_set = 'train' sample.sub_set = "train"
else: else:
for sub_set, sub_set_samples in [('train', train_set), ('dev', sample_sets[0]), ('test', sample_sets[1])]: for sub_set, sub_set_samples in [
("train", train_set),
("dev", sample_sets[0]),
("test", sample_sets[1]),
]:
for sample in sub_set_samples: for sample in sub_set_samples:
sample.sub_set = sub_set sample.sub_set = sub_set
for sub_set, sub_set_samples in group(samples, lambda s: s.sub_set).items(): for sub_set, sub_set_samples in group(samples, lambda s: s.sub_set).items():
t = sum(map(lambda s: s.end - s.start, sub_set_samples)) / (1000 * 60 * 60) t = sum(map(lambda s: s.end - s.start, sub_set_samples)) / (1000 * 60 * 60)
print('Sub-set "{}" with {} samples (duration: {:.2f} h)' print(
.format(sub_set, len(sub_set_samples), t)) 'Sub-set "{}" with {} samples (duration: {:.2f} h)'.format(
sub_set, len(sub_set_samples), t
)
)
def create_sample_dirs(language): def create_sample_dirs(language):
print('Creating sample directories...') print("Creating sample directories...")
for set_name in ['train', 'dev', 'test']: for set_name in ["train", "dev", "test"]:
dir_path = os.path.join(CLI_ARGS.base_dir, language + '-' + set_name) dir_path = os.path.join(CLI_ARGS.base_dir, language + "-" + set_name)
if not os.path.isdir(dir_path): if not os.path.isdir(dir_path):
os.mkdir(dir_path) os.mkdir(dir_path)
def split_audio_files(samples, language): def split_audio_files(samples, language):
print('Splitting audio files...') print("Splitting audio files...")
sub_sets = Counter() sub_sets = Counter()
src_wav_files = group(samples, lambda s: s.wav_path).items() src_wav_files = group(samples, lambda s: s.wav_path).items()
bar = progressbar.ProgressBar(max_value=len(src_wav_files), widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=len(src_wav_files), widgets=SIMPLE_BAR)
for wav_path, file_samples in bar(src_wav_files): for wav_path, file_samples in bar(src_wav_files):
file_samples = sorted(file_samples, key=lambda s: s.start) file_samples = sorted(file_samples, key=lambda s: s.start)
with wave.open(wav_path, 'r') as src_wav_file: with wave.open(wav_path, "r") as src_wav_file:
rate = src_wav_file.getframerate() rate = src_wav_file.getframerate()
for sample in file_samples: for sample in file_samples:
index = sub_sets[sample.sub_set] index = sub_sets[sample.sub_set]
sample_wav_path = os.path.join(CLI_ARGS.base_dir, sample_wav_path = os.path.join(
language + '-' + sample.sub_set, CLI_ARGS.base_dir,
'sample-{0:06d}.wav'.format(index)) language + "-" + sample.sub_set,
"sample-{0:06d}.wav".format(index),
)
sample.wav_path = sample_wav_path sample.wav_path = sample_wav_path
sub_sets[sample.sub_set] += 1 sub_sets[sample.sub_set] += 1
src_wav_file.setpos(int(sample.start * rate / 1000.0)) src_wav_file.setpos(int(sample.start * rate / 1000.0))
data = src_wav_file.readframes(int((sample.end - sample.start) * rate / 1000.0)) data = src_wav_file.readframes(
with wave.open(sample_wav_path, 'w') as sample_wav_file: int((sample.end - sample.start) * rate / 1000.0)
)
with wave.open(sample_wav_path, "w") as sample_wav_file:
sample_wav_file.setnchannels(src_wav_file.getnchannels()) sample_wav_file.setnchannels(src_wav_file.getnchannels())
sample_wav_file.setsampwidth(src_wav_file.getsampwidth()) sample_wav_file.setsampwidth(src_wav_file.getsampwidth())
sample_wav_file.setframerate(rate) sample_wav_file.setframerate(rate)
@ -389,21 +460,25 @@ def write_csvs(samples, language):
for sub_set, set_samples in group(samples, lambda s: s.sub_set).items(): for sub_set, set_samples in group(samples, lambda s: s.sub_set).items():
set_samples = sorted(set_samples, key=lambda s: s.wav_path) set_samples = sorted(set_samples, key=lambda s: s.wav_path)
base_dir = os.path.abspath(CLI_ARGS.base_dir) base_dir = os.path.abspath(CLI_ARGS.base_dir)
csv_path = os.path.join(base_dir, language + '-' + sub_set + '.csv') csv_path = os.path.join(base_dir, language + "-" + sub_set + ".csv")
print('Writing "{}"...'.format(csv_path)) print('Writing "{}"...'.format(csv_path))
with open(csv_path, 'w') as csv_file: with open(csv_path, "w") as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES) writer = csv.DictWriter(
csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES
)
writer.writeheader() writer.writeheader()
bar = progressbar.ProgressBar(max_value=len(set_samples), widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(
max_value=len(set_samples), widgets=SIMPLE_BAR
)
for sample in bar(set_samples): for sample in bar(set_samples):
row = { row = {
'wav_filename': os.path.relpath(sample.wav_path, base_dir), "wav_filename": os.path.relpath(sample.wav_path, base_dir),
'wav_filesize': os.path.getsize(sample.wav_path), "wav_filesize": os.path.getsize(sample.wav_path),
'transcript': sample.text "transcript": sample.text,
} }
if CLI_ARGS.add_meta: if CLI_ARGS.add_meta:
row['article'] = sample.article row["article"] = sample.article
row['speaker'] = sample.speaker row["speaker"] = sample.speaker
writer.writerow(row) writer.writerow(row)
@ -430,34 +505,75 @@ def prepare_language(language):
def handle_args(): def handle_args():
parser = argparse.ArgumentParser(description='Import Spoken Wikipedia Corpora') parser = argparse.ArgumentParser(description="Import Spoken Wikipedia Corpora")
parser.add_argument('base_dir', help='Directory containing all data') parser.add_argument("base_dir", help="Directory containing all data")
parser.add_argument('--language', default='all', help='One of (all|{})'.format('|'.join(LANGUAGES))) parser.add_argument(
parser.add_argument('--exclude_numbers', type=bool, default=True, "--language", default="all", help="One of (all|{})".format("|".join(LANGUAGES))
help='If sequences with non-transliterated numbers should be excluded') )
parser.add_argument('--max_duration', type=int, default=10000, help='Maximum sample duration in milliseconds') parser.add_argument(
parser.add_argument('--ignore_too_long', type=bool, default=False, "--exclude_numbers",
help='If samples exceeding max_duration should be removed') type=bool,
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones') default=True,
help="If sequences with non-transliterated numbers should be excluded",
)
parser.add_argument(
"--max_duration",
type=int,
default=10000,
help="Maximum sample duration in milliseconds",
)
parser.add_argument(
"--ignore_too_long",
type=bool,
default=False,
help="If samples exceeding max_duration should be removed",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
for language in LANGUAGES: for language in LANGUAGES:
parser.add_argument('--{}_alphabet'.format(language), parser.add_argument(
help='Exclude {} samples with characters not in provided alphabet file'.format(language)) "--{}_alphabet".format(language),
parser.add_argument('--add_meta', action='store_true', help='Adds article and speaker CSV columns') help="Exclude {} samples with characters not in provided alphabet file".format(
parser.add_argument('--exclude_unknown_speakers', action='store_true', help='Exclude unknown speakers') language
parser.add_argument('--exclude_unknown_articles', action='store_true', help='Exclude unknown articles') ),
parser.add_argument('--keep_archive', type=bool, default=True, )
help='If downloaded archives should be kept') parser.add_argument(
parser.add_argument('--keep_intermediate', type=bool, default=False, "--add_meta", action="store_true", help="Adds article and speaker CSV columns"
help='If intermediate files should be kept') )
parser.add_argument(
"--exclude_unknown_speakers",
action="store_true",
help="Exclude unknown speakers",
)
parser.add_argument(
"--exclude_unknown_articles",
action="store_true",
help="Exclude unknown articles",
)
parser.add_argument(
"--keep_archive",
type=bool,
default=True,
help="If downloaded archives should be kept",
)
parser.add_argument(
"--keep_intermediate",
type=bool,
default=False,
help="If intermediate files should be kept",
)
return parser.parse_args() return parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
CLI_ARGS = handle_args() CLI_ARGS = handle_args()
if CLI_ARGS.language == 'all': if CLI_ARGS.language == "all":
for lang in LANGUAGES: for lang in LANGUAGES:
prepare_language(lang) prepare_language(lang)
elif CLI_ARGS.language in LANGUAGES: elif CLI_ARGS.language in LANGUAGES:
prepare_language(CLI_ARGS.language) prepare_language(CLI_ARGS.language)
else: else:
fail('Wrong language id') fail("Wrong language id")

View File

@ -37,6 +37,7 @@ def _download_and_preprocess_data(data_dir):
dev_files.to_csv(path.join(data_dir, "ted-dev.csv"), index=False) dev_files.to_csv(path.join(data_dir, "ted-dev.csv"), index=False)
test_files.to_csv(path.join(data_dir, "ted-test.csv"), index=False) test_files.to_csv(path.join(data_dir, "ted-test.csv"), index=False)
def _maybe_extract(data_dir, extracted_data, archive): def _maybe_extract(data_dir, extracted_data, archive):
# If data_dir/extracted_data does not exist, extract archive in data_dir # If data_dir/extracted_data does not exist, extract archive in data_dir
if not gfile.Exists(path.join(data_dir, extracted_data)): if not gfile.Exists(path.join(data_dir, extracted_data)):
@ -44,6 +45,7 @@ def _maybe_extract(data_dir, extracted_data, archive):
tar.extractall(data_dir) tar.extractall(data_dir)
tar.close() tar.close()
def _maybe_convert_wav(data_dir, extracted_data): def _maybe_convert_wav(data_dir, extracted_data):
# Create extracted_data dir # Create extracted_data dir
extracted_dir = path.join(data_dir, extracted_data) extracted_dir = path.join(data_dir, extracted_data)
@ -57,6 +59,7 @@ def _maybe_convert_wav(data_dir, extracted_data):
# Conditionally convert test sph to wav # Conditionally convert test sph to wav
_maybe_convert_wav_dataset(extracted_dir, "test") _maybe_convert_wav_dataset(extracted_dir, "test")
def _maybe_convert_wav_dataset(extracted_dir, data_set): def _maybe_convert_wav_dataset(extracted_dir, data_set):
# Create source dir # Create source dir
source_dir = path.join(extracted_dir, data_set, "sph") source_dir = path.join(extracted_dir, data_set, "sph")
@ -80,6 +83,7 @@ def _maybe_convert_wav_dataset(extracted_dir, data_set):
# Remove source_dir # Remove source_dir
rmdir(source_dir) rmdir(source_dir)
def _maybe_split_sentences(data_dir, extracted_data): def _maybe_split_sentences(data_dir, extracted_data):
# Create extracted_data dir # Create extracted_data dir
extracted_dir = path.join(data_dir, extracted_data) extracted_dir = path.join(data_dir, extracted_data)
@ -95,6 +99,7 @@ def _maybe_split_sentences(data_dir, extracted_data):
return train_files, dev_files, test_files return train_files, dev_files, test_files
def _maybe_split_dataset(extracted_dir, data_set): def _maybe_split_dataset(extracted_dir, data_set):
# Create stm dir # Create stm dir
stm_dir = path.join(extracted_dir, data_set, "stm") stm_dir = path.join(extracted_dir, data_set, "stm")
@ -112,14 +117,21 @@ def _maybe_split_dataset(extracted_dir, data_set):
# Open wav corresponding to stm_file # Open wav corresponding to stm_file
wav_filename = path.splitext(path.basename(stm_file))[0] + ".wav" wav_filename = path.splitext(path.basename(stm_file))[0] + ".wav"
wav_file = path.join(wav_dir, wav_filename) wav_file = path.join(wav_dir, wav_filename)
origAudio = wave.open(wav_file,'r') origAudio = wave.open(wav_file, "r")
# Loop over stm_segments and split wav_file for each segment # Loop over stm_segments and split wav_file for each segment
for stm_segment in stm_segments: for stm_segment in stm_segments:
# Create wav segment filename # Create wav segment filename
start_time = stm_segment.start_time start_time = stm_segment.start_time
stop_time = stm_segment.stop_time stop_time = stm_segment.stop_time
new_wav_filename = path.splitext(path.basename(stm_file))[0] + "-" + str(start_time) + "-" + str(stop_time) + ".wav" new_wav_filename = (
path.splitext(path.basename(stm_file))[0]
+ "-"
+ str(start_time)
+ "-"
+ str(stop_time)
+ ".wav"
)
new_wav_file = path.join(wav_dir, new_wav_filename) new_wav_file = path.join(wav_dir, new_wav_filename)
# If the wav segment filename does not exist create it # If the wav segment filename does not exist create it
@ -127,23 +139,29 @@ def _maybe_split_dataset(extracted_dir, data_set):
_split_wav(origAudio, start_time, stop_time, new_wav_file) _split_wav(origAudio, start_time, stop_time, new_wav_file)
new_wav_filesize = path.getsize(new_wav_file) new_wav_filesize = path.getsize(new_wav_file)
files.append((path.abspath(new_wav_file), new_wav_filesize, stm_segment.transcript)) files.append(
(path.abspath(new_wav_file), new_wav_filesize, stm_segment.transcript)
)
# Close origAudio # Close origAudio
origAudio.close() origAudio.close()
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"]) return pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
)
def _split_wav(origAudio, start_time, stop_time, new_wav_file): def _split_wav(origAudio, start_time, stop_time, new_wav_file):
frameRate = origAudio.getframerate() frameRate = origAudio.getframerate()
origAudio.setpos(int(start_time * frameRate)) origAudio.setpos(int(start_time * frameRate))
chunkData = origAudio.readframes(int((stop_time - start_time) * frameRate)) chunkData = origAudio.readframes(int((stop_time - start_time) * frameRate))
chunkAudio = wave.open(new_wav_file,'w') chunkAudio = wave.open(new_wav_file, "w")
chunkAudio.setnchannels(origAudio.getnchannels()) chunkAudio.setnchannels(origAudio.getnchannels())
chunkAudio.setsampwidth(origAudio.getsampwidth()) chunkAudio.setsampwidth(origAudio.getsampwidth())
chunkAudio.setframerate(frameRate) chunkAudio.setframerate(frameRate)
chunkAudio.writeframes(chunkData) chunkAudio.writeframes(chunkData)
chunkAudio.close() chunkAudio.close()
if __name__ == "__main__": if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1]) _download_and_preprocess_data(sys.argv[1])

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
''' """
NAME : LDC TIMIT Dataset NAME : LDC TIMIT Dataset
URL : https://catalog.ldc.upenn.edu/ldc93s1 URL : https://catalog.ldc.upenn.edu/ldc93s1
HOURS : 5 HOURS : 5
@ -8,7 +8,7 @@
AUTHORS : Garofolo, John, et al. AUTHORS : Garofolo, John, et al.
TYPE : LDC Membership TYPE : LDC Membership
LICENCE : LDC User Agreement LICENCE : LDC User Agreement
''' """
import errno import errno
import fnmatch import fnmatch
@ -23,16 +23,17 @@ import pandas as pd
def clean(word): def clean(word):
# LC ALL & strip punctuation which are not required # LC ALL & strip punctuation which are not required
new = word.lower().replace('.', '') new = word.lower().replace(".", "")
new = new.replace(',', '') new = new.replace(",", "")
new = new.replace(';', '') new = new.replace(";", "")
new = new.replace('"', '') new = new.replace('"', "")
new = new.replace('!', '') new = new.replace("!", "")
new = new.replace('?', '') new = new.replace("?", "")
new = new.replace(':', '') new = new.replace(":", "")
new = new.replace('-', '') new = new.replace("-", "")
return new return new
def _preprocess_data(args): def _preprocess_data(args):
# Assume data is downloaded from LDC - https://catalog.ldc.upenn.edu/ldc93s1 # Assume data is downloaded from LDC - https://catalog.ldc.upenn.edu/ldc93s1
@ -42,16 +43,24 @@ def _preprocess_data(args):
if ignoreSASentences: if ignoreSASentences:
print("Using recommended ignore SA sentences") print("Using recommended ignore SA sentences")
print("Ignoring SA sentences (2 x sentences which are repeated by all speakers)") print(
"Ignoring SA sentences (2 x sentences which are repeated by all speakers)"
)
else: else:
print("Using unrecommended setting to include SA sentences") print("Using unrecommended setting to include SA sentences")
datapath = args datapath = args
target = path.join(datapath, "TIMIT") target = path.join(datapath, "TIMIT")
print("Checking to see if data has already been extracted in given argument: %s", target) print(
"Checking to see if data has already been extracted in given argument: %s",
target,
)
if not path.isdir(target): if not path.isdir(target):
print("Could not find extracted data, trying to find: TIMIT-LDC93S1.tgz in: ", datapath) print(
"Could not find extracted data, trying to find: TIMIT-LDC93S1.tgz in: ",
datapath,
)
filepath = path.join(datapath, "TIMIT-LDC93S1.tgz") filepath = path.join(datapath, "TIMIT-LDC93S1.tgz")
if path.isfile(filepath): if path.isfile(filepath):
print("File found, extracting") print("File found, extracting")
@ -105,40 +114,58 @@ def _preprocess_data(args):
# if ignoreSAsentences we only want those without SA in the name # if ignoreSAsentences we only want those without SA in the name
# OR # OR
# if not ignoreSAsentences we want all to be added # if not ignoreSAsentences we want all to be added
if (ignoreSASentences and not ('SA' in os.path.basename(full_wav))) or (not ignoreSASentences): if (ignoreSASentences and not ("SA" in os.path.basename(full_wav))) or (
if 'train' in full_wav.lower(): not ignoreSASentences
):
if "train" in full_wav.lower():
train_list_wavs.append(full_wav) train_list_wavs.append(full_wav)
train_list_trans.append(trans) train_list_trans.append(trans)
train_list_size.append(wav_filesize) train_list_size.append(wav_filesize)
elif 'test' in full_wav.lower(): elif "test" in full_wav.lower():
test_list_wavs.append(full_wav) test_list_wavs.append(full_wav)
test_list_trans.append(trans) test_list_trans.append(trans)
test_list_size.append(wav_filesize) test_list_size.append(wav_filesize)
else: else:
raise IOError raise IOError
a = {'wav_filename': train_list_wavs, a = {
'wav_filesize': train_list_size, "wav_filename": train_list_wavs,
'transcript': train_list_trans "wav_filesize": train_list_size,
"transcript": train_list_trans,
} }
c = {'wav_filename': test_list_wavs, c = {
'wav_filesize': test_list_size, "wav_filename": test_list_wavs,
'transcript': test_list_trans "wav_filesize": test_list_size,
"transcript": test_list_trans,
} }
all = {'wav_filename': train_list_wavs + test_list_wavs, all = {
'wav_filesize': train_list_size + test_list_size, "wav_filename": train_list_wavs + test_list_wavs,
'transcript': train_list_trans + test_list_trans "wav_filesize": train_list_size + test_list_size,
"transcript": train_list_trans + test_list_trans,
} }
df_all = pd.DataFrame(all, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int) df_all = pd.DataFrame(
df_train = pd.DataFrame(a, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int) all, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
df_test = pd.DataFrame(c, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int) )
df_train = pd.DataFrame(
a, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
)
df_test = pd.DataFrame(
c, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
)
df_all.to_csv(
target + "/timit_all.csv", sep=",", header=True, index=False, encoding="ascii"
)
df_train.to_csv(
target + "/timit_train.csv", sep=",", header=True, index=False, encoding="ascii"
)
df_test.to_csv(
target + "/timit_test.csv", sep=",", header=True, index=False, encoding="ascii"
)
df_all.to_csv(target+"/timit_all.csv", sep=',', header=True, index=False, encoding='ascii')
df_train.to_csv(target+"/timit_train.csv", sep=',', header=True, index=False, encoding='ascii')
df_test.to_csv(target+"/timit_test.csv", sep=',', header=True, index=False, encoding='ascii')
if __name__ == "__main__": if __name__ == "__main__":
_preprocess_data(sys.argv[1]) _preprocess_data(sys.argv[1])

View File

@ -18,26 +18,32 @@ from deepspeech_training.util.importers import (
get_imported_samples, get_imported_samples,
get_importers_parser, get_importers_parser,
get_validate_label, get_validate_label,
print_import_report print_import_report,
) )
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
MAX_SECS = 15 MAX_SECS = 15
ARCHIVE_NAME = '2019-04-11_fr_FR' ARCHIVE_NAME = "2019-04-11_fr_FR"
ARCHIVE_DIR_NAME = 'ts_' + ARCHIVE_NAME ARCHIVE_DIR_NAME = "ts_" + ARCHIVE_NAME
ARCHIVE_URL = 'https://deepspeech-storage-mirror.s3.fr-par.scw.cloud/' + ARCHIVE_NAME + '.zip' ARCHIVE_URL = (
"https://deepspeech-storage-mirror.s3.fr-par.scw.cloud/" + ARCHIVE_NAME + ".zip"
)
def _download_and_preprocess_data(target_dir, english_compatible=False): def _download_and_preprocess_data(target_dir, english_compatible=False):
# Making path absolute # Making path absolute
target_dir = os.path.abspath(target_dir) target_dir = os.path.abspath(target_dir)
# Conditionally download data # Conditionally download data
archive_path = maybe_download('ts_' + ARCHIVE_NAME + '.zip', target_dir, ARCHIVE_URL) archive_path = maybe_download(
"ts_" + ARCHIVE_NAME + ".zip", target_dir, ARCHIVE_URL
)
# Conditionally extract archive data # Conditionally extract archive data
_maybe_extract(target_dir, ARCHIVE_DIR_NAME, archive_path) _maybe_extract(target_dir, ARCHIVE_DIR_NAME, archive_path)
# Conditionally convert TrainingSpeech data to DeepSpeech CSVs and wav # Conditionally convert TrainingSpeech data to DeepSpeech CSVs and wav
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME, english_compatible=english_compatible) _maybe_convert_sets(
target_dir, ARCHIVE_DIR_NAME, english_compatible=english_compatible
)
def _maybe_extract(target_dir, extracted_data, archive_path): def _maybe_extract(target_dir, extracted_data, archive_path):
@ -55,7 +61,7 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
def one_sample(sample): def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """ """ Take a audio file, and optionally convert it to 16kHz WAV """
orig_filename = sample['path'] orig_filename = sample["path"]
# Storing wav files next to the wav ones - just with a different suffix # Storing wav files next to the wav ones - just with a different suffix
wav_filename = os.path.splitext(orig_filename)[0] + ".converted.wav" wav_filename = os.path.splitext(orig_filename)[0] + ".converted.wav"
_maybe_convert_wav(orig_filename, wav_filename) _maybe_convert_wav(orig_filename, wav_filename)
@ -63,8 +69,12 @@ def one_sample(sample):
frames = 0 frames = 0
if os.path.exists(wav_filename): if os.path.exists(wav_filename):
file_size = os.path.getsize(wav_filename) file_size = os.path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(
label = sample['text'] subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = sample["text"]
rows = [] rows = []
@ -72,21 +82,21 @@ def one_sample(sample):
counter = get_counter() counter = get_counter()
if file_size == -1: if file_size == -1:
# Excluding samples that failed upon conversion # Excluding samples that failed upon conversion
counter['failed'] += 1 counter["failed"] += 1
elif label is None: elif label is None:
# Excluding samples that failed on label validation # Excluding samples that failed on label validation
counter['invalid_label'] += 1 counter["invalid_label"] += 1
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)): elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript # Excluding samples that are too short to fit the transcript
counter['too_short'] += 1 counter["too_short"] += 1
elif frames / SAMPLE_RATE > MAX_SECS: elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size # Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1 counter["too_long"] += 1
else: else:
# This one is good - keep it for the target CSV # This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label)) rows.append((wav_filename, file_size, label))
counter['all'] += 1 counter["all"] += 1
counter['total_time'] += frames counter["total_time"] += frames
return (counter, rows) return (counter, rows)
@ -94,18 +104,19 @@ def one_sample(sample):
def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False): def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
extracted_dir = os.path.join(target_dir, extracted_data) extracted_dir = os.path.join(target_dir, extracted_data)
# override existing CSV with normalized one # override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, 'ts_' + ARCHIVE_NAME + '_{}.csv') target_csv_template = os.path.join(target_dir, "ts_" + ARCHIVE_NAME + "_{}.csv")
if os.path.isfile(target_csv_template): if os.path.isfile(target_csv_template):
return return
path_to_original_csv = os.path.join(extracted_dir, 'data.csv') path_to_original_csv = os.path.join(extracted_dir, "data.csv")
with open(path_to_original_csv) as csv_f: with open(path_to_original_csv) as csv_f:
data = [ data = [
d for d in csv.DictReader(csv_f, delimiter=',') d
if float(d['duration']) <= MAX_SECS for d in csv.DictReader(csv_f, delimiter=",")
if float(d["duration"]) <= MAX_SECS
] ]
for line in data: for line in data:
line['path'] = os.path.join(extracted_dir, line['path']) line["path"] = os.path.join(extracted_dir, line["path"])
num_samples = len(data) num_samples = len(data)
rows = [] rows = []
@ -122,9 +133,9 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
pool.close() pool.close()
pool.join() pool.join()
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80% with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10% with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10% with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES) train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
train_writer.writeheader() train_writer.writeheader()
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES) dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
@ -133,7 +144,11 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
test_writer.writeheader() test_writer.writeheader()
for i, item in enumerate(rows): for i, item in enumerate(rows):
transcript = validate_label(cleanup_transcript(item[2], english_compatible=english_compatible)) transcript = validate_label(
cleanup_transcript(
item[2], english_compatible=english_compatible
)
)
if not transcript: if not transcript:
continue continue
wav_filename = os.path.join(target_dir, extracted_data, item[0]) wav_filename = os.path.join(target_dir, extracted_data, item[0])
@ -144,18 +159,21 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
writer = dev_writer writer = dev_writer
else: else:
writer = train_writer writer = train_writer
writer.writerow(dict( writer.writerow(
dict(
wav_filename=wav_filename, wav_filename=wav_filename,
wav_filesize=os.path.getsize(wav_filename), wav_filesize=os.path.getsize(wav_filename),
transcript=transcript, transcript=transcript,
)) )
)
imported_samples = get_imported_samples(counter) imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples assert counter["all"] == num_samples
assert len(rows) == imported_samples assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(orig_filename, wav_filename): def _maybe_convert_wav(orig_filename, wav_filename):
if not os.path.exists(wav_filename): if not os.path.exists(wav_filename):
transformer = sox.Transformer() transformer = sox.Transformer()
@ -163,26 +181,31 @@ def _maybe_convert_wav(orig_filename, wav_filename):
try: try:
transformer.build(orig_filename, wav_filename) transformer.build(orig_filename, wav_filename)
except sox.core.SoxError as ex: except sox.core.SoxError as ex:
print('SoX processing error', ex, orig_filename, wav_filename) print("SoX processing error", ex, orig_filename, wav_filename)
PUNCTUATIONS_REG = re.compile(r"\-,;!?.()\[\]*…—]") PUNCTUATIONS_REG = re.compile(r"\-,;!?.()\[\]*…—]")
MULTIPLE_SPACES_REG = re.compile(r'\s{2,}') MULTIPLE_SPACES_REG = re.compile(r"\s{2,}")
def cleanup_transcript(text, english_compatible=False): def cleanup_transcript(text, english_compatible=False):
text = text.replace('', "'").replace('\u00A0', ' ') text = text.replace("", "'").replace("\u00A0", " ")
text = PUNCTUATIONS_REG.sub(' ', text) text = PUNCTUATIONS_REG.sub(" ", text)
text = MULTIPLE_SPACES_REG.sub(' ', text) text = MULTIPLE_SPACES_REG.sub(" ", text)
if english_compatible: if english_compatible:
text = unidecode.unidecode(text) text = unidecode.unidecode(text)
return text.strip().lower() return text.strip().lower()
def handle_args(): def handle_args():
parser = get_importers_parser(description='Importer for TrainingSpeech dataset.') parser = get_importers_parser(description="Importer for TrainingSpeech dataset.")
parser.add_argument(dest='target_dir') parser.add_argument(dest="target_dir")
parser.add_argument('--english-compatible', action='store_true', dest='english_compatible', help='Remove diactrics and other non-ascii chars.') parser.add_argument(
"--english-compatible",
action="store_true",
dest="english_compatible",
help="Remove diactrics and other non-ascii chars.",
)
return parser.parse_args() return parser.parse_args()

View File

@ -1,8 +1,8 @@
#!/usr/bin/env python #!/usr/bin/env python
''' """
Downloads and prepares (parts of) the "German Distant Speech" corpus (TUDA) for DeepSpeech.py Downloads and prepares (parts of) the "German Distant Speech" corpus (TUDA) for DeepSpeech.py
Use "python3 import_tuda.py -h" for help Use "python3 import_tuda.py -h" for help
''' """
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import argparse import argparse
@ -17,20 +17,21 @@ from collections import Counter
import progressbar import progressbar
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import \ from deepspeech_training.util.importers import validate_label_eng as validate_label
validate_label_eng as validate_label
from deepspeech_training.util.text import Alphabet from deepspeech_training.util.text import Alphabet
TUDA_VERSION = 'v2' TUDA_VERSION = "v2"
TUDA_PACKAGE = 'german-speechdata-package-{}'.format(TUDA_VERSION) TUDA_PACKAGE = "german-speechdata-package-{}".format(TUDA_VERSION)
TUDA_URL = 'http://ltdata1.informatik.uni-hamburg.de/kaldi_tuda_de/{}.tar.gz'.format(TUDA_PACKAGE) TUDA_URL = "http://ltdata1.informatik.uni-hamburg.de/kaldi_tuda_de/{}.tar.gz".format(
TUDA_ARCHIVE = '{}.tar.gz'.format(TUDA_PACKAGE) TUDA_PACKAGE
)
TUDA_ARCHIVE = "{}.tar.gz".format(TUDA_PACKAGE)
CHANNELS = 1 CHANNELS = 1
SAMPLE_WIDTH = 2 SAMPLE_WIDTH = 2
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
def maybe_extract(archive): def maybe_extract(archive):
@ -48,69 +49,79 @@ def maybe_extract(archive):
def check_and_prepare_sentence(sentence): def check_and_prepare_sentence(sentence):
sentence = sentence.lower().replace('co2', 'c o zwei') sentence = sentence.lower().replace("co2", "c o zwei")
chars = [] chars = []
for c in sentence: for c in sentence:
if CLI_ARGS.normalize and c not in 'äöüß' and (ALPHABET is None or not ALPHABET.has_char(c)): if (
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore") CLI_ARGS.normalize
and c not in "äöüß"
and (ALPHABET is None or not ALPHABET.has_char(c))
):
c = (
unicodedata.normalize("NFKD", c)
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
for sc in c: for sc in c:
if ALPHABET is not None and not ALPHABET.has_char(c): if ALPHABET is not None and not ALPHABET.has_char(c):
return None return None
chars.append(sc) chars.append(sc)
return validate_label(''.join(chars)) return validate_label("".join(chars))
def check_wav_file(wav_path, sentence): # pylint: disable=too-many-return-statements def check_wav_file(wav_path, sentence): # pylint: disable=too-many-return-statements
try: try:
with wave.open(wav_path, 'r') as src_wav_file: with wave.open(wav_path, "r") as src_wav_file:
rate = src_wav_file.getframerate() rate = src_wav_file.getframerate()
channels = src_wav_file.getnchannels() channels = src_wav_file.getnchannels()
sample_width = src_wav_file.getsampwidth() sample_width = src_wav_file.getsampwidth()
milliseconds = int(src_wav_file.getnframes() * 1000 / rate) milliseconds = int(src_wav_file.getnframes() * 1000 / rate)
if rate != SAMPLE_RATE: if rate != SAMPLE_RATE:
return False, 'wrong sample rate' return False, "wrong sample rate"
if channels != CHANNELS: if channels != CHANNELS:
return False, 'wrong number of channels' return False, "wrong number of channels"
if sample_width != SAMPLE_WIDTH: if sample_width != SAMPLE_WIDTH:
return False, 'wrong sample width' return False, "wrong sample width"
if milliseconds / len(sentence) < 30: if milliseconds / len(sentence) < 30:
return False, 'too short' return False, "too short"
if milliseconds > CLI_ARGS.max_duration > 0: if milliseconds > CLI_ARGS.max_duration > 0:
return False, 'too long' return False, "too long"
except wave.Error: except wave.Error:
return False, 'invalid wav file' return False, "invalid wav file"
except EOFError: except EOFError:
return False, 'premature EOF' return False, "premature EOF"
return True, 'OK' return True, "OK"
def write_csvs(extracted): def write_csvs(extracted):
sample_counter = 0 sample_counter = 0
reasons = Counter() reasons = Counter()
for sub_set in ['train', 'dev', 'test']: for sub_set in ["train", "dev", "test"]:
set_path = os.path.join(extracted, sub_set) set_path = os.path.join(extracted, sub_set)
set_files = os.listdir(set_path) set_files = os.listdir(set_path)
recordings = {} recordings = {}
for file in set_files: for file in set_files:
if file.endswith('.xml'): if file.endswith(".xml"):
recordings[file[:-4]] = [] recordings[file[:-4]] = []
for file in set_files: for file in set_files:
if file.endswith('.wav') and '_' in file: if file.endswith(".wav") and "_" in file:
prefix = file.split('_')[0] prefix = file.split("_")[0]
if prefix in recordings: if prefix in recordings:
recordings[prefix].append(file) recordings[prefix].append(file)
recordings = recordings.items() recordings = recordings.items()
csv_path = os.path.join(CLI_ARGS.base_dir, 'tuda-{}-{}.csv'.format(TUDA_VERSION, sub_set)) csv_path = os.path.join(
CLI_ARGS.base_dir, "tuda-{}-{}.csv".format(TUDA_VERSION, sub_set)
)
print('Writing "{}"...'.format(csv_path)) print('Writing "{}"...'.format(csv_path))
with open(csv_path, 'w') as csv_file: with open(csv_path, "w") as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES) writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES)
writer.writeheader() writer.writeheader()
set_dir = os.path.join(extracted, sub_set) set_dir = os.path.join(extracted, sub_set)
bar = progressbar.ProgressBar(max_value=len(recordings), widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=len(recordings), widgets=SIMPLE_BAR)
for prefix, wav_names in bar(recordings): for prefix, wav_names in bar(recordings):
xml_path = os.path.join(set_dir, prefix + '.xml') xml_path = os.path.join(set_dir, prefix + ".xml")
meta = ET.parse(xml_path).getroot() meta = ET.parse(xml_path).getroot()
sentence = list(meta.iter('cleaned_sentence'))[0].text sentence = list(meta.iter("cleaned_sentence"))[0].text
sentence = check_and_prepare_sentence(sentence) sentence = check_and_prepare_sentence(sentence)
if sentence is None: if sentence is None:
continue continue
@ -119,15 +130,19 @@ def write_csvs(extracted):
wav_path = os.path.join(set_path, wav_name) wav_path = os.path.join(set_path, wav_name)
keep, reason = check_wav_file(wav_path, sentence) keep, reason = check_wav_file(wav_path, sentence)
if keep: if keep:
writer.writerow({ writer.writerow(
'wav_filename': os.path.relpath(wav_path, CLI_ARGS.base_dir), {
'wav_filesize': os.path.getsize(wav_path), "wav_filename": os.path.relpath(
'transcript': sentence.lower() wav_path, CLI_ARGS.base_dir
}) ),
"wav_filesize": os.path.getsize(wav_path),
"transcript": sentence.lower(),
}
)
else: else:
reasons[reason] += 1 reasons[reason] += 1
if len(reasons.keys()) > 0: if len(reasons.keys()) > 0:
print('Excluded samples:') print("Excluded samples:")
for reason, n in reasons.most_common(): for reason, n in reasons.most_common():
print(' - "{}": {} ({:.2f}%)'.format(reason, n, n * 100 / sample_counter)) print(' - "{}": {} ({:.2f}%)'.format(reason, n, n * 100 / sample_counter))
@ -146,13 +161,29 @@ def download_and_prepare():
def handle_args(): def handle_args():
parser = argparse.ArgumentParser(description='Import German Distant Speech (TUDA)') parser = argparse.ArgumentParser(description="Import German Distant Speech (TUDA)")
parser.add_argument('base_dir', help='Directory containing all data') parser.add_argument("base_dir", help="Directory containing all data")
parser.add_argument('--max_duration', type=int, default=10000, help='Maximum sample duration in milliseconds') parser.add_argument(
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones') "--max_duration",
parser.add_argument('--alphabet', help='Exclude samples with characters not in provided alphabet file') type=int,
parser.add_argument('--keep_archive', type=bool, default=True, default=10000,
help='If downloaded archives should be kept') help="Maximum sample duration in milliseconds",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
parser.add_argument(
"--alphabet",
help="Exclude samples with characters not in provided alphabet file",
)
parser.add_argument(
"--keep_archive",
type=bool,
default=True,
help="If downloaded archives should be kept",
)
return parser.parse_args() return parser.parse_args()

View File

@ -17,7 +17,7 @@ from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import ( from deepspeech_training.util.importers import (
get_counter, get_counter,
get_imported_samples, get_imported_samples,
print_import_report print_import_report,
) )
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -62,7 +62,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
all_samples = [] all_samples = []
for target in sorted(os.listdir(directory)): for target in sorted(os.listdir(directory)):
all_samples += _maybe_prepare_set(path.join(extracted_dir, os.path.split(target)[-1])) all_samples += _maybe_prepare_set(
path.join(extracted_dir, os.path.split(target)[-1])
)
num_samples = len(all_samples) num_samples = len(all_samples)
print(f"Converting wav files to {SAMPLE_RATE}hz...") print(f"Converting wav files to {SAMPLE_RATE}hz...")
@ -76,6 +78,7 @@ def _maybe_convert_sets(target_dir, extracted_data):
_write_csv(extracted_dir, txt_dir, target_dir) _write_csv(extracted_dir, txt_dir, target_dir)
def one_sample(sample): def one_sample(sample):
if is_audio_file(sample): if is_audio_file(sample):
y, sr = librosa.load(sample, sr=16000) y, sr = librosa.load(sample, sr=16000)
@ -98,6 +101,7 @@ def _maybe_prepare_set(target_csv):
samples = new_samples samples = new_samples
return samples return samples
def _write_csv(extracted_dir, txt_dir, target_dir): def _write_csv(extracted_dir, txt_dir, target_dir):
print(f"Writing CSV file") print(f"Writing CSV file")
dset_abs_path = extracted_dir dset_abs_path = extracted_dir
@ -192,7 +196,9 @@ AUDIO_EXTENSIONS = [".wav", "WAV"]
def is_audio_file(filepath): def is_audio_file(filepath):
return any(os.path.basename(filepath).endswith(extension) for extension in AUDIO_EXTENSIONS) return any(
os.path.basename(filepath).endswith(extension) for extension in AUDIO_EXTENSIONS
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -24,8 +24,10 @@ NUM_PARALLEL = 8
"""Lambda function returns the filename of a path""" """Lambda function returns the filename of a path"""
filename_of = lambda x: path.split(x)[1] filename_of = lambda x: path.split(x)[1]
class AtomicCounter(object): class AtomicCounter(object):
"""A class that atomically increments a counter""" """A class that atomically increments a counter"""
def __init__(self, start_count=0): def __init__(self, start_count=0):
"""Initialize the counter """Initialize the counter
:param start_count: the number to start counting at :param start_count: the number to start counting at
@ -48,6 +50,7 @@ class AtomicCounter(object):
"""Returns the current value of the counter (not atomic)""" """Returns the current value of the counter (not atomic)"""
return self.__count return self.__count
def _parallel_downloader(voxforge_url, archive_dir, total, counter): def _parallel_downloader(voxforge_url, archive_dir, total, counter):
"""Generate a function to download a file based on given parameters """Generate a function to download a file based on given parameters
This works by currying the above given arguments into a closure This works by currying the above given arguments into a closure
@ -59,6 +62,7 @@ def _parallel_downloader(voxforge_url, archive_dir, total, counter):
:param counter: an atomic counter to keep track of # of downloaded files :param counter: an atomic counter to keep track of # of downloaded files
:return: a function that actually downloads a file given these params :return: a function that actually downloads a file given these params
""" """
def download(d): def download(d):
"""Binds voxforge_url, archive_dir, total, and counter into this scope """Binds voxforge_url, archive_dir, total, and counter into this scope
Downloads the given file Downloads the given file
@ -66,12 +70,14 @@ def _parallel_downloader(voxforge_url, archive_dir, total, counter):
of the file to download and file is the name of the file to download of the file to download and file is the name of the file to download
""" """
(i, file) = d (i, file) = d
download_url = voxforge_url + '/' + file download_url = voxforge_url + "/" + file
c = counter.increment() c = counter.increment()
print('Downloading file {} ({}/{})...'.format(i+1, c, total)) print("Downloading file {} ({}/{})...".format(i + 1, c, total))
maybe_download(filename_of(download_url), archive_dir, download_url) maybe_download(filename_of(download_url), archive_dir, download_url)
return download return download
def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter): def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter):
"""Generate a function to extract a tar file based on given parameters """Generate a function to extract a tar file based on given parameters
This works by currying the above given arguments into a closure This works by currying the above given arguments into a closure
@ -84,6 +90,7 @@ def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter)
:param counter: an atomic counter to keep track of # of extracted files :param counter: an atomic counter to keep track of # of extracted files
:return: a function that actually extracts a tar file given these params :return: a function that actually extracts a tar file given these params
""" """
def extract(d): def extract(d):
"""Binds data_dir, number_of_test, number_of_dev, total, and counter into this scope """Binds data_dir, number_of_test, number_of_dev, total, and counter into this scope
Extracts the given file Extracts the given file
@ -97,14 +104,18 @@ def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter)
dataset_dir = path.join(data_dir, "dev") dataset_dir = path.join(data_dir, "dev")
else: else:
dataset_dir = path.join(data_dir, "train") dataset_dir = path.join(data_dir, "train")
if not gfile.Exists(os.path.join(dataset_dir, '.'.join(filename_of(archive).split(".")[:-1]))): if not gfile.Exists(
os.path.join(dataset_dir, ".".join(filename_of(archive).split(".")[:-1]))
):
c = counter.increment() c = counter.increment()
print('Extracting file {} ({}/{})...'.format(i+1, c, total)) print("Extracting file {} ({}/{})...".format(i + 1, c, total))
tar = tarfile.open(archive) tar = tarfile.open(archive)
tar.extractall(dataset_dir) tar.extractall(dataset_dir)
tar.close() tar.close()
return extract return extract
def _download_and_preprocess_data(data_dir): def _download_and_preprocess_data(data_dir):
# Conditionally download data to data_dir # Conditionally download data to data_dir
if not path.isdir(data_dir): if not path.isdir(data_dir):
@ -114,18 +125,24 @@ def _download_and_preprocess_data(data_dir):
if not path.isdir(archive_dir): if not path.isdir(archive_dir):
makedirs(archive_dir) makedirs(archive_dir)
print("Downloading Voxforge data set into {} if not already present...".format(archive_dir)) print(
"Downloading Voxforge data set into {} if not already present...".format(
archive_dir
)
)
voxforge_url = 'http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/Audio/Main/16kHz_16bit' voxforge_url = "http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/Audio/Main/16kHz_16bit"
html_page = urllib.request.urlopen(voxforge_url) html_page = urllib.request.urlopen(voxforge_url)
soup = BeautifulSoup(html_page, 'html.parser') soup = BeautifulSoup(html_page, "html.parser")
# list all links # list all links
refs = [l['href'] for l in soup.find_all('a') if ".tgz" in l['href']] refs = [l["href"] for l in soup.find_all("a") if ".tgz" in l["href"]]
# download files in parallel # download files in parallel
print('{} files to download'.format(len(refs))) print("{} files to download".format(len(refs)))
downloader = _parallel_downloader(voxforge_url, archive_dir, len(refs), AtomicCounter()) downloader = _parallel_downloader(
voxforge_url, archive_dir, len(refs), AtomicCounter()
)
p = ThreadPool(NUM_PARALLEL) p = ThreadPool(NUM_PARALLEL)
p.map(downloader, enumerate(refs)) p.map(downloader, enumerate(refs))
@ -143,8 +160,14 @@ def _download_and_preprocess_data(data_dir):
number_of_dev = number_of_files // 100 number_of_dev = number_of_files // 100
# extract tars in parallel # extract tars in parallel
print("Extracting Voxforge data set into {} if not already present...".format(data_dir)) print(
extracter = _parallel_extracter(data_dir, number_of_test, number_of_dev, len(tarfiles), AtomicCounter()) "Extracting Voxforge data set into {} if not already present...".format(
data_dir
)
)
extracter = _parallel_extracter(
data_dir, number_of_test, number_of_dev, len(tarfiles), AtomicCounter()
)
p.map(extracter, enumerate(tarfiles)) p.map(extracter, enumerate(tarfiles))
# Generate data set # Generate data set
@ -158,34 +181,46 @@ def _download_and_preprocess_data(data_dir):
dev_files.to_csv(os.path.join(data_dir, "voxforge-dev.csv"), index=False) dev_files.to_csv(os.path.join(data_dir, "voxforge-dev.csv"), index=False)
test_files.to_csv(os.path.join(data_dir, "voxforge-test.csv"), index=False) test_files.to_csv(os.path.join(data_dir, "voxforge-test.csv"), index=False)
def _generate_dataset(data_dir, data_set): def _generate_dataset(data_dir, data_set):
extracted_dir = path.join(data_dir, data_set) extracted_dir = path.join(data_dir, data_set)
files = [] files = []
for promts_file in glob(os.path.join(extracted_dir + "/*/etc/", "PROMPTS")): for promts_file in glob(os.path.join(extracted_dir + "/*/etc/", "PROMPTS")):
if path.isdir(os.path.join(promts_file[:-11], "wav")): if path.isdir(os.path.join(promts_file[:-11], "wav")):
with codecs.open(promts_file, 'r', 'utf-8') as f: with codecs.open(promts_file, "r", "utf-8") as f:
for line in f: for line in f:
id = line.split(' ')[0].split('/')[-1] id = line.split(" ")[0].split("/")[-1]
sentence = ' '.join(line.split(' ')[1:]) sentence = " ".join(line.split(" ")[1:])
sentence = re.sub("[^a-z']", " ", sentence.strip().lower()) sentence = re.sub("[^a-z']", " ", sentence.strip().lower())
transcript = "" transcript = ""
for token in sentence.split(" "): for token in sentence.split(" "):
word = token.strip() word = token.strip()
if word != "" and word != " ": if word != "" and word != " ":
transcript += word + " " transcript += word + " "
transcript = unicodedata.normalize("NFKD", transcript.strip()) \ transcript = (
.encode("ascii", "ignore") \ unicodedata.normalize("NFKD", transcript.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore") .decode("ascii", "ignore")
)
wav_file = path.join(promts_file[:-11], "wav/" + id + ".wav") wav_file = path.join(promts_file[:-11], "wav/" + id + ".wav")
if gfile.Exists(wav_file): if gfile.Exists(wav_file):
wav_filesize = path.getsize(wav_file) wav_filesize = path.getsize(wav_file)
# remove audios that are shorter than 0.5s and longer than 20s. # remove audios that are shorter than 0.5s and longer than 20s.
# remove audios that are too short for transcript. # remove audios that are too short for transcript.
if ((wav_filesize/32000) > 0.5 and (wav_filesize/32000) < 20 and transcript != "" and if (
wav_filesize/len(transcript) > 1400): (wav_filesize / 32000) > 0.5
files.append((os.path.abspath(wav_file), wav_filesize, transcript)) and (wav_filesize / 32000) < 20
and transcript != ""
and wav_filesize / len(transcript) > 1400
):
files.append(
(os.path.abspath(wav_file), wav_filesize, transcript)
)
return pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
)
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
if __name__ == "__main__": if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1]) _download_and_preprocess_data(sys.argv[1])

View File

@ -7,11 +7,12 @@ import tensorflow.compat.v1 as tfv1
def main(): def main():
with tfv1.gfile.FastGFile(sys.argv[1], 'rb') as fin: with tfv1.gfile.FastGFile(sys.argv[1], "rb") as fin:
graph_def = tfv1.GraphDef() graph_def = tfv1.GraphDef()
graph_def.ParseFromString(fin.read()) graph_def.ParseFromString(fin.read())
print('\n'.join(sorted(set(n.op for n in graph_def.node)))) print("\n".join(sorted(set(n.op for n in graph_def.node))))
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View File

@ -10,10 +10,7 @@ import random
import sys import sys
from deepspeech_training.util.audio import AUDIO_TYPE_PCM from deepspeech_training.util.audio import AUDIO_TYPE_PCM
from deepspeech_training.util.sample_collections import ( from deepspeech_training.util.sample_collections import LabeledSample, samples_from_file
LabeledSample,
samples_from_file
)
def play_sample(samples, index): def play_sample(samples, index):
@ -22,7 +19,7 @@ def play_sample(samples, index):
if CLI_ARGS.random: if CLI_ARGS.random:
index = random.randint(0, len(samples)) index = random.randint(0, len(samples))
elif index >= len(samples): elif index >= len(samples):
print('No sample with index {}'.format(CLI_ARGS.start)) print("No sample with index {}".format(CLI_ARGS.start))
sys.exit(1) sys.exit(1)
sample = samples[index] sample = samples[index]
print('Sample "{}"'.format(sample.sample_id)) print('Sample "{}"'.format(sample.sample_id))
@ -48,13 +45,28 @@ def play_collection():
def handle_args(): def handle_args():
parser = argparse.ArgumentParser(description='Tool for playing samples from Sample Databases (SDB files) ' parser = argparse.ArgumentParser(
'and DeepSpeech CSV files') description="Tool for playing samples from Sample Databases (SDB files) "
parser.add_argument('collection', help='Sample DB or CSV file to play samples from') "and DeepSpeech CSV files"
parser.add_argument('--start', type=int, default=0, )
help='Sample index to start at (negative numbers are relative to the end of the collection)') parser.add_argument("collection", help="Sample DB or CSV file to play samples from")
parser.add_argument('--number', type=int, default=-1, help='Number of samples to play (-1 for endless)') parser.add_argument(
parser.add_argument('--random', action='store_true', help='If samples should be played in random order') "--start",
type=int,
default=0,
help="Sample index to start at (negative numbers are relative to the end of the collection)",
)
parser.add_argument(
"--number",
type=int,
default=-1,
help="Number of samples to play (-1 for endless)",
)
parser.add_argument(
"--random",
action="store_true",
help="If samples should be played in random order",
)
return parser.parse_args() return parser.parse_args()
@ -68,5 +80,5 @@ if __name__ == "__main__":
try: try:
play_collection() play_collection()
except KeyboardInterrupt: except KeyboardInterrupt:
print(' Stopped') print(" Stopped")
sys.exit(0) sys.exit(0)