Merge pull request #2459 from tilmankamp/import_swc

Spoken Wikipedia importer
This commit is contained in:
Tilman Kamp 2019-10-24 15:35:49 +02:00 committed by GitHub
commit ede4dd6f93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 449 additions and 0 deletions

446
bin/import_swc.py Executable file
View File

@ -0,0 +1,446 @@
#!/usr/bin/env python
'''
Downloads and prepares (parts of) the "Spoken Wikipedia Corpora" for DeepSpeech.py
Use "python3 import_swc.py -h" for help
'''
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import re
import csv
import sox
import wave
import shutil
import random
import tarfile
import argparse
import progressbar
import unicodedata
import xml.etree.cElementTree as ET
from os import path
from glob import glob
from collections import Counter
from multiprocessing.pool import ThreadPool
from util.text import Alphabet, validate_label
from util.downloader import maybe_download, SIMPLE_BAR
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"
SWC_ARCHIVE = "SWC_{language}.tar"
LANGUAGES = ['dutch', 'english', 'german']
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
CHANNELS = 1
SAMPLE_RATE = 16000
AUDIO_PATTERN = 'audio*.ogg'
WAV_NAME = 'audio.wav'
ALIGNED_NAME = 'aligned.swc'
SUBSTITUTIONS = {
'german': [
(re.compile(r'\$'), 'dollar'),
(re.compile(r''), 'euro'),
(re.compile(r'£'), 'pfund'),
(re.compile(r'ein tausend ([^\s]+) hundert ([^\s]+) er( |$)'), r'\1zehnhundert \2er '),
(re.compile(r'ein tausend (acht|neun) hundert'), r'\1zehnhundert'),
(re.compile(r'eins punkt null null null punkt null null null punkt null null null'), 'eine milliarde'),
(re.compile(r'punkt null null null punkt null null null punkt null null null'), 'milliarden'),
(re.compile(r'eins punkt null null null punkt null null null'), 'eine million'),
(re.compile(r'punkt null null null punkt null null null'), 'millionen'),
(re.compile(r'eins punkt null null null'), 'ein tausend'),
(re.compile(r'punkt null null null'), 'tausend'),
(re.compile(r'punkt null'), None)
]
}
DONT_NORMALIZE = {
'german': 'ÄÖÜäöüß'
}
PRE_FILTER = str.maketrans(dict.fromkeys('/()[]{}<>:'))
class Sample:
def __init__(self, wav_path, start, end, text, speaker, sub_set=None):
self.wav_path = wav_path
self.start = start
self.end = end
self.text = text
self.speaker = speaker
self.sub_set = sub_set
def fail(message):
print(message)
sys.exit(1)
def group(lst, get_key):
groups = {}
for obj in lst:
key = get_key(obj)
if key in groups:
groups[key].append(obj)
else:
groups[key] = [obj]
return groups
def get_sample_size(population_size):
margin_of_error = 0.01
fraction_picking = 0.50
z_score = 2.58 # Corresponds to confidence level 99%
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (margin_of_error ** 2)
sample_size = 0
for train_size in range(population_size, 0, -1):
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / \
(margin_of_error ** 2 * train_size)
sample_size = int(numerator / denominator)
if 2 * sample_size + train_size <= population_size:
break
return sample_size
def maybe_download_language(language):
lang_upper = language[0].upper() + language[1:]
return maybe_download(SWC_ARCHIVE.format(language=lang_upper),
CLI_ARGS.base_dir,
SWC_URL.format(language=lang_upper))
def maybe_extract(data_dir, extracted_data, archive):
extracted = path.join(data_dir, extracted_data)
if path.isdir(extracted):
print('Found directory "{}" - not extracting.'.format(extracted))
else:
print('Extracting "{}"...'.format(archive))
with tarfile.open(archive) as tar:
members = tar.getmembers()
bar = progressbar.ProgressBar(max_value=len(members), widgets=SIMPLE_BAR)
for member in bar(members):
tar.extract(member=member, path=extracted)
return extracted
def ignored(node):
if node is None:
return False
if node.tag == 'ignored':
return True
return ignored(node.find('..'))
def read_token(token):
texts, start, end = [], None, None
notes = token.findall('n')
if len(notes) > 0:
for note in notes:
attributes = note.attrib
if start is None and 'start' in attributes:
start = int(attributes['start'])
if 'end' in attributes:
token_end = int(attributes['end'])
if end is None or token_end > end:
end = token_end
if 'pronunciation' in attributes:
t = attributes['pronunciation']
texts.append(t)
elif 'text' in token.attrib:
texts.append(token.attrib['text'])
return start, end, ' '.join(texts)
def in_alphabet(alphabet, c):
return True if alphabet is None else alphabet.has_char(c)
ALPHABETS = {}
def get_alphabet(language):
if language in ALPHABETS:
return ALPHABETS[language]
alphabet_path = getattr(CLI_ARGS, language + '_alphabet')
alphabet = Alphabet(alphabet_path) if alphabet_path else None
ALPHABETS[language] = alphabet
return alphabet
def label_filter(label, language):
label = label.translate(PRE_FILTER)
label = validate_label(label)
if label is None:
return None, 'validation'
substitutions = SUBSTITUTIONS[language] if language in SUBSTITUTIONS else []
for pattern, replacement in substitutions:
if replacement is None:
if pattern.match(label):
return None, 'substitution rule'
else:
label = pattern.sub(replacement, label)
chars = []
dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else ''
alphabet = get_alphabet(language)
for c in label:
if CLI_ARGS.normalize and c not in dont_normalize and not in_alphabet(alphabet, c):
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore")
for sc in c:
if not in_alphabet(alphabet, sc):
return None, 'illegal character'
chars.append(sc)
label = ''.join(chars)
label = validate_label(label)
return label, 'validation' if label is None else None
def collect_samples(base_dir, language):
roots = []
for root, _, files in os.walk(base_dir):
if ALIGNED_NAME in files and WAV_NAME in files:
roots.append(root)
samples = []
reasons = Counter()
def add_sample(p_wav_path, p_speaker, p_start, p_end, p_text, p_reason='complete'):
if p_start is not None and p_end is not None and p_text is not None:
duration = p_end - p_start
text, filter_reason = label_filter(p_text, language)
skip = False
if filter_reason is not None:
skip = True
p_reason = filter_reason
elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long:
skip = True
p_reason = 'exceeded duration'
elif int(duration / 20) < len(text):
skip = True
p_reason = 'too short to decode'
elif duration / len(text) < 10:
skip = True
p_reason = 'length duration ratio'
if skip:
reasons[p_reason] += 1
else:
samples.append(Sample(p_wav_path, p_start, p_end, text, p_speaker))
elif p_start is None or p_end is None:
reasons['missing timestamps'] += 1
else:
reasons['missing text'] += 1
print('Collecting samples...')
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
for root in bar(roots):
wav_path = path.join(root, WAV_NAME)
aligned = ET.parse(path.join(root, ALIGNED_NAME))
speaker = '<unknown>'
for prop in aligned.iter('prop'):
attributes = prop.attrib
if 'key' in attributes and 'value' in attributes and attributes['key'] == 'reader.name':
speaker = attributes['value']
break
for sentence in aligned.iter('s'):
if ignored(sentence):
continue
split = False
tokens = list(map(read_token, sentence.findall('t')))
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
for token_start, token_end, token_text in tokens:
if CLI_ARGS.exclude_numbers and any(c.isdigit() for c in token_text):
add_sample(wav_path, speaker, sample_start, sample_end, ' '.join(sample_texts),
p_reason='has numbers')
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
continue
if sample_start is None:
sample_start = token_start
if sample_start is None:
continue
token_texts.append(token_text)
if token_end is not None:
if token_start != sample_start and token_end - sample_start > CLI_ARGS.max_duration > 0:
add_sample(wav_path, speaker, sample_start, sample_end, ' '.join(sample_texts),
p_reason='split')
sample_start = sample_end
sample_texts = []
split = True
sample_end = token_end
sample_texts.extend(token_texts)
token_texts = []
add_sample(wav_path, speaker, sample_start, sample_end, ' '.join(sample_texts),
p_reason='split' if split else 'complete')
print('Skipped samples:')
for reason, n in reasons.most_common():
print(' - {}: {}'.format(reason, n))
return samples
def maybe_convert_one_to_wav(entry):
root, _, files = entry
transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
combiner = sox.Combiner()
combiner.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
output_wav = path.join(root, WAV_NAME)
if path.isfile(output_wav):
return
files = sorted(glob(path.join(root, AUDIO_PATTERN)))
try:
if len(files) == 1:
transformer.build(files[0], output_wav)
elif len(files) > 1:
wav_files = []
for i, file in enumerate(files):
wav_path = path.join(root, 'audio{}.wav'.format(i))
transformer.build(file, wav_path)
wav_files.append(wav_path)
combiner.set_input_format(file_type=['wav'] * len(wav_files))
combiner.build(wav_files, output_wav, 'concatenate')
except sox.core.SoxError:
return
def maybe_convert_to_wav(base_dir):
roots = list(os.walk(base_dir))
print('Converting and joining source audio files...')
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
tp = ThreadPool()
for _ in bar(tp.imap_unordered(maybe_convert_one_to_wav, roots)):
pass
tp.close()
tp.join()
def assign_sub_sets(samples):
sample_size = get_sample_size(len(samples))
speakers = group(samples, lambda sample: sample.speaker).values()
speakers = list(sorted(speakers, key=len))
sample_sets = [[], []]
while any(map(lambda s: len(s) < sample_size, sample_sets)) and len(speakers) > 0:
for sample_set in sample_sets:
if len(sample_set) < sample_size and len(speakers) > 0:
sample_set.extend(speakers.pop(0))
train_set = sum(speakers, [])
if len(train_set) == 0:
print('WARNING: Unable to build dev and test sets without speaker bias as there is no speaker meta data')
random.seed(42) # same source data == same output
random.shuffle(samples)
for index, sample in enumerate(samples):
if index < sample_size:
sample.sub_set = 'dev'
elif index < 2 * sample_size:
sample.sub_set = 'test'
else:
sample.sub_set = 'train'
else:
for sub_set, sub_set_samples in [('train', train_set), ('dev', sample_sets[0]), ('test', sample_sets[1])]:
for sample in sub_set_samples:
sample.sub_set = sub_set
for sub_set, sub_set_samples in group(samples, lambda s: s.sub_set).items():
t = sum(map(lambda s: s.end - s.start, sub_set_samples)) / (1000 * 60 * 60)
print('Sub-set "{}" with {} samples (duration: {:.2f} h)'
.format(sub_set, len(sub_set_samples), t))
def create_sample_dirs(language):
print('Creating sample directories...')
for set_name in ['train', 'dev', 'test']:
dir_path = path.join(CLI_ARGS.base_dir, language + '-' + set_name)
if not path.isdir(dir_path):
os.mkdir(dir_path)
def split_audio_files(samples, language):
print('Splitting audio files...')
sub_sets = Counter()
src_wav_files = group(samples, lambda s: s.wav_path).items()
bar = progressbar.ProgressBar(max_value=len(src_wav_files), widgets=SIMPLE_BAR)
for wav_path, file_samples in bar(src_wav_files):
file_samples = sorted(file_samples, key=lambda s: s.start)
with wave.open(wav_path, 'r') as src_wav_file:
rate = src_wav_file.getframerate()
for sample in file_samples:
index = sub_sets[sample.sub_set]
sample_wav_path = path.join(CLI_ARGS.base_dir,
language + '-' + sample.sub_set,
'sample-{0:06d}.wav'.format(index))
sample.wav_path = sample_wav_path
sub_sets[sample.sub_set] += 1
src_wav_file.setpos(int(sample.start * rate / 1000.0))
data = src_wav_file.readframes(int((sample.end - sample.start) * rate / 1000.0))
with wave.open(sample_wav_path, 'w') as sample_wav_file:
sample_wav_file.setnchannels(src_wav_file.getnchannels())
sample_wav_file.setsampwidth(src_wav_file.getsampwidth())
sample_wav_file.setframerate(rate)
sample_wav_file.writeframes(data)
def write_csvs(samples, language):
for sub_set, set_samples in group(samples, lambda s: s.sub_set).items():
set_samples = sorted(set_samples, key=lambda s: s.wav_path)
base_dir = path.abspath(CLI_ARGS.base_dir)
csv_path = path.join(base_dir, language + '-' + sub_set + '.csv')
print('Writing "{}"...'.format(csv_path))
with open(csv_path, 'w') as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES)
writer.writeheader()
bar = progressbar.ProgressBar(max_value=len(set_samples), widgets=SIMPLE_BAR)
for sample in bar(set_samples):
writer.writerow({
'wav_filename': path.relpath(sample.wav_path, base_dir),
'wav_filesize': path.getsize(sample.wav_path),
'transcript': sample.text
})
def cleanup(archive, language):
if not CLI_ARGS.keep_archive:
print('Removing archive "{}"...'.format(archive))
os.remove(archive)
language_dir = path.join(CLI_ARGS.base_dir, language)
if not CLI_ARGS.keep_intermediate and path.isdir(language_dir):
print('Removing intermediate files in "{}"...'.format(language_dir))
shutil.rmtree(language_dir)
def prepare_language(language):
archive = maybe_download_language(language)
extracted = maybe_extract(CLI_ARGS.base_dir, language, archive)
maybe_convert_to_wav(extracted)
samples = collect_samples(extracted, language)
assign_sub_sets(samples)
create_sample_dirs(language)
split_audio_files(samples, language)
write_csvs(samples, language)
cleanup(archive, language)
def handle_args():
parser = argparse.ArgumentParser(description='Import Spoken Wikipedia Corpora')
parser.add_argument('base_dir', help='Directory containing all data')
parser.add_argument('--language', default='all', help='One of (all|{})'.format('|'.join(LANGUAGES)))
parser.add_argument('--exclude_numbers', type=bool, default=True,
help='If sequences with non-transliterated numbers should be excluded')
parser.add_argument('--max_duration', type=int, default=10000, help='Maximum sample duration in milliseconds')
parser.add_argument('--ignore_too_long', type=bool, default=False,
help='If samples exceeding max_duration should be removed')
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
for language in LANGUAGES:
parser.add_argument('--{}_alphabet'.format(language),
help='Exclude {} samples with characters not in provided alphabet file'.format(language))
parser.add_argument('--keep_archive', type=bool, default=True,
help='If downloaded archives should be kept')
parser.add_argument('--keep_intermediate', type=bool, default=False,
help='If intermediate files should be kept')
return parser.parse_args()
if __name__ == "__main__":
CLI_ARGS = handle_args()
if CLI_ARGS.language == 'all':
for lang in LANGUAGES:
prepare_language(lang)
elif CLI_ARGS.language in LANGUAGES:
prepare_language(CLI_ARGS.language)
else:
fail('Wrong language id')

View File

@ -36,6 +36,9 @@ class Alphabet(object):
'then add all these to data/alphabet.txt.'.format(string)
).with_traceback(e.__traceback__)
def has_char(self, char):
return char in self._str_to_label
def encode(self, string):
res = []
for char in string: