diff --git a/bin/import_swc.py b/bin/import_swc.py index 2046c584..cada7a9a 100755 --- a/bin/import_swc.py +++ b/bin/import_swc.py @@ -54,7 +54,7 @@ SUBSTITUTIONS = { (re.compile(r'eins punkt null null null'), 'ein tausend'), (re.compile(r'punkt null null null'), 'tausend'), (re.compile(r'punkt null'), None) - ] # TODO: Add Dutch and English + ] } DONT_NORMALIZE = { @@ -77,7 +77,7 @@ class Sample: def fail(message): print(message) - exit(1) + sys.exit(1) def group(lst, get_key): @@ -95,14 +95,11 @@ def get_sample_size(population_size): margin_of_error = 0.01 fraction_picking = 0.50 z_score = 2.58 # Corresponds to confidence level 99% - numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / ( - margin_of_error ** 2 - ) + numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (margin_of_error ** 2) sample_size = 0 for train_size in range(population_size, 0, -1): - denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / ( - margin_of_error ** 2 * train_size - ) + denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / \ + (margin_of_error ** 2 * train_size) sample_size = int(numerator / denominator) if 2 * sample_size + train_size <= population_size: break @@ -162,14 +159,13 @@ def in_alphabet(alphabet, c): return True if alphabet is None else alphabet.has_char(c) -alphabets = {} +ALPHABETS = {} def get_alphabet(language): - global alphabets - if language in alphabets: - return alphabets[language] + if language in ALPHABETS: + return ALPHABETS[language] alphabet_path = getattr(CLI_ARGS, language + '_alphabet') alphabet = Alphabet(alphabet_path) if alphabet_path else None - alphabets[language] = alphabet + ALPHABETS[language] = alphabet return alphabet @@ -190,9 +186,7 @@ def label_filter(label, language): alphabet = get_alphabet(language) for c in label: if CLI_ARGS.normalize and c not in dont_normalize and not in_alphabet(alphabet, c): - c = (unicodedata.normalize("NFKD", c) - .encode("ascii", "ignore") - .decode("ascii", "ignore")) + c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore") for sc in c: if not in_alphabet(alphabet, sc): return None, 'illegal character' @@ -204,11 +198,38 @@ def label_filter(label, language): def collect_samples(base_dir, language): roots = [] - for root, dirs, files in os.walk(base_dir): + for root, _, files in os.walk(base_dir): if ALIGNED_NAME in files and WAV_NAME in files: roots.append(root) samples = [] reasons = Counter() + + def add_sample(p_wav_path, p_speaker, p_start, p_end, p_text, p_reason='complete'): + if p_start is not None and p_end is not None and p_text is not None: + duration = p_end - p_start + text, filter_reason = label_filter(p_text, language) + skip = False + if filter_reason is not None: + skip = True + p_reason = filter_reason + elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long: + skip = True + p_reason = 'exceeded duration' + elif int(duration / 20) < len(text): + skip = True + p_reason = 'too short to decode' + elif duration / len(text) < 10: + skip = True + p_reason = 'length duration ratio' + if skip: + reasons[p_reason] += 1 + else: + samples.append(Sample(p_wav_path, p_start, p_end, text, p_speaker)) + elif p_start is None or p_end is None: + reasons['missing timestamps'] += 1 + else: + reasons['missing text'] += 1 + print('Collecting samples...') bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR) for root in bar(roots): @@ -221,57 +242,34 @@ def collect_samples(base_dir, language): speaker = attributes['value'] break for sentence in aligned.iter('s'): - def add_sample(start, end, text, reason='complete'): - if start is not None and end is not None and text is not None: - duration = end - start - text, filter_reason = label_filter(text, language) - skip = False - if filter_reason is not None: - skip = True - reason = filter_reason - elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long: - skip = True - reason = 'exceeded duration' - elif int(duration / 20) < len(text): - skip = True - reason = 'too short to decode' - elif duration / len(text) < 10: - skip = True - reason = 'length duration ratio' - if skip: - reasons[reason] += 1 - else: - samples.append(Sample(wav_path, start, end, text, speaker)) - elif start is None or end is None: - reasons['missing timestamps'] += 1 - else: - reasons['missing text'] += 1 if ignored(sentence): continue split = False - tokens = list(map(lambda token: read_token(token), sentence.findall('t'))) + tokens = list(map(read_token, sentence.findall('t'))) sample_start, sample_end, token_texts, sample_texts = None, None, [], [] for token_start, token_end, token_text in tokens: if CLI_ARGS.exclude_numbers and any(c.isdigit() for c in token_text): - add_sample(sample_start, sample_end, ' '.join(sample_texts), reason='has numbers') + add_sample(wav_path, speaker, sample_start, sample_end, ' '.join(sample_texts), + p_reason='has numbers') sample_start, sample_end, token_texts, sample_texts = None, None, [], [] continue if sample_start is None: sample_start = token_start if sample_start is None: continue - else: - token_texts.append(token_text) + token_texts.append(token_text) if token_end is not None: if token_start != sample_start and token_end - sample_start > CLI_ARGS.max_duration > 0: - add_sample(sample_start, sample_end, ' '.join(sample_texts), reason='split') + add_sample(wav_path, speaker, sample_start, sample_end, ' '.join(sample_texts), + p_reason='split') sample_start = sample_end sample_texts = [] split = True sample_end = token_end sample_texts.extend(token_texts) token_texts = [] - add_sample(sample_start, sample_end, ' '.join(sample_texts), reason='split' if split else 'complete') + add_sample(wav_path, speaker, sample_start, sample_end, ' '.join(sample_texts), + p_reason='split' if split else 'complete') print('Skipped samples:') for reason, n in reasons.most_common(): print(' - {}: {}'.format(reason, n)) @@ -279,7 +277,7 @@ def collect_samples(base_dir, language): def maybe_convert_one_to_wav(entry): - root, dirs, files = entry + root, _, files = entry transformer = sox.Transformer() transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS) combiner = sox.Combiner() @@ -317,7 +315,7 @@ def maybe_convert_to_wav(base_dir): def assign_sub_sets(samples): sample_size = get_sample_size(len(samples)) speakers = group(samples, lambda sample: sample.speaker).values() - speakers = list(sorted(speakers, key=lambda speaker_samples: len(speaker_samples))) + speakers = list(sorted(speakers, key=len)) sample_sets = [[], []] while any(map(lambda s: len(s) < sample_size, sample_sets)) and len(speakers) > 0: for sample_set in sample_sets: @@ -441,8 +439,8 @@ def handle_args(): if __name__ == "__main__": CLI_ARGS = handle_args() if CLI_ARGS.language == 'all': - for language in LANGUAGES: - prepare_language(language) + for lang in LANGUAGES: + prepare_language(lang) elif CLI_ARGS.language in LANGUAGES: prepare_language(CLI_ARGS.language) else: