better handling of empty TSV dir / comments
This commit is contained in:
parent
e6e33eb3a7
commit
e26aefc39b
@ -21,46 +21,54 @@ from util.downloader import SIMPLE_BAR
|
||||
|
||||
|
||||
'''
|
||||
This script will do the following:
|
||||
Broadly speaking, this script takes the audio downloaded from Common Voice
|
||||
for a certain language, in addition to the *.tsv files output by CorporaCeator,
|
||||
and the script formats the data and transcripts to be in a state usable by
|
||||
DeepSpeech.py
|
||||
|
||||
Input:
|
||||
audio_dir (string) path to dir of audio downloaded from Common Voice
|
||||
tsv_dir (string) path to dir containing tsv files generated by CorporaCreator
|
||||
(1) audio_dir (string) path to dir of audio downloaded from Common Voice
|
||||
(2) tsv_dir (string) path to dir containing tsv files generated by CorporaCreator
|
||||
|
||||
Ouput:
|
||||
csv files in format needed by DeepSpeech.py, saved into audio_dir
|
||||
optionally converted wav files, saved into audio_dir alongside their mp3s
|
||||
(1) csv files in format needed by DeepSpeech.py, saved into audio_dir
|
||||
(2) wav files, saved into audio_dir alongside their mp3s
|
||||
'''
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 10
|
||||
|
||||
|
||||
def _preprocess_data(audio_dir, tsv_dir):
|
||||
for source_tsv in glob(path.join(path.abspath(tsv_dir), '*.tsv')):
|
||||
if not os.path.isfile(source_tsv):
|
||||
print("ERROR: TSV file not found:", source_tsv)
|
||||
else:
|
||||
_maybe_convert_set(audio_dir, source_tsv)
|
||||
try:
|
||||
# Check if there is at least one TSV file in tsv_dir
|
||||
os.path.isfile(glob(path.join(path.abspath(tsv_dir), '*.tsv'))[0])
|
||||
for input_tsv in glob(path.join(path.abspath(tsv_dir), '*.tsv')):
|
||||
print("Loading in TSV file: ", input_tsv)
|
||||
_maybe_convert_set(audio_dir, input_tsv)
|
||||
except IndexError:
|
||||
print("ERROR: no TSV file found in: ", tsv_dir)
|
||||
|
||||
def _maybe_convert_set(audio_dir, source_tsv):
|
||||
# save new csv to the dir where the data is
|
||||
target_csv = path.join(audio_dir,os.path.split(source_tsv)[-1].replace('tsv', 'csv'))
|
||||
print("Saving new DeepSpeech-formatted CSV file to: ", target_csv)
|
||||
|
||||
def _maybe_convert_set(audio_dir, input_tsv):
|
||||
output_csv = path.join(audio_dir,os.path.split(input_tsv)[-1].replace('tsv', 'csv'))
|
||||
print("Saving new DeepSpeech-formatted CSV file to: ", output_csv)
|
||||
|
||||
# Get audiofile path and transcript for each sentence in tsv
|
||||
samples = []
|
||||
with open(source_tsv) as source_tsv_file:
|
||||
reader = csv.DictReader(source_tsv_file, delimiter='\t')
|
||||
with open(input_tsv) as input_tsv_file:
|
||||
reader = csv.DictReader(input_tsv_file, delimiter='\t')
|
||||
for row in reader:
|
||||
samples.append((row['path'], row['sentence']))
|
||||
|
||||
# Mutable counters for the concurrent embedded routine
|
||||
|
||||
# Keep track of how many samples are good vs. problematic
|
||||
counter = { 'all': 0, 'too_short': 0, 'too_long': 0 }
|
||||
lock = RLock()
|
||||
num_samples = len(samples)
|
||||
rows = []
|
||||
|
||||
|
||||
def one_sample(sample):
|
||||
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
||||
mp3_filename = path.join(audio_dir, sample[0])
|
||||
# Storing wav files next to the mp3 ones - just with a different suffix
|
||||
wav_filename = path.splitext(mp3_filename)[0] + ".wav"
|
||||
@ -78,8 +86,8 @@ def _maybe_convert_set(audio_dir, source_tsv):
|
||||
# This one is good - keep it for the target CSV
|
||||
rows.append((wav_filename, file_size, sample[1]))
|
||||
counter['all'] += 1
|
||||
|
||||
print('Importing mp3 files...')
|
||||
|
||||
print("Importing mp3 files...")
|
||||
pool = Pool(cpu_count())
|
||||
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
|
||||
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1):
|
||||
@ -87,15 +95,15 @@ def _maybe_convert_set(audio_dir, source_tsv):
|
||||
bar.update(num_samples)
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
with open(target_csv, 'w') as target_csv_file:
|
||||
print('Writing out ', target_csv)
|
||||
writer = csv.DictWriter(target_csv_file, fieldnames=FIELDNAMES)
|
||||
|
||||
with open(output_csv, 'w') as output_csv_file:
|
||||
print('Writing CSV file for DeepSpeech.py as: ', output_csv)
|
||||
writer = csv.DictWriter(output_csv_file, fieldnames=FIELDNAMES)
|
||||
writer.writeheader()
|
||||
bar = progressbar.ProgressBar(max_value=len(rows), widgets=SIMPLE_BAR)
|
||||
for filename, file_size, transcript in bar(rows):
|
||||
writer.writerow({ 'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript })
|
||||
|
||||
|
||||
print('Imported %d samples.' % (counter['all'] - counter['too_short'] - counter['too_long']))
|
||||
if counter['too_short'] > 0:
|
||||
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
|
||||
@ -107,13 +115,10 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
|
||||
transformer = Transformer()
|
||||
transformer.convert(samplerate=SAMPLE_RATE)
|
||||
transformer.build(mp3_filename, wav_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
audio_dir = sys.argv[1]
|
||||
tsv_dir = sys.argv[2]
|
||||
print('Expecting your audio from Common Voice to be in ', audio_dir)
|
||||
print('Looking for *.tsv files (generated by CorporaCreator) in ', tsv_dir)
|
||||
print('Expecting your audio from Common Voice to be in: ', audio_dir)
|
||||
print('Looking for *.tsv files (generated by CorporaCreator) in: ', tsv_dir)
|
||||
_preprocess_data(audio_dir, tsv_dir)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user