STT/bin/import_tuda.py
2021-05-18 13:45:52 +02:00

192 lines
6.6 KiB
Python
Executable File

#!/usr/bin/env python
"""
Downloads and prepares (parts of) the "German Distant Speech" corpus (TUDA) for train.py
Use "python3 import_tuda.py -h" for help
"""
import argparse
import csv
import os
import tarfile
import unicodedata
import wave
import xml.etree.ElementTree as ET
from collections import Counter
import progressbar
from coqui_stt_ctcdecoder import Alphabet
from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download
from coqui_stt_training.util.importers import validate_label_eng as validate_label
TUDA_VERSION = "v2"
TUDA_PACKAGE = "german-speechdata-package-{}".format(TUDA_VERSION)
TUDA_URL = "http://ltdata1.informatik.uni-hamburg.de/kaldi_tuda_de/{}.tar.gz".format(
TUDA_PACKAGE
)
TUDA_ARCHIVE = "{}.tar.gz".format(TUDA_PACKAGE)
CHANNELS = 1
SAMPLE_WIDTH = 2
SAMPLE_RATE = 16000
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
def maybe_extract(archive):
extracted = os.path.join(CLI_ARGS.base_dir, TUDA_PACKAGE)
if os.path.isdir(extracted):
print('Found directory "{}" - not extracting.'.format(extracted))
else:
print('Extracting "{}"...'.format(archive))
with tarfile.open(archive) as tar:
members = tar.getmembers()
bar = progressbar.ProgressBar(max_value=len(members), widgets=SIMPLE_BAR)
for member in bar(members):
tar.extract(member=member, path=CLI_ARGS.base_dir)
return extracted
def in_alphabet(c):
return ALPHABET.CanEncode(c) if ALPHABET else True
def check_and_prepare_sentence(sentence):
sentence = sentence.lower().replace("co2", "c o zwei")
chars = []
for c in sentence:
if CLI_ARGS.normalize and c not in "äöüß" and not in_alphabet(c):
c = (
unicodedata.normalize("NFKD", c)
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
for sc in c:
if not in_alphabet(c):
return None
chars.append(sc)
return validate_label("".join(chars))
def check_wav_file(wav_path, sentence): # pylint: disable=too-many-return-statements
try:
with wave.open(wav_path, "r") as src_wav_file:
rate = src_wav_file.getframerate()
channels = src_wav_file.getnchannels()
sample_width = src_wav_file.getsampwidth()
milliseconds = int(src_wav_file.getnframes() * 1000 / rate)
if rate != SAMPLE_RATE:
return False, "wrong sample rate"
if channels != CHANNELS:
return False, "wrong number of channels"
if sample_width != SAMPLE_WIDTH:
return False, "wrong sample width"
if milliseconds / len(sentence) < 30:
return False, "too short"
if milliseconds > CLI_ARGS.max_duration > 0:
return False, "too long"
except wave.Error:
return False, "invalid wav file"
except EOFError:
return False, "premature EOF"
return True, "OK"
def write_csvs(extracted):
sample_counter = 0
reasons = Counter()
for sub_set in ["train", "dev", "test"]:
set_path = os.path.join(extracted, sub_set)
set_files = os.listdir(set_path)
recordings = {}
for file in set_files:
if file.endswith(".xml"):
recordings[file[:-4]] = []
for file in set_files:
if file.endswith(".wav") and "_" in file:
prefix = file.split("_")[0]
if prefix in recordings:
recordings[prefix].append(file)
recordings = recordings.items()
csv_path = os.path.join(
CLI_ARGS.base_dir, "tuda-{}-{}.csv".format(TUDA_VERSION, sub_set)
)
print('Writing "{}"...'.format(csv_path))
with open(csv_path, "w", encoding="utf-8", newline="") as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES)
writer.writeheader()
set_dir = os.path.join(extracted, sub_set)
bar = progressbar.ProgressBar(max_value=len(recordings), widgets=SIMPLE_BAR)
for prefix, wav_names in bar(recordings):
xml_path = os.path.join(set_dir, prefix + ".xml")
meta = ET.parse(xml_path).getroot()
sentence = list(meta.iter("cleaned_sentence"))[0].text
sentence = check_and_prepare_sentence(sentence)
if sentence is None:
reasons["alphabet filter"] += 1
continue
for wav_name in wav_names:
sample_counter += 1
wav_path = os.path.join(set_path, wav_name)
keep, reason = check_wav_file(wav_path, sentence)
if keep:
writer.writerow(
{
"wav_filename": os.path.relpath(
wav_path, CLI_ARGS.base_dir
),
"wav_filesize": os.path.getsize(wav_path),
"transcript": sentence.lower(),
}
)
else:
reasons[reason] += 1
if len(reasons.keys()) > 0:
print("Excluded samples:")
for reason, n in reasons.most_common():
print(' - "{}": {} ({:.2f}%)'.format(reason, n, n * 100 / sample_counter))
def cleanup(archive):
if not CLI_ARGS.keep_archive:
print('Removing archive "{}"...'.format(archive))
os.remove(archive)
def download_and_prepare():
archive = maybe_download(TUDA_ARCHIVE, CLI_ARGS.base_dir, TUDA_URL)
extracted = maybe_extract(archive)
write_csvs(extracted)
cleanup(archive)
def handle_args():
parser = argparse.ArgumentParser(description="Import German Distant Speech (TUDA)")
parser.add_argument("base_dir", help="Directory containing all data")
parser.add_argument(
"--max_duration",
type=int,
default=10000,
help="Maximum sample duration in milliseconds",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
parser.add_argument(
"--alphabet",
help="Exclude samples with characters not in provided alphabet file",
)
parser.add_argument(
"--keep_archive",
type=bool,
default=True,
help="If downloaded archives should be kept",
)
return parser.parse_args()
if __name__ == "__main__":
CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.alphabet) if CLI_ARGS.alphabet else None
download_and_prepare()