Merge branch 'import_cv2_multiprocessing' (Fixes #3008)

This commit is contained in:
Reuben Morais 2020-05-28 00:00:43 +02:00
commit b9f9b3cedd

View File

@ -7,6 +7,7 @@ DeepSpeech.py
Use "python3 import_cv2.py -h" for help
"""
import csv
import itertools
import os
import subprocess
import unicodedata
@ -30,19 +31,20 @@ SAMPLE_RATE = 16000
MAX_SECS = 10
def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
def _preprocess_data(tsv_dir, audio_dir, filter_obj, space_after_every_character=False):
exclude = []
for dataset in ["test", "dev", "train", "validated", "other"]:
set_samples = _maybe_convert_set(dataset, tsv_dir, audio_dir, space_after_every_character)
set_samples = _maybe_convert_set(dataset, tsv_dir, audio_dir, filter_obj, space_after_every_character)
if dataset in ["test", "dev"]:
exclude += set_samples
if dataset == "validated":
_maybe_convert_set("train-all", tsv_dir, audio_dir, space_after_every_character,
_maybe_convert_set("train-all", tsv_dir, audio_dir, filter_obj, space_after_every_character,
rows=set_samples, exclude=exclude)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
def one_sample(args):
""" Take an audio file, and optionally convert it to 16kHz WAV """
sample, filter_obj = args
mp3_filename = sample[0]
if not os.path.splitext(mp3_filename.lower())[1] == ".mp3":
mp3_filename += ".mp3"
@ -58,7 +60,7 @@ def one_sample(sample):
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = label_filter_fun(sample[1])
label = filter_obj.filter(sample[1])
rows = []
counter = get_counter()
if file_size == -1:
@ -82,7 +84,7 @@ def one_sample(sample):
return (counter, rows)
def _maybe_convert_set(dataset, tsv_dir, audio_dir, space_after_every_character=None, rows=None, exclude=None):
def _maybe_convert_set(dataset, tsv_dir, audio_dir, filter_obj, space_after_every_character=None, rows=None, exclude=None):
exclude_transcripts = set()
exclude_speakers = set()
if exclude is not None:
@ -109,7 +111,8 @@ def _maybe_convert_set(dataset, tsv_dir, audio_dir, space_after_every_character=
print("Importing mp3 files...")
pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
samples_with_context = itertools.zip_longest(samples, [], fillvalue=filter_obj)
for i, processed in enumerate(pool.imap_unordered(one_sample, samples_with_context), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i)
@ -160,50 +163,60 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
except sox.core.SoxError:
pass
class LabelFilter:
def __init__(self, normalize, alphabet, validate_fun):
self.normalize = normalize
self.alphabet = alphabet
self.validate_fun = validate_fun
if __name__ == "__main__":
PARSER = get_importers_parser(description="Import CommonVoice v2.0 corpora")
PARSER.add_argument("tsv_dir", help="Directory containing tsv files")
PARSER.add_argument(
"--audio_dir",
help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"',
)
PARSER.add_argument(
"--filter_alphabet",
help="Exclude samples with characters not in provided alphabet",
)
PARSER.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
PARSER.add_argument(
"--space_after_every_character",
action="store_true",
help="To help transcript join by white space",
)
PARAMS = PARSER.parse_args()
validate_label = get_validate_label(PARAMS)
AUDIO_DIR = (
PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, "clips")
)
ALPHABET = Alphabet(PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None
def label_filter_fun(label):
if PARAMS.normalize:
def filter(self, label):
if self.normalize:
label = (
unicodedata.normalize("NFKD", label.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
label = validate_label(label)
if ALPHABET and label:
label = self.validate_fun(label)
if self.alphabet and label:
try:
ALPHABET.encode(label)
self.alphabet.encode(label)
except KeyError:
label = None
return label
_preprocess_data(PARAMS.tsv_dir, AUDIO_DIR, PARAMS.space_after_every_character)
def main():
parser = get_importers_parser(description="Import CommonVoice v2.0 corpora")
parser.add_argument("tsv_dir", help="Directory containing tsv files")
parser.add_argument(
"--audio_dir",
help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"',
)
parser.add_argument(
"--filter_alphabet",
help="Exclude samples with characters not in provided alphabet",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
parser.add_argument(
"--space_after_every_character",
action="store_true",
help="To help transcript join by white space",
)
params = parser.parse_args()
validate_label = get_validate_label(params)
audio_dir = (
params.audio_dir if params.audio_dir else os.path.join(params.tsv_dir, "clips")
)
alphabet = Alphabet(params.filter_alphabet) if params.filter_alphabet else None
filter_obj = LabelFilter(params.normalize, alphabet, validate_label)
_preprocess_data(params.tsv_dir, audio_dir, filter_obj,
params.space_after_every_character)
if __name__ == "__main__":
main()