Spoken Wikipedia importer

This commit is contained in:
Tilman Kamp 2019-10-23 14:19:27 +02:00
parent d35107acdb
commit 3424ab2b5d

455
bin/import_swc.py Executable file
View File

@ -0,0 +1,455 @@
#!/usr/bin/env python
'''
Downloads and prepares (parts of) the "Spoken Wikipedia Corpora" for DeepSpeech.py
Use "python3 import_swc.py -h" for help
'''
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import re
import csv
import sox
import wave
import shutil
import random
import tarfile
import argparse
import progressbar
import unicodedata
import xml.etree.cElementTree as ET
from os import path
from glob import glob
from collections import Counter
from multiprocessing.pool import ThreadPool
from util.text import Alphabet, validate_label
from util.downloader import maybe_download, SIMPLE_BAR
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"
SWC_ARCHIVE = "SWC_{language}.tar"
LANGUAGES = ['dutch', 'english', 'german']
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
CHANNELS = 1
SAMPLE_RATE = 16000
AUDIO_PATTERN = 'audio*.ogg'
WAV_NAME = 'audio.wav'
ALIGNED_NAME = 'aligned.swc'
SUBSTITUTIONS = {
'german': [
(re.compile(r'\$'), 'dollar'),
(re.compile(r''), 'euro'),
(re.compile(r'£'), 'pfund'),
(re.compile(r'ein tausend ([^\s]+) hundert ([^\s]+) er( |$)'), r'\1zehnhundert \2er '),
(re.compile(r'ein tausend (acht|neun) hundert'), r'\1zehnhundert'),
(re.compile(r'eins punkt null null null punkt null null null punkt null null null'), 'eine milliarde'),
(re.compile(r'punkt null null null punkt null null null punkt null null null'), 'milliarden'),
(re.compile(r'eins punkt null null null punkt null null null'), 'eine million'),
(re.compile(r'punkt null null null punkt null null null'), 'millionen'),
(re.compile(r'eins punkt null null null'), 'ein tausend'),
(re.compile(r'punkt null null null'), 'tausend'),
(re.compile(r'punkt null'), None)
] # TODO: Add Dutch and English
}
DONT_NORMALIZE = {
'german': 'ÄÖÜäöüß',
'dutch': 'IJij'
}
PRE_FILTER = str.maketrans(dict.fromkeys('/()[]{}<>:'))
class Sample:
def __init__(self, wav_path, start, end, text, speaker, sub_set=None):
self.wav_path = wav_path
self.start = start
self.end = end
self.text = text
self.speaker = speaker
self.sub_set = sub_set
def fail(message):
print(message)
exit(1)
def group(lst, get_key):
groups = {}
for obj in lst:
key = get_key(obj)
if key in groups:
groups[key].append(obj)
else:
groups[key] = [obj]
return groups
def get_sample_size(population_size):
margin_of_error = 0.01
fraction_picking = 0.50
z_score = 2.58 # Corresponds to confidence level 99%
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2
)
sample_size = 0
for train_size in range(population_size, 0, -1):
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2 * train_size
)
sample_size = int(numerator / denominator)
if 2 * sample_size + train_size <= population_size:
break
return sample_size
def maybe_download_language(language):
lang_upper = language[0].upper() + language[1:]
return maybe_download(SWC_ARCHIVE.format(language=lang_upper),
CLI_ARGS.base_dir,
SWC_URL.format(language=lang_upper))
def maybe_extract(data_dir, extracted_data, archive):
extracted = path.join(data_dir, extracted_data)
if path.isdir(extracted):
print('Found directory "{}" - not extracting.'.format(extracted))
else:
print('Extracting "{}"...'.format(archive))
with tarfile.open(archive) as tar:
members = tar.getmembers()
bar = progressbar.ProgressBar(max_value=len(members), widgets=SIMPLE_BAR)
for member in bar(members):
tar.extract(member=member, path=extracted)
return extracted
def ignored(node):
if node is None:
return False
if node.tag == 'ignored':
return True
return ignored(node.find('..'))
def read_token(token):
texts, start, end = [], None, None
notes = token.findall('n')
if len(notes) > 0:
for note in notes:
attributes = note.attrib
if start is None and 'start' in attributes:
start = int(attributes['start'])
if 'end' in attributes:
token_end = int(attributes['end'])
if end is None or token_end > end:
end = token_end
if 'pronunciation' in attributes:
t = attributes['pronunciation']
texts.append(t)
elif 'text' in token.attrib:
texts.append(token.attrib['text'])
return start, end, ' '.join(texts)
def in_alphabet(alphabet, c):
if alphabet is None:
return False
try:
alphabet.label_from_string(c)
return True
except KeyError:
return False
alphabets = {}
def get_alphabet(language):
global alphabets
if language in alphabets:
return alphabets[language]
alphabet_path = getattr(CLI_ARGS, language + '_alphabet')
alphabet = Alphabet(alphabet_path) if alphabet_path else None
alphabets[language] = alphabet
return alphabet
def label_filter(label, language):
label = label.translate(PRE_FILTER)
label = validate_label(label)
if label is None:
return None, 'validation'
substitutions = SUBSTITUTIONS[language] if language in SUBSTITUTIONS else []
for pattern, replacement in substitutions:
if replacement is None:
if pattern.match(label):
return None, 'substitution rule'
else:
label = pattern.sub(replacement, label)
chars = []
dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else ''
alphabet = get_alphabet(language)
for c in label:
if CLI_ARGS.normalize and c not in dont_normalize and not in_alphabet(alphabet, c):
c = (unicodedata.normalize("NFKD", c)
.encode("ascii", "ignore")
.decode("ascii", "ignore"))
for sc in c:
if alphabet is not None and not in_alphabet(alphabet, sc):
return None, 'illegal character'
chars.append(sc)
label = ''.join(chars)
label = validate_label(label)
return label, 'validation' if label is None else None
def collect_samples(base_dir, language):
roots = []
for root, dirs, files in os.walk(base_dir):
if ALIGNED_NAME in files and WAV_NAME in files:
roots.append(root)
samples = []
reasons = Counter()
print('Collecting samples...')
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
for root in bar(roots):
wav_path = path.join(root, WAV_NAME)
aligned = ET.parse(path.join(root, ALIGNED_NAME))
speaker = '<unknown>'
for prop in aligned.iter('prop'):
attributes = prop.attrib
if 'key' in attributes and 'value' in attributes and attributes['key'] == 'reader.name':
speaker = attributes['value']
break
for sentence in aligned.iter('s'):
def add_sample(start, end, text, reason='complete'):
if start is not None and end is not None and text is not None:
duration = end - start
text, filter_reason = label_filter(text, language)
skip = False
if filter_reason is not None:
skip = True
reason = filter_reason
elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long:
skip = True
reason = 'exceeded duration'
elif int(duration / 20) < len(text):
skip = True
reason = 'too short to decode'
elif duration / len(text) < 10:
skip = True
reason = 'length duration ratio'
if skip:
reasons[reason] += 1
else:
samples.append(Sample(wav_path, start, end, text, speaker))
elif start is None or end is None:
reasons['missing timestamps'] += 1
else:
reasons['missing text'] += 1
if ignored(sentence):
continue
split = False
tokens = list(map(lambda token: read_token(token), sentence.findall('t')))
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
for token_start, token_end, token_text in tokens:
if CLI_ARGS.exclude_numbers and any(c.isdigit() for c in token_text):
add_sample(sample_start, sample_end, ' '.join(sample_texts), reason='has numbers')
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
continue
if sample_start is None:
sample_start = token_start
if sample_start is None:
continue
else:
token_texts.append(token_text)
if token_end is not None:
if token_start != sample_start and token_end - sample_start > CLI_ARGS.max_duration > 0:
add_sample(sample_start, sample_end, ' '.join(sample_texts), reason='split')
sample_start = sample_end
sample_texts = []
split = True
sample_end = token_end
sample_texts.extend(token_texts)
token_texts = []
add_sample(sample_start, sample_end, ' '.join(sample_texts), reason='split' if split else 'complete')
print('Skipped samples:')
for reason, n in reasons.most_common():
print(' - {}: {}'.format(reason, n))
return samples
def maybe_convert_one_to_wav(entry):
root, dirs, files = entry
transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
combiner = sox.Combiner()
combiner.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
output_wav = path.join(root, WAV_NAME)
if path.isfile(output_wav):
return
files = sorted(glob(path.join(root, AUDIO_PATTERN)))
try:
if len(files) == 1:
transformer.build(files[0], output_wav)
elif len(files) > 1:
wav_files = []
for i, file in enumerate(files):
wav_path = path.join(root, 'audio{}.wav'.format(i))
transformer.build(file, wav_path)
wav_files.append(wav_path)
combiner.set_input_format(file_type=['wav'] * len(wav_files))
combiner.build(wav_files, output_wav, 'concatenate')
except sox.core.SoxError:
return
def maybe_convert_to_wav(base_dir):
roots = list(os.walk(base_dir))
print('Converting and joining source audio files...')
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
tp = ThreadPool()
for _ in bar(tp.imap_unordered(maybe_convert_one_to_wav, roots)):
pass
tp.close()
tp.join()
def assign_sub_sets(samples):
sample_size = get_sample_size(len(samples))
speakers = group(samples, lambda sample: sample.speaker).values()
speakers = list(sorted(speakers, key=lambda speaker_samples: len(speaker_samples)))
sample_sets = [[], []]
while any(map(lambda s: len(s) < sample_size, sample_sets)) and len(speakers) > 0:
for sample_set in sample_sets:
if len(sample_set) < sample_size and len(speakers) > 0:
sample_set.extend(speakers.pop(0))
train_set = sum(speakers, [])
if len(train_set) == 0:
print('WARNING: Unable to build dev and test sets without speaker bias as there is no speaker meta data')
random.seed(42) # same source data == same output
random.shuffle(samples)
for index, sample in enumerate(samples):
if index < sample_size:
sample.sub_set = 'dev'
elif index < 2 * sample_size:
sample.sub_set = 'test'
else:
sample.sub_set = 'train'
else:
for sub_set, sub_set_samples in [('train', train_set), ('dev', sample_sets[0]), ('test', sample_sets[1])]:
for sample in sub_set_samples:
sample.sub_set = sub_set
for sub_set, sub_set_samples in group(samples, lambda s: s.sub_set).items():
t = sum(map(lambda s: s.end - s.start, sub_set_samples)) / (1000 * 60 * 60)
print('Sub-set "{}" with {} samples (duration: {:.2f} h)'
.format(sub_set, len(sub_set_samples), t))
def create_sample_dirs(language):
print('Creating sample directories...')
for set_name in ['train', 'dev', 'test']:
dir_path = path.join(CLI_ARGS.base_dir, language + '-' + set_name)
if not path.isdir(dir_path):
os.mkdir(dir_path)
def split_audio_files(samples, language):
print('Splitting audio files...')
sub_sets = Counter()
src_wav_files = group(samples, lambda s: s.wav_path).items()
bar = progressbar.ProgressBar(max_value=len(src_wav_files), widgets=SIMPLE_BAR)
for wav_path, file_samples in bar(src_wav_files):
file_samples = sorted(file_samples, key=lambda s: s.start)
with wave.open(wav_path, 'r') as src_wav_file:
rate = src_wav_file.getframerate()
for sample in file_samples:
index = sub_sets[sample.sub_set]
sample_wav_path = path.join(CLI_ARGS.base_dir,
language + '-' + sample.sub_set,
'sample-{0:06d}.wav'.format(index))
sample.wav_path = sample_wav_path
sub_sets[sample.sub_set] += 1
src_wav_file.setpos(int(sample.start * rate / 1000.0))
data = src_wav_file.readframes(int((sample.end - sample.start) * rate / 1000.0))
with wave.open(sample_wav_path, 'w') as sample_wav_file:
sample_wav_file.setnchannels(src_wav_file.getnchannels())
sample_wav_file.setsampwidth(src_wav_file.getsampwidth())
sample_wav_file.setframerate(rate)
sample_wav_file.writeframes(data)
def write_csvs(samples, language):
for sub_set, set_samples in group(samples, lambda s: s.sub_set).items():
set_samples = sorted(set_samples, key=lambda s: s.wav_path)
base_dir = path.abspath(CLI_ARGS.base_dir)
csv_path = path.join(base_dir, language + '-' + sub_set + '.csv')
print('Writing "{}"...'.format(csv_path))
with open(csv_path, 'w') as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES)
writer.writeheader()
bar = progressbar.ProgressBar(max_value=len(set_samples), widgets=SIMPLE_BAR)
for sample in bar(set_samples):
writer.writerow({
'wav_filename': path.relpath(sample.wav_path, base_dir),
'wav_filesize': path.getsize(sample.wav_path),
'transcript': sample.text
})
def cleanup(archive, language):
if not CLI_ARGS.keep_archive:
print('Removing archive "{}"...'.format(archive))
os.remove(archive)
language_dir = path.join(CLI_ARGS.base_dir, language)
if not CLI_ARGS.keep_intermediate and path.isdir(language_dir):
print('Removing intermediate files in "{}"...'.format(language_dir))
shutil.rmtree(language_dir)
def prepare_language(language):
archive = maybe_download_language(language)
extracted = maybe_extract(CLI_ARGS.base_dir, language, archive)
maybe_convert_to_wav(extracted)
samples = collect_samples(extracted, language)
assign_sub_sets(samples)
create_sample_dirs(language)
split_audio_files(samples, language)
write_csvs(samples, language)
cleanup(archive, language)
def handle_args():
parser = argparse.ArgumentParser(description='Import Spoken Wikipedia Corpora')
parser.add_argument('base_dir', help='Directory containing all data')
parser.add_argument('--language', default='all', help='One of (all|{})'.format('|'.join(LANGUAGES)))
parser.add_argument('--exclude_numbers', type=bool, default=True,
help='If sequences with non-transliterated numbers should be excluded')
parser.add_argument('--max_duration', type=int, default=10000, help='Maximum sample duration in milliseconds')
parser.add_argument('--ignore_too_long', type=bool, default=False,
help='If samples exceeding max_duration should be removed')
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
for language in LANGUAGES:
parser.add_argument('--{}_alphabet'.format(language),
help='Exclude {} samples with characters not in provided alphabet file'.format(language))
parser.add_argument('--keep_archive', type=bool, default=True,
help='If downloaded archives should be kept')
parser.add_argument('--keep_intermediate', type=bool, default=False,
help='If intermediate files should be kept')
return parser.parse_args()
if __name__ == "__main__":
CLI_ARGS = handle_args()
if CLI_ARGS.language == 'all':
for language in LANGUAGES:
prepare_language(language)
elif CLI_ARGS.language in LANGUAGES:
prepare_language(CLI_ARGS.language)
else:
fail('Wrong language id')