Fix #1991 - Additional import options for import_cv2.py

This commit is contained in:
Tilman Kamp 2019-04-01 15:31:42 +02:00
parent 7d7a7f7be5
commit 5645285d25

View File

@ -9,49 +9,42 @@ sys.path.insert(1, os.path.join(sys.path[0], '..'))
import csv
import sox
import argparse
import subprocess
import progressbar
import unicodedata
from os import path
from threading import RLock
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from util.downloader import SIMPLE_BAR
from util.text import validate_label
from util.text import Alphabet, validate_label
'''
Broadly speaking, this script takes the audio downloaded from Common Voice
for a certain language, in addition to the *.tsv files output by CorporaCreator,
and the script formats the data and transcripts to be in a state usable by
DeepSpeech.py
Usage:
$ python3 import_cv2.py /path/to/audio/data_dir /path/to/tsv_dir
Input:
(1) audio_dir (string) path to dir of audio downloaded from Common Voice
(2) tsv_dir (string) path to dir containing {train,test,dev}.tsv files
which were generated by CorporaCreator
Ouput:
(1) csv files in format needed by DeepSpeech.py, saved into audio_dir
(2) wav files, saved into audio_dir alongside their mp3s
Use "python3 import_cv2.py -h" for help
'''
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
SAMPLE_RATE = 16000
MAX_SECS = 10
def _preprocess_data(audio_dir, tsv_dir):
def _preprocess_data(tsv_dir, audio_dir, label_filter):
for dataset in ['train','test','dev']:
input_tsv= path.join(path.abspath(tsv_dir), dataset+".tsv")
if os.path.isfile(input_tsv):
print("Loading TSV file: ", input_tsv)
_maybe_convert_set(audio_dir, input_tsv)
_maybe_convert_set(input_tsv, audio_dir, label_filter)
else:
print("ERROR: no TSV file found: ", input_tsv)
def _maybe_convert_set(audio_dir, input_tsv):
def _maybe_convert_set(input_tsv, audio_dir, label_filter):
output_csv = path.join(audio_dir,os.path.split(input_tsv)[-1].replace('tsv', 'csv'))
print("Saving new DeepSpeech-formatted CSV file to: ", output_csv)
@ -80,7 +73,7 @@ def _maybe_convert_set(audio_dir, input_tsv):
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = validate_label(sample[1])
label = label_filter(sample[1])
with lock:
if file_size == -1:
# Excluding samples that failed upon conversion
@ -126,6 +119,7 @@ def _maybe_convert_set(audio_dir, input_tsv):
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
def _maybe_convert_wav(mp3_filename, wav_filename):
if not path.exists(wav_filename):
transformer = sox.Transformer()
@ -137,8 +131,28 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
if __name__ == "__main__":
audio_dir = sys.argv[1]
tsv_dir = sys.argv[2]
print('Expecting your audio from Common Voice to be in: ', audio_dir)
print('Looking for *.tsv files (generated by CorporaCreator) in: ', tsv_dir)
_preprocess_data(audio_dir, tsv_dir)
parser = argparse.ArgumentParser(description='Import CommonVoice v2.0 corpora')
parser.add_argument('tsv_dir', help='Directory containing tsv files')
parser.add_argument('--audio_dir', help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"')
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
parser.set_defaults(normalize=False)
params = parser.parse_args()
audio_dir = params.audio_dir if params.audio_dir else os.path.join(params.tsv_dir, 'clips')
alphabet = Alphabet(params.filter_alphabet) if params.filter_alphabet else None
def label_filter(label):
if params.normalize:
label = unicodedata.normalize("NFKD", label.strip()) \
.encode("ascii", "ignore") \
.decode("ascii", "ignore")
label = validate_label(label)
if alphabet and label:
try:
[alphabet.label_from_string(c) for c in label]
except KeyError:
return None
return label
_preprocess_data(params.tsv_dir, audio_dir, label_filter)