578 lines
20 KiB
Python
Executable File
578 lines
20 KiB
Python
Executable File
#!/usr/bin/env python
|
|
"""
|
|
Downloads and prepares (parts of) the "Spoken Wikipedia Corpora" for train.py
|
|
Use "python3 import_swc.py -h" for help
|
|
"""
|
|
|
|
import argparse
|
|
import csv
|
|
import os
|
|
import random
|
|
import re
|
|
import shutil
|
|
import sys
|
|
import tarfile
|
|
import unicodedata
|
|
import wave
|
|
import xml.etree.ElementTree as ET
|
|
from collections import Counter
|
|
from glob import glob
|
|
from multiprocessing.pool import ThreadPool
|
|
|
|
import progressbar
|
|
import sox
|
|
from coqui_stt_ctcdecoder import Alphabet
|
|
from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download
|
|
from coqui_stt_training.util.importers import validate_label_eng as validate_label
|
|
|
|
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"
|
|
SWC_ARCHIVE = "SWC_{language}.tar"
|
|
LANGUAGES = ["dutch", "english", "german"]
|
|
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
|
FIELDNAMES_EXT = FIELDNAMES + ["article", "speaker"]
|
|
CHANNELS = 1
|
|
SAMPLE_RATE = 16000
|
|
UNKNOWN = "<unknown>"
|
|
AUDIO_PATTERN = "audio*.ogg"
|
|
WAV_NAME = "audio.wav"
|
|
ALIGNED_NAME = "aligned.swc"
|
|
|
|
SUBSTITUTIONS = {
|
|
"german": [
|
|
(re.compile(r"\$"), "dollar"),
|
|
(re.compile(r"€"), "euro"),
|
|
(re.compile(r"£"), "pfund"),
|
|
(
|
|
re.compile(r"ein tausend ([^\s]+) hundert ([^\s]+) er( |$)"),
|
|
r"\1zehnhundert \2er ",
|
|
),
|
|
(re.compile(r"ein tausend (acht|neun) hundert"), r"\1zehnhundert"),
|
|
(
|
|
re.compile(
|
|
r"eins punkt null null null punkt null null null punkt null null null"
|
|
),
|
|
"eine milliarde",
|
|
),
|
|
(
|
|
re.compile(
|
|
r"punkt null null null punkt null null null punkt null null null"
|
|
),
|
|
"milliarden",
|
|
),
|
|
(re.compile(r"eins punkt null null null punkt null null null"), "eine million"),
|
|
(re.compile(r"punkt null null null punkt null null null"), "millionen"),
|
|
(re.compile(r"eins punkt null null null"), "ein tausend"),
|
|
(re.compile(r"punkt null null null"), "tausend"),
|
|
(re.compile(r"punkt null"), None),
|
|
]
|
|
}
|
|
|
|
DONT_NORMALIZE = {"german": "ÄÖÜäöüß"}
|
|
|
|
PRE_FILTER = str.maketrans(dict.fromkeys("/()[]{}<>:"))
|
|
|
|
|
|
class Sample:
|
|
def __init__(self, wav_path, start, end, text, article, speaker, sub_set=None):
|
|
self.wav_path = wav_path
|
|
self.start = start
|
|
self.end = end
|
|
self.text = text
|
|
self.article = article
|
|
self.speaker = speaker
|
|
self.sub_set = sub_set
|
|
|
|
|
|
def fail(message):
|
|
print(message)
|
|
sys.exit(1)
|
|
|
|
|
|
def group(lst, get_key):
|
|
groups = {}
|
|
for obj in lst:
|
|
key = get_key(obj)
|
|
if key in groups:
|
|
groups[key].append(obj)
|
|
else:
|
|
groups[key] = [obj]
|
|
return groups
|
|
|
|
|
|
def get_sample_size(population_size):
|
|
margin_of_error = 0.01
|
|
fraction_picking = 0.50
|
|
z_score = 2.58 # Corresponds to confidence level 99%
|
|
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
|
margin_of_error ** 2
|
|
)
|
|
sample_size = 0
|
|
for train_size in range(population_size, 0, -1):
|
|
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
|
margin_of_error ** 2 * train_size
|
|
)
|
|
sample_size = int(numerator / denominator)
|
|
if 2 * sample_size + train_size <= population_size:
|
|
break
|
|
return sample_size
|
|
|
|
|
|
def maybe_download_language(language):
|
|
lang_upper = language[0].upper() + language[1:]
|
|
return maybe_download(
|
|
SWC_ARCHIVE.format(language=lang_upper),
|
|
CLI_ARGS.base_dir,
|
|
SWC_URL.format(language=lang_upper),
|
|
)
|
|
|
|
|
|
def maybe_extract(data_dir, extracted_data, archive):
|
|
extracted = os.path.join(data_dir, extracted_data)
|
|
if os.path.isdir(extracted):
|
|
print('Found directory "{}" - not extracting.'.format(extracted))
|
|
else:
|
|
print('Extracting "{}"...'.format(archive))
|
|
with tarfile.open(archive) as tar:
|
|
members = tar.getmembers()
|
|
bar = progressbar.ProgressBar(max_value=len(members), widgets=SIMPLE_BAR)
|
|
for member in bar(members):
|
|
tar.extract(member=member, path=extracted)
|
|
return extracted
|
|
|
|
|
|
def ignored(node):
|
|
if node is None:
|
|
return False
|
|
if node.tag == "ignored":
|
|
return True
|
|
return ignored(node.find(".."))
|
|
|
|
|
|
def read_token(token):
|
|
texts, start, end = [], None, None
|
|
notes = token.findall("n")
|
|
if len(notes) > 0:
|
|
for note in notes:
|
|
attributes = note.attrib
|
|
if start is None and "start" in attributes:
|
|
start = int(attributes["start"])
|
|
if "end" in attributes:
|
|
token_end = int(attributes["end"])
|
|
if end is None or token_end > end:
|
|
end = token_end
|
|
if "pronunciation" in attributes:
|
|
t = attributes["pronunciation"]
|
|
texts.append(t)
|
|
elif "text" in token.attrib:
|
|
texts.append(token.attrib["text"])
|
|
return start, end, " ".join(texts)
|
|
|
|
|
|
def in_alphabet(alphabet, c):
|
|
return alphabet.CanEncode(c) if alphabet else True
|
|
|
|
|
|
ALPHABETS = {}
|
|
|
|
|
|
def get_alphabet(language):
|
|
if language in ALPHABETS:
|
|
return ALPHABETS[language]
|
|
alphabet_path = getattr(CLI_ARGS, language + "_alphabet")
|
|
alphabet = Alphabet(alphabet_path) if alphabet_path else None
|
|
ALPHABETS[language] = alphabet
|
|
return alphabet
|
|
|
|
|
|
def label_filter(label, language):
|
|
label = label.translate(PRE_FILTER)
|
|
label = validate_label(label)
|
|
if label is None:
|
|
return None, "validation"
|
|
substitutions = SUBSTITUTIONS[language] if language in SUBSTITUTIONS else []
|
|
for pattern, replacement in substitutions:
|
|
if replacement is None:
|
|
if pattern.match(label):
|
|
return None, "substitution rule"
|
|
else:
|
|
label = pattern.sub(replacement, label)
|
|
chars = []
|
|
dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else ""
|
|
alphabet = get_alphabet(language)
|
|
for c in label:
|
|
if (
|
|
CLI_ARGS.normalize
|
|
and c not in dont_normalize
|
|
and not in_alphabet(alphabet, c)
|
|
):
|
|
c = (
|
|
unicodedata.normalize("NFKD", c)
|
|
.encode("ascii", "ignore")
|
|
.decode("ascii", "ignore")
|
|
)
|
|
for sc in c:
|
|
if not in_alphabet(alphabet, sc):
|
|
return None, "illegal character"
|
|
chars.append(sc)
|
|
label = "".join(chars)
|
|
label = validate_label(label)
|
|
return label, "validation" if label is None else None
|
|
|
|
|
|
def collect_samples(base_dir, language):
|
|
roots = []
|
|
for root, _, files in os.walk(base_dir):
|
|
if ALIGNED_NAME in files and WAV_NAME in files:
|
|
roots.append(root)
|
|
samples = []
|
|
reasons = Counter()
|
|
|
|
def add_sample(
|
|
p_wav_path, p_article, p_speaker, p_start, p_end, p_text, p_reason="complete"
|
|
):
|
|
if p_start is not None and p_end is not None and p_text is not None:
|
|
duration = p_end - p_start
|
|
text, filter_reason = label_filter(p_text, language)
|
|
skip = False
|
|
if filter_reason is not None:
|
|
skip = True
|
|
p_reason = filter_reason
|
|
elif CLI_ARGS.exclude_unknown_speakers and p_speaker == UNKNOWN:
|
|
skip = True
|
|
p_reason = "unknown speaker"
|
|
elif CLI_ARGS.exclude_unknown_articles and p_article == UNKNOWN:
|
|
skip = True
|
|
p_reason = "unknown article"
|
|
elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long:
|
|
skip = True
|
|
p_reason = "exceeded duration"
|
|
elif int(duration / 30) < len(text):
|
|
skip = True
|
|
p_reason = "too short to decode"
|
|
elif duration / len(text) < 10:
|
|
skip = True
|
|
p_reason = "length duration ratio"
|
|
if skip:
|
|
reasons[p_reason] += 1
|
|
else:
|
|
samples.append(
|
|
Sample(p_wav_path, p_start, p_end, text, p_article, p_speaker)
|
|
)
|
|
elif p_start is None or p_end is None:
|
|
reasons["missing timestamps"] += 1
|
|
else:
|
|
reasons["missing text"] += 1
|
|
|
|
print("Collecting samples...")
|
|
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
|
|
for root in bar(roots):
|
|
wav_path = os.path.join(root, WAV_NAME)
|
|
aligned = ET.parse(os.path.join(root, ALIGNED_NAME))
|
|
article = UNKNOWN
|
|
speaker = UNKNOWN
|
|
for prop in aligned.iter("prop"):
|
|
attributes = prop.attrib
|
|
if "key" in attributes and "value" in attributes:
|
|
if attributes["key"] == "DC.identifier":
|
|
article = attributes["value"]
|
|
elif attributes["key"] == "reader.name":
|
|
speaker = attributes["value"]
|
|
for sentence in aligned.iter("s"):
|
|
if ignored(sentence):
|
|
continue
|
|
split = False
|
|
tokens = list(map(read_token, sentence.findall("t")))
|
|
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
|
|
for token_start, token_end, token_text in tokens:
|
|
if CLI_ARGS.exclude_numbers and any(c.isdigit() for c in token_text):
|
|
add_sample(
|
|
wav_path,
|
|
article,
|
|
speaker,
|
|
sample_start,
|
|
sample_end,
|
|
" ".join(sample_texts),
|
|
p_reason="has numbers",
|
|
)
|
|
sample_start, sample_end, token_texts, sample_texts = (
|
|
None,
|
|
None,
|
|
[],
|
|
[],
|
|
)
|
|
continue
|
|
if sample_start is None:
|
|
sample_start = token_start
|
|
if sample_start is None:
|
|
continue
|
|
token_texts.append(token_text)
|
|
if token_end is not None:
|
|
if (
|
|
token_start != sample_start
|
|
and token_end - sample_start > CLI_ARGS.max_duration > 0
|
|
):
|
|
add_sample(
|
|
wav_path,
|
|
article,
|
|
speaker,
|
|
sample_start,
|
|
sample_end,
|
|
" ".join(sample_texts),
|
|
p_reason="split",
|
|
)
|
|
sample_start = sample_end
|
|
sample_texts = []
|
|
split = True
|
|
sample_end = token_end
|
|
sample_texts.extend(token_texts)
|
|
token_texts = []
|
|
add_sample(
|
|
wav_path,
|
|
article,
|
|
speaker,
|
|
sample_start,
|
|
sample_end,
|
|
" ".join(sample_texts),
|
|
p_reason="split" if split else "complete",
|
|
)
|
|
print("Skipped samples:")
|
|
for reason, n in reasons.most_common():
|
|
print(" - {}: {}".format(reason, n))
|
|
return samples
|
|
|
|
|
|
def maybe_convert_one_to_wav(entry):
|
|
root, _, files = entry
|
|
transformer = sox.Transformer()
|
|
transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
|
|
combiner = sox.Combiner()
|
|
combiner.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
|
|
output_wav = os.path.join(root, WAV_NAME)
|
|
if os.path.isfile(output_wav):
|
|
return
|
|
files = sorted(glob(os.path.join(root, AUDIO_PATTERN)))
|
|
try:
|
|
if len(files) == 1:
|
|
transformer.build(files[0], output_wav)
|
|
elif len(files) > 1:
|
|
wav_files = []
|
|
for i, file in enumerate(files):
|
|
wav_path = os.path.join(root, "audio{}.wav".format(i))
|
|
transformer.build(file, wav_path)
|
|
wav_files.append(wav_path)
|
|
combiner.set_input_format(file_type=["wav"] * len(wav_files))
|
|
combiner.build(wav_files, output_wav, "concatenate")
|
|
except sox.core.SoxError:
|
|
return
|
|
|
|
|
|
def maybe_convert_to_wav(base_dir):
|
|
roots = list(os.walk(base_dir))
|
|
print("Converting and joining source audio files...")
|
|
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
|
|
tp = ThreadPool()
|
|
for _ in bar(tp.imap_unordered(maybe_convert_one_to_wav, roots)):
|
|
pass
|
|
tp.close()
|
|
tp.join()
|
|
|
|
|
|
def assign_sub_sets(samples):
|
|
sample_size = get_sample_size(len(samples))
|
|
speakers = group(samples, lambda sample: sample.speaker).values()
|
|
speakers = list(sorted(speakers, key=len))
|
|
sample_sets = [[], []]
|
|
while any(map(lambda s: len(s) < sample_size, sample_sets)) and len(speakers) > 0:
|
|
for sample_set in sample_sets:
|
|
if len(sample_set) < sample_size and len(speakers) > 0:
|
|
sample_set.extend(speakers.pop(0))
|
|
train_set = sum(speakers, [])
|
|
if len(train_set) == 0:
|
|
print(
|
|
"WARNING: Unable to build dev and test sets without speaker bias as there is no speaker meta data"
|
|
)
|
|
random.seed(42) # same source data == same output
|
|
random.shuffle(samples)
|
|
for index, sample in enumerate(samples):
|
|
if index < sample_size:
|
|
sample.sub_set = "dev"
|
|
elif index < 2 * sample_size:
|
|
sample.sub_set = "test"
|
|
else:
|
|
sample.sub_set = "train"
|
|
else:
|
|
for sub_set, sub_set_samples in [
|
|
("train", train_set),
|
|
("dev", sample_sets[0]),
|
|
("test", sample_sets[1]),
|
|
]:
|
|
for sample in sub_set_samples:
|
|
sample.sub_set = sub_set
|
|
for sub_set, sub_set_samples in group(samples, lambda s: s.sub_set).items():
|
|
t = sum(map(lambda s: s.end - s.start, sub_set_samples)) / (1000 * 60 * 60)
|
|
print(
|
|
'Sub-set "{}" with {} samples (duration: {:.2f} h)'.format(
|
|
sub_set, len(sub_set_samples), t
|
|
)
|
|
)
|
|
|
|
|
|
def create_sample_dirs(language):
|
|
print("Creating sample directories...")
|
|
for set_name in ["train", "dev", "test"]:
|
|
dir_path = os.path.join(CLI_ARGS.base_dir, language + "-" + set_name)
|
|
if not os.path.isdir(dir_path):
|
|
os.mkdir(dir_path)
|
|
|
|
|
|
def split_audio_files(samples, language):
|
|
print("Splitting audio files...")
|
|
sub_sets = Counter()
|
|
src_wav_files = group(samples, lambda s: s.wav_path).items()
|
|
bar = progressbar.ProgressBar(max_value=len(src_wav_files), widgets=SIMPLE_BAR)
|
|
for wav_path, file_samples in bar(src_wav_files):
|
|
file_samples = sorted(file_samples, key=lambda s: s.start)
|
|
with wave.open(wav_path, "r") as src_wav_file:
|
|
rate = src_wav_file.getframerate()
|
|
for sample in file_samples:
|
|
index = sub_sets[sample.sub_set]
|
|
sample_wav_path = os.path.join(
|
|
CLI_ARGS.base_dir,
|
|
language + "-" + sample.sub_set,
|
|
"sample-{0:06d}.wav".format(index),
|
|
)
|
|
sample.wav_path = sample_wav_path
|
|
sub_sets[sample.sub_set] += 1
|
|
src_wav_file.setpos(int(sample.start * rate / 1000.0))
|
|
data = src_wav_file.readframes(
|
|
int((sample.end - sample.start) * rate / 1000.0)
|
|
)
|
|
with wave.open(sample_wav_path, "w") as sample_wav_file:
|
|
sample_wav_file.setnchannels(src_wav_file.getnchannels())
|
|
sample_wav_file.setsampwidth(src_wav_file.getsampwidth())
|
|
sample_wav_file.setframerate(rate)
|
|
sample_wav_file.writeframes(data)
|
|
|
|
|
|
def write_csvs(samples, language):
|
|
for sub_set, set_samples in group(samples, lambda s: s.sub_set).items():
|
|
set_samples = sorted(set_samples, key=lambda s: s.wav_path)
|
|
base_dir = os.path.abspath(CLI_ARGS.base_dir)
|
|
csv_path = os.path.join(base_dir, language + "-" + sub_set + ".csv")
|
|
print('Writing "{}"...'.format(csv_path))
|
|
with open(csv_path, "w", encoding="utf-8", newline="") as csv_file:
|
|
writer = csv.DictWriter(
|
|
csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES
|
|
)
|
|
writer.writeheader()
|
|
bar = progressbar.ProgressBar(
|
|
max_value=len(set_samples), widgets=SIMPLE_BAR
|
|
)
|
|
for sample in bar(set_samples):
|
|
row = {
|
|
"wav_filename": os.path.relpath(sample.wav_path, base_dir),
|
|
"wav_filesize": os.path.getsize(sample.wav_path),
|
|
"transcript": sample.text,
|
|
}
|
|
if CLI_ARGS.add_meta:
|
|
row["article"] = sample.article
|
|
row["speaker"] = sample.speaker
|
|
writer.writerow(row)
|
|
|
|
|
|
def cleanup(archive, language):
|
|
if not CLI_ARGS.keep_archive:
|
|
print('Removing archive "{}"...'.format(archive))
|
|
os.remove(archive)
|
|
language_dir = os.path.join(CLI_ARGS.base_dir, language)
|
|
if not CLI_ARGS.keep_intermediate and os.path.isdir(language_dir):
|
|
print('Removing intermediate files in "{}"...'.format(language_dir))
|
|
shutil.rmtree(language_dir)
|
|
|
|
|
|
def prepare_language(language):
|
|
archive = maybe_download_language(language)
|
|
extracted = maybe_extract(CLI_ARGS.base_dir, language, archive)
|
|
maybe_convert_to_wav(extracted)
|
|
samples = collect_samples(extracted, language)
|
|
assign_sub_sets(samples)
|
|
create_sample_dirs(language)
|
|
split_audio_files(samples, language)
|
|
write_csvs(samples, language)
|
|
cleanup(archive, language)
|
|
|
|
|
|
def handle_args():
|
|
parser = argparse.ArgumentParser(description="Import Spoken Wikipedia Corpora")
|
|
parser.add_argument("base_dir", help="Directory containing all data")
|
|
parser.add_argument(
|
|
"--language", default="all", help="One of (all|{})".format("|".join(LANGUAGES))
|
|
)
|
|
parser.add_argument(
|
|
"--exclude_numbers",
|
|
type=bool,
|
|
default=True,
|
|
help="If sequences with non-transliterated numbers should be excluded",
|
|
)
|
|
parser.add_argument(
|
|
"--max_duration",
|
|
type=int,
|
|
default=10000,
|
|
help="Maximum sample duration in milliseconds",
|
|
)
|
|
parser.add_argument(
|
|
"--ignore_too_long",
|
|
type=bool,
|
|
default=False,
|
|
help="If samples exceeding max_duration should be removed",
|
|
)
|
|
parser.add_argument(
|
|
"--normalize",
|
|
action="store_true",
|
|
help="Converts diacritic characters to their base ones",
|
|
)
|
|
for language in LANGUAGES:
|
|
parser.add_argument(
|
|
"--{}_alphabet".format(language),
|
|
help="Exclude {} samples with characters not in provided alphabet file".format(
|
|
language
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--add_meta", action="store_true", help="Adds article and speaker CSV columns"
|
|
)
|
|
parser.add_argument(
|
|
"--exclude_unknown_speakers",
|
|
action="store_true",
|
|
help="Exclude unknown speakers",
|
|
)
|
|
parser.add_argument(
|
|
"--exclude_unknown_articles",
|
|
action="store_true",
|
|
help="Exclude unknown articles",
|
|
)
|
|
parser.add_argument(
|
|
"--keep_archive",
|
|
type=bool,
|
|
default=True,
|
|
help="If downloaded archives should be kept",
|
|
)
|
|
parser.add_argument(
|
|
"--keep_intermediate",
|
|
type=bool,
|
|
default=False,
|
|
help="If intermediate files should be kept",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
CLI_ARGS = handle_args()
|
|
if CLI_ARGS.language == "all":
|
|
for lang in LANGUAGES:
|
|
prepare_language(lang)
|
|
elif CLI_ARGS.language in LANGUAGES:
|
|
prepare_language(CLI_ARGS.language)
|
|
else:
|
|
fail("Wrong language id")
|