Linter induced changes

This commit is contained in:
Tilman Kamp 2019-10-23 16:47:50 +02:00
parent 010f24578f
commit 122a007d33

View File

@ -54,7 +54,7 @@ SUBSTITUTIONS = {
(re.compile(r'eins punkt null null null'), 'ein tausend'), (re.compile(r'eins punkt null null null'), 'ein tausend'),
(re.compile(r'punkt null null null'), 'tausend'), (re.compile(r'punkt null null null'), 'tausend'),
(re.compile(r'punkt null'), None) (re.compile(r'punkt null'), None)
] # TODO: Add Dutch and English ]
} }
DONT_NORMALIZE = { DONT_NORMALIZE = {
@ -77,7 +77,7 @@ class Sample:
def fail(message): def fail(message):
print(message) print(message)
exit(1) sys.exit(1)
def group(lst, get_key): def group(lst, get_key):
@ -95,14 +95,11 @@ def get_sample_size(population_size):
margin_of_error = 0.01 margin_of_error = 0.01
fraction_picking = 0.50 fraction_picking = 0.50
z_score = 2.58 # Corresponds to confidence level 99% z_score = 2.58 # Corresponds to confidence level 99%
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / ( numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (margin_of_error ** 2)
margin_of_error ** 2
)
sample_size = 0 sample_size = 0
for train_size in range(population_size, 0, -1): for train_size in range(population_size, 0, -1):
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / ( denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / \
margin_of_error ** 2 * train_size (margin_of_error ** 2 * train_size)
)
sample_size = int(numerator / denominator) sample_size = int(numerator / denominator)
if 2 * sample_size + train_size <= population_size: if 2 * sample_size + train_size <= population_size:
break break
@ -162,14 +159,13 @@ def in_alphabet(alphabet, c):
return True if alphabet is None else alphabet.has_char(c) return True if alphabet is None else alphabet.has_char(c)
alphabets = {} ALPHABETS = {}
def get_alphabet(language): def get_alphabet(language):
global alphabets if language in ALPHABETS:
if language in alphabets: return ALPHABETS[language]
return alphabets[language]
alphabet_path = getattr(CLI_ARGS, language + '_alphabet') alphabet_path = getattr(CLI_ARGS, language + '_alphabet')
alphabet = Alphabet(alphabet_path) if alphabet_path else None alphabet = Alphabet(alphabet_path) if alphabet_path else None
alphabets[language] = alphabet ALPHABETS[language] = alphabet
return alphabet return alphabet
@ -190,9 +186,7 @@ def label_filter(label, language):
alphabet = get_alphabet(language) alphabet = get_alphabet(language)
for c in label: for c in label:
if CLI_ARGS.normalize and c not in dont_normalize and not in_alphabet(alphabet, c): if CLI_ARGS.normalize and c not in dont_normalize and not in_alphabet(alphabet, c):
c = (unicodedata.normalize("NFKD", c) c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore")
.encode("ascii", "ignore")
.decode("ascii", "ignore"))
for sc in c: for sc in c:
if not in_alphabet(alphabet, sc): if not in_alphabet(alphabet, sc):
return None, 'illegal character' return None, 'illegal character'
@ -204,11 +198,38 @@ def label_filter(label, language):
def collect_samples(base_dir, language): def collect_samples(base_dir, language):
roots = [] roots = []
for root, dirs, files in os.walk(base_dir): for root, _, files in os.walk(base_dir):
if ALIGNED_NAME in files and WAV_NAME in files: if ALIGNED_NAME in files and WAV_NAME in files:
roots.append(root) roots.append(root)
samples = [] samples = []
reasons = Counter() reasons = Counter()
def add_sample(p_wav_path, p_speaker, p_start, p_end, p_text, p_reason='complete'):
if p_start is not None and p_end is not None and p_text is not None:
duration = p_end - p_start
text, filter_reason = label_filter(p_text, language)
skip = False
if filter_reason is not None:
skip = True
p_reason = filter_reason
elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long:
skip = True
p_reason = 'exceeded duration'
elif int(duration / 20) < len(text):
skip = True
p_reason = 'too short to decode'
elif duration / len(text) < 10:
skip = True
p_reason = 'length duration ratio'
if skip:
reasons[p_reason] += 1
else:
samples.append(Sample(p_wav_path, p_start, p_end, text, p_speaker))
elif p_start is None or p_end is None:
reasons['missing timestamps'] += 1
else:
reasons['missing text'] += 1
print('Collecting samples...') print('Collecting samples...')
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
for root in bar(roots): for root in bar(roots):
@ -221,57 +242,34 @@ def collect_samples(base_dir, language):
speaker = attributes['value'] speaker = attributes['value']
break break
for sentence in aligned.iter('s'): for sentence in aligned.iter('s'):
def add_sample(start, end, text, reason='complete'):
if start is not None and end is not None and text is not None:
duration = end - start
text, filter_reason = label_filter(text, language)
skip = False
if filter_reason is not None:
skip = True
reason = filter_reason
elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long:
skip = True
reason = 'exceeded duration'
elif int(duration / 20) < len(text):
skip = True
reason = 'too short to decode'
elif duration / len(text) < 10:
skip = True
reason = 'length duration ratio'
if skip:
reasons[reason] += 1
else:
samples.append(Sample(wav_path, start, end, text, speaker))
elif start is None or end is None:
reasons['missing timestamps'] += 1
else:
reasons['missing text'] += 1
if ignored(sentence): if ignored(sentence):
continue continue
split = False split = False
tokens = list(map(lambda token: read_token(token), sentence.findall('t'))) tokens = list(map(read_token, sentence.findall('t')))
sample_start, sample_end, token_texts, sample_texts = None, None, [], [] sample_start, sample_end, token_texts, sample_texts = None, None, [], []
for token_start, token_end, token_text in tokens: for token_start, token_end, token_text in tokens:
if CLI_ARGS.exclude_numbers and any(c.isdigit() for c in token_text): if CLI_ARGS.exclude_numbers and any(c.isdigit() for c in token_text):
add_sample(sample_start, sample_end, ' '.join(sample_texts), reason='has numbers') add_sample(wav_path, speaker, sample_start, sample_end, ' '.join(sample_texts),
p_reason='has numbers')
sample_start, sample_end, token_texts, sample_texts = None, None, [], [] sample_start, sample_end, token_texts, sample_texts = None, None, [], []
continue continue
if sample_start is None: if sample_start is None:
sample_start = token_start sample_start = token_start
if sample_start is None: if sample_start is None:
continue continue
else: token_texts.append(token_text)
token_texts.append(token_text)
if token_end is not None: if token_end is not None:
if token_start != sample_start and token_end - sample_start > CLI_ARGS.max_duration > 0: if token_start != sample_start and token_end - sample_start > CLI_ARGS.max_duration > 0:
add_sample(sample_start, sample_end, ' '.join(sample_texts), reason='split') add_sample(wav_path, speaker, sample_start, sample_end, ' '.join(sample_texts),
p_reason='split')
sample_start = sample_end sample_start = sample_end
sample_texts = [] sample_texts = []
split = True split = True
sample_end = token_end sample_end = token_end
sample_texts.extend(token_texts) sample_texts.extend(token_texts)
token_texts = [] token_texts = []
add_sample(sample_start, sample_end, ' '.join(sample_texts), reason='split' if split else 'complete') add_sample(wav_path, speaker, sample_start, sample_end, ' '.join(sample_texts),
p_reason='split' if split else 'complete')
print('Skipped samples:') print('Skipped samples:')
for reason, n in reasons.most_common(): for reason, n in reasons.most_common():
print(' - {}: {}'.format(reason, n)) print(' - {}: {}'.format(reason, n))
@ -279,7 +277,7 @@ def collect_samples(base_dir, language):
def maybe_convert_one_to_wav(entry): def maybe_convert_one_to_wav(entry):
root, dirs, files = entry root, _, files = entry
transformer = sox.Transformer() transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS) transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
combiner = sox.Combiner() combiner = sox.Combiner()
@ -317,7 +315,7 @@ def maybe_convert_to_wav(base_dir):
def assign_sub_sets(samples): def assign_sub_sets(samples):
sample_size = get_sample_size(len(samples)) sample_size = get_sample_size(len(samples))
speakers = group(samples, lambda sample: sample.speaker).values() speakers = group(samples, lambda sample: sample.speaker).values()
speakers = list(sorted(speakers, key=lambda speaker_samples: len(speaker_samples))) speakers = list(sorted(speakers, key=len))
sample_sets = [[], []] sample_sets = [[], []]
while any(map(lambda s: len(s) < sample_size, sample_sets)) and len(speakers) > 0: while any(map(lambda s: len(s) < sample_size, sample_sets)) and len(speakers) > 0:
for sample_set in sample_sets: for sample_set in sample_sets:
@ -441,8 +439,8 @@ def handle_args():
if __name__ == "__main__": if __name__ == "__main__":
CLI_ARGS = handle_args() CLI_ARGS = handle_args()
if CLI_ARGS.language == 'all': if CLI_ARGS.language == 'all':
for language in LANGUAGES: for lang in LANGUAGES:
prepare_language(language) prepare_language(lang)
elif CLI_ARGS.language in LANGUAGES: elif CLI_ARGS.language in LANGUAGES:
prepare_language(CLI_ARGS.language) prepare_language(CLI_ARGS.language)
else: else: