Converting importers from multiprocessing.dummy to multiprocessing

Fixes #2817
This commit is contained in:
Alexandre Lissy 2020-03-11 13:37:18 +01:00
parent ce59228824
commit 7b2a409f9f
9 changed files with 340 additions and 340 deletions

View File

@ -15,10 +15,8 @@ import progressbar
from glob import glob from glob import glob
from os import path from os import path
from threading import RLock from multiprocessing import Pool
from multiprocessing.dummy import Pool from util.importers import validate_label_eng as validate_label, get_counter, get_imported_samples, print_import_report
from multiprocessing import cpu_count
from util.importers import validate_label_eng as validate_label
from util.downloader import maybe_download, SIMPLE_BAR from util.downloader import maybe_download, SIMPLE_BAR
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
@ -53,6 +51,38 @@ def _maybe_convert_sets(target_dir, extracted_data):
for source_csv in glob(path.join(extracted_dir, '*.csv')): for source_csv in glob(path.join(extracted_dir, '*.csv')):
_maybe_convert_set(extracted_dir, source_csv, path.join(target_dir, os.path.split(source_csv)[-1])) _maybe_convert_set(extracted_dir, source_csv, path.join(target_dir, os.path.split(source_csv)[-1]))
def one_sample(sample):
mp3_filename = sample[0]
# Storing wav files next to the mp3 ones - just with a different suffix
wav_filename = path.splitext(mp3_filename)[0] + ".wav"
_maybe_convert_wav(mp3_filename, wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
file_size = -1
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = validate_label(sample[1])
rows = []
counter = get_counter()
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
return (counter, rows)
def _maybe_convert_set(extracted_dir, source_csv, target_csv): def _maybe_convert_set(extracted_dir, source_csv, target_csv):
print() print()
if path.exists(target_csv): if path.exists(target_csv):
@ -63,48 +93,19 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
with open(source_csv) as source_csv_file: with open(source_csv) as source_csv_file:
reader = csv.DictReader(source_csv_file) reader = csv.DictReader(source_csv_file)
for row in reader: for row in reader:
samples.append((row['filename'], row['text'])) samples.append((os.path.join(extracted_dir, row['filename']), row['text']))
# Mutable counters for the concurrent embedded routine # Mutable counters for the concurrent embedded routine
counter = { 'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0 } counter = get_counter()
lock = RLock()
num_samples = len(samples) num_samples = len(samples)
rows = [] rows = []
def one_sample(sample):
mp3_filename = path.join(*(sample[0].split('/')))
mp3_filename = path.join(extracted_dir, mp3_filename)
# Storing wav files next to the mp3 ones - just with a different suffix
wav_filename = path.splitext(mp3_filename)[0] + ".wav"
_maybe_convert_wav(mp3_filename, wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
file_size = -1
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = validate_label(sample[1])
with lock:
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
print('Importing mp3 files...') print('Importing mp3 files...')
pool = Pool(cpu_count()) pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i) bar.update(i)
bar.update(num_samples) bar.update(num_samples)
pool.close() pool.close()
@ -118,15 +119,11 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
for filename, file_size, transcript in bar(rows): for filename, file_size, transcript in bar(rows):
writer.writerow({ 'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript }) writer.writerow({ 'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript })
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) imported_samples = get_imported_samples(counter)
if counter['failed'] > 0: assert counter['all'] == num_samples
print('Skipped %d samples that failed upon conversion.' % counter['failed']) assert len(rows) == imported_samples
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
def _maybe_convert_wav(mp3_filename, wav_filename): def _maybe_convert_wav(mp3_filename, wav_filename):
if not path.exists(wav_filename): if not path.exists(wav_filename):

View File

@ -21,13 +21,10 @@ import progressbar
import unicodedata import unicodedata
from os import path from os import path
from threading import RLock from multiprocessing import Pool
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from util.downloader import SIMPLE_BAR from util.downloader import SIMPLE_BAR
from util.text import Alphabet from util.text import Alphabet
from util.importers import get_importers_parser, get_validate_label from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
from util.helpers import secs_to_hours
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
@ -35,15 +32,50 @@ SAMPLE_RATE = 16000
MAX_SECS = 10 MAX_SECS = 10
def _preprocess_data(tsv_dir, audio_dir, label_filter, space_after_every_character=False): def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
for dataset in ['train', 'test', 'dev', 'validated', 'other']: for dataset in ['train', 'test', 'dev', 'validated', 'other']:
input_tsv = path.join(path.abspath(tsv_dir), dataset+".tsv") input_tsv = path.join(path.abspath(tsv_dir), dataset+".tsv")
if os.path.isfile(input_tsv): if os.path.isfile(input_tsv):
print("Loading TSV file: ", input_tsv) print("Loading TSV file: ", input_tsv)
_maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_character) _maybe_convert_set(input_tsv, audio_dir, space_after_every_character)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
mp3_filename = sample[0]
if not path.splitext(mp3_filename.lower())[1] == '.mp3':
mp3_filename += ".mp3"
# Storing wav files next to the mp3 ones - just with a different suffix
wav_filename = path.splitext(mp3_filename)[0] + ".wav"
_maybe_convert_wav(mp3_filename, wav_filename)
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter_fun(sample[1])
rows = []
counter = get_counter()
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((os.path.split(wav_filename)[-1], file_size, label))
counter['all'] += 1
counter['total_time'] += frames
def _maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_character=None): return (counter, rows)
def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
output_csv = path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv')) output_csv = path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv'))
print("Saving new DeepSpeech-formatted CSV file to: ", output_csv) print("Saving new DeepSpeech-formatted CSV file to: ", output_csv)
@ -52,51 +84,18 @@ def _maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_cha
with open(input_tsv, encoding='utf-8') as input_tsv_file: with open(input_tsv, encoding='utf-8') as input_tsv_file:
reader = csv.DictReader(input_tsv_file, delimiter='\t') reader = csv.DictReader(input_tsv_file, delimiter='\t')
for row in reader: for row in reader:
samples.append((row['path'], row['sentence'])) samples.append((path.join(audio_dir, row['path']), row['sentence']))
# Keep track of how many samples are good vs. problematic counter = get_counter()
counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0}
lock = RLock()
num_samples = len(samples) num_samples = len(samples)
rows = [] rows = []
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
mp3_filename = path.join(audio_dir, sample[0])
if not path.splitext(mp3_filename.lower())[1] == '.mp3':
mp3_filename += ".mp3"
# Storing wav files next to the mp3 ones - just with a different suffix
wav_filename = path.splitext(mp3_filename)[0] + ".wav"
_maybe_convert_wav(mp3_filename, wav_filename)
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
with lock:
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((os.path.split(wav_filename)[-1], file_size, label))
counter['all'] += 1
counter['total_time'] += frames
print("Importing mp3 files...") print("Importing mp3 files...")
pool = Pool(cpu_count()) pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i) bar.update(i)
bar.update(num_samples) bar.update(num_samples)
pool.close() pool.close()
@ -113,16 +112,11 @@ def _maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_cha
else: else:
writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript}) writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript})
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) imported_samples = get_imported_samples(counter)
if counter['failed'] > 0: assert counter['all'] == num_samples
print('Skipped %d samples that failed upon conversion.' % counter['failed']) assert len(rows) == imported_samples
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE))
def _maybe_convert_wav(mp3_filename, wav_filename): def _maybe_convert_wav(mp3_filename, wav_filename):
@ -162,4 +156,4 @@ if __name__ == "__main__":
label = None label = None
return label return label
_preprocess_data(PARAMS.tsv_dir, AUDIO_DIR, label_filter_fun, PARAMS.space_after_every_character) _preprocess_data(PARAMS.tsv_dir, AUDIO_DIR, PARAMS.space_after_every_character)

View File

@ -7,7 +7,7 @@ import os
import sys import sys
sys.path.insert(1, os.path.join(sys.path[0], '..')) sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import argparse import argparse
import csv import csv
@ -18,9 +18,7 @@ import subprocess
import progressbar import progressbar
import unicodedata import unicodedata
from threading import RLock from multiprocessing import Pool
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from util.downloader import SIMPLE_BAR from util.downloader import SIMPLE_BAR
from os import path from os import path
@ -28,7 +26,6 @@ from glob import glob
from util.downloader import maybe_download from util.downloader import maybe_download
from util.text import Alphabet from util.text import Alphabet
from util.helpers import secs_to_hours
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -61,6 +58,41 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
else: else:
print('Found directory "%s" - not extracting it from archive.' % archive_path) print('Found directory "%s" - not extracting it from archive.' % archive_path)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
ogg_filename = sample[0]
# Storing wav files next to the ogg ones - just with a different suffix
wav_filename = path.splitext(ogg_filename)[0] + ".wav"
_maybe_convert_wav(ogg_filename, wav_filename)
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
rows = []
counter = get_counter()
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data) extracted_dir = path.join(target_dir, extracted_data)
# override existing CSV with normalized one # override existing CSV with normalized one
@ -76,49 +108,18 @@ def _maybe_convert_sets(target_dir, extracted_data):
for record in glob(glob_dir, recursive=True): for record in glob(glob_dir, recursive=True):
record_file = record.replace(ogg_root_dir + os.path.sep, '') record_file = record.replace(ogg_root_dir + os.path.sep, '')
if record_filter(record_file): if record_filter(record_file):
samples.append((record_file, os.path.splitext(os.path.basename(record_file))[0])) samples.append((os.path.join(ogg_root_dir, record_file), os.path.splitext(os.path.basename(record_file))[0]))
# Keep track of how many samples are good vs. problematic counter = get_counter()
counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0}
lock = RLock()
num_samples = len(samples) num_samples = len(samples)
rows = [] rows = []
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
ogg_filename = path.join(ogg_root_dir, sample[0])
# Storing wav files next to the ogg ones - just with a different suffix
wav_filename = path.splitext(ogg_filename)[0] + ".wav"
_maybe_convert_wav(ogg_filename, wav_filename)
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
with lock:
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
print("Importing ogg files...") print("Importing ogg files...")
pool = Pool(cpu_count()) pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i) bar.update(i)
bar.update(num_samples) bar.update(num_samples)
pool.close() pool.close()
@ -152,16 +153,11 @@ def _maybe_convert_sets(target_dir, extracted_data):
transcript=transcript, transcript=transcript,
)) ))
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) imported_samples = get_imported_samples(counter)
if counter['failed'] > 0: assert counter['all'] == num_samples
print('Skipped %d samples that failed upon conversion.' % counter['failed']) assert len(rows) == imported_samples
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE))
def _maybe_convert_wav(ogg_filename, wav_filename): def _maybe_convert_wav(ogg_filename, wav_filename):
if not path.exists(wav_filename): if not path.exists(wav_filename):

View File

@ -9,7 +9,7 @@ import sys
sys.path.insert(1, os.path.join(sys.path[0], '..')) sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv import csv
import subprocess import subprocess
@ -17,9 +17,7 @@ import progressbar
import unicodedata import unicodedata
import tarfile import tarfile
from threading import RLock from multiprocessing import Pool
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from util.downloader import SIMPLE_BAR from util.downloader import SIMPLE_BAR
from os import path from os import path
@ -27,7 +25,6 @@ from glob import glob
from util.downloader import maybe_download from util.downloader import maybe_download
from util.text import Alphabet from util.text import Alphabet
from util.helpers import secs_to_hours
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -63,6 +60,38 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
print('Found directory "%s" - not extracting it from archive.' % archive_path) print('Found directory "%s" - not extracting it from archive.' % archive_path)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
wav_filename = sample[0]
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
counter = get_counter()
rows = []
if file_size == -1:
# Excluding samples that failed upon conversion
print("conversion failure", wav_filename)
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data) extracted_dir = path.join(target_dir, extracted_data)
# override existing CSV with normalized one # override existing CSV with normalized one
@ -85,44 +114,16 @@ def _maybe_convert_sets(target_dir, extracted_data):
transcript = re[2] transcript = re[2]
samples.append((audio, transcript)) samples.append((audio, transcript))
# Keep track of how many samples are good vs. problematic counter = get_counter()
counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0}
lock = RLock()
num_samples = len(samples) num_samples = len(samples)
rows = [] rows = []
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
wav_filename = sample[0]
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
with lock:
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
print("Importing WAV files...") print("Importing WAV files...")
pool = Pool(cpu_count()) pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i) bar.update(i)
bar.update(num_samples) bar.update(num_samples)
pool.close() pool.close()
@ -156,17 +157,11 @@ def _maybe_convert_sets(target_dir, extracted_data):
transcript=transcript, transcript=transcript,
)) ))
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) imported_samples = get_imported_samples(counter)
if counter['failed'] > 0: assert counter['all'] == num_samples
print('Skipped %d samples that failed upon conversion.' % counter['failed']) assert len(rows) == imported_samples
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label'])
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE))
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def handle_args(): def handle_args():
parser = get_importers_parser(description='Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/.') parser = get_importers_parser(description='Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/.')

View File

@ -7,7 +7,7 @@ import os
import sys import sys
sys.path.insert(1, os.path.join(sys.path[0], '..')) sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv import csv
import re import re
@ -18,9 +18,7 @@ import progressbar
import unicodedata import unicodedata
import tarfile import tarfile
from threading import RLock from multiprocessing import Pool
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from util.downloader import SIMPLE_BAR from util.downloader import SIMPLE_BAR
from os import path from os import path
@ -62,6 +60,37 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
else: else:
print('Found directory "%s" - not extracting it from archive.' % archive_path) print('Found directory "%s" - not extracting it from archive.' % archive_path)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
wav_filename = sample[0]
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
counter = get_counter()
rows = []
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data) extracted_dir = path.join(target_dir, extracted_data)
# override existing CSV with normalized one # override existing CSV with normalized one
@ -112,43 +141,16 @@ def _maybe_convert_sets(target_dir, extracted_data):
samples.append((record, transcripts[record_file])) samples.append((record, transcripts[record_file]))
# Keep track of how many samples are good vs. problematic # Keep track of how many samples are good vs. problematic
counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0} counter = get_counter()
lock = RLock()
num_samples = len(samples) num_samples = len(samples)
rows = [] rows = []
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
wav_filename = sample[0]
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
with lock:
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
print("Importing WAV files...") print("Importing WAV files...")
pool = Pool(cpu_count()) pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i) bar.update(i)
bar.update(num_samples) bar.update(num_samples)
pool.close() pool.close()
@ -182,16 +184,11 @@ def _maybe_convert_sets(target_dir, extracted_data):
transcript=transcript, transcript=transcript,
)) ))
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) imported_samples = get_imported_samples(counter)
if counter['failed'] > 0: assert counter['all'] == num_samples
print('Skipped %d samples that failed upon conversion.' % counter['failed']) assert len(rows) == imported_samples
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE))
def handle_args(): def handle_args():
parser = get_importers_parser(description='Importer for African Accented French dataset. More information on http://www.openslr.org/57/.') parser = get_importers_parser(description='Importer for African Accented French dataset. More information on http://www.openslr.org/57/.')

View File

@ -8,7 +8,7 @@ import re
import sys import sys
sys.path.insert(1, os.path.join(sys.path[0], '..')) sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv import csv
import unidecode import unidecode
@ -17,15 +17,12 @@ import sox
import subprocess import subprocess
import progressbar import progressbar
from threading import RLock from multiprocessing import Pool
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from util.downloader import SIMPLE_BAR from util.downloader import SIMPLE_BAR
from os import path from os import path
from util.downloader import maybe_download from util.downloader import maybe_download
from util.helpers import secs_to_hours
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -59,6 +56,44 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
print('Found directory "%s" - not extracting it from archive.' % archive_path) print('Found directory "%s" - not extracting it from archive.' % archive_path)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
orig_filename = sample['path']
# Storing wav files next to the wav ones - just with a different suffix
wav_filename = path.splitext(orig_filename)[0] + ".converted.wav"
_maybe_convert_wav(orig_filename, wav_filename)
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = sample['text']
rows = []
# Keep track of how many samples are good vs. problematic
counter = get_counter()
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False): def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
extracted_dir = path.join(target_dir, extracted_data) extracted_dir = path.join(target_dir, extracted_data)
# override existing CSV with normalized one # override existing CSV with normalized one
@ -72,49 +107,19 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
if float(d['duration']) <= MAX_SECS if float(d['duration']) <= MAX_SECS
] ]
# Keep track of how many samples are good vs. problematic for line in data:
counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0} line['path'] = os.path.join(extracted_dir, line['path'])
lock = RLock()
num_samples = len(data) num_samples = len(data)
rows = [] rows = []
counter = get_counter()
wav_root_dir = extracted_dir print("Importing {} wav files...".format(num_samples))
pool = Pool()
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
orig_filename = path.join(wav_root_dir, sample['path'])
# Storing wav files next to the wav ones - just with a different suffix
wav_filename = path.splitext(orig_filename)[0] + ".converted.wav"
_maybe_convert_wav(orig_filename, wav_filename)
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = sample['text']
with lock:
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
print("Importing wav files...")
pool = Pool(cpu_count())
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, data), start=1): for i, processed in enumerate(pool.imap_unordered(one_sample, data), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i) bar.update(i)
bar.update(num_samples) bar.update(num_samples)
pool.close() pool.close()
@ -131,7 +136,6 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
test_writer.writeheader() test_writer.writeheader()
for i, item in enumerate(rows): for i, item in enumerate(rows):
print('item', item)
transcript = validate_label(cleanup_transcript(item[2], english_compatible=english_compatible)) transcript = validate_label(cleanup_transcript(item[2], english_compatible=english_compatible))
if not transcript: if not transcript:
continue continue
@ -149,16 +153,11 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
transcript=transcript, transcript=transcript,
)) ))
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) imported_samples = get_imported_samples(counter)
if counter['failed'] > 0: assert counter['all'] == num_samples
print('Skipped %d samples that failed upon conversion.' % counter['failed']) assert len(rows) == imported_samples
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE))
def _maybe_convert_wav(orig_filename, wav_filename): def _maybe_convert_wav(orig_filename, wav_filename):
if not path.exists(wav_filename): if not path.exists(wav_filename):

View File

