diff --git a/.travis.yml b/.travis.yml index a45b3e1b..10ca12d6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,11 +7,20 @@ before_cache: python: - "3.6" -install: - - pip install --upgrade cardboardlint pylint - -script: - # Run cardboardlinter, in case of pull requests - - if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then - cardboardlinter --refspec $TRAVIS_BRANCH -n auto; - fi +jobs: + include: + - stage: cardboard linter + install: + - pip install --upgrade cardboardlint pylint + script: + # Run cardboardlinter, in case of pull requests + - if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then + cardboardlinter --refspec $TRAVIS_BRANCH -n auto; + fi + - stage: python unit tests + install: + - pip install --upgrade -r requirements_tests.txt + script: + - if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then + python -m unittest; + fi diff --git a/DeepSpeech.py b/DeepSpeech.py index 8ebd1e25..92404e07 100644 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -33,7 +33,7 @@ from util.config import Config, initialize_globals from util.checkpoints import load_or_init_graph from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features from util.flags import create_flags, FLAGS -from util.helpers import check_ctcdecoder_version +from util.helpers import check_ctcdecoder_version, ExceptionBox from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar check_ctcdecoder_version() @@ -418,12 +418,17 @@ def train(): FLAGS.augmentation_sparse_warp): do_cache_dataset = False + exception_box = ExceptionBox() + # Create training and validation datasets train_set = create_dataset(FLAGS.train_files.split(','), batch_size=FLAGS.train_batch_size, enable_cache=FLAGS.feature_cache and do_cache_dataset, cache_path=FLAGS.feature_cache, - train_phase=True) + train_phase=True, + exception_box=exception_box, + process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2, + buffering=FLAGS.read_buffer) iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), tfv1.data.get_output_shapes(train_set), @@ -433,8 +438,13 @@ def train(): train_init_op = iterator.make_initializer(train_set) if FLAGS.dev_files: - dev_csvs = FLAGS.dev_files.split(',') - dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size, train_phase=False) for csv in dev_csvs] + dev_sources = FLAGS.dev_files.split(',') + dev_sets = [create_dataset([source], + batch_size=FLAGS.dev_batch_size, + train_phase=False, + exception_box=exception_box, + process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2, + buffering=FLAGS.read_buffer) for source in dev_sources] dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets] # Dropout @@ -540,6 +550,7 @@ def train(): _, current_step, batch_loss, problem_files, step_summary = \ session.run([train_op, global_step, loss, non_finite_files, step_summaries_op], feed_dict=feed_dict) + exception_box.raise_if_set() except tf.errors.InvalidArgumentError as err: if FLAGS.augmentation_sparse_warp: log_info("Ignoring sparse warp error: {}".format(err)) @@ -547,6 +558,7 @@ def train(): else: raise except tf.errors.OutOfRangeError: + exception_box.raise_if_set() break if problem_files.size > 0: @@ -586,12 +598,12 @@ def train(): # Validation dev_loss = 0.0 total_steps = 0 - for csv, init_op in zip(dev_csvs, dev_init_ops): - log_progress('Validating epoch %d on %s...' % (epoch, csv)) - set_loss, steps = run_set('dev', epoch, init_op, dataset=csv) + for source, init_op in zip(dev_sources, dev_init_ops): + log_progress('Validating epoch %d on %s...' % (epoch, source)) + set_loss, steps = run_set('dev', epoch, init_op, dataset=source) dev_loss += set_loss * steps total_steps += steps - log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss)) + log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss)) dev_loss = dev_loss / total_steps dev_losses.append(dev_loss) @@ -727,6 +739,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): return inputs, outputs, layers + def file_relative_read(fname): return open(os.path.join(os.path.dirname(__file__), fname)).read() @@ -768,7 +781,7 @@ def export(): method_order = [FLAGS.load] load_or_init_graph(session, method_order) - output_filename = FLAGS.export_name + '.pb' + output_filename = FLAGS.export_file_name + '.pb' if FLAGS.remove_export: if os.path.isdir(FLAGS.export_dir): log_info('Removing old export') @@ -805,21 +818,42 @@ def export(): log_info('Models exported at %s' % (FLAGS.export_dir)) + metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format( + FLAGS.export_author_id, + FLAGS.export_model_name, + FLAGS.export_model_version)) + + model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow' + with open(metadata_fname, 'w') as f: + f.write('---\n') + f.write('author: {}\n'.format(FLAGS.export_author_id)) + f.write('model_name: {}\n'.format(FLAGS.export_model_name)) + f.write('model_version: {}\n'.format(FLAGS.export_model_version)) + f.write('contact_info: {}\n'.format(FLAGS.export_contact_info)) + f.write('license: {}\n'.format(FLAGS.export_license)) + f.write('language: {}\n'.format(FLAGS.export_language)) + f.write('runtime: {}\n'.format(model_runtime)) + f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version)) + f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version)) + f.write('acoustic_model_url: \n') + f.write('scorer_url: \n') + f.write('---\n') + f.write('{}\n'.format(FLAGS.export_description)) + + log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname)) + + def package_zip(): # --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/' zip_filename = os.path.dirname(export_dir) - with open(os.path.join(export_dir, 'info.json'), 'w') as f: - json.dump({ - 'name': FLAGS.export_language, - }, f) - shutil.copy(FLAGS.scorer_path, export_dir) archive = shutil.make_archive(zip_filename, 'zip', export_dir) log_info('Exported packaged model {}'.format(archive)) + def do_single_file_inference(input_file_path): with tfv1.Session(config=Config.session_config) as session: inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1) @@ -895,6 +929,7 @@ def main(_): tfv1.reset_default_graph() do_single_file_inference(FLAGS.one_shot_infer) + if __name__ == '__main__': create_flags() absl.app.run(main) diff --git a/Dockerfile b/Dockerfile index 56afdbfc..71fa0705 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,6 +7,7 @@ FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04 # Get basic packages RUN apt-get update && apt-get install -y --no-install-recommends \ + apt-utils \ build-essential \ curl \ wget \ @@ -44,14 +45,14 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ RUN ln -s -f /usr/bin/python3 /usr/bin/python # Install NCCL 2.2 -RUN apt-get install -qq -y --allow-downgrades --allow-change-held-packages libnccl2=2.3.7-1+cuda10.0 libnccl-dev=2.3.7-1+cuda10.0 +RUN apt-get --no-install-recommends install -qq -y --allow-downgrades --allow-change-held-packages libnccl2=2.3.7-1+cuda10.0 libnccl-dev=2.3.7-1+cuda10.0 # Install Bazel RUN curl -LO "https://github.com/bazelbuild/bazel/releases/download/0.24.1/bazel_0.24.1-linux-x86_64.deb" RUN dpkg -i bazel_*.deb # Install CUDA CLI Tools -RUN apt-get install -qq -y cuda-command-line-tools-10-0 +RUN apt-get --no-install-recommends install -qq -y cuda-command-line-tools-10-0 # Install pip RUN wget https://bootstrap.pypa.io/get-pip.py && \ diff --git a/README.rst b/README.rst index e0ed5ad8..ba007ed9 100644 --- a/README.rst +++ b/README.rst @@ -14,7 +14,7 @@ Project DeepSpeech DeepSpeech is an open source Speech-To-Text engine, using a model trained by machine learning techniques based on `Baidu's Deep Speech research paper `_. Project DeepSpeech uses Google's `TensorFlow `_ to make the implementation easier. -**NOTE:** This documentation applies to the **MASTER version** of DeepSpeech only. If you're using a stable release, you must use the documentation for the corresponding version by using GitHub's branch switcher button above. +**NOTE:** This documentation applies to the **MASTER version** of DeepSpeech only. **Documentation for the latest stable version** is published on `deepspeech.readthedocs.io `_. To install and use deepspeech all you have to do is: diff --git a/VERSION b/VERSION index 15f10e3d..ba984594 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.7.0-alpha.2 +0.7.0-alpha.3 diff --git a/bin/build_sdb.py b/bin/build_sdb.py new file mode 100755 index 00000000..b5fa8d35 --- /dev/null +++ b/bin/build_sdb.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +''' +Tool for building Sample Databases (SDB files) from DeepSpeech CSV files and other SDB files +Use "python3 build_sdb.py -h" for help +''' +from __future__ import absolute_import, division, print_function + +# Make sure we can import stuff from util/ +# This script needs to be run from the root of the DeepSpeech repository +import os +import sys +sys.path.insert(1, os.path.join(sys.path[0], '..')) + +import argparse +import progressbar + +from util.downloader import SIMPLE_BAR +from util.audio import change_audio_types, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS +from util.sample_collections import samples_from_files, DirectSDBWriter + +AUDIO_TYPE_LOOKUP = { + 'wav': AUDIO_TYPE_WAV, + 'opus': AUDIO_TYPE_OPUS +} + + +def build_sdb(): + audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type] + with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled) as sdb_writer: + samples = samples_from_files(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled) + bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR) + for sample in bar(change_audio_types(samples, audio_type=audio_type, processes=CLI_ARGS.workers)): + sdb_writer.add(sample) + + +def handle_args(): + parser = argparse.ArgumentParser(description='Tool for building Sample Databases (SDB files) ' + 'from DeepSpeech CSV files and other SDB files') + parser.add_argument('sources', nargs='+', + help='Source CSV and/or SDB files - ' + 'Note: For getting a correctly ordered target SDB, source SDBs have to have their samples ' + 'already ordered from shortest to longest.') + parser.add_argument('target', help='SDB file to create') + parser.add_argument('--audio-type', default='opus', choices=AUDIO_TYPE_LOOKUP.keys(), + help='Audio representation inside target SDB') + parser.add_argument('--workers', type=int, default=None, + help='Number of encoding SDB workers') + parser.add_argument('--unlabeled', action='store_true', + help='If to build an SDB with unlabeled (audio only) samples - ' + 'typically used for building noise augmentation corpora') + return parser.parse_args() + + +if __name__ == "__main__": + CLI_ARGS = handle_args() + build_sdb() diff --git a/bin/import_aidatatang.py b/bin/import_aidatatang.py index d1367281..703c570f 100755 --- a/bin/import_aidatatang.py +++ b/bin/import_aidatatang.py @@ -7,7 +7,7 @@ import os import sys sys.path.insert(1, os.path.join(sys.path[0], '..')) -import argparse +from util.importers import get_importers_parser import glob import pandas import tarfile @@ -81,7 +81,7 @@ def preprocess_data(tgz_file, target_dir): def main(): # https://www.openslr.org/62/ - parser = argparse.ArgumentParser(description='Import aidatatang_200zh corpus') + parser = get_importers_parser(description='Import aidatatang_200zh corpus') parser.add_argument('tgz_file', help='Path to aidatatang_200zh.tgz') parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.') params = parser.parse_args() diff --git a/bin/import_aishell.py b/bin/import_aishell.py index 5de1121b..939b5c92 100755 --- a/bin/import_aishell.py +++ b/bin/import_aishell.py @@ -7,7 +7,7 @@ import os import sys sys.path.insert(1, os.path.join(sys.path[0], '..')) -import argparse +from util.importers import get_importers_parser import glob import tarfile import pandas @@ -80,7 +80,7 @@ def preprocess_data(tgz_file, target_dir): def main(): # http://www.openslr.org/33/ - parser = argparse.ArgumentParser(description='Import AISHELL corpus') + parser = get_importers_parser(description='Import AISHELL corpus') parser.add_argument('aishell_tgz_file', help='Path to data_aishell.tgz') parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.') params = parser.parse_args() diff --git a/bin/import_cv.py b/bin/import_cv.py index 7dd04d84..ec326d8c 100755 --- a/bin/import_cv.py +++ b/bin/import_cv.py @@ -15,10 +15,8 @@ import progressbar from glob import glob from os import path -from threading import RLock -from multiprocessing.dummy import Pool -from multiprocessing import cpu_count -from util.text import validate_label +from multiprocessing import Pool +from util.importers import validate_label_eng as validate_label, get_counter, get_imported_samples, print_import_report from util.downloader import maybe_download, SIMPLE_BAR FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] @@ -53,6 +51,38 @@ def _maybe_convert_sets(target_dir, extracted_data): for source_csv in glob(path.join(extracted_dir, '*.csv')): _maybe_convert_set(extracted_dir, source_csv, path.join(target_dir, os.path.split(source_csv)[-1])) +def one_sample(sample): + mp3_filename = sample[0] + # Storing wav files next to the mp3 ones - just with a different suffix + wav_filename = path.splitext(mp3_filename)[0] + ".wav" + _maybe_convert_wav(mp3_filename, wav_filename) + frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) + file_size = -1 + if path.exists(wav_filename): + file_size = path.getsize(wav_filename) + frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) + label = validate_label(sample[1]) + rows = [] + counter = get_counter() + if file_size == -1: + # Excluding samples that failed upon conversion + counter['failed'] += 1 + elif label is None: + # Excluding samples that failed on label validation + counter['invalid_label'] += 1 + elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)): + # Excluding samples that are too short to fit the transcript + counter['too_short'] += 1 + elif frames/SAMPLE_RATE > MAX_SECS: + # Excluding very long samples to keep a reasonable batch-size + counter['too_long'] += 1 + else: + # This one is good - keep it for the target CSV + rows.append((wav_filename, file_size, label)) + counter['all'] += 1 + counter['total_time'] += frames + return (counter, rows) + def _maybe_convert_set(extracted_dir, source_csv, target_csv): print() if path.exists(target_csv): @@ -63,48 +93,19 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv): with open(source_csv) as source_csv_file: reader = csv.DictReader(source_csv_file) for row in reader: - samples.append((row['filename'], row['text'])) + samples.append((os.path.join(extracted_dir, row['filename']), row['text'])) # Mutable counters for the concurrent embedded routine - counter = { 'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0 } - lock = RLock() + counter = get_counter() num_samples = len(samples) rows = [] - def one_sample(sample): - mp3_filename = path.join(*(sample[0].split('/'))) - mp3_filename = path.join(extracted_dir, mp3_filename) - # Storing wav files next to the mp3 ones - just with a different suffix - wav_filename = path.splitext(mp3_filename)[0] + ".wav" - _maybe_convert_wav(mp3_filename, wav_filename) - frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) - file_size = -1 - if path.exists(wav_filename): - file_size = path.getsize(wav_filename) - frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) - label = validate_label(sample[1]) - with lock: - if file_size == -1: - # Excluding samples that failed upon conversion - counter['failed'] += 1 - elif label is None: - # Excluding samples that failed on label validation - counter['invalid_label'] += 1 - elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)): - # Excluding samples that are too short to fit the transcript - counter['too_short'] += 1 - elif frames/SAMPLE_RATE > MAX_SECS: - # Excluding very long samples to keep a reasonable batch-size - counter['too_long'] += 1 - else: - # This one is good - keep it for the target CSV - rows.append((wav_filename, file_size, label)) - counter['all'] += 1 - print('Importing mp3 files...') - pool = Pool(cpu_count()) + pool = Pool() bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) - for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): + for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1): + counter += processed[0] + rows += processed[1] bar.update(i) bar.update(num_samples) pool.close() @@ -118,15 +119,11 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv): for filename, file_size, transcript in bar(rows): writer.writerow({ 'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript }) - print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) - if counter['failed'] > 0: - print('Skipped %d samples that failed upon conversion.' % counter['failed']) - if counter['invalid_label'] > 0: - print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) - if counter['too_short'] > 0: - print('Skipped %d samples that were too short to match the transcript.' % counter['too_short']) - if counter['too_long'] > 0: - print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS)) + imported_samples = get_imported_samples(counter) + assert counter['all'] == num_samples + assert len(rows) == imported_samples + + print_import_report(counter, SAMPLE_RATE, MAX_SECS) def _maybe_convert_wav(mp3_filename, wav_filename): if not path.exists(wav_filename): diff --git a/bin/import_cv2.py b/bin/import_cv2.py index acea122b..474202be 100755 --- a/bin/import_cv2.py +++ b/bin/import_cv2.py @@ -16,18 +16,15 @@ sys.path.insert(1, os.path.join(sys.path[0], '..')) import csv import sox -import argparse import subprocess import progressbar import unicodedata from os import path -from threading import RLock -from multiprocessing.dummy import Pool -from multiprocessing import cpu_count +from multiprocessing import Pool from util.downloader import SIMPLE_BAR -from util.text import Alphabet, validate_label -from util.helpers import secs_to_hours +from util.text import Alphabet +from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] @@ -35,15 +32,50 @@ SAMPLE_RATE = 16000 MAX_SECS = 10 -def _preprocess_data(tsv_dir, audio_dir, label_filter, space_after_every_character=False): +def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False): for dataset in ['train', 'test', 'dev', 'validated', 'other']: input_tsv = path.join(path.abspath(tsv_dir), dataset+".tsv") if os.path.isfile(input_tsv): print("Loading TSV file: ", input_tsv) - _maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_character) + _maybe_convert_set(input_tsv, audio_dir, space_after_every_character) +def one_sample(sample): + """ Take a audio file, and optionally convert it to 16kHz WAV """ + mp3_filename = sample[0] + if not path.splitext(mp3_filename.lower())[1] == '.mp3': + mp3_filename += ".mp3" + # Storing wav files next to the mp3 ones - just with a different suffix + wav_filename = path.splitext(mp3_filename)[0] + ".wav" + _maybe_convert_wav(mp3_filename, wav_filename) + file_size = -1 + frames = 0 + if path.exists(wav_filename): + file_size = path.getsize(wav_filename) + frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) + label = label_filter_fun(sample[1]) + rows = [] + counter = get_counter() + if file_size == -1: + # Excluding samples that failed upon conversion + counter['failed'] += 1 + elif label is None: + # Excluding samples that failed on label validation + counter['invalid_label'] += 1 + elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)): + # Excluding samples that are too short to fit the transcript + counter['too_short'] += 1 + elif frames/SAMPLE_RATE > MAX_SECS: + # Excluding very long samples to keep a reasonable batch-size + counter['too_long'] += 1 + else: + # This one is good - keep it for the target CSV + rows.append((os.path.split(wav_filename)[-1], file_size, label)) + counter['all'] += 1 + counter['total_time'] += frames -def _maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_character=None): + return (counter, rows) + +def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None): output_csv = path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv')) print("Saving new DeepSpeech-formatted CSV file to: ", output_csv) @@ -52,51 +84,18 @@ def _maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_cha with open(input_tsv, encoding='utf-8') as input_tsv_file: reader = csv.DictReader(input_tsv_file, delimiter='\t') for row in reader: - samples.append((row['path'], row['sentence'])) + samples.append((path.join(audio_dir, row['path']), row['sentence'])) - # Keep track of how many samples are good vs. problematic - counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0} - lock = RLock() + counter = get_counter() num_samples = len(samples) rows = [] - def one_sample(sample): - """ Take a audio file, and optionally convert it to 16kHz WAV """ - mp3_filename = path.join(audio_dir, sample[0]) - if not path.splitext(mp3_filename.lower())[1] == '.mp3': - mp3_filename += ".mp3" - # Storing wav files next to the mp3 ones - just with a different suffix - wav_filename = path.splitext(mp3_filename)[0] + ".wav" - _maybe_convert_wav(mp3_filename, wav_filename) - file_size = -1 - frames = 0 - if path.exists(wav_filename): - file_size = path.getsize(wav_filename) - frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) - label = label_filter(sample[1]) - with lock: - if file_size == -1: - # Excluding samples that failed upon conversion - counter['failed'] += 1 - elif label is None: - # Excluding samples that failed on label validation - counter['invalid_label'] += 1 - elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)): - # Excluding samples that are too short to fit the transcript - counter['too_short'] += 1 - elif frames/SAMPLE_RATE > MAX_SECS: - # Excluding very long samples to keep a reasonable batch-size - counter['too_long'] += 1 - else: - # This one is good - keep it for the target CSV - rows.append((os.path.split(wav_filename)[-1], file_size, label)) - counter['all'] += 1 - counter['total_time'] += frames - print("Importing mp3 files...") - pool = Pool(cpu_count()) + pool = Pool() bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) - for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): + for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1): + counter += processed[0] + rows += processed[1] bar.update(i) bar.update(num_samples) pool.close() @@ -113,16 +112,11 @@ def _maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_cha else: writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript}) - print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) - if counter['failed'] > 0: - print('Skipped %d samples that failed upon conversion.' % counter['failed']) - if counter['invalid_label'] > 0: - print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) - if counter['too_short'] > 0: - print('Skipped %d samples that were too short to match the transcript.' % counter['too_short']) - if counter['too_long'] > 0: - print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS)) - print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE)) + imported_samples = get_imported_samples(counter) + assert counter['all'] == num_samples + assert len(rows) == imported_samples + + print_import_report(counter, SAMPLE_RATE, MAX_SECS) def _maybe_convert_wav(mp3_filename, wav_filename): @@ -136,7 +130,7 @@ def _maybe_convert_wav(mp3_filename, wav_filename): if __name__ == "__main__": - PARSER = argparse.ArgumentParser(description='Import CommonVoice v2.0 corpora') + PARSER = get_importers_parser(description='Import CommonVoice v2.0 corpora') PARSER.add_argument('tsv_dir', help='Directory containing tsv files') PARSER.add_argument('--audio_dir', help='Directory containing the audio clips - defaults to "/clips"') PARSER.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet') @@ -144,6 +138,7 @@ if __name__ == "__main__": PARSER.add_argument('--space_after_every_character', action='store_true', help='To help transcript join by white space') PARAMS = PARSER.parse_args() + validate_label = get_validate_label(PARAMS) AUDIO_DIR = PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, 'clips') ALPHABET = Alphabet(PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None @@ -161,4 +156,4 @@ if __name__ == "__main__": label = None return label - _preprocess_data(PARAMS.tsv_dir, AUDIO_DIR, label_filter_fun, PARAMS.space_after_every_character) + _preprocess_data(PARAMS.tsv_dir, AUDIO_DIR, PARAMS.space_after_every_character) diff --git a/bin/import_fisher.py b/bin/import_fisher.py index e3340244..dd054765 100755 --- a/bin/import_fisher.py +++ b/bin/import_fisher.py @@ -19,7 +19,7 @@ import unicodedata import librosa import soundfile # <= Has an external dependency on libsndfile -from util.text import validate_label +from util.importers import validate_label_eng as validate_label def _download_and_preprocess_data(data_dir): # Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19 diff --git a/bin/import_freestmandarin.py b/bin/import_freestmandarin.py index e600befb..8e6f5615 100755 --- a/bin/import_freestmandarin.py +++ b/bin/import_freestmandarin.py @@ -7,7 +7,7 @@ import os import sys sys.path.insert(1, os.path.join(sys.path[0], '..')) -import argparse +from util.importers import get_importers_parser import glob import numpy as np import pandas @@ -81,7 +81,7 @@ def preprocess_data(tgz_file, target_dir): def main(): # https://www.openslr.org/38/ - parser = argparse.ArgumentParser(description='Import Free ST Chinese Mandarin corpus') + parser = get_importers_parser(description='Import Free ST Chinese Mandarin corpus') parser.add_argument('tgz_file', help='Path to ST-CMDS-20170001_1-OS.tar.gz') parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.') params = parser.parse_args() diff --git a/bin/import_gram_vaani.py b/bin/import_gram_vaani.py index e1fdd078..141478b8 100755 --- a/bin/import_gram_vaani.py +++ b/bin/import_gram_vaani.py @@ -1,12 +1,16 @@ #!/usr/bin/env python +# Make sure we can import stuff from util/ +# This script needs to be run from the root of the DeepSpeech repository import os -import csv import sys +sys.path.insert(1, os.path.join(sys.path[0], '..')) + +import csv import math import urllib import logging -import argparse +from util.importers import get_importers_parser, get_validate_label import subprocess from os import path from pathlib import Path @@ -15,8 +19,6 @@ import swifter import pandas as pd from sox import Transformer -from util.text import validate_label - __version__ = "0.1.0" _logger = logging.getLogger(__name__) @@ -38,7 +40,7 @@ def parse_args(args): Returns: :obj:`argparse.Namespace`: command line parameters namespace """ - parser = argparse.ArgumentParser( + parser = get_importers_parser( description="Imports GramVaani data for Deep Speech" ) parser.add_argument( @@ -286,6 +288,7 @@ def main(args): args ([str]): command line parameter list """ args = parse_args(args) + validate_label = get_validate_label(args) setup_logging(args.loglevel) _logger.info("Starting GramVaani importer...") _logger.info("Starting loading GramVaani csv...") diff --git a/bin/import_lingua_libre.py b/bin/import_lingua_libre.py index ae893350..493f28a0 100755 --- a/bin/import_lingua_libre.py +++ b/bin/import_lingua_libre.py @@ -3,13 +3,13 @@ from __future__ import absolute_import, division, print_function # Make sure we can import stuff from util/ # This script needs to be run from the root of the DeepSpeech repository -import argparse import os import sys - - sys.path.insert(1, os.path.join(sys.path[0], '..')) +from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report + +import argparse import csv import re import sox @@ -18,17 +18,14 @@ import subprocess import progressbar import unicodedata -from threading import RLock -from multiprocessing.dummy import Pool -from multiprocessing import cpu_count +from multiprocessing import Pool from util.downloader import SIMPLE_BAR from os import path from glob import glob from util.downloader import maybe_download -from util.text import Alphabet, validate_label -from util.helpers import secs_to_hours +from util.text import Alphabet FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] SAMPLE_RATE = 16000 @@ -61,6 +58,41 @@ def _maybe_extract(target_dir, extracted_data, archive_path): else: print('Found directory "%s" - not extracting it from archive.' % archive_path) +def one_sample(sample): + """ Take a audio file, and optionally convert it to 16kHz WAV """ + ogg_filename = sample[0] + # Storing wav files next to the ogg ones - just with a different suffix + wav_filename = path.splitext(ogg_filename)[0] + ".wav" + _maybe_convert_wav(ogg_filename, wav_filename) + file_size = -1 + frames = 0 + if path.exists(wav_filename): + file_size = path.getsize(wav_filename) + frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) + label = label_filter(sample[1]) + rows = [] + counter = get_counter() + + if file_size == -1: + # Excluding samples that failed upon conversion + counter['failed'] += 1 + elif label is None: + # Excluding samples that failed on label validation + counter['invalid_label'] += 1 + elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)): + # Excluding samples that are too short to fit the transcript + counter['too_short'] += 1 + elif frames/SAMPLE_RATE > MAX_SECS: + # Excluding very long samples to keep a reasonable batch-size + counter['too_long'] += 1 + else: + # This one is good - keep it for the target CSV + rows.append((wav_filename, file_size, label)) + counter['all'] += 1 + counter['total_time'] += frames + + return (counter, rows) + def _maybe_convert_sets(target_dir, extracted_data): extracted_dir = path.join(target_dir, extracted_data) # override existing CSV with normalized one @@ -76,49 +108,18 @@ def _maybe_convert_sets(target_dir, extracted_data): for record in glob(glob_dir, recursive=True): record_file = record.replace(ogg_root_dir + os.path.sep, '') if record_filter(record_file): - samples.append((record_file, os.path.splitext(os.path.basename(record_file))[0])) + samples.append((os.path.join(ogg_root_dir, record_file), os.path.splitext(os.path.basename(record_file))[0])) - # Keep track of how many samples are good vs. problematic - counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0} - lock = RLock() + counter = get_counter() num_samples = len(samples) rows = [] - def one_sample(sample): - """ Take a audio file, and optionally convert it to 16kHz WAV """ - ogg_filename = path.join(ogg_root_dir, sample[0]) - # Storing wav files next to the ogg ones - just with a different suffix - wav_filename = path.splitext(ogg_filename)[0] + ".wav" - _maybe_convert_wav(ogg_filename, wav_filename) - file_size = -1 - frames = 0 - if path.exists(wav_filename): - file_size = path.getsize(wav_filename) - frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) - label = label_filter(sample[1]) - with lock: - if file_size == -1: - # Excluding samples that failed upon conversion - counter['failed'] += 1 - elif label is None: - # Excluding samples that failed on label validation - counter['invalid_label'] += 1 - elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)): - # Excluding samples that are too short to fit the transcript - counter['too_short'] += 1 - elif frames/SAMPLE_RATE > MAX_SECS: - # Excluding very long samples to keep a reasonable batch-size - counter['too_long'] += 1 - else: - # This one is good - keep it for the target CSV - rows.append((wav_filename, file_size, label)) - counter['all'] += 1 - counter['total_time'] += frames - print("Importing ogg files...") - pool = Pool(cpu_count()) + pool = Pool() bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) - for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): + for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1): + counter += processed[0] + rows += processed[1] bar.update(i) bar.update(num_samples) pool.close() @@ -152,16 +153,11 @@ def _maybe_convert_sets(target_dir, extracted_data): transcript=transcript, )) - print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) - if counter['failed'] > 0: - print('Skipped %d samples that failed upon conversion.' % counter['failed']) - if counter['invalid_label'] > 0: - print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) - if counter['too_short'] > 0: - print('Skipped %d samples that were too short to match the transcript.' % counter['too_short']) - if counter['too_long'] > 0: - print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS)) - print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE)) + imported_samples = get_imported_samples(counter) + assert counter['all'] == num_samples + assert len(rows) == imported_samples + + print_import_report(counter, SAMPLE_RATE, MAX_SECS) def _maybe_convert_wav(ogg_filename, wav_filename): if not path.exists(wav_filename): @@ -173,7 +169,7 @@ def _maybe_convert_wav(ogg_filename, wav_filename): print('SoX processing error', ex, ogg_filename, wav_filename) def handle_args(): - parser = argparse.ArgumentParser(description='Importer for LinguaLibre dataset. Check https://lingualibre.fr/wiki/Help:Download_from_LinguaLibre for details.') + parser = get_importers_parser(description='Importer for LinguaLibre dataset. Check https://lingualibre.fr/wiki/Help:Download_from_LinguaLibre for details.') parser.add_argument(dest='target_dir') parser.add_argument('--qId', type=int, required=True, help='LinguaLibre language qId') parser.add_argument('--iso639-3', type=str, required=True, help='ISO639-3 language code') @@ -186,6 +182,7 @@ def handle_args(): if __name__ == "__main__": CLI_ARGS = handle_args() ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None + validate_label = get_validate_label(CLI_ARGS) bogus_regexes = [] if CLI_ARGS.bogus_records: diff --git a/bin/import_m-ailabs.py b/bin/import_m-ailabs.py index 060e8f2a..dc5b7cfe 100755 --- a/bin/import_m-ailabs.py +++ b/bin/import_m-ailabs.py @@ -4,29 +4,27 @@ from __future__ import absolute_import, division, print_function # Make sure we can import stuff from util/ # This script needs to be run from the root of the DeepSpeech repository -import argparse import os import sys sys.path.insert(1, os.path.join(sys.path[0], '..')) +from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report + import csv import subprocess import progressbar import unicodedata import tarfile -from threading import RLock -from multiprocessing.dummy import Pool -from multiprocessing import cpu_count +from multiprocessing import Pool from util.downloader import SIMPLE_BAR from os import path from glob import glob from util.downloader import maybe_download -from util.text import Alphabet, validate_label -from util.helpers import secs_to_hours +from util.text import Alphabet FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] SAMPLE_RATE = 16000 @@ -62,6 +60,38 @@ def _maybe_extract(target_dir, extracted_data, archive_path): print('Found directory "%s" - not extracting it from archive.' % archive_path) +def one_sample(sample): + """ Take a audio file, and optionally convert it to 16kHz WAV """ + wav_filename = sample[0] + file_size = -1 + frames = 0 + if path.exists(wav_filename): + file_size = path.getsize(wav_filename) + frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) + label = label_filter(sample[1]) + counter = get_counter() + rows = [] + + if file_size == -1: + # Excluding samples that failed upon conversion + print("conversion failure", wav_filename) + counter['failed'] += 1 + elif label is None: + # Excluding samples that failed on label validation + counter['invalid_label'] += 1 + elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)): + # Excluding samples that are too short to fit the transcript + counter['too_short'] += 1 + elif frames/SAMPLE_RATE > MAX_SECS: + # Excluding very long samples to keep a reasonable batch-size + counter['too_long'] += 1 + else: + # This one is good - keep it for the target CSV + rows.append((wav_filename, file_size, label)) + counter['all'] += 1 + counter['total_time'] += frames + return (counter, rows) + def _maybe_convert_sets(target_dir, extracted_data): extracted_dir = path.join(target_dir, extracted_data) # override existing CSV with normalized one @@ -84,44 +114,16 @@ def _maybe_convert_sets(target_dir, extracted_data): transcript = re[2] samples.append((audio, transcript)) - # Keep track of how many samples are good vs. problematic - counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0} - lock = RLock() + counter = get_counter() num_samples = len(samples) rows = [] - def one_sample(sample): - """ Take a audio file, and optionally convert it to 16kHz WAV """ - wav_filename = sample[0] - file_size = -1 - frames = 0 - if path.exists(wav_filename): - file_size = path.getsize(wav_filename) - frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) - label = label_filter(sample[1]) - with lock: - if file_size == -1: - # Excluding samples that failed upon conversion - counter['failed'] += 1 - elif label is None: - # Excluding samples that failed on label validation - counter['invalid_label'] += 1 - elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)): - # Excluding samples that are too short to fit the transcript - counter['too_short'] += 1 - elif frames/SAMPLE_RATE > MAX_SECS: - # Excluding very long samples to keep a reasonable batch-size - counter['too_long'] += 1 - else: - # This one is good - keep it for the target CSV - rows.append((wav_filename, file_size, label)) - counter['all'] += 1 - counter['total_time'] += frames - print("Importing WAV files...") - pool = Pool(cpu_count()) + pool = Pool() bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) - for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): + for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1): + counter += processed[0] + rows += processed[1] bar.update(i) bar.update(num_samples) pool.close() @@ -155,20 +157,14 @@ def _maybe_convert_sets(target_dir, extracted_data): transcript=transcript, )) - print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) - if counter['failed'] > 0: - print('Skipped %d samples that failed upon conversion.' % counter['failed']) - if counter['invalid_label'] > 0: - print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) - if counter['too_short'] > 0: - print('Skipped %d samples that were too short to match the transcript.' % counter['too_short']) - if counter['too_long'] > 0: - print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS)) - print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE)) + imported_samples = get_imported_samples(counter) + assert counter['all'] == num_samples + assert len(rows) == imported_samples + print_import_report(counter, SAMPLE_RATE, MAX_SECS) def handle_args(): - parser = argparse.ArgumentParser(description='Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/.') + parser = get_importers_parser(description='Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/.') parser.add_argument(dest='target_dir') parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet') parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones') @@ -181,6 +177,7 @@ if __name__ == "__main__": CLI_ARGS = handle_args() ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(',')) + validate_label = get_validate_label(CLI_ARGS) def label_filter(label): if CLI_ARGS.normalize: diff --git a/bin/import_magicdata.py b/bin/import_magicdata.py index 2ec01549..27dbf74a 100755 --- a/bin/import_magicdata.py +++ b/bin/import_magicdata.py @@ -7,7 +7,7 @@ import os import sys sys.path.insert(1, os.path.join(sys.path[0], '..')) -import argparse +from util.importers import get_importers_parser import glob import pandas import tarfile @@ -99,7 +99,7 @@ def preprocess_data(folder_with_archives, target_dir): def main(): # https://openslr.org/68/ - parser = argparse.ArgumentParser(description='Import MAGICDATA corpus') + parser = get_importers_parser(description='Import MAGICDATA corpus') parser.add_argument('folder_with_archives', help='Path to folder containing magicdata_{train,dev,test}.tar.gz') parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to a folder called magicdata next to the archives') params = parser.parse_args() diff --git a/bin/import_primewords.py b/bin/import_primewords.py index 63f21cf7..0d6fdc52 100755 --- a/bin/import_primewords.py +++ b/bin/import_primewords.py @@ -7,7 +7,7 @@ import os import sys sys.path.insert(1, os.path.join(sys.path[0], '..')) -import argparse +from util.importers import get_importers_parser import glob import json import numpy as np @@ -93,7 +93,7 @@ def preprocess_data(tgz_file, target_dir): def main(): # https://www.openslr.org/47/ - parser = argparse.ArgumentParser(description='Import Primewords Chinese corpus set 1') + parser = get_importers_parser(description='Import Primewords Chinese corpus set 1') parser.add_argument('tgz_file', help='Path to primewords_md_2018_set1.tar.gz') parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.') params = parser.parse_args() diff --git a/bin/import_slr57.py b/bin/import_slr57.py index 5dde767a..f11a78ed 100755 --- a/bin/import_slr57.py +++ b/bin/import_slr57.py @@ -3,13 +3,12 @@ from __future__ import absolute_import, division, print_function # Make sure we can import stuff from util/ # This script needs to be run from the root of the DeepSpeech repository -import argparse import os import sys - - sys.path.insert(1, os.path.join(sys.path[0], '..')) +from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report + import csv import re import sox @@ -19,16 +18,14 @@ import progressbar import unicodedata import tarfile -from threading import RLock -from multiprocessing.dummy import Pool -from multiprocessing import cpu_count +from multiprocessing import Pool from util.downloader import SIMPLE_BAR from os import path from glob import glob from util.downloader import maybe_download -from util.text import Alphabet, validate_label +from util.text import Alphabet from util.helpers import secs_to_hours FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] @@ -63,6 +60,37 @@ def _maybe_extract(target_dir, extracted_data, archive_path): else: print('Found directory "%s" - not extracting it from archive.' % archive_path) +def one_sample(sample): + """ Take a audio file, and optionally convert it to 16kHz WAV """ + wav_filename = sample[0] + file_size = -1 + frames = 0 + if path.exists(wav_filename): + file_size = path.getsize(wav_filename) + frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) + label = label_filter(sample[1]) + counter = get_counter() + rows = [] + if file_size == -1: + # Excluding samples that failed upon conversion + counter['failed'] += 1 + elif label is None: + # Excluding samples that failed on label validation + counter['invalid_label'] += 1 + elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)): + # Excluding samples that are too short to fit the transcript + counter['too_short'] += 1 + elif frames/SAMPLE_RATE > MAX_SECS: + # Excluding very long samples to keep a reasonable batch-size + counter['too_long'] += 1 + else: + # This one is good - keep it for the target CSV + rows.append((wav_filename, file_size, label)) + counter['all'] += 1 + counter['total_time'] += frames + + return (counter, rows) + def _maybe_convert_sets(target_dir, extracted_data): extracted_dir = path.join(target_dir, extracted_data) # override existing CSV with normalized one @@ -113,43 +141,16 @@ def _maybe_convert_sets(target_dir, extracted_data): samples.append((record, transcripts[record_file])) # Keep track of how many samples are good vs. problematic - counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0} - lock = RLock() + counter = get_counter() num_samples = len(samples) rows = [] - def one_sample(sample): - """ Take a audio file, and optionally convert it to 16kHz WAV """ - wav_filename = sample[0] - file_size = -1 - frames = 0 - if path.exists(wav_filename): - file_size = path.getsize(wav_filename) - frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) - label = label_filter(sample[1]) - with lock: - if file_size == -1: - # Excluding samples that failed upon conversion - counter['failed'] += 1 - elif label is None: - # Excluding samples that failed on label validation - counter['invalid_label'] += 1 - elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)): - # Excluding samples that are too short to fit the transcript - counter['too_short'] += 1 - elif frames/SAMPLE_RATE > MAX_SECS: - # Excluding very long samples to keep a reasonable batch-size - counter['too_long'] += 1 - else: - # This one is good - keep it for the target CSV - rows.append((wav_filename, file_size, label)) - counter['all'] += 1 - counter['total_time'] += frames - print("Importing WAV files...") - pool = Pool(cpu_count()) + pool = Pool() bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) - for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): + for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1): + counter += processed[0] + rows += processed[1] bar.update(i) bar.update(num_samples) pool.close() @@ -183,19 +184,14 @@ def _maybe_convert_sets(target_dir, extracted_data): transcript=transcript, )) - print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) - if counter['failed'] > 0: - print('Skipped %d samples that failed upon conversion.' % counter['failed']) - if counter['invalid_label'] > 0: - print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) - if counter['too_short'] > 0: - print('Skipped %d samples that were too short to match the transcript.' % counter['too_short']) - if counter['too_long'] > 0: - print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS)) - print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE)) + imported_samples = get_imported_samples(counter) + assert counter['all'] == num_samples + assert len(rows) == imported_samples + + print_import_report(counter, SAMPLE_RATE, MAX_SECS) def handle_args(): - parser = argparse.ArgumentParser(description='Importer for African Accented French dataset. More information on http://www.openslr.org/57/.') + parser = get_importers_parser(description='Importer for African Accented French dataset. More information on http://www.openslr.org/57/.') parser.add_argument(dest='target_dir') parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet') parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones') @@ -204,6 +200,7 @@ def handle_args(): if __name__ == "__main__": CLI_ARGS = handle_args() ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None + validate_label = get_validate_label(CLI_ARGS) def label_filter(label): if CLI_ARGS.normalize: diff --git a/bin/import_swb.py b/bin/import_swb.py index e4261aa2..b682ae30 100755 --- a/bin/import_swb.py +++ b/bin/import_swb.py @@ -20,7 +20,7 @@ import wave import codecs import tarfile import requests -from util.text import validate_label +from util.importers import validate_label_eng as validate_label import librosa import soundfile # <= Has an external dependency on libsndfile diff --git a/bin/import_swc.py b/bin/import_swc.py index 93410805..e5114156 100755 --- a/bin/import_swc.py +++ b/bin/import_swc.py @@ -27,7 +27,8 @@ from os import path from glob import glob from collections import Counter from multiprocessing.pool import ThreadPool -from util.text import Alphabet, validate_label +from util.text import Alphabet +from util.importers import validate_label_eng as validate_label from util.downloader import maybe_download, SIMPLE_BAR SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar" diff --git a/bin/import_ts.py b/bin/import_ts.py index 4aaa058c..a1f0d3b9 100755 --- a/bin/import_ts.py +++ b/bin/import_ts.py @@ -3,14 +3,13 @@ from __future__ import absolute_import, division, print_function # Make sure we can import stuff from util/ # This script needs to be run from the root of the DeepSpeech repository -import argparse import os import re import sys - - sys.path.insert(1, os.path.join(sys.path[0], '..')) +from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report + import csv import unidecode import zipfile @@ -18,16 +17,12 @@ import sox import subprocess import progressbar -from threading import RLock -from multiprocessing.dummy import Pool -from multiprocessing import cpu_count +from multiprocessing import Pool from util.downloader import SIMPLE_BAR from os import path from util.downloader import maybe_download -from util.text import validate_label -from util.helpers import secs_to_hours FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] SAMPLE_RATE = 16000 @@ -61,6 +56,44 @@ def _maybe_extract(target_dir, extracted_data, archive_path): print('Found directory "%s" - not extracting it from archive.' % archive_path) +def one_sample(sample): + """ Take a audio file, and optionally convert it to 16kHz WAV """ + orig_filename = sample['path'] + # Storing wav files next to the wav ones - just with a different suffix + wav_filename = path.splitext(orig_filename)[0] + ".converted.wav" + _maybe_convert_wav(orig_filename, wav_filename) + file_size = -1 + frames = 0 + if path.exists(wav_filename): + file_size = path.getsize(wav_filename) + frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) + label = sample['text'] + + rows = [] + + # Keep track of how many samples are good vs. problematic + counter = get_counter() + if file_size == -1: + # Excluding samples that failed upon conversion + counter['failed'] += 1 + elif label is None: + # Excluding samples that failed on label validation + counter['invalid_label'] += 1 + elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)): + # Excluding samples that are too short to fit the transcript + counter['too_short'] += 1 + elif frames/SAMPLE_RATE > MAX_SECS: + # Excluding very long samples to keep a reasonable batch-size + counter['too_long'] += 1 + else: + # This one is good - keep it for the target CSV + rows.append((wav_filename, file_size, label)) + counter['all'] += 1 + counter['total_time'] += frames + + return (counter, rows) + + def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False): extracted_dir = path.join(target_dir, extracted_data) # override existing CSV with normalized one @@ -74,49 +107,19 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False): if float(d['duration']) <= MAX_SECS ] - # Keep track of how many samples are good vs. problematic - counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0} - lock = RLock() + for line in data: + line['path'] = os.path.join(extracted_dir, line['path']) + num_samples = len(data) rows = [] + counter = get_counter() - wav_root_dir = extracted_dir - - def one_sample(sample): - """ Take a audio file, and optionally convert it to 16kHz WAV """ - orig_filename = path.join(wav_root_dir, sample['path']) - # Storing wav files next to the wav ones - just with a different suffix - wav_filename = path.splitext(orig_filename)[0] + ".converted.wav" - _maybe_convert_wav(orig_filename, wav_filename) - file_size = -1 - frames = 0 - if path.exists(wav_filename): - file_size = path.getsize(wav_filename) - frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) - label = sample['text'] - with lock: - if file_size == -1: - # Excluding samples that failed upon conversion - counter['failed'] += 1 - elif label is None: - # Excluding samples that failed on label validation - counter['invalid_label'] += 1 - elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)): - # Excluding samples that are too short to fit the transcript - counter['too_short'] += 1 - elif frames/SAMPLE_RATE > MAX_SECS: - # Excluding very long samples to keep a reasonable batch-size - counter['too_long'] += 1 - else: - # This one is good - keep it for the target CSV - rows.append((wav_filename, file_size, label)) - counter['all'] += 1 - counter['total_time'] += frames - - print("Importing wav files...") - pool = Pool(cpu_count()) + print("Importing {} wav files...".format(num_samples)) + pool = Pool() bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) - for i, _ in enumerate(pool.imap_unordered(one_sample, data), start=1): + for i, processed in enumerate(pool.imap_unordered(one_sample, data), start=1): + counter += processed[0] + rows += processed[1] bar.update(i) bar.update(num_samples) pool.close() @@ -133,7 +136,6 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False): test_writer.writeheader() for i, item in enumerate(rows): - print('item', item) transcript = validate_label(cleanup_transcript(item[2], english_compatible=english_compatible)) if not transcript: continue @@ -151,16 +153,11 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False): transcript=transcript, )) - print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'])) - if counter['failed'] > 0: - print('Skipped %d samples that failed upon conversion.' % counter['failed']) - if counter['invalid_label'] > 0: - print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) - if counter['too_short'] > 0: - print('Skipped %d samples that were too short to match the transcript.' % counter['too_short']) - if counter['too_long'] > 0: - print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS)) - print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE)) + imported_samples = get_imported_samples(counter) + assert counter['all'] == num_samples + assert len(rows) == imported_samples + + print_import_report(counter, SAMPLE_RATE, MAX_SECS) def _maybe_convert_wav(orig_filename, wav_filename): if not path.exists(wav_filename): @@ -186,7 +183,7 @@ def cleanup_transcript(text, english_compatible=False): def handle_args(): - parser = argparse.ArgumentParser(description='Importer for TrainingSpeech dataset.') + parser = get_importers_parser(description='Importer for TrainingSpeech dataset.') parser.add_argument(dest='target_dir') parser.add_argument('--english-compatible', action='store_true', dest='english_compatible', help='Remove diactrics and other non-ascii chars.') return parser.parse_args() @@ -194,4 +191,5 @@ def handle_args(): if __name__ == "__main__": cli_args = handle_args() + validate_label = get_validate_label(cli_args) _download_and_preprocess_data(cli_args.target_dir, cli_args.english_compatible) diff --git a/bin/import_tuda.py b/bin/import_tuda.py index 89590144..857be405 100755 --- a/bin/import_tuda.py +++ b/bin/import_tuda.py @@ -21,7 +21,8 @@ import xml.etree.cElementTree as ET from os import path from collections import Counter -from util.text import Alphabet, validate_label +from util.text import Alphabet +from util.importers import validate_label_eng as validate_label from util.downloader import maybe_download, SIMPLE_BAR TUDA_VERSION = 'v2' diff --git a/bin/import_vctk.py b/bin/import_vctk.py index 59e1fafe..68477706 100755 --- a/bin/import_vctk.py +++ b/bin/import_vctk.py @@ -14,13 +14,14 @@ import sys sys.path.insert(1, os.path.join(sys.path[0], "..")) +from util.importers import get_counter, get_imported_samples, print_import_report + import re import librosa import progressbar from os import path -from multiprocessing.dummy import Pool -from multiprocessing import cpu_count +from multiprocessing import Pool from util.downloader import maybe_download, SIMPLE_BAR from zipfile import ZipFile @@ -61,47 +62,46 @@ def _maybe_convert_sets(target_dir, extracted_data): extracted_dir = path.join(target_dir, extracted_data, "wav48") txt_dir = path.join(target_dir, extracted_data, "txt") - cnt = 1 directory = os.path.expanduser(extracted_dir) srtd = len(sorted(os.listdir(directory))) + all_samples = [] for target in sorted(os.listdir(directory)): - print(f"\nSpeaker {cnt} of {srtd}") - _maybe_convert_set(path.join(extracted_dir, os.path.split(target)[-1])) - cnt += 1 - - _write_csv(extracted_dir, txt_dir, target_dir) - - -def _maybe_convert_set(target_csv): - def one_sample(sample): - if is_audio_file(sample): - sample = os.path.join(target_csv, sample) - - y, sr = librosa.load(sample, sr=16000) - - # Trim the beginning and ending silence - yt, index = librosa.effects.trim(y) # pylint: disable=unused-variable - - duration = librosa.get_duration(yt, sr) - if duration > MAX_SECS or duration < MIN_SECS: - os.remove(sample) - else: - librosa.output.write_wav(sample, yt, sr) - - samples = sorted(os.listdir(target_csv)) - - num_samples = len(samples) + all_samples += _maybe_prepare_set(path.join(extracted_dir, os.path.split(target)[-1])) + num_samples = len(all_samples) print(f"Converting wav files to {SAMPLE_RATE}hz...") - pool = Pool(cpu_count()) + pool = Pool() bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) - for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1): + for i, _ in enumerate(pool.imap_unordered(one_sample, all_samples), start=1): bar.update(i) bar.update(num_samples) pool.close() pool.join() + _write_csv(extracted_dir, txt_dir, target_dir) + +def one_sample(sample): + if is_audio_file(sample): + y, sr = librosa.load(sample, sr=16000) + + # Trim the beginning and ending silence + yt, index = librosa.effects.trim(y) # pylint: disable=unused-variable + + duration = librosa.get_duration(yt, sr) + if duration > MAX_SECS or duration < MIN_SECS: + os.remove(sample) + else: + librosa.output.write_wav(sample, yt, sr) + + +def _maybe_prepare_set(target_csv): + samples = sorted(os.listdir(target_csv)) + new_samples = [] + for s in samples: + new_samples.append(os.path.join(target_csv, s)) + samples = new_samples + return samples def _write_csv(extracted_dir, txt_dir, target_dir): print(f"Writing CSV file") @@ -196,8 +196,8 @@ def load_txts(directory): AUDIO_EXTENSIONS = [".wav", "WAV"] -def is_audio_file(filename): - return any(filename.endswith(extension) for extension in AUDIO_EXTENSIONS) +def is_audio_file(filepath): + return any(os.path.basename(filepath).endswith(extension) for extension in AUDIO_EXTENSIONS) if __name__ == "__main__": diff --git a/bin/play.py b/bin/play.py new file mode 100755 index 00000000..180d0b00 --- /dev/null +++ b/bin/play.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python +""" +Tool for playing samples from Sample Databases (SDB files) and DeepSpeech CSV files +Use "python3 build_sdb.py -h" for help +""" +from __future__ import absolute_import, division, print_function + +# Make sure we can import stuff from util/ +# This script needs to be run from the root of the DeepSpeech repository +import os +import sys +sys.path.insert(1, os.path.join(sys.path[0], '..')) + +import random +import argparse + +from util.sample_collections import samples_from_file, LabeledSample +from util.audio import AUDIO_TYPE_PCM + + +def play_sample(samples, index): + if index < 0: + index = len(samples) + index + if CLI_ARGS.random: + index = random.randint(0, len(samples)) + elif index >= len(samples): + print('No sample with index {}'.format(CLI_ARGS.start)) + sys.exit(1) + sample = samples[index] + print('Sample "{}"'.format(sample.sample_id)) + if isinstance(sample, LabeledSample): + print(' "{}"'.format(sample.transcript)) + sample.change_audio_type(AUDIO_TYPE_PCM) + rate, channels, width = sample.audio_format + wave_obj = simpleaudio.WaveObject(sample.audio, channels, width, rate) + play_obj = wave_obj.play() + play_obj.wait_done() + + +def play_collection(): + samples = samples_from_file(CLI_ARGS.collection, buffering=0) + played = 0 + index = CLI_ARGS.start + while True: + if 0 <= CLI_ARGS.number <= played: + return + play_sample(samples, index) + played += 1 + index = (index + 1) % len(samples) + + +def handle_args(): + parser = argparse.ArgumentParser(description='Tool for playing samples from Sample Databases (SDB files) ' + 'and DeepSpeech CSV files') + parser.add_argument('collection', help='Sample DB or CSV file to play samples from') + parser.add_argument('--start', type=int, default=0, + help='Sample index to start at (negative numbers are relative to the end of the collection)') + parser.add_argument('--number', type=int, default=-1, help='Number of samples to play (-1 for endless)') + parser.add_argument('--random', action='store_true', help='If samples should be played in random order') + return parser.parse_args() + + +if __name__ == "__main__": + try: + import simpleaudio + except ModuleNotFoundError: + print('play.py requires Python package "simpleaudio"') + sys.exit(1) + CLI_ARGS = handle_args() + try: + play_collection() + except KeyboardInterrupt: + print(' Stopped') + sys.exit(0) diff --git a/bin/run-tc-ldc93s1_checkpoint_sdb.sh b/bin/run-tc-ldc93s1_checkpoint_sdb.sh new file mode 100755 index 00000000..6f5c307f --- /dev/null +++ b/bin/run-tc-ldc93s1_checkpoint_sdb.sh @@ -0,0 +1,37 @@ +#!/bin/sh + +set -xe + +ldc93s1_dir="./data/smoke_test" +ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv" +ldc93s1_sdb="${ldc93s1_dir}/ldc93s1.sdb" + +if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then + echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}." + python -u bin/import_ldc93s1.py ${ldc93s1_dir} +fi; + +if [ ! -f "${ldc93s1_dir}/ldc93s1.sdb" ]; then + echo "Converting LDC93S1 example data, saving to ${ldc93s1_sdb}." + python -u bin/build_sdb.py ${ldc93s1_csv} ${ldc93s1_sdb} +fi; + +# Force only one visible device because we have a single-sample dataset +# and when trying to run on multiple devices (like GPUs), this will break +export CUDA_VISIBLE_DEVICES=0 + +python -u DeepSpeech.py --noshow_progressbar --noearly_stop \ + --train_files ${ldc93s1_sdb} --train_batch_size 1 \ + --dev_files ${ldc93s1_sdb} --dev_batch_size 1 \ + --test_files ${ldc93s1_sdb} --test_batch_size 1 \ + --n_hidden 100 --epochs 1 \ + --max_to_keep 1 --checkpoint_dir '/tmp/ckpt_sdb' \ + --learning_rate 0.001 --dropout_rate 0.05 \ + --scorer_path 'data/smoke_test/pruned_lm.scorer' | tee /tmp/resume.log + +if ! grep "Loading best validating checkpoint from" /tmp/resume.log; then + echo "Did not resume training from checkpoint" + exit 1 +else + exit 0 +fi diff --git a/bin/run-tc-ldc93s1_new_sdb.sh b/bin/run-tc-ldc93s1_new_sdb.sh new file mode 100755 index 00000000..76032aa2 --- /dev/null +++ b/bin/run-tc-ldc93s1_new_sdb.sh @@ -0,0 +1,34 @@ +#!/bin/sh + +set -xe + +ldc93s1_dir="./data/smoke_test" +ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv" +ldc93s1_sdb="${ldc93s1_dir}/ldc93s1.sdb" + +epoch_count=$1 +audio_sample_rate=$2 + +if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then + echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}." + python -u bin/import_ldc93s1.py ${ldc93s1_dir} +fi; + +if [ ! -f "${ldc93s1_dir}/ldc93s1.sdb" ]; then + echo "Converting LDC93S1 example data, saving to ${ldc93s1_sdb}." + python -u bin/build_sdb.py ${ldc93s1_csv} ${ldc93s1_sdb} +fi; + +# Force only one visible device because we have a single-sample dataset +# and when trying to run on multiple devices (like GPUs), this will break +export CUDA_VISIBLE_DEVICES=0 + +python -u DeepSpeech.py --noshow_progressbar --noearly_stop \ + --train_files ${ldc93s1_sdb} --train_batch_size 1 \ + --dev_files ${ldc93s1_sdb} --dev_batch_size 1 \ + --test_files ${ldc93s1_sdb} --test_batch_size 1 \ + --n_hidden 100 --epochs $epoch_count \ + --max_to_keep 1 --checkpoint_dir '/tmp/ckpt_sdb' \ + --learning_rate 0.001 --dropout_rate 0.05 --export_dir '/tmp/train_sdb' \ + --scorer_path 'data/smoke_test/pruned_lm.scorer' \ + --audio_sample_rate ${audio_sample_rate} diff --git a/bin/run-tc-ldc93s1_new_sdb_csv.sh b/bin/run-tc-ldc93s1_new_sdb_csv.sh new file mode 100755 index 00000000..1b0f6d3d --- /dev/null +++ b/bin/run-tc-ldc93s1_new_sdb_csv.sh @@ -0,0 +1,35 @@ +#!/bin/sh + +set -xe + +ldc93s1_dir="./data/smoke_test" +ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv" +ldc93s1_sdb="${ldc93s1_dir}/ldc93s1.sdb" + +epoch_count=$1 +audio_sample_rate=$2 + +if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then + echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}." + python -u bin/import_ldc93s1.py ${ldc93s1_dir} +fi; + +if [ ! -f "${ldc93s1_dir}/ldc93s1.sdb" ]; then + echo "Converting LDC93S1 example data, saving to ${ldc93s1_sdb}." + python -u bin/build_sdb.py ${ldc93s1_csv} ${ldc93s1_sdb} +fi; + +# Force only one visible device because we have a single-sample dataset +# and when trying to run on multiple devices (like GPUs), this will break +export CUDA_VISIBLE_DEVICES=0 + +python -u DeepSpeech.py --noshow_progressbar --noearly_stop \ + --train_files ${ldc93s1_sdb},${ldc93s1_csv} --train_batch_size 1 \ + --feature_cache '/tmp/ldc93s1_cache_sdb_csv' \ + --dev_files ${ldc93s1_sdb},${ldc93s1_csv} --dev_batch_size 1 \ + --test_files ${ldc93s1_sdb},${ldc93s1_csv} --test_batch_size 1 \ + --n_hidden 100 --epochs $epoch_count \ + --max_to_keep 1 --checkpoint_dir '/tmp/ckpt_sdb_csv' \ + --learning_rate 0.001 --dropout_rate 0.05 --export_dir '/tmp/train_sdb_csv' \ + --scorer_path 'data/smoke_test/pruned_lm.scorer' \ + --audio_sample_rate ${audio_sample_rate} diff --git a/doc/C-API.rst b/doc/C-API.rst index 2506d9b2..2b0e7e05 100644 --- a/doc/C-API.rst +++ b/doc/C-API.rst @@ -34,6 +34,9 @@ C .. doxygenfunction:: DS_IntermediateDecode :project: deepspeech-c +.. doxygenfunction:: DS_IntermediateDecodeWithMetadata + :project: deepspeech-c + .. doxygenfunction:: DS_FinishStream :project: deepspeech-c diff --git a/doc/Decoder.rst b/doc/Decoder.rst new file mode 100644 index 00000000..d7960fad --- /dev/null +++ b/doc/Decoder.rst @@ -0,0 +1,79 @@ +.. _decoder-docs: + +CTC beam search decoder with external scorer +============================================ + +Introduction +^^^^^^^^^^^^ + +DeepSpeech uses the `Connectionist Temporal Classification `_ loss function. For an excellent explanation of CTC and its usage, see this Distill article: `Sequence Modeling with CTC `_. This document assumes the reader is familiar with the concepts described in that article, and describes DeepSpeech specific behaviors that developers building systems with DeepSpeech should know to avoid problems. + +The key words "MUST", "MUST NOT", "REQUIRED", "SHALL", "SHALL NOT", "SHOULD", "SHOULD NOT", "RECOMMENDED", "MAY", and "OPTIONAL" in this document are to be interpreted as described in `BCP 14 `_ when, and only when, they appear in all capitals, as shown here. + + +External scorer +^^^^^^^^^^^^^^^ + +DeepSpeech clients support OPTIONAL use of an external language model to improve the accuracy of the predicted transcripts. In the code, command line parameters, and documentation, this is referred to as a "scorer". The scorer is used to compute the likelihood (also called a score, hence the name "scorer") of sequences of words or characters in the output, to guide the decoder towards more likely results. This improves accuracy significantly. + +The use of an external scorer is fully optional. When an external scorer is not specified, DeepSpeech still uses a beam search decoding algorithm, but without any outside scoring. + +Currently, the DeepSpeech external scorer is implemented with `KenLM `_, plus some tooling to package the necessary files and metadata into a single ``.scorer`` package. The tooling lives in ``data/lm/``. The scripts included in ``data/lm/`` can be used and modified to build your own language model based on your particular use case or language. + +The scripts are geared towards replicating the language model files we release as part of `DeepSpeech model releases `_, but modifying them to use different datasets or language model construction parameters should be simple. + + +Decoding modes +^^^^^^^^^^^^^^ + +DeepSpeech currently supports two modes of operation with significant differences at both training and decoding time. + + +Default mode (alphabet based) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The default mode, which uses an alphabet file (specified with ``--alphabet_config_path`` at training and export time) to determine which labels (characters), and how many of them, to predict in the output layer. At decoding time, if using an external scorer, it MUST be word based and MUST be built using the same alphabet file used for training. Word based means the text corpus used to build the scorer should contain words separated by whitespace. For most western languages, this is the default and requires no special steps from the developer when creating the scorer. + + +UTF-8 mode +^^^^^^^^^^ + +In UTF-8 mode the model predicts UTF-8 bytes directly instead of letters from an alphabet file. This idea was proposed in the paper `Bytes Are All You Need `_. This mode is enabled with the ``--utf8`` flag at training and export time. At training time, the alphabet file is not used. Instead, the model is forced to have 256 labels, with labels 0-254 corresponding to UTF-8 byte values 1-255, and label 255 is used for the CTC blank symbol. If using an external scorer at decoding time, it MUST be built according to the instructions that follow. + +UTF-8 decoding can be useful for languages with very large alphabets, such as Mandarin written with Simplified Chinese characters. It may also be useful for building multi-language models, or as a base for transfer learning. Currently these cases are untested and unsupported. Note that UTF-8 mode makes assumptions that hold for Mandarin written with Simplified Chinese characters and may not hold for other languages. + +UTF-8 scorers are character based (more specifically, Unicode codepoint based), but the way they are used is similar to a word based scorer where each "word" is a sequence of UTF-8 bytes representing a single Unicode codepoint. This means that the input text used to create UTF-8 scorers should contain space separated Unicode codepoints. For example, the following input text: + +``早 上 好`` + +corresponds to the following three "words", or UTF-8 byte sequences: + +``E6 97 A9`` +``E4 B8 8A`` +``E5 A5 BD`` + +At decoding time, the scorer is queried every time a Unicode codepoint is predicted, instead of when a space character is predicted. From the language modeling perspective, this is a character based model. From the implementation perspective, this is a word based model, because each character is composed of multiple labels. + +**Acoustic models trained with ``--utf8`` MUST NOT be used with an alphabet based scorer. Conversely, acoustic models trained with an alphabet file MUST NOT be used with a UTF-8 scorer.** + +UTF-8 scorers can be built by using an input corpus with space separated codepoints. If your corpus only contains single codepoints separated by spaces, ``data/lm/generate_package.py`` should automatically enable UTF-8 mode, and it should print the message "Looks like a character based model." + +If the message "Doesn't look like a character based model." is printed, you should double check your inputs to make sure it only contains single codepoints separated by spaces. UTF-8 mode can be forced by specifying the ``--force_utf8`` flag when running ``data/lm/generate_package.py``, but it is NOT RECOMMENDED. + +Because KenLM uses spaces as a word separator, the resulting language model will not include space characters in it. If you wish to use UTF-8 mode but still model spaces, you need to replace spaces in the input corpus with a different character **before** converting it to space separated codepoints. For example: + +.. code-block:: python + + input_text = 'The quick brown fox jumps over the lazy dog' + spaces_replaced = input_text.replace(' ', '|') + space_separated = ' '.join(spaces_replaced) + print(space_separated) + # T h e | q u i c k | b r o w n | f o x | j u m p s | o v e r | t h e | l a z y | d o g + +The character, '|' in this case, will then have to be replaced with spaces as a post-processing step after decoding. + + +Implementation +^^^^^^^^^^^^^^ + +The decoder source code can be found in ``native_client/ctcdecode``. The decoder is included in the language bindings and clients. In addition, there is a separate Python module which includes just the decoder and is needed for evaluation. In order to build and install this package, see the :github:`native_client README `. diff --git a/doc/DotNet-API.rst b/doc/DotNet-API.rst index 2ba3415f..b4f85dfc 100644 --- a/doc/DotNet-API.rst +++ b/doc/DotNet-API.rst @@ -31,13 +31,20 @@ ErrorCodes Metadata -------- -.. doxygenstruct:: DeepSpeechClient::Structs::Metadata +.. doxygenclass:: DeepSpeechClient::Models::Metadata :project: deepspeech-dotnet - :members: items, num_items, confidence + :members: Transcripts -MetadataItem ------------- +CandidateTranscript +------------------- -.. doxygenstruct:: DeepSpeechClient::Structs::MetadataItem +.. doxygenclass:: DeepSpeechClient::Models::CandidateTranscript :project: deepspeech-dotnet - :members: character, timestep, start_time + :members: Tokens, Confidence + +TokenMetadata +------------- + +.. doxygenclass:: DeepSpeechClient::Models::TokenMetadata + :project: deepspeech-dotnet + :members: Text, Timestep, StartTime diff --git a/doc/Java-API.rst b/doc/Java-API.rst index a485dc02..2986ca97 100644 --- a/doc/Java-API.rst +++ b/doc/Java-API.rst @@ -13,11 +13,17 @@ Metadata .. doxygenclass:: org::mozilla::deepspeech::libdeepspeech::Metadata :project: deepspeech-java - :members: getItems, getNum_items, getProbability, getItem + :members: getTranscripts, getNum_transcripts, getTranscript -MetadataItem ------------- +CandidateTranscript +------------------- -.. doxygenclass:: org::mozilla::deepspeech::libdeepspeech::MetadataItem +.. doxygenclass:: org::mozilla::deepspeech::libdeepspeech::CandidateTranscript :project: deepspeech-java - :members: getCharacter, getTimestep, getStart_time + :members: getTokens, getNum_tokens, getConfidence, getToken + +TokenMetadata +------------- +.. doxygenclass:: org::mozilla::deepspeech::libdeepspeech::TokenMetadata + :project: deepspeech-java + :members: getText, getTimestep, getStart_time diff --git a/doc/NodeJS-API.rst b/doc/NodeJS-API.rst index aaba718c..b6170b5b 100644 --- a/doc/NodeJS-API.rst +++ b/doc/NodeJS-API.rst @@ -30,8 +30,14 @@ Metadata .. js:autoclass:: Metadata :members: -MetadataItem ------------- +CandidateTranscript +------------------- -.. js:autoclass:: MetadataItem +.. js:autoclass:: CandidateTranscript + :members: + +TokenMetadata +------------- + +.. js:autoclass:: TokenMetadata :members: diff --git a/doc/Python-API.rst b/doc/Python-API.rst index b2b3567f..9aec57f0 100644 --- a/doc/Python-API.rst +++ b/doc/Python-API.rst @@ -21,8 +21,14 @@ Metadata .. autoclass:: Metadata :members: -MetadataItem ------------- +CandidateTranscript +------------------- -.. autoclass:: MetadataItem +.. autoclass:: CandidateTranscript + :members: + +TokenMetadata +------------- + +.. autoclass:: TokenMetadata :members: diff --git a/doc/Structs.rst b/doc/Structs.rst index 713e52e0..5d532277 100644 --- a/doc/Structs.rst +++ b/doc/Structs.rst @@ -8,9 +8,16 @@ Metadata :project: deepspeech-c :members: -MetadataItem ------------- +CandidateTranscript +------------------- -.. doxygenstruct:: MetadataItem +.. doxygenstruct:: CandidateTranscript + :project: deepspeech-c + :members: + +TokenMetadata +------------- + +.. doxygenstruct:: TokenMetadata :project: deepspeech-c :members: diff --git a/doc/TRAINING.rst b/doc/TRAINING.rst index 95aec8eb..f4eb56bb 100644 --- a/doc/TRAINING.rst +++ b/doc/TRAINING.rst @@ -219,6 +219,11 @@ Note: the released models were trained with ``--n_hidden 2048``\ , so you need t Key cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/bias/Adam not found in checkpoint +UTF-8 mode +^^^^^^^^^^ + +DeepSpeech includes a UTF-8 operating mode which can be useful to model languages with very large alphabets, such as Chinese Mandarin. For details on how it works and how to use it, see :ref:`decoder-docs`. + Training with augmentation ^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/doc/conf.py b/doc/conf.py index 1ee683e6..a3761170 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -40,7 +40,7 @@ import semver # -- Project information ----------------------------------------------------- project = u'DeepSpeech' -copyright = '2019, Mozilla Corporation' +copyright = '2019-2020, Mozilla Corporation' author = 'Mozilla Corporation' with open('../VERSION', 'r') as ver: diff --git a/doc/doxygen-dotnet.conf b/doc/doxygen-dotnet.conf index ad64cfcb..74c2c5bb 100644 --- a/doc/doxygen-dotnet.conf +++ b/doc/doxygen-dotnet.conf @@ -790,7 +790,7 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = native_client/dotnet/DeepSpeechClient/ native_client/dotnet/DeepSpeechClient/Interfaces/ native_client/dotnet/DeepSpeechClient/Enums/ native_client/dotnet/DeepSpeechClient/Structs/ +INPUT = native_client/dotnet/DeepSpeechClient/ native_client/dotnet/DeepSpeechClient/Interfaces/ native_client/dotnet/DeepSpeechClient/Enums/ native_client/dotnet/DeepSpeechClient/Models/ # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/doc/examples b/doc/examples index 3beecad7..81a06eea 160000 --- a/doc/examples +++ b/doc/examples @@ -1 +1 @@ -Subproject commit 3beecad75c6dbe92d0604690014a3dba9fb9c926 +Subproject commit 81a06eea64d1dda734f6b97b3005b4416ac2f50a diff --git a/lm_optimizer.py b/lm_optimizer.py new file mode 100644 index 00000000..9e01ab96 --- /dev/null +++ b/lm_optimizer.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function + +import sys + +import optuna +import absl.app +from ds_ctcdecoder import Scorer +import tensorflow.compat.v1 as tfv1 + +from DeepSpeech import create_model +from evaluate import evaluate +from util.config import Config, initialize_globals +from util.flags import create_flags, FLAGS +from util.logging import log_error +from util.evaluate_tools import wer_cer_batch + + +def character_based(): + is_character_based = False + if FLAGS.scorer_path: + scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) + is_character_based = scorer.is_utf8_mode() + return is_character_based + +def objective(trial): + FLAGS.lm_alpha = trial.suggest_uniform('lm_alpha', 0, FLAGS.lm_alpha_max) + FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max) + + tfv1.reset_default_graph() + samples = evaluate(FLAGS.test_files.split(','), create_model) + + is_character_based = trial.study.user_attrs['is_character_based'] + + wer, cer = wer_cer_batch(samples) + return cer if is_character_based else wer + +def main(_): + initialize_globals() + + if not FLAGS.test_files: + log_error('You need to specify what files to use for evaluation via ' + 'the --test_files flag.') + sys.exit(1) + + is_character_based = character_based() + + study = optuna.create_study() + study.set_user_attr("is_character_based", is_character_based) + study.optimize(objective, n_jobs=1, n_trials=FLAGS.n_trials) + print('Best params: lm_alpha={} and lm_beta={} with WER={}'.format(study.best_params['lm_alpha'], + study.best_params['lm_beta'], + study.best_value)) + + +if __name__ == '__main__': + create_flags() + absl.app.run(main) diff --git a/native_client/args.h b/native_client/args.h index 33b9b8fe..ca28bfb7 100644 --- a/native_client/args.h +++ b/native_client/args.h @@ -34,6 +34,8 @@ bool extended_metadata = false; bool json_output = false; +int json_candidate_transcripts = 3; + int stream_size = 0; void PrintHelp(const char* bin) @@ -43,18 +45,19 @@ void PrintHelp(const char* bin) "\n" "Running DeepSpeech inference.\n" "\n" - "\t--model MODEL\t\tPath to the model (protocol buffer binary file)\n" - "\t--scorer SCORER\t\tPath to the external scorer file\n" - "\t--audio AUDIO\t\tPath to the audio file to run (WAV format)\n" - "\t--beam_width BEAM_WIDTH\tValue for decoder beam width (int)\n" - "\t--lm_alpha LM_ALPHA\tValue for language model alpha param (float)\n" - "\t--lm_beta LM_BETA\tValue for language model beta param (float)\n" - "\t-t\t\t\tRun in benchmark mode, output mfcc & inference time\n" - "\t--extended\t\tOutput string from extended metadata\n" - "\t--json\t\t\tExtended output, shows word timings as JSON\n" - "\t--stream size\t\tRun in stream mode, output intermediate results\n" - "\t--help\t\t\tShow help\n" - "\t--version\t\tPrint version and exits\n"; + "\t--model MODEL\t\t\tPath to the model (protocol buffer binary file)\n" + "\t--scorer SCORER\t\t\tPath to the external scorer file\n" + "\t--audio AUDIO\t\t\tPath to the audio file to run (WAV format)\n" + "\t--beam_width BEAM_WIDTH\t\tValue for decoder beam width (int)\n" + "\t--lm_alpha LM_ALPHA\t\tValue for language model alpha param (float)\n" + "\t--lm_beta LM_BETA\t\tValue for language model beta param (float)\n" + "\t-t\t\t\t\tRun in benchmark mode, output mfcc & inference time\n" + "\t--extended\t\t\tOutput string from extended metadata\n" + "\t--json\t\t\t\tExtended output, shows word timings as JSON\n" + "\t--candidate_transcripts NUMBER\tNumber of candidate transcripts to include in output\n" + "\t--stream size\t\t\tRun in stream mode, output intermediate results\n" + "\t--help\t\t\t\tShow help\n" + "\t--version\t\t\tPrint version and exits\n"; char* version = DS_Version(); std::cerr << "DeepSpeech " << version << "\n"; DS_FreeString(version); @@ -74,6 +77,7 @@ bool ProcessArgs(int argc, char** argv) {"t", no_argument, nullptr, 't'}, {"extended", no_argument, nullptr, 'e'}, {"json", no_argument, nullptr, 'j'}, + {"candidate_transcripts", required_argument, nullptr, 150}, {"stream", required_argument, nullptr, 's'}, {"version", no_argument, nullptr, 'v'}, {"help", no_argument, nullptr, 'h'}, @@ -128,6 +132,10 @@ bool ProcessArgs(int argc, char** argv) json_output = true; break; + case 150: + json_candidate_transcripts = atoi(optarg); + break; + case 's': stream_size = atoi(optarg); break; diff --git a/native_client/client.cc b/native_client/client.cc index abcadd8d..1f7f78eb 100644 --- a/native_client/client.cc +++ b/native_client/client.cc @@ -44,9 +44,115 @@ struct meta_word { float duration; }; -char* metadataToString(Metadata* metadata); -std::vector WordsFromMetadata(Metadata* metadata); -char* JSONOutput(Metadata* metadata); +char* +CandidateTranscriptToString(const CandidateTranscript* transcript) +{ + std::string retval = ""; + for (int i = 0; i < transcript->num_tokens; i++) { + const TokenMetadata& token = transcript->tokens[i]; + retval += token.text; + } + return strdup(retval.c_str()); +} + +std::vector +CandidateTranscriptToWords(const CandidateTranscript* transcript) +{ + std::vector word_list; + + std::string word = ""; + float word_start_time = 0; + + // Loop through each token + for (int i = 0; i < transcript->num_tokens; i++) { + const TokenMetadata& token = transcript->tokens[i]; + + // Append token to word if it's not a space + if (strcmp(token.text, u8" ") != 0) { + // Log the start time of the new word + if (word.length() == 0) { + word_start_time = token.start_time; + } + word.append(token.text); + } + + // Word boundary is either a space or the last token in the array + if (strcmp(token.text, u8" ") == 0 || i == transcript->num_tokens-1) { + float word_duration = token.start_time - word_start_time; + + if (word_duration < 0) { + word_duration = 0; + } + + meta_word w; + w.word = word; + w.start_time = word_start_time; + w.duration = word_duration; + + word_list.push_back(w); + + // Reset + word = ""; + word_start_time = 0; + } + } + + return word_list; +} + +std::string +CandidateTranscriptToJSON(const CandidateTranscript *transcript) +{ + std::ostringstream out_string; + + std::vector words = CandidateTranscriptToWords(transcript); + + out_string << R"("metadata":{"confidence":)" << transcript->confidence << R"(},"words":[)"; + + for (int i = 0; i < words.size(); i++) { + meta_word w = words[i]; + out_string << R"({"word":")" << w.word << R"(","time":)" << w.start_time << R"(,"duration":)" << w.duration << "}"; + + if (i < words.size() - 1) { + out_string << ","; + } + } + + out_string << "]"; + + return out_string.str(); +} + +char* +MetadataToJSON(Metadata* result) +{ + std::ostringstream out_string; + out_string << "{\n"; + + for (int j=0; j < result->num_transcripts; ++j) { + const CandidateTranscript *transcript = &result->transcripts[j]; + + if (j == 0) { + out_string << CandidateTranscriptToJSON(transcript); + + if (result->num_transcripts > 1) { + out_string << ",\n" << R"("alternatives")" << ":[\n"; + } + } else { + out_string << "{" << CandidateTranscriptToJSON(transcript) << "}"; + + if (j < result->num_transcripts - 1) { + out_string << ",\n"; + } else { + out_string << "\n]"; + } + } + } + + out_string << "\n}\n"; + + return strdup(out_string.str().c_str()); +} ds_result LocalDsSTT(ModelState* aCtx, const short* aBuffer, size_t aBufferSize, @@ -57,13 +163,13 @@ LocalDsSTT(ModelState* aCtx, const short* aBuffer, size_t aBufferSize, clock_t ds_start_time = clock(); if (extended_output) { - Metadata *metadata = DS_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize); - res.string = metadataToString(metadata); - DS_FreeMetadata(metadata); + Metadata *result = DS_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize, 1); + res.string = CandidateTranscriptToString(&result->transcripts[0]); + DS_FreeMetadata(result); } else if (json_output) { - Metadata *metadata = DS_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize); - res.string = JSONOutput(metadata); - DS_FreeMetadata(metadata); + Metadata *result = DS_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize, json_candidate_transcripts); + res.string = MetadataToJSON(result); + DS_FreeMetadata(result); } else if (stream_size > 0) { StreamingState* ctx; int status = DS_CreateStream(aCtx, &ctx); @@ -278,87 +384,6 @@ ProcessFile(ModelState* context, const char* path, bool show_times) } } -char* -metadataToString(Metadata* metadata) -{ - std::string retval = ""; - for (int i = 0; i < metadata->num_items; i++) { - MetadataItem item = metadata->items[i]; - retval += item.character; - } - return strdup(retval.c_str()); -} - -std::vector -WordsFromMetadata(Metadata* metadata) -{ - std::vector word_list; - - std::string word = ""; - float word_start_time = 0; - - // Loop through each character - for (int i = 0; i < metadata->num_items; i++) { - MetadataItem item = metadata->items[i]; - - // Append character to word if it's not a space - if (strcmp(item.character, u8" ") != 0) { - // Log the start time of the new word - if (word.length() == 0) { - word_start_time = item.start_time; - } - word.append(item.character); - } - - // Word boundary is either a space or the last character in the array - if (strcmp(item.character, " ") == 0 - || strcmp(item.character, u8" ") == 0 - || i == metadata->num_items-1) { - - float word_duration = item.start_time - word_start_time; - - if (word_duration < 0) { - word_duration = 0; - } - - meta_word w; - w.word = word; - w.start_time = word_start_time; - w.duration = word_duration; - - word_list.push_back(w); - - // Reset - word = ""; - word_start_time = 0; - } - } - - return word_list; -} - -char* -JSONOutput(Metadata* metadata) -{ - std::vector words = WordsFromMetadata(metadata); - - std::ostringstream out_string; - out_string << R"({"metadata":{"confidence":)" << metadata->confidence << R"(},"words":[)"; - - for (int i = 0; i < words.size(); i++) { - meta_word w = words[i]; - out_string << R"({"word":")" << w.word << R"(","time":)" << w.start_time << R"(,"duration":)" << w.duration << "}"; - - if (i < words.size() - 1) { - out_string << ","; - } - } - - out_string << "]}\n"; - - return strdup(out_string.str().c_str()); -} - int main(int argc, char **argv) { diff --git a/native_client/ctcdecode/Makefile b/native_client/ctcdecode/Makefile index fb3ca6a4..1e8aee0c 100644 --- a/native_client/ctcdecode/Makefile +++ b/native_client/ctcdecode/Makefile @@ -15,14 +15,24 @@ else GENERATE_DEBUG_SYMS := endif +ifeq ($(findstring _NT,$(OS)),_NT) + ARCHIVE_EXT := lib +else + ARCHIVE_EXT := a +endif + +FIRST_PARTY := first_party.$(ARCHIVE_EXT) +THIRD_PARTY := third_party.$(ARCHIVE_EXT) + + all: bindings clean-keep-third-party: rm -rf dist temp_build ds_ctcdecoder.egg-info - rm -f swigwrapper_wrap.cpp swigwrapper.py first_party.a + rm -f swigwrapper_wrap.cpp swigwrapper.py $(FIRST_PARTY) clean: clean-keep-third-party - rm -f third_party.a + rm -f $(THIRD_PARTY) rm workspace_status.cc rm -fr bazel-out/ @@ -31,17 +41,19 @@ workspace_status.cc: ../bazel_workspace_status_cmd.sh > bazel-out/stable-status.txt && \ ../gen_workspace_status.sh > $@ +# Enforce PATH here because swig calls from build_ext looses track of some +# variables over several runs bindings: clean-keep-third-party workspace_status.cc pip install --quiet $(PYTHON_PACKAGES) wheel==0.33.6 setuptools==39.1.0 - AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS) + PATH=$(TOOLCHAIN):$$PATH AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS) find temp_build -type f -name "*.o" -delete - AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py bdist_wheel $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS) + AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py bdist_wheel $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS) rm -rf temp_build bindings-debug: clean-keep-third-party workspace_status.cc pip install --quiet $(PYTHON_PACKAGES) wheel==0.33.6 setuptools==39.1.0 - AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS) $(CXXFLAGS) -DDEBUG" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --debug --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS) + PATH=$(TOOLCHAIN):$$PATH AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS) -DDEBUG" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --debug --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS) $(GENERATE_DEBUG_SYMS) find temp_build -type f -name "*.o" -delete - AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS) $(CXXFLAGS) -DDEBUG" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py bdist_wheel $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS) + AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS) -DDEBUG" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py bdist_wheel $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS) rm -rf temp_build diff --git a/native_client/ctcdecode/build_archive.py b/native_client/ctcdecode/build_archive.py index 6b36ea45..c379d6b3 100644 --- a/native_client/ctcdecode/build_archive.py +++ b/native_client/ctcdecode/build_archive.py @@ -9,14 +9,23 @@ import sys from multiprocessing.dummy import Pool -ARGS = ['-DKENLM_MAX_ORDER=6', '-std=c++11', '-Wno-unused-local-typedefs', '-Wno-sign-compare'] -OPT_ARGS = ['-O3', '-DNDEBUG'] -DBG_ARGS = ['-O0', '-g', '-UNDEBUG', '-DDEBUG'] +if sys.platform.startswith('win'): + ARGS = ['/nologo', '/D KENLM_MAX_ORDER=6', '/EHsc', '/source-charset:utf-8'] + OPT_ARGS = ['/O2', '/MT', '/D NDEBUG'] + DBG_ARGS = ['/Od', '/MTd', '/Zi', '/U NDEBUG', '/D DEBUG'] + OPENFST_DIR = 'third_party/openfst-1.6.9-win' +else: + ARGS = ['-fPIC', '-DKENLM_MAX_ORDER=6', '-std=c++11', '-Wno-unused-local-typedefs', '-Wno-sign-compare'] + OPT_ARGS = ['-O3', '-DNDEBUG'] + DBG_ARGS = ['-O0', '-g', '-UNDEBUG', '-DDEBUG'] + OPENFST_DIR = 'third_party/openfst-1.6.7' + + INCLUDES = [ '..', '../kenlm', - 'third_party/openfst-1.6.7/src/include', + OPENFST_DIR + '/src/include', 'third_party/ThreadPool' ] @@ -24,7 +33,7 @@ KENLM_FILES = (glob.glob('../kenlm/util/*.cc') + glob.glob('../kenlm/lm/*.cc') + glob.glob('../kenlm/util/double-conversion/*.cc')) -KENLM_FILES += glob.glob('third_party/openfst-1.6.7/src/lib/*.cc') +KENLM_FILES += glob.glob(OPENFST_DIR + '/src/lib/*.cc') KENLM_FILES = [ fn for fn in KENLM_FILES @@ -42,7 +51,10 @@ CTC_DECODER_FILES = [ def build_archive(srcs=[], out_name='', build_dir='temp_build/temp_build', debug=False, num_parallel=1): compiler = os.environ.get('CXX', 'g++') + if sys.platform.startswith('win'): + compiler = '"{}"'.format(compiler) ar = os.environ.get('AR', 'ar') + libexe = os.environ.get('LIBEXE', 'lib.exe') libtool = os.environ.get('LIBTOOL', 'libtool') cflags = os.environ.get('CFLAGS', '') + os.environ.get('CXXFLAGS', '') args = ARGS + (DBG_ARGS if debug else OPT_ARGS) @@ -59,13 +71,19 @@ def build_archive(srcs=[], out_name='', build_dir='temp_build/temp_build', debug if os.path.exists(outfile): return - cmd = '{cc} -fPIC -c {cflags} {args} {includes} {infile} -o {outfile}'.format( + if sys.platform.startswith('win'): + file = '"{}"'.format(file.replace('\\', '/')) + output = '/Fo"{}"'.format(outfile.replace('\\', '/')) + else: + output = '-o ' + outfile + + cmd = '{cc} -c {cflags} {args} {includes} {infile} {output}'.format( cc=compiler, cflags=cflags, args=' '.join(args), includes=' '.join('-I' + i for i in INCLUDES), infile=file, - outfile=outfile, + output=output, ) print(cmd) subprocess.check_call(shlex.split(cmd)) @@ -82,6 +100,14 @@ def build_archive(srcs=[], out_name='', build_dir='temp_build/temp_build', debug ) print(cmd) subprocess.check_call(shlex.split(cmd)) + elif sys.platform.startswith('win'): + cmd = '"{libexe}" /OUT:"{outfile}" {infiles} /MACHINE:X64 /NOLOGO'.format( + libexe=libexe, + outfile=out_name, + infiles=' '.join(obj_files)) + cmd = cmd.replace('\\', '/') + print(cmd) + subprocess.check_call(shlex.split(cmd)) else: cmd = '{ar} rcs {outfile} {infiles}'.format( ar=ar, diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index 5dadd57f..8a072c53 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -157,7 +157,7 @@ DecoderState::next(const double *probs, } std::vector -DecoderState::decode() const +DecoderState::decode(size_t num_results) const { std::vector prefixes_copy = prefixes_; std::unordered_map scores; @@ -181,16 +181,12 @@ DecoderState::decode() const } using namespace std::placeholders; - size_t num_prefixes = std::min(prefixes_copy.size(), beam_size_); + size_t num_returned = std::min(prefixes_copy.size(), num_results); std::partial_sort(prefixes_copy.begin(), - prefixes_copy.begin() + num_prefixes, + prefixes_copy.begin() + num_returned, prefixes_copy.end(), std::bind(prefix_compare_external, _1, _2, scores)); - //TODO: expose this as an API parameter - const size_t top_paths = 1; - size_t num_returned = std::min(num_prefixes, top_paths); - std::vector outputs; outputs.reserve(num_returned); diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.h b/native_client/ctcdecode/ctc_beam_search_decoder.h index a3d5c480..b785e097 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.h +++ b/native_client/ctcdecode/ctc_beam_search_decoder.h @@ -60,13 +60,16 @@ public: int time_dim, int class_dim); - /* Get transcription from current decoder state + /* Get up to num_results transcriptions from current decoder state. + * + * Parameters: + * num_results: Number of beams to return. * * Return: * A vector where each element is a pair of score and decoding result, * in descending order. */ - std::vector decode() const; + std::vector decode(size_t num_results=1) const; }; diff --git a/native_client/ctcdecode/setup.py b/native_client/ctcdecode/setup.py index fb5a7114..8a3876c9 100644 --- a/native_client/ctcdecode/setup.py +++ b/native_client/ctcdecode/setup.py @@ -54,8 +54,15 @@ def maybe_rebuild(srcs, out_name, build_dir): project_version = read('../../VERSION').strip() build_dir = 'temp_build/temp_build' -third_party_build = 'third_party.a' -ctc_decoder_build = 'first_party.a' + +if sys.platform.startswith('win'): + archive_ext = 'lib' +else: + archive_ext = 'a' + +third_party_build = 'third_party.{}'.format(archive_ext) +ctc_decoder_build = 'first_party.{}'.format(archive_ext) + maybe_rebuild(KENLM_FILES, third_party_build, build_dir) maybe_rebuild(CTC_DECODER_FILES, ctc_decoder_build, build_dir) diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index dd2a95ea..96989e04 100644 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -60,7 +60,7 @@ using std::vector; When batch_buffer is full, we do a single step through the acoustic model and accumulate the intermediate decoding state in the DecoderState structure. - When finishStream() is called, we return the corresponding transcription from + When finishStream() is called, we return the corresponding transcript from the current decoder state. */ struct StreamingState { @@ -78,9 +78,10 @@ struct StreamingState { void feedAudioContent(const short* buffer, unsigned int buffer_size); char* intermediateDecode() const; + Metadata* intermediateDecodeWithMetadata(unsigned int num_results) const; void finalizeStream(); char* finishStream(); - Metadata* finishStreamWithMetadata(); + Metadata* finishStreamWithMetadata(unsigned int num_results); void processAudioWindow(const vector& buf); void processMfccWindow(const vector& buf); @@ -136,6 +137,12 @@ StreamingState::intermediateDecode() const return model_->decode(decoder_state_); } +Metadata* +StreamingState::intermediateDecodeWithMetadata(unsigned int num_results) const +{ + return model_->decode_metadata(decoder_state_, num_results); +} + char* StreamingState::finishStream() { @@ -144,10 +151,10 @@ StreamingState::finishStream() } Metadata* -StreamingState::finishStreamWithMetadata() +StreamingState::finishStreamWithMetadata(unsigned int num_results) { finalizeStream(); - return model_->decode_metadata(decoder_state_); + return model_->decode_metadata(decoder_state_, num_results); } void @@ -402,6 +409,13 @@ DS_IntermediateDecode(const StreamingState* aSctx) return aSctx->intermediateDecode(); } +Metadata* +DS_IntermediateDecodeWithMetadata(const StreamingState* aSctx, + unsigned int aNumResults) +{ + return aSctx->intermediateDecodeWithMetadata(aNumResults); +} + char* DS_FinishStream(StreamingState* aSctx) { @@ -411,11 +425,12 @@ DS_FinishStream(StreamingState* aSctx) } Metadata* -DS_FinishStreamWithMetadata(StreamingState* aSctx) +DS_FinishStreamWithMetadata(StreamingState* aSctx, + unsigned int aNumResults) { - Metadata* metadata = aSctx->finishStreamWithMetadata(); + Metadata* result = aSctx->finishStreamWithMetadata(aNumResults); DS_FreeStream(aSctx); - return metadata; + return result; } StreamingState* @@ -444,10 +459,11 @@ DS_SpeechToText(ModelState* aCtx, Metadata* DS_SpeechToTextWithMetadata(ModelState* aCtx, const short* aBuffer, - unsigned int aBufferSize) + unsigned int aBufferSize, + unsigned int aNumResults) { StreamingState* ctx = CreateStreamAndFeedAudioContent(aCtx, aBuffer, aBufferSize); - return DS_FinishStreamWithMetadata(ctx); + return DS_FinishStreamWithMetadata(ctx, aNumResults); } void @@ -460,11 +476,16 @@ void DS_FreeMetadata(Metadata* m) { if (m) { - for (int i = 0; i < m->num_items; ++i) { - free(m->items[i].character); + for (int i = 0; i < m->num_transcripts; ++i) { + for (int j = 0; j < m->transcripts[i].num_tokens; ++j) { + free((void*)m->transcripts[i].tokens[j].text); + } + + free((void*)m->transcripts[i].tokens); } - delete[] m->items; - delete m; + + free((void*)m->transcripts); + free(m); } } diff --git a/native_client/deepspeech.h b/native_client/deepspeech.h index 6dad59db..a8c29c93 100644 --- a/native_client/deepspeech.h +++ b/native_client/deepspeech.h @@ -20,32 +20,43 @@ typedef struct ModelState ModelState; typedef struct StreamingState StreamingState; /** - * @brief Stores each individual character, along with its timing information + * @brief Stores text of an individual token, along with its timing information */ -typedef struct MetadataItem { - /** The character generated for transcription */ - char* character; +typedef struct TokenMetadata { + /** The text corresponding to this token */ + const char* const text; - /** Position of the character in units of 20ms */ - int timestep; + /** Position of the token in units of 20ms */ + const unsigned int timestep; - /** Position of the character in seconds */ - float start_time; -} MetadataItem; + /** Position of the token in seconds */ + const float start_time; +} TokenMetadata; /** - * @brief Stores the entire CTC output as an array of character metadata objects + * @brief A single transcript computed by the model, including a confidence + * value and the metadata for its constituent tokens. + */ +typedef struct CandidateTranscript { + /** Array of TokenMetadata objects */ + const TokenMetadata* const tokens; + /** Size of the tokens array */ + const unsigned int num_tokens; + /** Approximated confidence value for this transcript. This is roughly the + * sum of the acoustic model logit values for each timestep/character that + * contributed to the creation of this transcript. + */ + const double confidence; +} CandidateTranscript; + +/** + * @brief An array of CandidateTranscript objects computed by the model. */ typedef struct Metadata { - /** List of items */ - MetadataItem* items; - /** Size of the list of items */ - int num_items; - /** Approximated confidence value for this transcription. This is roughly the - * sum of the acoustic model logit values for each timestep/character that - * contributed to the creation of this transcription. - */ - double confidence; + /** Array of CandidateTranscript objects */ + const CandidateTranscript* const transcripts; + /** Size of the transcripts array */ + const unsigned int num_transcripts; } Metadata; enum DeepSpeech_Error_Codes @@ -164,7 +175,7 @@ int DS_SetScorerAlphaBeta(ModelState* aCtx, float aBeta); /** - * @brief Use the DeepSpeech model to perform Speech-To-Text. + * @brief Use the DeepSpeech model to convert speech to text. * * @param aCtx The ModelState pointer for the model to use. * @param aBuffer A 16-bit, mono raw audio signal at the appropriate @@ -180,21 +191,25 @@ char* DS_SpeechToText(ModelState* aCtx, unsigned int aBufferSize); /** - * @brief Use the DeepSpeech model to perform Speech-To-Text and output metadata - * about the results. + * @brief Use the DeepSpeech model to convert speech to text and output results + * including metadata. * * @param aCtx The ModelState pointer for the model to use. * @param aBuffer A 16-bit, mono raw audio signal at the appropriate * sample rate (matching what the model was trained on). * @param aBufferSize The number of samples in the audio signal. + * @param aNumResults The maximum number of CandidateTranscript structs to return. Returned value might be smaller than this. * - * @return Outputs a struct of individual letters along with their timing information. - * The user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. Returns NULL on error. + * @return Metadata struct containing multiple CandidateTranscript structs. Each + * transcript has per-token metadata including timing information. The + * user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. + * Returns NULL on error. */ DEEPSPEECH_EXPORT Metadata* DS_SpeechToTextWithMetadata(ModelState* aCtx, const short* aBuffer, - unsigned int aBufferSize); + unsigned int aBufferSize, + unsigned int aNumResults); /** * @brief Create a new streaming inference state. The streaming state returned @@ -236,8 +251,24 @@ DEEPSPEECH_EXPORT char* DS_IntermediateDecode(const StreamingState* aSctx); /** - * @brief Signal the end of an audio signal to an ongoing streaming - * inference, returns the STT result over the whole audio signal. + * @brief Compute the intermediate decoding of an ongoing streaming inference, + * return results including metadata. + * + * @param aSctx A streaming state pointer returned by {@link DS_CreateStream()}. + * @param aNumResults The number of candidate transcripts to return. + * + * @return Metadata struct containing multiple candidate transcripts. Each transcript + * has per-token metadata including timing information. The user is + * responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. + * Returns NULL on error. + */ +DEEPSPEECH_EXPORT +Metadata* DS_IntermediateDecodeWithMetadata(const StreamingState* aSctx, + unsigned int aNumResults); + +/** + * @brief Compute the final decoding of an ongoing streaming inference and return + * the result. Signals the end of an ongoing streaming inference. * * @param aSctx A streaming state pointer returned by {@link DS_CreateStream()}. * @@ -250,18 +281,23 @@ DEEPSPEECH_EXPORT char* DS_FinishStream(StreamingState* aSctx); /** - * @brief Signal the end of an audio signal to an ongoing streaming - * inference, returns per-letter metadata. + * @brief Compute the final decoding of an ongoing streaming inference and return + * results including metadata. Signals the end of an ongoing streaming + * inference. * * @param aSctx A streaming state pointer returned by {@link DS_CreateStream()}. + * @param aNumResults The number of candidate transcripts to return. * - * @return Outputs a struct of individual letters along with their timing information. - * The user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. Returns NULL on error. + * @return Metadata struct containing multiple candidate transcripts. Each transcript + * has per-token metadata including timing information. The user is + * responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. + * Returns NULL on error. * * @note This method will free the state pointer (@p aSctx). */ DEEPSPEECH_EXPORT -Metadata* DS_FinishStreamWithMetadata(StreamingState* aSctx); +Metadata* DS_FinishStreamWithMetadata(StreamingState* aSctx, + unsigned int aNumResults); /** * @brief Destroy a streaming state without decoding the computed logits. This diff --git a/native_client/definitions.mk b/native_client/definitions.mk index 04e6d0c5..fd9358a2 100644 --- a/native_client/definitions.mk +++ b/native_client/definitions.mk @@ -10,6 +10,7 @@ TOOL_CC := gcc TOOL_CXX := c++ TOOL_LD := ld TOOL_LDD := ldd +TOOL_LIBEXE := DEEPSPEECH_BIN := deepspeech CFLAGS_DEEPSPEECH := -std=c++11 -o $(DEEPSPEECH_BIN) @@ -37,9 +38,10 @@ endif ifeq ($(TARGET),host-win) DEEPSPEECH_BIN := deepspeech.exe TOOLCHAIN := '$(VCINSTALLDIR)\bin\amd64\' -TOOL_CC := cl.exe -TOOL_CXX := cl.exe -TOOL_LD := link.exe +TOOL_CC := cl.exe +TOOL_CXX := cl.exe +TOOL_LD := link.exe +TOOL_LIBEXE := lib.exe LINK_DEEPSPEECH := $(TFDIR)\bazel-bin\native_client\libdeepspeech.so.if.lib LINK_PATH_DEEPSPEECH := CFLAGS_DEEPSPEECH := -nologo -Fe$(DEEPSPEECH_BIN) @@ -113,6 +115,7 @@ CC := $(TOOLCHAIN)$(TOOL_CC) CXX := $(TOOLCHAIN)$(TOOL_CXX) LD := $(TOOLCHAIN)$(TOOL_LD) LDD := $(TOOLCHAIN)$(TOOL_LDD) $(TOOLCHAIN_LDD_OPTS) +LIBEXE := $(TOOLCHAIN)$(TOOL_LIBEXE) RPATH_PYTHON := '-Wl,-rpath,\$$ORIGIN/lib/' $(LDFLAGS_RPATH) RPATH_NODEJS := '-Wl,-rpath,$$\$$ORIGIN/../' diff --git a/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs b/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs index 576ed308..a30bd4de 100644 --- a/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs +++ b/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs @@ -89,38 +89,9 @@ namespace DeepSpeechClient /// Native result code. private void EvaluateResultCode(ErrorCodes resultCode) { - switch (resultCode) + if (resultCode != ErrorCodes.DS_ERR_OK) { - case ErrorCodes.DS_ERR_OK: - break; - case ErrorCodes.DS_ERR_NO_MODEL: - throw new ArgumentException("Missing model information."); - case ErrorCodes.DS_ERR_INVALID_ALPHABET: - throw new ArgumentException("Invalid alphabet embedded in model. (Data corruption?)"); - case ErrorCodes.DS_ERR_INVALID_SHAPE: - throw new ArgumentException("Invalid model shape."); - case ErrorCodes.DS_ERR_INVALID_SCORER: - throw new ArgumentException("Invalid scorer file."); - case ErrorCodes.DS_ERR_FAIL_INIT_MMAP: - throw new ArgumentException("Failed to initialize memory mapped model."); - case ErrorCodes.DS_ERR_FAIL_INIT_SESS: - throw new ArgumentException("Failed to initialize the session."); - case ErrorCodes.DS_ERR_FAIL_INTERPRETER: - throw new ArgumentException("Interpreter failed."); - case ErrorCodes.DS_ERR_FAIL_RUN_SESS: - throw new ArgumentException("Failed to run the session."); - case ErrorCodes.DS_ERR_FAIL_CREATE_STREAM: - throw new ArgumentException("Error creating the stream."); - case ErrorCodes.DS_ERR_FAIL_READ_PROTOBUF: - throw new ArgumentException("Error reading the proto buffer model file."); - case ErrorCodes.DS_ERR_FAIL_CREATE_SESS: - throw new ArgumentException("Error failed to create session."); - case ErrorCodes.DS_ERR_MODEL_INCOMPATIBLE: - throw new ArgumentException("Error incompatible model."); - case ErrorCodes.DS_ERR_SCORER_NOT_ENABLED: - throw new ArgumentException("External scorer is not enabled."); - default: - throw new ArgumentException("Unknown error, please make sure you are using the correct native binary."); + throw new ArgumentException(NativeImp.DS_ErrorCodeToErrorMessage((int)resultCode).PtrToString()); } } @@ -140,7 +111,6 @@ namespace DeepSpeechClient /// Thrown when cannot find the scorer file. public unsafe void EnableExternalScorer(string aScorerPath) { - string exceptionMessage = null; if (string.IsNullOrWhiteSpace(aScorerPath)) { throw new FileNotFoundException("Path to the scorer file cannot be empty."); @@ -199,13 +169,14 @@ namespace DeepSpeechClient } /// - /// Closes the ongoing streaming inference, returns the STT result over the whole audio signal. + /// Closes the ongoing streaming inference, returns the STT result over the whole audio signal, including metadata. /// /// Instance of the stream to finish. + /// Maximum number of candidate transcripts to return. Returned list might be smaller than this. /// The extended metadata result. - public unsafe Metadata FinishStreamWithMetadata(DeepSpeechStream stream) + public unsafe Metadata FinishStreamWithMetadata(DeepSpeechStream stream, uint aNumResults) { - return NativeImp.DS_FinishStreamWithMetadata(stream.GetNativePointer()).PtrToMetadata(); + return NativeImp.DS_FinishStreamWithMetadata(stream.GetNativePointer(), aNumResults).PtrToMetadata(); } /// @@ -218,6 +189,17 @@ namespace DeepSpeechClient return NativeImp.DS_IntermediateDecode(stream.GetNativePointer()).PtrToString(); } + /// + /// Computes the intermediate decoding of an ongoing streaming inference, including metadata. + /// + /// Instance of the stream to decode. + /// Maximum number of candidate transcripts to return. Returned list might be smaller than this. + /// The STT intermediate result. + public unsafe Metadata IntermediateDecodeWithMetadata(DeepSpeechStream stream, uint aNumResults) + { + return NativeImp.DS_IntermediateDecodeWithMetadata(stream.GetNativePointer(), aNumResults).PtrToMetadata(); + } + /// /// Return version of this library. The returned version is a semantic version /// (SemVer 2.0.0). @@ -261,14 +243,15 @@ namespace DeepSpeechClient } /// - /// Use the DeepSpeech model to perform Speech-To-Text. + /// Use the DeepSpeech model to perform Speech-To-Text, return results including metadata. /// /// A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on). /// The number of samples in the audio signal. + /// Maximum number of candidate transcripts to return. Returned list might be smaller than this. /// The extended metadata. Returns NULL on error. - public unsafe Metadata SpeechToTextWithMetadata(short[] aBuffer, uint aBufferSize) + public unsafe Metadata SpeechToTextWithMetadata(short[] aBuffer, uint aBufferSize, uint aNumResults) { - return NativeImp.DS_SpeechToTextWithMetadata(_modelStatePP, aBuffer, aBufferSize).PtrToMetadata(); + return NativeImp.DS_SpeechToTextWithMetadata(_modelStatePP, aBuffer, aBufferSize, aNumResults).PtrToMetadata(); } #endregion diff --git a/native_client/dotnet/DeepSpeechClient/DeepSpeechClient.csproj b/native_client/dotnet/DeepSpeechClient/DeepSpeechClient.csproj index b9077361..0139b3e8 100644 --- a/native_client/dotnet/DeepSpeechClient/DeepSpeechClient.csproj +++ b/native_client/dotnet/DeepSpeechClient/DeepSpeechClient.csproj @@ -50,11 +50,13 @@ - + + - + + diff --git a/native_client/dotnet/DeepSpeechClient/Extensions/NativeExtensions.cs b/native_client/dotnet/DeepSpeechClient/Extensions/NativeExtensions.cs index 6b7f4c6a..9325f4b8 100644 --- a/native_client/dotnet/DeepSpeechClient/Extensions/NativeExtensions.cs +++ b/native_client/dotnet/DeepSpeechClient/Extensions/NativeExtensions.cs @@ -26,35 +26,68 @@ namespace DeepSpeechClient.Extensions } /// - /// Converts a pointer into managed metadata object. + /// Converts a pointer into managed TokenMetadata object. + /// + /// Native pointer. + /// TokenMetadata managed object. + private static Models.TokenMetadata PtrToTokenMetadata(this IntPtr intPtr) + { + var token = Marshal.PtrToStructure(intPtr); + var managedToken = new Models.TokenMetadata + { + Timestep = token.timestep, + StartTime = token.start_time, + Text = token.text.PtrToString(releasePtr: false) + }; + return managedToken; + } + + /// + /// Converts a pointer into managed CandidateTranscript object. + /// + /// Native pointer. + /// CandidateTranscript managed object. + private static Models.CandidateTranscript PtrToCandidateTranscript(this IntPtr intPtr) + { + var managedTranscript = new Models.CandidateTranscript(); + var transcript = Marshal.PtrToStructure(intPtr); + + managedTranscript.Tokens = new Models.TokenMetadata[transcript.num_tokens]; + managedTranscript.Confidence = transcript.confidence; + + //we need to manually read each item from the native ptr using its size + var sizeOfTokenMetadata = Marshal.SizeOf(typeof(TokenMetadata)); + for (int i = 0; i < transcript.num_tokens; i++) + { + managedTranscript.Tokens[i] = transcript.tokens.PtrToTokenMetadata(); + transcript.tokens += sizeOfTokenMetadata; + } + + return managedTranscript; + } + + /// + /// Converts a pointer into managed Metadata object. /// /// Native pointer. /// Metadata managed object. internal static Models.Metadata PtrToMetadata(this IntPtr intPtr) { - var managedMetaObject = new Models.Metadata(); - var metaData = (Metadata)Marshal.PtrToStructure(intPtr, typeof(Metadata)); - - managedMetaObject.Items = new Models.MetadataItem[metaData.num_items]; - managedMetaObject.Confidence = metaData.confidence; + var managedMetadata = new Models.Metadata(); + var metadata = Marshal.PtrToStructure(intPtr); + managedMetadata.Transcripts = new Models.CandidateTranscript[metadata.num_transcripts]; //we need to manually read each item from the native ptr using its size - var sizeOfMetaItem = Marshal.SizeOf(typeof(MetadataItem)); - for (int i = 0; i < metaData.num_items; i++) + var sizeOfCandidateTranscript = Marshal.SizeOf(typeof(CandidateTranscript)); + for (int i = 0; i < metadata.num_transcripts; i++) { - var tempItem = Marshal.PtrToStructure(metaData.items); - managedMetaObject.Items[i] = new Models.MetadataItem - { - Timestep = tempItem.timestep, - StartTime = tempItem.start_time, - Character = tempItem.character.PtrToString(releasePtr: false) - }; - //we keep the offset on each read - metaData.items += sizeOfMetaItem; + managedMetadata.Transcripts[i] = metadata.transcripts.PtrToCandidateTranscript(); + metadata.transcripts += sizeOfCandidateTranscript; } + NativeImp.DS_FreeMetadata(intPtr); - return managedMetaObject; + return managedMetadata; } } } diff --git a/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs b/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs index 18677abc..37d6ce59 100644 --- a/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs +++ b/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs @@ -68,13 +68,15 @@ namespace DeepSpeechClient.Interfaces uint aBufferSize); /// - /// Use the DeepSpeech model to perform Speech-To-Text. + /// Use the DeepSpeech model to perform Speech-To-Text, return results including metadata. /// /// A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on). /// The number of samples in the audio signal. + /// Maximum number of candidate transcripts to return. Returned list might be smaller than this. /// The extended metadata. Returns NULL on error. unsafe Metadata SpeechToTextWithMetadata(short[] aBuffer, - uint aBufferSize); + uint aBufferSize, + uint aNumResults); /// /// Destroy a streaming state without decoding the computed logits. @@ -102,6 +104,14 @@ namespace DeepSpeechClient.Interfaces /// The STT intermediate result. unsafe string IntermediateDecode(DeepSpeechStream stream); + /// + /// Computes the intermediate decoding of an ongoing streaming inference, including metadata. + /// + /// Instance of the stream to decode. + /// Maximum number of candidate transcripts to return. Returned list might be smaller than this. + /// The extended metadata result. + unsafe Metadata IntermediateDecodeWithMetadata(DeepSpeechStream stream, uint aNumResults); + /// /// Closes the ongoing streaming inference, returns the STT result over the whole audio signal. /// @@ -110,10 +120,11 @@ namespace DeepSpeechClient.Interfaces unsafe string FinishStream(DeepSpeechStream stream); /// - /// Closes the ongoing streaming inference, returns the STT result over the whole audio signal. + /// Closes the ongoing streaming inference, returns the STT result over the whole audio signal, including metadata. /// /// Instance of the stream to finish. + /// Maximum number of candidate transcripts to return. Returned list might be smaller than this. /// The extended metadata result. - unsafe Metadata FinishStreamWithMetadata(DeepSpeechStream stream); + unsafe Metadata FinishStreamWithMetadata(DeepSpeechStream stream, uint aNumResults); } } diff --git a/native_client/dotnet/DeepSpeechClient/Models/CandidateTranscript.cs b/native_client/dotnet/DeepSpeechClient/Models/CandidateTranscript.cs new file mode 100644 index 00000000..cc6b5d28 --- /dev/null +++ b/native_client/dotnet/DeepSpeechClient/Models/CandidateTranscript.cs @@ -0,0 +1,17 @@ +namespace DeepSpeechClient.Models +{ + /// + /// Stores the entire CTC output as an array of character metadata objects. + /// + public class CandidateTranscript + { + /// + /// Approximated confidence value for this transcription. + /// + public double Confidence { get; set; } + /// + /// List of metada tokens containing text, timestep, and time offset. + /// + public TokenMetadata[] Tokens { get; set; } + } +} \ No newline at end of file diff --git a/native_client/dotnet/DeepSpeechClient/Models/Metadata.cs b/native_client/dotnet/DeepSpeechClient/Models/Metadata.cs index 870eb162..fb6c613d 100644 --- a/native_client/dotnet/DeepSpeechClient/Models/Metadata.cs +++ b/native_client/dotnet/DeepSpeechClient/Models/Metadata.cs @@ -6,12 +6,8 @@ public class Metadata { /// - /// Approximated confidence value for this transcription. + /// List of candidate transcripts. /// - public double Confidence { get; set; } - /// - /// List of metada items containing char, timespet, and time offset. - /// - public MetadataItem[] Items { get; set; } + public CandidateTranscript[] Transcripts { get; set; } } } \ No newline at end of file diff --git a/native_client/dotnet/DeepSpeechClient/Models/MetadataItem.cs b/native_client/dotnet/DeepSpeechClient/Models/TokenMetadata.cs similarity index 89% rename from native_client/dotnet/DeepSpeechClient/Models/MetadataItem.cs rename to native_client/dotnet/DeepSpeechClient/Models/TokenMetadata.cs index e329c6cb..5f2dea56 100644 --- a/native_client/dotnet/DeepSpeechClient/Models/MetadataItem.cs +++ b/native_client/dotnet/DeepSpeechClient/Models/TokenMetadata.cs @@ -3,12 +3,12 @@ /// /// Stores each individual character, along with its timing information. /// - public class MetadataItem + public class TokenMetadata { /// /// Char of the current timestep. /// - public string Character; + public string Text; /// /// Position of the character in units of 20ms. /// diff --git a/native_client/dotnet/DeepSpeechClient/NativeImp.cs b/native_client/dotnet/DeepSpeechClient/NativeImp.cs index 6c3494b6..bc77cf1b 100644 --- a/native_client/dotnet/DeepSpeechClient/NativeImp.cs +++ b/native_client/dotnet/DeepSpeechClient/NativeImp.cs @@ -17,45 +17,49 @@ namespace DeepSpeechClient [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath, - ref IntPtr** pint); + ref IntPtr** pint); + + [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] + internal unsafe static extern IntPtr DS_ErrorCodeToErrorMessage(int aErrorCode); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal unsafe static extern uint DS_GetModelBeamWidth(IntPtr** aCtx); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal unsafe static extern ErrorCodes DS_SetModelBeamWidth(IntPtr** aCtx, - uint aBeamWidth); + uint aBeamWidth); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath, - uint aBeamWidth, - ref IntPtr** pint); + uint aBeamWidth, + ref IntPtr** pint); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal unsafe static extern int DS_GetModelSampleRate(IntPtr** aCtx); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal static unsafe extern ErrorCodes DS_EnableExternalScorer(IntPtr** aCtx, - string aScorerPath); + string aScorerPath); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal static unsafe extern ErrorCodes DS_DisableExternalScorer(IntPtr** aCtx); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal static unsafe extern ErrorCodes DS_SetScorerAlphaBeta(IntPtr** aCtx, - float aAlpha, - float aBeta); + float aAlpha, + float aBeta); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl, CharSet = CharSet.Ansi, SetLastError = true)] internal static unsafe extern IntPtr DS_SpeechToText(IntPtr** aCtx, - short[] aBuffer, - uint aBufferSize); + short[] aBuffer, + uint aBufferSize); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl, SetLastError = true)] internal static unsafe extern IntPtr DS_SpeechToTextWithMetadata(IntPtr** aCtx, - short[] aBuffer, - uint aBufferSize); + short[] aBuffer, + uint aBufferSize, + uint aNumResults); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal static unsafe extern void DS_FreeModel(IntPtr** aCtx); @@ -76,18 +80,23 @@ namespace DeepSpeechClient [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl, CharSet = CharSet.Ansi, SetLastError = true)] internal static unsafe extern void DS_FeedAudioContent(IntPtr** aSctx, - short[] aBuffer, - uint aBufferSize); + short[] aBuffer, + uint aBufferSize); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal static unsafe extern IntPtr DS_IntermediateDecode(IntPtr** aSctx); + [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] + internal static unsafe extern IntPtr DS_IntermediateDecodeWithMetadata(IntPtr** aSctx, + uint aNumResults); + [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl, CharSet = CharSet.Ansi, SetLastError = true)] internal static unsafe extern IntPtr DS_FinishStream(IntPtr** aSctx); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] - internal static unsafe extern IntPtr DS_FinishStreamWithMetadata(IntPtr** aSctx); + internal static unsafe extern IntPtr DS_FinishStreamWithMetadata(IntPtr** aSctx, + uint aNumResults); #endregion } } diff --git a/native_client/dotnet/DeepSpeechClient/Structs/CandidateTranscript.cs b/native_client/dotnet/DeepSpeechClient/Structs/CandidateTranscript.cs new file mode 100644 index 00000000..54581f6f --- /dev/null +++ b/native_client/dotnet/DeepSpeechClient/Structs/CandidateTranscript.cs @@ -0,0 +1,22 @@ +using System; +using System.Runtime.InteropServices; + +namespace DeepSpeechClient.Structs +{ + [StructLayout(LayoutKind.Sequential)] + internal unsafe struct CandidateTranscript + { + /// + /// Native list of tokens. + /// + internal unsafe IntPtr tokens; + /// + /// Count of tokens from the native side. + /// + internal unsafe int num_tokens; + /// + /// Approximated confidence value for this transcription. + /// + internal unsafe double confidence; + } +} diff --git a/native_client/dotnet/DeepSpeechClient/Structs/Metadata.cs b/native_client/dotnet/DeepSpeechClient/Structs/Metadata.cs index 411da9f2..0a9beddc 100644 --- a/native_client/dotnet/DeepSpeechClient/Structs/Metadata.cs +++ b/native_client/dotnet/DeepSpeechClient/Structs/Metadata.cs @@ -7,16 +7,12 @@ namespace DeepSpeechClient.Structs internal unsafe struct Metadata { /// - /// Native list of items. + /// Native list of candidate transcripts. /// - internal unsafe IntPtr items; + internal unsafe IntPtr transcripts; /// - /// Count of items from the native side. + /// Count of transcripts from the native side. /// - internal unsafe int num_items; - /// - /// Approximated confidence value for this transcription. - /// - internal unsafe double confidence; + internal unsafe int num_transcripts; } } diff --git a/native_client/dotnet/DeepSpeechClient/Structs/MetadataItem.cs b/native_client/dotnet/DeepSpeechClient/Structs/TokenMetadata.cs similarity index 80% rename from native_client/dotnet/DeepSpeechClient/Structs/MetadataItem.cs rename to native_client/dotnet/DeepSpeechClient/Structs/TokenMetadata.cs index 10092742..1c660c71 100644 --- a/native_client/dotnet/DeepSpeechClient/Structs/MetadataItem.cs +++ b/native_client/dotnet/DeepSpeechClient/Structs/TokenMetadata.cs @@ -4,12 +4,12 @@ using System.Runtime.InteropServices; namespace DeepSpeechClient.Structs { [StructLayout(LayoutKind.Sequential)] - internal unsafe struct MetadataItem + internal unsafe struct TokenMetadata { /// - /// Native character. + /// Native text. /// - internal unsafe IntPtr character; + internal unsafe IntPtr text; /// /// Position of the character in units of 20ms. /// diff --git a/native_client/dotnet/DeepSpeechConsole/Program.cs b/native_client/dotnet/DeepSpeechConsole/Program.cs index b35c7046..a08e44b6 100644 --- a/native_client/dotnet/DeepSpeechConsole/Program.cs +++ b/native_client/dotnet/DeepSpeechConsole/Program.cs @@ -21,14 +21,14 @@ namespace CSharpExamples static string GetArgument(IEnumerable args, string option) => args.SkipWhile(i => i != option).Skip(1).Take(1).FirstOrDefault(); - static string MetadataToString(Metadata meta) + static string MetadataToString(CandidateTranscript transcript) { var nl = Environment.NewLine; string retval = - Environment.NewLine + $"Recognized text: {string.Join("", meta?.Items?.Select(x => x.Character))} {nl}" - + $"Confidence: {meta?.Confidence} {nl}" - + $"Item count: {meta?.Items?.Length} {nl}" - + string.Join(nl, meta?.Items?.Select(x => $"Timestep : {x.Timestep} TimeOffset: {x.StartTime} Char: {x.Character}")); + Environment.NewLine + $"Recognized text: {string.Join("", transcript?.Tokens?.Select(x => x.Text))} {nl}" + + $"Confidence: {transcript?.Confidence} {nl}" + + $"Item count: {transcript?.Tokens?.Length} {nl}" + + string.Join(nl, transcript?.Tokens?.Select(x => $"Timestep : {x.Timestep} TimeOffset: {x.StartTime} Char: {x.Text}")); return retval; } @@ -75,8 +75,8 @@ namespace CSharpExamples if (extended) { Metadata metaResult = sttClient.SpeechToTextWithMetadata(waveBuffer.ShortBuffer, - Convert.ToUInt32(waveBuffer.MaxSize / 2)); - speechResult = MetadataToString(metaResult); + Convert.ToUInt32(waveBuffer.MaxSize / 2), 1); + speechResult = MetadataToString(metaResult.Transcripts[0]); } else { diff --git a/native_client/java/jni/deepspeech.i b/native_client/java/jni/deepspeech.i index ded18439..c028714c 100644 --- a/native_client/java/jni/deepspeech.i +++ b/native_client/java/jni/deepspeech.i @@ -6,6 +6,8 @@ %} %include "typemaps.i" +%include "enums.swg" +%javaconst(1); %include "arrays_java.i" // apply to DS_FeedAudioContent and DS_SpeechToText @@ -15,21 +17,29 @@ %pointer_functions(ModelState*, modelstatep); %pointer_functions(StreamingState*, streamingstatep); -%typemap(newfree) char* "DS_FreeString($1);"; - -%include "carrays.i" -%array_functions(struct MetadataItem, metadataItem_array); +%extend struct CandidateTranscript { + /** + * Retrieve one TokenMetadata element + * + * @param i Array index of the TokenMetadata to get + * + * @return The TokenMetadata requested or null + */ + const TokenMetadata& getToken(int i) { + return self->tokens[i]; + } +} %extend struct Metadata { /** - * Retrieve one MetadataItem element + * Retrieve one CandidateTranscript element * - * @param i Array index of the MetadataItem to get + * @param i Array index of the CandidateTranscript to get * - * @return The MetadataItem requested or null + * @return The CandidateTranscript requested or null */ - MetadataItem getItem(int i) { - return metadataItem_array_getitem(self->items, i); + const CandidateTranscript& getTranscript(int i) { + return self->transcripts[i]; } ~Metadata() { @@ -37,14 +47,18 @@ } } -%nodefaultdtor Metadata; %nodefaultctor Metadata; -%nodefaultctor MetadataItem; -%nodefaultdtor MetadataItem; +%nodefaultdtor Metadata; +%nodefaultctor CandidateTranscript; +%nodefaultdtor CandidateTranscript; +%nodefaultctor TokenMetadata; +%nodefaultdtor TokenMetadata; +%typemap(newfree) char* "DS_FreeString($1);"; %newobject DS_SpeechToText; %newobject DS_IntermediateDecode; %newobject DS_FinishStream; +%newobject DS_ErrorCodeToErrorMessage; %rename ("%(strip:[DS_])s") ""; diff --git a/native_client/java/libdeepspeech/src/androidTest/java/org/mozilla/deepspeech/libdeepspeech/test/BasicTest.java b/native_client/java/libdeepspeech/src/androidTest/java/org/mozilla/deepspeech/libdeepspeech/test/BasicTest.java index 2957b2e7..f7eccf00 100644 --- a/native_client/java/libdeepspeech/src/androidTest/java/org/mozilla/deepspeech/libdeepspeech/test/BasicTest.java +++ b/native_client/java/libdeepspeech/src/androidTest/java/org/mozilla/deepspeech/libdeepspeech/test/BasicTest.java @@ -12,7 +12,7 @@ import org.junit.runners.MethodSorters; import static org.junit.Assert.*; import org.mozilla.deepspeech.libdeepspeech.DeepSpeechModel; -import org.mozilla.deepspeech.libdeepspeech.Metadata; +import org.mozilla.deepspeech.libdeepspeech.CandidateTranscript; import java.io.RandomAccessFile; import java.io.FileNotFoundException; @@ -61,10 +61,10 @@ public class BasicTest { m.freeModel(); } - private String metadataToString(Metadata m) { + private String candidateTranscriptToString(CandidateTranscript t) { String retval = ""; - for (int i = 0; i < m.getNum_items(); ++i) { - retval += m.getItem(i).getCharacter(); + for (int i = 0; i < t.getNum_tokens(); ++i) { + retval += t.getToken(i).getText(); } return retval; } @@ -97,7 +97,7 @@ public class BasicTest { ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().get(shorts); if (extendedMetadata) { - return metadataToString(m.sttWithMetadata(shorts, shorts.length)); + return candidateTranscriptToString(m.sttWithMetadata(shorts, shorts.length, 1).getTranscript(0)); } else { return m.stt(shorts, shorts.length); } diff --git a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java index 6d0a316b..eafa11e2 100644 --- a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java +++ b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java @@ -11,8 +11,15 @@ public class DeepSpeechModel { } // FIXME: We should have something better than those SWIGTYPE_* - SWIGTYPE_p_p_ModelState _mspp; - SWIGTYPE_p_ModelState _msp; + private SWIGTYPE_p_p_ModelState _mspp; + private SWIGTYPE_p_ModelState _msp; + + private void evaluateErrorCode(int errorCode) { + DeepSpeech_Error_Codes code = DeepSpeech_Error_Codes.swigToEnum(errorCode); + if (code != DeepSpeech_Error_Codes.ERR_OK) { + throw new RuntimeException("Error: " + impl.ErrorCodeToErrorMessage(errorCode) + " (0x" + Integer.toHexString(errorCode) + ")."); + } + } /** * @brief An object providing an interface to a trained DeepSpeech model. @@ -20,10 +27,12 @@ public class DeepSpeechModel { * @constructor * * @param modelPath The path to the frozen model graph. + * + * @throws RuntimeException on failure. */ public DeepSpeechModel(String modelPath) { this._mspp = impl.new_modelstatep(); - impl.CreateModel(modelPath, this._mspp); + evaluateErrorCode(impl.CreateModel(modelPath, this._mspp)); this._msp = impl.modelstatep_value(this._mspp); } @@ -43,10 +52,10 @@ public class DeepSpeechModel { * @param aBeamWidth The beam width used by the model. A larger beam width value * generates better results at the cost of decoding time. * - * @return Zero on success, non-zero on failure. + * @throws RuntimeException on failure. */ - public int setBeamWidth(long beamWidth) { - return impl.SetModelBeamWidth(this._msp, beamWidth); + public void setBeamWidth(long beamWidth) { + evaluateErrorCode(impl.SetModelBeamWidth(this._msp, beamWidth)); } /** @@ -70,19 +79,19 @@ public class DeepSpeechModel { * * @param scorer The path to the external scorer file. * - * @return Zero on success, non-zero on failure (invalid arguments). + * @throws RuntimeException on failure. */ public void enableExternalScorer(String scorer) { - impl.EnableExternalScorer(this._msp, scorer); + evaluateErrorCode(impl.EnableExternalScorer(this._msp, scorer)); } /** * @brief Disable decoding using an external scorer. * - * @return Zero on success, non-zero on failure (invalid arguments). + * @throws RuntimeException on failure. */ public void disableExternalScorer() { - impl.DisableExternalScorer(this._msp); + evaluateErrorCode(impl.DisableExternalScorer(this._msp)); } /** @@ -91,10 +100,10 @@ public class DeepSpeechModel { * @param alpha The alpha hyperparameter of the decoder. Language model weight. * @param beta The beta hyperparameter of the decoder. Word insertion weight. * - * @return Zero on success, non-zero on failure (invalid arguments). + * @throws RuntimeException on failure. */ public void setScorerAlphaBeta(float alpha, float beta) { - impl.SetScorerAlphaBeta(this._msp, alpha, beta); + evaluateErrorCode(impl.SetScorerAlphaBeta(this._msp, alpha, beta)); } /* @@ -117,11 +126,13 @@ public class DeepSpeechModel { * @param buffer A 16-bit, mono raw audio signal at the appropriate * sample rate (matching what the model was trained on). * @param buffer_size The number of samples in the audio signal. + * @param num_results Maximum number of candidate transcripts to return. Returned list might be smaller than this. * - * @return Outputs a Metadata object of individual letters along with their timing information. + * @return Metadata struct containing multiple candidate transcripts. Each transcript + * has per-token metadata including timing information. */ - public Metadata sttWithMetadata(short[] buffer, int buffer_size) { - return impl.SpeechToTextWithMetadata(this._msp, buffer, buffer_size); + public Metadata sttWithMetadata(short[] buffer, int buffer_size, int num_results) { + return impl.SpeechToTextWithMetadata(this._msp, buffer, buffer_size, num_results); } /** @@ -130,10 +141,12 @@ public class DeepSpeechModel { * and finishStream(). * * @return An opaque object that represents the streaming state. + * + * @throws RuntimeException on failure. */ public DeepSpeechStreamingState createStream() { SWIGTYPE_p_p_StreamingState ssp = impl.new_streamingstatep(); - impl.CreateStream(this._msp, ssp); + evaluateErrorCode(impl.CreateStream(this._msp, ssp)); return new DeepSpeechStreamingState(impl.streamingstatep_value(ssp)); } @@ -161,8 +174,20 @@ public class DeepSpeechModel { } /** - * @brief Signal the end of an audio signal to an ongoing streaming - * inference, returns the STT result over the whole audio signal. + * @brief Compute the intermediate decoding of an ongoing streaming inference. + * + * @param ctx A streaming state pointer returned by createStream(). + * @param num_results Maximum number of candidate transcripts to return. Returned list might be smaller than this. + * + * @return The STT intermediate result. + */ + public Metadata intermediateDecodeWithMetadata(DeepSpeechStreamingState ctx, int num_results) { + return impl.IntermediateDecodeWithMetadata(ctx.get(), num_results); + } + + /** + * @brief Compute the final decoding of an ongoing streaming inference and return + * the result. Signals the end of an ongoing streaming inference. * * @param ctx A streaming state pointer returned by createStream(). * @@ -175,16 +200,19 @@ public class DeepSpeechModel { } /** - * @brief Signal the end of an audio signal to an ongoing streaming - * inference, returns per-letter metadata. + * @brief Compute the final decoding of an ongoing streaming inference and return + * the results including metadata. Signals the end of an ongoing streaming + * inference. * * @param ctx A streaming state pointer returned by createStream(). + * @param num_results Maximum number of candidate transcripts to return. Returned list might be smaller than this. * - * @return Outputs a Metadata object of individual letters along with their timing information. + * @return Metadata struct containing multiple candidate transcripts. Each transcript + * has per-token metadata including timing information. * * @note This method will free the state pointer (@p ctx). */ - public Metadata finishStreamWithMetadata(DeepSpeechStreamingState ctx) { - return impl.FinishStreamWithMetadata(ctx.get()); + public Metadata finishStreamWithMetadata(DeepSpeechStreamingState ctx, int num_results) { + return impl.FinishStreamWithMetadata(ctx.get(), num_results); } } diff --git a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/CandidateTranscript.java b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/CandidateTranscript.java new file mode 100644 index 00000000..fa13c474 --- /dev/null +++ b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/CandidateTranscript.java @@ -0,0 +1,73 @@ +/* ---------------------------------------------------------------------------- + * This file was automatically generated by SWIG (http://www.swig.org). + * Version 4.0.1 + * + * Do not make changes to this file unless you know what you are doing--modify + * the SWIG interface file instead. + * ----------------------------------------------------------------------------- */ + +package org.mozilla.deepspeech.libdeepspeech; + +/** + * A single transcript computed by the model, including a confidence
+ * value and the metadata for its constituent tokens. + */ +public class CandidateTranscript { + private transient long swigCPtr; + protected transient boolean swigCMemOwn; + + protected CandidateTranscript(long cPtr, boolean cMemoryOwn) { + swigCMemOwn = cMemoryOwn; + swigCPtr = cPtr; + } + + protected static long getCPtr(CandidateTranscript obj) { + return (obj == null) ? 0 : obj.swigCPtr; + } + + public synchronized void delete() { + if (swigCPtr != 0) { + if (swigCMemOwn) { + swigCMemOwn = false; + throw new UnsupportedOperationException("C++ destructor does not have public access"); + } + swigCPtr = 0; + } + } + + /** + * Array of TokenMetadata objects + */ + public TokenMetadata getTokens() { + long cPtr = implJNI.CandidateTranscript_tokens_get(swigCPtr, this); + return (cPtr == 0) ? null : new TokenMetadata(cPtr, false); + } + + /** + * Size of the tokens array + */ + public long getNum_tokens() { + return implJNI.CandidateTranscript_num_tokens_get(swigCPtr, this); + } + + /** + * Approximated confidence value for this transcript. This is roughly the
+ * sum of the acoustic model logit values for each timestep/character that
+ * contributed to the creation of this transcript. + */ + public double getConfidence() { + return implJNI.CandidateTranscript_confidence_get(swigCPtr, this); + } + + /** + * Retrieve one TokenMetadata element
+ *
+ * @param i Array index of the TokenMetadata to get
+ *
+ * @return The TokenMetadata requested or null + */ + public TokenMetadata getToken(int i) { + return new TokenMetadata(implJNI.CandidateTranscript_getToken(swigCPtr, this, i), false); + } + +} diff --git a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/DeepSpeech_Error_Codes.java b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/DeepSpeech_Error_Codes.java new file mode 100644 index 00000000..ed47183e --- /dev/null +++ b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/DeepSpeech_Error_Codes.java @@ -0,0 +1,65 @@ +/* ---------------------------------------------------------------------------- + * This file was automatically generated by SWIG (http://www.swig.org). + * Version 4.0.1 + * + * Do not make changes to this file unless you know what you are doing--modify + * the SWIG interface file instead. + * ----------------------------------------------------------------------------- */ + +package org.mozilla.deepspeech.libdeepspeech; + +public enum DeepSpeech_Error_Codes { + ERR_OK(0x0000), + ERR_NO_MODEL(0x1000), + ERR_INVALID_ALPHABET(0x2000), + ERR_INVALID_SHAPE(0x2001), + ERR_INVALID_SCORER(0x2002), + ERR_MODEL_INCOMPATIBLE(0x2003), + ERR_SCORER_NOT_ENABLED(0x2004), + ERR_FAIL_INIT_MMAP(0x3000), + ERR_FAIL_INIT_SESS(0x3001), + ERR_FAIL_INTERPRETER(0x3002), + ERR_FAIL_RUN_SESS(0x3003), + ERR_FAIL_CREATE_STREAM(0x3004), + ERR_FAIL_READ_PROTOBUF(0x3005), + ERR_FAIL_CREATE_SESS(0x3006), + ERR_FAIL_CREATE_MODEL(0x3007); + + public final int swigValue() { + return swigValue; + } + + public static DeepSpeech_Error_Codes swigToEnum(int swigValue) { + DeepSpeech_Error_Codes[] swigValues = DeepSpeech_Error_Codes.class.getEnumConstants(); + if (swigValue < swigValues.length && swigValue >= 0 && swigValues[swigValue].swigValue == swigValue) + return swigValues[swigValue]; + for (DeepSpeech_Error_Codes swigEnum : swigValues) + if (swigEnum.swigValue == swigValue) + return swigEnum; + throw new IllegalArgumentException("No enum " + DeepSpeech_Error_Codes.class + " with value " + swigValue); + } + + @SuppressWarnings("unused") + private DeepSpeech_Error_Codes() { + this.swigValue = SwigNext.next++; + } + + @SuppressWarnings("unused") + private DeepSpeech_Error_Codes(int swigValue) { + this.swigValue = swigValue; + SwigNext.next = swigValue+1; + } + + @SuppressWarnings("unused") + private DeepSpeech_Error_Codes(DeepSpeech_Error_Codes swigEnum) { + this.swigValue = swigEnum.swigValue; + SwigNext.next = this.swigValue+1; + } + + private final int swigValue; + + private static class SwigNext { + private static int next = 0; + } +} + diff --git a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/Metadata.java b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/Metadata.java index 482b7c58..d2831bc4 100644 --- a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/Metadata.java +++ b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/Metadata.java @@ -1,6 +1,6 @@ /* ---------------------------------------------------------------------------- * This file was automatically generated by SWIG (http://www.swig.org). - * Version 4.0.2 + * Version 4.0.1 * * Do not make changes to this file unless you know what you are doing--modify * the SWIG interface file instead. @@ -9,7 +9,7 @@ package org.mozilla.deepspeech.libdeepspeech; /** - * Stores the entire CTC output as an array of character metadata objects + * An array of CandidateTranscript objects computed by the model. */ public class Metadata { private transient long swigCPtr; @@ -40,61 +40,29 @@ public class Metadata { } /** - * List of items + * Array of CandidateTranscript objects */ - public void setItems(MetadataItem value) { - implJNI.Metadata_items_set(swigCPtr, this, MetadataItem.getCPtr(value), value); + public CandidateTranscript getTranscripts() { + long cPtr = implJNI.Metadata_transcripts_get(swigCPtr, this); + return (cPtr == 0) ? null : new CandidateTranscript(cPtr, false); } /** - * List of items + * Size of the transcripts array */ - public MetadataItem getItems() { - long cPtr = implJNI.Metadata_items_get(swigCPtr, this); - return (cPtr == 0) ? null : new MetadataItem(cPtr, false); + public long getNum_transcripts() { + return implJNI.Metadata_num_transcripts_get(swigCPtr, this); } /** - * Size of the list of items - */ - public void setNum_items(int value) { - implJNI.Metadata_num_items_set(swigCPtr, this, value); - } - - /** - * Size of the list of items - */ - public int getNum_items() { - return implJNI.Metadata_num_items_get(swigCPtr, this); - } - - /** - * Approximated confidence value for this transcription. This is roughly the
- * sum of the acoustic model logit values for each timestep/character that
- * contributed to the creation of this transcription. - */ - public void setConfidence(double value) { - implJNI.Metadata_confidence_set(swigCPtr, this, value); - } - - /** - * Approximated confidence value for this transcription. This is roughly the
- * sum of the acoustic model logit values for each timestep/character that
- * contributed to the creation of this transcription. - */ - public double getConfidence() { - return implJNI.Metadata_confidence_get(swigCPtr, this); - } - - /** - * Retrieve one MetadataItem element
+ * Retrieve one CandidateTranscript element
*
- * @param i Array index of the MetadataItem to get
+ * @param i Array index of the CandidateTranscript to get
*
- * @return The MetadataItem requested or null + * @return The CandidateTranscript requested or null */ - public MetadataItem getItem(int i) { - return new MetadataItem(implJNI.Metadata_getItem(swigCPtr, this, i), true); + public CandidateTranscript getTranscript(int i) { + return new CandidateTranscript(implJNI.Metadata_getTranscript(swigCPtr, this, i), false); } } diff --git a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/README.rst b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/README.rst index 1279d717..bd89f9b8 100644 --- a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/README.rst +++ b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/README.rst @@ -4,7 +4,7 @@ Javadoc for Sphinx This code is only here for reference for documentation generation. -To update, please build SWIG (4.0 at least) and then run from native_client/java: +To update, please install SWIG (4.0 at least) and then run from native_client/java: .. code-block:: diff --git a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/TokenMetadata.java b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/TokenMetadata.java new file mode 100644 index 00000000..d14fc161 --- /dev/null +++ b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/TokenMetadata.java @@ -0,0 +1,58 @@ +/* ---------------------------------------------------------------------------- + * This file was automatically generated by SWIG (http://www.swig.org). + * Version 4.0.1 + * + * Do not make changes to this file unless you know what you are doing--modify + * the SWIG interface file instead. + * ----------------------------------------------------------------------------- */ + +package org.mozilla.deepspeech.libdeepspeech; + +/** + * Stores text of an individual token, along with its timing information + */ +public class TokenMetadata { + private transient long swigCPtr; + protected transient boolean swigCMemOwn; + + protected TokenMetadata(long cPtr, boolean cMemoryOwn) { + swigCMemOwn = cMemoryOwn; + swigCPtr = cPtr; + } + + protected static long getCPtr(TokenMetadata obj) { + return (obj == null) ? 0 : obj.swigCPtr; + } + + public synchronized void delete() { + if (swigCPtr != 0) { + if (swigCMemOwn) { + swigCMemOwn = false; + throw new UnsupportedOperationException("C++ destructor does not have public access"); + } + swigCPtr = 0; + } + } + + /** + * The text corresponding to this token + */ + public String getText() { + return implJNI.TokenMetadata_text_get(swigCPtr, this); + } + + /** + * Position of the token in units of 20ms + */ + public long getTimestep() { + return implJNI.TokenMetadata_timestep_get(swigCPtr, this); + } + + /** + * Position of the token in seconds + */ + public float getStart_time() { + return implJNI.TokenMetadata_start_time_get(swigCPtr, this); + } + +} diff --git a/native_client/javascript/client.js b/native_client/javascript/client.js index abbfe59e..16dd19e8 100644 --- a/native_client/javascript/client.js +++ b/native_client/javascript/client.js @@ -42,12 +42,11 @@ function totalTime(hrtimeValue) { return (hrtimeValue[0] + hrtimeValue[1] / 1000000000).toPrecision(4); } -function metadataToString(metadata) { +function candidateTranscriptToString(transcript) { var retval = "" - for (var i = 0; i < metadata.num_items; ++i) { - retval += metadata.items[i].character; + for (var i = 0; i < transcript.tokens.length; ++i) { + retval += transcript.tokens[i].text; } - Ds.FreeMetadata(metadata); return retval; } @@ -117,7 +116,9 @@ audioStream.on('finish', () => { const audioLength = (audioBuffer.length / 2) * (1 / desired_sample_rate); if (args['extended']) { - console.log(metadataToString(model.sttWithMetadata(audioBuffer))); + let metadata = model.sttWithMetadata(audioBuffer, 1); + console.log(candidateTranscriptToString(metadata.transcripts[0])); + Ds.FreeMetadata(metadata); } else { console.log(model.stt(audioBuffer)); } diff --git a/native_client/javascript/deepspeech.i b/native_client/javascript/deepspeech.i index efbaa360..e311a41b 100644 --- a/native_client/javascript/deepspeech.i +++ b/native_client/javascript/deepspeech.i @@ -37,6 +37,7 @@ using namespace node; %newobject DS_IntermediateDecode; %newobject DS_FinishStream; %newobject DS_Version; +%newobject DS_ErrorCodeToErrorMessage; // convert double pointer retval in CreateModel to an output %typemap(in, numinputs=0) ModelState **retval (ModelState *ret) { @@ -47,8 +48,8 @@ using namespace node; %typemap(argout) ModelState **retval { $result = SWIGV8_ARRAY_NEW(); SWIGV8_AppendOutput($result, SWIG_From_int(result)); - // owned by SWIG, ModelState destructor gets called when the JavaScript object is finalized (see below) - %append_output(SWIG_NewPointerObj(%as_voidptr(*$1), $*1_descriptor, SWIG_POINTER_OWN)); + // owned by the application. NodeJS does not guarantee the finalizer will be called so applications must call FreeMetadata themselves. + %append_output(SWIG_NewPointerObj(%as_voidptr(*$1), $*1_descriptor, 0)); } @@ -68,27 +69,29 @@ using namespace node; %nodefaultctor ModelState; %nodefaultdtor ModelState; -%typemap(out) MetadataItem* %{ +%typemap(out) TokenMetadata* %{ $result = SWIGV8_ARRAY_NEW(); - for (int i = 0; i < arg1->num_items; ++i) { - SWIGV8_AppendOutput($result, SWIG_NewPointerObj(SWIG_as_voidptr(&result[i]), SWIGTYPE_p_MetadataItem, SWIG_POINTER_OWN)); + for (int i = 0; i < arg1->num_tokens; ++i) { + SWIGV8_AppendOutput($result, SWIG_NewPointerObj(SWIG_as_voidptr(&result[i]), SWIGTYPE_p_TokenMetadata, 0)); } %} -%nodefaultdtor Metadata; -%nodefaultctor Metadata; -%nodefaultctor MetadataItem; -%nodefaultdtor MetadataItem; - -%extend struct Metadata { - ~Metadata() { - DS_FreeMetadata($self); +%typemap(out) CandidateTranscript* %{ + $result = SWIGV8_ARRAY_NEW(); + for (int i = 0; i < arg1->num_transcripts; ++i) { + SWIGV8_AppendOutput($result, SWIG_NewPointerObj(SWIG_as_voidptr(&result[i]), SWIGTYPE_p_CandidateTranscript, 0)); } -} +%} -%extend struct MetadataItem { - ~MetadataItem() { } -} +%ignore Metadata::num_transcripts; +%ignore CandidateTranscript::num_tokens; + +%nodefaultctor Metadata; +%nodefaultdtor Metadata; +%nodefaultctor CandidateTranscript; +%nodefaultdtor CandidateTranscript; +%nodefaultctor TokenMetadata; +%nodefaultdtor TokenMetadata; %rename ("%(strip:[DS_])s") ""; diff --git a/native_client/javascript/index.js b/native_client/javascript/index.js index cca483f1..30bcb690 100644 --- a/native_client/javascript/index.js +++ b/native_client/javascript/index.js @@ -2,7 +2,7 @@ const binary = require('node-pre-gyp'); const path = require('path') -// 'lib', 'binding', 'v0.1.1', ['node', 'v' + process.versions.modules, process.platform, process.arch].join('-'), 'deepspeech-bingings.node') +// 'lib', 'binding', 'v0.1.1', ['node', 'v' + process.versions.modules, process.platform, process.arch].join('-'), 'deepspeech-bindings.node') const binding_path = binary.find(path.resolve(path.join(__dirname, 'package.json'))); // On Windows, we can't rely on RPATH being set to $ORIGIN/../ or on @@ -35,7 +35,7 @@ function Model(aModelPath) { const status = rets[0]; const impl = rets[1]; if (status !== 0) { - throw "CreateModel failed with error code 0x" + status.toString(16); + throw "CreateModel failed "+binding.ErrorCodeToErrorMessage(status)+" 0x" + status.toString(16); } this._impl = impl; @@ -115,15 +115,16 @@ Model.prototype.stt = function(aBuffer) { } /** - * Use the DeepSpeech model to perform Speech-To-Text and output metadata - * about the results. + * Use the DeepSpeech model to perform Speech-To-Text and output results including metadata. * * @param {object} aBuffer A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on). + * @param {number} aNumResults Maximum number of candidate transcripts to return. Returned list might be smaller than this. Default value is 1 if not specified. * - * @return {object} Outputs a :js:func:`Metadata` struct of individual letters along with their timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`. Returns undefined on error. + * @return {object} :js:func:`Metadata` object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`. Returns undefined on error. */ -Model.prototype.sttWithMetadata = function(aBuffer) { - return binding.SpeechToTextWithMetadata(this._impl, aBuffer); +Model.prototype.sttWithMetadata = function(aBuffer, aNumResults) { + aNumResults = aNumResults || 1; + return binding.SpeechToTextWithMetadata(this._impl, aBuffer, aNumResults); } /** @@ -138,7 +139,7 @@ Model.prototype.createStream = function() { const status = rets[0]; const ctx = rets[1]; if (status !== 0) { - throw "CreateStream failed with error code 0x" + status.toString(16); + throw "CreateStream failed "+binding.ErrorCodeToErrorMessage(status)+" 0x" + status.toString(16); } return ctx; } @@ -172,7 +173,19 @@ Stream.prototype.intermediateDecode = function() { } /** - * Signal the end of an audio signal to an ongoing streaming inference, returns the STT result over the whole audio signal. + * Compute the intermediate decoding of an ongoing streaming inference, return results including metadata. + * + * @param {number} aNumResults Maximum number of candidate transcripts to return. Returned list might be smaller than this. Default value is 1 if not specified. + * + * @return {object} :js:func:`Metadata` object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`. Returns undefined on error. + */ +Stream.prototype.intermediateDecodeWithMetadata = function(aNumResults) { + aNumResults = aNumResults || 1; + return binding.IntermediateDecode(this._impl, aNumResults); +} + +/** + * Compute the final decoding of an ongoing streaming inference and return the result. Signals the end of an ongoing streaming inference. * * @return {string} The STT result. * @@ -185,14 +198,17 @@ Stream.prototype.finishStream = function() { } /** - * Signal the end of an audio signal to an ongoing streaming inference, returns per-letter metadata. + * Compute the final decoding of an ongoing streaming inference and return the results including metadata. Signals the end of an ongoing streaming inference. + * + * @param {number} aNumResults Maximum number of candidate transcripts to return. Returned list might be smaller than this. Default value is 1 if not specified. * * @return {object} Outputs a :js:func:`Metadata` struct of individual letters along with their timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`. * * This method will free the stream, it must not be used after this method is called. */ -Stream.prototype.finishStreamWithMetadata = function() { - result = binding.FinishStreamWithMetadata(this._impl); +Stream.prototype.finishStreamWithMetadata = function(aNumResults) { + aNumResults = aNumResults || 1; + result = binding.FinishStreamWithMetadata(this._impl, aNumResults); this._impl = null; return result; } @@ -236,70 +252,80 @@ function Version() { } -//// Metadata and MetadataItem are here only for documentation purposes +//// Metadata, CandidateTranscript and TokenMetadata are here only for documentation purposes /** * @class * - * Stores each individual character, along with its timing information + * Stores text of an individual token, along with its timing information */ -function MetadataItem() {} +function TokenMetadata() {} /** - * The character generated for transcription + * The text corresponding to this token * - * @return {string} The character generated + * @return {string} The text generated */ -MetadataItem.prototype.character = function() {} +TokenMetadata.prototype.text = function() {} /** - * Position of the character in units of 20ms + * Position of the token in units of 20ms * - * @return {int} The position of the character + * @return {int} The position of the token */ -MetadataItem.prototype.timestep = function() {}; +TokenMetadata.prototype.timestep = function() {}; /** - * Position of the character in seconds + * Position of the token in seconds * - * @return {float} The position of the character + * @return {float} The position of the token */ -MetadataItem.prototype.start_time = function() {}; +TokenMetadata.prototype.start_time = function() {}; /** * @class * - * Stores the entire CTC output as an array of character metadata objects + * A single transcript computed by the model, including a confidence value and + * the metadata for its constituent tokens. */ -function Metadata () {} +function CandidateTranscript () {} /** - * List of items + * Array of tokens * - * @return {array} List of :js:func:`MetadataItem` + * @return {array} Array of :js:func:`TokenMetadata` */ -Metadata.prototype.items = function() {} - -/** - * Size of the list of items - * - * @return {int} Number of items - */ -Metadata.prototype.num_items = function() {} +CandidateTranscript.prototype.tokens = function() {} /** * Approximated confidence value for this transcription. This is roughly the - * sum of the acoustic model logit values for each timestep/character that + * sum of the acoustic model logit values for each timestep/token that * contributed to the creation of this transcription. * * @return {float} Confidence value */ -Metadata.prototype.confidence = function() {} +CandidateTranscript.prototype.confidence = function() {} + +/** + * @class + * + * An array of CandidateTranscript objects computed by the model. + */ +function Metadata () {} + +/** + * Array of transcripts + * + * @return {array} Array of :js:func:`CandidateTranscript` objects + */ +Metadata.prototype.transcripts = function() {} + module.exports = { Model: Model, Metadata: Metadata, - MetadataItem: MetadataItem, + CandidateTranscript: CandidateTranscript, + TokenMetadata: TokenMetadata, Version: Version, FreeModel: FreeModel, FreeStream: FreeStream, diff --git a/native_client/modelstate.cc b/native_client/modelstate.cc index ea8928bd..3cb06ac2 100644 --- a/native_client/modelstate.cc +++ b/native_client/modelstate.cc @@ -37,27 +37,39 @@ ModelState::decode(const DecoderState& state) const } Metadata* -ModelState::decode_metadata(const DecoderState& state) +ModelState::decode_metadata(const DecoderState& state, + size_t num_results) { - vector out = state.decode(); + vector out = state.decode(num_results); + unsigned int num_returned = out.size(); - std::unique_ptr metadata(new Metadata()); - metadata->num_items = out[0].tokens.size(); - metadata->confidence = out[0].confidence; + CandidateTranscript* transcripts = (CandidateTranscript*)malloc(sizeof(CandidateTranscript)*num_returned); - std::unique_ptr items(new MetadataItem[metadata->num_items]()); + for (int i = 0; i < num_returned; ++i) { + TokenMetadata* tokens = (TokenMetadata*)malloc(sizeof(TokenMetadata)*out[i].tokens.size()); - // Loop through each character - for (int i = 0; i < out[0].tokens.size(); ++i) { - items[i].character = strdup(alphabet_.StringFromLabel(out[0].tokens[i]).c_str()); - items[i].timestep = out[0].timesteps[i]; - items[i].start_time = out[0].timesteps[i] * ((float)audio_win_step_ / sample_rate_); - - if (items[i].start_time < 0) { - items[i].start_time = 0; + for (int j = 0; j < out[i].tokens.size(); ++j) { + TokenMetadata token { + strdup(alphabet_.StringFromLabel(out[i].tokens[j]).c_str()), // text + static_cast(out[i].timesteps[j]), // timestep + out[i].timesteps[j] * ((float)audio_win_step_ / sample_rate_), // start_time + }; + memcpy(&tokens[j], &token, sizeof(TokenMetadata)); } + + CandidateTranscript transcript { + tokens, // tokens + static_cast(out[i].tokens.size()), // num_tokens + out[i].confidence, // confidence + }; + memcpy(&transcripts[i], &transcript, sizeof(CandidateTranscript)); } - metadata->items = items.release(); - return metadata.release(); + Metadata* ret = (Metadata*)malloc(sizeof(Metadata)); + Metadata metadata { + transcripts, // transcripts + num_returned, // num_transcripts + }; + memcpy(ret, &metadata, sizeof(Metadata)); + return ret; } diff --git a/native_client/modelstate.h b/native_client/modelstate.h index 25251e15..0dbe108a 100644 --- a/native_client/modelstate.h +++ b/native_client/modelstate.h @@ -66,11 +66,14 @@ struct ModelState { * @brief Return character-level metadata including letter timings. * * @param state Decoder state to use when decoding. + * @param num_results Maximum number of candidate results to return. * - * @return Metadata struct containing MetadataItem structs for each character. - * The user is responsible for freeing Metadata by calling DS_FreeMetadata(). + * @return A Metadata struct containing CandidateTranscript structs. + * Each represents an candidate transcript, with the first ranked most probable. + * The user is responsible for freeing Result by calling DS_FreeMetadata(). */ - virtual Metadata* decode_metadata(const DecoderState& state); + virtual Metadata* decode_metadata(const DecoderState& state, + size_t num_results); }; #endif // MODELSTATE_H diff --git a/native_client/python/Makefile b/native_client/python/Makefile index 1735fe77..5dfba785 100644 --- a/native_client/python/Makefile +++ b/native_client/python/Makefile @@ -6,6 +6,8 @@ bindings-clean: rm -rf dist temp_build deepspeech.egg-info MANIFEST.in temp_lib rm -f impl_wrap.cpp impl.py +# Enforce PATH here because swig calls from build_ext looses track of some +# variables over several runs bindings-build: pip install --quiet $(PYTHON_PACKAGES) wheel==0.33.6 setuptools==39.1.0 PATH=$(TOOLCHAIN):$$PATH AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED) $(RPATH_PYTHON)" MODEL_LDFLAGS="$(LDFLAGS_DIRS)" MODEL_LIBS="$(LIBS)" $(PYTHON_PATH) $(PYTHON_SYSCONFIGDATA) $(NUMPY_INCLUDE) python ./setup.py build_ext $(PYTHON_PLATFORM_NAME) diff --git a/native_client/python/__init__.py b/native_client/python/__init__.py index a6511efe..a6af56f1 100644 --- a/native_client/python/__init__.py +++ b/native_client/python/__init__.py @@ -35,7 +35,7 @@ class Model(object): status, impl = deepspeech.impl.CreateModel(model_path) if status != 0: - raise RuntimeError("CreateModel failed with error code 0x{:X}".format(status)) + raise RuntimeError("CreateModel failed with '{}' (0x{:X})".format(deepspeech.impl.ErrorCodeToErrorMessage(status),status)) self._impl = impl def __del__(self): @@ -121,17 +121,20 @@ class Model(object): """ return deepspeech.impl.SpeechToText(self._impl, audio_buffer) - def sttWithMetadata(self, audio_buffer): + def sttWithMetadata(self, audio_buffer, num_results=1): """ - Use the DeepSpeech model to perform Speech-To-Text and output metadata about the results. + Use the DeepSpeech model to perform Speech-To-Text and return results including metadata. :param audio_buffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on). :type audio_buffer: numpy.int16 array - :return: Outputs a struct of individual letters along with their timing information. + :param num_results: Maximum number of candidate transcripts to return. Returned list might be smaller than this. + :type num_results: int + + :return: Metadata object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information. :type: :func:`Metadata` """ - return deepspeech.impl.SpeechToTextWithMetadata(self._impl, audio_buffer) + return deepspeech.impl.SpeechToTextWithMetadata(self._impl, audio_buffer, num_results) def createStream(self): """ @@ -145,7 +148,7 @@ class Model(object): """ status, ctx = deepspeech.impl.CreateStream(self._impl) if status != 0: - raise RuntimeError("CreateStream failed with error code 0x{:X}".format(status)) + raise RuntimeError("CreateStream failed with '{}' (0x{:X})".format(deepspeech.impl.ErrorCodeToErrorMessage(status),status)) return Stream(ctx) @@ -187,10 +190,27 @@ class Stream(object): raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?") return deepspeech.impl.IntermediateDecode(self._impl) + def intermediateDecodeWithMetadata(self, num_results=1): + """ + Compute the intermediate decoding of an ongoing streaming inference and return results including metadata. + + :param num_results: Maximum number of candidate transcripts to return. Returned list might be smaller than this. + :type num_results: int + + :return: Metadata object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information. + :type: :func:`Metadata` + + :throws: RuntimeError if the stream object is not valid + """ + if not self._impl: + raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?") + return deepspeech.impl.IntermediateDecodeWithMetadata(self._impl, num_results) + def finishStream(self): """ - Signal the end of an audio signal to an ongoing streaming inference, - returns the STT result over the whole audio signal. + Compute the final decoding of an ongoing streaming inference and return + the result. Signals the end of an ongoing streaming inference. The underlying + stream object must not be used after this method is called. :return: The STT result. :type: str @@ -203,19 +223,24 @@ class Stream(object): self._impl = None return result - def finishStreamWithMetadata(self): + def finishStreamWithMetadata(self, num_results=1): """ - Signal the end of an audio signal to an ongoing streaming inference, - returns per-letter metadata. + Compute the final decoding of an ongoing streaming inference and return + results including metadata. Signals the end of an ongoing streaming + inference. The underlying stream object must not be used after this + method is called. - :return: Outputs a struct of individual letters along with their timing information. + :param num_results: Maximum number of candidate transcripts to return. Returned list might be smaller than this. + :type num_results: int + + :return: Metadata object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information. :type: :func:`Metadata` :throws: RuntimeError if the stream object is not valid """ if not self._impl: raise RuntimeError("Stream object is not valid. Trying to finish an already finished stream?") - result = deepspeech.impl.FinishStreamWithMetadata(self._impl) + result = deepspeech.impl.FinishStreamWithMetadata(self._impl, num_results) self._impl = None return result @@ -233,52 +258,43 @@ class Stream(object): # This is only for documentation purpose -# Metadata and MetadataItem should be in sync with native_client/deepspeech.h -class MetadataItem(object): +# Metadata, CandidateTranscript and TokenMetadata should be in sync with native_client/deepspeech.h +class TokenMetadata(object): """ Stores each individual character, along with its timing information """ - def character(self): + def text(self): """ - The character generated for transcription + The text for this token """ def timestep(self): """ - Position of the character in units of 20ms + Position of the token in units of 20ms """ def start_time(self): """ - Position of the character in seconds + Position of the token in seconds """ -class Metadata(object): +class CandidateTranscript(object): """ Stores the entire CTC output as an array of character metadata objects """ - def items(self): + def tokens(self): """ - List of items + List of tokens - :return: A list of :func:`MetadataItem` elements + :return: A list of :func:`TokenMetadata` elements :type: list """ - def num_items(self): - """ - Size of the list of items - - :return: Size of the list of items - :type: int - """ - - def confidence(self): """ Approximated confidence value for this transcription. This is roughly the @@ -286,3 +302,12 @@ class Metadata(object): contributed to the creation of this transcription. """ + +class Metadata(object): + def transcripts(self): + """ + List of candidate transcripts + + :return: A list of :func:`CandidateTranscript` objects + :type: list + """ diff --git a/native_client/python/client.py b/native_client/python/client.py index 671968b9..00fa2ff6 100644 --- a/native_client/python/client.py +++ b/native_client/python/client.py @@ -18,6 +18,7 @@ try: except ImportError: from pipes import quote + def convert_samplerate(audio_path, desired_sample_rate): sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path), desired_sample_rate) try: @@ -31,25 +32,25 @@ def convert_samplerate(audio_path, desired_sample_rate): def metadata_to_string(metadata): - return ''.join(item.character for item in metadata.items) + return ''.join(token.text for token in metadata.tokens) -def words_from_metadata(metadata): + +def words_from_candidate_transcript(metadata): word = "" word_list = [] word_start_time = 0 # Loop through each character - for i in range(0, metadata.num_items): - item = metadata.items[i] + for i, token in enumerate(metadata.tokens): # Append character to word if it's not a space - if item.character != " ": + if token.text != " ": if len(word) == 0: # Log the start time of the new word - word_start_time = item.start_time + word_start_time = token.start_time - word = word + item.character + word = word + token.text # Word boundary is either a space or the last character in the array - if item.character == " " or i == metadata.num_items - 1: - word_duration = item.start_time - word_start_time + if token.text == " " or i == len(metadata.tokens) - 1: + word_duration = token.start_time - word_start_time if word_duration < 0: word_duration = 0 @@ -69,9 +70,11 @@ def words_from_metadata(metadata): def metadata_json_output(metadata): json_result = dict() - json_result["words"] = words_from_metadata(metadata) - json_result["confidence"] = metadata.confidence - return json.dumps(json_result) + json_result["transcripts"] = [{ + "confidence": transcript.confidence, + "words": words_from_candidate_transcript(transcript), + } for transcript in metadata.transcripts] + return json.dumps(json_result, indent=2) @@ -141,9 +144,9 @@ def main(): print('Running inference.', file=sys.stderr) inference_start = timer() if args.extended: - print(metadata_to_string(ds.sttWithMetadata(audio))) + print(metadata_to_string(ds.sttWithMetadata(audio, 1).transcripts[0])) elif args.json: - print(metadata_json_output(ds.sttWithMetadata(audio))) + print(metadata_json_output(ds.sttWithMetadata(audio, 3))) else: print(ds.stt(audio)) inference_end = timer() - inference_start diff --git a/native_client/python/impl.i b/native_client/python/impl.i index d6c7ba19..3ee4b516 100644 --- a/native_client/python/impl.i +++ b/native_client/python/impl.i @@ -38,30 +38,69 @@ import_array(); %append_output(SWIG_NewPointerObj(%as_voidptr($1), $1_descriptor, SWIG_POINTER_OWN)); } -%typemap(out) MetadataItem* %{ - $result = PyList_New(arg1->num_items); - for (int i = 0; i < arg1->num_items; ++i) { - PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->items[i]), SWIGTYPE_p_MetadataItem, 0); +%fragment("parent_reference_init", "init") { + // Thread-safe initialization - initialize during Python module initialization + parent_reference(); +} + +%fragment("parent_reference_function", "header", fragment="parent_reference_init") { + +static PyObject *parent_reference() { + static PyObject *parent_reference_string = SWIG_Python_str_FromChar("__parent_reference"); + return parent_reference_string; +} + +} + +%typemap(out, fragment="parent_reference_function") CandidateTranscript* %{ + $result = PyList_New(arg1->num_transcripts); + for (int i = 0; i < arg1->num_transcripts; ++i) { + PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->transcripts[i]), SWIGTYPE_p_CandidateTranscript, 0); + // Add a reference to Metadata in the returned elements to avoid premature + // garbage collection + PyObject_SetAttr(o, parent_reference(), $self); PyList_SetItem($result, i, o); } %} -%extend struct MetadataItem { +%typemap(out, fragment="parent_reference_function") TokenMetadata* %{ + $result = PyList_New(arg1->num_tokens); + for (int i = 0; i < arg1->num_tokens; ++i) { + PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->tokens[i]), SWIGTYPE_p_TokenMetadata, 0); + // Add a reference to CandidateTranscript in the returned elements to avoid premature + // garbage collection + PyObject_SetAttr(o, parent_reference(), $self); + PyList_SetItem($result, i, o); + } +%} + +%extend struct TokenMetadata { %pythoncode %{ def __repr__(self): - return 'MetadataItem(character=\'{}\', timestep={}, start_time={})'.format(self.character, self.timestep, self.start_time) + return 'TokenMetadata(text=\'{}\', timestep={}, start_time={})'.format(self.text, self.timestep, self.start_time) +%} +} + +%extend struct CandidateTranscript { +%pythoncode %{ + def __repr__(self): + tokens_repr = ',\n'.join(repr(i) for i in self.tokens) + tokens_repr = '\n'.join(' ' + l for l in tokens_repr.split('\n')) + return 'CandidateTranscript(confidence={}, tokens=[\n{}\n])'.format(self.confidence, tokens_repr) %} } %extend struct Metadata { %pythoncode %{ def __repr__(self): - items_repr = ', \n'.join(' ' + repr(i) for i in self.items) - return 'Metadata(confidence={}, items=[\n{}\n])'.format(self.confidence, items_repr) + transcripts_repr = ',\n'.join(repr(i) for i in self.transcripts) + transcripts_repr = '\n'.join(' ' + l for l in transcripts_repr.split('\n')) + return 'Metadata(transcripts=[\n{}\n])'.format(transcripts_repr) %} } -%ignore Metadata::num_items; +%ignore Metadata::num_transcripts; +%ignore CandidateTranscript::num_tokens; %extend struct Metadata { ~Metadata() { @@ -69,10 +108,12 @@ import_array(); } } -%nodefaultdtor Metadata; %nodefaultctor Metadata; -%nodefaultctor MetadataItem; -%nodefaultdtor MetadataItem; +%nodefaultdtor Metadata; +%nodefaultctor CandidateTranscript; +%nodefaultdtor CandidateTranscript; +%nodefaultctor TokenMetadata; +%nodefaultdtor TokenMetadata; %typemap(newfree) char* "DS_FreeString($1);"; @@ -80,6 +121,7 @@ import_array(); %newobject DS_IntermediateDecode; %newobject DS_FinishStream; %newobject DS_Version; +%newobject DS_ErrorCodeToErrorMessage; %rename ("%(strip:[DS_])s") ""; diff --git a/requirements.txt b/requirements.txt index 742b8244..e05793c9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,13 @@ # Main training requirements -tensorflow == 1.15.0 +tensorflow == 1.15.2 numpy == 1.18.1 progressbar2 -pandas six pyxdg attrdict absl-py semver +opuslib == 2.0.0 # Requirements for building native_client files setuptools @@ -15,6 +15,10 @@ setuptools # Requirements for importers sox bs4 +pandas requests librosa soundfile + +# Requirements for optimizer +optuna diff --git a/requirements_tests.txt b/requirements_tests.txt new file mode 100644 index 00000000..de689076 --- /dev/null +++ b/requirements_tests.txt @@ -0,0 +1,3 @@ +absl-py +argparse +semver diff --git a/taskcluster/.shared.yml b/taskcluster/.shared.yml index e762ebd8..31ca23c0 100644 --- a/taskcluster/.shared.yml +++ b/taskcluster/.shared.yml @@ -5,6 +5,9 @@ python: apt: 'python3-virtualenv python3-setuptools python3-pip python3-wheel python3-pkg-resources' packages_docs_bionic: apt: 'python3 python3-pip zip doxygen' +training: + packages_trusty: + apt: 'libopus0' tensorflow: packages_trusty: apt: 'make build-essential gfortran git libblas-dev liblapack-dev libsox-dev libmagic-dev libgsm1-dev libltdl-dev libpng-dev python zlib1g-dev' @@ -80,6 +83,18 @@ system: android_25: url: 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.android_cache.x86_64.android-25.4/artifacts/public/android_cache.tar.gz' namespace: 'project.deepspeech.android_cache.x86_64.android-25.4' + android_26: + url: 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.android_cache.x86_64.android-26.0/artifacts/public/android_cache.tar.gz' + namespace: 'project.deepspeech.android_cache.x86_64.android-26.0' + android_27: + url: 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.android_cache.x86_64.android-27.0/artifacts/public/android_cache.tar.gz' + namespace: 'project.deepspeech.android_cache.x86_64.android-27.0' + android_28: + url: 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.android_cache.x86_64.android-28.0/artifacts/public/android_cache.tar.gz' + namespace: 'project.deepspeech.android_cache.x86_64.android-28.0' + android_29: + url: 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.android_cache.x86_64.android-29.0/artifacts/public/android_cache.tar.gz' + namespace: 'project.deepspeech.android_cache.x86_64.android-29.0' sdk: android_27: url: 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.android_cache.sdk.android-27.4/artifacts/public/android_cache.tar.gz' diff --git a/taskcluster/android-cache-x86_64-android-26.yml b/taskcluster/android-cache-x86_64-android-26.yml new file mode 100644 index 00000000..ec711ede --- /dev/null +++ b/taskcluster/android-cache-x86_64-android-26.yml @@ -0,0 +1,14 @@ +build: + template_file: android_cache-opt-base.tyml + system_setup: + > + ${java.packages_trusty.apt} + cache: + url: ${system.android_cache.x86_64.android_26.url} + namespace: ${system.android_cache.x86_64.android_26.namespace} + scripts: + build: "taskcluster/android_cache-build.sh x86_64 android-26" + package: "taskcluster/android_cache-package.sh" + metadata: + name: "Builds Android cache x86_64 / android-26" + description: "Setup an Android SDK / emulator cache for Android / x86_64 android-26" diff --git a/taskcluster/android-cache-x86_64-android-28.yml b/taskcluster/android-cache-x86_64-android-28.yml new file mode 100644 index 00000000..471f33b9 --- /dev/null +++ b/taskcluster/android-cache-x86_64-android-28.yml @@ -0,0 +1,14 @@ +build: + template_file: android_cache-opt-base.tyml + system_setup: + > + ${java.packages_trusty.apt} + cache: + url: ${system.android_cache.x86_64.android_28.url} + namespace: ${system.android_cache.x86_64.android_28.namespace} + scripts: + build: "taskcluster/android_cache-build.sh x86_64 android-28" + package: "taskcluster/android_cache-package.sh" + metadata: + name: "Builds Android cache x86_64 / android-28" + description: "Setup an Android SDK / emulator cache for Android / x86_64 android-28" diff --git a/taskcluster/android-cache-x86_64-android-29.yml b/taskcluster/android-cache-x86_64-android-29.yml new file mode 100644 index 00000000..835453f9 --- /dev/null +++ b/taskcluster/android-cache-x86_64-android-29.yml @@ -0,0 +1,14 @@ +build: + template_file: android_cache-opt-base.tyml + system_setup: + > + ${java.packages_trusty.apt} + cache: + url: ${system.android_cache.x86_64.android_29.url} + namespace: ${system.android_cache.x86_64.android_29.namespace} + scripts: + build: "taskcluster/android_cache-build.sh x86_64 android-29" + package: "taskcluster/android_cache-package.sh" + metadata: + name: "Builds Android cache x86_64 / android-29" + description: "Setup an Android SDK / emulator cache for Android / x86_64 android-29" diff --git a/taskcluster/decoder-build.sh b/taskcluster/decoder-build.sh index 4bae5e27..240a57ea 100755 --- a/taskcluster/decoder-build.sh +++ b/taskcluster/decoder-build.sh @@ -6,6 +6,10 @@ source $(dirname "$0")/tc-tests-utils.sh source ${DS_ROOT_TASK}/DeepSpeech/tf/tc-vars.sh -export SYSTEM_TARGET=host +if [ "${OS}" = "${TC_MSYS_VERSION}" ]; then + export SYSTEM_TARGET=host-win +else + export SYSTEM_TARGET=host +fi; do_deepspeech_decoder_build diff --git a/taskcluster/scriptworker-task-github.yml b/taskcluster/scriptworker-task-github.yml index 26326305..75799d40 100644 --- a/taskcluster/scriptworker-task-github.yml +++ b/taskcluster/scriptworker-task-github.yml @@ -19,6 +19,7 @@ build: - "android-java-opt" - "win-amd64-cpu-opt" - "win-amd64-gpu-opt" + - "win-amd64-ctc-opt" allowed: - "tag" ref_match: "refs/tags/" @@ -38,6 +39,7 @@ build: - "win-amd64-cpu-opt" - "win-amd64-tflite-opt" - "win-amd64-gpu-opt" + - "win-amd64-ctc-opt" javascript: # GPU package - "node-package-gpu" diff --git a/taskcluster/tc-asserts.sh b/taskcluster/tc-asserts.sh index fd720557..b12b9683 100755 --- a/taskcluster/tc-asserts.sh +++ b/taskcluster/tc-asserts.sh @@ -252,6 +252,25 @@ assert_deepspeech_version() assert_not_present "$1" "DeepSpeech: unknown" } +# We need to ensure that running on inference really leverages GPU because +# it might default back to CPU +ensure_cuda_usage() +{ + local _maybe_cuda=$1 + DS_BINARY_FILE=${DS_BINARY_FILE:-"deepspeech"} + + if [ "${_maybe_cuda}" = "cuda" ]; then + set +e + export TF_CPP_MIN_VLOG_LEVEL=1 + ds_cuda=$(${DS_BINARY_PREFIX}${DS_BINARY_FILE} --model ${TASKCLUSTER_TMP_DIR}/${model_name} --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>&1 1>/dev/null) + export TF_CPP_MIN_VLOG_LEVEL= + set -e + + assert_shows_something "${ds_cuda}" "Successfully opened dynamic library nvcuda.dll" + assert_not_present "${ds_cuda}" "Skipping registering GPU devices" + fi; +} + check_versions() { set +e diff --git a/taskcluster/tc-cppwin-ds-tests.sh b/taskcluster/tc-cppwin-ds-tests.sh index 6f177a39..64fa8387 100644 --- a/taskcluster/tc-cppwin-ds-tests.sh +++ b/taskcluster/tc-cppwin-ds-tests.sh @@ -13,4 +13,6 @@ export PATH=${TASKCLUSTER_TMP_DIR}/ds/:$PATH check_versions +ensure_cuda_usage "$2" + run_basic_inference_tests diff --git a/taskcluster/tc-electron-tests.sh b/taskcluster/tc-electron-tests.sh index 566775f4..ef0f2ac2 100755 --- a/taskcluster/tc-electron-tests.sh +++ b/taskcluster/tc-electron-tests.sh @@ -59,6 +59,8 @@ node --version check_runtime_electronjs +ensure_cuda_usage "$4" + run_electronjs_inference_tests if [ "${OS}" = "Linux" ]; then diff --git a/taskcluster/tc-netframework-ds-tests.sh b/taskcluster/tc-netframework-ds-tests.sh index 552a4204..e74d0cb8 100644 --- a/taskcluster/tc-netframework-ds-tests.sh +++ b/taskcluster/tc-netframework-ds-tests.sh @@ -10,7 +10,7 @@ source $(dirname "$0")/tc-tests-utils.sh bitrate=$1 set_ldc_sample_filename "${bitrate}" -if [ "${package_option}" = "--cuda" ]; then +if [ "${package_option}" = "cuda" ]; then PROJECT_NAME="DeepSpeech-GPU" elif [ "${package_option}" = "--tflite" ]; then PROJECT_NAME="DeepSpeech-TFLite" @@ -25,4 +25,7 @@ download_data install_nuget "${PROJECT_NAME}" +DS_BINARY_FILE="DeepSpeechConsole.exe" +ensure_cuda_usage "$2" + run_netframework_inference_tests diff --git a/taskcluster/tc-node-tests.sh b/taskcluster/tc-node-tests.sh index 17548022..4085a816 100644 --- a/taskcluster/tc-node-tests.sh +++ b/taskcluster/tc-node-tests.sh @@ -29,4 +29,6 @@ npm install --prefix ${NODE_ROOT} --cache ${NODE_CACHE} ${deepspeech_npm_url} check_runtime_nodejs +ensure_cuda_usage "$3" + run_all_inference_tests diff --git a/taskcluster/tc-node-utils.sh b/taskcluster/tc-node-utils.sh index a4bb1d13..8c33cb39 100755 --- a/taskcluster/tc-node-utils.sh +++ b/taskcluster/tc-node-utils.sh @@ -7,8 +7,8 @@ get_dep_npm_pkg_url() { local all_deps="$(curl -s https://community-tc.services.mozilla.com/api/queue/v1/task/${TASK_ID} | python -c 'import json; import sys; print(" ".join(json.loads(sys.stdin.read())["dependencies"]));')" - # We try "deepspeech-tflite" first and if we don't find it we try "deepspeech" - for pkg_basename in "deepspeech-tflite" "deepspeech"; do + # We try "deepspeech-tflite" and "deepspeech-gpu" first and if we don't find it we try "deepspeech" + for pkg_basename in "deepspeech-tflite" "deepspeech-gpu" "deepspeech"; do local deepspeech_pkg="${pkg_basename}-${DS_VERSION}.tgz" for dep in ${all_deps}; do local has_artifact=$(curl -s https://community-tc.services.mozilla.com/api/queue/v1/task/${dep}/artifacts | python -c 'import json; import sys; has_artifact = True in [ e["name"].find("'${deepspeech_pkg}'") > 0 for e in json.loads(sys.stdin.read())["artifacts"] ]; print(has_artifact)') diff --git a/taskcluster/tc-python-tests.sh b/taskcluster/tc-python-tests.sh index 92a2b792..d55a3097 100644 --- a/taskcluster/tc-python-tests.sh +++ b/taskcluster/tc-python-tests.sh @@ -13,12 +13,19 @@ download_data virtualenv_activate "${pyalias}" "deepspeech" -deepspeech_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type}) +if [ "$3" = "cuda" ]; then + deepspeech_pkg_url=$(get_python_pkg_url "${pyver_pkg}" "${py_unicode_type}" "deepspeech_gpu") +else + deepspeech_pkg_url=$(get_python_pkg_url "${pyver_pkg}" "${py_unicode_type}") +fi; + LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: --upgrade ${deepspeech_pkg_url} | cat which deepspeech deepspeech --version +ensure_cuda_usage "$3" + run_all_inference_tests virtualenv_deactivate "${pyalias}" "deepspeech" diff --git a/taskcluster/tc-train-tests.sh b/taskcluster/tc-train-tests.sh index 1be6533b..2273405a 100644 --- a/taskcluster/tc-train-tests.sh +++ b/taskcluster/tc-train-tests.sh @@ -48,6 +48,11 @@ pushd ${HOME}/DeepSpeech/ds/ time ./bin/run-tc-ldc93s1_new.sh 249 "${sample_rate}" time ./bin/run-tc-ldc93s1_new.sh 1 "${sample_rate}" time ./bin/run-tc-ldc93s1_tflite.sh "${sample_rate}" + # Testing single SDB source + time ./bin/run-tc-ldc93s1_new_sdb.sh 220 "${sample_rate}" + # Testing interleaved source (SDB+CSV combination) - run twice to test preprocessed features + time ./bin/run-tc-ldc93s1_new_sdb_csv.sh 109 "${sample_rate}" + time ./bin/run-tc-ldc93s1_new_sdb_csv.sh 1 "${sample_rate}" popd cp /tmp/train/output_graph.pb ${TASKCLUSTER_ARTIFACTS} @@ -62,6 +67,7 @@ cp /tmp/train/output_graph.pbmm ${TASKCLUSTER_ARTIFACTS} pushd ${HOME}/DeepSpeech/ds/ time ./bin/run-tc-ldc93s1_checkpoint.sh + time ./bin/run-tc-ldc93s1_checkpoint_sdb.sh popd virtualenv_deactivate "${pyalias}" "deepspeech" diff --git a/taskcluster/test-apk-android-26-x86_64-opt.yml.disabled b/taskcluster/test-apk-android-26-x86_64-opt.yml similarity index 57% rename from taskcluster/test-apk-android-26-x86_64-opt.yml.disabled rename to taskcluster/test-apk-android-26-x86_64-opt.yml index cddd767e..4fbc90e4 100644 --- a/taskcluster/test-apk-android-26-x86_64-opt.yml.disabled +++ b/taskcluster/test-apk-android-26-x86_64-opt.yml @@ -1,12 +1,21 @@ -# disabled because too intermittent to be reliable until we have some KVM-backed infra build: template_file: test-android-opt-base.tyml dependencies: - "android-x86_64-cpu-opt" - "test-training_16k-linux-amd64-py36m-opt" + - "swig-linux-amd64" + - "gradle-cache" + - "android-cache-x86_64-android-26" + test_model_task: "test-training_16k-linux-amd64-py36m-opt" system_setup: > apt-get -qq -y install curl make python + cache: + url: ${system.android_cache.x86_64.android_26.url} + namespace: ${system.android_cache.x86_64.android_26.namespace} + gradle_cache: + url: ${system.gradle_cache.url} + namespace: ${system.gradle_cache.namespace} args: tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-android-apk-tests.sh x86_64 android-26" metadata: diff --git a/taskcluster/test-apk-android-28-x86_64-opt.yml.disabled b/taskcluster/test-apk-android-28-x86_64-opt.yml similarity index 57% rename from taskcluster/test-apk-android-28-x86_64-opt.yml.disabled rename to taskcluster/test-apk-android-28-x86_64-opt.yml index ebf1d996..793b5472 100644 --- a/taskcluster/test-apk-android-28-x86_64-opt.yml.disabled +++ b/taskcluster/test-apk-android-28-x86_64-opt.yml @@ -1,12 +1,21 @@ -# disabled because too intermittent to be reliable until we have some KVM-backed infra build: template_file: test-android-opt-base.tyml dependencies: - "android-x86_64-cpu-opt" - "test-training_16k-linux-amd64-py36m-opt" + - "swig-linux-amd64" + - "gradle-cache" + - "android-cache-x86_64-android-28" + test_model_task: "test-training_16k-linux-amd64-py36m-opt" system_setup: > apt-get -qq -y install curl make python + cache: + url: ${system.android_cache.x86_64.android_28.url} + namespace: ${system.android_cache.x86_64.android_28.namespace} + gradle_cache: + url: ${system.gradle_cache.url} + namespace: ${system.gradle_cache.namespace} args: tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-android-apk-tests.sh x86_64 android-28" metadata: diff --git a/taskcluster/test-apk-android-29-x86_64-opt.yml b/taskcluster/test-apk-android-29-x86_64-opt.yml new file mode 100644 index 00000000..00aa6ef8 --- /dev/null +++ b/taskcluster/test-apk-android-29-x86_64-opt.yml @@ -0,0 +1,23 @@ +build: + template_file: test-android-opt-base.tyml + dependencies: + - "android-x86_64-cpu-opt" + - "test-training_16k-linux-amd64-py36m-opt" + - "swig-linux-amd64" + - "gradle-cache" + - "android-cache-x86_64-android-29" + test_model_task: "test-training_16k-linux-amd64-py36m-opt" + system_setup: + > + apt-get -qq -y install curl make python + cache: + url: ${system.android_cache.x86_64.android_29.url} + namespace: ${system.android_cache.x86_64.android_29.namespace} + gradle_cache: + url: ${system.gradle_cache.url} + namespace: ${system.gradle_cache.namespace} + args: + tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-android-apk-tests.sh x86_64 android-29" + metadata: + name: "DeepSpeech Android 10.0 x86_64 Google Pixel APK/Java tests" + description: "Testing DeepSpeech APK/Java for Android 10.0 x86_64 Google Pixel, optimized version" diff --git a/taskcluster/test-cpp_16k-win-cuda-opt.yml b/taskcluster/test-cpp_16k-win-cuda-opt.yml new file mode 100644 index 00000000..6685832a --- /dev/null +++ b/taskcluster/test-cpp_16k-win-cuda-opt.yml @@ -0,0 +1,11 @@ +build: + template_file: test-win-cuda-opt-base.tyml + dependencies: + - "win-amd64-gpu-opt" + - "test-training_16k-linux-amd64-py36m-opt" + test_model_task: "test-training_16k-linux-amd64-py36m-opt" + args: + tests_cmdline: "$TASKCLUSTER_TASK_DIR/DeepSpeech/ds/taskcluster/tc-cppwin-ds-tests.sh 16k cuda" + metadata: + name: "DeepSpeech Windows AMD64 CUDA C++ tests (16kHz)" + description: "Testing DeepSpeech C++ for Windows/AMD64, CUDA, optimized version (16kHz)" diff --git a/taskcluster/test-electronjs_v8.0_multiarchpkg-win-amd64-opt.yml b/taskcluster/test-electronjs_v8.0_multiarchpkg-win-amd64-opt.yml new file mode 100644 index 00000000..d086e347 --- /dev/null +++ b/taskcluster/test-electronjs_v8.0_multiarchpkg-win-amd64-opt.yml @@ -0,0 +1,14 @@ +build: + template_file: test-win-opt-base.tyml + dependencies: + - "node-package-cpu" + - "test-training_16k-linux-amd64-py36m-opt" + test_model_task: "test-training_16k-linux-amd64-py36m-opt" + system_setup: + > + ${system.sox_win} && ${nodejs.win.prep_12} + args: + tests_cmdline: "${system.homedir.win}/DeepSpeech/ds/taskcluster/tc-electron-tests.sh 12.x 8.0.1 16k" + metadata: + name: "DeepSpeech Windows AMD64 CPU ElectronJS MultiArch Package v8.0 tests" + description: "Testing DeepSpeech for Windows/AMD64 on ElectronJS MultiArch Package v8.0, CPU only, optimized version" diff --git a/taskcluster/test-electronjs_v8.0_multiarchpkg-win-cuda-opt.yml b/taskcluster/test-electronjs_v8.0_multiarchpkg-win-cuda-opt.yml new file mode 100644 index 00000000..ac35c001 --- /dev/null +++ b/taskcluster/test-electronjs_v8.0_multiarchpkg-win-cuda-opt.yml @@ -0,0 +1,14 @@ +build: + template_file: test-win-cuda-opt-base.tyml + dependencies: + - "node-package-gpu" + - "test-training_16k-linux-amd64-py36m-opt" + test_model_task: "test-training_16k-linux-amd64-py36m-opt" + system_setup: + > + ${system.sox_win} && ${nodejs.win.prep_12} + args: + tests_cmdline: "${system.homedir.win}/DeepSpeech/ds/taskcluster/tc-electron-tests.sh 12.x 8.0.1 16k cuda" + metadata: + name: "DeepSpeech Windows AMD64 CUDA ElectronJS MultiArch Package v8.0 tests" + description: "Testing DeepSpeech for Windows/AMD64 on ElectronJS MultiArch Package v8.0, CUDA, optimized version" diff --git a/taskcluster/test-electronjs_v8.0_multiarchpkg-win-tflite-opt.yml.disabled b/taskcluster/test-electronjs_v8.0_multiarchpkg-win-tflite-opt.yml.disabled new file mode 100644 index 00000000..ef2f693a --- /dev/null +++ b/taskcluster/test-electronjs_v8.0_multiarchpkg-win-tflite-opt.yml.disabled @@ -0,0 +1,14 @@ +build: + template_file: test-win-opt-base.tyml + dependencies: + - "node-package-tflite" + - "test-training_16k-linux-amd64-py36m-opt" + test_model_task: "test-training_16k-linux-amd64-py36m-opt" + system_setup: + > + ${system.sox_win} && ${nodejs.win.prep_12} + args: + tests_cmdline: "${system.homedir.win}/DeepSpeech/ds/taskcluster/tc-electron-tests.sh 12.x 8.0.1 16k" + metadata: + name: "DeepSpeech Windows AMD64 TFLite ElectronJS MultiArch Package v8.0 tests" + description: "Testing DeepSpeech for Windows/AMD64 on ElectronJS MultiArch Package v8.0, TFLite only, optimized version" diff --git a/taskcluster/test-netframework-win-cuda-opt.yml b/taskcluster/test-netframework-win-cuda-opt.yml new file mode 100644 index 00000000..c23664d7 --- /dev/null +++ b/taskcluster/test-netframework-win-cuda-opt.yml @@ -0,0 +1,11 @@ +build: + template_file: test-win-cuda-opt-base.tyml + dependencies: + - "win-amd64-gpu-opt" + - "test-training_16k-linux-amd64-py36m-opt" + test_model_task: "test-training_16k-linux-amd64-py36m-opt" + args: + tests_cmdline: "$TASKCLUSTER_TASK_DIR/DeepSpeech/ds/taskcluster/tc-netframework-ds-tests.sh 16k cuda" + metadata: + name: "DeepSpeech Windows AMD64 CUDA .Net Framework tests" + description: "Testing DeepSpeech .Net Framework for Windows/AMD64, CUDA, optimized version" diff --git a/taskcluster/test-nodejs_13x_multiarchpkg-win-cuda-opt.yml b/taskcluster/test-nodejs_13x_multiarchpkg-win-cuda-opt.yml new file mode 100644 index 00000000..499462a4 --- /dev/null +++ b/taskcluster/test-nodejs_13x_multiarchpkg-win-cuda-opt.yml @@ -0,0 +1,14 @@ +build: + template_file: test-win-cuda-opt-base.tyml + dependencies: + - "node-package-gpu" + - "test-training_16k-linux-amd64-py36m-opt" + test_model_task: "test-training_16k-linux-amd64-py36m-opt" + system_setup: + > + ${system.sox_win} && ${nodejs.win.prep_13} + args: + tests_cmdline: "${system.homedir.win}/DeepSpeech/ds/taskcluster/tc-node-tests.sh 13.x 16k cuda" + metadata: + name: "DeepSpeech Windows AMD64 CUDA NodeJS MultiArch Package 13.x tests" + description: "Testing DeepSpeech for Windows/AMD64 on NodeJS MultiArch Package v13.x, CUDA, optimized version" diff --git a/taskcluster/test-python_35-win-cuda-opt.yml b/taskcluster/test-python_35-win-cuda-opt.yml new file mode 100644 index 00000000..0c54f0b0 --- /dev/null +++ b/taskcluster/test-python_35-win-cuda-opt.yml @@ -0,0 +1,14 @@ +build: + template_file: test-win-cuda-opt-base.tyml + dependencies: + - "win-amd64-gpu-opt" + - "test-training_16k-linux-amd64-py36m-opt" + test_model_task: "test-training_16k-linux-amd64-py36m-opt" + system_setup: + > + ${system.sox_win} + args: + tests_cmdline: "${system.homedir.win}/DeepSpeech/ds/taskcluster/tc-python-tests.sh 3.5.4:m 16k cuda" + metadata: + name: "DeepSpeech Windows AMD64 CUDA Python v3.5 tests" + description: "Testing DeepSpeech for Windows/AMD64 on Python v3.5, CUDA, optimized version" diff --git a/taskcluster/test-python_36-win-cuda-opt.yml b/taskcluster/test-python_36-win-cuda-opt.yml new file mode 100644 index 00000000..640f6a1b --- /dev/null +++ b/taskcluster/test-python_36-win-cuda-opt.yml @@ -0,0 +1,14 @@ +build: + template_file: test-win-cuda-opt-base.tyml + dependencies: + - "win-amd64-gpu-opt" + - "test-training_16k-linux-amd64-py36m-opt" + test_model_task: "test-training_16k-linux-amd64-py36m-opt" + system_setup: + > + ${system.sox_win} + args: + tests_cmdline: "${system.homedir.win}/DeepSpeech/ds/taskcluster/tc-python-tests.sh 3.6.8:m 16k cuda" + metadata: + name: "DeepSpeech Windows AMD64 CUDA Python v3.6 tests" + description: "Testing DeepSpeech for Windows/AMD64 on Python v3.6, CUDA, optimized version" diff --git a/taskcluster/test-python_37-win-cuda-opt.yml b/taskcluster/test-python_37-win-cuda-opt.yml new file mode 100644 index 00000000..a10a06a4 --- /dev/null +++ b/taskcluster/test-python_37-win-cuda-opt.yml @@ -0,0 +1,14 @@ +build: + template_file: test-win-cuda-opt-base.tyml + dependencies: + - "win-amd64-gpu-opt" + - "test-training_16k-linux-amd64-py36m-opt" + test_model_task: "test-training_16k-linux-amd64-py36m-opt" + system_setup: + > + ${system.sox_win} + args: + tests_cmdline: "${system.homedir.win}/DeepSpeech/ds/taskcluster/tc-python-tests.sh 3.7.6:m 16k cuda" + metadata: + name: "DeepSpeech Windows AMD64 CUDA Python v3.7 tests" + description: "Testing DeepSpeech for Windows/AMD64 on Python v3.7, CUDA, optimized version" diff --git a/taskcluster/test-python_38-win-cuda-opt.yml b/taskcluster/test-python_38-win-cuda-opt.yml new file mode 100644 index 00000000..73140672 --- /dev/null +++ b/taskcluster/test-python_38-win-cuda-opt.yml @@ -0,0 +1,14 @@ +build: + template_file: test-win-cuda-opt-base.tyml + dependencies: + - "win-amd64-gpu-opt" + - "test-training_16k-linux-amd64-py36m-opt" + test_model_task: "test-training_16k-linux-amd64-py36m-opt" + system_setup: + > + ${system.sox_win} + args: + tests_cmdline: "${system.homedir.win}/DeepSpeech/ds/taskcluster/tc-python-tests.sh 3.8.1: 16k cuda" + metadata: + name: "DeepSpeech Windows AMD64 CUDA Python v3.8 tests" + description: "Testing DeepSpeech for Windows/AMD64 on Python v3.8, CUDA, optimized version" diff --git a/taskcluster/test-training_16k-linux-amd64-py35m-opt.yml b/taskcluster/test-training_16k-linux-amd64-py35m-opt.yml index e950969f..3f68fea3 100644 --- a/taskcluster/test-training_16k-linux-amd64-py35m-opt.yml +++ b/taskcluster/test-training_16k-linux-amd64-py35m-opt.yml @@ -2,6 +2,9 @@ build: template_file: test-linux-opt-base.tyml dependencies: - "linux-amd64-ctc-opt" + system_setup: + > + apt-get -qq update && apt-get -qq -y install ${training.packages_trusty.apt} args: tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-tests.sh 3.5.8:m 16k" metadata: diff --git a/taskcluster/test-training_16k-linux-amd64-py36m-opt.yml b/taskcluster/test-training_16k-linux-amd64-py36m-opt.yml index 0bb84191..9fa9791b 100644 --- a/taskcluster/test-training_16k-linux-amd64-py36m-opt.yml +++ b/taskcluster/test-training_16k-linux-amd64-py36m-opt.yml @@ -2,6 +2,9 @@ build: template_file: test-linux-opt-base.tyml dependencies: - "linux-amd64-ctc-opt" + system_setup: + > + apt-get -qq update && apt-get -qq -y install ${training.packages_trusty.apt} args: tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-tests.sh 3.6.10:m 16k" metadata: diff --git a/taskcluster/test-training_8k-linux-amd64-py36m-opt.yml b/taskcluster/test-training_8k-linux-amd64-py36m-opt.yml index e4164a9b..dc2b486f 100644 --- a/taskcluster/test-training_8k-linux-amd64-py36m-opt.yml +++ b/taskcluster/test-training_8k-linux-amd64-py36m-opt.yml @@ -2,6 +2,9 @@ build: template_file: test-linux-opt-base.tyml dependencies: - "linux-amd64-ctc-opt" + system_setup: + > + apt-get -qq update && apt-get -qq -y install ${training.packages_trusty.apt} args: tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-train-tests.sh 3.6.10:m 8k" metadata: diff --git a/taskcluster/test-win-cuda-opt-base.tyml b/taskcluster/test-win-cuda-opt-base.tyml new file mode 100644 index 00000000..c5bedf30 --- /dev/null +++ b/taskcluster/test-win-cuda-opt-base.tyml @@ -0,0 +1,80 @@ +$if: '(event.event != "push") && (event.event != "tag")' +then: + taskId: ${taskcluster.taskId} + provisionerId: ${taskcluster.docker.provisionerId} + workerType: ${taskcluster.docker.workerTypeCuda} + taskGroupId: ${taskcluster.taskGroupId} + schedulerId: ${taskcluster.schedulerId} + dependencies: + $map: { $eval: build.dependencies } + each(b): + $eval: as_slugid(b) + created: { $fromNow: '0 sec' } + deadline: { $fromNow: '1 day' } + expires: { $fromNow: '7 days' } + + extra: + github: + { $eval: taskcluster.github_events.pull_request } + + payload: + maxRunTime: { $eval: to_int(build.maxRunTime) } + + env: + $let: + training: { $eval: as_slugid(build.test_model_task) } + win_amd64_build: { $eval: as_slugid("win-amd64-gpu-opt") } + in: + DEEPSPEECH_ARTIFACTS_ROOT: https://community-tc.services.mozilla.com/api/queue/v1/task/${win_amd64_build}/artifacts/public + DEEPSPEECH_TEST_MODEL: https://community-tc.services.mozilla.com/api/queue/v1/task/${training}/artifacts/public/output_graph.pb + DEEPSPEECH_PROD_MODEL: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.15/output_graph.pb + DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.15/output_graph.pbmm + EXPECTED_TENSORFLOW_VERSION: "${build.tensorflow_git_desc}" + TC_MSYS_VERSION: 'MSYS_NT-6.3' + MSYS: 'winsymlinks:nativestrict' + + command: + - >- + "C:\Program Files\7-zip\7z.exe" x -txz -so msys2-base-x86_64.tar.xz | + "C:\Program Files\7-zip\7z.exe" x -o%USERPROFILE% -ttar -aoa -si + - .\msys64\usr\bin\bash.exe --login -cx "exit" + - .\msys64\usr\bin\bash.exe --login -cx "pacman --noconfirm -Syu" + - $let: + extraSystemSetup: { $eval: strip(str(build.system_setup)) } + in: > + .\msys64\usr\bin\bash.exe --login -cxe "export LC_ALL=C && + export PATH=\"/c/builds/tc-workdir/msys64/usr/bin:/c/Python36:/c/Program Files/Git/bin:/c/Program Files/7-Zip/:/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.0/bin/:$PATH\" && + export TASKCLUSTER_ARTIFACTS=\"$USERPROFILE/public\" && + export TASKCLUSTER_TASK_DIR=\"/c/builds/tc-workdir/\" && + export TASKCLUSTER_NODE_DIR=\"$(cygpath -w $TASKCLUSTER_TASK_DIR/bin)\" && + export TASKCLUSTER_TMP_DIR="$TASKCLUSTER_TASK_DIR/tmp" && + export PIP_DEFAULT_TIMEOUT=60 && + (mkdir $TASKCLUSTER_TASK_DIR || rm -fr $TASKCLUSTER_TASK_DIR/*) && cd $TASKCLUSTER_TASK_DIR && + env && + ln -s $USERPROFILE/msys64 $TASKCLUSTER_TASK_DIR/msys64 && + git clone --quiet ${event.head.repo.url} $TASKCLUSTER_TASK_DIR/DeepSpeech/ds/ && + cd $TASKCLUSTER_TASK_DIR/DeepSpeech/ds && git checkout --quiet ${event.head.sha} && + cd $TASKCLUSTER_TASK_DIR && + (mkdir pyenv-root/ && 7z x -so $USERPROFILE/pyenv.tar.gz | 7z x -opyenv-root/ -aoa -ttar -si ) && + pacman --noconfirm -R bsdtar && + pacman --noconfirm -S tar vim && + ${extraSystemSetup} && + /bin/bash ${build.args.tests_cmdline} ; + export TASKCLUSTER_TASK_EXIT_CODE=$? && + cd $TASKCLUSTER_TASK_DIR/../ && rm -fr tc-workdir/ && exit $TASKCLUSTER_TASK_EXIT_CODE" + + mounts: + - file: msys2-base-x86_64.tar.xz + content: + sha256: 4e799b5c3efcf9efcb84923656b7bcff16f75a666911abd6620ea8e5e1e9870c + url: >- + https://sourceforge.net/projects/msys2/files/Base/x86_64/msys2-base-x86_64-20180531.tar.xz/download + - file: pyenv.tar.gz + content: + url: ${system.pyenv.win.url} + + metadata: + name: ${build.metadata.name} + description: ${build.metadata.description} + owner: ${event.head.user.email} + source: ${event.head.repo.url} diff --git a/taskcluster/win-amd64-ctc-opt.yml b/taskcluster/win-amd64-ctc-opt.yml new file mode 100644 index 00000000..ebd37445 --- /dev/null +++ b/taskcluster/win-amd64-ctc-opt.yml @@ -0,0 +1,17 @@ +build: + template_file: win-opt-base.tyml + dependencies: + - "swig-win-amd64" + - "node-gyp-cache" + - "pyenv-win-amd64" + routes: + - "index.project.deepspeech.deepspeech.native_client.${event.head.branchortag}.win-ctc" + - "index.project.deepspeech.deepspeech.native_client.${event.head.branchortag}.${event.head.sha}.win-ctc" + - "index.project.deepspeech.deepspeech.native_client.win-ctc.${event.head.sha}" + tensorflow: "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.r1.15.ceb46aae5836a0f648a2c3da5942af2b7d1b98bf.win/artifacts/public/home.tar.xz" + scripts: + build: 'taskcluster/decoder-build.sh' + package: 'taskcluster/decoder-package.sh' + metadata: + name: "DeepSpeech CTC Decoder Windows AMD64 CPU" + description: "Building DeepSpeech CTC Decoder for Windows/AMD64, CPU only, optimized version" diff --git a/taskcluster/worker.cyml b/taskcluster/worker.cyml index 809343fd..9ef5a85e 100644 --- a/taskcluster/worker.cyml +++ b/taskcluster/worker.cyml @@ -5,6 +5,7 @@ taskcluster: workerType: ci workerTypeKvm: kvm workerTypeWin: win-b + workerTypeCuda: win-gpu dockerrpi3: provisionerId: proj-deepspeech workerType: ds-rpi3 diff --git a/util/audio.py b/util/audio.py index e713ca7c..9c6ed94e 100644 --- a/util/audio.py +++ b/util/audio.py @@ -1,34 +1,148 @@ import os -import sox +import io import wave import tempfile import collections +import numpy as np + +from util.helpers import LimitingPool DEFAULT_RATE = 16000 DEFAULT_CHANNELS = 1 DEFAULT_WIDTH = 2 DEFAULT_FORMAT = (DEFAULT_RATE, DEFAULT_CHANNELS, DEFAULT_WIDTH) +AUDIO_TYPE_NP = 'application/vnd.mozilla.np' +AUDIO_TYPE_PCM = 'application/vnd.mozilla.pcm' +AUDIO_TYPE_WAV = 'audio/wav' +AUDIO_TYPE_OPUS = 'application/vnd.mozilla.opus' +SERIALIZABLE_AUDIO_TYPES = [AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS] -def get_audio_format(wav_file): +OPUS_PCM_LEN_SIZE = 4 +OPUS_RATE_SIZE = 4 +OPUS_CHANNELS_SIZE = 1 +OPUS_WIDTH_SIZE = 1 +OPUS_CHUNK_LEN_SIZE = 2 + + +class Sample: + """ + Represents in-memory audio data of a certain (convertible) representation. + + Attributes + ---------- + audio_type : str + See `__init__`. + audio_format : tuple:(int, int, int) + See `__init__`. + audio : binary + Audio data represented as indicated by `audio_type` + duration : float + Audio duration of the sample in seconds + """ + def __init__(self, audio_type, raw_data, audio_format=None, sample_id=None): + """ + Parameters + ---------- + audio_type : str + Audio data representation type + Supported types: + - util.audio.AUDIO_TYPE_OPUS: Memory file representation (BytesIO) of Opus encoded audio + wrapped by a custom container format (used in SDBs) + - util.audio.AUDIO_TYPE_WAV: Memory file representation (BytesIO) of a Wave file + - util.audio.AUDIO_TYPE_PCM: Binary representation (bytearray) of PCM encoded audio data (Wave file without header) + - util.audio.AUDIO_TYPE_NP: NumPy representation of audio data (np.float32) - typically used for GPU feeding + raw_data : binary + Audio data in the form of the provided representation type (see audio_type). + For types util.audio.AUDIO_TYPE_OPUS or util.audio.AUDIO_TYPE_WAV data can also be passed as a bytearray. + audio_format : tuple + Tuple of sample-rate, number of channels and sample-width. + Required in case of audio_type = util.audio.AUDIO_TYPE_PCM or util.audio.AUDIO_TYPE_NP, + as this information cannot be derived from raw audio data. + sample_id : str + Tracking ID - should indicate sample's origin as precisely as possible + """ + self.audio_type = audio_type + self.audio_format = audio_format + self.sample_id = sample_id + if audio_type in SERIALIZABLE_AUDIO_TYPES: + self.audio = raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data) + self.duration = read_duration(audio_type, self.audio) + else: + self.audio = raw_data + if self.audio_format is None: + raise ValueError('For audio type "{}" parameter "audio_format" is mandatory'.format(self.audio_type)) + if audio_type == AUDIO_TYPE_PCM: + self.duration = get_pcm_duration(len(self.audio), self.audio_format) + elif audio_type == AUDIO_TYPE_NP: + self.duration = get_np_duration(len(self.audio), self.audio_format) + else: + raise ValueError('Unsupported audio type: {}'.format(self.audio_type)) + + def change_audio_type(self, new_audio_type): + """ + In-place conversion of audio data into a different representation. + + Parameters + ---------- + new_audio_type : str + New audio-type - see `__init__`. + Not supported: Converting from AUDIO_TYPE_NP into any other type. + """ + if self.audio_type == new_audio_type: + return + if new_audio_type == AUDIO_TYPE_PCM and self.audio_type in SERIALIZABLE_AUDIO_TYPES: + self.audio_format, audio = read_audio(self.audio_type, self.audio) + self.audio.close() + self.audio = audio + elif new_audio_type == AUDIO_TYPE_NP: + self.change_audio_type(AUDIO_TYPE_PCM) + self.audio = pcm_to_np(self.audio_format, self.audio) + elif new_audio_type in SERIALIZABLE_AUDIO_TYPES: + self.change_audio_type(AUDIO_TYPE_PCM) + audio_bytes = io.BytesIO() + write_audio(new_audio_type, audio_bytes, self.audio_format, self.audio) + audio_bytes.seek(0) + self.audio = audio_bytes + else: + raise RuntimeError('Changing audio representation type from "{}" to "{}" not supported' + .format(self.audio_type, new_audio_type)) + self.audio_type = new_audio_type + + +def _change_audio_type(sample_and_audio_type): + sample, audio_type = sample_and_audio_type + sample.change_audio_type(audio_type) + return sample + + +def change_audio_types(samples, audio_type=AUDIO_TYPE_PCM, processes=None, process_ahead=None): + with LimitingPool(processes=processes, process_ahead=process_ahead) as pool: + yield from pool.imap(_change_audio_type, map(lambda s: (s, audio_type), samples)) + + +def read_audio_format_from_wav_file(wav_file): return wav_file.getframerate(), wav_file.getnchannels(), wav_file.getsampwidth() -def get_num_samples(audio_data, audio_format=DEFAULT_FORMAT): +def get_num_samples(pcm_buffer_size, audio_format=DEFAULT_FORMAT): _, channels, width = audio_format - return len(audio_data) // (channels * width) + return pcm_buffer_size // (channels * width) -def get_duration(audio_data, audio_format=DEFAULT_FORMAT): - return get_num_samples(audio_data, audio_format) / audio_format[0] +def get_pcm_duration(pcm_buffer_size, audio_format=DEFAULT_FORMAT): + """Calculates duration in seconds of a binary PCM buffer (typically read from a WAV file)""" + return get_num_samples(pcm_buffer_size, audio_format) / audio_format[0] -def get_duration_ms(audio_data, audio_format=DEFAULT_FORMAT): - return get_duration(audio_data, audio_format) * 1000 +def get_np_duration(np_len, audio_format=DEFAULT_FORMAT): + """Calculates duration in seconds of NumPy audio data""" + return np_len / audio_format[0] def convert_audio(src_audio_path, dst_audio_path, file_type=None, audio_format=DEFAULT_FORMAT): sample_rate, channels, width = audio_format + import sox transformer = sox.Transformer() transformer.set_output_format(file_type=file_type, rate=sample_rate, channels=channels, bits=width*8) transformer.build(src_audio_path, dst_audio_path) @@ -45,7 +159,7 @@ class AudioFile: def __enter__(self): if self.audio_path.endswith('.wav'): self.open_file = wave.open(self.audio_path, 'r') - if get_audio_format(self.open_file) == self.audio_format: + if read_audio_format_from_wav_file(self.open_file) == self.audio_format: if self.as_path: self.open_file.close() return self.audio_path @@ -66,12 +180,12 @@ class AudioFile: def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False): - audio_format = get_audio_format(wav_file) + audio_format = read_audio_format_from_wav_file(wav_file) frame_size = int(audio_format[0] * (frame_duration_ms / 1000.0)) while True: try: data = wav_file.readframes(frame_size) - if not yield_remainder and get_duration_ms(data, audio_format) < frame_duration_ms: + if not yield_remainder and get_pcm_duration(len(data), audio_format) * 1000 < frame_duration_ms: break yield data except EOFError: @@ -106,7 +220,7 @@ def vad_split(audio_frames, frame_duration_ms = 0 frame_index = 0 for frame_index, frame in enumerate(audio_frames): - frame_duration_ms = get_duration_ms(frame, audio_format) + frame_duration_ms = get_pcm_duration(len(frame), audio_format) * 1000 if int(frame_duration_ms) not in [10, 20, 30]: raise ValueError('VAD-splitting only supported for frame durations 10, 20, or 30 ms') is_speech = vad.is_speech(frame, sample_rate) @@ -133,3 +247,123 @@ def vad_split(audio_frames, yield b''.join(voiced_frames), \ frame_duration_ms * (frame_index - len(voiced_frames)), \ frame_duration_ms * (frame_index + 1) + + +def pack_number(n, num_bytes): + return n.to_bytes(num_bytes, 'big', signed=False) + + +def unpack_number(data): + return int.from_bytes(data, 'big', signed=False) + + +def get_opus_frame_size(rate): + return 60 * rate // 1000 + + +def write_opus(opus_file, audio_format, audio_data): + rate, channels, width = audio_format + frame_size = get_opus_frame_size(rate) + import opuslib # pylint: disable=import-outside-toplevel + encoder = opuslib.Encoder(rate, channels, 'audio') + chunk_size = frame_size * channels * width + opus_file.write(pack_number(len(audio_data), OPUS_PCM_LEN_SIZE)) + opus_file.write(pack_number(rate, OPUS_RATE_SIZE)) + opus_file.write(pack_number(channels, OPUS_CHANNELS_SIZE)) + opus_file.write(pack_number(width, OPUS_WIDTH_SIZE)) + for i in range(0, len(audio_data), chunk_size): + chunk = audio_data[i:i + chunk_size] + # Preventing non-deterministic encoding results from uninitialized remainder of the encoder buffer + if len(chunk) < chunk_size: + chunk = chunk + bytearray(chunk_size - len(chunk)) + encoded = encoder.encode(chunk, frame_size) + opus_file.write(pack_number(len(encoded), OPUS_CHUNK_LEN_SIZE)) + opus_file.write(encoded) + + +def read_opus_header(opus_file): + opus_file.seek(0) + pcm_buffer_size = unpack_number(opus_file.read(OPUS_PCM_LEN_SIZE)) + rate = unpack_number(opus_file.read(OPUS_RATE_SIZE)) + channels = unpack_number(opus_file.read(OPUS_CHANNELS_SIZE)) + width = unpack_number(opus_file.read(OPUS_WIDTH_SIZE)) + return pcm_buffer_size, (rate, channels, width) + + +def read_opus(opus_file): + pcm_buffer_size, audio_format = read_opus_header(opus_file) + rate, channels, _ = audio_format + frame_size = get_opus_frame_size(rate) + import opuslib # pylint: disable=import-outside-toplevel + decoder = opuslib.Decoder(rate, channels) + audio_data = bytearray() + while len(audio_data) < pcm_buffer_size: + chunk_len = unpack_number(opus_file.read(OPUS_CHUNK_LEN_SIZE)) + chunk = opus_file.read(chunk_len) + decoded = decoder.decode(chunk, frame_size) + audio_data.extend(decoded) + audio_data = audio_data[:pcm_buffer_size] + return audio_format, audio_data + + +def write_wav(wav_file, audio_format, pcm_data): + with wave.open(wav_file, 'wb') as wav_file_writer: + rate, channels, width = audio_format + wav_file_writer.setframerate(rate) + wav_file_writer.setnchannels(channels) + wav_file_writer.setsampwidth(width) + wav_file_writer.writeframes(pcm_data) + + +def read_wav(wav_file): + wav_file.seek(0) + with wave.open(wav_file, 'rb') as wav_file_reader: + audio_format = read_audio_format_from_wav_file(wav_file_reader) + pcm_data = wav_file_reader.readframes(wav_file_reader.getnframes()) + return audio_format, pcm_data + + +def read_audio(audio_type, audio_file): + if audio_type == AUDIO_TYPE_WAV: + return read_wav(audio_file) + if audio_type == AUDIO_TYPE_OPUS: + return read_opus(audio_file) + raise ValueError('Unsupported audio type: {}'.format(audio_type)) + + +def write_audio(audio_type, audio_file, audio_format, pcm_data): + if audio_type == AUDIO_TYPE_WAV: + return write_wav(audio_file, audio_format, pcm_data) + if audio_type == AUDIO_TYPE_OPUS: + return write_opus(audio_file, audio_format, pcm_data) + raise ValueError('Unsupported audio type: {}'.format(audio_type)) + + +def read_wav_duration(wav_file): + wav_file.seek(0) + with wave.open(wav_file, 'rb') as wav_file_reader: + return wav_file_reader.getnframes() / wav_file_reader.getframerate() + + +def read_opus_duration(opus_file): + pcm_buffer_size, audio_format = read_opus_header(opus_file) + return get_pcm_duration(pcm_buffer_size, audio_format) + + +def read_duration(audio_type, audio_file): + if audio_type == AUDIO_TYPE_WAV: + return read_wav_duration(audio_file) + if audio_type == AUDIO_TYPE_OPUS: + return read_opus_duration(audio_file) + raise ValueError('Unsupported audio type: {}'.format(audio_type)) + + +def pcm_to_np(audio_format, pcm_data): + _, channels, width = audio_format + if width not in [1, 2, 4]: + raise ValueError('Unsupported sample width: {}'.format(width)) + dtype = [None, np.int8, np.int16, None, np.int32][width] + samples = np.frombuffer(pcm_data, dtype=dtype) + assert channels == 1 # only mono supported for now + samples = samples.astype(np.float32) / np.iinfo(dtype).max + return np.expand_dims(samples, axis=1) diff --git a/util/config.py b/util/config.py index 0e3a719b..bc9255dc 100755 --- a/util/config.py +++ b/util/config.py @@ -12,6 +12,7 @@ from util.flags import FLAGS from util.gpu import get_available_gpus from util.logging import log_error from util.text import Alphabet, UTF8Alphabet +from util.helpers import parse_file_size class ConfigSingleton: _config = None @@ -29,6 +30,9 @@ Config = ConfigSingleton() # pylint: disable=invalid-name def initialize_globals(): c = AttrDict() + # Read-buffer + FLAGS.read_buffer = parse_file_size(FLAGS.read_buffer) + # Set default dropout rates if FLAGS.dropout_rate2 < 0: FLAGS.dropout_rate2 = FLAGS.dropout_rate diff --git a/util/feeding.py b/util/feeding.py index 2c33d2ae..09a0904c 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -1,12 +1,9 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function -import os - from functools import partial import numpy as np -import pandas import tensorflow as tf from tensorflow.python.ops import gen_audio_ops as contrib_audio @@ -15,24 +12,20 @@ from util.config import Config from util.text import text_to_char_array from util.flags import FLAGS from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp -from util.audio import read_frames_from_file, vad_split, DEFAULT_FORMAT +from util.audio import change_audio_types, read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT, AUDIO_TYPE_NP +from util.sample_collections import samples_from_files +from util.helpers import remember_exception, MEGABYTE -def read_csvs(csv_files): - sets = [] - for csv in csv_files: - file = pandas.read_csv(csv, encoding='utf-8', na_filter=False) - #FIXME: not cross-platform - csv_dir = os.path.dirname(os.path.abspath(csv)) - file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1))) # pylint: disable=cell-var-from-loop - sets.append(file) - # Concat all sets, drop any extra columns, re-index the final result as 0..N - return pandas.concat(sets, join='inner', ignore_index=True) - - -def samples_to_mfccs(samples, sample_rate, train_phase=False, wav_filename=None): - if train_phase and sample_rate != FLAGS.audio_sample_rate: - tf.print('WARNING: sample rate of file', wav_filename, '(', sample_rate, ') does not match FLAGS.audio_sample_rate. This can lead to incorrect results.') +def samples_to_mfccs(samples, sample_rate, train_phase=False, sample_id=None): + if train_phase: + # We need the lambdas to make TensorFlow happy. + # pylint: disable=unnecessary-lambda + tf.cond(tf.math.not_equal(sample_rate, FLAGS.audio_sample_rate), + lambda: tf.print('WARNING: sample rate of sample', sample_id, '(', sample_rate, ') ' + 'does not match FLAGS.audio_sample_rate. This can lead to incorrect results.'), + lambda: tf.no_op(), + name='matching_sample_rate') spectrogram = contrib_audio.audio_spectrogram(samples, window_size=Config.audio_window_samples, @@ -79,10 +72,8 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False, wav_filename=None) return mfccs, tf.shape(input=mfccs)[0] -def audiofile_to_features(wav_filename, train_phase=False): - samples = tf.io.read_file(wav_filename) - decoded = contrib_audio.decode_wav(samples, desired_channels=1) - features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate, train_phase=train_phase, wav_filename=wav_filename) +def audio_to_features(audio, sample_rate, train_phase=False, sample_id=None): + features, features_len = samples_to_mfccs(audio, sample_rate, train_phase=train_phase, sample_id=sample_id) if train_phase: if FLAGS.data_aug_features_multiplicative > 0: @@ -94,10 +85,17 @@ def audiofile_to_features(wav_filename, train_phase=False): return features, features_len -def entry_to_features(wav_filename, transcript, train_phase): +def audiofile_to_features(wav_filename, train_phase=False): + samples = tf.io.read_file(wav_filename) + decoded = contrib_audio.decode_wav(samples, desired_channels=1) + return audio_to_features(decoded.audio, decoded.sample_rate, train_phase=train_phase, sample_id=wav_filename) + + +def entry_to_features(sample_id, audio, sample_rate, transcript, train_phase=False): # https://bugs.python.org/issue32117 - features, features_len = audiofile_to_features(wav_filename, train_phase=train_phase) - return wav_filename, features, features_len, tf.SparseTensor(*transcript) + features, features_len = audio_to_features(audio, sample_rate, train_phase=train_phase, sample_id=sample_id) + sparse_transcript = tf.SparseTensor(*transcript) + return sample_id, features, features_len, sparse_transcript def to_sparse_tuple(sequence): @@ -109,15 +107,22 @@ def to_sparse_tuple(sequence): return indices, sequence, shape -def create_dataset(csvs, batch_size, enable_cache=False, cache_path=None, train_phase=False): - df = read_csvs(csvs) - df.sort_values(by='wav_filesize', inplace=True) - - df['transcript'] = df.apply(text_to_char_array, alphabet=Config.alphabet, result_type='reduce', axis=1) - +def create_dataset(sources, + batch_size, + enable_cache=False, + cache_path=None, + train_phase=False, + exception_box=None, + process_ahead=None, + buffering=1 * MEGABYTE): def generate_values(): - for _, row in df.iterrows(): - yield row.wav_filename, to_sparse_tuple(row.transcript) + samples = samples_from_files(sources, buffering=buffering, labeled=True) + for sample in change_audio_types(samples, + AUDIO_TYPE_NP, + process_ahead=2 * batch_size if process_ahead is None else process_ahead): + transcript = text_to_char_array(sample.transcript, Config.alphabet, context=sample.sample_id) + transcript = to_sparse_tuple(transcript) + yield sample.sample_id, sample.audio, sample.audio_format[0], transcript # Batching a dataset of 2D SparseTensors creates 3D batches, which fail # when passed to tf.nn.ctc_loss, so we reshape them to remove the extra @@ -126,27 +131,23 @@ def create_dataset(csvs, batch_size, enable_cache=False, cache_path=None, train_ shape = sparse.dense_shape return tf.sparse.reshape(sparse, [shape[0], shape[2]]) - def batch_fn(wav_filenames, features, features_len, transcripts): + def batch_fn(sample_ids, features, features_len, transcripts): features = tf.data.Dataset.zip((features, features_len)) - features = features.padded_batch(batch_size, - padded_shapes=([None, Config.n_input], [])) + features = features.padded_batch(batch_size, padded_shapes=([None, Config.n_input], [])) transcripts = transcripts.batch(batch_size).map(sparse_reshape) - wav_filenames = wav_filenames.batch(batch_size) - return tf.data.Dataset.zip((wav_filenames, features, transcripts)) + sample_ids = sample_ids.batch(batch_size) + return tf.data.Dataset.zip((sample_ids, features, transcripts)) - num_gpus = len(Config.available_devices) process_fn = partial(entry_to_features, train_phase=train_phase) - dataset = (tf.data.Dataset.from_generator(generate_values, - output_types=(tf.string, (tf.int64, tf.int32, tf.int64))) + dataset = (tf.data.Dataset.from_generator(remember_exception(generate_values, exception_box), + output_types=(tf.string, tf.float32, tf.int32, + (tf.int64, tf.int32, tf.int64))) .map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)) - if enable_cache: dataset = dataset.cache(cache_path) - dataset = (dataset.window(batch_size, drop_remainder=True).flat_map(batch_fn) - .prefetch(num_gpus)) - + .prefetch(len(Config.available_devices))) return dataset @@ -155,27 +156,24 @@ def split_audio_file(audio_path, batch_size=1, aggressiveness=3, outlier_duration_ms=10000, - outlier_batch_size=1): - sample_rate, _, sample_width = audio_format - multiplier = 1.0 / (1 << (8 * sample_width - 1)) - + outlier_batch_size=1, + exception_box=None): def generate_values(): frames = read_frames_from_file(audio_path) segments = vad_split(frames, aggressiveness=aggressiveness) for segment in segments: segment_buffer, time_start, time_end = segment - samples = np.frombuffer(segment_buffer, dtype=np.int16) - samples = samples * multiplier - samples = np.expand_dims(samples, axis=1) + samples = pcm_to_np(audio_format, segment_buffer) yield time_start, time_end, samples def to_mfccs(time_start, time_end, samples): - features, features_len = samples_to_mfccs(samples, sample_rate) + features, features_len = samples_to_mfccs(samples, audio_format[0]) return time_start, time_end, features, features_len def create_batch_set(bs, criteria): return (tf.data.Dataset - .from_generator(generate_values, output_types=(tf.int32, tf.int32, tf.float32)) + .from_generator(remember_exception(generate_values, exception_box), + output_types=(tf.int32, tf.int32, tf.float32)) .map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE) .filter(criteria) .padded_batch(bs, padded_shapes=([], [], [None, Config.n_input], []))) @@ -187,9 +185,3 @@ def split_audio_file(audio_path, dataset = nds.concatenate(ods) dataset = dataset.prefetch(len(Config.available_devices)) return dataset - - -def secs_to_hours(secs): - hours, remainder = divmod(secs, 3600) - minutes, seconds = divmod(remainder, 60) - return '%d:%02d:%02d' % (hours, minutes, seconds) diff --git a/util/flags.py b/util/flags.py index 5057d76c..89494221 100644 --- a/util/flags.py +++ b/util/flags.py @@ -15,7 +15,8 @@ def create_flags(): f.DEFINE_string('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.') f.DEFINE_string('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.') - f.DEFINE_string('feature_cache', '', 'cache MFCC features to disk to speed up future training runs ont he same data. This flag specifies the path where cached features extracted from --train_files will be saved. If empty, or if online augmentation flags are enabled, caching will be disabled.') + f.DEFINE_string('read_buffer', '1MB', 'buffer-size for reading samples from datasets (supports file-size suffixes KB, MB, GB, TB)') + f.DEFINE_string('feature_cache', '', 'cache MFCC features to disk to speed up future training runs on the same data. This flag specifies the path where cached features extracted from --train_files will be saved. If empty, or if online augmentation flags are enabled, caching will be disabled.') f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds') f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds') @@ -112,11 +113,26 @@ def create_flags(): f.DEFINE_boolean('remove_export', False, 'whether to remove old exported models') f.DEFINE_boolean('export_tflite', False, 'export a graph ready for TF Lite engine') f.DEFINE_integer('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency') - f.DEFINE_string('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.') f.DEFINE_boolean('export_zip', False, 'export a TFLite model and package with LM and info.json') - f.DEFINE_string('export_name', 'output_graph', 'name for the export model') + f.DEFINE_string('export_file_name', 'output_graph', 'name for the exported model file name') f.DEFINE_integer('export_beam_width', 500, 'default beam width to embed into exported graph') + # Model metadata + + f.DEFINE_string('export_author_id', 'author', 'author of the exported model. GitHub user or organization name used to uniquely identify the author of this model') + f.DEFINE_string('export_model_name', 'model', 'name of the exported model. Must not contain forward slashes.') + f.DEFINE_string('export_model_version', '0.0.1', 'semantic version of the exported model. See https://semver.org/. This is fully controlled by you as author of the model and has no required connection with DeepSpeech versions') + + def str_val_equals_help(name, val_desc): + f.DEFINE_string(name, '<{}>'.format(val_desc), val_desc) + + str_val_equals_help('export_contact_info', 'public contact information of the author. Can be an email address, or a link to a contact form, issue tracker, or discussion forum. Must provide a way to reach the model authors') + str_val_equals_help('export_license', 'SPDX identifier of the license of the exported model. See https://spdx.org/licenses/. If the license does not have an SPDX identifier, use the license name.') + str_val_equals_help('export_language', 'language the model was trained on - IETF BCP 47 language tag including at least language, script and region subtags. E.g. "en-Latn-UK" or "de-Latn-DE" or "cmn-Hans-CN". Include as much info as you can without loss of precision. For example, if a model is trained on Scottish English, include the variant subtag: "en-Latn-GB-Scotland".') + str_val_equals_help('export_min_ds_version', 'minimum DeepSpeech version (inclusive) the exported model is compatible with') + str_val_equals_help('export_max_ds_version', 'maximum DeepSpeech version (inclusive) the exported model is compatible with') + str_val_equals_help('export_description', 'Freeform description of the model being exported. Markdown accepted. You can also leave this flag unchanged and edit the generated .md file directly. Useful things to describe are demographic and acoustic characteristics of the data used to train the model, any architectural changes, names of public datasets that were used when applicable, hyperparameters used for training, evaluation results on standard benchmark datasets, etc.') + # Reporting f.DEFINE_integer('log_level', 1, 'log level for console logs - 0: DEBUG, 1: INFO, 2: WARN, 3: ERROR') @@ -166,6 +182,12 @@ def create_flags(): f.DEFINE_string('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.') + # Optimizer mode + + f.DEFINE_float('lm_alpha_max', 5, 'the maximum of the alpha hyperparameter of the CTC decoder explored during hyperparameter optimization. Language Model weight.') + f.DEFINE_float('lm_beta_max', 5, 'the maximum beta hyperparameter of the CTC decoder explored during hyperparameter optimization. Word insertion weight.') + f.DEFINE_integer('n_trials', 2400, 'the number of trials to run during hyperparameter optimization.') + # Register validators for paths which require a file to be specified f.register_validator('alphabet_config_path', diff --git a/util/helpers.py b/util/helpers.py index cd4f4b03..c158066d 100644 --- a/util/helpers.py +++ b/util/helpers.py @@ -1,10 +1,32 @@ import os -import semver import sys +import time +import heapq +import semver + +from multiprocessing import Pool + +KILO = 1024 +KILOBYTE = 1 * KILO +MEGABYTE = KILO * KILOBYTE +GIGABYTE = KILO * MEGABYTE +TERABYTE = KILO * GIGABYTE +SIZE_PREFIX_LOOKUP = {'k': KILOBYTE, 'm': MEGABYTE, 'g': GIGABYTE, 't': TERABYTE} + + +def parse_file_size(file_size): + file_size = file_size.lower().strip() + if len(file_size) == 0: + return 0 + n = int(keep_only_digits(file_size)) + if file_size[-1] == 'b': + file_size = file_size[:-1] + e = file_size[-1] + return SIZE_PREFIX_LOOKUP[e] * n if e in SIZE_PREFIX_LOOKUP else n def keep_only_digits(txt): - return ''.join(filter(lambda c: c.isdigit(), txt)) + return ''.join(filter(str.isdigit, txt)) def secs_to_hours(secs): @@ -21,7 +43,8 @@ def check_ctcdecoder_version(): from ds_ctcdecoder import __version__ as decoder_version except ImportError as e: if e.msg.find('__version__') > 0: - print("DeepSpeech version ({ds_version}) requires CTC decoder to expose __version__. Please upgrade the ds_ctcdecoder package to version {ds_version}".format(ds_version=ds_version_s)) + print("DeepSpeech version ({ds_version}) requires CTC decoder to expose __version__. " + "Please upgrade the ds_ctcdecoder package to version {ds_version}".format(ds_version=ds_version_s)) sys.exit(1) raise e @@ -29,7 +52,79 @@ def check_ctcdecoder_version(): rv = semver.compare(ds_version_s, decoder_version_s) if rv != 0: - print("DeepSpeech version ({}) and CTC decoder version ({}) do not match. Please ensure matching versions are in use.".format(ds_version_s, decoder_version_s)) + print("DeepSpeech version ({}) and CTC decoder version ({}) do not match. " + "Please ensure matching versions are in use.".format(ds_version_s, decoder_version_s)) sys.exit(1) return rv + + +class Interleaved: + """Collection that lazily combines sorted collections in an interleaving fashion. + During iteration the next smallest element from all the sorted collections is always picked. + The collections must support iter() and len().""" + def __init__(self, *iterables, key=lambda obj: obj): + self.iterables = iterables + self.key = key + self.len = sum(map(len, iterables)) + + def __iter__(self): + return heapq.merge(*self.iterables, key=self.key) + + def __len__(self): + return self.len + + +class LimitingPool: + """Limits unbound ahead-processing of multiprocessing.Pool's imap method + before items get consumed by the iteration caller. + This prevents OOM issues in situations where items represent larger memory allocations.""" + def __init__(self, processes=None, process_ahead=None, sleeping_for=0.1): + self.process_ahead = os.cpu_count() if process_ahead is None else process_ahead + self.sleeping_for = sleeping_for + self.processed = 0 + self.pool = Pool(processes=processes) + + def __enter__(self): + return self + + def _limit(self, it): + for obj in it: + while self.processed >= self.process_ahead: + time.sleep(self.sleeping_for) + self.processed += 1 + yield obj + + def imap(self, fun, it): + for obj in self.pool.imap(fun, self._limit(it)): + self.processed -= 1 + yield obj + + def __exit__(self, exc_type, exc_value, traceback): + self.pool.close() + + +class ExceptionBox: + """Helper class for passing-back and re-raising an exception from inside a TensorFlow dataset generator. + Used in conjunction with `remember_exception`.""" + def __init__(self): + self.exception = None + + def raise_if_set(self): + if self.exception is not None: + exception = self.exception + self.exception = None + raise exception # pylint: disable = raising-bad-type + + +def remember_exception(iterable, exception_box=None): + """Wraps a TensorFlow dataset generator for catching its actual exceptions + that would otherwise just interrupt iteration w/o bubbling up.""" + def do_iterate(): + try: + yield from iterable() + except StopIteration: + return + except Exception as ex: # pylint: disable = broad-except + exception_box.exception = ex + return iterable if exception_box is None else do_iterate diff --git a/util/importers.py b/util/importers.py new file mode 100644 index 00000000..50b87fa0 --- /dev/null +++ b/util/importers.py @@ -0,0 +1,77 @@ +import argparse +import importlib +import os +import re +import sys + +from util.helpers import secs_to_hours +from collections import Counter + +def get_counter(): + return Counter({'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0}) + +def get_imported_samples(counter): + return counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'] - counter['invalid_label'] + +def print_import_report(counter, sample_rate, max_secs): + print('Imported %d samples.' % (get_imported_samples(counter))) + if counter['failed'] > 0: + print('Skipped %d samples that failed upon conversion.' % counter['failed']) + if counter['invalid_label'] > 0: + print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) + if counter['too_short'] > 0: + print('Skipped %d samples that were too short to match the transcript.' % counter['too_short']) + if counter['too_long'] > 0: + print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], max_secs)) + print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / sample_rate)) + +def get_importers_parser(description): + parser = argparse.ArgumentParser(description=description) + parser.add_argument('--validate_label_locale', help='Path to a Python file defining a |validate_label| function for your locale. WARNING: THIS WILL ADD THIS FILE\'s DIRECTORY INTO PYTHONPATH.') + return parser + +def get_validate_label(args): + """ + Expects an argparse.Namespace argument to search for validate_label_locale parameter. + If found, this will modify Python's library search path and add the directory of the + file pointed by the validate_label_locale argument. + + :param args: The importer's CLI argument object + :type args: argparse.Namespace + + :return: The user-supplied validate_label function + :type: function + """ + if 'validate_label_locale' not in args or (args.validate_label_locale is None): + print('WARNING: No --validate_label_locale specified, your might end with inconsistent dataset.') + return validate_label_eng + if not os.path.exists(os.path.abspath(args.validate_label_locale)): + print('ERROR: Inexistent --validate_label_locale specified. Please check.') + return None + module_dir = os.path.abspath(os.path.dirname(args.validate_label_locale)) + sys.path.insert(1, module_dir) + fname = os.path.basename(args.validate_label_locale).replace('.py', '') + locale_module = importlib.import_module(fname, package=None) + return locale_module.validate_label + +# Validate and normalize transcriptions. Returns a cleaned version of the label +# or None if it's invalid. +def validate_label_eng(label): + # For now we can only handle [a-z '] + if re.search(r"[0-9]|[(<\[\]&*{]", label) is not None: + return None + + label = label.replace("-", " ") + label = label.replace("_", " ") + label = re.sub("[ ]{2,}", " ", label) + label = label.replace(".", "") + label = label.replace(",", "") + label = label.replace(";", "") + label = label.replace("?", "") + label = label.replace("!", "") + label = label.replace(":", "") + label = label.replace("\"", "") + label = label.strip() + label = label.lower() + + return label if label else None diff --git a/util/sample_collections.py b/util/sample_collections.py new file mode 100644 index 00000000..7009db18 --- /dev/null +++ b/util/sample_collections.py @@ -0,0 +1,365 @@ +# -*- coding: utf-8 -*- +import os +import csv +import json + +from pathlib import Path +from functools import partial +from util.helpers import MEGABYTE, GIGABYTE, Interleaved +from util.audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, SERIALIZABLE_AUDIO_TYPES + +BIG_ENDIAN = 'big' +INT_SIZE = 4 +BIGINT_SIZE = 2 * INT_SIZE +MAGIC = b'SAMPLEDB' + +BUFFER_SIZE = 1 * MEGABYTE +CACHE_SIZE = 1 * GIGABYTE + +SCHEMA_KEY = 'schema' +CONTENT_KEY = 'content' +MIME_TYPE_KEY = 'mime-type' +MIME_TYPE_TEXT = 'text/plain' +CONTENT_TYPE_SPEECH = 'speech' +CONTENT_TYPE_TRANSCRIPT = 'transcript' + + +class LabeledSample(Sample): + """In-memory labeled audio sample representing an utterance. + Derived from util.audio.Sample and used by sample collection readers and writers.""" + def __init__(self, audio_type, raw_data, transcript, audio_format=DEFAULT_FORMAT, sample_id=None): + """ + Parameters + ---------- + audio_type : str + See util.audio.Sample.__init__ . + raw_data : binary + See util.audio.Sample.__init__ . + transcript : str + Transcript of the sample's utterance + audio_format : tuple + See util.audio.Sample.__init__ . + sample_id : str + Tracking ID - should indicate sample's origin as precisely as possible. + It is typically assigned by collection readers. + """ + super().__init__(audio_type, raw_data, audio_format=audio_format, sample_id=sample_id) + self.transcript = transcript + + +class DirectSDBWriter: + """Sample collection writer for creating a Sample DB (SDB) file""" + def __init__(self, sdb_filename, buffering=BUFFER_SIZE, audio_type=AUDIO_TYPE_OPUS, id_prefix=None, labeled=True): + """ + Parameters + ---------- + sdb_filename : str + Path to the SDB file to write + buffering : int + Write-buffer size to use while writing the SDB file + audio_type : str + See util.audio.Sample.__init__ . + id_prefix : str + Prefix for IDs of written samples - defaults to sdb_filename + labeled : bool or None + If True: Writes labeled samples (util.sample_collections.LabeledSample) only. + If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances. + """ + self.sdb_filename = sdb_filename + self.id_prefix = sdb_filename if id_prefix is None else id_prefix + self.labeled = labeled + if audio_type not in SERIALIZABLE_AUDIO_TYPES: + raise ValueError('Audio type "{}" not supported'.format(audio_type)) + self.audio_type = audio_type + self.sdb_file = open(sdb_filename, 'wb', buffering=buffering) + self.offsets = [] + self.num_samples = 0 + + self.sdb_file.write(MAGIC) + + schema_entries = [{CONTENT_KEY: CONTENT_TYPE_SPEECH, MIME_TYPE_KEY: audio_type}] + if self.labeled: + schema_entries.append({CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT}) + meta_data = {SCHEMA_KEY: schema_entries} + meta_data = json.dumps(meta_data).encode() + self.write_big_int(len(meta_data)) + self.sdb_file.write(meta_data) + + self.offset_samples = self.sdb_file.tell() + self.sdb_file.seek(2 * BIGINT_SIZE, 1) + + def write_int(self, n): + return self.sdb_file.write(n.to_bytes(INT_SIZE, BIG_ENDIAN)) + + def write_big_int(self, n): + return self.sdb_file.write(n.to_bytes(BIGINT_SIZE, BIG_ENDIAN)) + + def __enter__(self): + return self + + def add(self, sample): + def to_bytes(n): + return n.to_bytes(INT_SIZE, BIG_ENDIAN) + sample.change_audio_type(self.audio_type) + opus = sample.audio.getbuffer() + opus_len = to_bytes(len(opus)) + if self.labeled: + transcript = sample.transcript.encode() + transcript_len = to_bytes(len(transcript)) + entry_len = to_bytes(len(opus_len) + len(opus) + len(transcript_len) + len(transcript)) + buffer = b''.join([entry_len, opus_len, opus, transcript_len, transcript]) + else: + entry_len = to_bytes(len(opus_len) + len(opus)) + buffer = b''.join([entry_len, opus_len, opus]) + self.offsets.append(self.sdb_file.tell()) + self.sdb_file.write(buffer) + sample.sample_id = '{}:{}'.format(self.id_prefix, self.num_samples) + self.num_samples += 1 + return sample.sample_id + + def close(self): + if self.sdb_file is None: + return + offset_index = self.sdb_file.tell() + self.sdb_file.seek(self.offset_samples) + self.write_big_int(offset_index - self.offset_samples - BIGINT_SIZE) + self.write_big_int(self.num_samples) + + self.sdb_file.seek(offset_index + BIGINT_SIZE) + self.write_big_int(self.num_samples) + for offset in self.offsets: + self.write_big_int(offset) + offset_end = self.sdb_file.tell() + self.sdb_file.seek(offset_index) + self.write_big_int(offset_end - offset_index - BIGINT_SIZE) + self.sdb_file.close() + self.sdb_file = None + + def __len__(self): + return len(self.offsets) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + +class SDB: # pylint: disable=too-many-instance-attributes + """Sample collection reader for reading a Sample DB (SDB) file""" + def __init__(self, sdb_filename, buffering=BUFFER_SIZE, id_prefix=None, labeled=True): + """ + Parameters + ---------- + sdb_filename : str + Path to the SDB file to read samples from + buffering : int + Read-buffer size to use while reading the SDB file + id_prefix : str + Prefix for IDs of read samples - defaults to sdb_filename + labeled : bool or None + If True: Reads util.sample_collections.LabeledSample instances. Fails, if SDB file provides no transcripts. + If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances. + If None: Automatically determines if SDB schema has transcripts + (reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances). + """ + self.sdb_filename = sdb_filename + self.id_prefix = sdb_filename if id_prefix is None else id_prefix + self.sdb_file = open(sdb_filename, 'rb', buffering=buffering) + self.offsets = [] + if self.sdb_file.read(len(MAGIC)) != MAGIC: + raise RuntimeError('No Sample Database') + meta_chunk_len = self.read_big_int() + self.meta = json.loads(self.sdb_file.read(meta_chunk_len).decode()) + if SCHEMA_KEY not in self.meta: + raise RuntimeError('Missing schema') + self.schema = self.meta[SCHEMA_KEY] + + speech_columns = self.find_columns(content=CONTENT_TYPE_SPEECH, mime_type=SERIALIZABLE_AUDIO_TYPES) + if not speech_columns: + raise RuntimeError('No speech data (missing in schema)') + self.speech_index = speech_columns[0] + self.audio_type = self.schema[self.speech_index][MIME_TYPE_KEY] + + self.transcript_index = None + if labeled is not False: + transcript_columns = self.find_columns(content=CONTENT_TYPE_TRANSCRIPT, mime_type=MIME_TYPE_TEXT) + if transcript_columns: + self.transcript_index = transcript_columns[0] + else: + if labeled is True: + raise RuntimeError('No transcript data (missing in schema)') + + sample_chunk_len = self.read_big_int() + self.sdb_file.seek(sample_chunk_len + BIGINT_SIZE, 1) + num_samples = self.read_big_int() + for _ in range(num_samples): + self.offsets.append(self.read_big_int()) + + def read_int(self): + return int.from_bytes(self.sdb_file.read(INT_SIZE), BIG_ENDIAN) + + def read_big_int(self): + return int.from_bytes(self.sdb_file.read(BIGINT_SIZE), BIG_ENDIAN) + + def find_columns(self, content=None, mime_type=None): + criteria = [] + if content is not None: + criteria.append((CONTENT_KEY, content)) + if mime_type is not None: + criteria.append((MIME_TYPE_KEY, mime_type)) + if len(criteria) == 0: + raise ValueError('At least one of "content" or "mime-type" has to be provided') + matches = [] + for index, column in enumerate(self.schema): + matched = 0 + for field, value in criteria: + if column[field] == value or (isinstance(value, list) and column[field] in value): + matched += 1 + if matched == len(criteria): + matches.append(index) + return matches + + def read_row(self, row_index, *columns): + columns = list(columns) + column_data = [None] * len(columns) + found = 0 + if not 0 <= row_index < len(self.offsets): + raise ValueError('Wrong sample index: {} - has to be between 0 and {}' + .format(row_index, len(self.offsets) - 1)) + self.sdb_file.seek(self.offsets[row_index] + INT_SIZE) + for index in range(len(self.schema)): + chunk_len = self.read_int() + if index in columns: + column_data[columns.index(index)] = self.sdb_file.read(chunk_len) + found += 1 + if found == len(columns): + return tuple(column_data) + else: + self.sdb_file.seek(chunk_len, 1) + return tuple(column_data) + + def __getitem__(self, i): + sample_id = '{}:{}'.format(self.id_prefix, i) + if self.transcript_index is None: + [audio_data] = self.read_row(i, self.speech_index) + return Sample(self.audio_type, audio_data, sample_id=sample_id) + audio_data, transcript = self.read_row(i, self.speech_index, self.transcript_index) + transcript = transcript.decode() + return LabeledSample(self.audio_type, audio_data, transcript, sample_id=sample_id) + + def __iter__(self): + for i in range(len(self.offsets)): + yield self[i] + + def __len__(self): + return len(self.offsets) + + def close(self): + if self.sdb_file is not None: + self.sdb_file.close() + + def __del__(self): + self.close() + + +class CSV: + """Sample collection reader for reading a DeepSpeech CSV file + Automatically orders samples by CSV column wav_filesize (if available).""" + def __init__(self, csv_filename, labeled=None): + """ + Parameters + ---------- + csv_filename : str + Path to the CSV file containing sample audio paths and transcripts + labeled : bool or None + If True: Reads LabeledSample instances. Fails, if CSV file has no transcript column. + If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances. + If None: Automatically determines if CSV file has a transcript column + (reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances). + """ + self.csv_filename = csv_filename + self.labeled = labeled + self.rows = [] + csv_dir = Path(csv_filename).parent + with open(csv_filename, 'r', encoding='utf8') as csv_file: + reader = csv.DictReader(csv_file) + if 'transcript' in reader.fieldnames: + if self.labeled is None: + self.labeled = True + elif self.labeled: + raise RuntimeError('No transcript data (missing CSV column)') + for row in reader: + wav_filename = Path(row['wav_filename']) + if not wav_filename.is_absolute(): + wav_filename = csv_dir / wav_filename + wav_filename = str(wav_filename) + wav_filesize = int(row['wav_filesize']) if 'wav_filesize' in row else 0 + if self.labeled: + self.rows.append((wav_filename, wav_filesize, row['transcript'])) + else: + self.rows.append((wav_filename, wav_filesize)) + self.rows.sort(key=lambda r: r[1]) + + def __getitem__(self, i): + row = self.rows[i] + wav_filename = row[0] + with open(wav_filename, 'rb') as wav_file: + if self.labeled: + return LabeledSample(AUDIO_TYPE_WAV, wav_file.read(), row[2], sample_id=wav_filename) + return Sample(AUDIO_TYPE_WAV, wav_file.read(), sample_id=wav_filename) + + def __iter__(self): + for i in range(len(self.rows)): + yield self[i] + + def __len__(self): + return len(self.rows) + + +def samples_from_file(filename, buffering=BUFFER_SIZE, labeled=None): + """ + Returns an iterable of util.sample_collections.LabeledSample or util.audio.Sample instances + loaded from a sample source file. + + Parameters + ---------- + filename : str + Path to the sample source file (SDB or CSV) + buffering : int + Read-buffer size to use while reading files + labeled : bool or None + If True: Reads LabeledSample instances. Fails, if source provides no transcripts. + If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances. + If None: Automatically determines if source provides transcripts + (reading util.sample_collections.LabeledSample instances) or not (reading util.audio.Sample instances). + """ + ext = os.path.splitext(filename)[1].lower() + if ext == '.sdb': + return SDB(filename, buffering=buffering, labeled=labeled) + if ext == '.csv': + return CSV(filename, labeled=labeled) + raise ValueError('Unknown file type: "{}"'.format(ext)) + + +def samples_from_files(filenames, buffering=BUFFER_SIZE, labeled=None): + """ + Returns an iterable of util.sample_collections.LabeledSample or util.audio.Sample instances + loaded from a collection of sample source files. + + Parameters + ---------- + filenames : list of str + Paths to sample source files (SDBs or CSVs) + buffering : int + Read-buffer size to use while reading files + labeled : bool or None + If True: Reads LabeledSample instances. Fails, if not all sources provide transcripts. + If False: Ignores transcripts (if available) and always reads (unlabeled) util.audio.Sample instances. + If None: Reads util.sample_collections.LabeledSample instances from sources with transcripts and + util.audio.Sample instances from sources with no transcripts. + """ + filenames = list(filenames) + if len(filenames) == 0: + raise ValueError('No files') + if len(filenames) == 1: + return samples_from_file(filenames[0], buffering=buffering, labeled=labeled) + cols = list(map(partial(samples_from_file, buffering=buffering, labeled=labeled), filenames)) + return Interleaved(*cols, key=lambda s: s.duration) diff --git a/util/taskcluster.py b/util/taskcluster.py index 9116a759..a5d7dba9 100644 --- a/util/taskcluster.py +++ b/util/taskcluster.py @@ -118,7 +118,7 @@ def main(): ds_version = parse_version(version_string) args.branch = "v{}".format(version_string) else: - ds_version = args.branch.lstrip('v') + ds_version = parse_version(args.branch) if args.decoder: plat = platform.system().lower() diff --git a/util/test_data/alphabet_macos.txt b/util/test_data/alphabet_macos.txt new file mode 100644 index 00000000..f3fc2856 --- /dev/null +++ b/util/test_data/alphabet_macos.txt @@ -0,0 +1 @@ +a b c diff --git a/util/test_data/alphabet_unix.txt b/util/test_data/alphabet_unix.txt new file mode 100644 index 00000000..de980441 --- /dev/null +++ b/util/test_data/alphabet_unix.txt @@ -0,0 +1,3 @@ +a +b +c diff --git a/util/test_data/alphabet_windows.txt b/util/test_data/alphabet_windows.txt new file mode 100644 index 00000000..61b1b24d --- /dev/null +++ b/util/test_data/alphabet_windows.txt @@ -0,0 +1,4 @@ +a +b +c + diff --git a/util/test_data/validate_locale_fra.py b/util/test_data/validate_locale_fra.py new file mode 100644 index 00000000..4265fcde --- /dev/null +++ b/util/test_data/validate_locale_fra.py @@ -0,0 +1,2 @@ +def validate_label(label): + return label diff --git a/util/test_importers.py b/util/test_importers.py new file mode 100644 index 00000000..281e4ee1 --- /dev/null +++ b/util/test_importers.py @@ -0,0 +1,38 @@ +import unittest + +from argparse import Namespace +from .importers import validate_label_eng, get_validate_label + +class TestValidateLabelEng(unittest.TestCase): + + def test_numbers(self): + label = validate_label_eng("this is a 1 2 3 test") + self.assertEqual(label, None) + +class TestGetValidateLabel(unittest.TestCase): + + def test_no_validate_label_locale(self): + f = get_validate_label(Namespace()) + self.assertEqual(f('toto'), 'toto') + self.assertEqual(f('toto1234'), None) + self.assertEqual(f('toto1234[{[{[]'), None) + + def test_validate_label_locale_default(self): + f = get_validate_label(Namespace(validate_label_locale=None)) + self.assertEqual(f('toto'), 'toto') + self.assertEqual(f('toto1234'), None) + self.assertEqual(f('toto1234[{[{[]'), None) + + def test_get_validate_label_missing(self): + args = Namespace(validate_label_locale='util/test_data/validate_locale_ger.py') + f = get_validate_label(args) + self.assertEqual(f, None) + + def test_get_validate_label(self): + args = Namespace(validate_label_locale='util/test_data/validate_locale_fra.py') + f = get_validate_label(args) + l = f('toto') + self.assertEqual(l, 'toto') + +if __name__ == '__main__': + unittest.main() diff --git a/util/test_text.py b/util/test_text.py new file mode 100644 index 00000000..174a3eac --- /dev/null +++ b/util/test_text.py @@ -0,0 +1,34 @@ +import unittest +import os + +from .text import Alphabet + +class TestAlphabetParsing(unittest.TestCase): + + def _ending_tester(self, file, expected): + alphabet = Alphabet(os.path.join(os.path.dirname(__file__), 'test_data', file)) + label = '' + label_id = -1 + for expected_label, expected_label_id in expected: + try: + label_id = alphabet.encode(expected_label) + except KeyError: + pass + self.assertEqual(label_id, [expected_label_id]) + try: + label = alphabet.decode([expected_label_id]) + except KeyError: + pass + self.assertEqual(label, expected_label) + + def test_macos_ending(self): + self._ending_tester('alphabet_macos.txt', [('a', 0), ('b', 1), ('c', 2)]) + + def test_unix_ending(self): + self._ending_tester('alphabet_unix.txt', [('a', 0), ('b', 1), ('c', 2)]) + + def test_windows_ending(self): + self._ending_tester('alphabet_windows.txt', [('a', 0), ('b', 1), ('c', 2)]) + +if __name__ == '__main__': + unittest.main() diff --git a/util/text.py b/util/text.py index d3be7eb8..60bfe9f1 100644 --- a/util/text.py +++ b/util/text.py @@ -1,11 +1,8 @@ from __future__ import absolute_import, division, print_function -import codecs import numpy as np -import re import struct -from util.flags import FLAGS from six.moves import range class Alphabet(object): @@ -15,7 +12,7 @@ class Alphabet(object): self._str_to_label = {} self._size = 0 if config_file: - with codecs.open(config_file, 'r', 'utf-8') as fin: + with open(config_file, 'r', encoding='utf-8') as fin: for line in fin: if line[0:2] == '\\#': line = '#\n' @@ -33,9 +30,9 @@ class Alphabet(object): return self._str_to_label[string] except KeyError as e: raise KeyError( - 'ERROR: Your transcripts contain characters (e.g. \'{}\') which do not occur in data/alphabet.txt! Use ' \ + 'ERROR: Your transcripts contain characters (e.g. \'{}\') which do not occur in \'{}\'! Use ' \ 'util/check_characters.py to see what characters are in your [train,dev,test].csv transcripts, and ' \ - 'then add all these to data/alphabet.txt.'.format(string) + 'then add all these to \'{}\'.'.format(string, self._config_file, self._config_file) ).with_traceback(e.__traceback__) def has_char(self, char): @@ -121,19 +118,22 @@ class UTF8Alphabet(object): return '' -def text_to_char_array(series, alphabet): +def text_to_char_array(transcript, alphabet, context=''): r""" - Given a Pandas Series containing transcript string, map characters to + Given a transcript string, map characters to integers and return a numpy array representing the processed string. + Use a string in `context` for adding text to raised exceptions. """ try: - transcript = np.asarray(alphabet.encode(series['transcript'])) + transcript = alphabet.encode(transcript) if len(transcript) == 0: - raise ValueError('While processing: {}\nFound an empty transcript! You must include a transcript for all training data.'.format(series['wav_filename'])) + raise ValueError('While processing {}: Found an empty transcript! ' + 'You must include a transcript for all training data.' + .format(context)) return transcript except KeyError as e: # Provide the row context (especially wav_filename) for alphabet errors - raise ValueError('While processing: {}\n{}'.format(series['wav_filename'], e)) + raise ValueError('While processing: {}\n{}'.format(context, e)) # The following code is from: http://hetland.org/coding/python/levenshtein.py @@ -165,25 +165,3 @@ def levenshtein(a, b): current[j] = min(add, delete, change) return current[n] - -# Validate and normalize transcriptions. Returns a cleaned version of the label -# or None if it's invalid. -def validate_label(label): - # For now we can only handle [a-z '] - if re.search(r"[0-9]|[(<\[\]&*{]", label) is not None: - return None - - label = label.replace("-", " ") - label = label.replace("_", " ") - label = re.sub("[ ]{2,}", " ", label) - label = label.replace(".", "") - label = label.replace(",", "") - label = label.replace(";", "") - label = label.replace("?", "") - label = label.replace("!", "") - label = label.replace(":", "") - label = label.replace("\"", "") - label = label.strip() - label = label.lower() - - return label if label else None