Linter induced changes

This commit is contained in:
Tilman Kamp 2019-10-23 16:47:50 +02:00
parent 010f24578f
commit 122a007d33

View File

@ -54,7 +54,7 @@ SUBSTITUTIONS = {
(re.compile(r'eins punkt null null null'), 'ein tausend'),
(re.compile(r'punkt null null null'), 'tausend'),
(re.compile(r'punkt null'), None)
] # TODO: Add Dutch and English
]
}
DONT_NORMALIZE = {
@ -77,7 +77,7 @@ class Sample:
def fail(message):
print(message)
exit(1)
sys.exit(1)
def group(lst, get_key):
@ -95,14 +95,11 @@ def get_sample_size(population_size):
margin_of_error = 0.01
fraction_picking = 0.50
z_score = 2.58 # Corresponds to confidence level 99%
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2
)
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (margin_of_error ** 2)
sample_size = 0
for train_size in range(population_size, 0, -1):
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2 * train_size
)
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / \
(margin_of_error ** 2 * train_size)
sample_size = int(numerator / denominator)
if 2 * sample_size + train_size <= population_size:
break
@ -162,14 +159,13 @@ def in_alphabet(alphabet, c):
return True if alphabet is None else alphabet.has_char(c)
alphabets = {}
ALPHABETS = {}
def get_alphabet(language):
global alphabets
if language in alphabets:
return alphabets[language]
if language in ALPHABETS:
return ALPHABETS[language]
alphabet_path = getattr(CLI_ARGS, language + '_alphabet')
alphabet = Alphabet(alphabet_path) if alphabet_path else None
alphabets[language] = alphabet
ALPHABETS[language] = alphabet
return alphabet
@ -190,9 +186,7 @@ def label_filter(label, language):
alphabet = get_alphabet(language)
for c in label:
if CLI_ARGS.normalize and c not in dont_normalize and not in_alphabet(alphabet, c):
c = (unicodedata.normalize("NFKD", c)
.encode("ascii", "ignore")
.decode("ascii", "ignore"))
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore")
for sc in c:
if not in_alphabet(alphabet, sc):
return None, 'illegal character'
@ -204,11 +198,38 @@ def label_filter(label, language):
def collect_samples(base_dir, language):
roots = []
for root, dirs, files in os.walk(base_dir):
for root, _, files in os.walk(base_dir):
if ALIGNED_NAME in files and WAV_NAME in files:
roots.append(root)
samples = []
reasons = Counter()
def add_sample(p_wav_path, p_speaker, p_start, p_end, p_text, p_reason='complete'):
if p_start is not None and p_end is not None and p_text is not None:
duration = p_end - p_start
text, filter_reason = label_filter(p_text, language)
skip = False
if filter_reason is not None:
skip = True
p_reason = filter_reason
elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long:
skip = True
p_reason = 'exceeded duration'
elif int(duration / 20) < len(text):
skip = True
p_reason = 'too short to decode'
elif duration / len(text) < 10:
skip = True
p_reason = 'length duration ratio'
if skip:
reasons[p_reason] += 1
else:
samples.append(Sample(p_wav_path, p_start, p_end, text, p_speaker))
elif p_start is None or p_end is None:
reasons['missing timestamps'] += 1
else:
reasons['missing text'] += 1
print('Collecting samples...')
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
for root in bar(roots):
@ -221,57 +242,34 @@ def collect_samples(base_dir, language):
speaker = attributes['value']
break
for sentence in aligned.iter('s'):
def add_sample(start, end, text, reason='complete'):
if start is not None and end is not None and text is not None:
duration = end - start
text, filter_reason = label_filter(text, language)
skip = False
if filter_reason is not None:
skip = True
reason = filter_reason
elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long:
skip = True
reason = 'exceeded duration'
elif int(duration / 20) < len(text):
skip = True
reason = 'too short to decode'
elif duration / len(text) < 10:
skip = True
reason = 'length duration ratio'
if skip:
reasons[reason] += 1
else:
samples.append(Sample(wav_path, start, end, text, speaker))
elif start is None or end is None:
reasons['missing timestamps'] += 1
else:
reasons['missing text'] += 1
if ignored(sentence):
continue
split = False
tokens = list(map(lambda token: read_token(token), sentence.findall('t')))
tokens = list(map(read_token, sentence.findall('t')))
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
for token_start, token_end, token_text in tokens:
if CLI_ARGS.exclude_numbers and any(c.isdigit() for c in token_text):
add_sample(sample_start, sample_end, ' '.join(sample_texts), reason='has numbers')
add_sample(wav_path, speaker, sample_start, sample_end, ' '.join(sample_texts),
p_reason='has numbers')
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
continue
if sample_start is None:
sample_start = token_start
if sample_start is None:
continue
else:
token_texts.append(token_text)
token_texts.append(token_text)
if token_end is not None:
if token_start != sample_start and token_end - sample_start > CLI_ARGS.max_duration > 0:
add_sample(sample_start, sample_end, ' '.join(sample_texts), reason='split')
add_sample(wav_path, speaker, sample_start, sample_end, ' '.join(sample_texts),
p_reason='split')
sample_start = sample_end
sample_texts = []
split = True
sample_end = token_end
sample_texts.extend(token_texts)
token_texts = []
add_sample(sample_start, sample_end, ' '.join(sample_texts), reason='split' if split else 'complete')
add_sample(wav_path, speaker, sample_start, sample_end, ' '.join(sample_texts),
p_reason='split' if split else 'complete')
print('Skipped samples:')
for reason, n in reasons.most_common():
print(' - {}: {}'.format(reason, n))
@ -279,7 +277,7 @@ def collect_samples(base_dir, language):
def maybe_convert_one_to_wav(entry):
root, dirs, files = entry
root, _, files = entry
transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
combiner = sox.Combiner()
@ -317,7 +315,7 @@ def maybe_convert_to_wav(base_dir):
def assign_sub_sets(samples):
sample_size = get_sample_size(len(samples))
speakers = group(samples, lambda sample: sample.speaker).values()
speakers = list(sorted(speakers, key=lambda speaker_samples: len(speaker_samples)))
speakers = list(sorted(speakers, key=len))
sample_sets = [[], []]
while any(map(lambda s: len(s) < sample_size, sample_sets)) and len(speakers) > 0:
for sample_set in sample_sets:
@ -441,8 +439,8 @@ def handle_args():
if __name__ == "__main__":
CLI_ARGS = handle_args()
if CLI_ARGS.language == 'all':
for language in LANGUAGES:
prepare_language(language)
for lang in LANGUAGES:
prepare_language(lang)
elif CLI_ARGS.language in LANGUAGES:
prepare_language(CLI_ARGS.language)
else: