Reformat importers with black
This commit is contained in:
parent
b7e6b8c3e6
commit
6f0bf3b3a8
@ -1,8 +1,8 @@
|
||||
#!/usr/bin/env python
|
||||
'''
|
||||
"""
|
||||
Tool for building Sample Databases (SDB files) from DeepSpeech CSV files and other SDB files
|
||||
Use "python3 build_sdb.py -h" for help
|
||||
'''
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
@ -12,44 +12,60 @@ import progressbar
|
||||
from deepspeech_training.util.audio import (
|
||||
AUDIO_TYPE_OPUS,
|
||||
AUDIO_TYPE_WAV,
|
||||
change_audio_types
|
||||
change_audio_types,
|
||||
)
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR
|
||||
from deepspeech_training.util.sample_collections import (
|
||||
DirectSDBWriter,
|
||||
samples_from_files
|
||||
samples_from_files,
|
||||
)
|
||||
|
||||
AUDIO_TYPE_LOOKUP = {
|
||||
'wav': AUDIO_TYPE_WAV,
|
||||
'opus': AUDIO_TYPE_OPUS
|
||||
}
|
||||
AUDIO_TYPE_LOOKUP = {"wav": AUDIO_TYPE_WAV, "opus": AUDIO_TYPE_OPUS}
|
||||
|
||||
|
||||
def build_sdb():
|
||||
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
|
||||
with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled) as sdb_writer:
|
||||
with DirectSDBWriter(
|
||||
CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled
|
||||
) as sdb_writer:
|
||||
samples = samples_from_files(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
|
||||
bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR)
|
||||
for sample in bar(change_audio_types(samples, audio_type=audio_type, processes=CLI_ARGS.workers)):
|
||||
for sample in bar(
|
||||
change_audio_types(
|
||||
samples, audio_type=audio_type, processes=CLI_ARGS.workers
|
||||
)
|
||||
):
|
||||
sdb_writer.add(sample)
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = argparse.ArgumentParser(description='Tool for building Sample Databases (SDB files) '
|
||||
'from DeepSpeech CSV files and other SDB files')
|
||||
parser.add_argument('sources', nargs='+',
|
||||
help='Source CSV and/or SDB files - '
|
||||
'Note: For getting a correctly ordered target SDB, source SDBs have to have their samples '
|
||||
'already ordered from shortest to longest.')
|
||||
parser.add_argument('target', help='SDB file to create')
|
||||
parser.add_argument('--audio-type', default='opus', choices=AUDIO_TYPE_LOOKUP.keys(),
|
||||
help='Audio representation inside target SDB')
|
||||
parser.add_argument('--workers', type=int, default=None,
|
||||
help='Number of encoding SDB workers')
|
||||
parser.add_argument('--unlabeled', action='store_true',
|
||||
help='If to build an SDB with unlabeled (audio only) samples - '
|
||||
'typically used for building noise augmentation corpora')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Tool for building Sample Databases (SDB files) "
|
||||
"from DeepSpeech CSV files and other SDB files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"sources",
|
||||
nargs="+",
|
||||
help="Source CSV and/or SDB files - "
|
||||
"Note: For getting a correctly ordered target SDB, source SDBs have to have their samples "
|
||||
"already ordered from shortest to longest.",
|
||||
)
|
||||
parser.add_argument("target", help="SDB file to create")
|
||||
parser.add_argument(
|
||||
"--audio-type",
|
||||
default="opus",
|
||||
choices=AUDIO_TYPE_LOOKUP.keys(),
|
||||
help="Audio representation inside target SDB",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers", type=int, default=None, help="Number of encoding SDB workers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unlabeled",
|
||||
action="store_true",
|
||||
help="If to build an SDB with unlabeled (audio only) samples - "
|
||||
"typically used for building noise augmentation corpora",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -9,12 +9,13 @@ from google.protobuf import text_format
|
||||
|
||||
def main():
|
||||
# Load and export as string
|
||||
with tfv1.gfile.FastGFile(sys.argv[1], 'rb') as fin:
|
||||
with tfv1.gfile.FastGFile(sys.argv[1], "rb") as fin:
|
||||
graph_def = tfv1.GraphDef()
|
||||
graph_def.ParseFromString(fin.read())
|
||||
|
||||
with tfv1.gfile.FastGFile(sys.argv[1] + 'txt', 'w') as fout:
|
||||
with tfv1.gfile.FastGFile(sys.argv[1] + "txt", "w") as fout:
|
||||
fout.write(text_format.MessageToString(graph_def))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -9,11 +9,11 @@ import pandas
|
||||
|
||||
from deepspeech_training.util.importers import get_importers_parser
|
||||
|
||||
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
|
||||
|
||||
def extract(archive_path, target_dir):
|
||||
print('Extracting {} into {}...'.format(archive_path, target_dir))
|
||||
print("Extracting {} into {}...".format(archive_path, target_dir))
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
|
||||
@ -21,9 +21,9 @@ def extract(archive_path, target_dir):
|
||||
def preprocess_data(tgz_file, target_dir):
|
||||
# First extract main archive and sub-archives
|
||||
extract(tgz_file, target_dir)
|
||||
main_folder = os.path.join(target_dir, 'aidatatang_200zh')
|
||||
main_folder = os.path.join(target_dir, "aidatatang_200zh")
|
||||
|
||||
for targz in glob.glob(os.path.join(main_folder, 'corpus', '*', '*.tar.gz')):
|
||||
for targz in glob.glob(os.path.join(main_folder, "corpus", "*", "*.tar.gz")):
|
||||
extract(targz, os.path.dirname(targz))
|
||||
|
||||
# Folder structure is now:
|
||||
@ -42,9 +42,11 @@ def preprocess_data(tgz_file, target_dir):
|
||||
|
||||
# Since the transcripts themselves can contain spaces, we split on space but
|
||||
# only once, then build a mapping from file name to transcript
|
||||
transcripts_path = os.path.join(main_folder, 'transcript', 'aidatatang_200_zh_transcript.txt')
|
||||
transcripts_path = os.path.join(
|
||||
main_folder, "transcript", "aidatatang_200_zh_transcript.txt"
|
||||
)
|
||||
with open(transcripts_path) as fin:
|
||||
transcripts = dict((line.split(' ', maxsplit=1) for line in fin))
|
||||
transcripts = dict((line.split(" ", maxsplit=1) for line in fin))
|
||||
|
||||
def load_set(glob_path):
|
||||
set_files = []
|
||||
@ -53,33 +55,39 @@ def preprocess_data(tgz_file, target_dir):
|
||||
wav_filename = wav
|
||||
wav_filesize = os.path.getsize(wav)
|
||||
transcript_key = os.path.splitext(os.path.basename(wav))[0]
|
||||
transcript = transcripts[transcript_key].strip('\n')
|
||||
transcript = transcripts[transcript_key].strip("\n")
|
||||
set_files.append((wav_filename, wav_filesize, transcript))
|
||||
except KeyError:
|
||||
print('Warning: Missing transcript for WAV file {}.'.format(wav))
|
||||
print("Warning: Missing transcript for WAV file {}.".format(wav))
|
||||
return set_files
|
||||
|
||||
for subset in ('train', 'dev', 'test'):
|
||||
print('Loading {} set samples...'.format(subset))
|
||||
subset_files = load_set(os.path.join(main_folder, 'corpus', subset, '*', '*.wav'))
|
||||
for subset in ("train", "dev", "test"):
|
||||
print("Loading {} set samples...".format(subset))
|
||||
subset_files = load_set(
|
||||
os.path.join(main_folder, "corpus", subset, "*", "*.wav")
|
||||
)
|
||||
df = pandas.DataFrame(data=subset_files, columns=COLUMN_NAMES)
|
||||
|
||||
# Trim train set to under 10s by removing the last couple hundred samples
|
||||
if subset == 'train':
|
||||
durations = (df['wav_filesize'] - 44) / 16000 / 2
|
||||
if subset == "train":
|
||||
durations = (df["wav_filesize"] - 44) / 16000 / 2
|
||||
df = df[durations <= 10.0]
|
||||
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum()))
|
||||
print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
|
||||
|
||||
dest_csv = os.path.join(target_dir, 'aidatatang_{}.csv'.format(subset))
|
||||
print('Saving {} set into {}...'.format(subset, dest_csv))
|
||||
dest_csv = os.path.join(target_dir, "aidatatang_{}.csv".format(subset))
|
||||
print("Saving {} set into {}...".format(subset, dest_csv))
|
||||
df.to_csv(dest_csv, index=False)
|
||||
|
||||
|
||||
def main():
|
||||
# https://www.openslr.org/62/
|
||||
parser = get_importers_parser(description='Import aidatatang_200zh corpus')
|
||||
parser.add_argument('tgz_file', help='Path to aidatatang_200zh.tgz')
|
||||
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
|
||||
parser = get_importers_parser(description="Import aidatatang_200zh corpus")
|
||||
parser.add_argument("tgz_file", help="Path to aidatatang_200zh.tgz")
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default="",
|
||||
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
|
||||
)
|
||||
params = parser.parse_args()
|
||||
|
||||
if not params.target_dir:
|
||||
|
@ -9,11 +9,11 @@ import pandas
|
||||
|
||||
from deepspeech_training.util.importers import get_importers_parser
|
||||
|
||||
COLUMNNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
COLUMNNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
|
||||
|
||||
def extract(archive_path, target_dir):
|
||||
print('Extracting {} into {}...'.format(archive_path, target_dir))
|
||||
print("Extracting {} into {}...".format(archive_path, target_dir))
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
|
||||
@ -21,10 +21,10 @@ def extract(archive_path, target_dir):
|
||||
def preprocess_data(tgz_file, target_dir):
|
||||
# First extract main archive and sub-archives
|
||||
extract(tgz_file, target_dir)
|
||||
main_folder = os.path.join(target_dir, 'data_aishell')
|
||||
main_folder = os.path.join(target_dir, "data_aishell")
|
||||
|
||||
wav_archives_folder = os.path.join(main_folder, 'wav')
|
||||
for targz in glob.glob(os.path.join(wav_archives_folder, '*.tar.gz')):
|
||||
wav_archives_folder = os.path.join(main_folder, "wav")
|
||||
for targz in glob.glob(os.path.join(wav_archives_folder, "*.tar.gz")):
|
||||
extract(targz, main_folder)
|
||||
|
||||
# Folder structure is now:
|
||||
@ -41,9 +41,11 @@ def preprocess_data(tgz_file, target_dir):
|
||||
|
||||
# Since the transcripts themselves can contain spaces, we split on space but
|
||||
# only once, then build a mapping from file name to transcript
|
||||
transcripts_path = os.path.join(main_folder, 'transcript', 'aishell_transcript_v0.8.txt')
|
||||
transcripts_path = os.path.join(
|
||||
main_folder, "transcript", "aishell_transcript_v0.8.txt"
|
||||
)
|
||||
with open(transcripts_path) as fin:
|
||||
transcripts = dict((line.split(' ', maxsplit=1) for line in fin))
|
||||
transcripts = dict((line.split(" ", maxsplit=1) for line in fin))
|
||||
|
||||
def load_set(glob_path):
|
||||
set_files = []
|
||||
@ -52,33 +54,37 @@ def preprocess_data(tgz_file, target_dir):
|
||||
wav_filename = wav
|
||||
wav_filesize = os.path.getsize(wav)
|
||||
transcript_key = os.path.splitext(os.path.basename(wav))[0]
|
||||
transcript = transcripts[transcript_key].strip('\n')
|
||||
transcript = transcripts[transcript_key].strip("\n")
|
||||
set_files.append((wav_filename, wav_filesize, transcript))
|
||||
except KeyError:
|
||||
print('Warning: Missing transcript for WAV file {}.'.format(wav))
|
||||
print("Warning: Missing transcript for WAV file {}.".format(wav))
|
||||
return set_files
|
||||
|
||||
for subset in ('train', 'dev', 'test'):
|
||||
print('Loading {} set samples...'.format(subset))
|
||||
subset_files = load_set(os.path.join(main_folder, subset, 'S*', '*.wav'))
|
||||
for subset in ("train", "dev", "test"):
|
||||
print("Loading {} set samples...".format(subset))
|
||||
subset_files = load_set(os.path.join(main_folder, subset, "S*", "*.wav"))
|
||||
df = pandas.DataFrame(data=subset_files, columns=COLUMNNAMES)
|
||||
|
||||
# Trim train set to under 10s by removing the last couple hundred samples
|
||||
if subset == 'train':
|
||||
durations = (df['wav_filesize'] - 44) / 16000 / 2
|
||||
if subset == "train":
|
||||
durations = (df["wav_filesize"] - 44) / 16000 / 2
|
||||
df = df[durations <= 10.0]
|
||||
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum()))
|
||||
print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
|
||||
|
||||
dest_csv = os.path.join(target_dir, 'aishell_{}.csv'.format(subset))
|
||||
print('Saving {} set into {}...'.format(subset, dest_csv))
|
||||
dest_csv = os.path.join(target_dir, "aishell_{}.csv".format(subset))
|
||||
print("Saving {} set into {}...".format(subset, dest_csv))
|
||||
df.to_csv(dest_csv, index=False)
|
||||
|
||||
|
||||
def main():
|
||||
# http://www.openslr.org/33/
|
||||
parser = get_importers_parser(description='Import AISHELL corpus')
|
||||
parser.add_argument('aishell_tgz_file', help='Path to data_aishell.tgz')
|
||||
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
|
||||
parser = get_importers_parser(description="Import AISHELL corpus")
|
||||
parser.add_argument("aishell_tgz_file", help="Path to data_aishell.tgz")
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default="",
|
||||
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
|
||||
)
|
||||
params = parser.parse_args()
|
||||
|
||||
if not params.target_dir:
|
||||
|
@ -15,17 +15,19 @@ from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import (
|
||||
get_counter,
|
||||
get_imported_samples,
|
||||
print_import_report
|
||||
print_import_report,
|
||||
)
|
||||
from deepspeech_training.util.importers import \
|
||||
validate_label_eng as validate_label
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 10
|
||||
ARCHIVE_DIR_NAME = 'cv_corpus_v1'
|
||||
ARCHIVE_NAME = ARCHIVE_DIR_NAME + '.tar.gz'
|
||||
ARCHIVE_URL = 'https://s3.us-east-2.amazonaws.com/common-voice-data-download/' + ARCHIVE_NAME
|
||||
ARCHIVE_DIR_NAME = "cv_corpus_v1"
|
||||
ARCHIVE_NAME = ARCHIVE_DIR_NAME + ".tar.gz"
|
||||
ARCHIVE_URL = (
|
||||
"https://s3.us-east-2.amazonaws.com/common-voice-data-download/" + ARCHIVE_NAME
|
||||
)
|
||||
|
||||
|
||||
def _download_and_preprocess_data(target_dir):
|
||||
# Making path absolute
|
||||
@ -37,6 +39,7 @@ def _download_and_preprocess_data(target_dir):
|
||||
# Conditionally convert common voice CSV files and mp3 data to DeepSpeech CSVs and wav
|
||||
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
|
||||
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = os.join(target_dir, extracted_data)
|
||||
@ -47,43 +50,56 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
else:
|
||||
print('Found directory "%s" - not extracting it from archive.' % extracted_path)
|
||||
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data):
|
||||
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||
for source_csv in glob(os.path.join(extracted_dir, '*.csv')):
|
||||
_maybe_convert_set(extracted_dir, source_csv, os.path.join(target_dir, os.path.split(source_csv)[-1]))
|
||||
for source_csv in glob(os.path.join(extracted_dir, "*.csv")):
|
||||
_maybe_convert_set(
|
||||
extracted_dir,
|
||||
source_csv,
|
||||
os.path.join(target_dir, os.path.split(source_csv)[-1]),
|
||||
)
|
||||
|
||||
|
||||
def one_sample(sample):
|
||||
mp3_filename = sample[0]
|
||||
# Storing wav files next to the mp3 ones - just with a different suffix
|
||||
wav_filename = path.splitext(mp3_filename)[0] + ".wav"
|
||||
_maybe_convert_wav(mp3_filename, wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
frames = int(
|
||||
subprocess.check_output(["soxi", "-s", wav_filename], stderr=subprocess.STDOUT)
|
||||
)
|
||||
file_size = -1
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
frames = int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
label = validate_label(sample[1])
|
||||
rows = []
|
||||
counter = get_counter()
|
||||
if file_size == -1:
|
||||
# Excluding samples that failed upon conversion
|
||||
counter['failed'] += 1
|
||||
counter["failed"] += 1
|
||||
elif label is None:
|
||||
# Excluding samples that failed on label validation
|
||||
counter['invalid_label'] += 1
|
||||
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
|
||||
counter["invalid_label"] += 1
|
||||
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
|
||||
# Excluding samples that are too short to fit the transcript
|
||||
counter['too_short'] += 1
|
||||
elif frames/SAMPLE_RATE > MAX_SECS:
|
||||
counter["too_short"] += 1
|
||||
elif frames / SAMPLE_RATE > MAX_SECS:
|
||||
# Excluding very long samples to keep a reasonable batch-size
|
||||
counter['too_long'] += 1
|
||||
counter["too_long"] += 1
|
||||
else:
|
||||
# This one is good - keep it for the target CSV
|
||||
rows.append((wav_filename, file_size, label))
|
||||
counter['all'] += 1
|
||||
counter['total_time'] += frames
|
||||
counter["all"] += 1
|
||||
counter["total_time"] += frames
|
||||
return (counter, rows)
|
||||
|
||||
|
||||
def _maybe_convert_set(extracted_dir, source_csv, target_csv):
|
||||
print()
|
||||
if os.path.exists(target_csv):
|
||||
@ -94,14 +110,14 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
|
||||
with open(source_csv) as source_csv_file:
|
||||
reader = csv.DictReader(source_csv_file)
|
||||
for row in reader:
|
||||
samples.append((os.path.join(extracted_dir, row['filename']), row['text']))
|
||||
samples.append((os.path.join(extracted_dir, row["filename"]), row["text"]))
|
||||
|
||||
# Mutable counters for the concurrent embedded routine
|
||||
counter = get_counter()
|
||||
num_samples = len(samples)
|
||||
rows = []
|
||||
|
||||
print('Importing mp3 files...')
|
||||
print("Importing mp3 files...")
|
||||
pool = Pool()
|
||||
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
|
||||
for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
|
||||
@ -113,19 +129,26 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
|
||||
pool.join()
|
||||
|
||||
print('Writing "%s"...' % target_csv)
|
||||
with open(target_csv, 'w') as target_csv_file:
|
||||
with open(target_csv, "w") as target_csv_file:
|
||||
writer = csv.DictWriter(target_csv_file, fieldnames=FIELDNAMES)
|
||||
writer.writeheader()
|
||||
bar = progressbar.ProgressBar(max_value=len(rows), widgets=SIMPLE_BAR)
|
||||
for filename, file_size, transcript in bar(rows):
|
||||
writer.writerow({ 'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript })
|
||||
writer.writerow(
|
||||
{
|
||||
"wav_filename": filename,
|
||||
"wav_filesize": file_size,
|
||||
"transcript": transcript,
|
||||
}
|
||||
)
|
||||
|
||||
imported_samples = get_imported_samples(counter)
|
||||
assert counter['all'] == num_samples
|
||||
assert counter["all"] == num_samples
|
||||
assert len(rows) == imported_samples
|
||||
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
|
||||
|
||||
def _maybe_convert_wav(mp3_filename, wav_filename):
|
||||
if not os.path.exists(wav_filename):
|
||||
transformer = sox.Transformer()
|
||||
@ -135,5 +158,6 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
|
||||
except sox.core.SoxError:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_download_and_preprocess_data(sys.argv[1])
|
||||
|
@ -1,11 +1,11 @@
|
||||
#!/usr/bin/env python
|
||||
'''
|
||||
"""
|
||||
Broadly speaking, this script takes the audio downloaded from Common Voice
|
||||
for a certain language, in addition to the *.tsv files output by CorporaCreator,
|
||||
and the script formats the data and transcripts to be in a state usable by
|
||||
DeepSpeech.py
|
||||
Use "python3 import_cv2.py -h" for help
|
||||
'''
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import csv
|
||||
@ -23,26 +23,27 @@ from deepspeech_training.util.importers import (
|
||||
get_imported_samples,
|
||||
get_importers_parser,
|
||||
get_validate_label,
|
||||
print_import_report
|
||||
print_import_report,
|
||||
)
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 10
|
||||
|
||||
|
||||
def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
|
||||
for dataset in ['train', 'test', 'dev', 'validated', 'other']:
|
||||
input_tsv = os.path.join(os.path.abspath(tsv_dir), dataset+".tsv")
|
||||
for dataset in ["train", "test", "dev", "validated", "other"]:
|
||||
input_tsv = os.path.join(os.path.abspath(tsv_dir), dataset + ".tsv")
|
||||
if os.path.isfile(input_tsv):
|
||||
print("Loading TSV file: ", input_tsv)
|
||||
_maybe_convert_set(input_tsv, audio_dir, space_after_every_character)
|
||||
|
||||
|
||||
def one_sample(sample):
|
||||
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
||||
mp3_filename = sample[0]
|
||||
if not os.path.splitext(mp3_filename.lower())[1] == '.mp3':
|
||||
if not os.path.splitext(mp3_filename.lower())[1] == ".mp3":
|
||||
mp3_filename += ".mp3"
|
||||
# Storing wav files next to the mp3 ones - just with a different suffix
|
||||
wav_filename = os.path.splitext(mp3_filename)[0] + ".wav"
|
||||
@ -51,40 +52,47 @@ def one_sample(sample):
|
||||
frames = 0
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = os.path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
frames = int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
label = label_filter_fun(sample[1])
|
||||
rows = []
|
||||
counter = get_counter()
|
||||
if file_size == -1:
|
||||
# Excluding samples that failed upon conversion
|
||||
counter['failed'] += 1
|
||||
counter["failed"] += 1
|
||||
elif label is None:
|
||||
# Excluding samples that failed on label validation
|
||||
counter['invalid_label'] += 1
|
||||
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
|
||||
counter["invalid_label"] += 1
|
||||
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
|
||||
# Excluding samples that are too short to fit the transcript
|
||||
counter['too_short'] += 1
|
||||
elif frames/SAMPLE_RATE > MAX_SECS:
|
||||
counter["too_short"] += 1
|
||||
elif frames / SAMPLE_RATE > MAX_SECS:
|
||||
# Excluding very long samples to keep a reasonable batch-size
|
||||
counter['too_long'] += 1
|
||||
counter["too_long"] += 1
|
||||
else:
|
||||
# This one is good - keep it for the target CSV
|
||||
rows.append((os.path.split(wav_filename)[-1], file_size, label))
|
||||
counter['all'] += 1
|
||||
counter['total_time'] += frames
|
||||
counter["all"] += 1
|
||||
counter["total_time"] += frames
|
||||
|
||||
return (counter, rows)
|
||||
|
||||
|
||||
def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
|
||||
output_csv = os.path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv'))
|
||||
output_csv = os.path.join(
|
||||
audio_dir, os.path.split(input_tsv)[-1].replace("tsv", "csv")
|
||||
)
|
||||
print("Saving new DeepSpeech-formatted CSV file to: ", output_csv)
|
||||
|
||||
# Get audiofile path and transcript for each sentence in tsv
|
||||
samples = []
|
||||
with open(input_tsv, encoding='utf-8') as input_tsv_file:
|
||||
reader = csv.DictReader(input_tsv_file, delimiter='\t')
|
||||
with open(input_tsv, encoding="utf-8") as input_tsv_file:
|
||||
reader = csv.DictReader(input_tsv_file, delimiter="\t")
|
||||
for row in reader:
|
||||
samples.append((os.path.join(audio_dir, row['path']), row['sentence']))
|
||||
samples.append((os.path.join(audio_dir, row["path"]), row["sentence"]))
|
||||
|
||||
counter = get_counter()
|
||||
num_samples = len(samples)
|
||||
@ -101,19 +109,31 @@ def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
with open(output_csv, 'w', encoding='utf-8') as output_csv_file:
|
||||
print('Writing CSV file for DeepSpeech.py as: ', output_csv)
|
||||
with open(output_csv, "w", encoding="utf-8") as output_csv_file:
|
||||
print("Writing CSV file for DeepSpeech.py as: ", output_csv)
|
||||
writer = csv.DictWriter(output_csv_file, fieldnames=FIELDNAMES)
|
||||
writer.writeheader()
|
||||
bar = progressbar.ProgressBar(max_value=len(rows), widgets=SIMPLE_BAR)
|
||||
for filename, file_size, transcript in bar(rows):
|
||||
if space_after_every_character:
|
||||
writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': ' '.join(transcript)})
|
||||
writer.writerow(
|
||||
{
|
||||
"wav_filename": filename,
|
||||
"wav_filesize": file_size,
|
||||
"transcript": " ".join(transcript),
|
||||
}
|
||||
)
|
||||
else:
|
||||
writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript})
|
||||
writer.writerow(
|
||||
{
|
||||
"wav_filename": filename,
|
||||
"wav_filesize": file_size,
|
||||
"transcript": transcript,
|
||||
}
|
||||
)
|
||||
|
||||
imported_samples = get_imported_samples(counter)
|
||||
assert counter['all'] == num_samples
|
||||
assert counter["all"] == num_samples
|
||||
assert len(rows) == imported_samples
|
||||
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
@ -130,24 +150,42 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
PARSER = get_importers_parser(description='Import CommonVoice v2.0 corpora')
|
||||
PARSER.add_argument('tsv_dir', help='Directory containing tsv files')
|
||||
PARSER.add_argument('--audio_dir', help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"')
|
||||
PARSER.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
|
||||
PARSER.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
|
||||
PARSER.add_argument('--space_after_every_character', action='store_true', help='To help transcript join by white space')
|
||||
PARSER = get_importers_parser(description="Import CommonVoice v2.0 corpora")
|
||||
PARSER.add_argument("tsv_dir", help="Directory containing tsv files")
|
||||
PARSER.add_argument(
|
||||
"--audio_dir",
|
||||
help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"',
|
||||
)
|
||||
PARSER.add_argument(
|
||||
"--filter_alphabet",
|
||||
help="Exclude samples with characters not in provided alphabet",
|
||||
)
|
||||
PARSER.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Converts diacritic characters to their base ones",
|
||||
)
|
||||
PARSER.add_argument(
|
||||
"--space_after_every_character",
|
||||
action="store_true",
|
||||
help="To help transcript join by white space",
|
||||
)
|
||||
|
||||
PARAMS = PARSER.parse_args()
|
||||
validate_label = get_validate_label(PARAMS)
|
||||
|
||||
AUDIO_DIR = PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, 'clips')
|
||||
AUDIO_DIR = (
|
||||
PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, "clips")
|
||||
)
|
||||
ALPHABET = Alphabet(PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None
|
||||
|
||||
def label_filter_fun(label):
|
||||
if PARAMS.normalize:
|
||||
label = unicodedata.normalize("NFKD", label.strip()) \
|
||||
.encode("ascii", "ignore") \
|
||||
label = (
|
||||
unicodedata.normalize("NFKD", label.strip())
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
label = validate_label(label)
|
||||
if ALPHABET and label:
|
||||
try:
|
||||
|
@ -12,14 +12,12 @@ import librosa
|
||||
import pandas
|
||||
import soundfile # <= Has an external dependency on libsndfile
|
||||
|
||||
from deepspeech_training.util.importers import \
|
||||
validate_label_eng as validate_label
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
|
||||
# Prerequisite: Having the sph2pipe tool in your PATH:
|
||||
# https://www.ldc.upenn.edu/language-resources/tools/sphere-conversion-tools
|
||||
|
||||
|
||||
|
||||
def _download_and_preprocess_data(data_dir):
|
||||
# Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19
|
||||
|
||||
@ -28,33 +26,55 @@ def _download_and_preprocess_data(data_dir):
|
||||
_maybe_convert_wav(data_dir, "LDC2005S13", "fisher-2005-wav")
|
||||
|
||||
# Conditionally split Fisher wav data
|
||||
all_2004 = _split_wav_and_sentences(data_dir,
|
||||
original_data="fisher-2004-wav",
|
||||
converted_data="fisher-2004-split-wav",
|
||||
trans_data=os.path.join("LDC2004T19", "fe_03_p1_tran", "data", "trans"))
|
||||
all_2005 = _split_wav_and_sentences(data_dir,
|
||||
original_data="fisher-2005-wav",
|
||||
converted_data="fisher-2005-split-wav",
|
||||
trans_data=os.path.join("LDC2005T19", "fe_03_p2_tran", "data", "trans"))
|
||||
all_2004 = _split_wav_and_sentences(
|
||||
data_dir,
|
||||
original_data="fisher-2004-wav",
|
||||
converted_data="fisher-2004-split-wav",
|
||||
trans_data=os.path.join("LDC2004T19", "fe_03_p1_tran", "data", "trans"),
|
||||
)
|
||||
all_2005 = _split_wav_and_sentences(
|
||||
data_dir,
|
||||
original_data="fisher-2005-wav",
|
||||
converted_data="fisher-2005-split-wav",
|
||||
trans_data=os.path.join("LDC2005T19", "fe_03_p2_tran", "data", "trans"),
|
||||
)
|
||||
|
||||
# The following files have incorrect transcripts that are much longer than
|
||||
# their audio source. The result is that we end up with more labels than time
|
||||
# slices, which breaks CTC.
|
||||
all_2004.loc[all_2004["wav_filename"].str.endswith("fe_03_00265-33.53-33.81.wav"), "transcript"] = "correct"
|
||||
all_2004.loc[all_2004["wav_filename"].str.endswith("fe_03_00991-527.39-528.3.wav"), "transcript"] = "that's one of those"
|
||||
all_2005.loc[all_2005["wav_filename"].str.endswith("fe_03_10282-344.42-344.84.wav"), "transcript"] = "they don't want"
|
||||
all_2005.loc[all_2005["wav_filename"].str.endswith("fe_03_10677-101.04-106.41.wav"), "transcript"] = "uh my mine yeah the german shepherd pitbull mix he snores almost as loud as i do"
|
||||
all_2004.loc[
|
||||
all_2004["wav_filename"].str.endswith("fe_03_00265-33.53-33.81.wav"),
|
||||
"transcript",
|
||||
] = "correct"
|
||||
all_2004.loc[
|
||||
all_2004["wav_filename"].str.endswith("fe_03_00991-527.39-528.3.wav"),
|
||||
"transcript",
|
||||
] = "that's one of those"
|
||||
all_2005.loc[
|
||||
all_2005["wav_filename"].str.endswith("fe_03_10282-344.42-344.84.wav"),
|
||||
"transcript",
|
||||
] = "they don't want"
|
||||
all_2005.loc[
|
||||
all_2005["wav_filename"].str.endswith("fe_03_10677-101.04-106.41.wav"),
|
||||
"transcript",
|
||||
] = "uh my mine yeah the german shepherd pitbull mix he snores almost as loud as i do"
|
||||
|
||||
# The following file is just a short sound and not at all transcribed like provided.
|
||||
# So we just exclude it.
|
||||
all_2004 = all_2004[~all_2004["wav_filename"].str.endswith("fe_03_00027-393.8-394.05.wav")]
|
||||
all_2004 = all_2004[
|
||||
~all_2004["wav_filename"].str.endswith("fe_03_00027-393.8-394.05.wav")
|
||||
]
|
||||
|
||||
# The following file is far too long and would ruin our training batch size.
|
||||
# So we just exclude it.
|
||||
all_2005 = all_2005[~all_2005["wav_filename"].str.endswith("fe_03_11487-31.09-234.06.wav")]
|
||||
all_2005 = all_2005[
|
||||
~all_2005["wav_filename"].str.endswith("fe_03_11487-31.09-234.06.wav")
|
||||
]
|
||||
|
||||
# The following file is too large for its transcript, so we just exclude it.
|
||||
all_2004 = all_2004[~all_2004["wav_filename"].str.endswith("fe_03_01326-307.42-307.93.wav")]
|
||||
all_2004 = all_2004[
|
||||
~all_2004["wav_filename"].str.endswith("fe_03_01326-307.42-307.93.wav")
|
||||
]
|
||||
|
||||
# Conditionally split Fisher data into train/validation/test sets
|
||||
train_2004, dev_2004, test_2004 = _split_sets(all_2004)
|
||||
@ -70,6 +90,7 @@ def _download_and_preprocess_data(data_dir):
|
||||
dev_files.to_csv(os.path.join(data_dir, "fisher-dev.csv"), index=False)
|
||||
test_files.to_csv(os.path.join(data_dir, "fisher-test.csv"), index=False)
|
||||
|
||||
|
||||
def _maybe_convert_wav(data_dir, original_data, converted_data):
|
||||
source_dir = os.path.join(data_dir, original_data)
|
||||
target_dir = os.path.join(data_dir, converted_data)
|
||||
@ -87,10 +108,18 @@ def _maybe_convert_wav(data_dir, original_data, converted_data):
|
||||
for filename in fnmatch.filter(filenames, "*.sph"):
|
||||
sph_file = os.path.join(root, filename)
|
||||
for channel in ["1", "2"]:
|
||||
wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "_c" + channel + ".wav"
|
||||
wav_filename = (
|
||||
os.path.splitext(os.path.basename(sph_file))[0]
|
||||
+ "_c"
|
||||
+ channel
|
||||
+ ".wav"
|
||||
)
|
||||
wav_file = os.path.join(target_dir, wav_filename)
|
||||
print("converting {} to {}".format(sph_file, wav_file))
|
||||
subprocess.check_call(["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, wav_file])
|
||||
subprocess.check_call(
|
||||
["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, wav_file]
|
||||
)
|
||||
|
||||
|
||||
def _parse_transcriptions(trans_file):
|
||||
segments = []
|
||||
@ -108,18 +137,23 @@ def _parse_transcriptions(trans_file):
|
||||
# We need to do the encode-decode dance here because encode
|
||||
# returns a bytes() object on Python 3, and text_to_char_array
|
||||
# expects a string.
|
||||
transcript = unicodedata.normalize("NFKD", transcript) \
|
||||
.encode("ascii", "ignore") \
|
||||
.decode("ascii", "ignore")
|
||||
transcript = (
|
||||
unicodedata.normalize("NFKD", transcript)
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
|
||||
segments.append({
|
||||
"start_time": start_time,
|
||||
"stop_time": stop_time,
|
||||
"speaker": speaker,
|
||||
"transcript": transcript,
|
||||
})
|
||||
segments.append(
|
||||
{
|
||||
"start_time": start_time,
|
||||
"stop_time": stop_time,
|
||||
"speaker": speaker,
|
||||
"transcript": transcript,
|
||||
}
|
||||
)
|
||||
return segments
|
||||
|
||||
|
||||
def _split_wav_and_sentences(data_dir, trans_data, original_data, converted_data):
|
||||
trans_dir = os.path.join(data_dir, trans_data)
|
||||
source_dir = os.path.join(data_dir, original_data)
|
||||
@ -136,43 +170,73 @@ def _split_wav_and_sentences(data_dir, trans_data, original_data, converted_data
|
||||
segments = _parse_transcriptions(trans_file)
|
||||
|
||||
# Open wav corresponding to transcription file
|
||||
wav_filenames = [os.path.splitext(os.path.basename(trans_file))[0] + "_c" + channel + ".wav" for channel in ["1", "2"]]
|
||||
wav_files = [os.path.join(source_dir, wav_filename) for wav_filename in wav_filenames]
|
||||
wav_filenames = [
|
||||
os.path.splitext(os.path.basename(trans_file))[0]
|
||||
+ "_c"
|
||||
+ channel
|
||||
+ ".wav"
|
||||
for channel in ["1", "2"]
|
||||
]
|
||||
wav_files = [
|
||||
os.path.join(source_dir, wav_filename) for wav_filename in wav_filenames
|
||||
]
|
||||
|
||||
print("splitting {} according to {}".format(wav_files, trans_file))
|
||||
|
||||
origAudios = [librosa.load(wav_file, sr=16000, mono=False) for wav_file in wav_files]
|
||||
origAudios = [
|
||||
librosa.load(wav_file, sr=16000, mono=False) for wav_file in wav_files
|
||||
]
|
||||
|
||||
# Loop over segments and split wav_file for each segment
|
||||
for segment in segments:
|
||||
# Create wav segment filename
|
||||
start_time = segment["start_time"]
|
||||
stop_time = segment["stop_time"]
|
||||
new_wav_filename = os.path.splitext(os.path.basename(trans_file))[0] + "-" + str(start_time) + "-" + str(stop_time) + ".wav"
|
||||
new_wav_filename = (
|
||||
os.path.splitext(os.path.basename(trans_file))[0]
|
||||
+ "-"
|
||||
+ str(start_time)
|
||||
+ "-"
|
||||
+ str(stop_time)
|
||||
+ ".wav"
|
||||
)
|
||||
new_wav_file = os.path.join(target_dir, new_wav_filename)
|
||||
|
||||
channel = 0 if segment["speaker"] == "A:" else 1
|
||||
_split_and_resample_wav(origAudios[channel], start_time, stop_time, new_wav_file)
|
||||
_split_and_resample_wav(
|
||||
origAudios[channel], start_time, stop_time, new_wav_file
|
||||
)
|
||||
|
||||
new_wav_filesize = os.path.getsize(new_wav_file)
|
||||
transcript = validate_label(segment["transcript"])
|
||||
if transcript != None:
|
||||
files.append((os.path.abspath(new_wav_file), new_wav_filesize, transcript))
|
||||
files.append(
|
||||
(os.path.abspath(new_wav_file), new_wav_filesize, transcript)
|
||||
)
|
||||
|
||||
return pandas.DataFrame(
|
||||
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
|
||||
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
|
||||
def _split_audio(origAudio, start_time, stop_time):
|
||||
audioData, frameRate = origAudio
|
||||
nChannels = len(audioData.shape)
|
||||
startIndex = int(start_time * frameRate)
|
||||
stopIndex = int(stop_time * frameRate)
|
||||
return audioData[startIndex: stopIndex] if 1 == nChannels else audioData[:, startIndex: stopIndex]
|
||||
return (
|
||||
audioData[startIndex:stopIndex]
|
||||
if 1 == nChannels
|
||||
else audioData[:, startIndex:stopIndex]
|
||||
)
|
||||
|
||||
|
||||
def _split_and_resample_wav(origAudio, start_time, stop_time, new_wav_file):
|
||||
frameRate = origAudio[1]
|
||||
chunkData = _split_audio(origAudio, start_time, stop_time)
|
||||
soundfile.write(new_wav_file, chunkData, frameRate, "PCM_16")
|
||||
|
||||
|
||||
def _split_sets(filelist):
|
||||
# We initially split the entire set into 80% train and 20% test, then
|
||||
# split the train set into 80% train and 20% validation.
|
||||
@ -186,9 +250,12 @@ def _split_sets(filelist):
|
||||
test_beg = dev_end
|
||||
test_end = len(filelist)
|
||||
|
||||
return (filelist[train_beg:train_end],
|
||||
filelist[dev_beg:dev_end],
|
||||
filelist[test_beg:test_end])
|
||||
return (
|
||||
filelist[train_beg:train_end],
|
||||
filelist[dev_beg:dev_end],
|
||||
filelist[test_beg:test_end],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_download_and_preprocess_data(sys.argv[1])
|
||||
|
@ -10,11 +10,11 @@ import pandas
|
||||
|
||||
from deepspeech_training.util.importers import get_importers_parser
|
||||
|
||||
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
|
||||
|
||||
def extract(archive_path, target_dir):
|
||||
print('Extracting {} into {}...'.format(archive_path, target_dir))
|
||||
print("Extracting {} into {}...".format(archive_path, target_dir))
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
|
||||
@ -22,7 +22,7 @@ def extract(archive_path, target_dir):
|
||||
def preprocess_data(tgz_file, target_dir):
|
||||
# First extract main archive and sub-archives
|
||||
extract(tgz_file, target_dir)
|
||||
main_folder = os.path.join(target_dir, 'ST-CMDS-20170001_1-OS')
|
||||
main_folder = os.path.join(target_dir, "ST-CMDS-20170001_1-OS")
|
||||
|
||||
# Folder structure is now:
|
||||
# - ST-CMDS-20170001_1-OS/
|
||||
@ -35,16 +35,16 @@ def preprocess_data(tgz_file, target_dir):
|
||||
for wav in glob.glob(glob_path):
|
||||
wav_filename = wav
|
||||
wav_filesize = os.path.getsize(wav)
|
||||
txt_filename = os.path.splitext(wav_filename)[0] + '.txt'
|
||||
with open(txt_filename, 'r') as fin:
|
||||
txt_filename = os.path.splitext(wav_filename)[0] + ".txt"
|
||||
with open(txt_filename, "r") as fin:
|
||||
transcript = fin.read()
|
||||
set_files.append((wav_filename, wav_filesize, transcript))
|
||||
return set_files
|
||||
|
||||
# Load all files, then deterministically split into train/dev/test sets
|
||||
all_files = load_set(os.path.join(main_folder, '*.wav'))
|
||||
all_files = load_set(os.path.join(main_folder, "*.wav"))
|
||||
df = pandas.DataFrame(data=all_files, columns=COLUMN_NAMES)
|
||||
df.sort_values(by='wav_filename', inplace=True)
|
||||
df.sort_values(by="wav_filename", inplace=True)
|
||||
|
||||
indices = np.arange(0, len(df))
|
||||
np.random.seed(12345)
|
||||
@ -57,29 +57,33 @@ def preprocess_data(tgz_file, target_dir):
|
||||
train_indices = indices[:-10000]
|
||||
|
||||
train_files = df.iloc[train_indices]
|
||||
durations = (train_files['wav_filesize'] - 44) / 16000 / 2
|
||||
durations = (train_files["wav_filesize"] - 44) / 16000 / 2
|
||||
train_files = train_files[durations <= 10.0]
|
||||
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum()))
|
||||
dest_csv = os.path.join(target_dir, 'freestmandarin_train.csv')
|
||||
print('Saving train set into {}...'.format(dest_csv))
|
||||
print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
|
||||
dest_csv = os.path.join(target_dir, "freestmandarin_train.csv")
|
||||
print("Saving train set into {}...".format(dest_csv))
|
||||
train_files.to_csv(dest_csv, index=False)
|
||||
|
||||
dev_files = df.iloc[dev_indices]
|
||||
dest_csv = os.path.join(target_dir, 'freestmandarin_dev.csv')
|
||||
print('Saving dev set into {}...'.format(dest_csv))
|
||||
dest_csv = os.path.join(target_dir, "freestmandarin_dev.csv")
|
||||
print("Saving dev set into {}...".format(dest_csv))
|
||||
dev_files.to_csv(dest_csv, index=False)
|
||||
|
||||
test_files = df.iloc[test_indices]
|
||||
dest_csv = os.path.join(target_dir, 'freestmandarin_test.csv')
|
||||
print('Saving test set into {}...'.format(dest_csv))
|
||||
dest_csv = os.path.join(target_dir, "freestmandarin_test.csv")
|
||||
print("Saving test set into {}...".format(dest_csv))
|
||||
test_files.to_csv(dest_csv, index=False)
|
||||
|
||||
|
||||
def main():
|
||||
# https://www.openslr.org/38/
|
||||
parser = get_importers_parser(description='Import Free ST Chinese Mandarin corpus')
|
||||
parser.add_argument('tgz_file', help='Path to ST-CMDS-20170001_1-OS.tar.gz')
|
||||
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
|
||||
parser = get_importers_parser(description="Import Free ST Chinese Mandarin corpus")
|
||||
parser.add_argument("tgz_file", help="Path to ST-CMDS-20170001_1-OS.tar.gz")
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default="",
|
||||
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
|
||||
)
|
||||
params = parser.parse_args()
|
||||
|
||||
if not params.target_dir:
|
||||
|
@ -12,10 +12,7 @@ import pandas as pd
|
||||
from sox import Transformer
|
||||
|
||||
import swifter
|
||||
from deepspeech_training.util.importers import (
|
||||
get_importers_parser,
|
||||
get_validate_label
|
||||
)
|
||||
from deepspeech_training.util.importers import get_importers_parser, get_validate_label
|
||||
|
||||
__version__ = "0.1.0"
|
||||
_logger = logging.getLogger(__name__)
|
||||
@ -37,9 +34,7 @@ def parse_args(args):
|
||||
Returns:
|
||||
:obj:`argparse.Namespace`: command line parameters namespace
|
||||
"""
|
||||
parser = get_importers_parser(
|
||||
description="Imports GramVaani data for Deep Speech"
|
||||
)
|
||||
parser = get_importers_parser(description="Imports GramVaani data for Deep Speech")
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
action="version",
|
||||
@ -79,6 +74,7 @@ def parse_args(args):
|
||||
)
|
||||
return parser.parse_args(args)
|
||||
|
||||
|
||||
def setup_logging(level):
|
||||
"""Setup basic logging
|
||||
Args:
|
||||
@ -89,6 +85,7 @@ def setup_logging(level):
|
||||
level=level, stream=sys.stdout, format=format, datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
|
||||
|
||||
class GramVaaniCSV:
|
||||
"""GramVaaniCSV representing a GramVaani dataset.
|
||||
Args:
|
||||
@ -104,8 +101,17 @@ class GramVaaniCSV:
|
||||
_logger.info("Parsing csv file...%s", os.path.abspath(csv_filename))
|
||||
data = pd.read_csv(
|
||||
os.path.abspath(csv_filename),
|
||||
names=["piece_id","audio_url","transcript_labelled","transcript","labels","content_filename","audio_length","user_id"],
|
||||
usecols=["audio_url","transcript","audio_length"],
|
||||
names=[
|
||||
"piece_id",
|
||||
"audio_url",
|
||||
"transcript_labelled",
|
||||
"transcript",
|
||||
"labels",
|
||||
"content_filename",
|
||||
"audio_length",
|
||||
"user_id",
|
||||
],
|
||||
usecols=["audio_url", "transcript", "audio_length"],
|
||||
skiprows=[0],
|
||||
engine="python",
|
||||
encoding="utf-8",
|
||||
@ -116,6 +122,7 @@ class GramVaaniCSV:
|
||||
_logger.info("Parsed %d lines csv file." % len(data))
|
||||
return data
|
||||
|
||||
|
||||
class GramVaaniDownloader:
|
||||
"""GramVaaniDownloader downloads a GramVaani dataset.
|
||||
Args:
|
||||
@ -135,7 +142,9 @@ class GramVaaniDownloader:
|
||||
mp3_directory (os.path): The directory into which the associated mp3's were downloaded
|
||||
"""
|
||||
mp3_directory = self._pre_download()
|
||||
self.data.swifter.apply(func=lambda arg: self._download(*arg, mp3_directory), axis=1, raw=True)
|
||||
self.data.swifter.apply(
|
||||
func=lambda arg: self._download(*arg, mp3_directory), axis=1, raw=True
|
||||
)
|
||||
return mp3_directory
|
||||
|
||||
def _pre_download(self):
|
||||
@ -158,6 +167,7 @@ class GramVaaniDownloader:
|
||||
else:
|
||||
_logger.debug("Already downloaded mp3 file...%s", audio_url)
|
||||
|
||||
|
||||
class GramVaaniConverter:
|
||||
"""GramVaaniConverter converts the mp3's to wav's for a GramVaani dataset.
|
||||
Args:
|
||||
@ -178,15 +188,26 @@ class GramVaaniConverter:
|
||||
wav_directory (os.path): The directory into which the associated wav's were downloaded
|
||||
"""
|
||||
wav_directory = self._pre_convert()
|
||||
for mp3_filename in self.mp3_directory.glob('**/*.mp3'):
|
||||
wav_filename = os.path.join(wav_directory, os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
|
||||
for mp3_filename in self.mp3_directory.glob("**/*.mp3"):
|
||||
wav_filename = os.path.join(
|
||||
wav_directory,
|
||||
os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav",
|
||||
)
|
||||
if not os.path.exists(wav_filename):
|
||||
_logger.debug("Converting mp3 file %s to wav file %s" % (mp3_filename, wav_filename))
|
||||
_logger.debug(
|
||||
"Converting mp3 file %s to wav file %s"
|
||||
% (mp3_filename, wav_filename)
|
||||
)
|
||||
transformer = Transformer()
|
||||
transformer.convert(samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH)
|
||||
transformer.convert(
|
||||
samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH
|
||||
)
|
||||
transformer.build(str(mp3_filename), str(wav_filename))
|
||||
else:
|
||||
_logger.debug("Already converted mp3 file %s to wav file %s" % (mp3_filename, wav_filename))
|
||||
_logger.debug(
|
||||
"Already converted mp3 file %s to wav file %s"
|
||||
% (mp3_filename, wav_filename)
|
||||
)
|
||||
return wav_directory
|
||||
|
||||
def _pre_convert(self):
|
||||
@ -199,16 +220,21 @@ class GramVaaniConverter:
|
||||
os.mkdir(wav_directory)
|
||||
return wav_directory
|
||||
|
||||
|
||||
class GramVaaniDataSets:
|
||||
def __init__(self, target_dir, wav_directory, gram_vaani_csv):
|
||||
self.target_dir = target_dir
|
||||
self.wav_directory = wav_directory
|
||||
self.csv_data = gram_vaani_csv.data
|
||||
self.raw = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
||||
self.valid = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
||||
self.train = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
||||
self.dev = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
||||
self.test = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
||||
self.raw = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
self.valid = pd.DataFrame(
|
||||
columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
self.train = pd.DataFrame(
|
||||
columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
self.dev = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
self.test = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
|
||||
def create(self):
|
||||
self._convert_csv_data_to_raw_data()
|
||||
@ -217,30 +243,45 @@ class GramVaaniDataSets:
|
||||
self.valid = self.valid.sample(frac=1).reset_index(drop=True)
|
||||
train_size, dev_size, test_size = self._calculate_data_set_sizes()
|
||||
self.train = self.valid.loc[0:train_size]
|
||||
self.dev = self.valid.loc[train_size:train_size+dev_size]
|
||||
self.test = self.valid.loc[train_size+dev_size:train_size+dev_size+test_size]
|
||||
self.dev = self.valid.loc[train_size : train_size + dev_size]
|
||||
self.test = self.valid.loc[
|
||||
train_size + dev_size : train_size + dev_size + test_size
|
||||
]
|
||||
|
||||
def _convert_csv_data_to_raw_data(self):
|
||||
self.raw[["wav_filename","wav_filesize","transcript"]] = self.csv_data[
|
||||
["audio_url","transcript","audio_length"]
|
||||
].swifter.apply(func=lambda arg: self._convert_csv_data_to_raw_data_impl(*arg), axis=1, raw=True)
|
||||
self.raw[["wav_filename", "wav_filesize", "transcript"]] = self.csv_data[
|
||||
["audio_url", "transcript", "audio_length"]
|
||||
].swifter.apply(
|
||||
func=lambda arg: self._convert_csv_data_to_raw_data_impl(*arg),
|
||||
axis=1,
|
||||
raw=True,
|
||||
)
|
||||
self.raw.reset_index()
|
||||
|
||||
def _convert_csv_data_to_raw_data_impl(self, audio_url, transcript, audio_length):
|
||||
if audio_url == "audio_url":
|
||||
return pd.Series(["wav_filename", "wav_filesize", "transcript"])
|
||||
mp3_filename = os.path.basename(audio_url)
|
||||
wav_relative_filename = os.path.join("wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
|
||||
wav_filesize = os.path.getsize(os.path.join(self.target_dir, wav_relative_filename))
|
||||
wav_relative_filename = os.path.join(
|
||||
"wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav"
|
||||
)
|
||||
wav_filesize = os.path.getsize(
|
||||
os.path.join(self.target_dir, wav_relative_filename)
|
||||
)
|
||||
transcript = validate_label(transcript)
|
||||
if None == transcript:
|
||||
transcript = ""
|
||||
return pd.Series([wav_relative_filename, wav_filesize, transcript])
|
||||
return pd.Series([wav_relative_filename, wav_filesize, transcript])
|
||||
|
||||
def _is_valid_raw_rows(self):
|
||||
is_valid_raw_transcripts = self._is_valid_raw_transcripts()
|
||||
is_valid_raw_wav_frames = self._is_valid_raw_wav_frames()
|
||||
is_valid_raw_row = [(is_valid_raw_transcript & is_valid_raw_wav_frame) for is_valid_raw_transcript, is_valid_raw_wav_frame in zip(is_valid_raw_transcripts, is_valid_raw_wav_frames)]
|
||||
is_valid_raw_row = [
|
||||
(is_valid_raw_transcript & is_valid_raw_wav_frame)
|
||||
for is_valid_raw_transcript, is_valid_raw_wav_frame in zip(
|
||||
is_valid_raw_transcripts, is_valid_raw_wav_frames
|
||||
)
|
||||
]
|
||||
series = pd.Series(is_valid_raw_row)
|
||||
return series
|
||||
|
||||
@ -249,16 +290,29 @@ class GramVaaniDataSets:
|
||||
|
||||
def _is_valid_raw_wav_frames(self):
|
||||
transcripts = [str(transcript) for transcript in self.raw.transcript]
|
||||
wav_filepaths = [os.path.join(self.target_dir, str(wav_filename)) for wav_filename in self.raw.wav_filename]
|
||||
wav_frames = [int(subprocess.check_output(['soxi', '-s', wav_filepath], stderr=subprocess.STDOUT)) for wav_filepath in wav_filepaths]
|
||||
is_valid_raw_wav_frames = [self._is_wav_frame_valid(wav_frame, transcript) for wav_frame, transcript in zip(wav_frames, transcripts)]
|
||||
wav_filepaths = [
|
||||
os.path.join(self.target_dir, str(wav_filename))
|
||||
for wav_filename in self.raw.wav_filename
|
||||
]
|
||||
wav_frames = [
|
||||
int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filepath], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
for wav_filepath in wav_filepaths
|
||||
]
|
||||
is_valid_raw_wav_frames = [
|
||||
self._is_wav_frame_valid(wav_frame, transcript)
|
||||
for wav_frame, transcript in zip(wav_frames, transcripts)
|
||||
]
|
||||
return pd.Series(is_valid_raw_wav_frames)
|
||||
|
||||
def _is_wav_frame_valid(self, wav_frame, transcript):
|
||||
is_wav_frame_valid = True
|
||||
if int(wav_frame/SAMPLE_RATE*1000/10/2) < len(str(transcript)):
|
||||
if int(wav_frame / SAMPLE_RATE * 1000 / 10 / 2) < len(str(transcript)):
|
||||
is_wav_frame_valid = False
|
||||
elif wav_frame/SAMPLE_RATE > MAX_SECS:
|
||||
elif wav_frame / SAMPLE_RATE > MAX_SECS:
|
||||
is_wav_frame_valid = False
|
||||
return is_wav_frame_valid
|
||||
|
||||
@ -277,7 +331,14 @@ class GramVaaniDataSets:
|
||||
def _save(self, dataset):
|
||||
dataset_path = os.path.join(self.target_dir, dataset + ".csv")
|
||||
dataframe = getattr(self, dataset)
|
||||
dataframe.to_csv(dataset_path, index=False, encoding="utf-8", escapechar='\\', quoting=csv.QUOTE_MINIMAL)
|
||||
dataframe.to_csv(
|
||||
dataset_path,
|
||||
index=False,
|
||||
encoding="utf-8",
|
||||
escapechar="\\",
|
||||
quoting=csv.QUOTE_MINIMAL,
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
"""Main entry point allowing external calls
|
||||
@ -301,4 +362,5 @@ def main(args):
|
||||
datasets.save()
|
||||
_logger.info("Finished GramVaani importer...")
|
||||
|
||||
|
||||
main(sys.argv[1:])
|
||||
|
@ -13,14 +13,23 @@ def _download_and_preprocess_data(data_dir):
|
||||
# Conditionally download data
|
||||
LDC93S1_BASE = "LDC93S1"
|
||||
LDC93S1_BASE_URL = "https://catalog.ldc.upenn.edu/desc/addenda/"
|
||||
local_file = maybe_download(LDC93S1_BASE + ".wav", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".wav")
|
||||
trans_file = maybe_download(LDC93S1_BASE + ".txt", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".txt")
|
||||
local_file = maybe_download(
|
||||
LDC93S1_BASE + ".wav", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".wav"
|
||||
)
|
||||
trans_file = maybe_download(
|
||||
LDC93S1_BASE + ".txt", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".txt"
|
||||
)
|
||||
with open(trans_file, "r") as fin:
|
||||
transcript = ' '.join(fin.read().strip().lower().split(' ')[2:]).replace('.', '')
|
||||
transcript = " ".join(fin.read().strip().lower().split(" ")[2:]).replace(
|
||||
".", ""
|
||||
)
|
||||
|
||||
df = pandas.DataFrame(data=[(os.path.abspath(local_file), os.path.getsize(local_file), transcript)],
|
||||
columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
df = pandas.DataFrame(
|
||||
data=[(os.path.abspath(local_file), os.path.getsize(local_file), transcript)],
|
||||
columns=["wav_filename", "wav_filesize", "transcript"],
|
||||
)
|
||||
df.to_csv(os.path.join(data_dir, "ldc93s1.csv"), index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_download_and_preprocess_data(sys.argv[1])
|
||||
|
@ -18,13 +18,24 @@ from deepspeech_training.util.downloader import maybe_download
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
|
||||
def _download_and_preprocess_data(data_dir):
|
||||
# Conditionally download data to data_dir
|
||||
print("Downloading Librivox data set (55GB) into {} if not already present...".format(data_dir))
|
||||
print(
|
||||
"Downloading Librivox data set (55GB) into {} if not already present...".format(
|
||||
data_dir
|
||||
)
|
||||
)
|
||||
with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar:
|
||||
TRAIN_CLEAN_100_URL = "http://www.openslr.org/resources/12/train-clean-100.tar.gz"
|
||||
TRAIN_CLEAN_360_URL = "http://www.openslr.org/resources/12/train-clean-360.tar.gz"
|
||||
TRAIN_OTHER_500_URL = "http://www.openslr.org/resources/12/train-other-500.tar.gz"
|
||||
TRAIN_CLEAN_100_URL = (
|
||||
"http://www.openslr.org/resources/12/train-clean-100.tar.gz"
|
||||
)
|
||||
TRAIN_CLEAN_360_URL = (
|
||||
"http://www.openslr.org/resources/12/train-clean-360.tar.gz"
|
||||
)
|
||||
TRAIN_OTHER_500_URL = (
|
||||
"http://www.openslr.org/resources/12/train-other-500.tar.gz"
|
||||
)
|
||||
|
||||
DEV_CLEAN_URL = "http://www.openslr.org/resources/12/dev-clean.tar.gz"
|
||||
DEV_OTHER_URL = "http://www.openslr.org/resources/12/dev-other.tar.gz"
|
||||
@ -32,12 +43,20 @@ def _download_and_preprocess_data(data_dir):
|
||||
TEST_CLEAN_URL = "http://www.openslr.org/resources/12/test-clean.tar.gz"
|
||||
TEST_OTHER_URL = "http://www.openslr.org/resources/12/test-other.tar.gz"
|
||||
|
||||
def filename_of(x): return os.path.split(x)[1]
|
||||
train_clean_100 = maybe_download(filename_of(TRAIN_CLEAN_100_URL), data_dir, TRAIN_CLEAN_100_URL)
|
||||
def filename_of(x):
|
||||
return os.path.split(x)[1]
|
||||
|
||||
train_clean_100 = maybe_download(
|
||||
filename_of(TRAIN_CLEAN_100_URL), data_dir, TRAIN_CLEAN_100_URL
|
||||
)
|
||||
bar.update(0)
|
||||
train_clean_360 = maybe_download(filename_of(TRAIN_CLEAN_360_URL), data_dir, TRAIN_CLEAN_360_URL)
|
||||
train_clean_360 = maybe_download(
|
||||
filename_of(TRAIN_CLEAN_360_URL), data_dir, TRAIN_CLEAN_360_URL
|
||||
)
|
||||
bar.update(1)
|
||||
train_other_500 = maybe_download(filename_of(TRAIN_OTHER_500_URL), data_dir, TRAIN_OTHER_500_URL)
|
||||
train_other_500 = maybe_download(
|
||||
filename_of(TRAIN_OTHER_500_URL), data_dir, TRAIN_OTHER_500_URL
|
||||
)
|
||||
bar.update(2)
|
||||
|
||||
dev_clean = maybe_download(filename_of(DEV_CLEAN_URL), data_dir, DEV_CLEAN_URL)
|
||||
@ -45,9 +64,13 @@ def _download_and_preprocess_data(data_dir):
|
||||
dev_other = maybe_download(filename_of(DEV_OTHER_URL), data_dir, DEV_OTHER_URL)
|
||||
bar.update(4)
|
||||
|
||||
test_clean = maybe_download(filename_of(TEST_CLEAN_URL), data_dir, TEST_CLEAN_URL)
|
||||
test_clean = maybe_download(
|
||||
filename_of(TEST_CLEAN_URL), data_dir, TEST_CLEAN_URL
|
||||
)
|
||||
bar.update(5)
|
||||
test_other = maybe_download(filename_of(TEST_OTHER_URL), data_dir, TEST_OTHER_URL)
|
||||
test_other = maybe_download(
|
||||
filename_of(TEST_OTHER_URL), data_dir, TEST_OTHER_URL
|
||||
)
|
||||
bar.update(6)
|
||||
|
||||
# Conditionally extract LibriSpeech data
|
||||
@ -58,11 +81,17 @@ def _download_and_preprocess_data(data_dir):
|
||||
LIBRIVOX_DIR = "LibriSpeech"
|
||||
work_dir = os.path.join(data_dir, LIBRIVOX_DIR)
|
||||
|
||||
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-100"), train_clean_100)
|
||||
_maybe_extract(
|
||||
data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-100"), train_clean_100
|
||||
)
|
||||
bar.update(0)
|
||||
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-360"), train_clean_360)
|
||||
_maybe_extract(
|
||||
data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-360"), train_clean_360
|
||||
)
|
||||
bar.update(1)
|
||||
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-other-500"), train_other_500)
|
||||
_maybe_extract(
|
||||
data_dir, os.path.join(LIBRIVOX_DIR, "train-other-500"), train_other_500
|
||||
)
|
||||
bar.update(2)
|
||||
|
||||
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "dev-clean"), dev_clean)
|
||||
@ -88,28 +117,48 @@ def _download_and_preprocess_data(data_dir):
|
||||
# data_dir/LibriSpeech/split-wav/1-2-2.txt
|
||||
# ...
|
||||
print("Converting FLAC to WAV and splitting transcriptions...")
|
||||
with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar:
|
||||
train_100 = _convert_audio_and_split_sentences(work_dir, "train-clean-100", "train-clean-100-wav")
|
||||
with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar:
|
||||
train_100 = _convert_audio_and_split_sentences(
|
||||
work_dir, "train-clean-100", "train-clean-100-wav"
|
||||
)
|
||||
bar.update(0)
|
||||
train_360 = _convert_audio_and_split_sentences(work_dir, "train-clean-360", "train-clean-360-wav")
|
||||
train_360 = _convert_audio_and_split_sentences(
|
||||
work_dir, "train-clean-360", "train-clean-360-wav"
|
||||
)
|
||||
bar.update(1)
|
||||
train_500 = _convert_audio_and_split_sentences(work_dir, "train-other-500", "train-other-500-wav")
|
||||
train_500 = _convert_audio_and_split_sentences(
|
||||
work_dir, "train-other-500", "train-other-500-wav"
|
||||
)
|
||||
bar.update(2)
|
||||
|
||||
dev_clean = _convert_audio_and_split_sentences(work_dir, "dev-clean", "dev-clean-wav")
|
||||
dev_clean = _convert_audio_and_split_sentences(
|
||||
work_dir, "dev-clean", "dev-clean-wav"
|
||||
)
|
||||
bar.update(3)
|
||||
dev_other = _convert_audio_and_split_sentences(work_dir, "dev-other", "dev-other-wav")
|
||||
dev_other = _convert_audio_and_split_sentences(
|
||||
work_dir, "dev-other", "dev-other-wav"
|
||||
)
|
||||
bar.update(4)
|
||||
|
||||
test_clean = _convert_audio_and_split_sentences(work_dir, "test-clean", "test-clean-wav")
|
||||
test_clean = _convert_audio_and_split_sentences(
|
||||
work_dir, "test-clean", "test-clean-wav"
|
||||
)
|
||||
bar.update(5)
|
||||
test_other = _convert_audio_and_split_sentences(work_dir, "test-other", "test-other-wav")
|
||||
test_other = _convert_audio_and_split_sentences(
|
||||
work_dir, "test-other", "test-other-wav"
|
||||
)
|
||||
bar.update(6)
|
||||
|
||||
# Write sets to disk as CSV files
|
||||
train_100.to_csv(os.path.join(data_dir, "librivox-train-clean-100.csv"), index=False)
|
||||
train_360.to_csv(os.path.join(data_dir, "librivox-train-clean-360.csv"), index=False)
|
||||
train_500.to_csv(os.path.join(data_dir, "librivox-train-other-500.csv"), index=False)
|
||||
train_100.to_csv(
|
||||
os.path.join(data_dir, "librivox-train-clean-100.csv"), index=False
|
||||
)
|
||||
train_360.to_csv(
|
||||
os.path.join(data_dir, "librivox-train-clean-360.csv"), index=False
|
||||
)
|
||||
train_500.to_csv(
|
||||
os.path.join(data_dir, "librivox-train-other-500.csv"), index=False
|
||||
)
|
||||
|
||||
dev_clean.to_csv(os.path.join(data_dir, "librivox-dev-clean.csv"), index=False)
|
||||
dev_other.to_csv(os.path.join(data_dir, "librivox-dev-other.csv"), index=False)
|
||||
@ -117,6 +166,7 @@ def _download_and_preprocess_data(data_dir):
|
||||
test_clean.to_csv(os.path.join(data_dir, "librivox-test-clean.csv"), index=False)
|
||||
test_other.to_csv(os.path.join(data_dir, "librivox-test-other.csv"), index=False)
|
||||
|
||||
|
||||
def _maybe_extract(data_dir, extracted_data, archive):
|
||||
# If data_dir/extracted_data does not exist, extract archive in data_dir
|
||||
if not gfile.Exists(os.path.join(data_dir, extracted_data)):
|
||||
@ -124,6 +174,7 @@ def _maybe_extract(data_dir, extracted_data, archive):
|
||||
tar.extractall(data_dir)
|
||||
tar.close()
|
||||
|
||||
|
||||
def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
|
||||
source_dir = os.path.join(extracted_dir, data_set)
|
||||
target_dir = os.path.join(extracted_dir, dest_dir)
|
||||
@ -146,20 +197,22 @@ def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
|
||||
# We also convert the corresponding FLACs to WAV in the same pass
|
||||
files = []
|
||||
for root, dirnames, filenames in os.walk(source_dir):
|
||||
for filename in fnmatch.filter(filenames, '*.trans.txt'):
|
||||
for filename in fnmatch.filter(filenames, "*.trans.txt"):
|
||||
trans_filename = os.path.join(root, filename)
|
||||
with codecs.open(trans_filename, "r", "utf-8") as fin:
|
||||
for line in fin:
|
||||
# Parse each segment line
|
||||
first_space = line.find(" ")
|
||||
seqid, transcript = line[:first_space], line[first_space+1:]
|
||||
seqid, transcript = line[:first_space], line[first_space + 1 :]
|
||||
|
||||
# We need to do the encode-decode dance here because encode
|
||||
# returns a bytes() object on Python 3, and text_to_char_array
|
||||
# expects a string.
|
||||
transcript = unicodedata.normalize("NFKD", transcript) \
|
||||
.encode("ascii", "ignore") \
|
||||
.decode("ascii", "ignore")
|
||||
transcript = (
|
||||
unicodedata.normalize("NFKD", transcript)
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
|
||||
transcript = transcript.lower().strip()
|
||||
|
||||
@ -174,7 +227,10 @@ def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
|
||||
|
||||
files.append((os.path.abspath(wav_file), wav_filesize, transcript))
|
||||
|
||||
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
return pandas.DataFrame(
|
||||
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_download_and_preprocess_data(sys.argv[1])
|
||||
|
@ -20,17 +20,17 @@ from deepspeech_training.util.importers import (
|
||||
get_imported_samples,
|
||||
get_importers_parser,
|
||||
get_validate_label,
|
||||
print_import_report
|
||||
print_import_report,
|
||||
)
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 10
|
||||
|
||||
ARCHIVE_DIR_NAME = 'lingua_libre'
|
||||
ARCHIVE_NAME = 'Q{qId}-{iso639_3}-{language_English_name}.zip'
|
||||
ARCHIVE_URL = 'https://lingualibre.fr/datasets/' + ARCHIVE_NAME
|
||||
ARCHIVE_DIR_NAME = "lingua_libre"
|
||||
ARCHIVE_NAME = "Q{qId}-{iso639_3}-{language_English_name}.zip"
|
||||
ARCHIVE_URL = "https://lingualibre.fr/datasets/" + ARCHIVE_NAME
|
||||
|
||||
|
||||
def _download_and_preprocess_data(target_dir):
|
||||
@ -43,6 +43,7 @@ def _download_and_preprocess_data(target_dir):
|
||||
# Produce CSV files and convert ogg data to wav
|
||||
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
|
||||
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = os.path.join(target_dir, extracted_data)
|
||||
@ -55,6 +56,7 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
else:
|
||||
print('Found directory "%s" - not extracting it from archive.' % archive_path)
|
||||
|
||||
|
||||
def one_sample(sample):
|
||||
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
||||
ogg_filename = sample[0]
|
||||
@ -65,47 +67,59 @@ def one_sample(sample):
|
||||
frames = 0
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = os.path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
frames = int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
label = label_filter(sample[1])
|
||||
rows = []
|
||||
counter = get_counter()
|
||||
|
||||
if file_size == -1:
|
||||
# Excluding samples that failed upon conversion
|
||||
counter['failed'] += 1
|
||||
counter["failed"] += 1
|
||||
elif label is None:
|
||||
# Excluding samples that failed on label validation
|
||||
counter['invalid_label'] += 1
|
||||
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
|
||||
counter["invalid_label"] += 1
|
||||
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
|
||||
# Excluding samples that are too short to fit the transcript
|
||||
counter['too_short'] += 1
|
||||
elif frames/SAMPLE_RATE > MAX_SECS:
|
||||
counter["too_short"] += 1
|
||||
elif frames / SAMPLE_RATE > MAX_SECS:
|
||||
# Excluding very long samples to keep a reasonable batch-size
|
||||
counter['too_long'] += 1
|
||||
counter["too_long"] += 1
|
||||
else:
|
||||
# This one is good - keep it for the target CSV
|
||||
rows.append((wav_filename, file_size, label))
|
||||
counter['all'] += 1
|
||||
counter['total_time'] += frames
|
||||
counter["all"] += 1
|
||||
counter["total_time"] += frames
|
||||
|
||||
return (counter, rows)
|
||||
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data):
|
||||
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||
# override existing CSV with normalized one
|
||||
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME + '_' + ARCHIVE_NAME.replace('.zip', '_{}.csv'))
|
||||
target_csv_template = os.path.join(
|
||||
target_dir, ARCHIVE_DIR_NAME + "_" + ARCHIVE_NAME.replace(".zip", "_{}.csv")
|
||||
)
|
||||
if os.path.isfile(target_csv_template):
|
||||
return
|
||||
|
||||
ogg_root_dir = os.path.join(extracted_dir, ARCHIVE_NAME.replace('.zip', ''))
|
||||
ogg_root_dir = os.path.join(extracted_dir, ARCHIVE_NAME.replace(".zip", ""))
|
||||
|
||||
# Get audiofile path and transcript for each sentence in tsv
|
||||
samples = []
|
||||
glob_dir = os.path.join(ogg_root_dir, '**/*.ogg')
|
||||
glob_dir = os.path.join(ogg_root_dir, "**/*.ogg")
|
||||
for record in glob(glob_dir, recursive=True):
|
||||
record_file = record.replace(ogg_root_dir + os.path.sep, '')
|
||||
record_file = record.replace(ogg_root_dir + os.path.sep, "")
|
||||
if record_filter(record_file):
|
||||
samples.append((os.path.join(ogg_root_dir, record_file), os.path.splitext(os.path.basename(record_file))[0]))
|
||||
samples.append(
|
||||
(
|
||||
os.path.join(ogg_root_dir, record_file),
|
||||
os.path.splitext(os.path.basename(record_file))[0],
|
||||
)
|
||||
)
|
||||
|
||||
counter = get_counter()
|
||||
num_samples = len(samples)
|
||||
@ -122,9 +136,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80%
|
||||
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10%
|
||||
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
|
||||
with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
|
||||
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
|
||||
train_writer.writeheader()
|
||||
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
|
||||
@ -136,7 +150,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
transcript = validate_label(item[2])
|
||||
if not transcript:
|
||||
continue
|
||||
wav_filename = os.path.join(ogg_root_dir, item[0].replace('.ogg', '.wav'))
|
||||
wav_filename = os.path.join(
|
||||
ogg_root_dir, item[0].replace(".ogg", ".wav")
|
||||
)
|
||||
i_mod = i % 10
|
||||
if i_mod == 0:
|
||||
writer = test_writer
|
||||
@ -144,18 +160,21 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
writer = dev_writer
|
||||
else:
|
||||
writer = train_writer
|
||||
writer.writerow(dict(
|
||||
wav_filename=wav_filename,
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
))
|
||||
writer.writerow(
|
||||
dict(
|
||||
wav_filename=wav_filename,
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
)
|
||||
)
|
||||
|
||||
imported_samples = get_imported_samples(counter)
|
||||
assert counter['all'] == num_samples
|
||||
assert counter["all"] == num_samples
|
||||
assert len(rows) == imported_samples
|
||||
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
|
||||
|
||||
def _maybe_convert_wav(ogg_filename, wav_filename):
|
||||
if not os.path.exists(wav_filename):
|
||||
transformer = sox.Transformer()
|
||||
@ -163,19 +182,41 @@ def _maybe_convert_wav(ogg_filename, wav_filename):
|
||||
try:
|
||||
transformer.build(ogg_filename, wav_filename)
|
||||
except sox.core.SoxError as ex:
|
||||
print('SoX processing error', ex, ogg_filename, wav_filename)
|
||||
print("SoX processing error", ex, ogg_filename, wav_filename)
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = get_importers_parser(description='Importer for LinguaLibre dataset. Check https://lingualibre.fr/wiki/Help:Download_from_LinguaLibre for details.')
|
||||
parser.add_argument(dest='target_dir')
|
||||
parser.add_argument('--qId', type=int, required=True, help='LinguaLibre language qId')
|
||||
parser.add_argument('--iso639-3', type=str, required=True, help='ISO639-3 language code')
|
||||
parser.add_argument('--english-name', type=str, required=True, help='Enligh name of the language')
|
||||
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
|
||||
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
|
||||
parser.add_argument('--bogus-records', type=argparse.FileType('r'), required=False, help='Text file listing well-known bogus record to skip from importing, from https://lingualibre.fr/wiki/LinguaLibre:Misleading_items')
|
||||
parser = get_importers_parser(
|
||||
description="Importer for LinguaLibre dataset. Check https://lingualibre.fr/wiki/Help:Download_from_LinguaLibre for details."
|
||||
)
|
||||
parser.add_argument(dest="target_dir")
|
||||
parser.add_argument(
|
||||
"--qId", type=int, required=True, help="LinguaLibre language qId"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iso639-3", type=str, required=True, help="ISO639-3 language code"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--english-name", type=str, required=True, help="Enligh name of the language"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter_alphabet",
|
||||
help="Exclude samples with characters not in provided alphabet",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Converts diacritic characters to their base ones",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bogus-records",
|
||||
type=argparse.FileType("r"),
|
||||
required=False,
|
||||
help="Text file listing well-known bogus record to skip from importing, from https://lingualibre.fr/wiki/LinguaLibre:Misleading_items",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI_ARGS = handle_args()
|
||||
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
|
||||
@ -188,15 +229,17 @@ if __name__ == "__main__":
|
||||
|
||||
def record_filter(path):
|
||||
if any(regex.match(path) for regex in bogus_regexes):
|
||||
print('Reject', path)
|
||||
print("Reject", path)
|
||||
return False
|
||||
return True
|
||||
|
||||
def label_filter(label):
|
||||
if CLI_ARGS.normalize:
|
||||
label = unicodedata.normalize("NFKD", label.strip()) \
|
||||
.encode("ascii", "ignore") \
|
||||
label = (
|
||||
unicodedata.normalize("NFKD", label.strip())
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
label = validate_label(label)
|
||||
if ALPHABET and label:
|
||||
try:
|
||||
@ -205,6 +248,14 @@ if __name__ == "__main__":
|
||||
label = None
|
||||
return label
|
||||
|
||||
ARCHIVE_NAME = ARCHIVE_NAME.format(qId=CLI_ARGS.qId, iso639_3=CLI_ARGS.iso639_3, language_English_name=CLI_ARGS.english_name)
|
||||
ARCHIVE_URL = ARCHIVE_URL.format(qId=CLI_ARGS.qId, iso639_3=CLI_ARGS.iso639_3, language_English_name=CLI_ARGS.english_name)
|
||||
ARCHIVE_NAME = ARCHIVE_NAME.format(
|
||||
qId=CLI_ARGS.qId,
|
||||
iso639_3=CLI_ARGS.iso639_3,
|
||||
language_English_name=CLI_ARGS.english_name,
|
||||
)
|
||||
ARCHIVE_URL = ARCHIVE_URL.format(
|
||||
qId=CLI_ARGS.qId,
|
||||
iso639_3=CLI_ARGS.iso639_3,
|
||||
language_English_name=CLI_ARGS.english_name,
|
||||
)
|
||||
_download_and_preprocess_data(target_dir=CLI_ARGS.target_dir)
|
||||
|
@ -18,17 +18,17 @@ from deepspeech_training.util.importers import (
|
||||
get_imported_samples,
|
||||
get_importers_parser,
|
||||
get_validate_label,
|
||||
print_import_report
|
||||
print_import_report,
|
||||
)
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 15
|
||||
|
||||
ARCHIVE_DIR_NAME = '{language}'
|
||||
ARCHIVE_NAME = '{language}.tgz'
|
||||
ARCHIVE_URL = 'http://www.caito.de/data/Training/stt_tts/' + ARCHIVE_NAME
|
||||
ARCHIVE_DIR_NAME = "{language}"
|
||||
ARCHIVE_NAME = "{language}.tgz"
|
||||
ARCHIVE_URL = "http://www.caito.de/data/Training/stt_tts/" + ARCHIVE_NAME
|
||||
|
||||
|
||||
def _download_and_preprocess_data(target_dir):
|
||||
@ -63,7 +63,11 @@ def one_sample(sample):
|
||||
frames = 0
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = os.path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
frames = int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
label = label_filter(sample[1])
|
||||
counter = get_counter()
|
||||
rows = []
|
||||
@ -71,27 +75,30 @@ def one_sample(sample):
|
||||
if file_size == -1:
|
||||
# Excluding samples that failed upon conversion
|
||||
print("conversion failure", wav_filename)
|
||||
counter['failed'] += 1
|
||||
counter["failed"] += 1
|
||||
elif label is None:
|
||||
# Excluding samples that failed on label validation
|
||||
counter['invalid_label'] += 1
|
||||
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
|
||||
counter["invalid_label"] += 1
|
||||
elif int(frames / SAMPLE_RATE * 1000 / 15 / 2) < len(str(label)):
|
||||
# Excluding samples that are too short to fit the transcript
|
||||
counter['too_short'] += 1
|
||||
elif frames/SAMPLE_RATE > MAX_SECS:
|
||||
counter["too_short"] += 1
|
||||
elif frames / SAMPLE_RATE > MAX_SECS:
|
||||
# Excluding very long samples to keep a reasonable batch-size
|
||||
counter['too_long'] += 1
|
||||
counter["too_long"] += 1
|
||||
else:
|
||||
# This one is good - keep it for the target CSV
|
||||
rows.append((wav_filename, file_size, label))
|
||||
counter['all'] += 1
|
||||
counter['total_time'] += frames
|
||||
counter["all"] += 1
|
||||
counter["total_time"] += frames
|
||||
return (counter, rows)
|
||||
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data):
|
||||
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||
# override existing CSV with normalized one
|
||||
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tgz', '_{}.csv'))
|
||||
target_csv_template = os.path.join(
|
||||
target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace(".tgz", "_{}.csv")
|
||||
)
|
||||
if os.path.isfile(target_csv_template):
|
||||
return
|
||||
|
||||
@ -99,14 +106,16 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
|
||||
# Get audiofile path and transcript for each sentence in tsv
|
||||
samples = []
|
||||
glob_dir = os.path.join(wav_root_dir, '**/metadata.csv')
|
||||
glob_dir = os.path.join(wav_root_dir, "**/metadata.csv")
|
||||
for record in glob(glob_dir, recursive=True):
|
||||
if any(map(lambda sk: sk in record, SKIP_LIST)): # pylint: disable=cell-var-from-loop
|
||||
if any(
|
||||
map(lambda sk: sk in record, SKIP_LIST)
|
||||
): # pylint: disable=cell-var-from-loop
|
||||
continue
|
||||
with open(record, 'r') as rec:
|
||||
with open(record, "r") as rec:
|
||||
for re in rec.readlines():
|
||||
re = re.strip().split('|')
|
||||
audio = os.path.join(os.path.dirname(record), 'wavs', re[0] + '.wav')
|
||||
re = re.strip().split("|")
|
||||
audio = os.path.join(os.path.dirname(record), "wavs", re[0] + ".wav")
|
||||
transcript = re[2]
|
||||
samples.append((audio, transcript))
|
||||
|
||||
@ -125,9 +134,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80%
|
||||
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10%
|
||||
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
|
||||
with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
|
||||
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
|
||||
train_writer.writeheader()
|
||||
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
|
||||
@ -147,39 +156,60 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
writer = dev_writer
|
||||
else:
|
||||
writer = train_writer
|
||||
writer.writerow(dict(
|
||||
wav_filename=os.path.relpath(wav_filename, extracted_dir),
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
))
|
||||
writer.writerow(
|
||||
dict(
|
||||
wav_filename=os.path.relpath(wav_filename, extracted_dir),
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
)
|
||||
)
|
||||
|
||||
imported_samples = get_imported_samples(counter)
|
||||
assert counter['all'] == num_samples
|
||||
assert counter["all"] == num_samples
|
||||
assert len(rows) == imported_samples
|
||||
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = get_importers_parser(description='Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/.')
|
||||
parser.add_argument(dest='target_dir')
|
||||
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
|
||||
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
|
||||
parser.add_argument('--skiplist', type=str, default='', help='Directories / books to skip, comma separated')
|
||||
parser.add_argument('--language', required=True, type=str, help='Dataset language to use')
|
||||
parser = get_importers_parser(
|
||||
description="Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/."
|
||||
)
|
||||
parser.add_argument(dest="target_dir")
|
||||
parser.add_argument(
|
||||
"--filter_alphabet",
|
||||
help="Exclude samples with characters not in provided alphabet",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Converts diacritic characters to their base ones",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skiplist",
|
||||
type=str,
|
||||
default="",
|
||||
help="Directories / books to skip, comma separated",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language", required=True, type=str, help="Dataset language to use"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI_ARGS = handle_args()
|
||||
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
|
||||
SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(','))
|
||||
SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(","))
|
||||
validate_label = get_validate_label(CLI_ARGS)
|
||||
|
||||
def label_filter(label):
|
||||
if CLI_ARGS.normalize:
|
||||
label = unicodedata.normalize("NFKD", label.strip()) \
|
||||
.encode("ascii", "ignore") \
|
||||
label = (
|
||||
unicodedata.normalize("NFKD", label.strip())
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
label = validate_label(label)
|
||||
if ALPHABET and label:
|
||||
try:
|
||||
|
@ -10,17 +10,17 @@ import pandas
|
||||
|
||||
from deepspeech_training.util.importers import get_importers_parser
|
||||
|
||||
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
|
||||
|
||||
def extract(archive_path, target_dir):
|
||||
print('Extracting {} into {}...'.format(archive_path, target_dir))
|
||||
print("Extracting {} into {}...".format(archive_path, target_dir))
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
|
||||
|
||||
def is_file_truncated(wav_filename, wav_filesize):
|
||||
with wave.open(wav_filename, mode='rb') as fin:
|
||||
with wave.open(wav_filename, mode="rb") as fin:
|
||||
assert fin.getframerate() == 16000
|
||||
assert fin.getsampwidth() == 2
|
||||
assert fin.getnchannels() == 1
|
||||
@ -33,8 +33,13 @@ def is_file_truncated(wav_filename, wav_filesize):
|
||||
|
||||
def preprocess_data(folder_with_archives, target_dir):
|
||||
# First extract subset archives
|
||||
for subset in ('train', 'dev', 'test'):
|
||||
extract(os.path.join(folder_with_archives, 'magicdata_{}_set.tar.gz'.format(subset)), target_dir)
|
||||
for subset in ("train", "dev", "test"):
|
||||
extract(
|
||||
os.path.join(
|
||||
folder_with_archives, "magicdata_{}_set.tar.gz".format(subset)
|
||||
),
|
||||
target_dir,
|
||||
)
|
||||
|
||||
# Folder structure is now:
|
||||
# - magicdata_{train,dev,test}.tar.gz
|
||||
@ -50,58 +55,73 @@ def preprocess_data(folder_with_archives, target_dir):
|
||||
# name, one containing the speaker ID, and one containing the transcription
|
||||
|
||||
def load_set(set_path):
|
||||
transcripts = pandas.read_csv(os.path.join(set_path, 'TRANS.txt'), sep='\t', index_col=0)
|
||||
glob_path = os.path.join(set_path, '*', '*.wav')
|
||||
transcripts = pandas.read_csv(
|
||||
os.path.join(set_path, "TRANS.txt"), sep="\t", index_col=0
|
||||
)
|
||||
glob_path = os.path.join(set_path, "*", "*.wav")
|
||||
set_files = []
|
||||
for wav in glob.glob(glob_path):
|
||||
try:
|
||||
wav_filename = wav
|
||||
wav_filesize = os.path.getsize(wav)
|
||||
transcript_key = os.path.basename(wav)
|
||||
transcript = transcripts.loc[transcript_key, 'Transcription']
|
||||
transcript = transcripts.loc[transcript_key, "Transcription"]
|
||||
|
||||
# Some files in this dataset are truncated, the header duration
|
||||
# doesn't match the file size. This causes errors at training
|
||||
# time, so check here if things are fine before including a file
|
||||
if is_file_truncated(wav_filename, wav_filesize):
|
||||
print('Warning: File {} is corrupted, header duration does '
|
||||
'not match file size. Ignoring.'.format(wav_filename))
|
||||
print(
|
||||
"Warning: File {} is corrupted, header duration does "
|
||||
"not match file size. Ignoring.".format(wav_filename)
|
||||
)
|
||||
continue
|
||||
|
||||
set_files.append((wav_filename, wav_filesize, transcript))
|
||||
except KeyError:
|
||||
print('Warning: Missing transcript for WAV file {}.'.format(wav))
|
||||
print("Warning: Missing transcript for WAV file {}.".format(wav))
|
||||
return set_files
|
||||
|
||||
for subset in ('train', 'dev', 'test'):
|
||||
print('Loading {} set samples...'.format(subset))
|
||||
for subset in ("train", "dev", "test"):
|
||||
print("Loading {} set samples...".format(subset))
|
||||
subset_files = load_set(os.path.join(target_dir, subset))
|
||||
df = pandas.DataFrame(data=subset_files, columns=COLUMN_NAMES)
|
||||
|
||||
# Trim train set to under 10s
|
||||
if subset == 'train':
|
||||
durations = (df['wav_filesize'] - 44) / 16000 / 2
|
||||
if subset == "train":
|
||||
durations = (df["wav_filesize"] - 44) / 16000 / 2
|
||||
df = df[durations <= 10.0]
|
||||
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum()))
|
||||
|
||||
with_noise = df['transcript'].str.contains(r'\[(FIL|SPK)\]')
|
||||
df = df[~with_noise]
|
||||
print('Trimming {} samples with noise ([FIL] or [SPK])'.format(sum(with_noise)))
|
||||
print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
|
||||
|
||||
dest_csv = os.path.join(target_dir, 'magicdata_{}.csv'.format(subset))
|
||||
print('Saving {} set into {}...'.format(subset, dest_csv))
|
||||
with_noise = df["transcript"].str.contains(r"\[(FIL|SPK)\]")
|
||||
df = df[~with_noise]
|
||||
print(
|
||||
"Trimming {} samples with noise ([FIL] or [SPK])".format(
|
||||
sum(with_noise)
|
||||
)
|
||||
)
|
||||
|
||||
dest_csv = os.path.join(target_dir, "magicdata_{}.csv".format(subset))
|
||||
print("Saving {} set into {}...".format(subset, dest_csv))
|
||||
df.to_csv(dest_csv, index=False)
|
||||
|
||||
|
||||
def main():
|
||||
# https://openslr.org/68/
|
||||
parser = get_importers_parser(description='Import MAGICDATA corpus')
|
||||
parser.add_argument('folder_with_archives', help='Path to folder containing magicdata_{train,dev,test}.tar.gz')
|
||||
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to a folder called magicdata next to the archives')
|
||||
parser = get_importers_parser(description="Import MAGICDATA corpus")
|
||||
parser.add_argument(
|
||||
"folder_with_archives",
|
||||
help="Path to folder containing magicdata_{train,dev,test}.tar.gz",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default="",
|
||||
help="Target folder to extract files into and put the resulting CSVs. Defaults to a folder called magicdata next to the archives",
|
||||
)
|
||||
params = parser.parse_args()
|
||||
|
||||
if not params.target_dir:
|
||||
params.target_dir = os.path.join(params.folder_with_archives, 'magicdata')
|
||||
params.target_dir = os.path.join(params.folder_with_archives, "magicdata")
|
||||
|
||||
preprocess_data(params.folder_with_archives, params.target_dir)
|
||||
|
||||
|
@ -11,11 +11,11 @@ import pandas
|
||||
|
||||
from deepspeech_training.util.importers import get_importers_parser
|
||||
|
||||
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
|
||||
|
||||
def extract(archive_path, target_dir):
|
||||
print('Extracting {} into {}...'.format(archive_path, target_dir))
|
||||
print("Extracting {} into {}...".format(archive_path, target_dir))
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
|
||||
@ -23,7 +23,7 @@ def extract(archive_path, target_dir):
|
||||
def preprocess_data(tgz_file, target_dir):
|
||||
# First extract main archive and sub-archives
|
||||
extract(tgz_file, target_dir)
|
||||
main_folder = os.path.join(target_dir, 'primewords_md_2018_set1')
|
||||
main_folder = os.path.join(target_dir, "primewords_md_2018_set1")
|
||||
|
||||
# Folder structure is now:
|
||||
# - primewords_md_2018_set1/
|
||||
@ -31,14 +31,11 @@ def preprocess_data(tgz_file, target_dir):
|
||||
# - [0-f]/[00-0f]/*.wav
|
||||
# - set1_transcript.json
|
||||
|
||||
transcripts_path = os.path.join(main_folder, 'set1_transcript.json')
|
||||
transcripts_path = os.path.join(main_folder, "set1_transcript.json")
|
||||
with open(transcripts_path) as fin:
|
||||
transcripts = json.load(fin)
|
||||
|
||||
transcripts = {
|
||||
entry['file']: entry['text']
|
||||
for entry in transcripts
|
||||
}
|
||||
transcripts = {entry["file"]: entry["text"] for entry in transcripts}
|
||||
|
||||
def load_set(glob_path):
|
||||
set_files = []
|
||||
@ -50,13 +47,13 @@ def preprocess_data(tgz_file, target_dir):
|
||||
transcript = transcripts[transcript_key]
|
||||
set_files.append((wav_filename, wav_filesize, transcript))
|
||||
except KeyError:
|
||||
print('Warning: Missing transcript for WAV file {}.'.format(wav))
|
||||
print("Warning: Missing transcript for WAV file {}.".format(wav))
|
||||
return set_files
|
||||
|
||||
# Load all files, then deterministically split into train/dev/test sets
|
||||
all_files = load_set(os.path.join(main_folder, 'audio_files', '*', '*', '*.wav'))
|
||||
all_files = load_set(os.path.join(main_folder, "audio_files", "*", "*", "*.wav"))
|
||||
df = pandas.DataFrame(data=all_files, columns=COLUMN_NAMES)
|
||||
df.sort_values(by='wav_filename', inplace=True)
|
||||
df.sort_values(by="wav_filename", inplace=True)
|
||||
|
||||
indices = np.arange(0, len(df))
|
||||
np.random.seed(12345)
|
||||
@ -69,29 +66,33 @@ def preprocess_data(tgz_file, target_dir):
|
||||
train_indices = indices[:-10000]
|
||||
|
||||
train_files = df.iloc[train_indices]
|
||||
durations = (train_files['wav_filesize'] - 44) / 16000 / 2
|
||||
durations = (train_files["wav_filesize"] - 44) / 16000 / 2
|
||||
train_files = train_files[durations <= 15.0]
|
||||
print('Trimming {} samples > 15 seconds'.format((durations > 15.0).sum()))
|
||||
dest_csv = os.path.join(target_dir, 'primewords_train.csv')
|
||||
print('Saving train set into {}...'.format(dest_csv))
|
||||
print("Trimming {} samples > 15 seconds".format((durations > 15.0).sum()))
|
||||
dest_csv = os.path.join(target_dir, "primewords_train.csv")
|
||||
print("Saving train set into {}...".format(dest_csv))
|
||||
train_files.to_csv(dest_csv, index=False)
|
||||
|
||||
dev_files = df.iloc[dev_indices]
|
||||
dest_csv = os.path.join(target_dir, 'primewords_dev.csv')
|
||||
print('Saving dev set into {}...'.format(dest_csv))
|
||||
dest_csv = os.path.join(target_dir, "primewords_dev.csv")
|
||||
print("Saving dev set into {}...".format(dest_csv))
|
||||
dev_files.to_csv(dest_csv, index=False)
|
||||
|
||||
test_files = df.iloc[test_indices]
|
||||
dest_csv = os.path.join(target_dir, 'primewords_test.csv')
|
||||
print('Saving test set into {}...'.format(dest_csv))
|
||||
dest_csv = os.path.join(target_dir, "primewords_test.csv")
|
||||
print("Saving test set into {}...".format(dest_csv))
|
||||
test_files.to_csv(dest_csv, index=False)
|
||||
|
||||
|
||||
def main():
|
||||
# https://www.openslr.org/47/
|
||||
parser = get_importers_parser(description='Import Primewords Chinese corpus set 1')
|
||||
parser.add_argument('tgz_file', help='Path to primewords_md_2018_set1.tar.gz')
|
||||
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
|
||||
parser = get_importers_parser(description="Import Primewords Chinese corpus set 1")
|
||||
parser.add_argument("tgz_file", help="Path to primewords_md_2018_set1.tar.gz")
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default="",
|
||||
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
|
||||
)
|
||||
params = parser.parse_args()
|
||||
|
||||
if not params.target_dir:
|
||||
|
@ -20,17 +20,17 @@ from deepspeech_training.util.importers import (
|
||||
get_imported_samples,
|
||||
get_importers_parser,
|
||||
get_validate_label,
|
||||
print_import_report
|
||||
print_import_report,
|
||||
)
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 15
|
||||
|
||||
ARCHIVE_DIR_NAME = 'African_Accented_French'
|
||||
ARCHIVE_NAME = 'African_Accented_French.tar.gz'
|
||||
ARCHIVE_URL = 'http://www.openslr.org/resources/57/' + ARCHIVE_NAME
|
||||
ARCHIVE_DIR_NAME = "African_Accented_French"
|
||||
ARCHIVE_NAME = "African_Accented_French.tar.gz"
|
||||
ARCHIVE_URL = "http://www.openslr.org/resources/57/" + ARCHIVE_NAME
|
||||
|
||||
|
||||
def _download_and_preprocess_data(target_dir):
|
||||
@ -43,6 +43,7 @@ def _download_and_preprocess_data(target_dir):
|
||||
# Produce CSV files
|
||||
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
|
||||
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = os.path.join(target_dir, extracted_data)
|
||||
@ -56,6 +57,7 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
else:
|
||||
print('Found directory "%s" - not extracting it from archive.' % archive_path)
|
||||
|
||||
|
||||
def one_sample(sample):
|
||||
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
||||
wav_filename = sample[0]
|
||||
@ -63,74 +65,81 @@ def one_sample(sample):
|
||||
frames = 0
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = os.path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
frames = int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
label = label_filter(sample[1])
|
||||
counter = get_counter()
|
||||
rows = []
|
||||
if file_size == -1:
|
||||
# Excluding samples that failed upon conversion
|
||||
counter['failed'] += 1
|
||||
counter["failed"] += 1
|
||||
elif label is None:
|
||||
# Excluding samples that failed on label validation
|
||||
counter['invalid_label'] += 1
|
||||
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
|
||||
counter["invalid_label"] += 1
|
||||
elif int(frames / SAMPLE_RATE * 1000 / 15 / 2) < len(str(label)):
|
||||
# Excluding samples that are too short to fit the transcript
|
||||
counter['too_short'] += 1
|
||||
elif frames/SAMPLE_RATE > MAX_SECS:
|
||||
counter["too_short"] += 1
|
||||
elif frames / SAMPLE_RATE > MAX_SECS:
|
||||
# Excluding very long samples to keep a reasonable batch-size
|
||||
counter['too_long'] += 1
|
||||
counter["too_long"] += 1
|
||||
else:
|
||||
# This one is good - keep it for the target CSV
|
||||
rows.append((wav_filename, file_size, label))
|
||||
counter['all'] += 1
|
||||
counter['total_time'] += frames
|
||||
counter["all"] += 1
|
||||
counter["total_time"] += frames
|
||||
|
||||
return (counter, rows)
|
||||
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data):
|
||||
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||
# override existing CSV with normalized one
|
||||
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tar.gz', '_{}.csv'))
|
||||
target_csv_template = os.path.join(
|
||||
target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace(".tar.gz", "_{}.csv")
|
||||
)
|
||||
if os.path.isfile(target_csv_template):
|
||||
return
|
||||
|
||||
wav_root_dir = os.path.join(extracted_dir)
|
||||
|
||||
all_files = [
|
||||
'transcripts/train/yaounde/fn_text.txt',
|
||||
'transcripts/train/ca16_conv/transcripts.txt',
|
||||
'transcripts/train/ca16_read/conditioned.txt',
|
||||
'transcripts/dev/niger_west_african_fr/transcripts.txt',
|
||||
'speech/dev/niger_west_african_fr/niger_wav_file_name_transcript.tsv',
|
||||
'transcripts/devtest/ca16_read/conditioned.txt',
|
||||
'transcripts/test/ca16/prompts.txt',
|
||||
"transcripts/train/yaounde/fn_text.txt",
|
||||
"transcripts/train/ca16_conv/transcripts.txt",
|
||||
"transcripts/train/ca16_read/conditioned.txt",
|
||||
"transcripts/dev/niger_west_african_fr/transcripts.txt",
|
||||
"speech/dev/niger_west_african_fr/niger_wav_file_name_transcript.tsv",
|
||||
"transcripts/devtest/ca16_read/conditioned.txt",
|
||||
"transcripts/test/ca16/prompts.txt",
|
||||
]
|
||||
|
||||
transcripts = {}
|
||||
for tr in all_files:
|
||||
with open(os.path.join(target_dir, ARCHIVE_DIR_NAME, tr), 'r') as tr_source:
|
||||
with open(os.path.join(target_dir, ARCHIVE_DIR_NAME, tr), "r") as tr_source:
|
||||
for line in tr_source.readlines():
|
||||
line = line.strip()
|
||||
|
||||
if '.tsv' in tr:
|
||||
sep = ' '
|
||||
if ".tsv" in tr:
|
||||
sep = " "
|
||||
else:
|
||||
sep = ' '
|
||||
sep = " "
|
||||
|
||||
audio = os.path.basename(line.split(sep)[0])
|
||||
|
||||
if not ('.wav' in audio):
|
||||
if '.tdf' in audio:
|
||||
audio = audio.replace('.tdf', '.wav')
|
||||
if not (".wav" in audio):
|
||||
if ".tdf" in audio:
|
||||
audio = audio.replace(".tdf", ".wav")
|
||||
else:
|
||||
audio += '.wav'
|
||||
audio += ".wav"
|
||||
|
||||
transcript = ' '.join(line.split(sep)[1:])
|
||||
transcript = " ".join(line.split(sep)[1:])
|
||||
transcripts[audio] = transcript
|
||||
|
||||
# Get audiofile path and transcript for each sentence in tsv
|
||||
samples = []
|
||||
glob_dir = os.path.join(wav_root_dir, '**/*.wav')
|
||||
glob_dir = os.path.join(wav_root_dir, "**/*.wav")
|
||||
for record in glob(glob_dir, recursive=True):
|
||||
record_file = os.path.basename(record)
|
||||
if record_file in transcripts:
|
||||
@ -152,9 +161,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80%
|
||||
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10%
|
||||
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
|
||||
with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
|
||||
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
|
||||
train_writer.writeheader()
|
||||
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
|
||||
@ -174,25 +183,38 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
writer = dev_writer
|
||||
else:
|
||||
writer = train_writer
|
||||
writer.writerow(dict(
|
||||
wav_filename=wav_filename,
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
))
|
||||
writer.writerow(
|
||||
dict(
|
||||
wav_filename=wav_filename,
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
)
|
||||
)
|
||||
|
||||
imported_samples = get_imported_samples(counter)
|
||||
assert counter['all'] == num_samples
|
||||
assert counter["all"] == num_samples
|
||||
assert len(rows) == imported_samples
|
||||
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = get_importers_parser(description='Importer for African Accented French dataset. More information on http://www.openslr.org/57/.')
|
||||
parser.add_argument(dest='target_dir')
|
||||
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
|
||||
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
|
||||
parser = get_importers_parser(
|
||||
description="Importer for African Accented French dataset. More information on http://www.openslr.org/57/."
|
||||
)
|
||||
parser.add_argument(dest="target_dir")
|
||||
parser.add_argument(
|
||||
"--filter_alphabet",
|
||||
help="Exclude samples with characters not in provided alphabet",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Converts diacritic characters to their base ones",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI_ARGS = handle_args()
|
||||
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
|
||||
@ -200,9 +222,11 @@ if __name__ == "__main__":
|
||||
|
||||
def label_filter(label):
|
||||
if CLI_ARGS.normalize:
|
||||
label = unicodedata.normalize("NFKD", label.strip()) \
|
||||
.encode("ascii", "ignore") \
|
||||
label = (
|
||||
unicodedata.normalize("NFKD", label.strip())
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
label = validate_label(label)
|
||||
if ALPHABET and label:
|
||||
try:
|
||||
|
@ -18,24 +18,23 @@ import pandas
|
||||
import requests
|
||||
import soundfile # <= Has an external dependency on libsndfile
|
||||
|
||||
from deepspeech_training.util.importers import \
|
||||
validate_label_eng as validate_label
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
|
||||
# ARCHIVE_NAME refers to ISIP alignments from 01/29/03
|
||||
ARCHIVE_NAME = 'switchboard_word_alignments.tar.gz'
|
||||
ARCHIVE_URL = 'http://www.openslr.org/resources/5/'
|
||||
ARCHIVE_DIR_NAME = 'LDC97S62'
|
||||
LDC_DATASET = 'swb1_LDC97S62.tgz'
|
||||
ARCHIVE_NAME = "switchboard_word_alignments.tar.gz"
|
||||
ARCHIVE_URL = "http://www.openslr.org/resources/5/"
|
||||
ARCHIVE_DIR_NAME = "LDC97S62"
|
||||
LDC_DATASET = "swb1_LDC97S62.tgz"
|
||||
|
||||
|
||||
def download_file(folder, url):
|
||||
# https://stackoverflow.com/a/16696317/738515
|
||||
local_filename = url.split('/')[-1]
|
||||
local_filename = url.split("/")[-1]
|
||||
full_filename = os.path.join(folder, local_filename)
|
||||
r = requests.get(url, stream=True)
|
||||
with open(full_filename, 'wb') as f:
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
with open(full_filename, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
f.write(chunk)
|
||||
return full_filename
|
||||
|
||||
@ -43,7 +42,7 @@ def download_file(folder, url):
|
||||
def maybe_download(archive_url, target_dir, ldc_dataset):
|
||||
# If archive file does not exist, download it...
|
||||
archive_path = os.path.join(target_dir, ldc_dataset)
|
||||
ldc_path = archive_url+ldc_dataset
|
||||
ldc_path = archive_url + ldc_dataset
|
||||
if not os.path.exists(target_dir):
|
||||
print('No path "%s" - creating ...' % target_dir)
|
||||
makedirs(target_dir)
|
||||
@ -62,17 +61,23 @@ def _download_and_preprocess_data(data_dir):
|
||||
archive_path = os.path.abspath(os.path.join(data_dir, LDC_DATASET))
|
||||
|
||||
# Check swb1_LDC97S62.tgz then extract
|
||||
assert(os.path.isfile(archive_path))
|
||||
assert os.path.isfile(archive_path)
|
||||
_extract(target_dir, archive_path)
|
||||
|
||||
|
||||
# Transcripts
|
||||
transcripts_path = maybe_download(ARCHIVE_URL, target_dir, ARCHIVE_NAME)
|
||||
_extract(target_dir, transcripts_path)
|
||||
|
||||
# Check swb1_d1/2/3/4/swb_ms98_transcriptions
|
||||
expected_folders = ["swb1_d1","swb1_d2","swb1_d3","swb1_d4","swb_ms98_transcriptions"]
|
||||
assert(all([os.path.isdir(os.path.join(target_dir,e)) for e in expected_folders]))
|
||||
|
||||
expected_folders = [
|
||||
"swb1_d1",
|
||||
"swb1_d2",
|
||||
"swb1_d3",
|
||||
"swb1_d4",
|
||||
"swb_ms98_transcriptions",
|
||||
]
|
||||
assert all([os.path.isdir(os.path.join(target_dir, e)) for e in expected_folders])
|
||||
|
||||
# Conditionally convert swb sph data to wav
|
||||
_maybe_convert_wav(target_dir, "swb1_d1", "swb1_d1-wav")
|
||||
_maybe_convert_wav(target_dir, "swb1_d2", "swb1_d2-wav")
|
||||
@ -80,13 +85,21 @@ def _download_and_preprocess_data(data_dir):
|
||||
_maybe_convert_wav(target_dir, "swb1_d4", "swb1_d4-wav")
|
||||
|
||||
# Conditionally split wav data
|
||||
d1 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d1-wav", "swb1_d1-split-wav")
|
||||
d2 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d2-wav", "swb1_d2-split-wav")
|
||||
d3 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d3-wav", "swb1_d3-split-wav")
|
||||
d4 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d4-wav", "swb1_d4-split-wav")
|
||||
|
||||
d1 = _maybe_split_wav_and_sentences(
|
||||
target_dir, "swb_ms98_transcriptions", "swb1_d1-wav", "swb1_d1-split-wav"
|
||||
)
|
||||
d2 = _maybe_split_wav_and_sentences(
|
||||
target_dir, "swb_ms98_transcriptions", "swb1_d2-wav", "swb1_d2-split-wav"
|
||||
)
|
||||
d3 = _maybe_split_wav_and_sentences(
|
||||
target_dir, "swb_ms98_transcriptions", "swb1_d3-wav", "swb1_d3-split-wav"
|
||||
)
|
||||
d4 = _maybe_split_wav_and_sentences(
|
||||
target_dir, "swb_ms98_transcriptions", "swb1_d4-wav", "swb1_d4-split-wav"
|
||||
)
|
||||
|
||||
swb_files = d1.append(d2).append(d3).append(d4)
|
||||
|
||||
|
||||
train_files, dev_files, test_files = _split_sets(swb_files)
|
||||
|
||||
# Write sets to disk as CSV files
|
||||
@ -94,7 +107,7 @@ def _download_and_preprocess_data(data_dir):
|
||||
dev_files.to_csv(os.path.join(target_dir, "swb-dev.csv"), index=False)
|
||||
test_files.to_csv(os.path.join(target_dir, "swb-test.csv"), index=False)
|
||||
|
||||
|
||||
|
||||
def _extract(target_dir, archive_path):
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
@ -115,25 +128,46 @@ def _maybe_convert_wav(data_dir, original_data, converted_data):
|
||||
# Loop over sph files in source_dir and convert each to 16-bit PCM wav
|
||||
for root, dirnames, filenames in os.walk(source_dir):
|
||||
for filename in fnmatch.filter(filenames, "*.sph"):
|
||||
for channel in ['1', '2']:
|
||||
for channel in ["1", "2"]:
|
||||
sph_file = os.path.join(root, filename)
|
||||
wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "-" + channel + ".wav"
|
||||
wav_filename = (
|
||||
os.path.splitext(os.path.basename(sph_file))[0]
|
||||
+ "-"
|
||||
+ channel
|
||||
+ ".wav"
|
||||
)
|
||||
wav_file = os.path.join(target_dir, wav_filename)
|
||||
temp_wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "-" + channel + "-temp.wav"
|
||||
temp_wav_filename = (
|
||||
os.path.splitext(os.path.basename(sph_file))[0]
|
||||
+ "-"
|
||||
+ channel
|
||||
+ "-temp.wav"
|
||||
)
|
||||
temp_wav_file = os.path.join(target_dir, temp_wav_filename)
|
||||
print("converting {} to {}".format(sph_file, temp_wav_file))
|
||||
subprocess.check_call(["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, temp_wav_file])
|
||||
subprocess.check_call(
|
||||
[
|
||||
"sph2pipe",
|
||||
"-c",
|
||||
channel,
|
||||
"-p",
|
||||
"-f",
|
||||
"rif",
|
||||
sph_file,
|
||||
temp_wav_file,
|
||||
]
|
||||
)
|
||||
print("upsampling {} to {}".format(temp_wav_file, wav_file))
|
||||
audioData, frameRate = librosa.load(temp_wav_file, sr=16000, mono=True)
|
||||
soundfile.write(wav_file, audioData, frameRate, "PCM_16")
|
||||
os.remove(temp_wav_file)
|
||||
|
||||
|
||||
|
||||
def _parse_transcriptions(trans_file):
|
||||
segments = []
|
||||
with codecs.open(trans_file, "r", "utf-8") as fin:
|
||||
for line in fin:
|
||||
if line.startswith("#") or len(line) <= 1:
|
||||
if line.startswith("#") or len(line) <= 1:
|
||||
continue
|
||||
|
||||
tokens = line.split()
|
||||
@ -147,15 +181,19 @@ def _parse_transcriptions(trans_file):
|
||||
# We need to do the encode-decode dance here because encode
|
||||
# returns a bytes() object on Python 3, and text_to_char_array
|
||||
# expects a string.
|
||||
transcript = unicodedata.normalize("NFKD", transcript) \
|
||||
.encode("ascii", "ignore") \
|
||||
.decode("ascii", "ignore")
|
||||
transcript = (
|
||||
unicodedata.normalize("NFKD", transcript)
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
|
||||
segments.append({
|
||||
"start_time": start_time,
|
||||
"stop_time": stop_time,
|
||||
"transcript": transcript,
|
||||
})
|
||||
segments.append(
|
||||
{
|
||||
"start_time": start_time,
|
||||
"stop_time": stop_time,
|
||||
"transcript": transcript,
|
||||
}
|
||||
)
|
||||
return segments
|
||||
|
||||
|
||||
@ -180,8 +218,16 @@ def _maybe_split_wav_and_sentences(data_dir, trans_data, original_data, converte
|
||||
segments = _parse_transcriptions(trans_file)
|
||||
|
||||
# Open wav corresponding to transcription file
|
||||
channel = ("2","1")[(os.path.splitext(os.path.basename(trans_file))[0])[6] == 'A']
|
||||
wav_filename = "sw0" + (os.path.splitext(os.path.basename(trans_file))[0])[2:6] + "-" + channel + ".wav"
|
||||
channel = ("2", "1")[
|
||||
(os.path.splitext(os.path.basename(trans_file))[0])[6] == "A"
|
||||
]
|
||||
wav_filename = (
|
||||
"sw0"
|
||||
+ (os.path.splitext(os.path.basename(trans_file))[0])[2:6]
|
||||
+ "-"
|
||||
+ channel
|
||||
+ ".wav"
|
||||
)
|
||||
wav_file = os.path.join(source_dir, wav_filename)
|
||||
|
||||
print("splitting {} according to {}".format(wav_file, trans_file))
|
||||
@ -197,26 +243,39 @@ def _maybe_split_wav_and_sentences(data_dir, trans_data, original_data, converte
|
||||
# Create wav segment filename
|
||||
start_time = segment["start_time"]
|
||||
stop_time = segment["stop_time"]
|
||||
new_wav_filename = os.path.splitext(os.path.basename(trans_file))[0] + "-" + str(
|
||||
start_time) + "-" + str(stop_time) + ".wav"
|
||||
new_wav_filename = (
|
||||
os.path.splitext(os.path.basename(trans_file))[0]
|
||||
+ "-"
|
||||
+ str(start_time)
|
||||
+ "-"
|
||||
+ str(stop_time)
|
||||
+ ".wav"
|
||||
)
|
||||
if _is_wav_too_short(new_wav_filename):
|
||||
continue
|
||||
continue
|
||||
new_wav_file = os.path.join(target_dir, new_wav_filename)
|
||||
|
||||
_split_wav(origAudio, start_time, stop_time, new_wav_file)
|
||||
|
||||
new_wav_filesize = os.path.getsize(new_wav_file)
|
||||
transcript = segment["transcript"]
|
||||
files.append((os.path.abspath(new_wav_file), new_wav_filesize, transcript))
|
||||
files.append(
|
||||
(os.path.abspath(new_wav_file), new_wav_filesize, transcript)
|
||||
)
|
||||
|
||||
# Close origAudio
|
||||
origAudio.close()
|
||||
|
||||
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
return pandas.DataFrame(
|
||||
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
|
||||
|
||||
def _is_wav_too_short(wav_filename):
|
||||
short_wav_filenames = ['sw2986A-ms98-a-trans-80.6385-83.358875.wav', 'sw2663A-ms98-a-trans-161.12025-164.213375.wav']
|
||||
short_wav_filenames = [
|
||||
"sw2986A-ms98-a-trans-80.6385-83.358875.wav",
|
||||
"sw2663A-ms98-a-trans-161.12025-164.213375.wav",
|
||||
]
|
||||
return wav_filename in short_wav_filenames
|
||||
|
||||
|
||||
@ -231,7 +290,7 @@ def _split_wav(origAudio, start_time, stop_time, new_wav_file):
|
||||
chunkAudio.writeframes(chunkData)
|
||||
chunkAudio.close()
|
||||
|
||||
|
||||
|
||||
def _split_sets(filelist):
|
||||
# We initially split the entire set into 80% train and 20% test, then
|
||||
# split the train set into 80% train and 20% validation.
|
||||
@ -245,10 +304,24 @@ def _split_sets(filelist):
|
||||
test_beg = dev_end
|
||||
test_end = len(filelist)
|
||||
|
||||
return (filelist[train_beg:train_end], filelist[dev_beg:dev_end], filelist[test_beg:test_end])
|
||||
return (
|
||||
filelist[train_beg:train_end],
|
||||
filelist[dev_beg:dev_end],
|
||||
filelist[test_beg:test_end],
|
||||
)
|
||||
|
||||
|
||||
def _read_data_set(filelist, thread_count, batch_size, numcep, numcontext, stride=1, offset=0, next_index=lambda i: i + 1, limit=0):
|
||||
def _read_data_set(
|
||||
filelist,
|
||||
thread_count,
|
||||
batch_size,
|
||||
numcep,
|
||||
numcontext,
|
||||
stride=1,
|
||||
offset=0,
|
||||
next_index=lambda i: i + 1,
|
||||
limit=0,
|
||||
):
|
||||
# Optionally apply dataset size limit
|
||||
if limit > 0:
|
||||
filelist = filelist.iloc[:limit]
|
||||
@ -256,7 +329,9 @@ def _read_data_set(filelist, thread_count, batch_size, numcep, numcontext, strid
|
||||
filelist = filelist[offset::stride]
|
||||
|
||||
# Return DataSet
|
||||
return DataSet(txt_files, thread_count, batch_size, numcep, numcontext, next_index=next_index)
|
||||
return DataSet(
|
||||
txt_files, thread_count, batch_size, numcep, numcontext, next_index=next_index
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,8 +1,8 @@
|
||||
#!/usr/bin/env python
|
||||
'''
|
||||
"""
|
||||
Downloads and prepares (parts of) the "Spoken Wikipedia Corpora" for DeepSpeech.py
|
||||
Use "python3 import_swc.py -h" for help
|
||||
'''
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
@ -24,44 +24,54 @@ import progressbar
|
||||
import sox
|
||||
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import \
|
||||
validate_label_eng as validate_label
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"
|
||||
SWC_ARCHIVE = "SWC_{language}.tar"
|
||||
LANGUAGES = ['dutch', 'english', 'german']
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
FIELDNAMES_EXT = FIELDNAMES + ['article', 'speaker']
|
||||
LANGUAGES = ["dutch", "english", "german"]
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
FIELDNAMES_EXT = FIELDNAMES + ["article", "speaker"]
|
||||
CHANNELS = 1
|
||||
SAMPLE_RATE = 16000
|
||||
UNKNOWN = '<unknown>'
|
||||
AUDIO_PATTERN = 'audio*.ogg'
|
||||
WAV_NAME = 'audio.wav'
|
||||
ALIGNED_NAME = 'aligned.swc'
|
||||
UNKNOWN = "<unknown>"
|
||||
AUDIO_PATTERN = "audio*.ogg"
|
||||
WAV_NAME = "audio.wav"
|
||||
ALIGNED_NAME = "aligned.swc"
|
||||
|
||||
SUBSTITUTIONS = {
|
||||
'german': [
|
||||
(re.compile(r'\$'), 'dollar'),
|
||||
(re.compile(r'€'), 'euro'),
|
||||
(re.compile(r'£'), 'pfund'),
|
||||
(re.compile(r'ein tausend ([^\s]+) hundert ([^\s]+) er( |$)'), r'\1zehnhundert \2er '),
|
||||
(re.compile(r'ein tausend (acht|neun) hundert'), r'\1zehnhundert'),
|
||||
(re.compile(r'eins punkt null null null punkt null null null punkt null null null'), 'eine milliarde'),
|
||||
(re.compile(r'punkt null null null punkt null null null punkt null null null'), 'milliarden'),
|
||||
(re.compile(r'eins punkt null null null punkt null null null'), 'eine million'),
|
||||
(re.compile(r'punkt null null null punkt null null null'), 'millionen'),
|
||||
(re.compile(r'eins punkt null null null'), 'ein tausend'),
|
||||
(re.compile(r'punkt null null null'), 'tausend'),
|
||||
(re.compile(r'punkt null'), None)
|
||||
"german": [
|
||||
(re.compile(r"\$"), "dollar"),
|
||||
(re.compile(r"€"), "euro"),
|
||||
(re.compile(r"£"), "pfund"),
|
||||
(
|
||||
re.compile(r"ein tausend ([^\s]+) hundert ([^\s]+) er( |$)"),
|
||||
r"\1zehnhundert \2er ",
|
||||
),
|
||||
(re.compile(r"ein tausend (acht|neun) hundert"), r"\1zehnhundert"),
|
||||
(
|
||||
re.compile(
|
||||
r"eins punkt null null null punkt null null null punkt null null null"
|
||||
),
|
||||
"eine milliarde",
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"punkt null null null punkt null null null punkt null null null"
|
||||
),
|
||||
"milliarden",
|
||||
),
|
||||
(re.compile(r"eins punkt null null null punkt null null null"), "eine million"),
|
||||
(re.compile(r"punkt null null null punkt null null null"), "millionen"),
|
||||
(re.compile(r"eins punkt null null null"), "ein tausend"),
|
||||
(re.compile(r"punkt null null null"), "tausend"),
|
||||
(re.compile(r"punkt null"), None),
|
||||
]
|
||||
}
|
||||
|
||||
DONT_NORMALIZE = {
|
||||
'german': 'ÄÖÜäöüß'
|
||||
}
|
||||
DONT_NORMALIZE = {"german": "ÄÖÜäöüß"}
|
||||
|
||||
PRE_FILTER = str.maketrans(dict.fromkeys('/()[]{}<>:'))
|
||||
PRE_FILTER = str.maketrans(dict.fromkeys("/()[]{}<>:"))
|
||||
|
||||
|
||||
class Sample:
|
||||
@ -95,11 +105,14 @@ def get_sample_size(population_size):
|
||||
margin_of_error = 0.01
|
||||
fraction_picking = 0.50
|
||||
z_score = 2.58 # Corresponds to confidence level 99%
|
||||
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (margin_of_error ** 2)
|
||||
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
||||
margin_of_error ** 2
|
||||
)
|
||||
sample_size = 0
|
||||
for train_size in range(population_size, 0, -1):
|
||||
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / \
|
||||
(margin_of_error ** 2 * train_size)
|
||||
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
||||
margin_of_error ** 2 * train_size
|
||||
)
|
||||
sample_size = int(numerator / denominator)
|
||||
if 2 * sample_size + train_size <= population_size:
|
||||
break
|
||||
@ -108,9 +121,11 @@ def get_sample_size(population_size):
|
||||
|
||||
def maybe_download_language(language):
|
||||
lang_upper = language[0].upper() + language[1:]
|
||||
return maybe_download(SWC_ARCHIVE.format(language=lang_upper),
|
||||
CLI_ARGS.base_dir,
|
||||
SWC_URL.format(language=lang_upper))
|
||||
return maybe_download(
|
||||
SWC_ARCHIVE.format(language=lang_upper),
|
||||
CLI_ARGS.base_dir,
|
||||
SWC_URL.format(language=lang_upper),
|
||||
)
|
||||
|
||||
|
||||
def maybe_extract(data_dir, extracted_data, archive):
|
||||
@ -130,29 +145,29 @@ def maybe_extract(data_dir, extracted_data, archive):
|
||||
def ignored(node):
|
||||
if node is None:
|
||||
return False
|
||||
if node.tag == 'ignored':
|
||||
if node.tag == "ignored":
|
||||
return True
|
||||
return ignored(node.find('..'))
|
||||
return ignored(node.find(".."))
|
||||
|
||||
|
||||
def read_token(token):
|
||||
texts, start, end = [], None, None
|
||||
notes = token.findall('n')
|
||||
notes = token.findall("n")
|
||||
if len(notes) > 0:
|
||||
for note in notes:
|
||||
attributes = note.attrib
|
||||
if start is None and 'start' in attributes:
|
||||
start = int(attributes['start'])
|
||||
if 'end' in attributes:
|
||||
token_end = int(attributes['end'])
|
||||
if start is None and "start" in attributes:
|
||||
start = int(attributes["start"])
|
||||
if "end" in attributes:
|
||||
token_end = int(attributes["end"])
|
||||
if end is None or token_end > end:
|
||||
end = token_end
|
||||
if 'pronunciation' in attributes:
|
||||
t = attributes['pronunciation']
|
||||
if "pronunciation" in attributes:
|
||||
t = attributes["pronunciation"]
|
||||
texts.append(t)
|
||||
elif 'text' in token.attrib:
|
||||
texts.append(token.attrib['text'])
|
||||
return start, end, ' '.join(texts)
|
||||
elif "text" in token.attrib:
|
||||
texts.append(token.attrib["text"])
|
||||
return start, end, " ".join(texts)
|
||||
|
||||
|
||||
def in_alphabet(alphabet, c):
|
||||
@ -160,10 +175,12 @@ def in_alphabet(alphabet, c):
|
||||
|
||||
|
||||
ALPHABETS = {}
|
||||
|
||||
|
||||
def get_alphabet(language):
|
||||
if language in ALPHABETS:
|
||||
return ALPHABETS[language]
|
||||
alphabet_path = getattr(CLI_ARGS, language + '_alphabet')
|
||||
alphabet_path = getattr(CLI_ARGS, language + "_alphabet")
|
||||
alphabet = Alphabet(alphabet_path) if alphabet_path else None
|
||||
ALPHABETS[language] = alphabet
|
||||
return alphabet
|
||||
@ -173,27 +190,35 @@ def label_filter(label, language):
|
||||
label = label.translate(PRE_FILTER)
|
||||
label = validate_label(label)
|
||||
if label is None:
|
||||
return None, 'validation'
|
||||
return None, "validation"
|
||||
substitutions = SUBSTITUTIONS[language] if language in SUBSTITUTIONS else []
|
||||
for pattern, replacement in substitutions:
|
||||
if replacement is None:
|
||||
if pattern.match(label):
|
||||
return None, 'substitution rule'
|
||||
return None, "substitution rule"
|
||||
else:
|
||||
label = pattern.sub(replacement, label)
|
||||
chars = []
|
||||
dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else ''
|
||||
dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else ""
|
||||
alphabet = get_alphabet(language)
|
||||
for c in label:
|
||||
if CLI_ARGS.normalize and c not in dont_normalize and not in_alphabet(alphabet, c):
|
||||
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore")
|
||||
if (
|
||||
CLI_ARGS.normalize
|
||||
and c not in dont_normalize
|
||||
and not in_alphabet(alphabet, c)
|
||||
):
|
||||
c = (
|
||||
unicodedata.normalize("NFKD", c)
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
for sc in c:
|
||||
if not in_alphabet(alphabet, sc):
|
||||
return None, 'illegal character'
|
||||
return None, "illegal character"
|
||||
chars.append(sc)
|
||||
label = ''.join(chars)
|
||||
label = "".join(chars)
|
||||
label = validate_label(label)
|
||||
return label, 'validation' if label is None else None
|
||||
return label, "validation" if label is None else None
|
||||
|
||||
|
||||
def collect_samples(base_dir, language):
|
||||
@ -204,7 +229,9 @@ def collect_samples(base_dir, language):
|
||||
samples = []
|
||||
reasons = Counter()
|
||||
|
||||
def add_sample(p_wav_path, p_article, p_speaker, p_start, p_end, p_text, p_reason='complete'):
|
||||
def add_sample(
|
||||
p_wav_path, p_article, p_speaker, p_start, p_end, p_text, p_reason="complete"
|
||||
):
|
||||
if p_start is not None and p_end is not None and p_text is not None:
|
||||
duration = p_end - p_start
|
||||
text, filter_reason = label_filter(p_text, language)
|
||||
@ -214,53 +241,67 @@ def collect_samples(base_dir, language):
|
||||
p_reason = filter_reason
|
||||
elif CLI_ARGS.exclude_unknown_speakers and p_speaker == UNKNOWN:
|
||||
skip = True
|
||||
p_reason = 'unknown speaker'
|
||||
p_reason = "unknown speaker"
|
||||
elif CLI_ARGS.exclude_unknown_articles and p_article == UNKNOWN:
|
||||
skip = True
|
||||
p_reason = 'unknown article'
|
||||
p_reason = "unknown article"
|
||||
elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long:
|
||||
skip = True
|
||||
p_reason = 'exceeded duration'
|
||||
p_reason = "exceeded duration"
|
||||
elif int(duration / 30) < len(text):
|
||||
skip = True
|
||||
p_reason = 'too short to decode'
|
||||
p_reason = "too short to decode"
|
||||
elif duration / len(text) < 10:
|
||||
skip = True
|
||||
p_reason = 'length duration ratio'
|
||||
p_reason = "length duration ratio"
|
||||
if skip:
|
||||
reasons[p_reason] += 1
|
||||
else:
|
||||
samples.append(Sample(p_wav_path, p_start, p_end, text, p_article, p_speaker))
|
||||
samples.append(
|
||||
Sample(p_wav_path, p_start, p_end, text, p_article, p_speaker)
|
||||
)
|
||||
elif p_start is None or p_end is None:
|
||||
reasons['missing timestamps'] += 1
|
||||
reasons["missing timestamps"] += 1
|
||||
else:
|
||||
reasons['missing text'] += 1
|
||||
reasons["missing text"] += 1
|
||||
|
||||
print('Collecting samples...')
|
||||
print("Collecting samples...")
|
||||
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
|
||||
for root in bar(roots):
|
||||
wav_path = os.path.join(root, WAV_NAME)
|
||||
aligned = ET.parse(path.join(root, ALIGNED_NAME))
|
||||
article = UNKNOWN
|
||||
speaker = UNKNOWN
|
||||
for prop in aligned.iter('prop'):
|
||||
for prop in aligned.iter("prop"):
|
||||
attributes = prop.attrib
|
||||
if 'key' in attributes and 'value' in attributes:
|
||||
if attributes['key'] == 'DC.identifier':
|
||||
article = attributes['value']
|
||||
elif attributes['key'] == 'reader.name':
|
||||
speaker = attributes['value']
|
||||
for sentence in aligned.iter('s'):
|
||||
if "key" in attributes and "value" in attributes:
|
||||
if attributes["key"] == "DC.identifier":
|
||||
article = attributes["value"]
|
||||
elif attributes["key"] == "reader.name":
|
||||
speaker = attributes["value"]
|
||||
for sentence in aligned.iter("s"):
|
||||
if ignored(sentence):
|
||||
continue
|
||||
split = False
|
||||
tokens = list(map(read_token, sentence.findall('t')))
|
||||
tokens = list(map(read_token, sentence.findall("t")))
|
||||
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
|
||||
for token_start, token_end, token_text in tokens:
|
||||
if CLI_ARGS.exclude_numbers and any(c.isdigit() for c in token_text):
|
||||
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts),
|
||||
p_reason='has numbers')
|
||||
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
|
||||
add_sample(
|
||||
wav_path,
|
||||
article,
|
||||
speaker,
|
||||
sample_start,
|
||||
sample_end,
|
||||
" ".join(sample_texts),
|
||||
p_reason="has numbers",
|
||||
)
|
||||
sample_start, sample_end, token_texts, sample_texts = (
|
||||
None,
|
||||
None,
|
||||
[],
|
||||
[],
|
||||
)
|
||||
continue
|
||||
if sample_start is None:
|
||||
sample_start = token_start
|
||||
@ -268,20 +309,37 @@ def collect_samples(base_dir, language):
|
||||
continue
|
||||
token_texts.append(token_text)
|
||||
if token_end is not None:
|
||||
if token_start != sample_start and token_end - sample_start > CLI_ARGS.max_duration > 0:
|
||||
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts),
|
||||
p_reason='split')
|
||||
if (
|
||||
token_start != sample_start
|
||||
and token_end - sample_start > CLI_ARGS.max_duration > 0
|
||||
):
|
||||
add_sample(
|
||||
wav_path,
|
||||
article,
|
||||
speaker,
|
||||
sample_start,
|
||||
sample_end,
|
||||
" ".join(sample_texts),
|
||||
p_reason="split",
|
||||
)
|
||||
sample_start = sample_end
|
||||
sample_texts = []
|
||||
split = True
|
||||
sample_end = token_end
|
||||
sample_texts.extend(token_texts)
|
||||
token_texts = []
|
||||
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts),
|
||||
p_reason='split' if split else 'complete')
|
||||
print('Skipped samples:')
|
||||
add_sample(
|
||||
wav_path,
|
||||
article,
|
||||
speaker,
|
||||
sample_start,
|
||||
sample_end,
|
||||
" ".join(sample_texts),
|
||||
p_reason="split" if split else "complete",
|
||||
)
|
||||
print("Skipped samples:")
|
||||
for reason, n in reasons.most_common():
|
||||
print(' - {}: {}'.format(reason, n))
|
||||
print(" - {}: {}".format(reason, n))
|
||||
return samples
|
||||
|
||||
|
||||
@ -301,18 +359,18 @@ def maybe_convert_one_to_wav(entry):
|
||||
elif len(files) > 1:
|
||||
wav_files = []
|
||||
for i, file in enumerate(files):
|
||||
wav_path = os.path.join(root, 'audio{}.wav'.format(i))
|
||||
wav_path = os.path.join(root, "audio{}.wav".format(i))
|
||||
transformer.build(file, wav_path)
|
||||
wav_files.append(wav_path)
|
||||
combiner.set_input_format(file_type=['wav'] * len(wav_files))
|
||||
combiner.build(wav_files, output_wav, 'concatenate')
|
||||
combiner.set_input_format(file_type=["wav"] * len(wav_files))
|
||||
combiner.build(wav_files, output_wav, "concatenate")
|
||||
except sox.core.SoxError:
|
||||
return
|
||||
|
||||
|
||||
def maybe_convert_to_wav(base_dir):
|
||||
roots = list(os.walk(base_dir))
|
||||
print('Converting and joining source audio files...')
|
||||
print("Converting and joining source audio files...")
|
||||
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
|
||||
tp = ThreadPool()
|
||||
for _ in bar(tp.imap_unordered(maybe_convert_one_to_wav, roots)):
|
||||
@ -332,53 +390,66 @@ def assign_sub_sets(samples):
|
||||
sample_set.extend(speakers.pop(0))
|
||||
train_set = sum(speakers, [])
|
||||
if len(train_set) == 0:
|
||||
print('WARNING: Unable to build dev and test sets without speaker bias as there is no speaker meta data')
|
||||
print(
|
||||
"WARNING: Unable to build dev and test sets without speaker bias as there is no speaker meta data"
|
||||
)
|
||||
random.seed(42) # same source data == same output
|
||||
random.shuffle(samples)
|
||||
for index, sample in enumerate(samples):
|
||||
if index < sample_size:
|
||||
sample.sub_set = 'dev'
|
||||
sample.sub_set = "dev"
|
||||
elif index < 2 * sample_size:
|
||||
sample.sub_set = 'test'
|
||||
sample.sub_set = "test"
|
||||
else:
|
||||
sample.sub_set = 'train'
|
||||
sample.sub_set = "train"
|
||||
else:
|
||||
for sub_set, sub_set_samples in [('train', train_set), ('dev', sample_sets[0]), ('test', sample_sets[1])]:
|
||||
for sub_set, sub_set_samples in [
|
||||
("train", train_set),
|
||||
("dev", sample_sets[0]),
|
||||
("test", sample_sets[1]),
|
||||
]:
|
||||
for sample in sub_set_samples:
|
||||
sample.sub_set = sub_set
|
||||
for sub_set, sub_set_samples in group(samples, lambda s: s.sub_set).items():
|
||||
t = sum(map(lambda s: s.end - s.start, sub_set_samples)) / (1000 * 60 * 60)
|
||||
print('Sub-set "{}" with {} samples (duration: {:.2f} h)'
|
||||
.format(sub_set, len(sub_set_samples), t))
|
||||
print(
|
||||
'Sub-set "{}" with {} samples (duration: {:.2f} h)'.format(
|
||||
sub_set, len(sub_set_samples), t
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def create_sample_dirs(language):
|
||||
print('Creating sample directories...')
|
||||
for set_name in ['train', 'dev', 'test']:
|
||||
dir_path = os.path.join(CLI_ARGS.base_dir, language + '-' + set_name)
|
||||
print("Creating sample directories...")
|
||||
for set_name in ["train", "dev", "test"]:
|
||||
dir_path = os.path.join(CLI_ARGS.base_dir, language + "-" + set_name)
|
||||
if not os.path.isdir(dir_path):
|
||||
os.mkdir(dir_path)
|
||||
|
||||
|
||||
def split_audio_files(samples, language):
|
||||
print('Splitting audio files...')
|
||||
print("Splitting audio files...")
|
||||
sub_sets = Counter()
|
||||
src_wav_files = group(samples, lambda s: s.wav_path).items()
|
||||
bar = progressbar.ProgressBar(max_value=len(src_wav_files), widgets=SIMPLE_BAR)
|
||||
for wav_path, file_samples in bar(src_wav_files):
|
||||
file_samples = sorted(file_samples, key=lambda s: s.start)
|
||||
with wave.open(wav_path, 'r') as src_wav_file:
|
||||
with wave.open(wav_path, "r") as src_wav_file:
|
||||
rate = src_wav_file.getframerate()
|
||||
for sample in file_samples:
|
||||
index = sub_sets[sample.sub_set]
|
||||
sample_wav_path = os.path.join(CLI_ARGS.base_dir,
|
||||
language + '-' + sample.sub_set,
|
||||
'sample-{0:06d}.wav'.format(index))
|
||||
sample_wav_path = os.path.join(
|
||||
CLI_ARGS.base_dir,
|
||||
language + "-" + sample.sub_set,
|
||||
"sample-{0:06d}.wav".format(index),
|
||||
)
|
||||
sample.wav_path = sample_wav_path
|
||||
sub_sets[sample.sub_set] += 1
|
||||
src_wav_file.setpos(int(sample.start * rate / 1000.0))
|
||||
data = src_wav_file.readframes(int((sample.end - sample.start) * rate / 1000.0))
|
||||
with wave.open(sample_wav_path, 'w') as sample_wav_file:
|
||||
data = src_wav_file.readframes(
|
||||
int((sample.end - sample.start) * rate / 1000.0)
|
||||
)
|
||||
with wave.open(sample_wav_path, "w") as sample_wav_file:
|
||||
sample_wav_file.setnchannels(src_wav_file.getnchannels())
|
||||
sample_wav_file.setsampwidth(src_wav_file.getsampwidth())
|
||||
sample_wav_file.setframerate(rate)
|
||||
@ -389,21 +460,25 @@ def write_csvs(samples, language):
|
||||
for sub_set, set_samples in group(samples, lambda s: s.sub_set).items():
|
||||
set_samples = sorted(set_samples, key=lambda s: s.wav_path)
|
||||
base_dir = os.path.abspath(CLI_ARGS.base_dir)
|
||||
csv_path = os.path.join(base_dir, language + '-' + sub_set + '.csv')
|
||||
csv_path = os.path.join(base_dir, language + "-" + sub_set + ".csv")
|
||||
print('Writing "{}"...'.format(csv_path))
|
||||
with open(csv_path, 'w') as csv_file:
|
||||
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES)
|
||||
with open(csv_path, "w") as csv_file:
|
||||
writer = csv.DictWriter(
|
||||
csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES
|
||||
)
|
||||
writer.writeheader()
|
||||
bar = progressbar.ProgressBar(max_value=len(set_samples), widgets=SIMPLE_BAR)
|
||||
bar = progressbar.ProgressBar(
|
||||
max_value=len(set_samples), widgets=SIMPLE_BAR
|
||||
)
|
||||
for sample in bar(set_samples):
|
||||
row = {
|
||||
'wav_filename': os.path.relpath(sample.wav_path, base_dir),
|
||||
'wav_filesize': os.path.getsize(sample.wav_path),
|
||||
'transcript': sample.text
|
||||
"wav_filename": os.path.relpath(sample.wav_path, base_dir),
|
||||
"wav_filesize": os.path.getsize(sample.wav_path),
|
||||
"transcript": sample.text,
|
||||
}
|
||||
if CLI_ARGS.add_meta:
|
||||
row['article'] = sample.article
|
||||
row['speaker'] = sample.speaker
|
||||
row["article"] = sample.article
|
||||
row["speaker"] = sample.speaker
|
||||
writer.writerow(row)
|
||||
|
||||
|
||||
@ -430,34 +505,75 @@ def prepare_language(language):
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = argparse.ArgumentParser(description='Import Spoken Wikipedia Corpora')
|
||||
parser.add_argument('base_dir', help='Directory containing all data')
|
||||
parser.add_argument('--language', default='all', help='One of (all|{})'.format('|'.join(LANGUAGES)))
|
||||
parser.add_argument('--exclude_numbers', type=bool, default=True,
|
||||
help='If sequences with non-transliterated numbers should be excluded')
|
||||
parser.add_argument('--max_duration', type=int, default=10000, help='Maximum sample duration in milliseconds')
|
||||
parser.add_argument('--ignore_too_long', type=bool, default=False,
|
||||
help='If samples exceeding max_duration should be removed')
|
||||
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
|
||||
parser = argparse.ArgumentParser(description="Import Spoken Wikipedia Corpora")
|
||||
parser.add_argument("base_dir", help="Directory containing all data")
|
||||
parser.add_argument(
|
||||
"--language", default="all", help="One of (all|{})".format("|".join(LANGUAGES))
|
||||
)
|
||||
parser.add_argument(
|
||||
"--exclude_numbers",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="If sequences with non-transliterated numbers should be excluded",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_duration",
|
||||
type=int,
|
||||
default=10000,
|
||||
help="Maximum sample duration in milliseconds",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore_too_long",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="If samples exceeding max_duration should be removed",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Converts diacritic characters to their base ones",
|
||||
)
|
||||
for language in LANGUAGES:
|
||||
parser.add_argument('--{}_alphabet'.format(language),
|
||||
help='Exclude {} samples with characters not in provided alphabet file'.format(language))
|
||||
parser.add_argument('--add_meta', action='store_true', help='Adds article and speaker CSV columns')
|
||||
parser.add_argument('--exclude_unknown_speakers', action='store_true', help='Exclude unknown speakers')
|
||||
parser.add_argument('--exclude_unknown_articles', action='store_true', help='Exclude unknown articles')
|
||||
parser.add_argument('--keep_archive', type=bool, default=True,
|
||||
help='If downloaded archives should be kept')
|
||||
parser.add_argument('--keep_intermediate', type=bool, default=False,
|
||||
help='If intermediate files should be kept')
|
||||
parser.add_argument(
|
||||
"--{}_alphabet".format(language),
|
||||
help="Exclude {} samples with characters not in provided alphabet file".format(
|
||||
language
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add_meta", action="store_true", help="Adds article and speaker CSV columns"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--exclude_unknown_speakers",
|
||||
action="store_true",
|
||||
help="Exclude unknown speakers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--exclude_unknown_articles",
|
||||
action="store_true",
|
||||
help="Exclude unknown articles",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep_archive",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="If downloaded archives should be kept",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep_intermediate",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="If intermediate files should be kept",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI_ARGS = handle_args()
|
||||
if CLI_ARGS.language == 'all':
|
||||
if CLI_ARGS.language == "all":
|
||||
for lang in LANGUAGES:
|
||||
prepare_language(lang)
|
||||
elif CLI_ARGS.language in LANGUAGES:
|
||||
prepare_language(CLI_ARGS.language)
|
||||
else:
|
||||
fail('Wrong language id')
|
||||
fail("Wrong language id")
|
||||
|
@ -37,6 +37,7 @@ def _download_and_preprocess_data(data_dir):
|
||||
dev_files.to_csv(path.join(data_dir, "ted-dev.csv"), index=False)
|
||||
test_files.to_csv(path.join(data_dir, "ted-test.csv"), index=False)
|
||||
|
||||
|
||||
def _maybe_extract(data_dir, extracted_data, archive):
|
||||
# If data_dir/extracted_data does not exist, extract archive in data_dir
|
||||
if not gfile.Exists(path.join(data_dir, extracted_data)):
|
||||
@ -44,6 +45,7 @@ def _maybe_extract(data_dir, extracted_data, archive):
|
||||
tar.extractall(data_dir)
|
||||
tar.close()
|
||||
|
||||
|
||||
def _maybe_convert_wav(data_dir, extracted_data):
|
||||
# Create extracted_data dir
|
||||
extracted_dir = path.join(data_dir, extracted_data)
|
||||
@ -57,6 +59,7 @@ def _maybe_convert_wav(data_dir, extracted_data):
|
||||
# Conditionally convert test sph to wav
|
||||
_maybe_convert_wav_dataset(extracted_dir, "test")
|
||||
|
||||
|
||||
def _maybe_convert_wav_dataset(extracted_dir, data_set):
|
||||
# Create source dir
|
||||
source_dir = path.join(extracted_dir, data_set, "sph")
|
||||
@ -80,6 +83,7 @@ def _maybe_convert_wav_dataset(extracted_dir, data_set):
|
||||
# Remove source_dir
|
||||
rmdir(source_dir)
|
||||
|
||||
|
||||
def _maybe_split_sentences(data_dir, extracted_data):
|
||||
# Create extracted_data dir
|
||||
extracted_dir = path.join(data_dir, extracted_data)
|
||||
@ -95,6 +99,7 @@ def _maybe_split_sentences(data_dir, extracted_data):
|
||||
|
||||
return train_files, dev_files, test_files
|
||||
|
||||
|
||||
def _maybe_split_dataset(extracted_dir, data_set):
|
||||
# Create stm dir
|
||||
stm_dir = path.join(extracted_dir, data_set, "stm")
|
||||
@ -112,14 +117,21 @@ def _maybe_split_dataset(extracted_dir, data_set):
|
||||
# Open wav corresponding to stm_file
|
||||
wav_filename = path.splitext(path.basename(stm_file))[0] + ".wav"
|
||||
wav_file = path.join(wav_dir, wav_filename)
|
||||
origAudio = wave.open(wav_file,'r')
|
||||
origAudio = wave.open(wav_file, "r")
|
||||
|
||||
# Loop over stm_segments and split wav_file for each segment
|
||||
for stm_segment in stm_segments:
|
||||
# Create wav segment filename
|
||||
start_time = stm_segment.start_time
|
||||
stop_time = stm_segment.stop_time
|
||||
new_wav_filename = path.splitext(path.basename(stm_file))[0] + "-" + str(start_time) + "-" + str(stop_time) + ".wav"
|
||||
new_wav_filename = (
|
||||
path.splitext(path.basename(stm_file))[0]
|
||||
+ "-"
|
||||
+ str(start_time)
|
||||
+ "-"
|
||||
+ str(stop_time)
|
||||
+ ".wav"
|
||||
)
|
||||
new_wav_file = path.join(wav_dir, new_wav_filename)
|
||||
|
||||
# If the wav segment filename does not exist create it
|
||||
@ -127,23 +139,29 @@ def _maybe_split_dataset(extracted_dir, data_set):
|
||||
_split_wav(origAudio, start_time, stop_time, new_wav_file)
|
||||
|
||||
new_wav_filesize = path.getsize(new_wav_file)
|
||||
files.append((path.abspath(new_wav_file), new_wav_filesize, stm_segment.transcript))
|
||||
files.append(
|
||||
(path.abspath(new_wav_file), new_wav_filesize, stm_segment.transcript)
|
||||
)
|
||||
|
||||
# Close origAudio
|
||||
origAudio.close()
|
||||
|
||||
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
return pandas.DataFrame(
|
||||
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
|
||||
|
||||
def _split_wav(origAudio, start_time, stop_time, new_wav_file):
|
||||
frameRate = origAudio.getframerate()
|
||||
origAudio.setpos(int(start_time*frameRate))
|
||||
chunkData = origAudio.readframes(int((stop_time - start_time)*frameRate))
|
||||
chunkAudio = wave.open(new_wav_file,'w')
|
||||
origAudio.setpos(int(start_time * frameRate))
|
||||
chunkData = origAudio.readframes(int((stop_time - start_time) * frameRate))
|
||||
chunkAudio = wave.open(new_wav_file, "w")
|
||||
chunkAudio.setnchannels(origAudio.getnchannels())
|
||||
chunkAudio.setsampwidth(origAudio.getsampwidth())
|
||||
chunkAudio.setframerate(frameRate)
|
||||
chunkAudio.writeframes(chunkData)
|
||||
chunkAudio.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_download_and_preprocess_data(sys.argv[1])
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
'''
|
||||
"""
|
||||
NAME : LDC TIMIT Dataset
|
||||
URL : https://catalog.ldc.upenn.edu/ldc93s1
|
||||
HOURS : 5
|
||||
@ -8,7 +8,7 @@
|
||||
AUTHORS : Garofolo, John, et al.
|
||||
TYPE : LDC Membership
|
||||
LICENCE : LDC User Agreement
|
||||
'''
|
||||
"""
|
||||
|
||||
import errno
|
||||
import fnmatch
|
||||
@ -23,16 +23,17 @@ import pandas as pd
|
||||
|
||||
def clean(word):
|
||||
# LC ALL & strip punctuation which are not required
|
||||
new = word.lower().replace('.', '')
|
||||
new = new.replace(',', '')
|
||||
new = new.replace(';', '')
|
||||
new = new.replace('"', '')
|
||||
new = new.replace('!', '')
|
||||
new = new.replace('?', '')
|
||||
new = new.replace(':', '')
|
||||
new = new.replace('-', '')
|
||||
new = word.lower().replace(".", "")
|
||||
new = new.replace(",", "")
|
||||
new = new.replace(";", "")
|
||||
new = new.replace('"', "")
|
||||
new = new.replace("!", "")
|
||||
new = new.replace("?", "")
|
||||
new = new.replace(":", "")
|
||||
new = new.replace("-", "")
|
||||
return new
|
||||
|
||||
|
||||
def _preprocess_data(args):
|
||||
|
||||
# Assume data is downloaded from LDC - https://catalog.ldc.upenn.edu/ldc93s1
|
||||
@ -42,16 +43,24 @@ def _preprocess_data(args):
|
||||
|
||||
if ignoreSASentences:
|
||||
print("Using recommended ignore SA sentences")
|
||||
print("Ignoring SA sentences (2 x sentences which are repeated by all speakers)")
|
||||
print(
|
||||
"Ignoring SA sentences (2 x sentences which are repeated by all speakers)"
|
||||
)
|
||||
else:
|
||||
print("Using unrecommended setting to include SA sentences")
|
||||
|
||||
datapath = args
|
||||
target = path.join(datapath, "TIMIT")
|
||||
print("Checking to see if data has already been extracted in given argument: %s", target)
|
||||
print(
|
||||
"Checking to see if data has already been extracted in given argument: %s",
|
||||
target,
|
||||
)
|
||||
|
||||
if not path.isdir(target):
|
||||
print("Could not find extracted data, trying to find: TIMIT-LDC93S1.tgz in: ", datapath)
|
||||
print(
|
||||
"Could not find extracted data, trying to find: TIMIT-LDC93S1.tgz in: ",
|
||||
datapath,
|
||||
)
|
||||
filepath = path.join(datapath, "TIMIT-LDC93S1.tgz")
|
||||
if path.isfile(filepath):
|
||||
print("File found, extracting")
|
||||
@ -105,40 +114,58 @@ def _preprocess_data(args):
|
||||
# if ignoreSAsentences we only want those without SA in the name
|
||||
# OR
|
||||
# if not ignoreSAsentences we want all to be added
|
||||
if (ignoreSASentences and not ('SA' in os.path.basename(full_wav))) or (not ignoreSASentences):
|
||||
if 'train' in full_wav.lower():
|
||||
if (ignoreSASentences and not ("SA" in os.path.basename(full_wav))) or (
|
||||
not ignoreSASentences
|
||||
):
|
||||
if "train" in full_wav.lower():
|
||||
train_list_wavs.append(full_wav)
|
||||
train_list_trans.append(trans)
|
||||
train_list_size.append(wav_filesize)
|
||||
elif 'test' in full_wav.lower():
|
||||
elif "test" in full_wav.lower():
|
||||
test_list_wavs.append(full_wav)
|
||||
test_list_trans.append(trans)
|
||||
test_list_size.append(wav_filesize)
|
||||
else:
|
||||
raise IOError
|
||||
|
||||
a = {'wav_filename': train_list_wavs,
|
||||
'wav_filesize': train_list_size,
|
||||
'transcript': train_list_trans
|
||||
}
|
||||
a = {
|
||||
"wav_filename": train_list_wavs,
|
||||
"wav_filesize": train_list_size,
|
||||
"transcript": train_list_trans,
|
||||
}
|
||||
|
||||
c = {'wav_filename': test_list_wavs,
|
||||
'wav_filesize': test_list_size,
|
||||
'transcript': test_list_trans
|
||||
}
|
||||
c = {
|
||||
"wav_filename": test_list_wavs,
|
||||
"wav_filesize": test_list_size,
|
||||
"transcript": test_list_trans,
|
||||
}
|
||||
|
||||
all = {'wav_filename': train_list_wavs + test_list_wavs,
|
||||
'wav_filesize': train_list_size + test_list_size,
|
||||
'transcript': train_list_trans + test_list_trans
|
||||
}
|
||||
all = {
|
||||
"wav_filename": train_list_wavs + test_list_wavs,
|
||||
"wav_filesize": train_list_size + test_list_size,
|
||||
"transcript": train_list_trans + test_list_trans,
|
||||
}
|
||||
|
||||
df_all = pd.DataFrame(all, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int)
|
||||
df_train = pd.DataFrame(a, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int)
|
||||
df_test = pd.DataFrame(c, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int)
|
||||
df_all = pd.DataFrame(
|
||||
all, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
|
||||
)
|
||||
df_train = pd.DataFrame(
|
||||
a, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
|
||||
)
|
||||
df_test = pd.DataFrame(
|
||||
c, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
|
||||
)
|
||||
|
||||
df_all.to_csv(
|
||||
target + "/timit_all.csv", sep=",", header=True, index=False, encoding="ascii"
|
||||
)
|
||||
df_train.to_csv(
|
||||
target + "/timit_train.csv", sep=",", header=True, index=False, encoding="ascii"
|
||||
)
|
||||
df_test.to_csv(
|
||||
target + "/timit_test.csv", sep=",", header=True, index=False, encoding="ascii"
|
||||
)
|
||||
|
||||
df_all.to_csv(target+"/timit_all.csv", sep=',', header=True, index=False, encoding='ascii')
|
||||
df_train.to_csv(target+"/timit_train.csv", sep=',', header=True, index=False, encoding='ascii')
|
||||
df_test.to_csv(target+"/timit_test.csv", sep=',', header=True, index=False, encoding='ascii')
|
||||
|
||||
if __name__ == "__main__":
|
||||
_preprocess_data(sys.argv[1])
|
||||
|
105
bin/import_ts.py
105
bin/import_ts.py
@ -18,26 +18,32 @@ from deepspeech_training.util.importers import (
|
||||
get_imported_samples,
|
||||
get_importers_parser,
|
||||
get_validate_label,
|
||||
print_import_report
|
||||
print_import_report,
|
||||
)
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 15
|
||||
ARCHIVE_NAME = '2019-04-11_fr_FR'
|
||||
ARCHIVE_DIR_NAME = 'ts_' + ARCHIVE_NAME
|
||||
ARCHIVE_URL = 'https://deepspeech-storage-mirror.s3.fr-par.scw.cloud/' + ARCHIVE_NAME + '.zip'
|
||||
ARCHIVE_NAME = "2019-04-11_fr_FR"
|
||||
ARCHIVE_DIR_NAME = "ts_" + ARCHIVE_NAME
|
||||
ARCHIVE_URL = (
|
||||
"https://deepspeech-storage-mirror.s3.fr-par.scw.cloud/" + ARCHIVE_NAME + ".zip"
|
||||
)
|
||||
|
||||
|
||||
def _download_and_preprocess_data(target_dir, english_compatible=False):
|
||||
# Making path absolute
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
# Conditionally download data
|
||||
archive_path = maybe_download('ts_' + ARCHIVE_NAME + '.zip', target_dir, ARCHIVE_URL)
|
||||
archive_path = maybe_download(
|
||||
"ts_" + ARCHIVE_NAME + ".zip", target_dir, ARCHIVE_URL
|
||||
)
|
||||
# Conditionally extract archive data
|
||||
_maybe_extract(target_dir, ARCHIVE_DIR_NAME, archive_path)
|
||||
# Conditionally convert TrainingSpeech data to DeepSpeech CSVs and wav
|
||||
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME, english_compatible=english_compatible)
|
||||
_maybe_convert_sets(
|
||||
target_dir, ARCHIVE_DIR_NAME, english_compatible=english_compatible
|
||||
)
|
||||
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
@ -55,7 +61,7 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
|
||||
def one_sample(sample):
|
||||
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
||||
orig_filename = sample['path']
|
||||
orig_filename = sample["path"]
|
||||
# Storing wav files next to the wav ones - just with a different suffix
|
||||
wav_filename = os.path.splitext(orig_filename)[0] + ".converted.wav"
|
||||
_maybe_convert_wav(orig_filename, wav_filename)
|
||||
@ -63,8 +69,12 @@ def one_sample(sample):
|
||||
frames = 0
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = os.path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
label = sample['text']
|
||||
frames = int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
label = sample["text"]
|
||||
|
||||
rows = []
|
||||
|
||||
@ -72,21 +82,21 @@ def one_sample(sample):
|
||||
counter = get_counter()
|
||||
if file_size == -1:
|
||||
# Excluding samples that failed upon conversion
|
||||
counter['failed'] += 1
|
||||
counter["failed"] += 1
|
||||
elif label is None:
|
||||
# Excluding samples that failed on label validation
|
||||
counter['invalid_label'] += 1
|
||||
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
|
||||
counter["invalid_label"] += 1
|
||||
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
|
||||
# Excluding samples that are too short to fit the transcript
|
||||
counter['too_short'] += 1
|
||||
elif frames/SAMPLE_RATE > MAX_SECS:
|
||||
counter["too_short"] += 1
|
||||
elif frames / SAMPLE_RATE > MAX_SECS:
|
||||
# Excluding very long samples to keep a reasonable batch-size
|
||||
counter['too_long'] += 1
|
||||
counter["too_long"] += 1
|
||||
else:
|
||||
# This one is good - keep it for the target CSV
|
||||
rows.append((wav_filename, file_size, label))
|
||||
counter['all'] += 1
|
||||
counter['total_time'] += frames
|
||||
counter["all"] += 1
|
||||
counter["total_time"] += frames
|
||||
|
||||
return (counter, rows)
|
||||
|
||||
@ -94,18 +104,19 @@ def one_sample(sample):
|
||||
def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
|
||||
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||
# override existing CSV with normalized one
|
||||
target_csv_template = os.path.join(target_dir, 'ts_' + ARCHIVE_NAME + '_{}.csv')
|
||||
target_csv_template = os.path.join(target_dir, "ts_" + ARCHIVE_NAME + "_{}.csv")
|
||||
if os.path.isfile(target_csv_template):
|
||||
return
|
||||
path_to_original_csv = os.path.join(extracted_dir, 'data.csv')
|
||||
path_to_original_csv = os.path.join(extracted_dir, "data.csv")
|
||||
with open(path_to_original_csv) as csv_f:
|
||||
data = [
|
||||
d for d in csv.DictReader(csv_f, delimiter=',')
|
||||
if float(d['duration']) <= MAX_SECS
|
||||
d
|
||||
for d in csv.DictReader(csv_f, delimiter=",")
|
||||
if float(d["duration"]) <= MAX_SECS
|
||||
]
|
||||
|
||||
for line in data:
|
||||
line['path'] = os.path.join(extracted_dir, line['path'])
|
||||
line["path"] = os.path.join(extracted_dir, line["path"])
|
||||
|
||||
num_samples = len(data)
|
||||
rows = []
|
||||
@ -122,9 +133,9 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80%
|
||||
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10%
|
||||
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
|
||||
with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
|
||||
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
|
||||
train_writer.writeheader()
|
||||
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
|
||||
@ -133,7 +144,11 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
|
||||
test_writer.writeheader()
|
||||
|
||||
for i, item in enumerate(rows):
|
||||
transcript = validate_label(cleanup_transcript(item[2], english_compatible=english_compatible))
|
||||
transcript = validate_label(
|
||||
cleanup_transcript(
|
||||
item[2], english_compatible=english_compatible
|
||||
)
|
||||
)
|
||||
if not transcript:
|
||||
continue
|
||||
wav_filename = os.path.join(target_dir, extracted_data, item[0])
|
||||
@ -144,18 +159,21 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
|
||||
writer = dev_writer
|
||||
else:
|
||||
writer = train_writer
|
||||
writer.writerow(dict(
|
||||
wav_filename=wav_filename,
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
))
|
||||
writer.writerow(
|
||||
dict(
|
||||
wav_filename=wav_filename,
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
)
|
||||
)
|
||||
|
||||
imported_samples = get_imported_samples(counter)
|
||||
assert counter['all'] == num_samples
|
||||
assert counter["all"] == num_samples
|
||||
assert len(rows) == imported_samples
|
||||
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
|
||||
|
||||
def _maybe_convert_wav(orig_filename, wav_filename):
|
||||
if not os.path.exists(wav_filename):
|
||||
transformer = sox.Transformer()
|
||||
@ -163,26 +181,31 @@ def _maybe_convert_wav(orig_filename, wav_filename):
|
||||
try:
|
||||
transformer.build(orig_filename, wav_filename)
|
||||
except sox.core.SoxError as ex:
|
||||
print('SoX processing error', ex, orig_filename, wav_filename)
|
||||
print("SoX processing error", ex, orig_filename, wav_filename)
|
||||
|
||||
|
||||
PUNCTUATIONS_REG = re.compile(r"[°\-,;!?.()\[\]*…—]")
|
||||
MULTIPLE_SPACES_REG = re.compile(r'\s{2,}')
|
||||
MULTIPLE_SPACES_REG = re.compile(r"\s{2,}")
|
||||
|
||||
|
||||
def cleanup_transcript(text, english_compatible=False):
|
||||
text = text.replace('’', "'").replace('\u00A0', ' ')
|
||||
text = PUNCTUATIONS_REG.sub(' ', text)
|
||||
text = MULTIPLE_SPACES_REG.sub(' ', text)
|
||||
text = text.replace("’", "'").replace("\u00A0", " ")
|
||||
text = PUNCTUATIONS_REG.sub(" ", text)
|
||||
text = MULTIPLE_SPACES_REG.sub(" ", text)
|
||||
if english_compatible:
|
||||
text = unidecode.unidecode(text)
|
||||
return text.strip().lower()
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = get_importers_parser(description='Importer for TrainingSpeech dataset.')
|
||||
parser.add_argument(dest='target_dir')
|
||||
parser.add_argument('--english-compatible', action='store_true', dest='english_compatible', help='Remove diactrics and other non-ascii chars.')
|
||||
parser = get_importers_parser(description="Importer for TrainingSpeech dataset.")
|
||||
parser.add_argument(dest="target_dir")
|
||||
parser.add_argument(
|
||||
"--english-compatible",
|
||||
action="store_true",
|
||||
dest="english_compatible",
|
||||
help="Remove diactrics and other non-ascii chars.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
#!/usr/bin/env python
|
||||
'''
|
||||
"""
|
||||
Downloads and prepares (parts of) the "German Distant Speech" corpus (TUDA) for DeepSpeech.py
|
||||
Use "python3 import_tuda.py -h" for help
|
||||
'''
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
@ -17,20 +17,21 @@ from collections import Counter
|
||||
import progressbar
|
||||
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import \
|
||||
validate_label_eng as validate_label
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
TUDA_VERSION = 'v2'
|
||||
TUDA_PACKAGE = 'german-speechdata-package-{}'.format(TUDA_VERSION)
|
||||
TUDA_URL = 'http://ltdata1.informatik.uni-hamburg.de/kaldi_tuda_de/{}.tar.gz'.format(TUDA_PACKAGE)
|
||||
TUDA_ARCHIVE = '{}.tar.gz'.format(TUDA_PACKAGE)
|
||||
TUDA_VERSION = "v2"
|
||||
TUDA_PACKAGE = "german-speechdata-package-{}".format(TUDA_VERSION)
|
||||
TUDA_URL = "http://ltdata1.informatik.uni-hamburg.de/kaldi_tuda_de/{}.tar.gz".format(
|
||||
TUDA_PACKAGE
|
||||
)
|
||||
TUDA_ARCHIVE = "{}.tar.gz".format(TUDA_PACKAGE)
|
||||
|
||||
CHANNELS = 1
|
||||
SAMPLE_WIDTH = 2
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
|
||||
|
||||
def maybe_extract(archive):
|
||||
@ -48,69 +49,79 @@ def maybe_extract(archive):
|
||||
|
||||
|
||||
def check_and_prepare_sentence(sentence):
|
||||
sentence = sentence.lower().replace('co2', 'c o zwei')
|
||||
sentence = sentence.lower().replace("co2", "c o zwei")
|
||||
chars = []
|
||||
for c in sentence:
|
||||
if CLI_ARGS.normalize and c not in 'äöüß' and (ALPHABET is None or not ALPHABET.has_char(c)):
|
||||
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore")
|
||||
if (
|
||||
CLI_ARGS.normalize
|
||||
and c not in "äöüß"
|
||||
and (ALPHABET is None or not ALPHABET.has_char(c))
|
||||
):
|
||||
c = (
|
||||
unicodedata.normalize("NFKD", c)
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
for sc in c:
|
||||
if ALPHABET is not None and not ALPHABET.has_char(c):
|
||||
return None
|
||||
chars.append(sc)
|
||||
return validate_label(''.join(chars))
|
||||
return validate_label("".join(chars))
|
||||
|
||||
|
||||
def check_wav_file(wav_path, sentence): # pylint: disable=too-many-return-statements
|
||||
try:
|
||||
with wave.open(wav_path, 'r') as src_wav_file:
|
||||
with wave.open(wav_path, "r") as src_wav_file:
|
||||
rate = src_wav_file.getframerate()
|
||||
channels = src_wav_file.getnchannels()
|
||||
sample_width = src_wav_file.getsampwidth()
|
||||
milliseconds = int(src_wav_file.getnframes() * 1000 / rate)
|
||||
if rate != SAMPLE_RATE:
|
||||
return False, 'wrong sample rate'
|
||||
return False, "wrong sample rate"
|
||||
if channels != CHANNELS:
|
||||
return False, 'wrong number of channels'
|
||||
return False, "wrong number of channels"
|
||||
if sample_width != SAMPLE_WIDTH:
|
||||
return False, 'wrong sample width'
|
||||
return False, "wrong sample width"
|
||||
if milliseconds / len(sentence) < 30:
|
||||
return False, 'too short'
|
||||
return False, "too short"
|
||||
if milliseconds > CLI_ARGS.max_duration > 0:
|
||||
return False, 'too long'
|
||||
return False, "too long"
|
||||
except wave.Error:
|
||||
return False, 'invalid wav file'
|
||||
return False, "invalid wav file"
|
||||
except EOFError:
|
||||
return False, 'premature EOF'
|
||||
return True, 'OK'
|
||||
return False, "premature EOF"
|
||||
return True, "OK"
|
||||
|
||||
|
||||
def write_csvs(extracted):
|
||||
sample_counter = 0
|
||||
reasons = Counter()
|
||||
for sub_set in ['train', 'dev', 'test']:
|
||||
for sub_set in ["train", "dev", "test"]:
|
||||
set_path = os.path.join(extracted, sub_set)
|
||||
set_files = os.listdir(set_path)
|
||||
recordings = {}
|
||||
for file in set_files:
|
||||
if file.endswith('.xml'):
|
||||
if file.endswith(".xml"):
|
||||
recordings[file[:-4]] = []
|
||||
for file in set_files:
|
||||
if file.endswith('.wav') and '_' in file:
|
||||
prefix = file.split('_')[0]
|
||||
if file.endswith(".wav") and "_" in file:
|
||||
prefix = file.split("_")[0]
|
||||
if prefix in recordings:
|
||||
recordings[prefix].append(file)
|
||||
recordings = recordings.items()
|
||||
csv_path = os.path.join(CLI_ARGS.base_dir, 'tuda-{}-{}.csv'.format(TUDA_VERSION, sub_set))
|
||||
csv_path = os.path.join(
|
||||
CLI_ARGS.base_dir, "tuda-{}-{}.csv".format(TUDA_VERSION, sub_set)
|
||||
)
|
||||
print('Writing "{}"...'.format(csv_path))
|
||||
with open(csv_path, 'w') as csv_file:
|
||||
with open(csv_path, "w") as csv_file:
|
||||
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES)
|
||||
writer.writeheader()
|
||||
set_dir = os.path.join(extracted, sub_set)
|
||||
bar = progressbar.ProgressBar(max_value=len(recordings), widgets=SIMPLE_BAR)
|
||||
for prefix, wav_names in bar(recordings):
|
||||
xml_path = os.path.join(set_dir, prefix + '.xml')
|
||||
xml_path = os.path.join(set_dir, prefix + ".xml")
|
||||
meta = ET.parse(xml_path).getroot()
|
||||
sentence = list(meta.iter('cleaned_sentence'))[0].text
|
||||
sentence = list(meta.iter("cleaned_sentence"))[0].text
|
||||
sentence = check_and_prepare_sentence(sentence)
|
||||
if sentence is None:
|
||||
continue
|
||||
@ -119,15 +130,19 @@ def write_csvs(extracted):
|
||||
wav_path = os.path.join(set_path, wav_name)
|
||||
keep, reason = check_wav_file(wav_path, sentence)
|
||||
if keep:
|
||||
writer.writerow({
|
||||
'wav_filename': os.path.relpath(wav_path, CLI_ARGS.base_dir),
|
||||
'wav_filesize': os.path.getsize(wav_path),
|
||||
'transcript': sentence.lower()
|
||||
})
|
||||
writer.writerow(
|
||||
{
|
||||
"wav_filename": os.path.relpath(
|
||||
wav_path, CLI_ARGS.base_dir
|
||||
),
|
||||
"wav_filesize": os.path.getsize(wav_path),
|
||||
"transcript": sentence.lower(),
|
||||
}
|
||||
)
|
||||
else:
|
||||
reasons[reason] += 1
|
||||
if len(reasons.keys()) > 0:
|
||||
print('Excluded samples:')
|
||||
print("Excluded samples:")
|
||||
for reason, n in reasons.most_common():
|
||||
print(' - "{}": {} ({:.2f}%)'.format(reason, n, n * 100 / sample_counter))
|
||||
|
||||
@ -146,13 +161,29 @@ def download_and_prepare():
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = argparse.ArgumentParser(description='Import German Distant Speech (TUDA)')
|
||||
parser.add_argument('base_dir', help='Directory containing all data')
|
||||
parser.add_argument('--max_duration', type=int, default=10000, help='Maximum sample duration in milliseconds')
|
||||
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
|
||||
parser.add_argument('--alphabet', help='Exclude samples with characters not in provided alphabet file')
|
||||
parser.add_argument('--keep_archive', type=bool, default=True,
|
||||
help='If downloaded archives should be kept')
|
||||
parser = argparse.ArgumentParser(description="Import German Distant Speech (TUDA)")
|
||||
parser.add_argument("base_dir", help="Directory containing all data")
|
||||
parser.add_argument(
|
||||
"--max_duration",
|
||||
type=int,
|
||||
default=10000,
|
||||
help="Maximum sample duration in milliseconds",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Converts diacritic characters to their base ones",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alphabet",
|
||||
help="Exclude samples with characters not in provided alphabet file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep_archive",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="If downloaded archives should be kept",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -17,7 +17,7 @@ from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import (
|
||||
get_counter,
|
||||
get_imported_samples,
|
||||
print_import_report
|
||||
print_import_report,
|
||||
)
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
@ -62,7 +62,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
all_samples = []
|
||||
|
||||
for target in sorted(os.listdir(directory)):
|
||||
all_samples += _maybe_prepare_set(path.join(extracted_dir, os.path.split(target)[-1]))
|
||||
all_samples += _maybe_prepare_set(
|
||||
path.join(extracted_dir, os.path.split(target)[-1])
|
||||
)
|
||||
|
||||
num_samples = len(all_samples)
|
||||
print(f"Converting wav files to {SAMPLE_RATE}hz...")
|
||||
@ -76,6 +78,7 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
|
||||
_write_csv(extracted_dir, txt_dir, target_dir)
|
||||
|
||||
|
||||
def one_sample(sample):
|
||||
if is_audio_file(sample):
|
||||
y, sr = librosa.load(sample, sr=16000)
|
||||
@ -98,6 +101,7 @@ def _maybe_prepare_set(target_csv):
|
||||
samples = new_samples
|
||||
return samples
|
||||
|
||||
|
||||
def _write_csv(extracted_dir, txt_dir, target_dir):
|
||||
print(f"Writing CSV file")
|
||||
dset_abs_path = extracted_dir
|
||||
@ -192,7 +196,9 @@ AUDIO_EXTENSIONS = [".wav", "WAV"]
|
||||
|
||||
|
||||
def is_audio_file(filepath):
|
||||
return any(os.path.basename(filepath).endswith(extension) for extension in AUDIO_EXTENSIONS)
|
||||
return any(
|
||||
os.path.basename(filepath).endswith(extension) for extension in AUDIO_EXTENSIONS
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -24,8 +24,10 @@ NUM_PARALLEL = 8
|
||||
"""Lambda function returns the filename of a path"""
|
||||
filename_of = lambda x: path.split(x)[1]
|
||||
|
||||
|
||||
class AtomicCounter(object):
|
||||
"""A class that atomically increments a counter"""
|
||||
|
||||
def __init__(self, start_count=0):
|
||||
"""Initialize the counter
|
||||
:param start_count: the number to start counting at
|
||||
@ -48,6 +50,7 @@ class AtomicCounter(object):
|
||||
"""Returns the current value of the counter (not atomic)"""
|
||||
return self.__count
|
||||
|
||||
|
||||
def _parallel_downloader(voxforge_url, archive_dir, total, counter):
|
||||
"""Generate a function to download a file based on given parameters
|
||||
This works by currying the above given arguments into a closure
|
||||
@ -59,6 +62,7 @@ def _parallel_downloader(voxforge_url, archive_dir, total, counter):
|
||||
:param counter: an atomic counter to keep track of # of downloaded files
|
||||
:return: a function that actually downloads a file given these params
|
||||
"""
|
||||
|
||||
def download(d):
|
||||
"""Binds voxforge_url, archive_dir, total, and counter into this scope
|
||||
Downloads the given file
|
||||
@ -66,12 +70,14 @@ def _parallel_downloader(voxforge_url, archive_dir, total, counter):
|
||||
of the file to download and file is the name of the file to download
|
||||
"""
|
||||
(i, file) = d
|
||||
download_url = voxforge_url + '/' + file
|
||||
download_url = voxforge_url + "/" + file
|
||||
c = counter.increment()
|
||||
print('Downloading file {} ({}/{})...'.format(i+1, c, total))
|
||||
print("Downloading file {} ({}/{})...".format(i + 1, c, total))
|
||||
maybe_download(filename_of(download_url), archive_dir, download_url)
|
||||
|
||||
return download
|
||||
|
||||
|
||||
def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter):
|
||||
"""Generate a function to extract a tar file based on given parameters
|
||||
This works by currying the above given arguments into a closure
|
||||
@ -84,6 +90,7 @@ def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter)
|
||||
:param counter: an atomic counter to keep track of # of extracted files
|
||||
:return: a function that actually extracts a tar file given these params
|
||||
"""
|
||||
|
||||
def extract(d):
|
||||
"""Binds data_dir, number_of_test, number_of_dev, total, and counter into this scope
|
||||
Extracts the given file
|
||||
@ -93,39 +100,49 @@ def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter)
|
||||
(i, archive) = d
|
||||
if i < number_of_test:
|
||||
dataset_dir = path.join(data_dir, "test")
|
||||
elif i<number_of_test+number_of_dev:
|
||||
elif i < number_of_test + number_of_dev:
|
||||
dataset_dir = path.join(data_dir, "dev")
|
||||
else:
|
||||
dataset_dir = path.join(data_dir, "train")
|
||||
if not gfile.Exists(os.path.join(dataset_dir, '.'.join(filename_of(archive).split(".")[:-1]))):
|
||||
if not gfile.Exists(
|
||||
os.path.join(dataset_dir, ".".join(filename_of(archive).split(".")[:-1]))
|
||||
):
|
||||
c = counter.increment()
|
||||
print('Extracting file {} ({}/{})...'.format(i+1, c, total))
|
||||
print("Extracting file {} ({}/{})...".format(i + 1, c, total))
|
||||
tar = tarfile.open(archive)
|
||||
tar.extractall(dataset_dir)
|
||||
tar.close()
|
||||
|
||||
return extract
|
||||
|
||||
|
||||
def _download_and_preprocess_data(data_dir):
|
||||
# Conditionally download data to data_dir
|
||||
if not path.isdir(data_dir):
|
||||
makedirs(data_dir)
|
||||
|
||||
archive_dir = data_dir+"/archive"
|
||||
archive_dir = data_dir + "/archive"
|
||||
if not path.isdir(archive_dir):
|
||||
makedirs(archive_dir)
|
||||
|
||||
print("Downloading Voxforge data set into {} if not already present...".format(archive_dir))
|
||||
print(
|
||||
"Downloading Voxforge data set into {} if not already present...".format(
|
||||
archive_dir
|
||||
)
|
||||
)
|
||||
|
||||
voxforge_url = 'http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/Audio/Main/16kHz_16bit'
|
||||
voxforge_url = "http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/Audio/Main/16kHz_16bit"
|
||||
html_page = urllib.request.urlopen(voxforge_url)
|
||||
soup = BeautifulSoup(html_page, 'html.parser')
|
||||
soup = BeautifulSoup(html_page, "html.parser")
|
||||
|
||||
# list all links
|
||||
refs = [l['href'] for l in soup.find_all('a') if ".tgz" in l['href']]
|
||||
refs = [l["href"] for l in soup.find_all("a") if ".tgz" in l["href"]]
|
||||
|
||||
# download files in parallel
|
||||
print('{} files to download'.format(len(refs)))
|
||||
downloader = _parallel_downloader(voxforge_url, archive_dir, len(refs), AtomicCounter())
|
||||
print("{} files to download".format(len(refs)))
|
||||
downloader = _parallel_downloader(
|
||||
voxforge_url, archive_dir, len(refs), AtomicCounter()
|
||||
)
|
||||
p = ThreadPool(NUM_PARALLEL)
|
||||
p.map(downloader, enumerate(refs))
|
||||
|
||||
@ -139,12 +156,18 @@ def _download_and_preprocess_data(data_dir):
|
||||
|
||||
tarfiles = glob(os.path.join(archive_dir, "*.tgz"))
|
||||
number_of_files = len(tarfiles)
|
||||
number_of_test = number_of_files//100
|
||||
number_of_dev = number_of_files//100
|
||||
number_of_test = number_of_files // 100
|
||||
number_of_dev = number_of_files // 100
|
||||
|
||||
# extract tars in parallel
|
||||
print("Extracting Voxforge data set into {} if not already present...".format(data_dir))
|
||||
extracter = _parallel_extracter(data_dir, number_of_test, number_of_dev, len(tarfiles), AtomicCounter())
|
||||
print(
|
||||
"Extracting Voxforge data set into {} if not already present...".format(
|
||||
data_dir
|
||||
)
|
||||
)
|
||||
extracter = _parallel_extracter(
|
||||
data_dir, number_of_test, number_of_dev, len(tarfiles), AtomicCounter()
|
||||
)
|
||||
p.map(extracter, enumerate(tarfiles))
|
||||
|
||||
# Generate data set
|
||||
@ -158,34 +181,46 @@ def _download_and_preprocess_data(data_dir):
|
||||
dev_files.to_csv(os.path.join(data_dir, "voxforge-dev.csv"), index=False)
|
||||
test_files.to_csv(os.path.join(data_dir, "voxforge-test.csv"), index=False)
|
||||
|
||||
|
||||
def _generate_dataset(data_dir, data_set):
|
||||
extracted_dir = path.join(data_dir, data_set)
|
||||
files = []
|
||||
for promts_file in glob(os.path.join(extracted_dir+"/*/etc/", "PROMPTS")):
|
||||
for promts_file in glob(os.path.join(extracted_dir + "/*/etc/", "PROMPTS")):
|
||||
if path.isdir(os.path.join(promts_file[:-11], "wav")):
|
||||
with codecs.open(promts_file, 'r', 'utf-8') as f:
|
||||
with codecs.open(promts_file, "r", "utf-8") as f:
|
||||
for line in f:
|
||||
id = line.split(' ')[0].split('/')[-1]
|
||||
sentence = ' '.join(line.split(' ')[1:])
|
||||
sentence = re.sub("[^a-z']", " ",sentence.strip().lower())
|
||||
id = line.split(" ")[0].split("/")[-1]
|
||||
sentence = " ".join(line.split(" ")[1:])
|
||||
sentence = re.sub("[^a-z']", " ", sentence.strip().lower())
|
||||
transcript = ""
|
||||
for token in sentence.split(" "):
|
||||
word = token.strip()
|
||||
if word!="" and word!=" ":
|
||||
if word != "" and word != " ":
|
||||
transcript += word + " "
|
||||
transcript = unicodedata.normalize("NFKD", transcript.strip()) \
|
||||
.encode("ascii", "ignore") \
|
||||
.decode("ascii", "ignore")
|
||||
transcript = (
|
||||
unicodedata.normalize("NFKD", transcript.strip())
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
wav_file = path.join(promts_file[:-11], "wav/" + id + ".wav")
|
||||
if gfile.Exists(wav_file):
|
||||
wav_filesize = path.getsize(wav_file)
|
||||
# remove audios that are shorter than 0.5s and longer than 20s.
|
||||
# remove audios that are too short for transcript.
|
||||
if ((wav_filesize/32000) > 0.5 and (wav_filesize/32000) < 20 and transcript != "" and
|
||||
wav_filesize/len(transcript) > 1400):
|
||||
files.append((os.path.abspath(wav_file), wav_filesize, transcript))
|
||||
if (
|
||||
(wav_filesize / 32000) > 0.5
|
||||
and (wav_filesize / 32000) < 20
|
||||
and transcript != ""
|
||||
and wav_filesize / len(transcript) > 1400
|
||||
):
|
||||
files.append(
|
||||
(os.path.abspath(wav_file), wav_filesize, transcript)
|
||||
)
|
||||
|
||||
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
return pandas.DataFrame(
|
||||
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
|
||||
if __name__=="__main__":
|
||||
|
||||
if __name__ == "__main__":
|
||||
_download_and_preprocess_data(sys.argv[1])
|
||||
|
@ -7,11 +7,12 @@ import tensorflow.compat.v1 as tfv1
|
||||
|
||||
|
||||
def main():
|
||||
with tfv1.gfile.FastGFile(sys.argv[1], 'rb') as fin:
|
||||
with tfv1.gfile.FastGFile(sys.argv[1], "rb") as fin:
|
||||
graph_def = tfv1.GraphDef()
|
||||
graph_def.ParseFromString(fin.read())
|
||||
|
||||
print('\n'.join(sorted(set(n.op for n in graph_def.node))))
|
||||
print("\n".join(sorted(set(n.op for n in graph_def.node))))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
38
bin/play.py
38
bin/play.py
@ -10,10 +10,7 @@ import random
|
||||
import sys
|
||||
|
||||
from deepspeech_training.util.audio import AUDIO_TYPE_PCM
|
||||
from deepspeech_training.util.sample_collections import (
|
||||
LabeledSample,
|
||||
samples_from_file
|
||||
)
|
||||
from deepspeech_training.util.sample_collections import LabeledSample, samples_from_file
|
||||
|
||||
|
||||
def play_sample(samples, index):
|
||||
@ -22,7 +19,7 @@ def play_sample(samples, index):
|
||||
if CLI_ARGS.random:
|
||||
index = random.randint(0, len(samples))
|
||||
elif index >= len(samples):
|
||||
print('No sample with index {}'.format(CLI_ARGS.start))
|
||||
print("No sample with index {}".format(CLI_ARGS.start))
|
||||
sys.exit(1)
|
||||
sample = samples[index]
|
||||
print('Sample "{}"'.format(sample.sample_id))
|
||||
@ -48,13 +45,28 @@ def play_collection():
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = argparse.ArgumentParser(description='Tool for playing samples from Sample Databases (SDB files) '
|
||||
'and DeepSpeech CSV files')
|
||||
parser.add_argument('collection', help='Sample DB or CSV file to play samples from')
|
||||
parser.add_argument('--start', type=int, default=0,
|
||||
help='Sample index to start at (negative numbers are relative to the end of the collection)')
|
||||
parser.add_argument('--number', type=int, default=-1, help='Number of samples to play (-1 for endless)')
|
||||
parser.add_argument('--random', action='store_true', help='If samples should be played in random order')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Tool for playing samples from Sample Databases (SDB files) "
|
||||
"and DeepSpeech CSV files"
|
||||
)
|
||||
parser.add_argument("collection", help="Sample DB or CSV file to play samples from")
|
||||
parser.add_argument(
|
||||
"--start",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Sample index to start at (negative numbers are relative to the end of the collection)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--number",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Number of samples to play (-1 for endless)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random",
|
||||
action="store_true",
|
||||
help="If samples should be played in random order",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -68,5 +80,5 @@ if __name__ == "__main__":
|
||||
try:
|
||||
play_collection()
|
||||
except KeyboardInterrupt:
|
||||
print(' Stopped')
|
||||
print(" Stopped")
|
||||
sys.exit(0)
|
||||
|
Loading…
Reference in New Issue
Block a user