Merge pull request #2818 from lissyx/validate_label_locale+multiprocessing.notDummy

Validate label locale+multiprocessing.not dummy
This commit is contained in:
lissyx 2020-03-19 10:14:06 +01:00 committed by GitHub
commit ff9a720764
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 480 additions and 398 deletions

View File

@ -7,7 +7,7 @@ import os
import sys import sys
sys.path.insert(1, os.path.join(sys.path[0], '..')) sys.path.insert(1, os.path.join(sys.path[0], '..'))
import argparse from util.importers import get_importers_parser
import glob import glob
import pandas import pandas
import tarfile import tarfile
@ -81,7 +81,7 @@ def preprocess_data(tgz_file, target_dir):
def main(): def main():
# https://www.openslr.org/62/ # https://www.openslr.org/62/
parser = argparse.ArgumentParser(description='Import aidatatang_200zh corpus') parser = get_importers_parser(description='Import aidatatang_200zh corpus')
parser.add_argument('tgz_file', help='Path to aidatatang_200zh.tgz') parser.add_argument('tgz_file', help='Path to aidatatang_200zh.tgz')
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.') parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
params = parser.parse_args() params = parser.parse_args()

View File

@ -7,7 +7,7 @@ import os
import sys import sys
sys.path.insert(1, os.path.join(sys.path[0], '..')) sys.path.insert(1, os.path.join(sys.path[0], '..'))
import argparse from util.importers import get_importers_parser
import glob import glob
import tarfile import tarfile
import pandas import pandas
@ -80,7 +80,7 @@ def preprocess_data(tgz_file, target_dir):
def main(): def main():
# http://www.openslr.org/33/ # http://www.openslr.org/33/
parser = argparse.ArgumentParser(description='Import AISHELL corpus') parser = get_importers_parser(description='Import AISHELL corpus')
parser.add_argument('aishell_tgz_file', help='Path to data_aishell.tgz') parser.add_argument('aishell_tgz_file', help='Path to data_aishell.tgz')
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.') parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
params = parser.parse_args() params = parser.parse_args()

View File

@ -15,10 +15,8 @@ import progressbar
from glob import glob from glob import glob
from os import path from os import path
from threading import RLock from multiprocessing import Pool
from multiprocessing.dummy import Pool from util.importers import validate_label_eng as validate_label, get_counter, get_imported_samples, print_import_report
from multiprocessing import cpu_count
from util.text import validate_label
from util.downloader import maybe_download, SIMPLE_BAR from util.downloader import maybe_download, SIMPLE_BAR
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
@ -53,27 +51,8 @@ def _maybe_convert_sets(target_dir, extracted_data):
for source_csv in glob(path.join(extracted_dir, '*.csv')): for source_csv in glob(path.join(extracted_dir, '*.csv')):
_maybe_convert_set(extracted_dir, source_csv, path.join(target_dir, os.path.split(source_csv)[-1])) _maybe_convert_set(extracted_dir, source_csv, path.join(target_dir, os.path.split(source_csv)[-1]))
def _maybe_convert_set(extracted_dir, source_csv, target_csv):
print()
if path.exists(target_csv):
print('Found CSV file "%s" - not importing "%s".' % (target_csv, source_csv))
return
print('No CSV file "%s" - importing "%s"...' % (target_csv, source_csv))
samples = []
with open(source_csv) as source_csv_file:
reader = csv.DictReader(source_csv_file)
for row in reader:
samples.append((row['filename'], row['text']))
# Mutable counters for the concurrent embedded routine
counter = { 'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0 }
lock = RLock()
num_samples = len(samples)
rows = []
def one_sample(sample): def one_sample(sample):
mp3_filename = path.join(*(sample[0].split('/'))) mp3_filename = sample[0]
mp3_filename = path.join(extracted_dir, mp3_filename)
# Storing wav files next to the mp3 ones - just with a different suffix # Storing wav files next to the mp3 ones - just with a different suffix
wav_filename = path.splitext(mp3_filename)[0] + ".wav" wav_filename = path.splitext(mp3_filename)[0] + ".wav"
_maybe_convert_wav(mp3_filename, wav_filename) _maybe_convert_wav(mp3_filename, wav_filename)
@ -83,7 +62,8 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
file_size = path.getsize(wav_filename) file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = validate_label(sample[1]) label = validate_label(sample[1])
with lock: rows = []
counter = get_counter()
if file_size == -1: if file_size == -1:
# Excluding samples that failed upon conversion # Excluding samples that failed upon conversion
counter['failed'] += 1 counter['failed'] += 1
@ -100,11 +80,32 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
# This one is good - keep it for the target CSV # This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label)) rows.append((wav_filename, file_size, label))
counter['all'] += 1 counter['all'] += 1
counter['total_time'] += frames
return (counter, rows)
def _maybe_convert_set(extracted_dir, source_csv, target_csv):
print()
if path.exists(target_csv):
print('Found CSV file "%s" - not importing "%s".' % (target_csv, source_csv))
return
print('No CSV file "%s" - importing "%s"...' % (target_csv, source_csv))
samples = []
with open(source_csv) as source_csv_file:
reader = csv.DictReader(source_csv_file)
for row in reader:
samples.append((os.path.join(extracted_dir, row['filename']), row['text']))
# Mutable counters for the concurrent embedded routine
counter = get_counter()
num_samples = len(samples)
rows = []
print('Importing mp3 files...') print('Importing mp3 files...')
pool = Pool(cpu_count()) pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i) bar.update(i)
bar.update(num_samples) bar.update(num_samples)
pool.close() pool.close()
@ -118,15 +119,11 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
for filename, file_size, transcript in bar(rows): for filename, file_size, transcript in bar(rows):
writer.writerow({ 'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript }) writer.writerow({ 'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript })
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) imported_samples = get_imported_samples(counter)
if counter['failed'] > 0: assert counter['all'] == num_samples
print('Skipped %d samples that failed upon conversion.' % counter['failed']) assert len(rows) == imported_samples
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
def _maybe_convert_wav(mp3_filename, wav_filename): def _maybe_convert_wav(mp3_filename, wav_filename):
if not path.exists(wav_filename): if not path.exists(wav_filename):

View File

@ -16,18 +16,15 @@ sys.path.insert(1, os.path.join(sys.path[0], '..'))
import csv import csv
import sox import sox
import argparse
import subprocess import subprocess
import progressbar import progressbar
import unicodedata import unicodedata
from os import path from os import path
from threading import RLock from multiprocessing import Pool
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from util.downloader import SIMPLE_BAR from util.downloader import SIMPLE_BAR
from util.text import Alphabet, validate_label from util.text import Alphabet
from util.helpers import secs_to_hours from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
@ -35,34 +32,16 @@ SAMPLE_RATE = 16000
MAX_SECS = 10 MAX_SECS = 10
def _preprocess_data(tsv_dir, audio_dir, label_filter, space_after_every_character=False): def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
for dataset in ['train', 'test', 'dev', 'validated', 'other']: for dataset in ['train', 'test', 'dev', 'validated', 'other']:
input_tsv = path.join(path.abspath(tsv_dir), dataset+".tsv") input_tsv = path.join(path.abspath(tsv_dir), dataset+".tsv")
if os.path.isfile(input_tsv): if os.path.isfile(input_tsv):
print("Loading TSV file: ", input_tsv) print("Loading TSV file: ", input_tsv)
_maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_character) _maybe_convert_set(input_tsv, audio_dir, space_after_every_character)
def _maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_character=None):
output_csv = path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv'))
print("Saving new DeepSpeech-formatted CSV file to: ", output_csv)
# Get audiofile path and transcript for each sentence in tsv
samples = []
with open(input_tsv, encoding='utf-8') as input_tsv_file:
reader = csv.DictReader(input_tsv_file, delimiter='\t')
for row in reader:
samples.append((row['path'], row['sentence']))
# Keep track of how many samples are good vs. problematic
counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0}
lock = RLock()
num_samples = len(samples)
rows = []
def one_sample(sample): def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """ """ Take a audio file, and optionally convert it to 16kHz WAV """
mp3_filename = path.join(audio_dir, sample[0]) mp3_filename = sample[0]
if not path.splitext(mp3_filename.lower())[1] == '.mp3': if not path.splitext(mp3_filename.lower())[1] == '.mp3':
mp3_filename += ".mp3" mp3_filename += ".mp3"
# Storing wav files next to the mp3 ones - just with a different suffix # Storing wav files next to the mp3 ones - just with a different suffix
@ -73,8 +52,9 @@ def _maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_cha
if path.exists(wav_filename): if path.exists(wav_filename):
file_size = path.getsize(wav_filename) file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1]) label = label_filter_fun(sample[1])
with lock: rows = []
counter = get_counter()
if file_size == -1: if file_size == -1:
# Excluding samples that failed upon conversion # Excluding samples that failed upon conversion
counter['failed'] += 1 counter['failed'] += 1
@ -93,10 +73,29 @@ def _maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_cha
counter['all'] += 1 counter['all'] += 1
counter['total_time'] += frames counter['total_time'] += frames
return (counter, rows)
def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
output_csv = path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv'))
print("Saving new DeepSpeech-formatted CSV file to: ", output_csv)
# Get audiofile path and transcript for each sentence in tsv
samples = []
with open(input_tsv, encoding='utf-8') as input_tsv_file:
reader = csv.DictReader(input_tsv_file, delimiter='\t')
for row in reader:
samples.append((path.join(audio_dir, row['path']), row['sentence']))
counter = get_counter()
num_samples = len(samples)
rows = []
print("Importing mp3 files...") print("Importing mp3 files...")
pool = Pool(cpu_count()) pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i) bar.update(i)
bar.update(num_samples) bar.update(num_samples)
pool.close() pool.close()
@ -113,16 +112,11 @@ def _maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_cha
else: else:
writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript}) writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript})
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) imported_samples = get_imported_samples(counter)
if counter['failed'] > 0: assert counter['all'] == num_samples
print('Skipped %d samples that failed upon conversion.' % counter['failed']) assert len(rows) == imported_samples
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE))
def _maybe_convert_wav(mp3_filename, wav_filename): def _maybe_convert_wav(mp3_filename, wav_filename):
@ -136,7 +130,7 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
if __name__ == "__main__": if __name__ == "__main__":
PARSER = argparse.ArgumentParser(description='Import CommonVoice v2.0 corpora') PARSER = get_importers_parser(description='Import CommonVoice v2.0 corpora')
PARSER.add_argument('tsv_dir', help='Directory containing tsv files') PARSER.add_argument('tsv_dir', help='Directory containing tsv files')
PARSER.add_argument('--audio_dir', help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"') PARSER.add_argument('--audio_dir', help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"')
PARSER.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet') PARSER.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
@ -144,6 +138,7 @@ if __name__ == "__main__":
PARSER.add_argument('--space_after_every_character', action='store_true', help='To help transcript join by white space') PARSER.add_argument('--space_after_every_character', action='store_true', help='To help transcript join by white space')
PARAMS = PARSER.parse_args() PARAMS = PARSER.parse_args()
validate_label = get_validate_label(PARAMS)
AUDIO_DIR = PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, 'clips') AUDIO_DIR = PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, 'clips')
ALPHABET = Alphabet(PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None ALPHABET = Alphabet(PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None
@ -161,4 +156,4 @@ if __name__ == "__main__":
label = None label = None
return label return label
_preprocess_data(PARAMS.tsv_dir, AUDIO_DIR, label_filter_fun, PARAMS.space_after_every_character) _preprocess_data(PARAMS.tsv_dir, AUDIO_DIR, PARAMS.space_after_every_character)

View File

@ -19,7 +19,7 @@ import unicodedata
import librosa import librosa
import soundfile # <= Has an external dependency on libsndfile import soundfile # <= Has an external dependency on libsndfile
from util.text import validate_label from util.importers import validate_label_eng as validate_label
def _download_and_preprocess_data(data_dir): def _download_and_preprocess_data(data_dir):
# Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19 # Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19

View File

@ -7,7 +7,7 @@ import os
import sys import sys
sys.path.insert(1, os.path.join(sys.path[0], '..')) sys.path.insert(1, os.path.join(sys.path[0], '..'))
import argparse from util.importers import get_importers_parser
import glob import glob
import numpy as np import numpy as np
import pandas import pandas
@ -81,7 +81,7 @@ def preprocess_data(tgz_file, target_dir):
def main(): def main():
# https://www.openslr.org/38/ # https://www.openslr.org/38/
parser = argparse.ArgumentParser(description='Import Free ST Chinese Mandarin corpus') parser = get_importers_parser(description='Import Free ST Chinese Mandarin corpus')
parser.add_argument('tgz_file', help='Path to ST-CMDS-20170001_1-OS.tar.gz') parser.add_argument('tgz_file', help='Path to ST-CMDS-20170001_1-OS.tar.gz')
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.') parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
params = parser.parse_args() params = parser.parse_args()

View File

@ -1,12 +1,16 @@
#!/usr/bin/env python #!/usr/bin/env python
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os import os
import csv
import sys import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import csv
import math import math
import urllib import urllib
import logging import logging
import argparse from util.importers import get_importers_parser, get_validate_label
import subprocess import subprocess
from os import path from os import path
from pathlib import Path from pathlib import Path
@ -15,8 +19,6 @@ import swifter
import pandas as pd import pandas as pd
from sox import Transformer from sox import Transformer
from util.text import validate_label
__version__ = "0.1.0" __version__ = "0.1.0"
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -38,7 +40,7 @@ def parse_args(args):
Returns: Returns:
:obj:`argparse.Namespace`: command line parameters namespace :obj:`argparse.Namespace`: command line parameters namespace
""" """
parser = argparse.ArgumentParser( parser = get_importers_parser(
description="Imports GramVaani data for Deep Speech" description="Imports GramVaani data for Deep Speech"
) )
parser.add_argument( parser.add_argument(
@ -286,6 +288,7 @@ def main(args):
args ([str]): command line parameter list args ([str]): command line parameter list
""" """
args = parse_args(args) args = parse_args(args)
validate_label = get_validate_label(args)
setup_logging(args.loglevel) setup_logging(args.loglevel)
_logger.info("Starting GramVaani importer...") _logger.info("Starting GramVaani importer...")
_logger.info("Starting loading GramVaani csv...") _logger.info("Starting loading GramVaani csv...")

View File

@ -3,13 +3,13 @@ from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/ # Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository # This script needs to be run from the root of the DeepSpeech repository
import argparse
import os import os
import sys import sys
sys.path.insert(1, os.path.join(sys.path[0], '..')) sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import argparse
import csv import csv
import re import re
import sox import sox
@ -18,17 +18,14 @@ import subprocess
import progressbar import progressbar
import unicodedata import unicodedata
from threading import RLock from multiprocessing import Pool
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from util.downloader import SIMPLE_BAR from util.downloader import SIMPLE_BAR
from os import path from os import path
from glob import glob from glob import glob
from util.downloader import maybe_download from util.downloader import maybe_download
from util.text import Alphabet, validate_label from util.text import Alphabet
from util.helpers import secs_to_hours
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -61,32 +58,9 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
else: else:
print('Found directory "%s" - not extracting it from archive.' % archive_path) print('Found directory "%s" - not extracting it from archive.' % archive_path)
def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data)
# override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME + '_' + ARCHIVE_NAME.replace('.zip', '_{}.csv'))
if os.path.isfile(target_csv_template):
return
ogg_root_dir = os.path.join(extracted_dir, ARCHIVE_NAME.replace('.zip', ''))
# Get audiofile path and transcript for each sentence in tsv
samples = []
glob_dir = os.path.join(ogg_root_dir, '**/*.ogg')
for record in glob(glob_dir, recursive=True):
record_file = record.replace(ogg_root_dir + os.path.sep, '')
if record_filter(record_file):
samples.append((record_file, os.path.splitext(os.path.basename(record_file))[0]))
# Keep track of how many samples are good vs. problematic
counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0}
lock = RLock()
num_samples = len(samples)
rows = []
def one_sample(sample): def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """ """ Take a audio file, and optionally convert it to 16kHz WAV """
ogg_filename = path.join(ogg_root_dir, sample[0]) ogg_filename = sample[0]
# Storing wav files next to the ogg ones - just with a different suffix # Storing wav files next to the ogg ones - just with a different suffix
wav_filename = path.splitext(ogg_filename)[0] + ".wav" wav_filename = path.splitext(ogg_filename)[0] + ".wav"
_maybe_convert_wav(ogg_filename, wav_filename) _maybe_convert_wav(ogg_filename, wav_filename)
@ -96,7 +70,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
file_size = path.getsize(wav_filename) file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1]) label = label_filter(sample[1])
with lock: rows = []
counter = get_counter()
if file_size == -1: if file_size == -1:
# Excluding samples that failed upon conversion # Excluding samples that failed upon conversion
counter['failed'] += 1 counter['failed'] += 1
@ -115,10 +91,35 @@ def _maybe_convert_sets(target_dir, extracted_data):
counter['all'] += 1 counter['all'] += 1
counter['total_time'] += frames counter['total_time'] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data)
# override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME + '_' + ARCHIVE_NAME.replace('.zip', '_{}.csv'))
if os.path.isfile(target_csv_template):
return
ogg_root_dir = os.path.join(extracted_dir, ARCHIVE_NAME.replace('.zip', ''))
# Get audiofile path and transcript for each sentence in tsv
samples = []
glob_dir = os.path.join(ogg_root_dir, '**/*.ogg')
for record in glob(glob_dir, recursive=True):
record_file = record.replace(ogg_root_dir + os.path.sep, '')
if record_filter(record_file):
samples.append((os.path.join(ogg_root_dir, record_file), os.path.splitext(os.path.basename(record_file))[0]))
counter = get_counter()
num_samples = len(samples)
rows = []
print("Importing ogg files...") print("Importing ogg files...")
pool = Pool(cpu_count()) pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i) bar.update(i)
bar.update(num_samples) bar.update(num_samples)
pool.close() pool.close()
@ -152,16 +153,11 @@ def _maybe_convert_sets(target_dir, extracted_data):
transcript=transcript, transcript=transcript,
)) ))
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) imported_samples = get_imported_samples(counter)
if counter['failed'] > 0: assert counter['all'] == num_samples
print('Skipped %d samples that failed upon conversion.' % counter['failed']) assert len(rows) == imported_samples
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE))
def _maybe_convert_wav(ogg_filename, wav_filename): def _maybe_convert_wav(ogg_filename, wav_filename):
if not path.exists(wav_filename): if not path.exists(wav_filename):
@ -173,7 +169,7 @@ def _maybe_convert_wav(ogg_filename, wav_filename):
print('SoX processing error', ex, ogg_filename, wav_filename) print('SoX processing error', ex, ogg_filename, wav_filename)
def handle_args(): def handle_args():
parser = argparse.ArgumentParser(description='Importer for LinguaLibre dataset. Check https://lingualibre.fr/wiki/Help:Download_from_LinguaLibre for details.') parser = get_importers_parser(description='Importer for LinguaLibre dataset. Check https://lingualibre.fr/wiki/Help:Download_from_LinguaLibre for details.')
parser.add_argument(dest='target_dir') parser.add_argument(dest='target_dir')
parser.add_argument('--qId', type=int, required=True, help='LinguaLibre language qId') parser.add_argument('--qId', type=int, required=True, help='LinguaLibre language qId')
parser.add_argument('--iso639-3', type=str, required=True, help='ISO639-3 language code') parser.add_argument('--iso639-3', type=str, required=True, help='ISO639-3 language code')
@ -186,6 +182,7 @@ def handle_args():
if __name__ == "__main__": if __name__ == "__main__":
CLI_ARGS = handle_args() CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
validate_label = get_validate_label(CLI_ARGS)
bogus_regexes = [] bogus_regexes = []
if CLI_ARGS.bogus_records: if CLI_ARGS.bogus_records:

View File

@ -4,29 +4,27 @@ from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/ # Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository # This script needs to be run from the root of the DeepSpeech repository
import argparse
import os import os
import sys import sys
sys.path.insert(1, os.path.join(sys.path[0], '..')) sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv import csv
import subprocess import subprocess
import progressbar import progressbar
import unicodedata import unicodedata
import tarfile import tarfile
from threading import RLock from multiprocessing import Pool
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from util.downloader import SIMPLE_BAR from util.downloader import SIMPLE_BAR
from os import path from os import path
from glob import glob from glob import glob
from util.downloader import maybe_download from util.downloader import maybe_download
from util.text import Alphabet, validate_label from util.text import Alphabet
from util.helpers import secs_to_hours
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -62,6 +60,38 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
print('Found directory "%s" - not extracting it from archive.' % archive_path) print('Found directory "%s" - not extracting it from archive.' % archive_path)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
wav_filename = sample[0]
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
counter = get_counter()
rows = []
if file_size == -1:
# Excluding samples that failed upon conversion
print("conversion failure", wav_filename)
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data) extracted_dir = path.join(target_dir, extracted_data)
# override existing CSV with normalized one # override existing CSV with normalized one
@ -84,44 +114,16 @@ def _maybe_convert_sets(target_dir, extracted_data):
transcript = re[2] transcript = re[2]
samples.append((audio, transcript)) samples.append((audio, transcript))
# Keep track of how many samples are good vs. problematic counter = get_counter()
counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0}
lock = RLock()
num_samples = len(samples) num_samples = len(samples)
rows = [] rows = []
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
wav_filename = sample[0]
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
with lock:
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
print("Importing WAV files...") print("Importing WAV files...")
pool = Pool(cpu_count()) pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i) bar.update(i)
bar.update(num_samples) bar.update(num_samples)
pool.close() pool.close()
@ -155,20 +157,14 @@ def _maybe_convert_sets(target_dir, extracted_data):
transcript=transcript, transcript=transcript,
)) ))
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) imported_samples = get_imported_samples(counter)
if counter['failed'] > 0: assert counter['all'] == num_samples
print('Skipped %d samples that failed upon conversion.' % counter['failed']) assert len(rows) == imported_samples
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label'])
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE))
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def handle_args(): def handle_args():
parser = argparse.ArgumentParser(description='Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/.') parser = get_importers_parser(description='Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/.')
parser.add_argument(dest='target_dir') parser.add_argument(dest='target_dir')
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet') parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones') parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
@ -181,6 +177,7 @@ if __name__ == "__main__":
CLI_ARGS = handle_args() CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(',')) SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(','))
validate_label = get_validate_label(CLI_ARGS)
def label_filter(label): def label_filter(label):
if CLI_ARGS.normalize: if CLI_ARGS.normalize:

View File

@ -7,7 +7,7 @@ import os
import sys import sys
sys.path.insert(1, os.path.join(sys.path[0], '..')) sys.path.insert(1, os.path.join(sys.path[0], '..'))
import argparse from util.importers import get_importers_parser
import glob import glob
import pandas import pandas
import tarfile import tarfile
@ -99,7 +99,7 @@ def preprocess_data(folder_with_archives, target_dir):
def main(): def main():
# https://openslr.org/68/ # https://openslr.org/68/
parser = argparse.ArgumentParser(description='Import MAGICDATA corpus') parser = get_importers_parser(description='Import MAGICDATA corpus')
parser.add_argument('folder_with_archives', help='Path to folder containing magicdata_{train,dev,test}.tar.gz') parser.add_argument('folder_with_archives', help='Path to folder containing magicdata_{train,dev,test}.tar.gz')
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to a folder called magicdata next to the archives') parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to a folder called magicdata next to the archives')
params = parser.parse_args() params = parser.parse_args()

View File

@ -7,7 +7,7 @@ import os
import sys import sys
sys.path.insert(1, os.path.join(sys.path[0], '..')) sys.path.insert(1, os.path.join(sys.path[0], '..'))
import argparse from util.importers import get_importers_parser
import glob import glob
import json import json
import numpy as np import numpy as np
@ -93,7 +93,7 @@ def preprocess_data(tgz_file, target_dir):
def main(): def main():
# https://www.openslr.org/47/ # https://www.openslr.org/47/
parser = argparse.ArgumentParser(description='Import Primewords Chinese corpus set 1') parser = get_importers_parser(description='Import Primewords Chinese corpus set 1')
parser.add_argument('tgz_file', help='Path to primewords_md_2018_set1.tar.gz') parser.add_argument('tgz_file', help='Path to primewords_md_2018_set1.tar.gz')
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.') parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
params = parser.parse_args() params = parser.parse_args()

View File

@ -3,13 +3,12 @@ from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/ # Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository # This script needs to be run from the root of the DeepSpeech repository
import argparse
import os import os
import sys import sys
sys.path.insert(1, os.path.join(sys.path[0], '..')) sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv import csv
import re import re
import sox import sox
@ -19,16 +18,14 @@ import progressbar
import unicodedata import unicodedata
import tarfile import tarfile
from threading import RLock from multiprocessing import Pool
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from util.downloader import SIMPLE_BAR from util.downloader import SIMPLE_BAR
from os import path from os import path
from glob import glob from glob import glob
from util.downloader import maybe_download from util.downloader import maybe_download
from util.text import Alphabet, validate_label from util.text import Alphabet
from util.helpers import secs_to_hours from util.helpers import secs_to_hours
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
@ -63,6 +60,37 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
else: else:
print('Found directory "%s" - not extracting it from archive.' % archive_path) print('Found directory "%s" - not extracting it from archive.' % archive_path)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
wav_filename = sample[0]
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
counter = get_counter()
rows = []
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data) extracted_dir = path.join(target_dir, extracted_data)
# override existing CSV with normalized one # override existing CSV with normalized one
@ -113,43 +141,16 @@ def _maybe_convert_sets(target_dir, extracted_data):
samples.append((record, transcripts[record_file])) samples.append((record, transcripts[record_file]))
# Keep track of how many samples are good vs. problematic # Keep track of how many samples are good vs. problematic
counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0} counter = get_counter()
lock = RLock()
num_samples = len(samples) num_samples = len(samples)
rows = [] rows = []
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
wav_filename = sample[0]
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
with lock:
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
print("Importing WAV files...") print("Importing WAV files...")
pool = Pool(cpu_count()) pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i) bar.update(i)
bar.update(num_samples) bar.update(num_samples)
pool.close() pool.close()
@ -183,19 +184,14 @@ def _maybe_convert_sets(target_dir, extracted_data):
transcript=transcript, transcript=transcript,
)) ))
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) imported_samples = get_imported_samples(counter)
if counter['failed'] > 0: assert counter['all'] == num_samples
print('Skipped %d samples that failed upon conversion.' % counter['failed']) assert len(rows) == imported_samples
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE))
def handle_args(): def handle_args():
parser = argparse.ArgumentParser(description='Importer for African Accented French dataset. More information on http://www.openslr.org/57/.') parser = get_importers_parser(description='Importer for African Accented French dataset. More information on http://www.openslr.org/57/.')
parser.add_argument(dest='target_dir') parser.add_argument(dest='target_dir')
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet') parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones') parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
@ -204,6 +200,7 @@ def handle_args():
if __name__ == "__main__": if __name__ == "__main__":
CLI_ARGS = handle_args() CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
validate_label = get_validate_label(CLI_ARGS)
def label_filter(label): def label_filter(label):
if CLI_ARGS.normalize: if CLI_ARGS.normalize:

