From 45d8f7cd617ebccafda3b0228e352dfdb545f998 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Mon, 25 May 2020 17:47:17 +0200 Subject: [PATCH 1/2] Explicitly pass filter context to multiprocessing function --- bin/import_cv2.py | 103 ++++++++++++++++++++++++++-------------------- 1 file changed, 58 insertions(+), 45 deletions(-) diff --git a/bin/import_cv2.py b/bin/import_cv2.py index 7c85c8f0..50434605 100755 --- a/bin/import_cv2.py +++ b/bin/import_cv2.py @@ -7,6 +7,7 @@ DeepSpeech.py Use "python3 import_cv2.py -h" for help """ import csv +import itertools import os import subprocess import unicodedata @@ -30,19 +31,20 @@ SAMPLE_RATE = 16000 MAX_SECS = 10 -def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False): +def _preprocess_data(tsv_dir, audio_dir, filter_obj, space_after_every_character=False): exclude = [] for dataset in ["test", "dev", "train", "validated", "other"]: - set_samples = _maybe_convert_set(dataset, tsv_dir, audio_dir, space_after_every_character) + set_samples = _maybe_convert_set(dataset, tsv_dir, audio_dir, filter_obj, space_after_every_character) if dataset in ["test", "dev"]: exclude += set_samples if dataset == "validated": - _maybe_convert_set("train-all", tsv_dir, audio_dir, space_after_every_character, - rows=set_samples, exclude=exclude) + _maybe_convert_set("train-all", tsv_dir, audio_dir, filter_obj, space_after_every_character, + filter_obj, rows=set_samples, exclude=exclude) -def one_sample(sample): - """ Take a audio file, and optionally convert it to 16kHz WAV """ +def one_sample(args): + """ Take an audio file, and optionally convert it to 16kHz WAV """ + sample, filter_obj = args mp3_filename = sample[0] if not os.path.splitext(mp3_filename.lower())[1] == ".mp3": mp3_filename += ".mp3" @@ -58,7 +60,7 @@ def one_sample(sample): ["soxi", "-s", wav_filename], stderr=subprocess.STDOUT ) ) - label = label_filter_fun(sample[1]) + label = filter_obj.filter(sample[1]) rows = [] counter = get_counter() if file_size == -1: @@ -82,7 +84,7 @@ def one_sample(sample): return (counter, rows) -def _maybe_convert_set(dataset, tsv_dir, audio_dir, space_after_every_character=None, rows=None, exclude=None): +def _maybe_convert_set(dataset, tsv_dir, audio_dir, filter_obj, space_after_every_character=None, rows=None, exclude=None): exclude_transcripts = set() exclude_speakers = set() if exclude is not None: @@ -109,7 +111,8 @@ def _maybe_convert_set(dataset, tsv_dir, audio_dir, space_after_every_character= print("Importing mp3 files...") pool = Pool() bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) - for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1): + samples_with_context = itertools.zip_longest(samples, [], fillvalue=filter_obj) + for i, processed in enumerate(pool.imap_unordered(one_sample, samples_with_context), start=1): counter += processed[0] rows += processed[1] bar.update(i) @@ -160,50 +163,60 @@ def _maybe_convert_wav(mp3_filename, wav_filename): except sox.core.SoxError: pass +class LabelFilter: + def __init__(self, normalize, alphabet, validate_fun): + self.normalize = normalize + self.alphabet = alphabet + self.validate_fun = validate_fun -if __name__ == "__main__": - PARSER = get_importers_parser(description="Import CommonVoice v2.0 corpora") - PARSER.add_argument("tsv_dir", help="Directory containing tsv files") - PARSER.add_argument( - "--audio_dir", - help='Directory containing the audio clips - defaults to "/clips"', - ) - PARSER.add_argument( - "--filter_alphabet", - help="Exclude samples with characters not in provided alphabet", - ) - PARSER.add_argument( - "--normalize", - action="store_true", - help="Converts diacritic characters to their base ones", - ) - PARSER.add_argument( - "--space_after_every_character", - action="store_true", - help="To help transcript join by white space", - ) - - PARAMS = PARSER.parse_args() - validate_label = get_validate_label(PARAMS) - - AUDIO_DIR = ( - PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, "clips") - ) - ALPHABET = Alphabet(PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None - - def label_filter_fun(label): - if PARAMS.normalize: + def filter(self, label): + if self.normalize: label = ( unicodedata.normalize("NFKD", label.strip()) .encode("ascii", "ignore") .decode("ascii", "ignore") ) - label = validate_label(label) - if ALPHABET and label: + label = self.validate_fun(label) + if self.alphabet and label: try: - ALPHABET.encode(label) + self.alphabet.encode(label) except KeyError: label = None return label - _preprocess_data(PARAMS.tsv_dir, AUDIO_DIR, PARAMS.space_after_every_character) +def main(): + parser = get_importers_parser(description="Import CommonVoice v2.0 corpora") + parser.add_argument("tsv_dir", help="Directory containing tsv files") + parser.add_argument( + "--audio_dir", + help='Directory containing the audio clips - defaults to "/clips"', + ) + parser.add_argument( + "--filter_alphabet", + help="Exclude samples with characters not in provided alphabet", + ) + parser.add_argument( + "--normalize", + action="store_true", + help="Converts diacritic characters to their base ones", + ) + parser.add_argument( + "--space_after_every_character", + action="store_true", + help="To help transcript join by white space", + ) + + params = parser.parse_args() + validate_label = get_validate_label(params) + + audio_dir = ( + params.audio_dir if params.audio_dir else os.path.join(params.tsv_dir, "clips") + ) + alphabet = Alphabet(params.filter_alphabet) if params.filter_alphabet else None + + filter_obj = LabelFilter(params.normalize, alphabet, validate_label) + _preprocess_data(params.tsv_dir, audio_dir, filter_obj, + params.space_after_every_character) + +if __name__ == "__main__": + main() From 3d0ec01853fe9f5b73a6e35b0226e4c97c237655 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Wed, 27 May 2020 19:02:55 +0200 Subject: [PATCH 2/2] Fix typo from argument reordering --- bin/import_cv2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/import_cv2.py b/bin/import_cv2.py index 50434605..3b7e8fba 100755 --- a/bin/import_cv2.py +++ b/bin/import_cv2.py @@ -39,7 +39,7 @@ def _preprocess_data(tsv_dir, audio_dir, filter_obj, space_after_every_character exclude += set_samples if dataset == "validated": _maybe_convert_set("train-all", tsv_dir, audio_dir, filter_obj, space_after_every_character, - filter_obj, rows=set_samples, exclude=exclude) + rows=set_samples, exclude=exclude) def one_sample(args):