Converting importers from multiprocessing.dummy to multiprocessing

Fixes #2817
This commit is contained in:
Alexandre Lissy 2020-03-11 13:37:18 +01:00
parent ce59228824
commit 7b2a409f9f
9 changed files with 340 additions and 340 deletions

View File

@ -15,10 +15,8 @@ import progressbar
from glob import glob from glob import glob
from os import path from os import path
from threading import RLock from multiprocessing import Pool
from multiprocessing.dummy import Pool from util.importers import validate_label_eng as validate_label, get_counter, get_imported_samples, print_import_report
from multiprocessing import cpu_count
from util.importers import validate_label_eng as validate_label
from util.downloader import maybe_download, SIMPLE_BAR from util.downloader import maybe_download, SIMPLE_BAR
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
@ -53,27 +51,8 @@ def _maybe_convert_sets(target_dir, extracted_data):
for source_csv in glob(path.join(extracted_dir, '*.csv')): for source_csv in glob(path.join(extracted_dir, '*.csv')):
_maybe_convert_set(extracted_dir, source_csv, path.join(target_dir, os.path.split(source_csv)[-1])) _maybe_convert_set(extracted_dir, source_csv, path.join(target_dir, os.path.split(source_csv)[-1]))
def _maybe_convert_set(extracted_dir, source_csv, target_csv): def one_sample(sample):
print() mp3_filename = sample[0]
if path.exists(target_csv):
print('Found CSV file "%s" - not importing "%s".' % (target_csv, source_csv))
return
print('No CSV file "%s" - importing "%s"...' % (target_csv, source_csv))
samples = []
with open(source_csv) as source_csv_file:
reader = csv.DictReader(source_csv_file)
for row in reader:
samples.append((row['filename'], row['text']))
# Mutable counters for the concurrent embedded routine
counter = { 'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0 }
lock = RLock()
num_samples = len(samples)
rows = []
def one_sample(sample):
mp3_filename = path.join(*(sample[0].split('/')))
mp3_filename = path.join(extracted_dir, mp3_filename)
# Storing wav files next to the mp3 ones - just with a different suffix # Storing wav files next to the mp3 ones - just with a different suffix
wav_filename = path.splitext(mp3_filename)[0] + ".wav" wav_filename = path.splitext(mp3_filename)[0] + ".wav"
_maybe_convert_wav(mp3_filename, wav_filename) _maybe_convert_wav(mp3_filename, wav_filename)
@ -83,7 +62,8 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
file_size = path.getsize(wav_filename) file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = validate_label(sample[1]) label = validate_label(sample[1])
with lock: rows = []
counter = get_counter()
if file_size == -1: if file_size == -1:
# Excluding samples that failed upon conversion # Excluding samples that failed upon conversion
counter['failed'] += 1 counter['failed'] += 1
@ -100,11 +80,32 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
# This one is good - keep it for the target CSV # This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label)) rows.append((wav_filename, file_size, label))
counter['all'] += 1 counter['all'] += 1
counter['total_time'] += frames
return (counter, rows)
def _maybe_convert_set(extracted_dir, source_csv, target_csv):
print()
if path.exists(target_csv):
print('Found CSV file "%s" - not importing "%s".' % (target_csv, source_csv))
return
print('No CSV file "%s" - importing "%s"...' % (target_csv, source_csv))
samples = []
with open(source_csv) as source_csv_file:
reader = csv.DictReader(source_csv_file)
for row in reader:
samples.append((os.path.join(extracted_dir, row['filename']), row['text']))
# Mutable counters for the concurrent embedded routine
counter = get_counter()
num_samples = len(samples)
rows = []
print('Importing mp3 files...') print('Importing mp3 files...')
pool = Pool(cpu_count()) pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i) bar.update(i)
bar.update(num_samples) bar.update(num_samples)
pool.close() pool.close()
@ -118,15 +119,11 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
for filename, file_size, transcript in bar(rows): for filename, file_size, transcript in bar(rows):
writer.writerow({ 'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript }) writer.writerow({ 'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript })
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) imported_samples = get_imported_samples(counter)
if counter['failed'] > 0: assert counter['all'] == num_samples
print('Skipped %d samples that failed upon conversion.' % counter['failed']) assert len(rows) == imported_samples
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
def _maybe_convert_wav(mp3_filename, wav_filename): def _maybe_convert_wav(mp3_filename, wav_filename):
if not path.exists(wav_filename): if not path.exists(wav_filename):

View File

@ -21,13 +21,10 @@ import progressbar
import unicodedata import unicodedata
from os import path from os import path
from threading import RLock from multiprocessing import Pool
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from util.downloader import SIMPLE_BAR from util.downloader import SIMPLE_BAR
from util.text import Alphabet from util.text import Alphabet
from util.importers import get_importers_parser, get_validate_label from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
from util.helpers import secs_to_hours
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
@ -35,34 +32,16 @@ SAMPLE_RATE = 16000
MAX_SECS = 10 MAX_SECS = 10
def _preprocess_data(tsv_dir, audio_dir, label_filter, space_after_every_character=False): def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
for dataset in ['train', 'test', 'dev', 'validated', 'other']: for dataset in ['train', 'test', 'dev', 'validated', 'other']:
input_tsv = path.join(path.abspath(tsv_dir), dataset+".tsv") input_tsv = path.join(path.abspath(tsv_dir), dataset+".tsv")
if os.path.isfile(input_tsv): if os.path.isfile(input_tsv):
print("Loading TSV file: ", input_tsv) print("Loading TSV file: ", input_tsv)
_maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_character) _maybe_convert_set(input_tsv, audio_dir, space_after_every_character)
def one_sample(sample):
def _maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_character=None):
output_csv = path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv'))
print("Saving new DeepSpeech-formatted CSV file to: ", output_csv)
# Get audiofile path and transcript for each sentence in tsv
samples = []
with open(input_tsv, encoding='utf-8') as input_tsv_file:
reader = csv.DictReader(input_tsv_file, delimiter='\t')
for row in reader:
samples.append((row['path'], row['sentence']))
# Keep track of how many samples are good vs. problematic
counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0}
lock = RLock()
num_samples = len(samples)
rows = []
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """ """ Take a audio file, and optionally convert it to 16kHz WAV """
mp3_filename = path.join(audio_dir, sample[0]) mp3_filename = sample[0]
if not path.splitext(mp3_filename.lower())[1] == '.mp3': if not path.splitext(mp3_filename.lower())[1] == '.mp3':
mp3_filename += ".mp3" mp3_filename += ".mp3"
# Storing wav files next to the mp3 ones - just with a different suffix # Storing wav files next to the mp3 ones - just with a different suffix
@ -73,8 +52,9 @@ def _maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_cha
if path.exists(wav_filename): if path.exists(wav_filename):
file_size = path.getsize(wav_filename) file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1]) label = label_filter_fun(sample[1])
with lock: rows = []
counter = get_counter()
if file_size == -1: if file_size == -1:
# Excluding samples that failed upon conversion # Excluding samples that failed upon conversion
counter['failed'] += 1 counter['failed'] += 1
@ -93,10 +73,29 @@ def _maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_cha
counter['all'] += 1 counter['all'] += 1
counter['total_time'] += frames counter['total_time'] += frames
return (counter, rows)
def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
output_csv = path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv'))
print("Saving new DeepSpeech-formatted CSV file to: ", output_csv)
# Get audiofile path and transcript for each sentence in tsv
samples = []
with open(input_tsv, encoding='utf-8') as input_tsv_file:
reader = csv.DictReader(input_tsv_file, delimiter='\t')
for row in reader:
samples.append((path.join(audio_dir, row['path']), row['sentence']))
counter = get_counter()
num_samples = len(samples)
rows = []
print("Importing mp3 files...") print("Importing mp3 files...")
pool = Pool(cpu_count()) pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i) bar.update(i)
bar.update(num_samples) bar.update(num_samples)
pool.close() pool.close()
@ -113,16 +112,11 @@ def _maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_cha
else: else:
writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript}) writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript})
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) imported_samples = get_imported_samples(counter)
if counter['failed'] > 0: assert counter['all'] == num_samples
print('Skipped %d samples that failed upon conversion.' % counter['failed']) assert len(rows) == imported_samples
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE))
def _maybe_convert_wav(mp3_filename, wav_filename): def _maybe_convert_wav(mp3_filename, wav_filename):
@ -162,4 +156,4 @@ if __name__ == "__main__":
label = None label = None
return label return label
_preprocess_data(PARAMS.tsv_dir, AUDIO_DIR, label_filter_fun, PARAMS.space_after_every_character) _preprocess_data(PARAMS.tsv_dir, AUDIO_DIR, PARAMS.space_after_every_character)

View File

@ -7,7 +7,7 @@ import os
import sys import sys
sys.path.insert(1, os.path.join(sys.path[0], '..')) sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import argparse import argparse
import csv import csv
@ -18,9 +18,7 @@ import subprocess
import progressbar import progressbar
import unicodedata import unicodedata
from threading import RLock from multiprocessing import Pool
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from util.downloader import SIMPLE_BAR from util.downloader import SIMPLE_BAR
from os import path from os import path
@ -28,7 +26,6 @@ from glob import glob
from util.downloader import maybe_download from util.downloader import maybe_download
from util.text import Alphabet from util.text import Alphabet
from util.helpers import secs_to_hours
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -61,32 +58,9 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
else: else:
print('Found directory "%s" - not extracting it from archive.' % archive_path) print('Found directory "%s" - not extracting it from archive.' % archive_path)
def _maybe_convert_sets(target_dir, extracted_data): def one_sample(sample):
extracted_dir = path.join(target_dir, extracted_data)
# override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME + '_' + ARCHIVE_NAME.replace('.zip', '_{}.csv'))
if os.path.isfile(target_csv_template):
return
ogg_root_dir = os.path.join(extracted_dir, ARCHIVE_NAME.replace('.zip', ''))
# Get audiofile path and transcript for each sentence in tsv
samples = []
glob_dir = os.path.join(ogg_root_dir, '**/*.ogg')
for record in glob(glob_dir, recursive=True):
record_file = record.replace(ogg_root_dir + os.path.sep, '')
if record_filter(record_file):
samples.append((record_file, os.path.splitext(os.path.basename(record_file))[0]))
# Keep track of how many samples are good vs. problematic
counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0}
lock = RLock()
num_samples = len(samples)
rows = []
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """ """ Take a audio file, and optionally convert it to 16kHz WAV """
ogg_filename = path.join(ogg_root_dir, sample[0]) ogg_filename = sample[0]
# Storing wav files next to the ogg ones - just with a different suffix # Storing wav files next to the ogg ones - just with a different suffix
wav_filename = path.splitext(ogg_filename)[0] + ".wav" wav_filename = path.splitext(ogg_filename)[0] + ".wav"
_maybe_convert_wav(ogg_filename, wav_filename) _maybe_convert_wav(ogg_filename, wav_filename)
@ -96,7 +70,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
file_size = path.getsize(wav_filename) file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1]) label = label_filter(sample[1])
with lock: rows = []
counter = get_counter()
if file_size == -1: if file_size == -1:
# Excluding samples that failed upon conversion # Excluding samples that failed upon conversion
counter['failed'] += 1 counter['failed'] += 1
@ -115,10 +91,35 @@ def _maybe_convert_sets(target_dir, extracted_data):
counter['all'] += 1 counter['all'] += 1
counter['total_time'] += frames counter['total_time'] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data)
# override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME + '_' + ARCHIVE_NAME.replace('.zip', '_{}.csv'))
if os.path.isfile(target_csv_template):
return
ogg_root_dir = os.path.join(extracted_dir, ARCHIVE_NAME.replace('.zip', ''))
# Get audiofile path and transcript for each sentence in tsv
samples = []
glob_dir = os.path.join(ogg_root_dir, '**/*.ogg')
for record in glob(glob_dir, recursive=True):
record_file = record.replace(ogg_root_dir + os.path.sep, '')
if record_filter(record_file):
samples.append((os.path.join(ogg_root_dir, record_file), os.path.splitext(os.path.basename(record_file))[0]))
counter = get_counter()
num_samples = len(samples)
rows = []
print("Importing ogg files...") print("Importing ogg files...")
pool = Pool(cpu_count()) pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i) bar.update(i)
bar.update(num_samples) bar.update(num_samples)
pool.close() pool.close()
@ -152,16 +153,11 @@ def _maybe_convert_sets(target_dir, extracted_data):
transcript=transcript, transcript=transcript,
)) ))
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) imported_samples = get_imported_samples(counter)
if counter['failed'] > 0: assert counter['all'] == num_samples
print('Skipped %d samples that failed upon conversion.' % counter['failed']) assert len(rows) == imported_samples
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE))
def _maybe_convert_wav(ogg_filename, wav_filename): def _maybe_convert_wav(ogg_filename, wav_filename):
if not path.exists(wav_filename): if not path.exists(wav_filename):

View File

@ -9,7 +9,7 @@ import sys
sys.path.insert(1, os.path.join(sys.path[0], '..')) sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv import csv
import subprocess import subprocess
@ -17,9 +17,7 @@ import progressbar
import unicodedata import unicodedata
import tarfile import tarfile
from threading import RLock from multiprocessing import Pool
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from util.downloader import SIMPLE_BAR from util.downloader import SIMPLE_BAR
from os import path from os import path
@ -27,7 +25,6 @@ from glob import glob
from util.downloader import maybe_download from util.downloader import maybe_download
from util.text import Alphabet from util.text import Alphabet
from util.helpers import secs_to_hours
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -63,6 +60,38 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
print('Found directory "%s" - not extracting it from archive.' % archive_path) print('Found directory "%s" - not extracting it from archive.' % archive_path)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
wav_filename = sample[0]
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
counter = get_counter()
rows = []
if file_size == -1:
# Excluding samples that failed upon conversion
print("conversion failure", wav_filename)
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data) extracted_dir = path.join(target_dir, extracted_data)
# override existing CSV with normalized one # override existing CSV with normalized one
@ -85,44 +114,16 @@ def _maybe_convert_sets(target_dir, extracted_data):
transcript = re[2] transcript = re[2]
samples.append((audio, transcript)) samples.append((audio, transcript))
# Keep track of how many samples are good vs. problematic counter = get_counter()
counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0}
lock = RLock()
num_samples = len(samples) num_samples = len(samples)
rows = [] rows = []
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
wav_filename = sample[0]
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
with lock:
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
print("Importing WAV files...") print("Importing WAV files...")
pool = Pool(cpu_count()) pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i) bar.update(i)
bar.update(num_samples) bar.update(num_samples)
pool.close() pool.close()
@ -156,17 +157,11 @@ def _maybe_convert_sets(target_dir, extracted_data):
transcript=transcript, transcript=transcript,
)) ))
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) imported_samples = get_imported_samples(counter)
if counter['failed'] > 0: assert counter['all'] == num_samples
print('Skipped %d samples that failed upon conversion.' % counter['failed']) assert len(rows) == imported_samples
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label'])
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE))
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def handle_args(): def handle_args():
parser = get_importers_parser(description='Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/.') parser = get_importers_parser(description='Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/.')

View File

@ -7,7 +7,7 @@ import os
import sys import sys
sys.path.insert(1, os.path.join(sys.path[0], '..')) sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv import csv
import re import re
@ -18,9 +18,7 @@ import progressbar
import unicodedata import unicodedata
import tarfile import tarfile
from threading import RLock from multiprocessing import Pool
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from util.downloader import SIMPLE_BAR from util.downloader import SIMPLE_BAR
from os import path from os import path
@ -62,6 +60,37 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
else: else:
print('Found directory "%s" - not extracting it from archive.' % archive_path) print('Found directory "%s" - not extracting it from archive.' % archive_path)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
wav_filename = sample[0]
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
counter = get_counter()
rows = []
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data) extracted_dir = path.join(target_dir, extracted_data)
# override existing CSV with normalized one # override existing CSV with normalized one
@ -112,43 +141,16 @@ def _maybe_convert_sets(target_dir, extracted_data):
samples.append((record, transcripts[record_file])) samples.append((record, transcripts[record_file]))
# Keep track of how many samples are good vs. problematic # Keep track of how many samples are good vs. problematic
counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0} counter = get_counter()
lock = RLock()
num_samples = len(samples) num_samples = len(samples)
rows = [] rows = []
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
wav_filename = sample[0]
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
with lock:
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
print("Importing WAV files...") print("Importing WAV files...")
pool = Pool(cpu_count()) pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i) bar.update(i)
bar.update(num_samples) bar.update(num_samples)
pool.close() pool.close()
@ -182,16 +184,11 @@ def _maybe_convert_sets(target_dir, extracted_data):
transcript=transcript, transcript=transcript,
)) ))
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) imported_samples = get_imported_samples(counter)
if counter['failed'] > 0: assert counter['all'] == num_samples
print('Skipped %d samples that failed upon conversion.' % counter['failed']) assert len(rows) == imported_samples
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE))
def handle_args(): def handle_args():
parser = get_importers_parser(description='Importer for African Accented French dataset. More information on http://www.openslr.org/57/.') parser = get_importers_parser(description='Importer for African Accented French dataset. More information on http://www.openslr.org/57/.')

View File

@ -8,7 +8,7 @@ import re
import sys import sys
sys.path.insert(1, os.path.join(sys.path[0], '..')) sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv import csv
import unidecode import unidecode
@ -17,15 +17,12 @@ import sox
import subprocess import subprocess
import progressbar import progressbar
from threading import RLock from multiprocessing import Pool
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from util.downloader import SIMPLE_BAR from util.downloader import SIMPLE_BAR
from os import path from os import path
from util.downloader import maybe_download from util.downloader import maybe_download
from util.helpers import secs_to_hours
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -59,30 +56,9 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
print('Found directory "%s" - not extracting it from archive.' % archive_path) print('Found directory "%s" - not extracting it from archive.' % archive_path)
def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False): def one_sample(sample):
extracted_dir = path.join(target_dir, extracted_data)
# override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, 'ts_' + ARCHIVE_NAME + '_{}.csv')
if os.path.isfile(target_csv_template):
return
path_to_original_csv = os.path.join(extracted_dir, 'data.csv')
with open(path_to_original_csv) as csv_f:
data = [
d for d in csv.DictReader(csv_f, delimiter=',')
if float(d['duration']) <= MAX_SECS
]
# Keep track of how many samples are good vs. problematic
counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0}
lock = RLock()
num_samples = len(data)
rows = []
wav_root_dir = extracted_dir
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """ """ Take a audio file, and optionally convert it to 16kHz WAV """
orig_filename = path.join(wav_root_dir, sample['path']) orig_filename = sample['path']
# Storing wav files next to the wav ones - just with a different suffix # Storing wav files next to the wav ones - just with a different suffix
wav_filename = path.splitext(orig_filename)[0] + ".converted.wav" wav_filename = path.splitext(orig_filename)[0] + ".converted.wav"
_maybe_convert_wav(orig_filename, wav_filename) _maybe_convert_wav(orig_filename, wav_filename)
@ -92,7 +68,11 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
file_size = path.getsize(wav_filename) file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = sample['text'] label = sample['text']
with lock:
rows = []
# Keep track of how many samples are good vs. problematic
counter = get_counter()
if file_size == -1: if file_size == -1:
# Excluding samples that failed upon conversion # Excluding samples that failed upon conversion
counter['failed'] += 1 counter['failed'] += 1
@ -111,10 +91,35 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
counter['all'] += 1 counter['all'] += 1
counter['total_time'] += frames counter['total_time'] += frames
print("Importing wav files...") return (counter, rows)
pool = Pool(cpu_count())
def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
extracted_dir = path.join(target_dir, extracted_data)
# override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, 'ts_' + ARCHIVE_NAME + '_{}.csv')
if os.path.isfile(target_csv_template):
return
path_to_original_csv = os.path.join(extracted_dir, 'data.csv')
with open(path_to_original_csv) as csv_f:
data = [
d for d in csv.DictReader(csv_f, delimiter=',')
if float(d['duration']) <= MAX_SECS
]
for line in data:
line['path'] = os.path.join(extracted_dir, line['path'])
num_samples = len(data)
rows = []
counter = get_counter()
print("Importing {} wav files...".format(num_samples))
pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, data), start=1): for i, processed in enumerate(pool.imap_unordered(one_sample, data), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i) bar.update(i)
bar.update(num_samples) bar.update(num_samples)
pool.close() pool.close()
@ -131,7 +136,6 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
test_writer.writeheader() test_writer.writeheader()
for i, item in enumerate(rows): for i, item in enumerate(rows):
print('item', item)
transcript = validate_label(cleanup_transcript(item[2], english_compatible=english_compatible)) transcript = validate_label(cleanup_transcript(item[2], english_compatible=english_compatible))
if not transcript: if not transcript:
continue continue
@ -149,16 +153,11 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
transcript=transcript, transcript=transcript,
)) ))
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) imported_samples = get_imported_samples(counter)
if counter['failed'] > 0: assert counter['all'] == num_samples
print('Skipped %d samples that failed upon conversion.' % counter['failed']) assert len(rows) == imported_samples
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE))
def _maybe_convert_wav(orig_filename, wav_filename): def _maybe_convert_wav(orig_filename, wav_filename):
if not path.exists(wav_filename): if not path.exists(wav_filename):

View File

@ -14,13 +14,14 @@ import sys
sys.path.insert(1, os.path.join(sys.path[0], "..")) sys.path.insert(1, os.path.join(sys.path[0], ".."))
from util.importers import get_counter, get_imported_samples, print_import_report
import re import re
import librosa import librosa
import progressbar import progressbar
from os import path from os import path
from multiprocessing.dummy import Pool from multiprocessing import Pool
from multiprocessing import cpu_count
from util.downloader import maybe_download, SIMPLE_BAR from util.downloader import maybe_download, SIMPLE_BAR
from zipfile import ZipFile from zipfile import ZipFile
@ -61,23 +62,27 @@ def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data, "wav48") extracted_dir = path.join(target_dir, extracted_data, "wav48")
txt_dir = path.join(target_dir, extracted_data, "txt") txt_dir = path.join(target_dir, extracted_data, "txt")
cnt = 1
directory = os.path.expanduser(extracted_dir) directory = os.path.expanduser(extracted_dir)
srtd = len(sorted(os.listdir(directory))) srtd = len(sorted(os.listdir(directory)))
all_samples = []
for target in sorted(os.listdir(directory)): for target in sorted(os.listdir(directory)):
print(f"\nSpeaker {cnt} of {srtd}") all_samples += _maybe_prepare_set(path.join(extracted_dir, os.path.split(target)[-1]))
_maybe_convert_set(path.join(extracted_dir, os.path.split(target)[-1]))
cnt += 1 num_samples = len(all_samples)
print(f"Converting wav files to {SAMPLE_RATE}hz...")
pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, all_samples), start=1):
bar.update(i)
bar.update(num_samples)
pool.close()
pool.join()
_write_csv(extracted_dir, txt_dir, target_dir) _write_csv(extracted_dir, txt_dir, target_dir)
def one_sample(sample):
def _maybe_convert_set(target_csv):
def one_sample(sample):
if is_audio_file(sample): if is_audio_file(sample):
sample = os.path.join(target_csv, sample)
y, sr = librosa.load(sample, sr=16000) y, sr = librosa.load(sample, sr=16000)
# Trim the beginning and ending silence # Trim the beginning and ending silence
@ -89,19 +94,14 @@ def _maybe_convert_set(target_csv):
else: else:
librosa.output.write_wav(sample, yt, sr) librosa.output.write_wav(sample, yt, sr)
def _maybe_prepare_set(target_csv):
samples = sorted(os.listdir(target_csv)) samples = sorted(os.listdir(target_csv))
new_samples = []
num_samples = len(samples) for s in samples:
new_samples.append(os.path.join(target_csv, s))
print(f"Converting wav files to {SAMPLE_RATE}hz...") samples = new_samples
pool = Pool(cpu_count()) return samples
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1):
bar.update(i)
bar.update(num_samples)
pool.close()
pool.join()
def _write_csv(extracted_dir, txt_dir, target_dir): def _write_csv(extracted_dir, txt_dir, target_dir):
print(f"Writing CSV file") print(f"Writing CSV file")
@ -196,8 +196,8 @@ def load_txts(directory):
AUDIO_EXTENSIONS = [".wav", "WAV"] AUDIO_EXTENSIONS = [".wav", "WAV"]
def is_audio_file(filename): def is_audio_file(filepath):
return any(filename.endswith(extension) for extension in AUDIO_EXTENSIONS) return any(os.path.basename(filepath).endswith(extension) for extension in AUDIO_EXTENSIONS)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,2 +1,3 @@
absl-py absl-py
argparse argparse
semver

View File

@ -4,6 +4,27 @@ import os
import re import re
import sys import sys
from util.helpers import secs_to_hours
from collections import Counter
def get_counter():
return Counter({'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0})
def get_imported_samples(counter):
return counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'] - counter['invalid_label']
def print_import_report(counter, sample_rate, max_secs):
print('Imported %d samples.' % (get_imported_samples(counter)))
if counter['failed'] > 0:
print('Skipped %d samples that failed upon conversion.' % counter['failed'])
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label'])
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], max_secs))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / sample_rate))
def get_importers_parser(description): def get_importers_parser(description):
parser = argparse.ArgumentParser(description=description) parser = argparse.ArgumentParser(description=description)
parser.add_argument('--validate_label_locale', help='Path to a Python file defining a |validate_label| function for your locale. WARNING: THIS WILL ADD THIS FILE\'s DIRECTORY INTO PYTHONPATH.') parser.add_argument('--validate_label_locale', help='Path to a Python file defining a |validate_label| function for your locale. WARNING: THIS WILL ADD THIS FILE\'s DIRECTORY INTO PYTHONPATH.')