TUDA importer

This commit is contained in:
Tilman Kamp 2019-10-25 13:32:59 +02:00
parent 44a605c8b7
commit 2cdfcff4c6

165
bin/import_tuda.py Executable file
View File

@ -0,0 +1,165 @@
#!/usr/bin/env python
'''
Downloads and prepares (parts of) the "German Distant Speech" corpus (TUDA) for DeepSpeech.py
Use "python3 import_tuda.py -h" for help
'''
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import csv
import wave
import tarfile
import argparse
import progressbar
import unicodedata
import xml.etree.cElementTree as ET
from os import path
from collections import Counter
from util.text import Alphabet, validate_label
from util.downloader import maybe_download, SIMPLE_BAR
TUDA_VERSION = 'v2'
TUDA_PACKAGE = 'german-speechdata-package-{}'.format(TUDA_VERSION)
TUDA_URL = 'http://ltdata1.informatik.uni-hamburg.de/kaldi_tuda_de/{}.tar.gz'.format(TUDA_PACKAGE)
TUDA_ARCHIVE = '{}.tar.gz'.format(TUDA_PACKAGE)
CHANNELS = 1
SAMPLE_WIDTH = 2
SAMPLE_RATE = 16000
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
def maybe_extract(archive):
extracted = path.join(CLI_ARGS.base_dir, TUDA_PACKAGE)
if path.isdir(extracted):
print('Found directory "{}" - not extracting.'.format(extracted))
else:
print('Extracting "{}"...'.format(archive))
with tarfile.open(archive) as tar:
members = tar.getmembers()
bar = progressbar.ProgressBar(max_value=len(members), widgets=SIMPLE_BAR)
for member in bar(members):
tar.extract(member=member, path=CLI_ARGS.base_dir)
return extracted
def check_and_prepare_sentence(sentence):
sentence = sentence.lower().replace('co2', 'c o zwei')
chars = []
for c in sentence:
if CLI_ARGS.normalize and c not in 'äöüß' and (ALPHABET is None or not ALPHABET.has_char(c)):
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore")
for sc in c:
if ALPHABET is not None and not ALPHABET.has_char(c):
return None
chars.append(sc)
return validate_label(''.join(chars))
def check_wav_file(wav_path, sentence): # pylint: disable=too-many-return-statements
try:
with wave.open(wav_path, 'r') as src_wav_file:
rate = src_wav_file.getframerate()
channels = src_wav_file.getnchannels()
sample_width = src_wav_file.getsampwidth()
milliseconds = int(src_wav_file.getnframes() * 1000 / rate)
if rate != SAMPLE_RATE:
return False, 'wrong sample rate'
if channels != CHANNELS:
return False, 'wrong number of channels'
if sample_width != SAMPLE_WIDTH:
return False, 'wrong sample width'
if milliseconds / len(sentence) < 20:
return False, 'too short'
if milliseconds > CLI_ARGS.max_duration > 0:
return False, 'too long'
except wave.Error:
return False, 'invalid wav file'
except EOFError:
return False, 'premature EOF'
return True, 'OK'
def write_csvs(extracted):
sample_counter = 0
reasons = Counter()
for sub_set in ['train', 'dev', 'test']:
set_path = path.join(extracted, sub_set)
set_files = os.listdir(set_path)
recordings = {}
for file in set_files:
if file.endswith('.xml'):
recordings[file[:-4]] = []
for file in set_files:
if file.endswith('.wav') and '_' in file:
prefix = file.split('_')[0]
if prefix in recordings:
recordings[prefix].append(file)
recordings = recordings.items()
csv_path = path.join(CLI_ARGS.base_dir, 'tuda-{}-{}.csv'.format(TUDA_VERSION, sub_set))
print('Writing "{}"...'.format(csv_path))
with open(csv_path, 'w') as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES)
writer.writeheader()
set_dir = path.join(extracted, sub_set)
bar = progressbar.ProgressBar(max_value=len(recordings), widgets=SIMPLE_BAR)
for prefix, wav_names in bar(recordings):
xml_path = path.join(set_dir, prefix + '.xml')
meta = ET.parse(xml_path).getroot()
sentence = list(meta.iter('cleaned_sentence'))[0].text
sentence = check_and_prepare_sentence(sentence)
if sentence is None:
continue
for wav_name in wav_names:
sample_counter += 1
wav_path = path.join(set_path, wav_name)
keep, reason = check_wav_file(wav_path, sentence)
if keep:
writer.writerow({
'wav_filename': path.relpath(wav_path, CLI_ARGS.base_dir),
'wav_filesize': path.getsize(wav_path),
'transcript': sentence.lower()
})
else:
reasons[reason] += 1
if len(reasons.keys()) > 0:
print('Excluded samples:')
for reason, n in reasons.most_common():
print(' - "{}": {} ({:.2f}%)'.format(reason, n, n * 100 / sample_counter))
def cleanup(archive):
if not CLI_ARGS.keep_archive:
print('Removing archive "{}"...'.format(archive))
os.remove(archive)
def download_and_prepare():
archive = maybe_download(TUDA_ARCHIVE, CLI_ARGS.base_dir, TUDA_URL)
extracted = maybe_extract(archive)
write_csvs(extracted)
cleanup(archive)
def handle_args():
parser = argparse.ArgumentParser(description='Import German Distant Speech (TUDA)')
parser.add_argument('base_dir', help='Directory containing all data')
parser.add_argument('--max_duration', type=int, default=10000, help='Maximum sample duration in milliseconds')
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
parser.add_argument('--alphabet', help='Exclude samples with characters not in provided alphabet file')
parser.add_argument('--keep_archive', type=bool, default=True,
help='If downloaded archives should be kept')
return parser.parse_args()
if __name__ == "__main__":
CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.alphabet) if CLI_ARGS.alphabet else None
download_and_prepare()