Merge branch 'import_cv2_multiprocessing' (Fixes #3008)
This commit is contained in:
commit
b9f9b3cedd
@ -7,6 +7,7 @@ DeepSpeech.py
|
||||
Use "python3 import_cv2.py -h" for help
|
||||
"""
|
||||
import csv
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
import unicodedata
|
||||
@ -30,19 +31,20 @@ SAMPLE_RATE = 16000
|
||||
MAX_SECS = 10
|
||||
|
||||
|
||||
def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
|
||||
def _preprocess_data(tsv_dir, audio_dir, filter_obj, space_after_every_character=False):
|
||||
exclude = []
|
||||
for dataset in ["test", "dev", "train", "validated", "other"]:
|
||||
set_samples = _maybe_convert_set(dataset, tsv_dir, audio_dir, space_after_every_character)
|
||||
set_samples = _maybe_convert_set(dataset, tsv_dir, audio_dir, filter_obj, space_after_every_character)
|
||||
if dataset in ["test", "dev"]:
|
||||
exclude += set_samples
|
||||
if dataset == "validated":
|
||||
_maybe_convert_set("train-all", tsv_dir, audio_dir, space_after_every_character,
|
||||
_maybe_convert_set("train-all", tsv_dir, audio_dir, filter_obj, space_after_every_character,
|
||||
rows=set_samples, exclude=exclude)
|
||||
|
||||
|
||||
def one_sample(sample):
|
||||
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
||||
def one_sample(args):
|
||||
""" Take an audio file, and optionally convert it to 16kHz WAV """
|
||||
sample, filter_obj = args
|
||||
mp3_filename = sample[0]
|
||||
if not os.path.splitext(mp3_filename.lower())[1] == ".mp3":
|
||||
mp3_filename += ".mp3"
|
||||
@ -58,7 +60,7 @@ def one_sample(sample):
|
||||
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
label = label_filter_fun(sample[1])
|
||||
label = filter_obj.filter(sample[1])
|
||||
rows = []
|
||||
counter = get_counter()
|
||||
if file_size == -1:
|
||||
@ -82,7 +84,7 @@ def one_sample(sample):
|
||||
return (counter, rows)
|
||||
|
||||
|
||||
def _maybe_convert_set(dataset, tsv_dir, audio_dir, space_after_every_character=None, rows=None, exclude=None):
|
||||
def _maybe_convert_set(dataset, tsv_dir, audio_dir, filter_obj, space_after_every_character=None, rows=None, exclude=None):
|
||||
exclude_transcripts = set()
|
||||
exclude_speakers = set()
|
||||
if exclude is not None:
|
||||
@ -109,7 +111,8 @@ def _maybe_convert_set(dataset, tsv_dir, audio_dir, space_after_every_character=
|
||||
print("Importing mp3 files...")
|
||||
pool = Pool()
|
||||
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
|
||||
for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
|
||||
samples_with_context = itertools.zip_longest(samples, [], fillvalue=filter_obj)
|
||||
for i, processed in enumerate(pool.imap_unordered(one_sample, samples_with_context), start=1):
|
||||
counter += processed[0]
|
||||
rows += processed[1]
|
||||
bar.update(i)
|
||||
@ -160,50 +163,60 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
|
||||
except sox.core.SoxError:
|
||||
pass
|
||||
|
||||
class LabelFilter:
|
||||
def __init__(self, normalize, alphabet, validate_fun):
|
||||
self.normalize = normalize
|
||||
self.alphabet = alphabet
|
||||
self.validate_fun = validate_fun
|
||||
|
||||
if __name__ == "__main__":
|
||||
PARSER = get_importers_parser(description="Import CommonVoice v2.0 corpora")
|
||||
PARSER.add_argument("tsv_dir", help="Directory containing tsv files")
|
||||
PARSER.add_argument(
|
||||
"--audio_dir",
|
||||
help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"',
|
||||
)
|
||||
PARSER.add_argument(
|
||||
"--filter_alphabet",
|
||||
help="Exclude samples with characters not in provided alphabet",
|
||||
)
|
||||
PARSER.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Converts diacritic characters to their base ones",
|
||||
)
|
||||
PARSER.add_argument(
|
||||
"--space_after_every_character",
|
||||
action="store_true",
|
||||
help="To help transcript join by white space",
|
||||
)
|
||||
|
||||
PARAMS = PARSER.parse_args()
|
||||
validate_label = get_validate_label(PARAMS)
|
||||
|
||||
AUDIO_DIR = (
|
||||
PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, "clips")
|
||||
)
|
||||
ALPHABET = Alphabet(PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None
|
||||
|
||||
def label_filter_fun(label):
|
||||
if PARAMS.normalize:
|
||||
def filter(self, label):
|
||||
if self.normalize:
|
||||
label = (
|
||||
unicodedata.normalize("NFKD", label.strip())
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
label = validate_label(label)
|
||||
if ALPHABET and label:
|
||||
label = self.validate_fun(label)
|
||||
if self.alphabet and label:
|
||||
try:
|
||||
ALPHABET.encode(label)
|
||||
self.alphabet.encode(label)
|
||||
except KeyError:
|
||||
label = None
|
||||
return label
|
||||
|
||||
_preprocess_data(PARAMS.tsv_dir, AUDIO_DIR, PARAMS.space_after_every_character)
|
||||
def main():
|
||||
parser = get_importers_parser(description="Import CommonVoice v2.0 corpora")
|
||||
parser.add_argument("tsv_dir", help="Directory containing tsv files")
|
||||
parser.add_argument(
|
||||
"--audio_dir",
|
||||
help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter_alphabet",
|
||||
help="Exclude samples with characters not in provided alphabet",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Converts diacritic characters to their base ones",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--space_after_every_character",
|
||||
action="store_true",
|
||||
help="To help transcript join by white space",
|
||||
)
|
||||
|
||||
params = parser.parse_args()
|
||||
validate_label = get_validate_label(params)
|
||||
|
||||
audio_dir = (
|
||||
params.audio_dir if params.audio_dir else os.path.join(params.tsv_dir, "clips")
|
||||
)
|
||||
alphabet = Alphabet(params.filter_alphabet) if params.filter_alphabet else None
|
||||
|
||||
filter_obj = LabelFilter(params.normalize, alphabet, validate_label)
|
||||
_preprocess_data(params.tsv_dir, audio_dir, filter_obj,
|
||||
params.space_after_every_character)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user