@ -14,13 +14,14 @@ import sys
sys.path.insert(1, os.path.join(sys.path[0], "..")) sys.path.insert(1, os.path.join(sys.path[0], ".."))
from util.importers import get_counter, get_imported_samples, print_import_report
import re import re
import librosa import librosa
import progressbar import progressbar
from os import path from os import path
from multiprocessing.dummy import Pool from multiprocessing import Pool
from multiprocessing import cpu_count
from util.downloader import maybe_download, SIMPLE_BAR from util.downloader import maybe_download, SIMPLE_BAR
from zipfile import ZipFile from zipfile import ZipFile
@ -61,47 +62,46 @@ def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data, "wav48") extracted_dir = path.join(target_dir, extracted_data, "wav48")
txt_dir = path.join(target_dir, extracted_data, "txt") txt_dir = path.join(target_dir, extracted_data, "txt")
cnt = 1
directory = os.path.expanduser(extracted_dir) directory = os.path.expanduser(extracted_dir)
srtd = len(sorted(os.listdir(directory))) srtd = len(sorted(os.listdir(directory)))
all_samples = []
for target in sorted(os.listdir(directory)): for target in sorted(os.listdir(directory)):
print(f"\nSpeaker {cnt} of {srtd}") all_samples += _maybe_prepare_set(path.join(extracted_dir, os.path.split(target)[-1]))
_maybe_convert_set(path.join(extracted_dir, os.path.split(target)[-1]))
cnt += 1
_write_csv(extracted_dir, txt_dir, target_dir)
def _maybe_convert_set(target_csv):
def one_sample(sample):
if is_audio_file(sample):
sample = os.path.join(target_csv, sample)
y, sr = librosa.load(sample, sr=16000)
# Trim the beginning and ending silence
yt, index = librosa.effects.trim(y) # pylint: disable=unused-variable
duration = librosa.get_duration(yt, sr)
if duration > MAX_SECS or duration < MIN_SECS:
os.remove(sample)
else:
librosa.output.write_wav(sample, yt, sr)
samples = sorted(os.listdir(target_csv))
num_samples = len(samples)
num_samples = len(all_samples)
print(f"Converting wav files to {SAMPLE_RATE}hz...") print(f"Converting wav files to {SAMPLE_RATE}hz...")
pool = Pool(cpu_count()) pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): for i, _ in enumerate(pool.imap_unordered(one_sample, all_samples), start=1):
bar.update(i) bar.update(i)
bar.update(num_samples) bar.update(num_samples)
pool.close() pool.close()
pool.join() pool.join()
_write_csv(extracted_dir, txt_dir, target_dir)
def one_sample(sample):
if is_audio_file(sample):
y, sr = librosa.load(sample, sr=16000)
# Trim the beginning and ending silence
yt, index = librosa.effects.trim(y) # pylint: disable=unused-variable
duration = librosa.get_duration(yt, sr)
if duration > MAX_SECS or duration < MIN_SECS:
os.remove(sample)
else:
librosa.output.write_wav(sample, yt, sr)
def _maybe_prepare_set(target_csv):
samples = sorted(os.listdir(target_csv))
new_samples = []
for s in samples:
new_samples.append(os.path.join(target_csv, s))
samples = new_samples
return samples
def _write_csv(extracted_dir, txt_dir, target_dir): def _write_csv(extracted_dir, txt_dir, target_dir):
print(f"Writing CSV file") print(f"Writing CSV file")
@ -196,8 +196,8 @@ def load_txts(directory):
AUDIO_EXTENSIONS = [".wav", "WAV"] AUDIO_EXTENSIONS = [".wav", "WAV"]
def is_audio_file(filename): def is_audio_file(filepath):
return any(filename.endswith(extension) for extension in AUDIO_EXTENSIONS) return any(os.path.basename(filepath).endswith(extension) for extension in AUDIO_EXTENSIONS)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,2 +1,3 @@
absl-py absl-py
argparse argparse
semver

View File

@ -4,6 +4,27 @@ import os
import re import re
import sys import sys
from util.helpers import secs_to_hours
from collections import Counter
def get_counter():
return Counter({'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0})
def get_imported_samples(counter):
return counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'] - counter['invalid_label']
def print_import_report(counter, sample_rate, max_secs):
print('Imported %d samples.' % (get_imported_samples(counter)))
if counter['failed'] > 0:
print('Skipped %d samples that failed upon conversion.' % counter['failed'])
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label'])
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], max_secs))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / sample_rate))
def get_importers_parser(description): def get_importers_parser(description):
parser = argparse.ArgumentParser(description=description) parser = argparse.ArgumentParser(description=description)
parser.add_argument('--validate_label_locale', help='Path to a Python file defining a |validate_label| function for your locale. WARNING: THIS WILL ADD THIS FILE\'s DIRECTORY INTO PYTHONPATH.') parser.add_argument('--validate_label_locale', help='Path to a Python file defining a |validate_label| function for your locale. WARNING: THIS WILL ADD THIS FILE\'s DIRECTORY INTO PYTHONPATH.')