View File

@ -20,7 +20,7 @@ import wave
import codecs import codecs
import tarfile import tarfile
import requests import requests
from util.text import validate_label from util.importers import validate_label_eng as validate_label
import librosa import librosa
import soundfile # <= Has an external dependency on libsndfile import soundfile # <= Has an external dependency on libsndfile

View File

@ -27,7 +27,8 @@ from os import path
from glob import glob from glob import glob
from collections import Counter from collections import Counter
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from util.text import Alphabet, validate_label from util.text import Alphabet
from util.importers import validate_label_eng as validate_label
from util.downloader import maybe_download, SIMPLE_BAR from util.downloader import maybe_download, SIMPLE_BAR
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar" SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"

View File

@ -3,14 +3,13 @@ from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/ # Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository # This script needs to be run from the root of the DeepSpeech repository
import argparse
import os import os
import re import re
import sys import sys
sys.path.insert(1, os.path.join(sys.path[0], '..')) sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv import csv
import unidecode import unidecode
import zipfile import zipfile
@ -18,16 +17,12 @@ import sox
import subprocess import subprocess
import progressbar import progressbar
from threading import RLock from multiprocessing import Pool
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from util.downloader import SIMPLE_BAR from util.downloader import SIMPLE_BAR
from os import path from os import path
from util.downloader import maybe_download from util.downloader import maybe_download
from util.text import validate_label
from util.helpers import secs_to_hours
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -61,30 +56,9 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
print('Found directory "%s" - not extracting it from archive.' % archive_path) print('Found directory "%s" - not extracting it from archive.' % archive_path)
def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
extracted_dir = path.join(target_dir, extracted_data)
# override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, 'ts_' + ARCHIVE_NAME + '_{}.csv')
if os.path.isfile(target_csv_template):
return
path_to_original_csv = os.path.join(extracted_dir, 'data.csv')
with open(path_to_original_csv) as csv_f:
data = [
d for d in csv.DictReader(csv_f, delimiter=',')
if float(d['duration']) <= MAX_SECS
]
# Keep track of how many samples are good vs. problematic
counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0}
lock = RLock()
num_samples = len(data)
rows = []
wav_root_dir = extracted_dir
def one_sample(sample): def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """ """ Take a audio file, and optionally convert it to 16kHz WAV """
orig_filename = path.join(wav_root_dir, sample['path']) orig_filename = sample['path']
# Storing wav files next to the wav ones - just with a different suffix # Storing wav files next to the wav ones - just with a different suffix
wav_filename = path.splitext(orig_filename)[0] + ".converted.wav" wav_filename = path.splitext(orig_filename)[0] + ".converted.wav"
_maybe_convert_wav(orig_filename, wav_filename) _maybe_convert_wav(orig_filename, wav_filename)
@ -94,7 +68,11 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
file_size = path.getsize(wav_filename) file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = sample['text'] label = sample['text']
with lock:
rows = []
# Keep track of how many samples are good vs. problematic
counter = get_counter()
if file_size == -1: if file_size == -1:
# Excluding samples that failed upon conversion # Excluding samples that failed upon conversion
counter['failed'] += 1 counter['failed'] += 1
@ -113,10 +91,35 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
counter['all'] += 1 counter['all'] += 1
counter['total_time'] += frames counter['total_time'] += frames
print("Importing wav files...") return (counter, rows)
pool = Pool(cpu_count())
def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
extracted_dir = path.join(target_dir, extracted_data)
# override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, 'ts_' + ARCHIVE_NAME + '_{}.csv')
if os.path.isfile(target_csv_template):
return
path_to_original_csv = os.path.join(extracted_dir, 'data.csv')
with open(path_to_original_csv) as csv_f:
data = [
d for d in csv.DictReader(csv_f, delimiter=',')
if float(d['duration']) <= MAX_SECS
]
for line in data:
line['path'] = os.path.join(extracted_dir, line['path'])
num_samples = len(data)
rows = []
counter = get_counter()
print("Importing {} wav files...".format(num_samples))
pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, data), start=1): for i, processed in enumerate(pool.imap_unordered(one_sample, data), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i) bar.update(i)
bar.update(num_samples) bar.update(num_samples)
pool.close() pool.close()
@ -133,7 +136,6 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
test_writer.writeheader() test_writer.writeheader()
for i, item in enumerate(rows): for i, item in enumerate(rows):
print('item', item)
transcript = validate_label(cleanup_transcript(item[2], english_compatible=english_compatible)) transcript = validate_label(cleanup_transcript(item[2], english_compatible=english_compatible))
if not transcript: if not transcript:
continue continue
@ -151,16 +153,11 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
transcript=transcript, transcript=transcript,
)) ))
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) imported_samples = get_imported_samples(counter)
if counter['failed'] > 0: assert counter['all'] == num_samples
print('Skipped %d samples that failed upon conversion.' % counter['failed']) assert len(rows) == imported_samples
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE))
def _maybe_convert_wav(orig_filename, wav_filename): def _maybe_convert_wav(orig_filename, wav_filename):
if not path.exists(wav_filename): if not path.exists(wav_filename):
@ -186,7 +183,7 @@ def cleanup_transcript(text, english_compatible=False):
def handle_args(): def handle_args():
parser = argparse.ArgumentParser(description='Importer for TrainingSpeech dataset.') parser = get_importers_parser(description='Importer for TrainingSpeech dataset.')
parser.add_argument(dest='target_dir') parser.add_argument(dest='target_dir')
parser.add_argument('--english-compatible', action='store_true', dest='english_compatible', help='Remove diactrics and other non-ascii chars.') parser.add_argument('--english-compatible', action='store_true', dest='english_compatible', help='Remove diactrics and other non-ascii chars.')
return parser.parse_args() return parser.parse_args()
@ -194,4 +191,5 @@ def handle_args():
if __name__ == "__main__": if __name__ == "__main__":
cli_args = handle_args() cli_args = handle_args()
validate_label = get_validate_label(cli_args)
_download_and_preprocess_data(cli_args.target_dir, cli_args.english_compatible) _download_and_preprocess_data(cli_args.target_dir, cli_args.english_compatible)

View File

@ -21,7 +21,8 @@ import xml.etree.cElementTree as ET
from os import path from os import path
from collections import Counter from collections import Counter
from util.text import Alphabet, validate_label from util.text import Alphabet
from util.importers import validate_label_eng as validate_label
from util.downloader import maybe_download, SIMPLE_BAR from util.downloader import maybe_download, SIMPLE_BAR
TUDA_VERSION = 'v2' TUDA_VERSION = 'v2'

View File

@ -14,13 +14,14 @@ import sys
sys.path.insert(1, os.path.join(sys.path[0], "..")) sys.path.insert(1, os.path.join(sys.path[0], ".."))
from util.importers import get_counter, get_imported_samples, print_import_report
import re import re
import librosa import librosa
import progressbar import progressbar
from os import path from os import path
from multiprocessing.dummy import Pool from multiprocessing import Pool
from multiprocessing import cpu_count
from util.downloader import maybe_download, SIMPLE_BAR from util.downloader import maybe_download, SIMPLE_BAR
from zipfile import ZipFile from zipfile import ZipFile
@ -61,23 +62,27 @@ def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data, "wav48") extracted_dir = path.join(target_dir, extracted_data, "wav48")
txt_dir = path.join(target_dir, extracted_data, "txt") txt_dir = path.join(target_dir, extracted_data, "txt")
cnt = 1
directory = os.path.expanduser(extracted_dir) directory = os.path.expanduser(extracted_dir)
srtd = len(sorted(os.listdir(directory))) srtd = len(sorted(os.listdir(directory)))
all_samples = []
for target in sorted(os.listdir(directory)): for target in sorted(os.listdir(directory)):
print(f"\nSpeaker {cnt} of {srtd}") all_samples += _maybe_prepare_set(path.join(extracted_dir, os.path.split(target)[-1]))
_maybe_convert_set(path.join(extracted_dir, os.path.split(target)[-1]))
cnt += 1 num_samples = len(all_samples)
print(f"Converting wav files to {SAMPLE_RATE}hz...")
pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, all_samples), start=1):
bar.update(i)
bar.update(num_samples)
pool.close()
pool.join()
_write_csv(extracted_dir, txt_dir, target_dir) _write_csv(extracted_dir, txt_dir, target_dir)
def _maybe_convert_set(target_csv):
def one_sample(sample): def one_sample(sample):
if is_audio_file(sample): if is_audio_file(sample):
sample = os.path.join(target_csv, sample)
y, sr = librosa.load(sample, sr=16000) y, sr = librosa.load(sample, sr=16000)
# Trim the beginning and ending silence # Trim the beginning and ending silence
@ -89,19 +94,14 @@ def _maybe_convert_set(target_csv):
else: else:
librosa.output.write_wav(sample, yt, sr) librosa.output.write_wav(sample, yt, sr)
def _maybe_prepare_set(target_csv):
samples = sorted(os.listdir(target_csv)) samples = sorted(os.listdir(target_csv))
new_samples = []
num_samples = len(samples) for s in samples:
new_samples.append(os.path.join(target_csv, s))
print(f"Converting wav files to {SAMPLE_RATE}hz...") samples = new_samples
pool = Pool(cpu_count()) return samples
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1):
bar.update(i)
bar.update(num_samples)
pool.close()
pool.join()
def _write_csv(extracted_dir, txt_dir, target_dir): def _write_csv(extracted_dir, txt_dir, target_dir):
print(f"Writing CSV file") print(f"Writing CSV file")
@ -196,8 +196,8 @@ def load_txts(directory):
AUDIO_EXTENSIONS = [".wav", "WAV"] AUDIO_EXTENSIONS = [".wav", "WAV"]
def is_audio_file(filename): def is_audio_file(filepath):
return any(filename.endswith(extension) for extension in AUDIO_EXTENSIONS) return any(os.path.basename(filepath).endswith(extension) for extension in AUDIO_EXTENSIONS)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1 +1,3 @@
absl-py absl-py
argparse
semver

77
util/importers.py Normal file
View File

@ -0,0 +1,77 @@
import argparse
import importlib
import os
import re
import sys
from util.helpers import secs_to_hours
from collections import Counter
def get_counter():
return Counter({'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0})
def get_imported_samples(counter):
return counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'] - counter['invalid_label']
def print_import_report(counter, sample_rate, max_secs):
print('Imported %d samples.' % (get_imported_samples(counter)))
if counter['failed'] > 0:
print('Skipped %d samples that failed upon conversion.' % counter['failed'])
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label'])
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], max_secs))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / sample_rate))
def get_importers_parser(description):
parser = argparse.ArgumentParser(description=description)
parser.add_argument('--validate_label_locale', help='Path to a Python file defining a |validate_label| function for your locale. WARNING: THIS WILL ADD THIS FILE\'s DIRECTORY INTO PYTHONPATH.')
return parser
def get_validate_label(args):
"""
Expects an argparse.Namespace argument to search for validate_label_locale parameter.
If found, this will modify Python's library search path and add the directory of the
file pointed by the validate_label_locale argument.
:param args: The importer's CLI argument object
:type args: argparse.Namespace
:return: The user-supplied validate_label function
:type: function
"""
if 'validate_label_locale' not in args or (args.validate_label_locale is None):
print('WARNING: No --validate_label_locale specified, your might end with inconsistent dataset.')
return validate_label_eng
if not os.path.exists(os.path.abspath(args.validate_label_locale)):
print('ERROR: Inexistent --validate_label_locale specified. Please check.')
return None
module_dir = os.path.abspath(os.path.dirname(args.validate_label_locale))
sys.path.insert(1, module_dir)
fname = os.path.basename(args.validate_label_locale).replace('.py', '')
locale_module = importlib.import_module(fname, package=None)
return locale_module.validate_label
# Validate and normalize transcriptions. Returns a cleaned version of the label
# or None if it's invalid.
def validate_label_eng(label):
# For now we can only handle [a-z ']
if re.search(r"[0-9]|[(<\[\]&*{]", label) is not None:
return None
label = label.replace("-", " ")
label = label.replace("_", " ")
label = re.sub("[ ]{2,}", " ", label)
label = label.replace(".", "")
label = label.replace(",", "")
label = label.replace(";", "")
label = label.replace("?", "")
label = label.replace("!", "")
label = label.replace(":", "")
label = label.replace("\"", "")
label = label.strip()
label = label.lower()
return label if label else None

View File

@ -0,0 +1,2 @@
def validate_label(label):
return label

38
util/test_importers.py Normal file
View File

@ -0,0 +1,38 @@
import unittest
from argparse import Namespace
from .importers import validate_label_eng, get_validate_label
class TestValidateLabelEng(unittest.TestCase):
def test_numbers(self):
label = validate_label_eng("this is a 1 2 3 test")
self.assertEqual(label, None)
class TestGetValidateLabel(unittest.TestCase):
def test_no_validate_label_locale(self):
f = get_validate_label(Namespace())
self.assertEqual(f('toto'), 'toto')
self.assertEqual(f('toto1234'), None)
self.assertEqual(f('toto1234[{[{[]'), None)
def test_validate_label_locale_default(self):
f = get_validate_label(Namespace(validate_label_locale=None))
self.assertEqual(f('toto'), 'toto')
self.assertEqual(f('toto1234'), None)
self.assertEqual(f('toto1234[{[{[]'), None)
def test_get_validate_label_missing(self):
args = Namespace(validate_label_locale='util/test_data/validate_locale_ger.py')
f = get_validate_label(args)
self.assertEqual(f, None)
def test_get_validate_label(self):
args = Namespace(validate_label_locale='util/test_data/validate_locale_fra.py')
f = get_validate_label(args)
l = f('toto')
self.assertEqual(l, 'toto')
if __name__ == '__main__':
unittest.main()

View File

@ -1,7 +1,6 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import numpy as np import numpy as np
import re
import struct import struct
from six.moves import range from six.moves import range
@ -166,25 +165,3 @@ def levenshtein(a, b):
current[j] = min(add, delete, change) current[j] = min(add, delete, change)
return current[n] return current[n]
# Validate and normalize transcriptions. Returns a cleaned version of the label
# or None if it's invalid.
def validate_label(label):
# For now we can only handle [a-z ']
if re.search(r"[0-9]|[(<\[\]&*{]", label) is not None:
return None
label = label.replace("-", " ")
label = label.replace("_", " ")
label = re.sub("[ ]{2,}", " ", label)
label = label.replace(".", "")
label = label.replace(",", "")
label = label.replace(";", "")
label = label.replace("?", "")
label = label.replace("!", "")
label = label.replace(":", "")
label = label.replace("\"", "")
label = label.strip()
label = label.lower()
return label if label else None