Spoken Wikipedia importer
This commit is contained in:
parent
d35107acdb
commit
3424ab2b5d
455
bin/import_swc.py
Executable file
455
bin/import_swc.py
Executable file
@ -0,0 +1,455 @@
|
||||
#!/usr/bin/env python
|
||||
'''
|
||||
Downloads and prepares (parts of) the "Spoken Wikipedia Corpora" for DeepSpeech.py
|
||||
Use "python3 import_swc.py -h" for help
|
||||
'''
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import re
|
||||
import csv
|
||||
import sox
|
||||
import wave
|
||||
import shutil
|
||||
import random
|
||||
import tarfile
|
||||
import argparse
|
||||
import progressbar
|
||||
import unicodedata
|
||||
import xml.etree.cElementTree as ET
|
||||
|
||||
from os import path
|
||||
from glob import glob
|
||||
from collections import Counter
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from util.text import Alphabet, validate_label
|
||||
from util.downloader import maybe_download, SIMPLE_BAR
|
||||
|
||||
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"
|
||||
SWC_ARCHIVE = "SWC_{language}.tar"
|
||||
LANGUAGES = ['dutch', 'english', 'german']
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
CHANNELS = 1
|
||||
SAMPLE_RATE = 16000
|
||||
AUDIO_PATTERN = 'audio*.ogg'
|
||||
WAV_NAME = 'audio.wav'
|
||||
ALIGNED_NAME = 'aligned.swc'
|
||||
|
||||
SUBSTITUTIONS = {
|
||||
'german': [
|
||||
(re.compile(r'\$'), 'dollar'),
|
||||
(re.compile(r'€'), 'euro'),
|
||||
(re.compile(r'£'), 'pfund'),
|
||||
(re.compile(r'ein tausend ([^\s]+) hundert ([^\s]+) er( |$)'), r'\1zehnhundert \2er '),
|
||||
(re.compile(r'ein tausend (acht|neun) hundert'), r'\1zehnhundert'),
|
||||
(re.compile(r'eins punkt null null null punkt null null null punkt null null null'), 'eine milliarde'),
|
||||
(re.compile(r'punkt null null null punkt null null null punkt null null null'), 'milliarden'),
|
||||
(re.compile(r'eins punkt null null null punkt null null null'), 'eine million'),
|
||||
(re.compile(r'punkt null null null punkt null null null'), 'millionen'),
|
||||
(re.compile(r'eins punkt null null null'), 'ein tausend'),
|
||||
(re.compile(r'punkt null null null'), 'tausend'),
|
||||
(re.compile(r'punkt null'), None)
|
||||
] # TODO: Add Dutch and English
|
||||
}
|
||||
|
||||
DONT_NORMALIZE = {
|
||||
'german': 'ÄÖÜäöüß',
|
||||
'dutch': 'IJij'
|
||||
}
|
||||
|
||||
PRE_FILTER = str.maketrans(dict.fromkeys('/()[]{}<>:'))
|
||||
|
||||
|
||||
class Sample:
|
||||
def __init__(self, wav_path, start, end, text, speaker, sub_set=None):
|
||||
self.wav_path = wav_path
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.text = text
|
||||
self.speaker = speaker
|
||||
self.sub_set = sub_set
|
||||
|
||||
|
||||
def fail(message):
|
||||
print(message)
|
||||
exit(1)
|
||||
|
||||
|
||||
def group(lst, get_key):
|
||||
groups = {}
|
||||
for obj in lst:
|
||||
key = get_key(obj)
|
||||
if key in groups:
|
||||
groups[key].append(obj)
|
||||
else:
|
||||
groups[key] = [obj]
|
||||
return groups
|
||||
|
||||
|
||||
def get_sample_size(population_size):
|
||||
margin_of_error = 0.01
|
||||
fraction_picking = 0.50
|
||||
z_score = 2.58 # Corresponds to confidence level 99%
|
||||
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
||||
margin_of_error ** 2
|
||||
)
|
||||
sample_size = 0
|
||||
for train_size in range(population_size, 0, -1):
|
||||
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
||||
margin_of_error ** 2 * train_size
|
||||
)
|
||||
sample_size = int(numerator / denominator)
|
||||
if 2 * sample_size + train_size <= population_size:
|
||||
break
|
||||
return sample_size
|
||||
|
||||
|
||||
def maybe_download_language(language):
|
||||
lang_upper = language[0].upper() + language[1:]
|
||||
return maybe_download(SWC_ARCHIVE.format(language=lang_upper),
|
||||
CLI_ARGS.base_dir,
|
||||
SWC_URL.format(language=lang_upper))
|
||||
|
||||
|
||||
def maybe_extract(data_dir, extracted_data, archive):
|
||||
extracted = path.join(data_dir, extracted_data)
|
||||
if path.isdir(extracted):
|
||||
print('Found directory "{}" - not extracting.'.format(extracted))
|
||||
else:
|
||||
print('Extracting "{}"...'.format(archive))
|
||||
with tarfile.open(archive) as tar:
|
||||
members = tar.getmembers()
|
||||
bar = progressbar.ProgressBar(max_value=len(members), widgets=SIMPLE_BAR)
|
||||
for member in bar(members):
|
||||
tar.extract(member=member, path=extracted)
|
||||
return extracted
|
||||
|
||||
|
||||
def ignored(node):
|
||||
if node is None:
|
||||
return False
|
||||
if node.tag == 'ignored':
|
||||
return True
|
||||
return ignored(node.find('..'))
|
||||
|
||||
|
||||
def read_token(token):
|
||||
texts, start, end = [], None, None
|
||||
notes = token.findall('n')
|
||||
if len(notes) > 0:
|
||||
for note in notes:
|
||||
attributes = note.attrib
|
||||
if start is None and 'start' in attributes:
|
||||
start = int(attributes['start'])
|
||||
if 'end' in attributes:
|
||||
token_end = int(attributes['end'])
|
||||
if end is None or token_end > end:
|
||||
end = token_end
|
||||
if 'pronunciation' in attributes:
|
||||
t = attributes['pronunciation']
|
||||
texts.append(t)
|
||||
elif 'text' in token.attrib:
|
||||
texts.append(token.attrib['text'])
|
||||
return start, end, ' '.join(texts)
|
||||
|
||||
|
||||
def in_alphabet(alphabet, c):
|
||||
if alphabet is None:
|
||||
return False
|
||||
try:
|
||||
alphabet.label_from_string(c)
|
||||
return True
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
|
||||
alphabets = {}
|
||||
def get_alphabet(language):
|
||||
global alphabets
|
||||
if language in alphabets:
|
||||
return alphabets[language]
|
||||
alphabet_path = getattr(CLI_ARGS, language + '_alphabet')
|
||||
alphabet = Alphabet(alphabet_path) if alphabet_path else None
|
||||
alphabets[language] = alphabet
|
||||
return alphabet
|
||||
|
||||
|
||||
def label_filter(label, language):
|
||||
label = label.translate(PRE_FILTER)
|
||||
label = validate_label(label)
|
||||
if label is None:
|
||||
return None, 'validation'
|
||||
substitutions = SUBSTITUTIONS[language] if language in SUBSTITUTIONS else []
|
||||
for pattern, replacement in substitutions:
|
||||
if replacement is None:
|
||||
if pattern.match(label):
|
||||
return None, 'substitution rule'
|
||||
else:
|
||||
label = pattern.sub(replacement, label)
|
||||
chars = []
|
||||
dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else ''
|
||||
alphabet = get_alphabet(language)
|
||||
for c in label:
|
||||
if CLI_ARGS.normalize and c not in dont_normalize and not in_alphabet(alphabet, c):
|
||||
c = (unicodedata.normalize("NFKD", c)
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore"))
|
||||
for sc in c:
|
||||
if alphabet is not None and not in_alphabet(alphabet, sc):
|
||||
return None, 'illegal character'
|
||||
chars.append(sc)
|
||||
label = ''.join(chars)
|
||||
label = validate_label(label)
|
||||
return label, 'validation' if label is None else None
|
||||
|
||||
|
||||
def collect_samples(base_dir, language):
|
||||
roots = []
|
||||
for root, dirs, files in os.walk(base_dir):
|
||||
if ALIGNED_NAME in files and WAV_NAME in files:
|
||||
roots.append(root)
|
||||
samples = []
|
||||
reasons = Counter()
|
||||
print('Collecting samples...')
|
||||
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
|
||||
for root in bar(roots):
|
||||
wav_path = path.join(root, WAV_NAME)
|
||||
aligned = ET.parse(path.join(root, ALIGNED_NAME))
|
||||
speaker = '<unknown>'
|
||||
for prop in aligned.iter('prop'):
|
||||
attributes = prop.attrib
|
||||
if 'key' in attributes and 'value' in attributes and attributes['key'] == 'reader.name':
|
||||
speaker = attributes['value']
|
||||
break
|
||||
for sentence in aligned.iter('s'):
|
||||
def add_sample(start, end, text, reason='complete'):
|
||||
if start is not None and end is not None and text is not None:
|
||||
duration = end - start
|
||||
text, filter_reason = label_filter(text, language)
|
||||
skip = False
|
||||
if filter_reason is not None:
|
||||
skip = True
|
||||
reason = filter_reason
|
||||
elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long:
|
||||
skip = True
|
||||
reason = 'exceeded duration'
|
||||
elif int(duration / 20) < len(text):
|
||||
skip = True
|
||||
reason = 'too short to decode'
|
||||
elif duration / len(text) < 10:
|
||||
skip = True
|
||||
reason = 'length duration ratio'
|
||||
if skip:
|
||||
reasons[reason] += 1
|
||||
else:
|
||||
samples.append(Sample(wav_path, start, end, text, speaker))
|
||||
elif start is None or end is None:
|
||||
reasons['missing timestamps'] += 1
|
||||
else:
|
||||
reasons['missing text'] += 1
|
||||
if ignored(sentence):
|
||||
continue
|
||||
split = False
|
||||
tokens = list(map(lambda token: read_token(token), sentence.findall('t')))
|
||||
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
|
||||
for token_start, token_end, token_text in tokens:
|
||||
if CLI_ARGS.exclude_numbers and any(c.isdigit() for c in token_text):
|
||||
add_sample(sample_start, sample_end, ' '.join(sample_texts), reason='has numbers')
|
||||
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
|
||||
continue
|
||||
if sample_start is None:
|
||||
sample_start = token_start
|
||||
if sample_start is None:
|
||||
continue
|
||||
else:
|
||||
token_texts.append(token_text)
|
||||
if token_end is not None:
|
||||
if token_start != sample_start and token_end - sample_start > CLI_ARGS.max_duration > 0:
|
||||
add_sample(sample_start, sample_end, ' '.join(sample_texts), reason='split')
|
||||
sample_start = sample_end
|
||||
sample_texts = []
|
||||
split = True
|
||||
sample_end = token_end
|
||||
sample_texts.extend(token_texts)
|
||||
token_texts = []
|
||||
add_sample(sample_start, sample_end, ' '.join(sample_texts), reason='split' if split else 'complete')
|
||||
print('Skipped samples:')
|
||||
for reason, n in reasons.most_common():
|
||||
print(' - {}: {}'.format(reason, n))
|
||||
return samples
|
||||
|
||||
|
||||
def maybe_convert_one_to_wav(entry):
|
||||
root, dirs, files = entry
|
||||
transformer = sox.Transformer()
|
||||
transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
|
||||
combiner = sox.Combiner()
|
||||
combiner.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
|
||||
output_wav = path.join(root, WAV_NAME)
|
||||
if path.isfile(output_wav):
|
||||
return
|
||||
files = sorted(glob(path.join(root, AUDIO_PATTERN)))
|
||||
try:
|
||||
if len(files) == 1:
|
||||
transformer.build(files[0], output_wav)
|
||||
elif len(files) > 1:
|
||||
wav_files = []
|
||||
for i, file in enumerate(files):
|
||||
wav_path = path.join(root, 'audio{}.wav'.format(i))
|
||||
transformer.build(file, wav_path)
|
||||
wav_files.append(wav_path)
|
||||
combiner.set_input_format(file_type=['wav'] * len(wav_files))
|
||||
combiner.build(wav_files, output_wav, 'concatenate')
|
||||
except sox.core.SoxError:
|
||||
return
|
||||
|
||||
|
||||
def maybe_convert_to_wav(base_dir):
|
||||
roots = list(os.walk(base_dir))
|
||||
print('Converting and joining source audio files...')
|
||||
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
|
||||
tp = ThreadPool()
|
||||
for _ in bar(tp.imap_unordered(maybe_convert_one_to_wav, roots)):
|
||||
pass
|
||||
tp.close()
|
||||
tp.join()
|
||||
|
||||
|
||||
def assign_sub_sets(samples):
|
||||
sample_size = get_sample_size(len(samples))
|
||||
speakers = group(samples, lambda sample: sample.speaker).values()
|
||||
speakers = list(sorted(speakers, key=lambda speaker_samples: len(speaker_samples)))
|
||||
sample_sets = [[], []]
|
||||
while any(map(lambda s: len(s) < sample_size, sample_sets)) and len(speakers) > 0:
|
||||
for sample_set in sample_sets:
|
||||
if len(sample_set) < sample_size and len(speakers) > 0:
|
||||
sample_set.extend(speakers.pop(0))
|
||||
train_set = sum(speakers, [])
|
||||
if len(train_set) == 0:
|
||||
print('WARNING: Unable to build dev and test sets without speaker bias as there is no speaker meta data')
|
||||
random.seed(42) # same source data == same output
|
||||
random.shuffle(samples)
|
||||
for index, sample in enumerate(samples):
|
||||
if index < sample_size:
|
||||
sample.sub_set = 'dev'
|
||||
elif index < 2 * sample_size:
|
||||
sample.sub_set = 'test'
|
||||
else:
|
||||
sample.sub_set = 'train'
|
||||
else:
|
||||
for sub_set, sub_set_samples in [('train', train_set), ('dev', sample_sets[0]), ('test', sample_sets[1])]:
|
||||
for sample in sub_set_samples:
|
||||
sample.sub_set = sub_set
|
||||
for sub_set, sub_set_samples in group(samples, lambda s: s.sub_set).items():
|
||||
t = sum(map(lambda s: s.end - s.start, sub_set_samples)) / (1000 * 60 * 60)
|
||||
print('Sub-set "{}" with {} samples (duration: {:.2f} h)'
|
||||
.format(sub_set, len(sub_set_samples), t))
|
||||
|
||||
|
||||
def create_sample_dirs(language):
|
||||
print('Creating sample directories...')
|
||||
for set_name in ['train', 'dev', 'test']:
|
||||
dir_path = path.join(CLI_ARGS.base_dir, language + '-' + set_name)
|
||||
if not path.isdir(dir_path):
|
||||
os.mkdir(dir_path)
|
||||
|
||||
|
||||
def split_audio_files(samples, language):
|
||||
print('Splitting audio files...')
|
||||
sub_sets = Counter()
|
||||
src_wav_files = group(samples, lambda s: s.wav_path).items()
|
||||
bar = progressbar.ProgressBar(max_value=len(src_wav_files), widgets=SIMPLE_BAR)
|
||||
for wav_path, file_samples in bar(src_wav_files):
|
||||
file_samples = sorted(file_samples, key=lambda s: s.start)
|
||||
with wave.open(wav_path, 'r') as src_wav_file:
|
||||
rate = src_wav_file.getframerate()
|
||||
for sample in file_samples:
|
||||
index = sub_sets[sample.sub_set]
|
||||
sample_wav_path = path.join(CLI_ARGS.base_dir,
|
||||
language + '-' + sample.sub_set,
|
||||
'sample-{0:06d}.wav'.format(index))
|
||||
sample.wav_path = sample_wav_path
|
||||
sub_sets[sample.sub_set] += 1
|
||||
src_wav_file.setpos(int(sample.start * rate / 1000.0))
|
||||
data = src_wav_file.readframes(int((sample.end - sample.start) * rate / 1000.0))
|
||||
with wave.open(sample_wav_path, 'w') as sample_wav_file:
|
||||
sample_wav_file.setnchannels(src_wav_file.getnchannels())
|
||||
sample_wav_file.setsampwidth(src_wav_file.getsampwidth())
|
||||
sample_wav_file.setframerate(rate)
|
||||
sample_wav_file.writeframes(data)
|
||||
|
||||
|
||||
def write_csvs(samples, language):
|
||||
for sub_set, set_samples in group(samples, lambda s: s.sub_set).items():
|
||||
set_samples = sorted(set_samples, key=lambda s: s.wav_path)
|
||||
base_dir = path.abspath(CLI_ARGS.base_dir)
|
||||
csv_path = path.join(base_dir, language + '-' + sub_set + '.csv')
|
||||
print('Writing "{}"...'.format(csv_path))
|
||||
with open(csv_path, 'w') as csv_file:
|
||||
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES)
|
||||
writer.writeheader()
|
||||
bar = progressbar.ProgressBar(max_value=len(set_samples), widgets=SIMPLE_BAR)
|
||||
for sample in bar(set_samples):
|
||||
writer.writerow({
|
||||
'wav_filename': path.relpath(sample.wav_path, base_dir),
|
||||
'wav_filesize': path.getsize(sample.wav_path),
|
||||
'transcript': sample.text
|
||||
})
|
||||
|
||||
|
||||
def cleanup(archive, language):
|
||||
if not CLI_ARGS.keep_archive:
|
||||
print('Removing archive "{}"...'.format(archive))
|
||||
os.remove(archive)
|
||||
language_dir = path.join(CLI_ARGS.base_dir, language)
|
||||
if not CLI_ARGS.keep_intermediate and path.isdir(language_dir):
|
||||
print('Removing intermediate files in "{}"...'.format(language_dir))
|
||||
shutil.rmtree(language_dir)
|
||||
|
||||
|
||||
def prepare_language(language):
|
||||
archive = maybe_download_language(language)
|
||||
extracted = maybe_extract(CLI_ARGS.base_dir, language, archive)
|
||||
maybe_convert_to_wav(extracted)
|
||||
samples = collect_samples(extracted, language)
|
||||
assign_sub_sets(samples)
|
||||
create_sample_dirs(language)
|
||||
split_audio_files(samples, language)
|
||||
write_csvs(samples, language)
|
||||
cleanup(archive, language)
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = argparse.ArgumentParser(description='Import Spoken Wikipedia Corpora')
|
||||
parser.add_argument('base_dir', help='Directory containing all data')
|
||||
parser.add_argument('--language', default='all', help='One of (all|{})'.format('|'.join(LANGUAGES)))
|
||||
parser.add_argument('--exclude_numbers', type=bool, default=True,
|
||||
help='If sequences with non-transliterated numbers should be excluded')
|
||||
parser.add_argument('--max_duration', type=int, default=10000, help='Maximum sample duration in milliseconds')
|
||||
parser.add_argument('--ignore_too_long', type=bool, default=False,
|
||||
help='If samples exceeding max_duration should be removed')
|
||||
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
|
||||
for language in LANGUAGES:
|
||||
parser.add_argument('--{}_alphabet'.format(language),
|
||||
help='Exclude {} samples with characters not in provided alphabet file'.format(language))
|
||||
parser.add_argument('--keep_archive', type=bool, default=True,
|
||||
help='If downloaded archives should be kept')
|
||||
parser.add_argument('--keep_intermediate', type=bool, default=False,
|
||||
help='If intermediate files should be kept')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI_ARGS = handle_args()
|
||||
if CLI_ARGS.language == 'all':
|
||||
for language in LANGUAGES:
|
||||
prepare_language(language)
|
||||
elif CLI_ARGS.language in LANGUAGES:
|
||||
prepare_language(CLI_ARGS.language)
|
||||
else:
|
||||
fail('Wrong language id')
|
Loading…
x
Reference in New Issue
Block a user