STT/bin/import_swc.py
2021-05-18 13:45:52 +02:00

578 lines
20 KiB
Python
Executable File

#!/usr/bin/env python
"""
Downloads and prepares (parts of) the "Spoken Wikipedia Corpora" for train.py
Use "python3 import_swc.py -h" for help
"""
import argparse
import csv
import os
import random
import re
import shutil
import sys
import tarfile
import unicodedata
import wave
import xml.etree.ElementTree as ET
from collections import Counter
from glob import glob
from multiprocessing.pool import ThreadPool
import progressbar
import sox
from coqui_stt_ctcdecoder import Alphabet
from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download
from coqui_stt_training.util.importers import validate_label_eng as validate_label
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"
SWC_ARCHIVE = "SWC_{language}.tar"
LANGUAGES = ["dutch", "english", "german"]
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
FIELDNAMES_EXT = FIELDNAMES + ["article", "speaker"]
CHANNELS = 1
SAMPLE_RATE = 16000
UNKNOWN = "<unknown>"
AUDIO_PATTERN = "audio*.ogg"
WAV_NAME = "audio.wav"
ALIGNED_NAME = "aligned.swc"
SUBSTITUTIONS = {
"german": [
(re.compile(r"\$"), "dollar"),
(re.compile(r""), "euro"),
(re.compile(r"£"), "pfund"),
(
re.compile(r"ein tausend ([^\s]+) hundert ([^\s]+) er( |$)"),
r"\1zehnhundert \2er ",
),
(re.compile(r"ein tausend (acht|neun) hundert"), r"\1zehnhundert"),
(
re.compile(
r"eins punkt null null null punkt null null null punkt null null null"
),
"eine milliarde",
),
(
re.compile(
r"punkt null null null punkt null null null punkt null null null"
),
"milliarden",
),
(re.compile(r"eins punkt null null null punkt null null null"), "eine million"),
(re.compile(r"punkt null null null punkt null null null"), "millionen"),
(re.compile(r"eins punkt null null null"), "ein tausend"),
(re.compile(r"punkt null null null"), "tausend"),
(re.compile(r"punkt null"), None),
]
}
DONT_NORMALIZE = {"german": "ÄÖÜäöüß"}
PRE_FILTER = str.maketrans(dict.fromkeys("/()[]{}<>:"))
class Sample:
def __init__(self, wav_path, start, end, text, article, speaker, sub_set=None):
self.wav_path = wav_path
self.start = start
self.end = end
self.text = text
self.article = article
self.speaker = speaker
self.sub_set = sub_set
def fail(message):
print(message)
sys.exit(1)
def group(lst, get_key):
groups = {}
for obj in lst:
key = get_key(obj)
if key in groups:
groups[key].append(obj)
else:
groups[key] = [obj]
return groups
def get_sample_size(population_size):
margin_of_error = 0.01
fraction_picking = 0.50
z_score = 2.58 # Corresponds to confidence level 99%
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2
)
sample_size = 0
for train_size in range(population_size, 0, -1):
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2 * train_size
)
sample_size = int(numerator / denominator)
if 2 * sample_size + train_size <= population_size:
break
return sample_size
def maybe_download_language(language):
lang_upper = language[0].upper() + language[1:]
return maybe_download(
SWC_ARCHIVE.format(language=lang_upper),
CLI_ARGS.base_dir,
SWC_URL.format(language=lang_upper),
)
def maybe_extract(data_dir, extracted_data, archive):
extracted = os.path.join(data_dir, extracted_data)
if os.path.isdir(extracted):
print('Found directory "{}" - not extracting.'.format(extracted))
else:
print('Extracting "{}"...'.format(archive))
with tarfile.open(archive) as tar:
members = tar.getmembers()
bar = progressbar.ProgressBar(max_value=len(members), widgets=SIMPLE_BAR)
for member in bar(members):
tar.extract(member=member, path=extracted)
return extracted
def ignored(node):
if node is None:
return False
if node.tag == "ignored":
return True
return ignored(node.find(".."))
def read_token(token):
texts, start, end = [], None, None
notes = token.findall("n")
if len(notes) > 0:
for note in notes:
attributes = note.attrib
if start is None and "start" in attributes:
start = int(attributes["start"])
if "end" in attributes:
token_end = int(attributes["end"])
if end is None or token_end > end:
end = token_end
if "pronunciation" in attributes:
t = attributes["pronunciation"]
texts.append(t)
elif "text" in token.attrib:
texts.append(token.attrib["text"])
return start, end, " ".join(texts)
def in_alphabet(alphabet, c):
return alphabet.CanEncode(c) if alphabet else True
ALPHABETS = {}
def get_alphabet(language):
if language in ALPHABETS:
return ALPHABETS[language]
alphabet_path = getattr(CLI_ARGS, language + "_alphabet")
alphabet = Alphabet(alphabet_path) if alphabet_path else None
ALPHABETS[language] = alphabet
return alphabet
def label_filter(label, language):
label = label.translate(PRE_FILTER)
label = validate_label(label)
if label is None:
return None, "validation"
substitutions = SUBSTITUTIONS[language] if language in SUBSTITUTIONS else []
for pattern, replacement in substitutions:
if replacement is None:
if pattern.match(label):
return None, "substitution rule"
else:
label = pattern.sub(replacement, label)
chars = []
dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else ""
alphabet = get_alphabet(language)
for c in label:
if (
CLI_ARGS.normalize
and c not in dont_normalize
and not in_alphabet(alphabet, c)
):
c = (
unicodedata.normalize("NFKD", c)
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
for sc in c:
if not in_alphabet(alphabet, sc):
return None, "illegal character"
chars.append(sc)
label = "".join(chars)
label = validate_label(label)
return label, "validation" if label is None else None
def collect_samples(base_dir, language):
roots = []
for root, _, files in os.walk(base_dir):
if ALIGNED_NAME in files and WAV_NAME in files:
roots.append(root)
samples = []
reasons = Counter()
def add_sample(
p_wav_path, p_article, p_speaker, p_start, p_end, p_text, p_reason="complete"
):
if p_start is not None and p_end is not None and p_text is not None:
duration = p_end - p_start
text, filter_reason = label_filter(p_text, language)
skip = False
if filter_reason is not None:
skip = True
p_reason = filter_reason
elif CLI_ARGS.exclude_unknown_speakers and p_speaker == UNKNOWN:
skip = True
p_reason = "unknown speaker"
elif CLI_ARGS.exclude_unknown_articles and p_article == UNKNOWN:
skip = True
p_reason = "unknown article"
elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long:
skip = True
p_reason = "exceeded duration"
elif int(duration / 30) < len(text):
skip = True
p_reason = "too short to decode"
elif duration / len(text) < 10:
skip = True
p_reason = "length duration ratio"
if skip:
reasons[p_reason] += 1
else:
samples.append(
Sample(p_wav_path, p_start, p_end, text, p_article, p_speaker)
)
elif p_start is None or p_end is None:
reasons["missing timestamps"] += 1
else:
reasons["missing text"] += 1
print("Collecting samples...")
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
for root in bar(roots):
wav_path = os.path.join(root, WAV_NAME)
aligned = ET.parse(os.path.join(root, ALIGNED_NAME))
article = UNKNOWN
speaker = UNKNOWN
for prop in aligned.iter("prop"):
attributes = prop.attrib
if "key" in attributes and "value" in attributes:
if attributes["key"] == "DC.identifier":
article = attributes["value"]
elif attributes["key"] == "reader.name":
speaker = attributes["value"]
for sentence in aligned.iter("s"):
if ignored(sentence):
continue
split = False
tokens = list(map(read_token, sentence.findall("t")))
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
for token_start, token_end, token_text in tokens:
if CLI_ARGS.exclude_numbers and any(c.isdigit() for c in token_text):
add_sample(
wav_path,
article,
speaker,
sample_start,
sample_end,
" ".join(sample_texts),
p_reason="has numbers",
)
sample_start, sample_end, token_texts, sample_texts = (
None,
None,
[],
[],
)
continue
if sample_start is None:
sample_start = token_start
if sample_start is None:
continue
token_texts.append(token_text)
if token_end is not None:
if (
token_start != sample_start
and token_end - sample_start > CLI_ARGS.max_duration > 0
):
add_sample(
wav_path,
article,
speaker,
sample_start,
sample_end,
" ".join(sample_texts),
p_reason="split",
)
sample_start = sample_end
sample_texts = []
split = True
sample_end = token_end
sample_texts.extend(token_texts)
token_texts = []
add_sample(
wav_path,
article,
speaker,
sample_start,
sample_end,
" ".join(sample_texts),
p_reason="split" if split else "complete",
)
print("Skipped samples:")
for reason, n in reasons.most_common():
print(" - {}: {}".format(reason, n))
return samples
def maybe_convert_one_to_wav(entry):
root, _, files = entry
transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
combiner = sox.Combiner()
combiner.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
output_wav = os.path.join(root, WAV_NAME)
if os.path.isfile(output_wav):
return
files = sorted(glob(os.path.join(root, AUDIO_PATTERN)))
try:
if len(files) == 1:
transformer.build(files[0], output_wav)
elif len(files) > 1:
wav_files = []
for i, file in enumerate(files):
wav_path = os.path.join(root, "audio{}.wav".format(i))
transformer.build(file, wav_path)
wav_files.append(wav_path)
combiner.set_input_format(file_type=["wav"] * len(wav_files))
combiner.build(wav_files, output_wav, "concatenate")
except sox.core.SoxError:
return
def maybe_convert_to_wav(base_dir):
roots = list(os.walk(base_dir))
print("Converting and joining source audio files...")
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
tp = ThreadPool()
for _ in bar(tp.imap_unordered(maybe_convert_one_to_wav, roots)):
pass
tp.close()
tp.join()
def assign_sub_sets(samples):
sample_size = get_sample_size(len(samples))
speakers = group(samples, lambda sample: sample.speaker).values()
speakers = list(sorted(speakers, key=len))
sample_sets = [[], []]
while any(map(lambda s: len(s) < sample_size, sample_sets)) and len(speakers) > 0:
for sample_set in sample_sets:
if len(sample_set) < sample_size and len(speakers) > 0:
sample_set.extend(speakers.pop(0))
train_set = sum(speakers, [])
if len(train_set) == 0:
print(
"WARNING: Unable to build dev and test sets without speaker bias as there is no speaker meta data"
)
random.seed(42) # same source data == same output
random.shuffle(samples)
for index, sample in enumerate(samples):
if index < sample_size:
sample.sub_set = "dev"
elif index < 2 * sample_size:
sample.sub_set = "test"
else:
sample.sub_set = "train"
else:
for sub_set, sub_set_samples in [
("train", train_set),
("dev", sample_sets[0]),
("test", sample_sets[1]),
]:
for sample in sub_set_samples:
sample.sub_set = sub_set
for sub_set, sub_set_samples in group(samples, lambda s: s.sub_set).items():
t = sum(map(lambda s: s.end - s.start, sub_set_samples)) / (1000 * 60 * 60)
print(
'Sub-set "{}" with {} samples (duration: {:.2f} h)'.format(
sub_set, len(sub_set_samples), t
)
)
def create_sample_dirs(language):
print("Creating sample directories...")
for set_name in ["train", "dev", "test"]:
dir_path = os.path.join(CLI_ARGS.base_dir, language + "-" + set_name)
if not os.path.isdir(dir_path):
os.mkdir(dir_path)
def split_audio_files(samples, language):
print("Splitting audio files...")
sub_sets = Counter()
src_wav_files = group(samples, lambda s: s.wav_path).items()
bar = progressbar.ProgressBar(max_value=len(src_wav_files), widgets=SIMPLE_BAR)
for wav_path, file_samples in bar(src_wav_files):
file_samples = sorted(file_samples, key=lambda s: s.start)
with wave.open(wav_path, "r") as src_wav_file:
rate = src_wav_file.getframerate()
for sample in file_samples:
index = sub_sets[sample.sub_set]
sample_wav_path = os.path.join(
CLI_ARGS.base_dir,
language + "-" + sample.sub_set,
"sample-{0:06d}.wav".format(index),
)
sample.wav_path = sample_wav_path
sub_sets[sample.sub_set] += 1
src_wav_file.setpos(int(sample.start * rate / 1000.0))
data = src_wav_file.readframes(
int((sample.end - sample.start) * rate / 1000.0)
)
with wave.open(sample_wav_path, "w") as sample_wav_file:
sample_wav_file.setnchannels(src_wav_file.getnchannels())
sample_wav_file.setsampwidth(src_wav_file.getsampwidth())
sample_wav_file.setframerate(rate)
sample_wav_file.writeframes(data)
def write_csvs(samples, language):
for sub_set, set_samples in group(samples, lambda s: s.sub_set).items():
set_samples = sorted(set_samples, key=lambda s: s.wav_path)
base_dir = os.path.abspath(CLI_ARGS.base_dir)
csv_path = os.path.join(base_dir, language + "-" + sub_set + ".csv")
print('Writing "{}"...'.format(csv_path))
with open(csv_path, "w", encoding="utf-8", newline="") as csv_file:
writer = csv.DictWriter(
csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES
)
writer.writeheader()
bar = progressbar.ProgressBar(
max_value=len(set_samples), widgets=SIMPLE_BAR
)
for sample in bar(set_samples):
row = {
"wav_filename": os.path.relpath(sample.wav_path, base_dir),
"wav_filesize": os.path.getsize(sample.wav_path),
"transcript": sample.text,
}
if CLI_ARGS.add_meta:
row["article"] = sample.article
row["speaker"] = sample.speaker
writer.writerow(row)
def cleanup(archive, language):
if not CLI_ARGS.keep_archive:
print('Removing archive "{}"...'.format(archive))
os.remove(archive)
language_dir = os.path.join(CLI_ARGS.base_dir, language)
if not CLI_ARGS.keep_intermediate and os.path.isdir(language_dir):
print('Removing intermediate files in "{}"...'.format(language_dir))
shutil.rmtree(language_dir)
def prepare_language(language):
archive = maybe_download_language(language)
extracted = maybe_extract(CLI_ARGS.base_dir, language, archive)
maybe_convert_to_wav(extracted)
samples = collect_samples(extracted, language)
assign_sub_sets(samples)
create_sample_dirs(language)
split_audio_files(samples, language)
write_csvs(samples, language)
cleanup(archive, language)
def handle_args():
parser = argparse.ArgumentParser(description="Import Spoken Wikipedia Corpora")
parser.add_argument("base_dir", help="Directory containing all data")
parser.add_argument(
"--language", default="all", help="One of (all|{})".format("|".join(LANGUAGES))
)
parser.add_argument(
"--exclude_numbers",
type=bool,
default=True,
help="If sequences with non-transliterated numbers should be excluded",
)
parser.add_argument(
"--max_duration",
type=int,
default=10000,
help="Maximum sample duration in milliseconds",
)
parser.add_argument(
"--ignore_too_long",
type=bool,
default=False,
help="If samples exceeding max_duration should be removed",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
for language in LANGUAGES:
parser.add_argument(
"--{}_alphabet".format(language),
help="Exclude {} samples with characters not in provided alphabet file".format(
language
),
)
parser.add_argument(
"--add_meta", action="store_true", help="Adds article and speaker CSV columns"
)
parser.add_argument(
"--exclude_unknown_speakers",
action="store_true",
help="Exclude unknown speakers",
)
parser.add_argument(
"--exclude_unknown_articles",
action="store_true",
help="Exclude unknown articles",
)
parser.add_argument(
"--keep_archive",
type=bool,
default=True,
help="If downloaded archives should be kept",
)
parser.add_argument(
"--keep_intermediate",
type=bool,
default=False,
help="If intermediate files should be kept",
)
return parser.parse_args()
if __name__ == "__main__":
CLI_ARGS = handle_args()
if CLI_ARGS.language == "all":
for lang in LANGUAGES:
prepare_language(lang)
elif CLI_ARGS.language in LANGUAGES:
prepare_language(CLI_ARGS.language)
else:
fail("Wrong language id")