Merge remote-tracking branch 'upstream/master'

This commit is contained in:
Daniel 2020-03-29 12:34:03 +02:00
commit a79cc0cee9
128 changed files with 3264 additions and 1048 deletions

View File

@ -7,11 +7,20 @@ before_cache:
python:
- "3.6"
install:
- pip install --upgrade cardboardlint pylint
script:
# Run cardboardlinter, in case of pull requests
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
cardboardlinter --refspec $TRAVIS_BRANCH -n auto;
fi
jobs:
include:
- stage: cardboard linter
install:
- pip install --upgrade cardboardlint pylint
script:
# Run cardboardlinter, in case of pull requests
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
cardboardlinter --refspec $TRAVIS_BRANCH -n auto;
fi
- stage: python unit tests
install:
- pip install --upgrade -r requirements_tests.txt
script:
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
python -m unittest;
fi

View File

@ -33,7 +33,7 @@ from util.config import Config, initialize_globals
from util.checkpoints import load_or_init_graph
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from util.flags import create_flags, FLAGS
from util.helpers import check_ctcdecoder_version
from util.helpers import check_ctcdecoder_version, ExceptionBox
from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
check_ctcdecoder_version()
@ -418,12 +418,17 @@ def train():
FLAGS.augmentation_sparse_warp):
do_cache_dataset = False
exception_box = ExceptionBox()
# Create training and validation datasets
train_set = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size,
enable_cache=FLAGS.feature_cache and do_cache_dataset,
cache_path=FLAGS.feature_cache,
train_phase=True)
train_phase=True,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
buffering=FLAGS.read_buffer)
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
tfv1.data.get_output_shapes(train_set),
@ -433,8 +438,13 @@ def train():
train_init_op = iterator.make_initializer(train_set)
if FLAGS.dev_files:
dev_csvs = FLAGS.dev_files.split(',')
dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size, train_phase=False) for csv in dev_csvs]
dev_sources = FLAGS.dev_files.split(',')
dev_sets = [create_dataset([source],
batch_size=FLAGS.dev_batch_size,
train_phase=False,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
buffering=FLAGS.read_buffer) for source in dev_sources]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
# Dropout
@ -540,6 +550,7 @@ def train():
_, current_step, batch_loss, problem_files, step_summary = \
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
feed_dict=feed_dict)
exception_box.raise_if_set()
except tf.errors.InvalidArgumentError as err:
if FLAGS.augmentation_sparse_warp:
log_info("Ignoring sparse warp error: {}".format(err))
@ -547,6 +558,7 @@ def train():
else:
raise
except tf.errors.OutOfRangeError:
exception_box.raise_if_set()
break
if problem_files.size > 0:
@ -586,12 +598,12 @@ def train():
# Validation
dev_loss = 0.0
total_steps = 0
for csv, init_op in zip(dev_csvs, dev_init_ops):
log_progress('Validating epoch %d on %s...' % (epoch, csv))
set_loss, steps = run_set('dev', epoch, init_op, dataset=csv)
for source, init_op in zip(dev_sources, dev_init_ops):
log_progress('Validating epoch %d on %s...' % (epoch, source))
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
dev_loss += set_loss * steps
total_steps += steps
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss))
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
dev_loss = dev_loss / total_steps
dev_losses.append(dev_loss)
@ -727,6 +739,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
return inputs, outputs, layers
def file_relative_read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
@ -768,7 +781,7 @@ def export():
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
output_filename = FLAGS.export_name + '.pb'
output_filename = FLAGS.export_file_name + '.pb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
@ -805,21 +818,42 @@ def export():
log_info('Models exported at %s' % (FLAGS.export_dir))
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
FLAGS.export_author_id,
FLAGS.export_model_name,
FLAGS.export_model_version))
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
with open(metadata_fname, 'w') as f:
f.write('---\n')
f.write('author: {}\n'.format(FLAGS.export_author_id))
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
f.write('license: {}\n'.format(FLAGS.export_license))
f.write('language: {}\n'.format(FLAGS.export_language))
f.write('runtime: {}\n'.format(model_runtime))
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
f.write('---\n')
f.write('{}\n'.format(FLAGS.export_description))
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
def package_zip():
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
zip_filename = os.path.dirname(export_dir)
with open(os.path.join(export_dir, 'info.json'), 'w') as f:
json.dump({
'name': FLAGS.export_language,
}, f)
shutil.copy(FLAGS.scorer_path, export_dir)
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
log_info('Exported packaged model {}'.format(archive))
def do_single_file_inference(input_file_path):
with tfv1.Session(config=Config.session_config) as session:
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
@ -895,6 +929,7 @@ def main(_):
tfv1.reset_default_graph()
do_single_file_inference(FLAGS.one_shot_infer)
if __name__ == '__main__':
create_flags()
absl.app.run(main)

View File

@ -7,6 +7,7 @@ FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04
# Get basic packages
RUN apt-get update && apt-get install -y --no-install-recommends \
apt-utils \
build-essential \
curl \
wget \
@ -44,14 +45,14 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
RUN ln -s -f /usr/bin/python3 /usr/bin/python
# Install NCCL 2.2
RUN apt-get install -qq -y --allow-downgrades --allow-change-held-packages libnccl2=2.3.7-1+cuda10.0 libnccl-dev=2.3.7-1+cuda10.0
RUN apt-get --no-install-recommends install -qq -y --allow-downgrades --allow-change-held-packages libnccl2=2.3.7-1+cuda10.0 libnccl-dev=2.3.7-1+cuda10.0
# Install Bazel
RUN curl -LO "https://github.com/bazelbuild/bazel/releases/download/0.24.1/bazel_0.24.1-linux-x86_64.deb"
RUN dpkg -i bazel_*.deb
# Install CUDA CLI Tools
RUN apt-get install -qq -y cuda-command-line-tools-10-0
RUN apt-get --no-install-recommends install -qq -y cuda-command-line-tools-10-0
# Install pip
RUN wget https://bootstrap.pypa.io/get-pip.py && \

View File

@ -14,7 +14,7 @@ Project DeepSpeech
DeepSpeech is an open source Speech-To-Text engine, using a model trained by machine learning techniques based on `Baidu's Deep Speech research paper <https://arxiv.org/abs/1412.5567>`_. Project DeepSpeech uses Google's `TensorFlow <https://www.tensorflow.org/>`_ to make the implementation easier.
**NOTE:** This documentation applies to the **MASTER version** of DeepSpeech only. If you're using a stable release, you must use the documentation for the corresponding version by using GitHub's branch switcher button above.
**NOTE:** This documentation applies to the **MASTER version** of DeepSpeech only. **Documentation for the latest stable version** is published on `deepspeech.readthedocs.io <http://deepspeech.readthedocs.io/?badge=latest>`_.
To install and use deepspeech all you have to do is:

View File

@ -1 +1 @@
0.7.0-alpha.2
0.7.0-alpha.3

56
bin/build_sdb.py Executable file
View File

@ -0,0 +1,56 @@
#!/usr/bin/env python
'''
Tool for building Sample Databases (SDB files) from DeepSpeech CSV files and other SDB files
Use "python3 build_sdb.py -h" for help
'''
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import argparse
import progressbar
from util.downloader import SIMPLE_BAR
from util.audio import change_audio_types, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS
from util.sample_collections import samples_from_files, DirectSDBWriter
AUDIO_TYPE_LOOKUP = {
'wav': AUDIO_TYPE_WAV,
'opus': AUDIO_TYPE_OPUS
}
def build_sdb():
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled) as sdb_writer:
samples = samples_from_files(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR)
for sample in bar(change_audio_types(samples, audio_type=audio_type, processes=CLI_ARGS.workers)):
sdb_writer.add(sample)
def handle_args():
parser = argparse.ArgumentParser(description='Tool for building Sample Databases (SDB files) '
'from DeepSpeech CSV files and other SDB files')
parser.add_argument('sources', nargs='+',
help='Source CSV and/or SDB files - '
'Note: For getting a correctly ordered target SDB, source SDBs have to have their samples '
'already ordered from shortest to longest.')
parser.add_argument('target', help='SDB file to create')
parser.add_argument('--audio-type', default='opus', choices=AUDIO_TYPE_LOOKUP.keys(),
help='Audio representation inside target SDB')
parser.add_argument('--workers', type=int, default=None,
help='Number of encoding SDB workers')
parser.add_argument('--unlabeled', action='store_true',
help='If to build an SDB with unlabeled (audio only) samples - '
'typically used for building noise augmentation corpora')
return parser.parse_args()
if __name__ == "__main__":
CLI_ARGS = handle_args()
build_sdb()

View File

@ -7,7 +7,7 @@ import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import argparse
from util.importers import get_importers_parser
import glob
import pandas
import tarfile
@ -81,7 +81,7 @@ def preprocess_data(tgz_file, target_dir):
def main():
# https://www.openslr.org/62/
parser = argparse.ArgumentParser(description='Import aidatatang_200zh corpus')
parser = get_importers_parser(description='Import aidatatang_200zh corpus')
parser.add_argument('tgz_file', help='Path to aidatatang_200zh.tgz')
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
params = parser.parse_args()

View File

@ -7,7 +7,7 @@ import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import argparse
from util.importers import get_importers_parser
import glob
import tarfile
import pandas
@ -80,7 +80,7 @@ def preprocess_data(tgz_file, target_dir):
def main():
# http://www.openslr.org/33/
parser = argparse.ArgumentParser(description='Import AISHELL corpus')
parser = get_importers_parser(description='Import AISHELL corpus')
parser.add_argument('aishell_tgz_file', help='Path to data_aishell.tgz')
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
params = parser.parse_args()

View File

@ -15,10 +15,8 @@ import progressbar
from glob import glob
from os import path
from threading import RLock
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from util.text import validate_label
from multiprocessing import Pool
from util.importers import validate_label_eng as validate_label, get_counter, get_imported_samples, print_import_report
from util.downloader import maybe_download, SIMPLE_BAR
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
@ -53,6 +51,38 @@ def _maybe_convert_sets(target_dir, extracted_data):
for source_csv in glob(path.join(extracted_dir, '*.csv')):
_maybe_convert_set(extracted_dir, source_csv, path.join(target_dir, os.path.split(source_csv)[-1]))
def one_sample(sample):
mp3_filename = sample[0]
# Storing wav files next to the mp3 ones - just with a different suffix
wav_filename = path.splitext(mp3_filename)[0] + ".wav"
_maybe_convert_wav(mp3_filename, wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
file_size = -1
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = validate_label(sample[1])
rows = []
counter = get_counter()
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
return (counter, rows)
def _maybe_convert_set(extracted_dir, source_csv, target_csv):
print()
if path.exists(target_csv):
@ -63,48 +93,19 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
with open(source_csv) as source_csv_file:
reader = csv.DictReader(source_csv_file)
for row in reader:
samples.append((row['filename'], row['text']))
samples.append((os.path.join(extracted_dir, row['filename']), row['text']))
# Mutable counters for the concurrent embedded routine
counter = { 'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0 }
lock = RLock()
counter = get_counter()
num_samples = len(samples)
rows = []
def one_sample(sample):
mp3_filename = path.join(*(sample[0].split('/')))
mp3_filename = path.join(extracted_dir, mp3_filename)
# Storing wav files next to the mp3 ones - just with a different suffix
wav_filename = path.splitext(mp3_filename)[0] + ".wav"
_maybe_convert_wav(mp3_filename, wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
file_size = -1
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = validate_label(sample[1])
with lock:
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
print('Importing mp3 files...')
pool = Pool(cpu_count())
pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1):
for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i)
bar.update(num_samples)
pool.close()
@ -118,15 +119,11 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
for filename, file_size, transcript in bar(rows):
writer.writerow({ 'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript })
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long']))
if counter['failed'] > 0:
print('Skipped %d samples that failed upon conversion.' % counter['failed'])
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label'])
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples
assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(mp3_filename, wav_filename):
if not path.exists(wav_filename):

View File

@ -16,18 +16,15 @@ sys.path.insert(1, os.path.join(sys.path[0], '..'))
import csv
import sox
import argparse
import subprocess
import progressbar
import unicodedata
from os import path
from threading import RLock
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from multiprocessing import Pool
from util.downloader import SIMPLE_BAR
from util.text import Alphabet, validate_label
from util.helpers import secs_to_hours
from util.text import Alphabet
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
@ -35,15 +32,50 @@ SAMPLE_RATE = 16000
MAX_SECS = 10
def _preprocess_data(tsv_dir, audio_dir, label_filter, space_after_every_character=False):
def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
for dataset in ['train', 'test', 'dev', 'validated', 'other']:
input_tsv = path.join(path.abspath(tsv_dir), dataset+".tsv")
if os.path.isfile(input_tsv):
print("Loading TSV file: ", input_tsv)
_maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_character)
_maybe_convert_set(input_tsv, audio_dir, space_after_every_character)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
mp3_filename = sample[0]
if not path.splitext(mp3_filename.lower())[1] == '.mp3':
mp3_filename += ".mp3"
# Storing wav files next to the mp3 ones - just with a different suffix
wav_filename = path.splitext(mp3_filename)[0] + ".wav"
_maybe_convert_wav(mp3_filename, wav_filename)
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter_fun(sample[1])
rows = []
counter = get_counter()
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((os.path.split(wav_filename)[-1], file_size, label))
counter['all'] += 1
counter['total_time'] += frames
def _maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_character=None):
return (counter, rows)
def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
output_csv = path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv'))
print("Saving new DeepSpeech-formatted CSV file to: ", output_csv)
@ -52,51 +84,18 @@ def _maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_cha
with open(input_tsv, encoding='utf-8') as input_tsv_file:
reader = csv.DictReader(input_tsv_file, delimiter='\t')
for row in reader:
samples.append((row['path'], row['sentence']))
samples.append((path.join(audio_dir, row['path']), row['sentence']))
# Keep track of how many samples are good vs. problematic
counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0}
lock = RLock()
counter = get_counter()
num_samples = len(samples)
rows = []
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
mp3_filename = path.join(audio_dir, sample[0])
if not path.splitext(mp3_filename.lower())[1] == '.mp3':
mp3_filename += ".mp3"
# Storing wav files next to the mp3 ones - just with a different suffix
wav_filename = path.splitext(mp3_filename)[0] + ".wav"
_maybe_convert_wav(mp3_filename, wav_filename)
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
with lock:
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((os.path.split(wav_filename)[-1], file_size, label))
counter['all'] += 1
counter['total_time'] += frames
print("Importing mp3 files...")
pool = Pool(cpu_count())
pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1):
for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i)
bar.update(num_samples)
pool.close()
@ -113,16 +112,11 @@ def _maybe_convert_set(input_tsv, audio_dir, label_filter, space_after_every_cha
else:
writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript})
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long']))
if counter['failed'] > 0:
print('Skipped %d samples that failed upon conversion.' % counter['failed'])
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label'])
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE))
imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples
assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(mp3_filename, wav_filename):
@ -136,7 +130,7 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
if __name__ == "__main__":
PARSER = argparse.ArgumentParser(description='Import CommonVoice v2.0 corpora')
PARSER = get_importers_parser(description='Import CommonVoice v2.0 corpora')
PARSER.add_argument('tsv_dir', help='Directory containing tsv files')
PARSER.add_argument('--audio_dir', help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"')
PARSER.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
@ -144,6 +138,7 @@ if __name__ == "__main__":
PARSER.add_argument('--space_after_every_character', action='store_true', help='To help transcript join by white space')
PARAMS = PARSER.parse_args()
validate_label = get_validate_label(PARAMS)
AUDIO_DIR = PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, 'clips')
ALPHABET = Alphabet(PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None
@ -161,4 +156,4 @@ if __name__ == "__main__":
label = None
return label
_preprocess_data(PARAMS.tsv_dir, AUDIO_DIR, label_filter_fun, PARAMS.space_after_every_character)
_preprocess_data(PARAMS.tsv_dir, AUDIO_DIR, PARAMS.space_after_every_character)

View File

@ -19,7 +19,7 @@ import unicodedata
import librosa
import soundfile # <= Has an external dependency on libsndfile
from util.text import validate_label
from util.importers import validate_label_eng as validate_label
def _download_and_preprocess_data(data_dir):
# Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19

View File

@ -7,7 +7,7 @@ import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import argparse
from util.importers import get_importers_parser
import glob
import numpy as np
import pandas
@ -81,7 +81,7 @@ def preprocess_data(tgz_file, target_dir):
def main():
# https://www.openslr.org/38/
parser = argparse.ArgumentParser(description='Import Free ST Chinese Mandarin corpus')
parser = get_importers_parser(description='Import Free ST Chinese Mandarin corpus')
parser.add_argument('tgz_file', help='Path to ST-CMDS-20170001_1-OS.tar.gz')
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
params = parser.parse_args()

View File

@ -1,12 +1,16 @@
#!/usr/bin/env python
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import csv
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import csv
import math
import urllib
import logging
import argparse
from util.importers import get_importers_parser, get_validate_label
import subprocess
from os import path
from pathlib import Path
@ -15,8 +19,6 @@ import swifter
import pandas as pd
from sox import Transformer
from util.text import validate_label
__version__ = "0.1.0"
_logger = logging.getLogger(__name__)
@ -38,7 +40,7 @@ def parse_args(args):
Returns:
:obj:`argparse.Namespace`: command line parameters namespace
"""
parser = argparse.ArgumentParser(
parser = get_importers_parser(
description="Imports GramVaani data for Deep Speech"
)
parser.add_argument(
@ -286,6 +288,7 @@ def main(args):
args ([str]): command line parameter list
"""
args = parse_args(args)
validate_label = get_validate_label(args)
setup_logging(args.loglevel)
_logger.info("Starting GramVaani importer...")
_logger.info("Starting loading GramVaani csv...")

View File

@ -3,13 +3,13 @@ from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import argparse
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import argparse
import csv
import re
import sox
@ -18,17 +18,14 @@ import subprocess
import progressbar
import unicodedata
from threading import RLock
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from multiprocessing import Pool
from util.downloader import SIMPLE_BAR
from os import path
from glob import glob
from util.downloader import maybe_download
from util.text import Alphabet, validate_label
from util.helpers import secs_to_hours
from util.text import Alphabet
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
SAMPLE_RATE = 16000
@ -61,6 +58,41 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
else:
print('Found directory "%s" - not extracting it from archive.' % archive_path)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
ogg_filename = sample[0]
# Storing wav files next to the ogg ones - just with a different suffix
wav_filename = path.splitext(ogg_filename)[0] + ".wav"
_maybe_convert_wav(ogg_filename, wav_filename)
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
rows = []
counter = get_counter()
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data)
# override existing CSV with normalized one
@ -76,49 +108,18 @@ def _maybe_convert_sets(target_dir, extracted_data):
for record in glob(glob_dir, recursive=True):
record_file = record.replace(ogg_root_dir + os.path.sep, '')
if record_filter(record_file):
samples.append((record_file, os.path.splitext(os.path.basename(record_file))[0]))
samples.append((os.path.join(ogg_root_dir, record_file), os.path.splitext(os.path.basename(record_file))[0]))
# Keep track of how many samples are good vs. problematic
counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0}
lock = RLock()
counter = get_counter()
num_samples = len(samples)
rows = []
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
ogg_filename = path.join(ogg_root_dir, sample[0])
# Storing wav files next to the ogg ones - just with a different suffix
wav_filename = path.splitext(ogg_filename)[0] + ".wav"
_maybe_convert_wav(ogg_filename, wav_filename)
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
with lock:
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
print("Importing ogg files...")
pool = Pool(cpu_count())
pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1):
for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i)
bar.update(num_samples)
pool.close()
@ -152,16 +153,11 @@ def _maybe_convert_sets(target_dir, extracted_data):
transcript=transcript,
))
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long']))
if counter['failed'] > 0:
print('Skipped %d samples that failed upon conversion.' % counter['failed'])
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label'])
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE))
imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples
assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(ogg_filename, wav_filename):
if not path.exists(wav_filename):
@ -173,7 +169,7 @@ def _maybe_convert_wav(ogg_filename, wav_filename):
print('SoX processing error', ex, ogg_filename, wav_filename)
def handle_args():
parser = argparse.ArgumentParser(description='Importer for LinguaLibre dataset. Check https://lingualibre.fr/wiki/Help:Download_from_LinguaLibre for details.')
parser = get_importers_parser(description='Importer for LinguaLibre dataset. Check https://lingualibre.fr/wiki/Help:Download_from_LinguaLibre for details.')
parser.add_argument(dest='target_dir')
parser.add_argument('--qId', type=int, required=True, help='LinguaLibre language qId')
parser.add_argument('--iso639-3', type=str, required=True, help='ISO639-3 language code')
@ -186,6 +182,7 @@ def handle_args():
if __name__ == "__main__":
CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
validate_label = get_validate_label(CLI_ARGS)
bogus_regexes = []
if CLI_ARGS.bogus_records:

View File

@ -4,29 +4,27 @@ from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import argparse
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv
import subprocess
import progressbar
import unicodedata
import tarfile
from threading import RLock
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from multiprocessing import Pool
from util.downloader import SIMPLE_BAR
from os import path
from glob import glob
from util.downloader import maybe_download
from util.text import Alphabet, validate_label
from util.helpers import secs_to_hours
from util.text import Alphabet
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
SAMPLE_RATE = 16000
@ -62,6 +60,38 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
print('Found directory "%s" - not extracting it from archive.' % archive_path)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
wav_filename = sample[0]
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
counter = get_counter()
rows = []
if file_size == -1:
# Excluding samples that failed upon conversion
print("conversion failure", wav_filename)
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data)
# override existing CSV with normalized one
@ -84,44 +114,16 @@ def _maybe_convert_sets(target_dir, extracted_data):
transcript = re[2]
samples.append((audio, transcript))
# Keep track of how many samples are good vs. problematic
counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0}
lock = RLock()
counter = get_counter()
num_samples = len(samples)
rows = []
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
wav_filename = sample[0]
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
with lock:
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
print("Importing WAV files...")
pool = Pool(cpu_count())
pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1):
for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i)
bar.update(num_samples)
pool.close()
@ -155,20 +157,14 @@ def _maybe_convert_sets(target_dir, extracted_data):
transcript=transcript,
))
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long']))
if counter['failed'] > 0:
print('Skipped %d samples that failed upon conversion.' % counter['failed'])
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label'])
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE))
imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples
assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def handle_args():
parser = argparse.ArgumentParser(description='Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/.')
parser = get_importers_parser(description='Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/.')
parser.add_argument(dest='target_dir')
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
@ -181,6 +177,7 @@ if __name__ == "__main__":
CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(','))
validate_label = get_validate_label(CLI_ARGS)
def label_filter(label):
if CLI_ARGS.normalize:

View File

@ -7,7 +7,7 @@ import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import argparse
from util.importers import get_importers_parser
import glob
import pandas
import tarfile
@ -99,7 +99,7 @@ def preprocess_data(folder_with_archives, target_dir):
def main():
# https://openslr.org/68/
parser = argparse.ArgumentParser(description='Import MAGICDATA corpus')
parser = get_importers_parser(description='Import MAGICDATA corpus')
parser.add_argument('folder_with_archives', help='Path to folder containing magicdata_{train,dev,test}.tar.gz')
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to a folder called magicdata next to the archives')
params = parser.parse_args()

View File

@ -7,7 +7,7 @@ import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import argparse
from util.importers import get_importers_parser
import glob
import json
import numpy as np
@ -93,7 +93,7 @@ def preprocess_data(tgz_file, target_dir):
def main():
# https://www.openslr.org/47/
parser = argparse.ArgumentParser(description='Import Primewords Chinese corpus set 1')
parser = get_importers_parser(description='Import Primewords Chinese corpus set 1')
parser.add_argument('tgz_file', help='Path to primewords_md_2018_set1.tar.gz')
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
params = parser.parse_args()

View File

@ -3,13 +3,12 @@ from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import argparse
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv
import re
import sox
@ -19,16 +18,14 @@ import progressbar
import unicodedata
import tarfile
from threading import RLock
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from multiprocessing import Pool
from util.downloader import SIMPLE_BAR
from os import path
from glob import glob
from util.downloader import maybe_download
from util.text import Alphabet, validate_label
from util.text import Alphabet
from util.helpers import secs_to_hours
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
@ -63,6 +60,37 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
else:
print('Found directory "%s" - not extracting it from archive.' % archive_path)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
wav_filename = sample[0]
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
counter = get_counter()
rows = []
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data)
# override existing CSV with normalized one
@ -113,43 +141,16 @@ def _maybe_convert_sets(target_dir, extracted_data):
samples.append((record, transcripts[record_file]))
# Keep track of how many samples are good vs. problematic
counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0}
lock = RLock()
counter = get_counter()
num_samples = len(samples)
rows = []
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
wav_filename = sample[0]
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = label_filter(sample[1])
with lock:
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
print("Importing WAV files...")
pool = Pool(cpu_count())
pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1):
for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i)
bar.update(num_samples)
pool.close()
@ -183,19 +184,14 @@ def _maybe_convert_sets(target_dir, extracted_data):
transcript=transcript,
))
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long']))
if counter['failed'] > 0:
print('Skipped %d samples that failed upon conversion.' % counter['failed'])
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label'])
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE))
imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples
assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def handle_args():
parser = argparse.ArgumentParser(description='Importer for African Accented French dataset. More information on http://www.openslr.org/57/.')
parser = get_importers_parser(description='Importer for African Accented French dataset. More information on http://www.openslr.org/57/.')
parser.add_argument(dest='target_dir')
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
@ -204,6 +200,7 @@ def handle_args():
if __name__ == "__main__":
CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
validate_label = get_validate_label(CLI_ARGS)
def label_filter(label):
if CLI_ARGS.normalize:

View File

@ -20,7 +20,7 @@ import wave
import codecs
import tarfile
import requests
from util.text import validate_label
from util.importers import validate_label_eng as validate_label
import librosa
import soundfile # <= Has an external dependency on libsndfile

View File

@ -27,7 +27,8 @@ from os import path
from glob import glob
from collections import Counter
from multiprocessing.pool import ThreadPool
from util.text import Alphabet, validate_label
from util.text import Alphabet
from util.importers import validate_label_eng as validate_label
from util.downloader import maybe_download, SIMPLE_BAR
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"

View File

@ -3,14 +3,13 @@ from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import argparse
import os
import re
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv
import unidecode
import zipfile
@ -18,16 +17,12 @@ import sox
import subprocess
import progressbar
from threading import RLock
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from multiprocessing import Pool
from util.downloader import SIMPLE_BAR
from os import path
from util.downloader import maybe_download
from util.text import validate_label
from util.helpers import secs_to_hours
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
SAMPLE_RATE = 16000
@ -61,6 +56,44 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
print('Found directory "%s" - not extracting it from archive.' % archive_path)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
orig_filename = sample['path']
# Storing wav files next to the wav ones - just with a different suffix
wav_filename = path.splitext(orig_filename)[0] + ".converted.wav"
_maybe_convert_wav(orig_filename, wav_filename)
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = sample['text']
rows = []
# Keep track of how many samples are good vs. problematic
counter = get_counter()
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
extracted_dir = path.join(target_dir, extracted_data)
# override existing CSV with normalized one
@ -74,49 +107,19 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
if float(d['duration']) <= MAX_SECS
]
# Keep track of how many samples are good vs. problematic
counter = {'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'total_time': 0}
lock = RLock()
for line in data:
line['path'] = os.path.join(extracted_dir, line['path'])
num_samples = len(data)
rows = []
counter = get_counter()
wav_root_dir = extracted_dir
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
orig_filename = path.join(wav_root_dir, sample['path'])
# Storing wav files next to the wav ones - just with a different suffix
wav_filename = path.splitext(orig_filename)[0] + ".converted.wav"
_maybe_convert_wav(orig_filename, wav_filename)
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = sample['text']
with lock:
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
print("Importing wav files...")
pool = Pool(cpu_count())
print("Importing {} wav files...".format(num_samples))
pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, data), start=1):
for i, processed in enumerate(pool.imap_unordered(one_sample, data), start=1):
counter += processed[0]
rows += processed[1]
bar.update(i)
bar.update(num_samples)
pool.close()
@ -133,7 +136,6 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
test_writer.writeheader()
for i, item in enumerate(rows):
print('item', item)
transcript = validate_label(cleanup_transcript(item[2], english_compatible=english_compatible))
if not transcript:
continue
@ -151,16 +153,11 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
transcript=transcript,
))
print('Imported %d samples.' % (counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long']))
if counter['failed'] > 0:
print('Skipped %d samples that failed upon conversion.' % counter['failed'])
if counter['invalid_label'] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label'])
if counter['too_short'] > 0:
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short'])
if counter['too_long'] > 0:
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], MAX_SECS))
print('Final amount of imported audio: %s.' % secs_to_hours(counter['total_time'] / SAMPLE_RATE))
imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples
assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(orig_filename, wav_filename):
if not path.exists(wav_filename):
@ -186,7 +183,7 @@ def cleanup_transcript(text, english_compatible=False):
def handle_args():
parser = argparse.ArgumentParser(description='Importer for TrainingSpeech dataset.')
parser = get_importers_parser(description='Importer for TrainingSpeech dataset.')
parser.add_argument(dest='target_dir')
parser.add_argument('--english-compatible', action='store_true', dest='english_compatible', help='Remove diactrics and other non-ascii chars.')
return parser.parse_args()
@ -194,4 +191,5 @@ def handle_args():
if __name__ == "__main__":
cli_args = handle_args()
validate_label = get_validate_label(cli_args)
_download_and_preprocess_data(cli_args.target_dir, cli_args.english_compatible)

View File

@ -21,7 +21,8 @@ import xml.etree.cElementTree as ET
from os import path
from collections import Counter
from util.text import Alphabet, validate_label
from util.text import Alphabet
from util.importers import validate_label_eng as validate_label
from util.downloader import maybe_download, SIMPLE_BAR
TUDA_VERSION = 'v2'

View File

@ -14,13 +14,14 @@ import sys
sys.path.insert(1, os.path.join(sys.path[0], ".."))
from util.importers import get_counter, get_imported_samples, print_import_report
import re
import librosa
import progressbar
from os import path
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from multiprocessing import Pool
from util.downloader import maybe_download, SIMPLE_BAR
from zipfile import ZipFile
@ -61,47 +62,46 @@ def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data, "wav48")
txt_dir = path.join(target_dir, extracted_data, "txt")
cnt = 1
directory = os.path.expanduser(extracted_dir)
srtd = len(sorted(os.listdir(directory)))
all_samples = []
for target in sorted(os.listdir(directory)):
print(f"\nSpeaker {cnt} of {srtd}")
_maybe_convert_set(path.join(extracted_dir, os.path.split(target)[-1]))
cnt += 1
_write_csv(extracted_dir, txt_dir, target_dir)
def _maybe_convert_set(target_csv):
def one_sample(sample):
if is_audio_file(sample):
sample = os.path.join(target_csv, sample)
y, sr = librosa.load(sample, sr=16000)
# Trim the beginning and ending silence
yt, index = librosa.effects.trim(y) # pylint: disable=unused-variable
duration = librosa.get_duration(yt, sr)
if duration > MAX_SECS or duration < MIN_SECS:
os.remove(sample)
else:
librosa.output.write_wav(sample, yt, sr)
samples = sorted(os.listdir(target_csv))
num_samples = len(samples)
all_samples += _maybe_prepare_set(path.join(extracted_dir, os.path.split(target)[-1]))
num_samples = len(all_samples)
print(f"Converting wav files to {SAMPLE_RATE}hz...")
pool = Pool(cpu_count())
pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, _ in enumerate(pool.imap_unordered(one_sample, samples), start=1):
for i, _ in enumerate(pool.imap_unordered(one_sample, all_samples), start=1):
bar.update(i)
bar.update(num_samples)
pool.close()
pool.join()
_write_csv(extracted_dir, txt_dir, target_dir)
def one_sample(sample):
if is_audio_file(sample):
y, sr = librosa.load(sample, sr=16000)
# Trim the beginning and ending silence
yt, index = librosa.effects.trim(y) # pylint: disable=unused-variable
duration = librosa.get_duration(yt, sr)
if duration > MAX_SECS or duration < MIN_SECS:
os.remove(sample)
else:
librosa.output.write_wav(sample, yt, sr)
def _maybe_prepare_set(target_csv):
samples = sorted(os.listdir(target_csv))
new_samples = []
for s in samples:
new_samples.append(os.path.join(target_csv, s))
samples = new_samples
return samples
def _write_csv(extracted_dir, txt_dir, target_dir):
print(f"Writing CSV file")
@ -196,8 +196,8 @@ def load_txts(directory):
AUDIO_EXTENSIONS = [".wav", "WAV"]
def is_audio_file(filename):
return any(filename.endswith(extension) for extension in AUDIO_EXTENSIONS)
def is_audio_file(filepath):
return any(os.path.basename(filepath).endswith(extension) for extension in AUDIO_EXTENSIONS)
if __name__ == "__main__":

74
bin/play.py Executable file
View File

@ -0,0 +1,74 @@
#!/usr/bin/env python
"""
Tool for playing samples from Sample Databases (SDB files) and DeepSpeech CSV files
Use "python3 build_sdb.py -h" for help
"""
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import random
import argparse
from util.sample_collections import samples_from_file, LabeledSample
from util.audio import AUDIO_TYPE_PCM
def play_sample(samples, index):
if index < 0:
index = len(samples) + index
if CLI_ARGS.random:
index = random.randint(0, len(samples))
elif index >= len(samples):
print('No sample with index {}'.format(CLI_ARGS.start))
sys.exit(1)
sample = samples[index]
print('Sample "{}"'.format(sample.sample_id))
if isinstance(sample, LabeledSample):
print(' "{}"'.format(sample.transcript))
sample.change_audio_type(AUDIO_TYPE_PCM)
rate, channels, width = sample.audio_format
wave_obj = simpleaudio.WaveObject(sample.audio, channels, width, rate)
play_obj = wave_obj.play()
play_obj.wait_done()
def play_collection():
samples = samples_from_file(CLI_ARGS.collection, buffering=0)
played = 0
index = CLI_ARGS.start
while True:
if 0 <= CLI_ARGS.number <= played:
return
play_sample(samples, index)
played += 1
index = (index + 1) % len(samples)
def handle_args():
parser = argparse.ArgumentParser(description='Tool for playing samples from Sample Databases (SDB files) '
'and DeepSpeech CSV files')
parser.add_argument('collection', help='Sample DB or CSV file to play samples from')
parser.add_argument('--start', type=int, default=0,
help='Sample index to start at (negative numbers are relative to the end of the collection)')
parser.add_argument('--number', type=int, default=-1, help='Number of samples to play (-1 for endless)')
parser.add_argument('--random', action='store_true', help='If samples should be played in random order')
return parser.parse_args()
if __name__ == "__main__":
try:
import simpleaudio
except ModuleNotFoundError:
print('play.py requires Python package "simpleaudio"')
sys.exit(1)
CLI_ARGS = handle_args()
try:
play_collection()
except KeyboardInterrupt:
print(' Stopped')
sys.exit(0)

View File

@ -0,0 +1,37 @@
#!/bin/sh
set -xe
ldc93s1_dir="./data/smoke_test"
ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv"
ldc93s1_sdb="${ldc93s1_dir}/ldc93s1.sdb"
if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
fi;
if [ ! -f "${ldc93s1_dir}/ldc93s1.sdb" ]; then
echo "Converting LDC93S1 example data, saving to ${ldc93s1_sdb}."
python -u bin/build_sdb.py ${ldc93s1_csv} ${ldc93s1_sdb}
fi;
# Force only one visible device because we have a single-sample dataset
# and when trying to run on multiple devices (like GPUs), this will break
export CUDA_VISIBLE_DEVICES=0
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--train_files ${ldc93s1_sdb} --train_batch_size 1 \
--dev_files ${ldc93s1_sdb} --dev_batch_size 1 \
--test_files ${ldc93s1_sdb} --test_batch_size 1 \
--n_hidden 100 --epochs 1 \
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt_sdb' \
--learning_rate 0.001 --dropout_rate 0.05 \
--scorer_path 'data/smoke_test/pruned_lm.scorer' | tee /tmp/resume.log
if ! grep "Loading best validating checkpoint from" /tmp/resume.log; then
echo "Did not resume training from checkpoint"
exit 1
else
exit 0
fi

34
bin/run-tc-ldc93s1_new_sdb.sh Executable file
View File

@ -0,0 +1,34 @@
#!/bin/sh
set -xe
ldc93s1_dir="./data/smoke_test"
ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv"
ldc93s1_sdb="${ldc93s1_dir}/ldc93s1.sdb"
epoch_count=$1
audio_sample_rate=$2
if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
fi;
if [ ! -f "${ldc93s1_dir}/ldc93s1.sdb" ]; then
echo "Converting LDC93S1 example data, saving to ${ldc93s1_sdb}."
python -u bin/build_sdb.py ${ldc93s1_csv} ${ldc93s1_sdb}
fi;
# Force only one visible device because we have a single-sample dataset
# and when trying to run on multiple devices (like GPUs), this will break
export CUDA_VISIBLE_DEVICES=0
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--train_files ${ldc93s1_sdb} --train_batch_size 1 \
--dev_files ${ldc93s1_sdb} --dev_batch_size 1 \
--test_files ${ldc93s1_sdb} --test_batch_size 1 \
--n_hidden 100 --epochs $epoch_count \
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt_sdb' \
--learning_rate 0.001 --dropout_rate 0.05 --export_dir '/tmp/train_sdb' \
--scorer_path 'data/smoke_test/pruned_lm.scorer' \
--audio_sample_rate ${audio_sample_rate}

View File

@ -0,0 +1,35 @@
#!/bin/sh
set -xe
ldc93s1_dir="./data/smoke_test"
ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv"
ldc93s1_sdb="${ldc93s1_dir}/ldc93s1.sdb"
epoch_count=$1
audio_sample_rate=$2
if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
fi;
if [ ! -f "${ldc93s1_dir}/ldc93s1.sdb" ]; then
echo "Converting LDC93S1 example data, saving to ${ldc93s1_sdb}."
python -u bin/build_sdb.py ${ldc93s1_csv} ${ldc93s1_sdb}
fi;
# Force only one visible device because we have a single-sample dataset
# and when trying to run on multiple devices (like GPUs), this will break
export CUDA_VISIBLE_DEVICES=0
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--train_files ${ldc93s1_sdb},${ldc93s1_csv} --train_batch_size 1 \
--feature_cache '/tmp/ldc93s1_cache_sdb_csv' \
--dev_files ${ldc93s1_sdb},${ldc93s1_csv} --dev_batch_size 1 \
--test_files ${ldc93s1_sdb},${ldc93s1_csv} --test_batch_size 1 \
--n_hidden 100 --epochs $epoch_count \
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt_sdb_csv' \
--learning_rate 0.001 --dropout_rate 0.05 --export_dir '/tmp/train_sdb_csv' \
--scorer_path 'data/smoke_test/pruned_lm.scorer' \
--audio_sample_rate ${audio_sample_rate}

View File

@ -34,6 +34,9 @@ C
.. doxygenfunction:: DS_IntermediateDecode
:project: deepspeech-c
.. doxygenfunction:: DS_IntermediateDecodeWithMetadata
:project: deepspeech-c
.. doxygenfunction:: DS_FinishStream
:project: deepspeech-c

79
doc/Decoder.rst Normal file
View File

@ -0,0 +1,79 @@
.. _decoder-docs:
CTC beam search decoder with external scorer
============================================
Introduction
^^^^^^^^^^^^
DeepSpeech uses the `Connectionist Temporal Classification <http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_ loss function. For an excellent explanation of CTC and its usage, see this Distill article: `Sequence Modeling with CTC <https://distill.pub/2017/ctc/>`_. This document assumes the reader is familiar with the concepts described in that article, and describes DeepSpeech specific behaviors that developers building systems with DeepSpeech should know to avoid problems.
The key words "MUST", "MUST NOT", "REQUIRED", "SHALL", "SHALL NOT", "SHOULD", "SHOULD NOT", "RECOMMENDED", "MAY", and "OPTIONAL" in this document are to be interpreted as described in `BCP 14 <https://tools.ietf.org/html/bcp14>`_ when, and only when, they appear in all capitals, as shown here.
External scorer
^^^^^^^^^^^^^^^
DeepSpeech clients support OPTIONAL use of an external language model to improve the accuracy of the predicted transcripts. In the code, command line parameters, and documentation, this is referred to as a "scorer". The scorer is used to compute the likelihood (also called a score, hence the name "scorer") of sequences of words or characters in the output, to guide the decoder towards more likely results. This improves accuracy significantly.
The use of an external scorer is fully optional. When an external scorer is not specified, DeepSpeech still uses a beam search decoding algorithm, but without any outside scoring.
Currently, the DeepSpeech external scorer is implemented with `KenLM <https://kheafield.com/code/kenlm/>`_, plus some tooling to package the necessary files and metadata into a single ``.scorer`` package. The tooling lives in ``data/lm/``. The scripts included in ``data/lm/`` can be used and modified to build your own language model based on your particular use case or language.
The scripts are geared towards replicating the language model files we release as part of `DeepSpeech model releases <https://github.com/mozilla/DeepSpeech/releases/latest>`_, but modifying them to use different datasets or language model construction parameters should be simple.
Decoding modes
^^^^^^^^^^^^^^
DeepSpeech currently supports two modes of operation with significant differences at both training and decoding time.
Default mode (alphabet based)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The default mode, which uses an alphabet file (specified with ``--alphabet_config_path`` at training and export time) to determine which labels (characters), and how many of them, to predict in the output layer. At decoding time, if using an external scorer, it MUST be word based and MUST be built using the same alphabet file used for training. Word based means the text corpus used to build the scorer should contain words separated by whitespace. For most western languages, this is the default and requires no special steps from the developer when creating the scorer.
UTF-8 mode
^^^^^^^^^^
In UTF-8 mode the model predicts UTF-8 bytes directly instead of letters from an alphabet file. This idea was proposed in the paper `Bytes Are All You Need <https://arxiv.org/abs/1811.09021>`_. This mode is enabled with the ``--utf8`` flag at training and export time. At training time, the alphabet file is not used. Instead, the model is forced to have 256 labels, with labels 0-254 corresponding to UTF-8 byte values 1-255, and label 255 is used for the CTC blank symbol. If using an external scorer at decoding time, it MUST be built according to the instructions that follow.
UTF-8 decoding can be useful for languages with very large alphabets, such as Mandarin written with Simplified Chinese characters. It may also be useful for building multi-language models, or as a base for transfer learning. Currently these cases are untested and unsupported. Note that UTF-8 mode makes assumptions that hold for Mandarin written with Simplified Chinese characters and may not hold for other languages.
UTF-8 scorers are character based (more specifically, Unicode codepoint based), but the way they are used is similar to a word based scorer where each "word" is a sequence of UTF-8 bytes representing a single Unicode codepoint. This means that the input text used to create UTF-8 scorers should contain space separated Unicode codepoints. For example, the following input text:
``早 上 好``
corresponds to the following three "words", or UTF-8 byte sequences:
``E6 97 A9``
``E4 B8 8A``
``E5 A5 BD``
At decoding time, the scorer is queried every time a Unicode codepoint is predicted, instead of when a space character is predicted. From the language modeling perspective, this is a character based model. From the implementation perspective, this is a word based model, because each character is composed of multiple labels.
**Acoustic models trained with ``--utf8`` MUST NOT be used with an alphabet based scorer. Conversely, acoustic models trained with an alphabet file MUST NOT be used with a UTF-8 scorer.**
UTF-8 scorers can be built by using an input corpus with space separated codepoints. If your corpus only contains single codepoints separated by spaces, ``data/lm/generate_package.py`` should automatically enable UTF-8 mode, and it should print the message "Looks like a character based model."
If the message "Doesn't look like a character based model." is printed, you should double check your inputs to make sure it only contains single codepoints separated by spaces. UTF-8 mode can be forced by specifying the ``--force_utf8`` flag when running ``data/lm/generate_package.py``, but it is NOT RECOMMENDED.
Because KenLM uses spaces as a word separator, the resulting language model will not include space characters in it. If you wish to use UTF-8 mode but still model spaces, you need to replace spaces in the input corpus with a different character **before** converting it to space separated codepoints. For example:
.. code-block:: python
input_text = 'The quick brown fox jumps over the lazy dog'
spaces_replaced = input_text.replace(' ', '|')
space_separated = ' '.join(spaces_replaced)
print(space_separated)
# T h e | q u i c k | b r o w n | f o x | j u m p s | o v e r | t h e | l a z y | d o g
The character, '|' in this case, will then have to be replaced with spaces as a post-processing step after decoding.
Implementation
^^^^^^^^^^^^^^
The decoder source code can be found in ``native_client/ctcdecode``. The decoder is included in the language bindings and clients. In addition, there is a separate Python module which includes just the decoder and is needed for evaluation. In order to build and install this package, see the :github:`native_client README <native_client/README.rst#install-the-ctc-decoder-package>`.

View File

@ -31,13 +31,20 @@ ErrorCodes
Metadata
--------
.. doxygenstruct:: DeepSpeechClient::Structs::Metadata
.. doxygenclass:: DeepSpeechClient::Models::Metadata
:project: deepspeech-dotnet
:members: items, num_items, confidence
:members: Transcripts
MetadataItem
------------
CandidateTranscript
-------------------
.. doxygenstruct:: DeepSpeechClient::Structs::MetadataItem
.. doxygenclass:: DeepSpeechClient::Models::CandidateTranscript
:project: deepspeech-dotnet
:members: character, timestep, start_time
:members: Tokens, Confidence
TokenMetadata
-------------
.. doxygenclass:: DeepSpeechClient::Models::TokenMetadata
:project: deepspeech-dotnet
:members: Text, Timestep, StartTime

View File

@ -13,11 +13,17 @@ Metadata
.. doxygenclass:: org::mozilla::deepspeech::libdeepspeech::Metadata
:project: deepspeech-java
:members: getItems, getNum_items, getProbability, getItem
:members: getTranscripts, getNum_transcripts, getTranscript
MetadataItem
------------
CandidateTranscript
-------------------
.. doxygenclass:: org::mozilla::deepspeech::libdeepspeech::MetadataItem
.. doxygenclass:: org::mozilla::deepspeech::libdeepspeech::CandidateTranscript
:project: deepspeech-java
:members: getCharacter, getTimestep, getStart_time
:members: getTokens, getNum_tokens, getConfidence, getToken
TokenMetadata
-------------
.. doxygenclass:: org::mozilla::deepspeech::libdeepspeech::TokenMetadata
:project: deepspeech-java
:members: getText, getTimestep, getStart_time

View File

@ -30,8 +30,14 @@ Metadata
.. js:autoclass:: Metadata
:members:
MetadataItem
------------
CandidateTranscript
-------------------
.. js:autoclass:: MetadataItem
.. js:autoclass:: CandidateTranscript
:members:
TokenMetadata
-------------
.. js:autoclass:: TokenMetadata
:members:

View File

@ -21,8 +21,14 @@ Metadata
.. autoclass:: Metadata
:members:
MetadataItem
------------
CandidateTranscript
-------------------
.. autoclass:: MetadataItem
.. autoclass:: CandidateTranscript
:members:
TokenMetadata
-------------
.. autoclass:: TokenMetadata
:members:

View File

@ -8,9 +8,16 @@ Metadata
:project: deepspeech-c
:members:
MetadataItem
------------
CandidateTranscript
-------------------
.. doxygenstruct:: MetadataItem
.. doxygenstruct:: CandidateTranscript
:project: deepspeech-c
:members:
TokenMetadata
-------------
.. doxygenstruct:: TokenMetadata
:project: deepspeech-c
:members:

View File

@ -219,6 +219,11 @@ Note: the released models were trained with ``--n_hidden 2048``\ , so you need t
Key cudnn_lstm/rnn/multi_rnn_cell/cell_0/cudnn_compatible_lstm_cell/bias/Adam not found in checkpoint
UTF-8 mode
^^^^^^^^^^
DeepSpeech includes a UTF-8 operating mode which can be useful to model languages with very large alphabets, such as Chinese Mandarin. For details on how it works and how to use it, see :ref:`decoder-docs`.
Training with augmentation
^^^^^^^^^^^^^^^^^^^^^^^^^^

View File

@ -40,7 +40,7 @@ import semver
# -- Project information -----------------------------------------------------
project = u'DeepSpeech'
copyright = '2019, Mozilla Corporation'
copyright = '2019-2020, Mozilla Corporation'
author = 'Mozilla Corporation'
with open('../VERSION', 'r') as ver:

View File

@ -790,7 +790,7 @@ WARN_LOGFILE =
# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING
# Note: If this tag is empty the current directory is searched.
INPUT = native_client/dotnet/DeepSpeechClient/ native_client/dotnet/DeepSpeechClient/Interfaces/ native_client/dotnet/DeepSpeechClient/Enums/ native_client/dotnet/DeepSpeechClient/Structs/
INPUT = native_client/dotnet/DeepSpeechClient/ native_client/dotnet/DeepSpeechClient/Interfaces/ native_client/dotnet/DeepSpeechClient/Enums/ native_client/dotnet/DeepSpeechClient/Models/
# This tag can be used to specify the character encoding of the source files
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses

@ -1 +1 @@
Subproject commit 3beecad75c6dbe92d0604690014a3dba9fb9c926
Subproject commit 81a06eea64d1dda734f6b97b3005b4416ac2f50a

59
lm_optimizer.py Normal file
View File

@ -0,0 +1,59 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
import sys
import optuna
import absl.app
from ds_ctcdecoder import Scorer
import tensorflow.compat.v1 as tfv1
from DeepSpeech import create_model
from evaluate import evaluate
from util.config import Config, initialize_globals
from util.flags import create_flags, FLAGS
from util.logging import log_error
from util.evaluate_tools import wer_cer_batch
def character_based():
is_character_based = False
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
is_character_based = scorer.is_utf8_mode()
return is_character_based
def objective(trial):
FLAGS.lm_alpha = trial.suggest_uniform('lm_alpha', 0, FLAGS.lm_alpha_max)
FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max)
tfv1.reset_default_graph()
samples = evaluate(FLAGS.test_files.split(','), create_model)
is_character_based = trial.study.user_attrs['is_character_based']
wer, cer = wer_cer_batch(samples)
return cer if is_character_based else wer
def main(_):
initialize_globals()
if not FLAGS.test_files:
log_error('You need to specify what files to use for evaluation via '
'the --test_files flag.')
sys.exit(1)
is_character_based = character_based()
study = optuna.create_study()
study.set_user_attr("is_character_based", is_character_based)
study.optimize(objective, n_jobs=1, n_trials=FLAGS.n_trials)
print('Best params: lm_alpha={} and lm_beta={} with WER={}'.format(study.best_params['lm_alpha'],
study.best_params['lm_beta'],
study.best_value))
if __name__ == '__main__':
create_flags()
absl.app.run(main)

View File

@ -34,6 +34,8 @@ bool extended_metadata = false;
bool json_output = false;
int json_candidate_transcripts = 3;
int stream_size = 0;
void PrintHelp(const char* bin)
@ -43,18 +45,19 @@ void PrintHelp(const char* bin)
"\n"
"Running DeepSpeech inference.\n"
"\n"
"\t--model MODEL\t\tPath to the model (protocol buffer binary file)\n"
"\t--scorer SCORER\t\tPath to the external scorer file\n"
"\t--audio AUDIO\t\tPath to the audio file to run (WAV format)\n"
"\t--beam_width BEAM_WIDTH\tValue for decoder beam width (int)\n"
"\t--lm_alpha LM_ALPHA\tValue for language model alpha param (float)\n"
"\t--lm_beta LM_BETA\tValue for language model beta param (float)\n"
"\t-t\t\t\tRun in benchmark mode, output mfcc & inference time\n"
"\t--extended\t\tOutput string from extended metadata\n"
"\t--json\t\t\tExtended output, shows word timings as JSON\n"
"\t--stream size\t\tRun in stream mode, output intermediate results\n"
"\t--help\t\t\tShow help\n"
"\t--version\t\tPrint version and exits\n";
"\t--model MODEL\t\t\tPath to the model (protocol buffer binary file)\n"
"\t--scorer SCORER\t\t\tPath to the external scorer file\n"
"\t--audio AUDIO\t\t\tPath to the audio file to run (WAV format)\n"
"\t--beam_width BEAM_WIDTH\t\tValue for decoder beam width (int)\n"
"\t--lm_alpha LM_ALPHA\t\tValue for language model alpha param (float)\n"
"\t--lm_beta LM_BETA\t\tValue for language model beta param (float)\n"
"\t-t\t\t\t\tRun in benchmark mode, output mfcc & inference time\n"
"\t--extended\t\t\tOutput string from extended metadata\n"
"\t--json\t\t\t\tExtended output, shows word timings as JSON\n"
"\t--candidate_transcripts NUMBER\tNumber of candidate transcripts to include in output\n"
"\t--stream size\t\t\tRun in stream mode, output intermediate results\n"
"\t--help\t\t\t\tShow help\n"
"\t--version\t\t\tPrint version and exits\n";
char* version = DS_Version();
std::cerr << "DeepSpeech " << version << "\n";
DS_FreeString(version);
@ -74,6 +77,7 @@ bool ProcessArgs(int argc, char** argv)
{"t", no_argument, nullptr, 't'},
{"extended", no_argument, nullptr, 'e'},
{"json", no_argument, nullptr, 'j'},
{"candidate_transcripts", required_argument, nullptr, 150},
{"stream", required_argument, nullptr, 's'},
{"version", no_argument, nullptr, 'v'},
{"help", no_argument, nullptr, 'h'},
@ -128,6 +132,10 @@ bool ProcessArgs(int argc, char** argv)
json_output = true;
break;
case 150:
json_candidate_transcripts = atoi(optarg);
break;
case 's':
stream_size = atoi(optarg);
break;

View File

@ -44,9 +44,115 @@ struct meta_word {
float duration;
};
char* metadataToString(Metadata* metadata);
std::vector<meta_word> WordsFromMetadata(Metadata* metadata);
char* JSONOutput(Metadata* metadata);
char*
CandidateTranscriptToString(const CandidateTranscript* transcript)
{
std::string retval = "";
for (int i = 0; i < transcript->num_tokens; i++) {
const TokenMetadata& token = transcript->tokens[i];
retval += token.text;
}
return strdup(retval.c_str());
}
std::vector<meta_word>
CandidateTranscriptToWords(const CandidateTranscript* transcript)
{
std::vector<meta_word> word_list;
std::string word = "";
float word_start_time = 0;
// Loop through each token
for (int i = 0; i < transcript->num_tokens; i++) {
const TokenMetadata& token = transcript->tokens[i];
// Append token to word if it's not a space
if (strcmp(token.text, u8" ") != 0) {
// Log the start time of the new word
if (word.length() == 0) {
word_start_time = token.start_time;
}
word.append(token.text);
}
// Word boundary is either a space or the last token in the array
if (strcmp(token.text, u8" ") == 0 || i == transcript->num_tokens-1) {
float word_duration = token.start_time - word_start_time;
if (word_duration < 0) {
word_duration = 0;
}
meta_word w;
w.word = word;
w.start_time = word_start_time;
w.duration = word_duration;
word_list.push_back(w);
// Reset
word = "";
word_start_time = 0;
}
}
return word_list;
}
std::string
CandidateTranscriptToJSON(const CandidateTranscript *transcript)
{
std::ostringstream out_string;
std::vector<meta_word> words = CandidateTranscriptToWords(transcript);
out_string << R"("metadata":{"confidence":)" << transcript->confidence << R"(},"words":[)";
for (int i = 0; i < words.size(); i++) {
meta_word w = words[i];
out_string << R"({"word":")" << w.word << R"(","time":)" << w.start_time << R"(,"duration":)" << w.duration << "}";
if (i < words.size() - 1) {
out_string << ",";
}
}
out_string << "]";
return out_string.str();
}
char*
MetadataToJSON(Metadata* result)
{
std::ostringstream out_string;
out_string << "{\n";
for (int j=0; j < result->num_transcripts; ++j) {
const CandidateTranscript *transcript = &result->transcripts[j];
if (j == 0) {
out_string << CandidateTranscriptToJSON(transcript);
if (result->num_transcripts > 1) {
out_string << ",\n" << R"("alternatives")" << ":[\n";
}
} else {
out_string << "{" << CandidateTranscriptToJSON(transcript) << "}";
if (j < result->num_transcripts - 1) {
out_string << ",\n";
} else {
out_string << "\n]";
}
}
}
out_string << "\n}\n";
return strdup(out_string.str().c_str());
}
ds_result
LocalDsSTT(ModelState* aCtx, const short* aBuffer, size_t aBufferSize,
@ -57,13 +163,13 @@ LocalDsSTT(ModelState* aCtx, const short* aBuffer, size_t aBufferSize,
clock_t ds_start_time = clock();
if (extended_output) {
Metadata *metadata = DS_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize);
res.string = metadataToString(metadata);
DS_FreeMetadata(metadata);
Metadata *result = DS_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize, 1);
res.string = CandidateTranscriptToString(&result->transcripts[0]);
DS_FreeMetadata(result);
} else if (json_output) {
Metadata *metadata = DS_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize);
res.string = JSONOutput(metadata);
DS_FreeMetadata(metadata);
Metadata *result = DS_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize, json_candidate_transcripts);
res.string = MetadataToJSON(result);
DS_FreeMetadata(result);
} else if (stream_size > 0) {
StreamingState* ctx;
int status = DS_CreateStream(aCtx, &ctx);
@ -278,87 +384,6 @@ ProcessFile(ModelState* context, const char* path, bool show_times)
}
}
char*
metadataToString(Metadata* metadata)
{
std::string retval = "";
for (int i = 0; i < metadata->num_items; i++) {
MetadataItem item = metadata->items[i];
retval += item.character;
}
return strdup(retval.c_str());
}
std::vector<meta_word>
WordsFromMetadata(Metadata* metadata)
{
std::vector<meta_word> word_list;
std::string word = "";
float word_start_time = 0;
// Loop through each character
for (int i = 0; i < metadata->num_items; i++) {
MetadataItem item = metadata->items[i];
// Append character to word if it's not a space
if (strcmp(item.character, u8" ") != 0) {
// Log the start time of the new word
if (word.length() == 0) {
word_start_time = item.start_time;
}
word.append(item.character);
}
// Word boundary is either a space or the last character in the array
if (strcmp(item.character, " ") == 0
|| strcmp(item.character, u8" ") == 0
|| i == metadata->num_items-1) {
float word_duration = item.start_time - word_start_time;
if (word_duration < 0) {
word_duration = 0;
}
meta_word w;
w.word = word;
w.start_time = word_start_time;
w.duration = word_duration;
word_list.push_back(w);
// Reset
word = "";
word_start_time = 0;
}
}
return word_list;
}
char*
JSONOutput(Metadata* metadata)
{
std::vector<meta_word> words = WordsFromMetadata(metadata);
std::ostringstream out_string;
out_string << R"({"metadata":{"confidence":)" << metadata->confidence << R"(},"words":[)";
for (int i = 0; i < words.size(); i++) {
meta_word w = words[i];
out_string << R"({"word":")" << w.word << R"(","time":)" << w.start_time << R"(,"duration":)" << w.duration << "}";
if (i < words.size() - 1) {
out_string << ",";
}
}
out_string << "]}\n";
return strdup(out_string.str().c_str());
}
int
main(int argc, char **argv)
{

View File

@ -15,14 +15,24 @@ else
GENERATE_DEBUG_SYMS :=
endif
ifeq ($(findstring _NT,$(OS)),_NT)
ARCHIVE_EXT := lib
else
ARCHIVE_EXT := a
endif
FIRST_PARTY := first_party.$(ARCHIVE_EXT)
THIRD_PARTY := third_party.$(ARCHIVE_EXT)
all: bindings
clean-keep-third-party:
rm -rf dist temp_build ds_ctcdecoder.egg-info
rm -f swigwrapper_wrap.cpp swigwrapper.py first_party.a
rm -f swigwrapper_wrap.cpp swigwrapper.py $(FIRST_PARTY)
clean: clean-keep-third-party
rm -f third_party.a
rm -f $(THIRD_PARTY)
rm workspace_status.cc
rm -fr bazel-out/
@ -31,17 +41,19 @@ workspace_status.cc:
../bazel_workspace_status_cmd.sh > bazel-out/stable-status.txt && \
../gen_workspace_status.sh > $@
# Enforce PATH here because swig calls from build_ext looses track of some
# variables over several runs
bindings: clean-keep-third-party workspace_status.cc
pip install --quiet $(PYTHON_PACKAGES) wheel==0.33.6 setuptools==39.1.0
AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
PATH=$(TOOLCHAIN):$$PATH AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
find temp_build -type f -name "*.o" -delete
AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py bdist_wheel $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py bdist_wheel $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
rm -rf temp_build
bindings-debug: clean-keep-third-party workspace_status.cc
pip install --quiet $(PYTHON_PACKAGES) wheel==0.33.6 setuptools==39.1.0
AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS) $(CXXFLAGS) -DDEBUG" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --debug --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
PATH=$(TOOLCHAIN):$$PATH AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS) -DDEBUG" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --debug --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
$(GENERATE_DEBUG_SYMS)
find temp_build -type f -name "*.o" -delete
AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS) $(CXXFLAGS) -DDEBUG" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py bdist_wheel $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS) -DDEBUG" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py bdist_wheel $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
rm -rf temp_build

View File

@ -9,14 +9,23 @@ import sys
from multiprocessing.dummy import Pool
ARGS = ['-DKENLM_MAX_ORDER=6', '-std=c++11', '-Wno-unused-local-typedefs', '-Wno-sign-compare']
OPT_ARGS = ['-O3', '-DNDEBUG']
DBG_ARGS = ['-O0', '-g', '-UNDEBUG', '-DDEBUG']
if sys.platform.startswith('win'):
ARGS = ['/nologo', '/D KENLM_MAX_ORDER=6', '/EHsc', '/source-charset:utf-8']
OPT_ARGS = ['/O2', '/MT', '/D NDEBUG']
DBG_ARGS = ['/Od', '/MTd', '/Zi', '/U NDEBUG', '/D DEBUG']
OPENFST_DIR = 'third_party/openfst-1.6.9-win'
else:
ARGS = ['-fPIC', '-DKENLM_MAX_ORDER=6', '-std=c++11', '-Wno-unused-local-typedefs', '-Wno-sign-compare']
OPT_ARGS = ['-O3', '-DNDEBUG']
DBG_ARGS = ['-O0', '-g', '-UNDEBUG', '-DDEBUG']
OPENFST_DIR = 'third_party/openfst-1.6.7'
INCLUDES = [
'..',
'../kenlm',
'third_party/openfst-1.6.7/src/include',
OPENFST_DIR + '/src/include',
'third_party/ThreadPool'
]
@ -24,7 +33,7 @@ KENLM_FILES = (glob.glob('../kenlm/util/*.cc')
+ glob.glob('../kenlm/lm/*.cc')
+ glob.glob('../kenlm/util/double-conversion/*.cc'))
KENLM_FILES += glob.glob('third_party/openfst-1.6.7/src/lib/*.cc')
KENLM_FILES += glob.glob(OPENFST_DIR + '/src/lib/*.cc')
KENLM_FILES = [
fn for fn in KENLM_FILES
@ -42,7 +51,10 @@ CTC_DECODER_FILES = [
def build_archive(srcs=[], out_name='', build_dir='temp_build/temp_build', debug=False, num_parallel=1):
compiler = os.environ.get('CXX', 'g++')
if sys.platform.startswith('win'):
compiler = '"{}"'.format(compiler)
ar = os.environ.get('AR', 'ar')
libexe = os.environ.get('LIBEXE', 'lib.exe')
libtool = os.environ.get('LIBTOOL', 'libtool')
cflags = os.environ.get('CFLAGS', '') + os.environ.get('CXXFLAGS', '')
args = ARGS + (DBG_ARGS if debug else OPT_ARGS)
@ -59,13 +71,19 @@ def build_archive(srcs=[], out_name='', build_dir='temp_build/temp_build', debug
if os.path.exists(outfile):
return
cmd = '{cc} -fPIC -c {cflags} {args} {includes} {infile} -o {outfile}'.format(
if sys.platform.startswith('win'):
file = '"{}"'.format(file.replace('\\', '/'))
output = '/Fo"{}"'.format(outfile.replace('\\', '/'))
else:
output = '-o ' + outfile
cmd = '{cc} -c {cflags} {args} {includes} {infile} {output}'.format(
cc=compiler,
cflags=cflags,
args=' '.join(args),
includes=' '.join('-I' + i for i in INCLUDES),
infile=file,
outfile=outfile,
output=output,
)
print(cmd)
subprocess.check_call(shlex.split(cmd))
@ -82,6 +100,14 @@ def build_archive(srcs=[], out_name='', build_dir='temp_build/temp_build', debug
)
print(cmd)
subprocess.check_call(shlex.split(cmd))
elif sys.platform.startswith('win'):
cmd = '"{libexe}" /OUT:"{outfile}" {infiles} /MACHINE:X64 /NOLOGO'.format(
libexe=libexe,
outfile=out_name,
infiles=' '.join(obj_files))
cmd = cmd.replace('\\', '/')
print(cmd)
subprocess.check_call(shlex.split(cmd))
else:
cmd = '{ar} rcs {outfile} {infiles}'.format(
ar=ar,

View File

@ -157,7 +157,7 @@ DecoderState::next(const double *probs,
}
std::vector<Output>
DecoderState::decode() const
DecoderState::decode(size_t num_results) const
{
std::vector<PathTrie*> prefixes_copy = prefixes_;
std::unordered_map<const PathTrie*, float> scores;
@ -181,16 +181,12 @@ DecoderState::decode() const
}
using namespace std::placeholders;
size_t num_prefixes = std::min(prefixes_copy.size(), beam_size_);
size_t num_returned = std::min(prefixes_copy.size(), num_results);
std::partial_sort(prefixes_copy.begin(),
prefixes_copy.begin() + num_prefixes,
prefixes_copy.begin() + num_returned,
prefixes_copy.end(),
std::bind(prefix_compare_external, _1, _2, scores));
//TODO: expose this as an API parameter
const size_t top_paths = 1;
size_t num_returned = std::min(num_prefixes, top_paths);
std::vector<Output> outputs;
outputs.reserve(num_returned);

View File

@ -60,13 +60,16 @@ public:
int time_dim,
int class_dim);
/* Get transcription from current decoder state
/* Get up to num_results transcriptions from current decoder state.
*
* Parameters:
* num_results: Number of beams to return.
*
* Return:
* A vector where each element is a pair of score and decoding result,
* in descending order.
*/
std::vector<Output> decode() const;
std::vector<Output> decode(size_t num_results=1) const;
};

View File

@ -54,8 +54,15 @@ def maybe_rebuild(srcs, out_name, build_dir):
project_version = read('../../VERSION').strip()
build_dir = 'temp_build/temp_build'
third_party_build = 'third_party.a'
ctc_decoder_build = 'first_party.a'
if sys.platform.startswith('win'):
archive_ext = 'lib'
else:
archive_ext = 'a'
third_party_build = 'third_party.{}'.format(archive_ext)
ctc_decoder_build = 'first_party.{}'.format(archive_ext)
maybe_rebuild(KENLM_FILES, third_party_build, build_dir)
maybe_rebuild(CTC_DECODER_FILES, ctc_decoder_build, build_dir)

View File

@ -60,7 +60,7 @@ using std::vector;
When batch_buffer is full, we do a single step through the acoustic model
and accumulate the intermediate decoding state in the DecoderState structure.
When finishStream() is called, we return the corresponding transcription from
When finishStream() is called, we return the corresponding transcript from
the current decoder state.
*/
struct StreamingState {
@ -78,9 +78,10 @@ struct StreamingState {
void feedAudioContent(const short* buffer, unsigned int buffer_size);
char* intermediateDecode() const;
Metadata* intermediateDecodeWithMetadata(unsigned int num_results) const;
void finalizeStream();
char* finishStream();
Metadata* finishStreamWithMetadata();
Metadata* finishStreamWithMetadata(unsigned int num_results);
void processAudioWindow(const vector<float>& buf);
void processMfccWindow(const vector<float>& buf);
@ -136,6 +137,12 @@ StreamingState::intermediateDecode() const
return model_->decode(decoder_state_);
}
Metadata*
StreamingState::intermediateDecodeWithMetadata(unsigned int num_results) const
{
return model_->decode_metadata(decoder_state_, num_results);
}
char*
StreamingState::finishStream()
{
@ -144,10 +151,10 @@ StreamingState::finishStream()
}
Metadata*
StreamingState::finishStreamWithMetadata()
StreamingState::finishStreamWithMetadata(unsigned int num_results)
{
finalizeStream();
return model_->decode_metadata(decoder_state_);
return model_->decode_metadata(decoder_state_, num_results);
}
void
@ -402,6 +409,13 @@ DS_IntermediateDecode(const StreamingState* aSctx)
return aSctx->intermediateDecode();
}
Metadata*
DS_IntermediateDecodeWithMetadata(const StreamingState* aSctx,
unsigned int aNumResults)
{
return aSctx->intermediateDecodeWithMetadata(aNumResults);
}
char*
DS_FinishStream(StreamingState* aSctx)
{
@ -411,11 +425,12 @@ DS_FinishStream(StreamingState* aSctx)
}
Metadata*
DS_FinishStreamWithMetadata(StreamingState* aSctx)
DS_FinishStreamWithMetadata(StreamingState* aSctx,
unsigned int aNumResults)
{
Metadata* metadata = aSctx->finishStreamWithMetadata();
Metadata* result = aSctx->finishStreamWithMetadata(aNumResults);
DS_FreeStream(aSctx);
return metadata;
return result;
}
StreamingState*
@ -444,10 +459,11 @@ DS_SpeechToText(ModelState* aCtx,
Metadata*
DS_SpeechToTextWithMetadata(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize)
unsigned int aBufferSize,
unsigned int aNumResults)
{
StreamingState* ctx = CreateStreamAndFeedAudioContent(aCtx, aBuffer, aBufferSize);
return DS_FinishStreamWithMetadata(ctx);
return DS_FinishStreamWithMetadata(ctx, aNumResults);
}
void
@ -460,11 +476,16 @@ void
DS_FreeMetadata(Metadata* m)
{
if (m) {
for (int i = 0; i < m->num_items; ++i) {
free(m->items[i].character);
for (int i = 0; i < m->num_transcripts; ++i) {
for (int j = 0; j < m->transcripts[i].num_tokens; ++j) {
free((void*)m->transcripts[i].tokens[j].text);
}
free((void*)m->transcripts[i].tokens);
}
delete[] m->items;
delete m;
free((void*)m->transcripts);
free(m);
}
}

View File

@ -20,32 +20,43 @@ typedef struct ModelState ModelState;
typedef struct StreamingState StreamingState;
/**
* @brief Stores each individual character, along with its timing information
* @brief Stores text of an individual token, along with its timing information
*/
typedef struct MetadataItem {
/** The character generated for transcription */
char* character;
typedef struct TokenMetadata {
/** The text corresponding to this token */
const char* const text;
/** Position of the character in units of 20ms */
int timestep;
/** Position of the token in units of 20ms */
const unsigned int timestep;
/** Position of the character in seconds */
float start_time;
} MetadataItem;
/** Position of the token in seconds */
const float start_time;
} TokenMetadata;
/**
* @brief Stores the entire CTC output as an array of character metadata objects
* @brief A single transcript computed by the model, including a confidence
* value and the metadata for its constituent tokens.
*/
typedef struct CandidateTranscript {
/** Array of TokenMetadata objects */
const TokenMetadata* const tokens;
/** Size of the tokens array */
const unsigned int num_tokens;
/** Approximated confidence value for this transcript. This is roughly the
* sum of the acoustic model logit values for each timestep/character that
* contributed to the creation of this transcript.
*/
const double confidence;
} CandidateTranscript;
/**
* @brief An array of CandidateTranscript objects computed by the model.
*/
typedef struct Metadata {
/** List of items */
MetadataItem* items;
/** Size of the list of items */
int num_items;
/** Approximated confidence value for this transcription. This is roughly the
* sum of the acoustic model logit values for each timestep/character that
* contributed to the creation of this transcription.
*/
double confidence;
/** Array of CandidateTranscript objects */
const CandidateTranscript* const transcripts;
/** Size of the transcripts array */
const unsigned int num_transcripts;
} Metadata;
enum DeepSpeech_Error_Codes
@ -164,7 +175,7 @@ int DS_SetScorerAlphaBeta(ModelState* aCtx,
float aBeta);
/**
* @brief Use the DeepSpeech model to perform Speech-To-Text.
* @brief Use the DeepSpeech model to convert speech to text.
*
* @param aCtx The ModelState pointer for the model to use.
* @param aBuffer A 16-bit, mono raw audio signal at the appropriate
@ -180,21 +191,25 @@ char* DS_SpeechToText(ModelState* aCtx,
unsigned int aBufferSize);
/**
* @brief Use the DeepSpeech model to perform Speech-To-Text and output metadata
* about the results.
* @brief Use the DeepSpeech model to convert speech to text and output results
* including metadata.
*
* @param aCtx The ModelState pointer for the model to use.
* @param aBuffer A 16-bit, mono raw audio signal at the appropriate
* sample rate (matching what the model was trained on).
* @param aBufferSize The number of samples in the audio signal.
* @param aNumResults The maximum number of CandidateTranscript structs to return. Returned value might be smaller than this.
*
* @return Outputs a struct of individual letters along with their timing information.
* The user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. Returns NULL on error.
* @return Metadata struct containing multiple CandidateTranscript structs. Each
* transcript has per-token metadata including timing information. The
* user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}.
* Returns NULL on error.
*/
DEEPSPEECH_EXPORT
Metadata* DS_SpeechToTextWithMetadata(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize);
unsigned int aBufferSize,
unsigned int aNumResults);
/**
* @brief Create a new streaming inference state. The streaming state returned
@ -236,8 +251,24 @@ DEEPSPEECH_EXPORT
char* DS_IntermediateDecode(const StreamingState* aSctx);
/**
* @brief Signal the end of an audio signal to an ongoing streaming
* inference, returns the STT result over the whole audio signal.
* @brief Compute the intermediate decoding of an ongoing streaming inference,
* return results including metadata.
*
* @param aSctx A streaming state pointer returned by {@link DS_CreateStream()}.
* @param aNumResults The number of candidate transcripts to return.
*
* @return Metadata struct containing multiple candidate transcripts. Each transcript
* has per-token metadata including timing information. The user is
* responsible for freeing Metadata by calling {@link DS_FreeMetadata()}.
* Returns NULL on error.
*/
DEEPSPEECH_EXPORT
Metadata* DS_IntermediateDecodeWithMetadata(const StreamingState* aSctx,
unsigned int aNumResults);
/**
* @brief Compute the final decoding of an ongoing streaming inference and return
* the result. Signals the end of an ongoing streaming inference.
*
* @param aSctx A streaming state pointer returned by {@link DS_CreateStream()}.
*
@ -250,18 +281,23 @@ DEEPSPEECH_EXPORT
char* DS_FinishStream(StreamingState* aSctx);
/**
* @brief Signal the end of an audio signal to an ongoing streaming
* inference, returns per-letter metadata.
* @brief Compute the final decoding of an ongoing streaming inference and return
* results including metadata. Signals the end of an ongoing streaming
* inference.
*
* @param aSctx A streaming state pointer returned by {@link DS_CreateStream()}.
* @param aNumResults The number of candidate transcripts to return.
*
* @return Outputs a struct of individual letters along with their timing information.
* The user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. Returns NULL on error.
* @return Metadata struct containing multiple candidate transcripts. Each transcript
* has per-token metadata including timing information. The user is
* responsible for freeing Metadata by calling {@link DS_FreeMetadata()}.
* Returns NULL on error.
*
* @note This method will free the state pointer (@p aSctx).
*/
DEEPSPEECH_EXPORT
Metadata* DS_FinishStreamWithMetadata(StreamingState* aSctx);
Metadata* DS_FinishStreamWithMetadata(StreamingState* aSctx,
unsigned int aNumResults);
/**
* @brief Destroy a streaming state without decoding the computed logits. This

View File

@ -10,6 +10,7 @@ TOOL_CC := gcc
TOOL_CXX := c++
TOOL_LD := ld
TOOL_LDD := ldd
TOOL_LIBEXE :=
DEEPSPEECH_BIN := deepspeech
CFLAGS_DEEPSPEECH := -std=c++11 -o $(DEEPSPEECH_BIN)
@ -37,9 +38,10 @@ endif
ifeq ($(TARGET),host-win)
DEEPSPEECH_BIN := deepspeech.exe
TOOLCHAIN := '$(VCINSTALLDIR)\bin\amd64\'
TOOL_CC := cl.exe
TOOL_CXX := cl.exe
TOOL_LD := link.exe
TOOL_CC := cl.exe
TOOL_CXX := cl.exe
TOOL_LD := link.exe
TOOL_LIBEXE := lib.exe
LINK_DEEPSPEECH := $(TFDIR)\bazel-bin\native_client\libdeepspeech.so.if.lib
LINK_PATH_DEEPSPEECH :=
CFLAGS_DEEPSPEECH := -nologo -Fe$(DEEPSPEECH_BIN)
@ -113,6 +115,7 @@ CC := $(TOOLCHAIN)$(TOOL_CC)
CXX := $(TOOLCHAIN)$(TOOL_CXX)
LD := $(TOOLCHAIN)$(TOOL_LD)
LDD := $(TOOLCHAIN)$(TOOL_LDD) $(TOOLCHAIN_LDD_OPTS)
LIBEXE := $(TOOLCHAIN)$(TOOL_LIBEXE)
RPATH_PYTHON := '-Wl,-rpath,\$$ORIGIN/lib/' $(LDFLAGS_RPATH)
RPATH_NODEJS := '-Wl,-rpath,$$\$$ORIGIN/../'

View File

@ -89,38 +89,9 @@ namespace DeepSpeechClient
/// <param name="resultCode">Native result code.</param>
private void EvaluateResultCode(ErrorCodes resultCode)
{
switch (resultCode)
if (resultCode != ErrorCodes.DS_ERR_OK)
{
case ErrorCodes.DS_ERR_OK:
break;
case ErrorCodes.DS_ERR_NO_MODEL:
throw new ArgumentException("Missing model information.");
case ErrorCodes.DS_ERR_INVALID_ALPHABET:
throw new ArgumentException("Invalid alphabet embedded in model. (Data corruption?)");
case ErrorCodes.DS_ERR_INVALID_SHAPE:
throw new ArgumentException("Invalid model shape.");
case ErrorCodes.DS_ERR_INVALID_SCORER:
throw new ArgumentException("Invalid scorer file.");
case ErrorCodes.DS_ERR_FAIL_INIT_MMAP:
throw new ArgumentException("Failed to initialize memory mapped model.");
case ErrorCodes.DS_ERR_FAIL_INIT_SESS:
throw new ArgumentException("Failed to initialize the session.");
case ErrorCodes.DS_ERR_FAIL_INTERPRETER:
throw new ArgumentException("Interpreter failed.");
case ErrorCodes.DS_ERR_FAIL_RUN_SESS:
throw new ArgumentException("Failed to run the session.");
case ErrorCodes.DS_ERR_FAIL_CREATE_STREAM:
throw new ArgumentException("Error creating the stream.");
case ErrorCodes.DS_ERR_FAIL_READ_PROTOBUF:
throw new ArgumentException("Error reading the proto buffer model file.");
case ErrorCodes.DS_ERR_FAIL_CREATE_SESS:
throw new ArgumentException("Error failed to create session.");
case ErrorCodes.DS_ERR_MODEL_INCOMPATIBLE:
throw new ArgumentException("Error incompatible model.");
case ErrorCodes.DS_ERR_SCORER_NOT_ENABLED:
throw new ArgumentException("External scorer is not enabled.");
default:
throw new ArgumentException("Unknown error, please make sure you are using the correct native binary.");
throw new ArgumentException(NativeImp.DS_ErrorCodeToErrorMessage((int)resultCode).PtrToString());
}
}
@ -140,7 +111,6 @@ namespace DeepSpeechClient
/// <exception cref="FileNotFoundException">Thrown when cannot find the scorer file.</exception>
public unsafe void EnableExternalScorer(string aScorerPath)
{
string exceptionMessage = null;
if (string.IsNullOrWhiteSpace(aScorerPath))
{
throw new FileNotFoundException("Path to the scorer file cannot be empty.");
@ -199,13 +169,14 @@ namespace DeepSpeechClient
}
/// <summary>
/// Closes the ongoing streaming inference, returns the STT result over the whole audio signal.
/// Closes the ongoing streaming inference, returns the STT result over the whole audio signal, including metadata.
/// </summary>
/// <param name="stream">Instance of the stream to finish.</param>
/// <param name="aNumResults">Maximum number of candidate transcripts to return. Returned list might be smaller than this.</param>
/// <returns>The extended metadata result.</returns>
public unsafe Metadata FinishStreamWithMetadata(DeepSpeechStream stream)
public unsafe Metadata FinishStreamWithMetadata(DeepSpeechStream stream, uint aNumResults)
{
return NativeImp.DS_FinishStreamWithMetadata(stream.GetNativePointer()).PtrToMetadata();
return NativeImp.DS_FinishStreamWithMetadata(stream.GetNativePointer(), aNumResults).PtrToMetadata();
}
/// <summary>
@ -218,6 +189,17 @@ namespace DeepSpeechClient
return NativeImp.DS_IntermediateDecode(stream.GetNativePointer()).PtrToString();
}
/// <summary>
/// Computes the intermediate decoding of an ongoing streaming inference, including metadata.
/// </summary>
/// <param name="stream">Instance of the stream to decode.</param>
/// <param name="aNumResults">Maximum number of candidate transcripts to return. Returned list might be smaller than this.</param>
/// <returns>The STT intermediate result.</returns>
public unsafe Metadata IntermediateDecodeWithMetadata(DeepSpeechStream stream, uint aNumResults)
{
return NativeImp.DS_IntermediateDecodeWithMetadata(stream.GetNativePointer(), aNumResults).PtrToMetadata();
}
/// <summary>
/// Return version of this library. The returned version is a semantic version
/// (SemVer 2.0.0).
@ -261,14 +243,15 @@ namespace DeepSpeechClient
}
/// <summary>
/// Use the DeepSpeech model to perform Speech-To-Text.
/// Use the DeepSpeech model to perform Speech-To-Text, return results including metadata.
/// </summary>
/// <param name="aBuffer">A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).</param>
/// <param name="aBufferSize">The number of samples in the audio signal.</param>
/// <param name="aNumResults">Maximum number of candidate transcripts to return. Returned list might be smaller than this.</param>
/// <returns>The extended metadata. Returns NULL on error.</returns>
public unsafe Metadata SpeechToTextWithMetadata(short[] aBuffer, uint aBufferSize)
public unsafe Metadata SpeechToTextWithMetadata(short[] aBuffer, uint aBufferSize, uint aNumResults)
{
return NativeImp.DS_SpeechToTextWithMetadata(_modelStatePP, aBuffer, aBufferSize).PtrToMetadata();
return NativeImp.DS_SpeechToTextWithMetadata(_modelStatePP, aBuffer, aBufferSize, aNumResults).PtrToMetadata();
}
#endregion

View File

@ -50,11 +50,13 @@
<Compile Include="Extensions\NativeExtensions.cs" />
<Compile Include="Models\DeepSpeechStream.cs" />
<Compile Include="Models\Metadata.cs" />
<Compile Include="Models\MetadataItem.cs" />
<Compile Include="Models\CandidateTranscript.cs" />
<Compile Include="Models\TokenMetadata.cs" />
<Compile Include="NativeImp.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="Structs\Metadata.cs" />
<Compile Include="Structs\MetadataItem.cs" />
<Compile Include="Structs\CandidateTranscript.cs" />
<Compile Include="Structs\TokenMetadata.cs" />
</ItemGroup>
<ItemGroup />
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />

View File

@ -26,35 +26,68 @@ namespace DeepSpeechClient.Extensions
}
/// <summary>
/// Converts a pointer into managed metadata object.
/// Converts a pointer into managed TokenMetadata object.
/// </summary>
/// <param name="intPtr">Native pointer.</param>
/// <returns>TokenMetadata managed object.</returns>
private static Models.TokenMetadata PtrToTokenMetadata(this IntPtr intPtr)
{
var token = Marshal.PtrToStructure<TokenMetadata>(intPtr);
var managedToken = new Models.TokenMetadata
{
Timestep = token.timestep,
StartTime = token.start_time,
Text = token.text.PtrToString(releasePtr: false)
};
return managedToken;
}
/// <summary>
/// Converts a pointer into managed CandidateTranscript object.
/// </summary>
/// <param name="intPtr">Native pointer.</param>
/// <returns>CandidateTranscript managed object.</returns>
private static Models.CandidateTranscript PtrToCandidateTranscript(this IntPtr intPtr)
{
var managedTranscript = new Models.CandidateTranscript();
var transcript = Marshal.PtrToStructure<CandidateTranscript>(intPtr);
managedTranscript.Tokens = new Models.TokenMetadata[transcript.num_tokens];
managedTranscript.Confidence = transcript.confidence;
//we need to manually read each item from the native ptr using its size
var sizeOfTokenMetadata = Marshal.SizeOf(typeof(TokenMetadata));
for (int i = 0; i < transcript.num_tokens; i++)
{
managedTranscript.Tokens[i] = transcript.tokens.PtrToTokenMetadata();
transcript.tokens += sizeOfTokenMetadata;
}
return managedTranscript;
}
/// <summary>
/// Converts a pointer into managed Metadata object.
/// </summary>
/// <param name="intPtr">Native pointer.</param>
/// <returns>Metadata managed object.</returns>
internal static Models.Metadata PtrToMetadata(this IntPtr intPtr)
{
var managedMetaObject = new Models.Metadata();
var metaData = (Metadata)Marshal.PtrToStructure(intPtr, typeof(Metadata));
managedMetaObject.Items = new Models.MetadataItem[metaData.num_items];
managedMetaObject.Confidence = metaData.confidence;
var managedMetadata = new Models.Metadata();
var metadata = Marshal.PtrToStructure<Metadata>(intPtr);
managedMetadata.Transcripts = new Models.CandidateTranscript[metadata.num_transcripts];
//we need to manually read each item from the native ptr using its size
var sizeOfMetaItem = Marshal.SizeOf(typeof(MetadataItem));
for (int i = 0; i < metaData.num_items; i++)
var sizeOfCandidateTranscript = Marshal.SizeOf(typeof(CandidateTranscript));
for (int i = 0; i < metadata.num_transcripts; i++)
{
var tempItem = Marshal.PtrToStructure<MetadataItem>(metaData.items);
managedMetaObject.Items[i] = new Models.MetadataItem
{
Timestep = tempItem.timestep,
StartTime = tempItem.start_time,
Character = tempItem.character.PtrToString(releasePtr: false)
};
//we keep the offset on each read
metaData.items += sizeOfMetaItem;
managedMetadata.Transcripts[i] = metadata.transcripts.PtrToCandidateTranscript();
metadata.transcripts += sizeOfCandidateTranscript;
}
NativeImp.DS_FreeMetadata(intPtr);
return managedMetaObject;
return managedMetadata;
}
}
}

View File

@ -68,13 +68,15 @@ namespace DeepSpeechClient.Interfaces
uint aBufferSize);
/// <summary>
/// Use the DeepSpeech model to perform Speech-To-Text.
/// Use the DeepSpeech model to perform Speech-To-Text, return results including metadata.
/// </summary>
/// <param name="aBuffer">A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).</param>
/// <param name="aBufferSize">The number of samples in the audio signal.</param>
/// <param name="aNumResults">Maximum number of candidate transcripts to return. Returned list might be smaller than this.</param>
/// <returns>The extended metadata. Returns NULL on error.</returns>
unsafe Metadata SpeechToTextWithMetadata(short[] aBuffer,
uint aBufferSize);
uint aBufferSize,
uint aNumResults);
/// <summary>
/// Destroy a streaming state without decoding the computed logits.
@ -102,6 +104,14 @@ namespace DeepSpeechClient.Interfaces
/// <returns>The STT intermediate result.</returns>
unsafe string IntermediateDecode(DeepSpeechStream stream);
/// <summary>
/// Computes the intermediate decoding of an ongoing streaming inference, including metadata.
/// </summary>
/// <param name="stream">Instance of the stream to decode.</param>
/// <param name="aNumResults">Maximum number of candidate transcripts to return. Returned list might be smaller than this.</param>
/// <returns>The extended metadata result.</returns>
unsafe Metadata IntermediateDecodeWithMetadata(DeepSpeechStream stream, uint aNumResults);
/// <summary>
/// Closes the ongoing streaming inference, returns the STT result over the whole audio signal.
/// </summary>
@ -110,10 +120,11 @@ namespace DeepSpeechClient.Interfaces
unsafe string FinishStream(DeepSpeechStream stream);
/// <summary>
/// Closes the ongoing streaming inference, returns the STT result over the whole audio signal.
/// Closes the ongoing streaming inference, returns the STT result over the whole audio signal, including metadata.
/// </summary>
/// <param name="stream">Instance of the stream to finish.</param>
/// <param name="aNumResults">Maximum number of candidate transcripts to return. Returned list might be smaller than this.</param>
/// <returns>The extended metadata result.</returns>
unsafe Metadata FinishStreamWithMetadata(DeepSpeechStream stream);
unsafe Metadata FinishStreamWithMetadata(DeepSpeechStream stream, uint aNumResults);
}
}

View File

@ -0,0 +1,17 @@
namespace DeepSpeechClient.Models
{
/// <summary>
/// Stores the entire CTC output as an array of character metadata objects.
/// </summary>
public class CandidateTranscript
{
/// <summary>
/// Approximated confidence value for this transcription.
/// </summary>
public double Confidence { get; set; }
/// <summary>
/// List of metada tokens containing text, timestep, and time offset.
/// </summary>
public TokenMetadata[] Tokens { get; set; }
}
}

View File

@ -6,12 +6,8 @@
public class Metadata
{
/// <summary>
/// Approximated confidence value for this transcription.
/// List of candidate transcripts.
/// </summary>
public double Confidence { get; set; }
/// <summary>
/// List of metada items containing char, timespet, and time offset.
/// </summary>
public MetadataItem[] Items { get; set; }
public CandidateTranscript[] Transcripts { get; set; }
}
}

View File

@ -3,12 +3,12 @@
/// <summary>
/// Stores each individual character, along with its timing information.
/// </summary>
public class MetadataItem
public class TokenMetadata
{
/// <summary>
/// Char of the current timestep.
/// </summary>
public string Character;
public string Text;
/// <summary>
/// Position of the character in units of 20ms.
/// </summary>

View File

@ -17,45 +17,49 @@ namespace DeepSpeechClient
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath,
ref IntPtr** pint);
ref IntPtr** pint);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal unsafe static extern IntPtr DS_ErrorCodeToErrorMessage(int aErrorCode);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal unsafe static extern uint DS_GetModelBeamWidth(IntPtr** aCtx);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal unsafe static extern ErrorCodes DS_SetModelBeamWidth(IntPtr** aCtx,
uint aBeamWidth);
uint aBeamWidth);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath,
uint aBeamWidth,
ref IntPtr** pint);
uint aBeamWidth,
ref IntPtr** pint);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal unsafe static extern int DS_GetModelSampleRate(IntPtr** aCtx);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal static unsafe extern ErrorCodes DS_EnableExternalScorer(IntPtr** aCtx,
string aScorerPath);
string aScorerPath);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal static unsafe extern ErrorCodes DS_DisableExternalScorer(IntPtr** aCtx);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal static unsafe extern ErrorCodes DS_SetScorerAlphaBeta(IntPtr** aCtx,
float aAlpha,
float aBeta);
float aAlpha,
float aBeta);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl,
CharSet = CharSet.Ansi, SetLastError = true)]
internal static unsafe extern IntPtr DS_SpeechToText(IntPtr** aCtx,
short[] aBuffer,
uint aBufferSize);
short[] aBuffer,
uint aBufferSize);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl, SetLastError = true)]
internal static unsafe extern IntPtr DS_SpeechToTextWithMetadata(IntPtr** aCtx,
short[] aBuffer,
uint aBufferSize);
short[] aBuffer,
uint aBufferSize,
uint aNumResults);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal static unsafe extern void DS_FreeModel(IntPtr** aCtx);
@ -76,18 +80,23 @@ namespace DeepSpeechClient
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl,
CharSet = CharSet.Ansi, SetLastError = true)]
internal static unsafe extern void DS_FeedAudioContent(IntPtr** aSctx,
short[] aBuffer,
uint aBufferSize);
short[] aBuffer,
uint aBufferSize);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal static unsafe extern IntPtr DS_IntermediateDecode(IntPtr** aSctx);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal static unsafe extern IntPtr DS_IntermediateDecodeWithMetadata(IntPtr** aSctx,
uint aNumResults);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl,
CharSet = CharSet.Ansi, SetLastError = true)]
internal static unsafe extern IntPtr DS_FinishStream(IntPtr** aSctx);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal static unsafe extern IntPtr DS_FinishStreamWithMetadata(IntPtr** aSctx);
internal static unsafe extern IntPtr DS_FinishStreamWithMetadata(IntPtr** aSctx,
uint aNumResults);
#endregion
}
}

View File

@ -0,0 +1,22 @@
using System;
using System.Runtime.InteropServices;
namespace DeepSpeechClient.Structs
{
[StructLayout(LayoutKind.Sequential)]
internal unsafe struct CandidateTranscript
{
/// <summary>
/// Native list of tokens.
/// </summary>
internal unsafe IntPtr tokens;
/// <summary>
/// Count of tokens from the native side.
/// </summary>
internal unsafe int num_tokens;
/// <summary>
/// Approximated confidence value for this transcription.
/// </summary>
internal unsafe double confidence;
}
}

View File

@ -7,16 +7,12 @@ namespace DeepSpeechClient.Structs
internal unsafe struct Metadata
{
/// <summary>
/// Native list of items.
/// Native list of candidate transcripts.
/// </summary>
internal unsafe IntPtr items;
internal unsafe IntPtr transcripts;
/// <summary>
/// Count of items from the native side.
/// Count of transcripts from the native side.
/// </summary>
internal unsafe int num_items;
/// <summary>
/// Approximated confidence value for this transcription.
/// </summary>
internal unsafe double confidence;
internal unsafe int num_transcripts;
}
}

View File

@ -4,12 +4,12 @@ using System.Runtime.InteropServices;
namespace DeepSpeechClient.Structs
{
[StructLayout(LayoutKind.Sequential)]
internal unsafe struct MetadataItem
internal unsafe struct TokenMetadata
{
/// <summary>
/// Native character.
/// Native text.
/// </summary>
internal unsafe IntPtr character;
internal unsafe IntPtr text;
/// <summary>
/// Position of the character in units of 20ms.
/// </summary>

View File

@ -21,14 +21,14 @@ namespace CSharpExamples
static string GetArgument(IEnumerable<string> args, string option)
=> args.SkipWhile(i => i != option).Skip(1).Take(1).FirstOrDefault();
static string MetadataToString(Metadata meta)
static string MetadataToString(CandidateTranscript transcript)
{
var nl = Environment.NewLine;
string retval =
Environment.NewLine + $"Recognized text: {string.Join("", meta?.Items?.Select(x => x.Character))} {nl}"
+ $"Confidence: {meta?.Confidence} {nl}"
+ $"Item count: {meta?.Items?.Length} {nl}"
+ string.Join(nl, meta?.Items?.Select(x => $"Timestep : {x.Timestep} TimeOffset: {x.StartTime} Char: {x.Character}"));
Environment.NewLine + $"Recognized text: {string.Join("", transcript?.Tokens?.Select(x => x.Text))} {nl}"
+ $"Confidence: {transcript?.Confidence} {nl}"
+ $"Item count: {transcript?.Tokens?.Length} {nl}"
+ string.Join(nl, transcript?.Tokens?.Select(x => $"Timestep : {x.Timestep} TimeOffset: {x.StartTime} Char: {x.Text}"));
return retval;
}
@ -75,8 +75,8 @@ namespace CSharpExamples
if (extended)
{
Metadata metaResult = sttClient.SpeechToTextWithMetadata(waveBuffer.ShortBuffer,
Convert.ToUInt32(waveBuffer.MaxSize / 2));
speechResult = MetadataToString(metaResult);
Convert.ToUInt32(waveBuffer.MaxSize / 2), 1);
speechResult = MetadataToString(metaResult.Transcripts[0]);
}
else
{

View File

@ -6,6 +6,8 @@
%}
%include "typemaps.i"
%include "enums.swg"
%javaconst(1);
%include "arrays_java.i"
// apply to DS_FeedAudioContent and DS_SpeechToText
@ -15,21 +17,29 @@
%pointer_functions(ModelState*, modelstatep);
%pointer_functions(StreamingState*, streamingstatep);
%typemap(newfree) char* "DS_FreeString($1);";
%include "carrays.i"
%array_functions(struct MetadataItem, metadataItem_array);
%extend struct CandidateTranscript {
/**
* Retrieve one TokenMetadata element
*
* @param i Array index of the TokenMetadata to get
*
* @return The TokenMetadata requested or null
*/
const TokenMetadata& getToken(int i) {
return self->tokens[i];
}
}
%extend struct Metadata {
/**
* Retrieve one MetadataItem element
* Retrieve one CandidateTranscript element
*
* @param i Array index of the MetadataItem to get
* @param i Array index of the CandidateTranscript to get
*
* @return The MetadataItem requested or null
* @return The CandidateTranscript requested or null
*/
MetadataItem getItem(int i) {
return metadataItem_array_getitem(self->items, i);
const CandidateTranscript& getTranscript(int i) {
return self->transcripts[i];
}
~Metadata() {
@ -37,14 +47,18 @@
}
}
%nodefaultdtor Metadata;
%nodefaultctor Metadata;
%nodefaultctor MetadataItem;
%nodefaultdtor MetadataItem;
%nodefaultdtor Metadata;
%nodefaultctor CandidateTranscript;
%nodefaultdtor CandidateTranscript;
%nodefaultctor TokenMetadata;
%nodefaultdtor TokenMetadata;
%typemap(newfree) char* "DS_FreeString($1);";
%newobject DS_SpeechToText;
%newobject DS_IntermediateDecode;
%newobject DS_FinishStream;
%newobject DS_ErrorCodeToErrorMessage;
%rename ("%(strip:[DS_])s") "";

View File

@ -12,7 +12,7 @@ import org.junit.runners.MethodSorters;
import static org.junit.Assert.*;
import org.mozilla.deepspeech.libdeepspeech.DeepSpeechModel;
import org.mozilla.deepspeech.libdeepspeech.Metadata;
import org.mozilla.deepspeech.libdeepspeech.CandidateTranscript;
import java.io.RandomAccessFile;
import java.io.FileNotFoundException;
@ -61,10 +61,10 @@ public class BasicTest {
m.freeModel();
}
private String metadataToString(Metadata m) {
private String candidateTranscriptToString(CandidateTranscript t) {
String retval = "";
for (int i = 0; i < m.getNum_items(); ++i) {
retval += m.getItem(i).getCharacter();
for (int i = 0; i < t.getNum_tokens(); ++i) {
retval += t.getToken(i).getText();
}
return retval;
}
@ -97,7 +97,7 @@ public class BasicTest {
ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().get(shorts);
if (extendedMetadata) {
return metadataToString(m.sttWithMetadata(shorts, shorts.length));
return candidateTranscriptToString(m.sttWithMetadata(shorts, shorts.length, 1).getTranscript(0));
} else {
return m.stt(shorts, shorts.length);
}

View File

@ -11,8 +11,15 @@ public class DeepSpeechModel {
}
// FIXME: We should have something better than those SWIGTYPE_*
SWIGTYPE_p_p_ModelState _mspp;
SWIGTYPE_p_ModelState _msp;
private SWIGTYPE_p_p_ModelState _mspp;
private SWIGTYPE_p_ModelState _msp;
private void evaluateErrorCode(int errorCode) {
DeepSpeech_Error_Codes code = DeepSpeech_Error_Codes.swigToEnum(errorCode);
if (code != DeepSpeech_Error_Codes.ERR_OK) {
throw new RuntimeException("Error: " + impl.ErrorCodeToErrorMessage(errorCode) + " (0x" + Integer.toHexString(errorCode) + ").");
}
}
/**
* @brief An object providing an interface to a trained DeepSpeech model.
@ -20,10 +27,12 @@ public class DeepSpeechModel {
* @constructor
*
* @param modelPath The path to the frozen model graph.
*
* @throws RuntimeException on failure.
*/
public DeepSpeechModel(String modelPath) {
this._mspp = impl.new_modelstatep();
impl.CreateModel(modelPath, this._mspp);
evaluateErrorCode(impl.CreateModel(modelPath, this._mspp));
this._msp = impl.modelstatep_value(this._mspp);
}
@ -43,10 +52,10 @@ public class DeepSpeechModel {
* @param aBeamWidth The beam width used by the model. A larger beam width value
* generates better results at the cost of decoding time.
*
* @return Zero on success, non-zero on failure.
* @throws RuntimeException on failure.
*/
public int setBeamWidth(long beamWidth) {
return impl.SetModelBeamWidth(this._msp, beamWidth);
public void setBeamWidth(long beamWidth) {
evaluateErrorCode(impl.SetModelBeamWidth(this._msp, beamWidth));
}
/**
@ -70,19 +79,19 @@ public class DeepSpeechModel {
*
* @param scorer The path to the external scorer file.
*
* @return Zero on success, non-zero on failure (invalid arguments).
* @throws RuntimeException on failure.
*/
public void enableExternalScorer(String scorer) {
impl.EnableExternalScorer(this._msp, scorer);
evaluateErrorCode(impl.EnableExternalScorer(this._msp, scorer));
}
/**
* @brief Disable decoding using an external scorer.
*
* @return Zero on success, non-zero on failure (invalid arguments).
* @throws RuntimeException on failure.
*/
public void disableExternalScorer() {
impl.DisableExternalScorer(this._msp);
evaluateErrorCode(impl.DisableExternalScorer(this._msp));
}
/**
@ -91,10 +100,10 @@ public class DeepSpeechModel {
* @param alpha The alpha hyperparameter of the decoder. Language model weight.
* @param beta The beta hyperparameter of the decoder. Word insertion weight.
*
* @return Zero on success, non-zero on failure (invalid arguments).
* @throws RuntimeException on failure.
*/
public void setScorerAlphaBeta(float alpha, float beta) {
impl.SetScorerAlphaBeta(this._msp, alpha, beta);
evaluateErrorCode(impl.SetScorerAlphaBeta(this._msp, alpha, beta));
}
/*
@ -117,11 +126,13 @@ public class DeepSpeechModel {
* @param buffer A 16-bit, mono raw audio signal at the appropriate
* sample rate (matching what the model was trained on).
* @param buffer_size The number of samples in the audio signal.
* @param num_results Maximum number of candidate transcripts to return. Returned list might be smaller than this.
*
* @return Outputs a Metadata object of individual letters along with their timing information.
* @return Metadata struct containing multiple candidate transcripts. Each transcript
* has per-token metadata including timing information.
*/
public Metadata sttWithMetadata(short[] buffer, int buffer_size) {
return impl.SpeechToTextWithMetadata(this._msp, buffer, buffer_size);
public Metadata sttWithMetadata(short[] buffer, int buffer_size, int num_results) {
return impl.SpeechToTextWithMetadata(this._msp, buffer, buffer_size, num_results);
}
/**
@ -130,10 +141,12 @@ public class DeepSpeechModel {
* and finishStream().
*
* @return An opaque object that represents the streaming state.
*
* @throws RuntimeException on failure.
*/
public DeepSpeechStreamingState createStream() {
SWIGTYPE_p_p_StreamingState ssp = impl.new_streamingstatep();
impl.CreateStream(this._msp, ssp);
evaluateErrorCode(impl.CreateStream(this._msp, ssp));
return new DeepSpeechStreamingState(impl.streamingstatep_value(ssp));
}
@ -161,8 +174,20 @@ public class DeepSpeechModel {
}
/**
* @brief Signal the end of an audio signal to an ongoing streaming
* inference, returns the STT result over the whole audio signal.
* @brief Compute the intermediate decoding of an ongoing streaming inference.
*
* @param ctx A streaming state pointer returned by createStream().
* @param num_results Maximum number of candidate transcripts to return. Returned list might be smaller than this.
*
* @return The STT intermediate result.
*/
public Metadata intermediateDecodeWithMetadata(DeepSpeechStreamingState ctx, int num_results) {
return impl.IntermediateDecodeWithMetadata(ctx.get(), num_results);
}
/**
* @brief Compute the final decoding of an ongoing streaming inference and return
* the result. Signals the end of an ongoing streaming inference.
*
* @param ctx A streaming state pointer returned by createStream().
*
@ -175,16 +200,19 @@ public class DeepSpeechModel {
}
/**
* @brief Signal the end of an audio signal to an ongoing streaming
* inference, returns per-letter metadata.
* @brief Compute the final decoding of an ongoing streaming inference and return
* the results including metadata. Signals the end of an ongoing streaming
* inference.
*
* @param ctx A streaming state pointer returned by createStream().
* @param num_results Maximum number of candidate transcripts to return. Returned list might be smaller than this.
*
* @return Outputs a Metadata object of individual letters along with their timing information.
* @return Metadata struct containing multiple candidate transcripts. Each transcript
* has per-token metadata including timing information.
*
* @note This method will free the state pointer (@p ctx).
*/
public Metadata finishStreamWithMetadata(DeepSpeechStreamingState ctx) {
return impl.FinishStreamWithMetadata(ctx.get());
public Metadata finishStreamWithMetadata(DeepSpeechStreamingState ctx, int num_results) {
return impl.FinishStreamWithMetadata(ctx.get(), num_results);
}
}

View File

@ -0,0 +1,73 @@
/* ----------------------------------------------------------------------------
* This file was automatically generated by SWIG (http://www.swig.org).
* Version 4.0.1
*
* Do not make changes to this file unless you know what you are doing--modify
* the SWIG interface file instead.
* ----------------------------------------------------------------------------- */
package org.mozilla.deepspeech.libdeepspeech;
/**
* A single transcript computed by the model, including a confidence<br>
* value and the metadata for its constituent tokens.
*/
public class CandidateTranscript {
private transient long swigCPtr;
protected transient boolean swigCMemOwn;
protected CandidateTranscript(long cPtr, boolean cMemoryOwn) {
swigCMemOwn = cMemoryOwn;
swigCPtr = cPtr;
}
protected static long getCPtr(CandidateTranscript obj) {
return (obj == null) ? 0 : obj.swigCPtr;
}
public synchronized void delete() {
if (swigCPtr != 0) {
if (swigCMemOwn) {
swigCMemOwn = false;
throw new UnsupportedOperationException("C++ destructor does not have public access");
}
swigCPtr = 0;
}
}
/**
* Array of TokenMetadata objects
*/
public TokenMetadata getTokens() {
long cPtr = implJNI.CandidateTranscript_tokens_get(swigCPtr, this);
return (cPtr == 0) ? null : new TokenMetadata(cPtr, false);
}
/**
* Size of the tokens array
*/
public long getNum_tokens() {
return implJNI.CandidateTranscript_num_tokens_get(swigCPtr, this);
}
/**
* Approximated confidence value for this transcript. This is roughly the<br>
* sum of the acoustic model logit values for each timestep/character that<br>
* contributed to the creation of this transcript.
*/
public double getConfidence() {
return implJNI.CandidateTranscript_confidence_get(swigCPtr, this);
}
/**
* Retrieve one TokenMetadata element<br>
* <br>
* @param i Array index of the TokenMetadata to get<br>
* <br>
* @return The TokenMetadata requested or null
*/
public TokenMetadata getToken(int i) {
return new TokenMetadata(implJNI.CandidateTranscript_getToken(swigCPtr, this, i), false);
}
}

View File

@ -0,0 +1,65 @@
/* ----------------------------------------------------------------------------
* This file was automatically generated by SWIG (http://www.swig.org).
* Version 4.0.1
*
* Do not make changes to this file unless you know what you are doing--modify
* the SWIG interface file instead.
* ----------------------------------------------------------------------------- */
package org.mozilla.deepspeech.libdeepspeech;
public enum DeepSpeech_Error_Codes {
ERR_OK(0x0000),
ERR_NO_MODEL(0x1000),
ERR_INVALID_ALPHABET(0x2000),
ERR_INVALID_SHAPE(0x2001),
ERR_INVALID_SCORER(0x2002),
ERR_MODEL_INCOMPATIBLE(0x2003),
ERR_SCORER_NOT_ENABLED(0x2004),
ERR_FAIL_INIT_MMAP(0x3000),
ERR_FAIL_INIT_SESS(0x3001),
ERR_FAIL_INTERPRETER(0x3002),
ERR_FAIL_RUN_SESS(0x3003),
ERR_FAIL_CREATE_STREAM(0x3004),
ERR_FAIL_READ_PROTOBUF(0x3005),
ERR_FAIL_CREATE_SESS(0x3006),
ERR_FAIL_CREATE_MODEL(0x3007);
public final int swigValue() {
return swigValue;
}
public static DeepSpeech_Error_Codes swigToEnum(int swigValue) {
DeepSpeech_Error_Codes[] swigValues = DeepSpeech_Error_Codes.class.getEnumConstants();
if (swigValue < swigValues.length && swigValue >= 0 && swigValues[swigValue].swigValue == swigValue)
return swigValues[swigValue];
for (DeepSpeech_Error_Codes swigEnum : swigValues)
if (swigEnum.swigValue == swigValue)
return swigEnum;
throw new IllegalArgumentException("No enum " + DeepSpeech_Error_Codes.class + " with value " + swigValue);
}
@SuppressWarnings("unused")
private DeepSpeech_Error_Codes() {
this.swigValue = SwigNext.next++;
}
@SuppressWarnings("unused")
private DeepSpeech_Error_Codes(int swigValue) {
this.swigValue = swigValue;
SwigNext.next = swigValue+1;
}
@SuppressWarnings("unused")
private DeepSpeech_Error_Codes(DeepSpeech_Error_Codes swigEnum) {
this.swigValue = swigEnum.swigValue;
SwigNext.next = this.swigValue+1;
}
private final int swigValue;
private static class SwigNext {
private static int next = 0;
}
}

View File

@ -1,6 +1,6 @@
/* ----------------------------------------------------------------------------
* This file was automatically generated by SWIG (http://www.swig.org).
* Version 4.0.2
* Version 4.0.1
*
* Do not make changes to this file unless you know what you are doing--modify
* the SWIG interface file instead.
@ -9,7 +9,7 @@
package org.mozilla.deepspeech.libdeepspeech;
/**
* Stores the entire CTC output as an array of character metadata objects
* An array of CandidateTranscript objects computed by the model.
*/
public class Metadata {
private transient long swigCPtr;
@ -40,61 +40,29 @@ public class Metadata {
}
/**
* List of items
* Array of CandidateTranscript objects
*/
public void setItems(MetadataItem value) {
implJNI.Metadata_items_set(swigCPtr, this, MetadataItem.getCPtr(value), value);
public CandidateTranscript getTranscripts() {
long cPtr = implJNI.Metadata_transcripts_get(swigCPtr, this);
return (cPtr == 0) ? null : new CandidateTranscript(cPtr, false);
}
/**
* List of items
* Size of the transcripts array
*/
public MetadataItem getItems() {
long cPtr = implJNI.Metadata_items_get(swigCPtr, this);
return (cPtr == 0) ? null : new MetadataItem(cPtr, false);
public long getNum_transcripts() {
return implJNI.Metadata_num_transcripts_get(swigCPtr, this);
}
/**
* Size of the list of items
*/
public void setNum_items(int value) {
implJNI.Metadata_num_items_set(swigCPtr, this, value);
}
/**
* Size of the list of items
*/
public int getNum_items() {
return implJNI.Metadata_num_items_get(swigCPtr, this);
}
/**
* Approximated confidence value for this transcription. This is roughly the<br>
* sum of the acoustic model logit values for each timestep/character that<br>
* contributed to the creation of this transcription.
*/
public void setConfidence(double value) {
implJNI.Metadata_confidence_set(swigCPtr, this, value);
}
/**
* Approximated confidence value for this transcription. This is roughly the<br>
* sum of the acoustic model logit values for each timestep/character that<br>
* contributed to the creation of this transcription.
*/
public double getConfidence() {
return implJNI.Metadata_confidence_get(swigCPtr, this);
}
/**
* Retrieve one MetadataItem element<br>
* Retrieve one CandidateTranscript element<br>
* <br>
* @param i Array index of the MetadataItem to get<br>
* @param i Array index of the CandidateTranscript to get<br>
* <br>
* @return The MetadataItem requested or null
* @return The CandidateTranscript requested or null
*/
public MetadataItem getItem(int i) {
return new MetadataItem(implJNI.Metadata_getItem(swigCPtr, this, i), true);
public CandidateTranscript getTranscript(int i) {
return new CandidateTranscript(implJNI.Metadata_getTranscript(swigCPtr, this, i), false);
}
}

View File

@ -4,7 +4,7 @@ Javadoc for Sphinx
This code is only here for reference for documentation generation.
To update, please build SWIG (4.0 at least) and then run from native_client/java:
To update, please install SWIG (4.0 at least) and then run from native_client/java:
.. code-block::

View File

@ -0,0 +1,58 @@
/* ----------------------------------------------------------------------------
* This file was automatically generated by SWIG (http://www.swig.org).
* Version 4.0.1
*
* Do not make changes to this file unless you know what you are doing--modify
* the SWIG interface file instead.
* ----------------------------------------------------------------------------- */
package org.mozilla.deepspeech.libdeepspeech;
/**
* Stores text of an individual token, along with its timing information
*/
public class TokenMetadata {
private transient long swigCPtr;
protected transient boolean swigCMemOwn;
protected TokenMetadata(long cPtr, boolean cMemoryOwn) {
swigCMemOwn = cMemoryOwn;
swigCPtr = cPtr;
}
protected static long getCPtr(TokenMetadata obj) {
return (obj == null) ? 0 : obj.swigCPtr;
}
public synchronized void delete() {
if (swigCPtr != 0) {
if (swigCMemOwn) {
swigCMemOwn = false;
throw new UnsupportedOperationException("C++ destructor does not have public access");
}
swigCPtr = 0;
}
}
/**
* The text corresponding to this token
*/
public String getText() {
return implJNI.TokenMetadata_text_get(swigCPtr, this);
}
/**
* Position of the token in units of 20ms
*/
public long getTimestep() {
return implJNI.TokenMetadata_timestep_get(swigCPtr, this);
}
/**
* Position of the token in seconds
*/
public float getStart_time() {
return implJNI.TokenMetadata_start_time_get(swigCPtr, this);
}
}

View File

@ -42,12 +42,11 @@ function totalTime(hrtimeValue) {
return (hrtimeValue[0] + hrtimeValue[1] / 1000000000).toPrecision(4);
}
function metadataToString(metadata) {
function candidateTranscriptToString(transcript) {
var retval = ""
for (var i = 0; i < metadata.num_items; ++i) {
retval += metadata.items[i].character;
for (var i = 0; i < transcript.tokens.length; ++i) {
retval += transcript.tokens[i].text;
}
Ds.FreeMetadata(metadata);
return retval;
}
@ -117,7 +116,9 @@ audioStream.on('finish', () => {
const audioLength = (audioBuffer.length / 2) * (1 / desired_sample_rate);
if (args['extended']) {
console.log(metadataToString(model.sttWithMetadata(audioBuffer)));
let metadata = model.sttWithMetadata(audioBuffer, 1);
console.log(candidateTranscriptToString(metadata.transcripts[0]));
Ds.FreeMetadata(metadata);
} else {
console.log(model.stt(audioBuffer));
}

View File

@ -37,6 +37,7 @@ using namespace node;
%newobject DS_IntermediateDecode;
%newobject DS_FinishStream;
%newobject DS_Version;
%newobject DS_ErrorCodeToErrorMessage;
// convert double pointer retval in CreateModel to an output
%typemap(in, numinputs=0) ModelState **retval (ModelState *ret) {
@ -47,8 +48,8 @@ using namespace node;
%typemap(argout) ModelState **retval {
$result = SWIGV8_ARRAY_NEW();
SWIGV8_AppendOutput($result, SWIG_From_int(result));
// owned by SWIG, ModelState destructor gets called when the JavaScript object is finalized (see below)
%append_output(SWIG_NewPointerObj(%as_voidptr(*$1), $*1_descriptor, SWIG_POINTER_OWN));
// owned by the application. NodeJS does not guarantee the finalizer will be called so applications must call FreeMetadata themselves.
%append_output(SWIG_NewPointerObj(%as_voidptr(*$1), $*1_descriptor, 0));
}
@ -68,27 +69,29 @@ using namespace node;
%nodefaultctor ModelState;
%nodefaultdtor ModelState;
%typemap(out) MetadataItem* %{
%typemap(out) TokenMetadata* %{
$result = SWIGV8_ARRAY_NEW();
for (int i = 0; i < arg1->num_items; ++i) {
SWIGV8_AppendOutput($result, SWIG_NewPointerObj(SWIG_as_voidptr(&result[i]), SWIGTYPE_p_MetadataItem, SWIG_POINTER_OWN));
for (int i = 0; i < arg1->num_tokens; ++i) {
SWIGV8_AppendOutput($result, SWIG_NewPointerObj(SWIG_as_voidptr(&result[i]), SWIGTYPE_p_TokenMetadata, 0));
}
%}
%nodefaultdtor Metadata;
%nodefaultctor Metadata;
%nodefaultctor MetadataItem;
%nodefaultdtor MetadataItem;
%extend struct Metadata {
~Metadata() {
DS_FreeMetadata($self);
%typemap(out) CandidateTranscript* %{
$result = SWIGV8_ARRAY_NEW();
for (int i = 0; i < arg1->num_transcripts; ++i) {
SWIGV8_AppendOutput($result, SWIG_NewPointerObj(SWIG_as_voidptr(&result[i]), SWIGTYPE_p_CandidateTranscript, 0));
}
}
%}
%extend struct MetadataItem {
~MetadataItem() { }
}
%ignore Metadata::num_transcripts;
%ignore CandidateTranscript::num_tokens;
%nodefaultctor Metadata;
%nodefaultdtor Metadata;
%nodefaultctor CandidateTranscript;
%nodefaultdtor CandidateTranscript;
%nodefaultctor TokenMetadata;
%nodefaultdtor TokenMetadata;
%rename ("%(strip:[DS_])s") "";

View File

@ -2,7 +2,7 @@
const binary = require('node-pre-gyp');
const path = require('path')
// 'lib', 'binding', 'v0.1.1', ['node', 'v' + process.versions.modules, process.platform, process.arch].join('-'), 'deepspeech-bingings.node')
// 'lib', 'binding', 'v0.1.1', ['node', 'v' + process.versions.modules, process.platform, process.arch].join('-'), 'deepspeech-bindings.node')
const binding_path = binary.find(path.resolve(path.join(__dirname, 'package.json')));
// On Windows, we can't rely on RPATH being set to $ORIGIN/../ or on
@ -35,7 +35,7 @@ function Model(aModelPath) {
const status = rets[0];
const impl = rets[1];
if (status !== 0) {
throw "CreateModel failed with error code 0x" + status.toString(16);
throw "CreateModel failed "+binding.ErrorCodeToErrorMessage(status)+" 0x" + status.toString(16);
}
this._impl = impl;
@ -115,15 +115,16 @@ Model.prototype.stt = function(aBuffer) {
}
/**
* Use the DeepSpeech model to perform Speech-To-Text and output metadata
* about the results.
* Use the DeepSpeech model to perform Speech-To-Text and output results including metadata.
*
* @param {object} aBuffer A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
* @param {number} aNumResults Maximum number of candidate transcripts to return. Returned list might be smaller than this. Default value is 1 if not specified.
*
* @return {object} Outputs a :js:func:`Metadata` struct of individual letters along with their timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`. Returns undefined on error.
* @return {object} :js:func:`Metadata` object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`. Returns undefined on error.
*/
Model.prototype.sttWithMetadata = function(aBuffer) {
return binding.SpeechToTextWithMetadata(this._impl, aBuffer);
Model.prototype.sttWithMetadata = function(aBuffer, aNumResults) {
aNumResults = aNumResults || 1;
return binding.SpeechToTextWithMetadata(this._impl, aBuffer, aNumResults);
}
/**
@ -138,7 +139,7 @@ Model.prototype.createStream = function() {
const status = rets[0];
const ctx = rets[1];
if (status !== 0) {
throw "CreateStream failed with error code 0x" + status.toString(16);
throw "CreateStream failed "+binding.ErrorCodeToErrorMessage(status)+" 0x" + status.toString(16);
}
return ctx;
}
@ -172,7 +173,19 @@ Stream.prototype.intermediateDecode = function() {
}
/**
* Signal the end of an audio signal to an ongoing streaming inference, returns the STT result over the whole audio signal.
* Compute the intermediate decoding of an ongoing streaming inference, return results including metadata.
*
* @param {number} aNumResults Maximum number of candidate transcripts to return. Returned list might be smaller than this. Default value is 1 if not specified.
*
* @return {object} :js:func:`Metadata` object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`. Returns undefined on error.
*/
Stream.prototype.intermediateDecodeWithMetadata = function(aNumResults) {
aNumResults = aNumResults || 1;
return binding.IntermediateDecode(this._impl, aNumResults);
}
/**
* Compute the final decoding of an ongoing streaming inference and return the result. Signals the end of an ongoing streaming inference.
*
* @return {string} The STT result.
*
@ -185,14 +198,17 @@ Stream.prototype.finishStream = function() {
}
/**
* Signal the end of an audio signal to an ongoing streaming inference, returns per-letter metadata.
* Compute the final decoding of an ongoing streaming inference and return the results including metadata. Signals the end of an ongoing streaming inference.
*
* @param {number} aNumResults Maximum number of candidate transcripts to return. Returned list might be smaller than this. Default value is 1 if not specified.
*
* @return {object} Outputs a :js:func:`Metadata` struct of individual letters along with their timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`.
*
* This method will free the stream, it must not be used after this method is called.
*/
Stream.prototype.finishStreamWithMetadata = function() {
result = binding.FinishStreamWithMetadata(this._impl);
Stream.prototype.finishStreamWithMetadata = function(aNumResults) {
aNumResults = aNumResults || 1;
result = binding.FinishStreamWithMetadata(this._impl, aNumResults);
this._impl = null;
return result;
}
@ -236,70 +252,80 @@ function Version() {
}
//// Metadata and MetadataItem are here only for documentation purposes
//// Metadata, CandidateTranscript and TokenMetadata are here only for documentation purposes
/**
* @class
*
* Stores each individual character, along with its timing information
* Stores text of an individual token, along with its timing information
*/
function MetadataItem() {}
function TokenMetadata() {}
/**
* The character generated for transcription
* The text corresponding to this token
*
* @return {string} The character generated
* @return {string} The text generated
*/
MetadataItem.prototype.character = function() {}
TokenMetadata.prototype.text = function() {}
/**
* Position of the character in units of 20ms
* Position of the token in units of 20ms
*
* @return {int} The position of the character
* @return {int} The position of the token
*/
MetadataItem.prototype.timestep = function() {};
TokenMetadata.prototype.timestep = function() {};
/**
* Position of the character in seconds
* Position of the token in seconds
*
* @return {float} The position of the character
* @return {float} The position of the token
*/
MetadataItem.prototype.start_time = function() {};
TokenMetadata.prototype.start_time = function() {};
/**
* @class
*
* Stores the entire CTC output as an array of character metadata objects
* A single transcript computed by the model, including a confidence value and
* the metadata for its constituent tokens.
*/
function Metadata () {}
function CandidateTranscript () {}
/**
* List of items
* Array of tokens
*
* @return {array} List of :js:func:`MetadataItem`
* @return {array} Array of :js:func:`TokenMetadata`
*/
Metadata.prototype.items = function() {}
/**
* Size of the list of items
*
* @return {int} Number of items
*/
Metadata.prototype.num_items = function() {}
CandidateTranscript.prototype.tokens = function() {}
/**
* Approximated confidence value for this transcription. This is roughly the
* sum of the acoustic model logit values for each timestep/character that
* sum of the acoustic model logit values for each timestep/token that
* contributed to the creation of this transcription.
*
* @return {float} Confidence value
*/
Metadata.prototype.confidence = function() {}
CandidateTranscript.prototype.confidence = function() {}
/**
* @class
*
* An array of CandidateTranscript objects computed by the model.
*/
function Metadata () {}
/**
* Array of transcripts
*
* @return {array} Array of :js:func:`CandidateTranscript` objects
*/
Metadata.prototype.transcripts = function() {}
module.exports = {
Model: Model,
Metadata: Metadata,
MetadataItem: MetadataItem,
CandidateTranscript: CandidateTranscript,
TokenMetadata: TokenMetadata,
Version: Version,
FreeModel: FreeModel,
FreeStream: FreeStream,

View File

@ -37,27 +37,39 @@ ModelState::decode(const DecoderState& state) const
}
Metadata*
ModelState::decode_metadata(const DecoderState& state)
ModelState::decode_metadata(const DecoderState& state,
size_t num_results)
{
vector<Output> out = state.decode();
vector<Output> out = state.decode(num_results);
unsigned int num_returned = out.size();
std::unique_ptr<Metadata> metadata(new Metadata());
metadata->num_items = out[0].tokens.size();
metadata->confidence = out[0].confidence;
CandidateTranscript* transcripts = (CandidateTranscript*)malloc(sizeof(CandidateTranscript)*num_returned);
std::unique_ptr<MetadataItem[]> items(new MetadataItem[metadata->num_items]());
for (int i = 0; i < num_returned; ++i) {
TokenMetadata* tokens = (TokenMetadata*)malloc(sizeof(TokenMetadata)*out[i].tokens.size());
// Loop through each character
for (int i = 0; i < out[0].tokens.size(); ++i) {
items[i].character = strdup(alphabet_.StringFromLabel(out[0].tokens[i]).c_str());
items[i].timestep = out[0].timesteps[i];
items[i].start_time = out[0].timesteps[i] * ((float)audio_win_step_ / sample_rate_);
if (items[i].start_time < 0) {
items[i].start_time = 0;
for (int j = 0; j < out[i].tokens.size(); ++j) {
TokenMetadata token {
strdup(alphabet_.StringFromLabel(out[i].tokens[j]).c_str()), // text
static_cast<unsigned int>(out[i].timesteps[j]), // timestep
out[i].timesteps[j] * ((float)audio_win_step_ / sample_rate_), // start_time
};
memcpy(&tokens[j], &token, sizeof(TokenMetadata));
}
CandidateTranscript transcript {
tokens, // tokens
static_cast<unsigned int>(out[i].tokens.size()), // num_tokens
out[i].confidence, // confidence
};
memcpy(&transcripts[i], &transcript, sizeof(CandidateTranscript));
}
metadata->items = items.release();
return metadata.release();
Metadata* ret = (Metadata*)malloc(sizeof(Metadata));
Metadata metadata {
transcripts, // transcripts
num_returned, // num_transcripts
};
memcpy(ret, &metadata, sizeof(Metadata));
return ret;
}

View File

@ -66,11 +66,14 @@ struct ModelState {
* @brief Return character-level metadata including letter timings.
*
* @param state Decoder state to use when decoding.
* @param num_results Maximum number of candidate results to return.
*
* @return Metadata struct containing MetadataItem structs for each character.
* The user is responsible for freeing Metadata by calling DS_FreeMetadata().
* @return A Metadata struct containing CandidateTranscript structs.
* Each represents an candidate transcript, with the first ranked most probable.
* The user is responsible for freeing Result by calling DS_FreeMetadata().
*/
virtual Metadata* decode_metadata(const DecoderState& state);
virtual Metadata* decode_metadata(const DecoderState& state,
size_t num_results);
};
#endif // MODELSTATE_H

View File

@ -6,6 +6,8 @@ bindings-clean:
rm -rf dist temp_build deepspeech.egg-info MANIFEST.in temp_lib
rm -f impl_wrap.cpp impl.py
# Enforce PATH here because swig calls from build_ext looses track of some
# variables over several runs
bindings-build:
pip install --quiet $(PYTHON_PACKAGES) wheel==0.33.6 setuptools==39.1.0
PATH=$(TOOLCHAIN):$$PATH AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED) $(RPATH_PYTHON)" MODEL_LDFLAGS="$(LDFLAGS_DIRS)" MODEL_LIBS="$(LIBS)" $(PYTHON_PATH) $(PYTHON_SYSCONFIGDATA) $(NUMPY_INCLUDE) python ./setup.py build_ext $(PYTHON_PLATFORM_NAME)

View File

@ -35,7 +35,7 @@ class Model(object):
status, impl = deepspeech.impl.CreateModel(model_path)
if status != 0:
raise RuntimeError("CreateModel failed with error code 0x{:X}".format(status))
raise RuntimeError("CreateModel failed with '{}' (0x{:X})".format(deepspeech.impl.ErrorCodeToErrorMessage(status),status))
self._impl = impl
def __del__(self):
@ -121,17 +121,20 @@ class Model(object):
"""
return deepspeech.impl.SpeechToText(self._impl, audio_buffer)
def sttWithMetadata(self, audio_buffer):
def sttWithMetadata(self, audio_buffer, num_results=1):
"""
Use the DeepSpeech model to perform Speech-To-Text and output metadata about the results.
Use the DeepSpeech model to perform Speech-To-Text and return results including metadata.
:param audio_buffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
:type audio_buffer: numpy.int16 array
:return: Outputs a struct of individual letters along with their timing information.
:param num_results: Maximum number of candidate transcripts to return. Returned list might be smaller than this.
:type num_results: int
:return: Metadata object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information.
:type: :func:`Metadata`
"""
return deepspeech.impl.SpeechToTextWithMetadata(self._impl, audio_buffer)
return deepspeech.impl.SpeechToTextWithMetadata(self._impl, audio_buffer, num_results)
def createStream(self):
"""
@ -145,7 +148,7 @@ class Model(object):
"""
status, ctx = deepspeech.impl.CreateStream(self._impl)
if status != 0:
raise RuntimeError("CreateStream failed with error code 0x{:X}".format(status))
raise RuntimeError("CreateStream failed with '{}' (0x{:X})".format(deepspeech.impl.ErrorCodeToErrorMessage(status),status))
return Stream(ctx)
@ -187,10 +190,27 @@ class Stream(object):
raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?")
return deepspeech.impl.IntermediateDecode(self._impl)
def intermediateDecodeWithMetadata(self, num_results=1):
"""
Compute the intermediate decoding of an ongoing streaming inference and return results including metadata.
:param num_results: Maximum number of candidate transcripts to return. Returned list might be smaller than this.
:type num_results: int
:return: Metadata object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information.
:type: :func:`Metadata`
:throws: RuntimeError if the stream object is not valid
"""
if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?")
return deepspeech.impl.IntermediateDecodeWithMetadata(self._impl, num_results)
def finishStream(self):
"""
Signal the end of an audio signal to an ongoing streaming inference,
returns the STT result over the whole audio signal.
Compute the final decoding of an ongoing streaming inference and return
the result. Signals the end of an ongoing streaming inference. The underlying
stream object must not be used after this method is called.
:return: The STT result.
:type: str
@ -203,19 +223,24 @@ class Stream(object):
self._impl = None
return result
def finishStreamWithMetadata(self):
def finishStreamWithMetadata(self, num_results=1):
"""
Signal the end of an audio signal to an ongoing streaming inference,
returns per-letter metadata.
Compute the final decoding of an ongoing streaming inference and return
results including metadata. Signals the end of an ongoing streaming
inference. The underlying stream object must not be used after this
method is called.
:return: Outputs a struct of individual letters along with their timing information.
:param num_results: Maximum number of candidate transcripts to return. Returned list might be smaller than this.
:type num_results: int
:return: Metadata object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information.
:type: :func:`Metadata`
:throws: RuntimeError if the stream object is not valid
"""
if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to finish an already finished stream?")
result = deepspeech.impl.FinishStreamWithMetadata(self._impl)
result = deepspeech.impl.FinishStreamWithMetadata(self._impl, num_results)
self._impl = None
return result
@ -233,52 +258,43 @@ class Stream(object):
# This is only for documentation purpose
# Metadata and MetadataItem should be in sync with native_client/deepspeech.h
class MetadataItem(object):
# Metadata, CandidateTranscript and TokenMetadata should be in sync with native_client/deepspeech.h
class TokenMetadata(object):
"""
Stores each individual character, along with its timing information
"""
def character(self):
def text(self):
"""
The character generated for transcription
The text for this token
"""
def timestep(self):
"""
Position of the character in units of 20ms
Position of the token in units of 20ms
"""
def start_time(self):
"""
Position of the character in seconds
Position of the token in seconds
"""
class Metadata(object):
class CandidateTranscript(object):
"""
Stores the entire CTC output as an array of character metadata objects
"""
def items(self):
def tokens(self):
"""
List of items
List of tokens
:return: A list of :func:`MetadataItem` elements
:return: A list of :func:`TokenMetadata` elements
:type: list
"""
def num_items(self):
"""
Size of the list of items
:return: Size of the list of items
:type: int
"""
def confidence(self):
"""
Approximated confidence value for this transcription. This is roughly the
@ -286,3 +302,12 @@ class Metadata(object):
contributed to the creation of this transcription.
"""
class Metadata(object):
def transcripts(self):
"""
List of candidate transcripts
:return: A list of :func:`CandidateTranscript` objects
:type: list
"""

View File

@ -18,6 +18,7 @@ try:
except ImportError:
from pipes import quote
def convert_samplerate(audio_path, desired_sample_rate):
sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path), desired_sample_rate)
try:
@ -31,25 +32,25 @@ def convert_samplerate(audio_path, desired_sample_rate):
def metadata_to_string(metadata):
return ''.join(item.character for item in metadata.items)
return ''.join(token.text for token in metadata.tokens)
def words_from_metadata(metadata):
def words_from_candidate_transcript(metadata):
word = ""
word_list = []
word_start_time = 0
# Loop through each character
for i in range(0, metadata.num_items):
item = metadata.items[i]
for i, token in enumerate(metadata.tokens):
# Append character to word if it's not a space
if item.character != " ":
if token.text != " ":
if len(word) == 0:
# Log the start time of the new word
word_start_time = item.start_time
word_start_time = token.start_time
word = word + item.character
word = word + token.text
# Word boundary is either a space or the last character in the array
if item.character == " " or i == metadata.num_items - 1:
word_duration = item.start_time - word_start_time
if token.text == " " or i == len(metadata.tokens) - 1:
word_duration = token.start_time - word_start_time
if word_duration < 0:
word_duration = 0
@ -69,9 +70,11 @@ def words_from_metadata(metadata):
def metadata_json_output(metadata):
json_result = dict()
json_result["words"] = words_from_metadata(metadata)
json_result["confidence"] = metadata.confidence
return json.dumps(json_result)
json_result["transcripts"] = [{
"confidence": transcript.confidence,
"words": words_from_candidate_transcript(transcript),
} for transcript in metadata.transcripts]
return json.dumps(json_result, indent=2)
@ -141,9 +144,9 @@ def main():
print('Running inference.', file=sys.stderr)
inference_start = timer()
if args.extended:
print(metadata_to_string(ds.sttWithMetadata(audio)))
print(metadata_to_string(ds.sttWithMetadata(audio, 1).transcripts[0]))
elif args.json:
print(metadata_json_output(ds.sttWithMetadata(audio)))
print(metadata_json_output(ds.sttWithMetadata(audio, 3)))
else:
print(ds.stt(audio))
inference_end = timer() - inference_start

View File

@ -38,30 +38,69 @@ import_array();
%append_output(SWIG_NewPointerObj(%as_voidptr($1), $1_descriptor, SWIG_POINTER_OWN));
}
%typemap(out) MetadataItem* %{
$result = PyList_New(arg1->num_items);
for (int i = 0; i < arg1->num_items; ++i) {
PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->items[i]), SWIGTYPE_p_MetadataItem, 0);
%fragment("parent_reference_init", "init") {
// Thread-safe initialization - initialize during Python module initialization
parent_reference();
}
%fragment("parent_reference_function", "header", fragment="parent_reference_init") {
static PyObject *parent_reference() {
static PyObject *parent_reference_string = SWIG_Python_str_FromChar("__parent_reference");
return parent_reference_string;
}
}
%typemap(out, fragment="parent_reference_function") CandidateTranscript* %{
$result = PyList_New(arg1->num_transcripts);
for (int i = 0; i < arg1->num_transcripts; ++i) {
PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->transcripts[i]), SWIGTYPE_p_CandidateTranscript, 0);
// Add a reference to Metadata in the returned elements to avoid premature
// garbage collection
PyObject_SetAttr(o, parent_reference(), $self);
PyList_SetItem($result, i, o);
}
%}
%extend struct MetadataItem {
%typemap(out, fragment="parent_reference_function") TokenMetadata* %{
$result = PyList_New(arg1->num_tokens);
for (int i = 0; i < arg1->num_tokens; ++i) {
PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->tokens[i]), SWIGTYPE_p_TokenMetadata, 0);
// Add a reference to CandidateTranscript in the returned elements to avoid premature
// garbage collection
PyObject_SetAttr(o, parent_reference(), $self);
PyList_SetItem($result, i, o);
}
%}
%extend struct TokenMetadata {
%pythoncode %{
def __repr__(self):
return 'MetadataItem(character=\'{}\', timestep={}, start_time={})'.format(self.character, self.timestep, self.start_time)
return 'TokenMetadata(text=\'{}\', timestep={}, start_time={})'.format(self.text, self.timestep, self.start_time)
%}
}
%extend struct CandidateTranscript {
%pythoncode %{
def __repr__(self):
tokens_repr = ',\n'.join(repr(i) for i in self.tokens)
tokens_repr = '\n'.join(' ' + l for l in tokens_repr.split('\n'))
return 'CandidateTranscript(confidence={}, tokens=[\n{}\n])'.format(self.confidence, tokens_repr)
%}
}
%extend struct Metadata {
%pythoncode %{
def __repr__(self):
items_repr = ', \n'.join(' ' + repr(i) for i in self.items)
return 'Metadata(confidence={}, items=[\n{}\n])'.format(self.confidence, items_repr)
transcripts_repr = ',\n'.join(repr(i) for i in self.transcripts)
transcripts_repr = '\n'.join(' ' + l for l in transcripts_repr.split('\n'))
return 'Metadata(transcripts=[\n{}\n])'.format(transcripts_repr)
%}
}
%ignore Metadata::num_items;
%ignore Metadata::num_transcripts;
%ignore CandidateTranscript::num_tokens;
%extend struct Metadata {
~Metadata() {
@ -69,10 +108,12 @@ import_array();
}
}
%nodefaultdtor Metadata;
%nodefaultctor Metadata;
%nodefaultctor MetadataItem;
%nodefaultdtor MetadataItem;
%nodefaultdtor Metadata;
%nodefaultctor CandidateTranscript;
%nodefaultdtor CandidateTranscript;
%nodefaultctor TokenMetadata;
%nodefaultdtor TokenMetadata;
%typemap(newfree) char* "DS_FreeString($1);";
@ -80,6 +121,7 @@ import_array();
%newobject DS_IntermediateDecode;
%newobject DS_FinishStream;
%newobject DS_Version;
%newobject DS_ErrorCodeToErrorMessage;
%rename ("%(strip:[DS_])s") "";

View File

@ -1,13 +1,13 @@
# Main training requirements
tensorflow == 1.15.0
tensorflow == 1.15.2
numpy == 1.18.1
progressbar2
pandas
six
pyxdg
attrdict
absl-py
semver
opuslib == 2.0.0
# Requirements for building native_client files
setuptools
@ -15,6 +15,10 @@ setuptools
# Requirements for importers
sox
bs4
pandas
requests
librosa
soundfile
# Requirements for optimizer
optuna

3
requirements_tests.txt Normal file
View File

@ -0,0 +1,3 @@
absl-py
argparse
semver

View File

@ -5,6 +5,9 @@ python:
apt: 'python3-virtualenv python3-setuptools python3-pip python3-wheel python3-pkg-resources'
packages_docs_bionic:
apt: 'python3 python3-pip zip doxygen'
training:
packages_trusty:
apt: 'libopus0'
tensorflow:
packages_trusty:
apt: 'make build-essential gfortran git libblas-dev liblapack-dev libsox-dev libmagic-dev libgsm1-dev libltdl-dev libpng-dev python zlib1g-dev'
@ -80,6 +83,18 @@ system:
android_25:
url: 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.android_cache.x86_64.android-25.4/artifacts/public/android_cache.tar.gz'
namespace: 'project.deepspeech.android_cache.x86_64.android-25.4'
android_26:
url: 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.android_cache.x86_64.android-26.0/artifacts/public/android_cache.tar.gz'
namespace: 'project.deepspeech.android_cache.x86_64.android-26.0'
android_27:
url: 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.android_cache.x86_64.android-27.0/artifacts/public/android_cache.tar.gz'
namespace: 'project.deepspeech.android_cache.x86_64.android-27.0'
android_28:
url: 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.android_cache.x86_64.android-28.0/artifacts/public/android_cache.tar.gz'
namespace: 'project.deepspeech.android_cache.x86_64.android-28.0'
android_29:
url: 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.android_cache.x86_64.android-29.0/artifacts/public/android_cache.tar.gz'
namespace: 'project.deepspeech.android_cache.x86_64.android-29.0'
sdk:
android_27:
url: 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.android_cache.sdk.android-27.4/artifacts/public/android_cache.tar.gz'

View File

@ -0,0 +1,14 @@
build:
template_file: android_cache-opt-base.tyml
system_setup:
>
${java.packages_trusty.apt}
cache:
url: ${system.android_cache.x86_64.android_26.url}
namespace: ${system.android_cache.x86_64.android_26.namespace}
scripts:
build: "taskcluster/android_cache-build.sh x86_64 android-26"
package: "taskcluster/android_cache-package.sh"
metadata:
name: "Builds Android cache x86_64 / android-26"
description: "Setup an Android SDK / emulator cache for Android / x86_64 android-26"

View File

@ -0,0 +1,14 @@
build:
template_file: android_cache-opt-base.tyml
system_setup:
>
${java.packages_trusty.apt}
cache:
url: ${system.android_cache.x86_64.android_28.url}
namespace: ${system.android_cache.x86_64.android_28.namespace}
scripts:
build: "taskcluster/android_cache-build.sh x86_64 android-28"
package: "taskcluster/android_cache-package.sh"
metadata:
name: "Builds Android cache x86_64 / android-28"
description: "Setup an Android SDK / emulator cache for Android / x86_64 android-28"

View File

@ -0,0 +1,14 @@
build:
template_file: android_cache-opt-base.tyml
system_setup:
>
${java.packages_trusty.apt}
cache:
url: ${system.android_cache.x86_64.android_29.url}
namespace: ${system.android_cache.x86_64.android_29.namespace}
scripts:
build: "taskcluster/android_cache-build.sh x86_64 android-29"
package: "taskcluster/android_cache-package.sh"
metadata:
name: "Builds Android cache x86_64 / android-29"
description: "Setup an Android SDK / emulator cache for Android / x86_64 android-29"

View File

@ -6,6 +6,10 @@ source $(dirname "$0")/tc-tests-utils.sh
source ${DS_ROOT_TASK}/DeepSpeech/tf/tc-vars.sh
export SYSTEM_TARGET=host
if [ "${OS}" = "${TC_MSYS_VERSION}" ]; then
export SYSTEM_TARGET=host-win
else
export SYSTEM_TARGET=host
fi;
do_deepspeech_decoder_build

View File

@ -19,6 +19,7 @@ build:
- "android-java-opt"
- "win-amd64-cpu-opt"
- "win-amd64-gpu-opt"
- "win-amd64-ctc-opt"
allowed:
- "tag"
ref_match: "refs/tags/"
@ -38,6 +39,7 @@ build:
- "win-amd64-cpu-opt"
- "win-amd64-tflite-opt"
- "win-amd64-gpu-opt"
- "win-amd64-ctc-opt"
javascript:
# GPU package
- "node-package-gpu"

View File

@ -252,6 +252,25 @@ assert_deepspeech_version()
assert_not_present "$1" "DeepSpeech: unknown"
}
# We need to ensure that running on inference really leverages GPU because
# it might default back to CPU
ensure_cuda_usage()
{
local _maybe_cuda=$1
DS_BINARY_FILE=${DS_BINARY_FILE:-"deepspeech"}
if [ "${_maybe_cuda}" = "cuda" ]; then
set +e
export TF_CPP_MIN_VLOG_LEVEL=1
ds_cuda=$(${DS_BINARY_PREFIX}${DS_BINARY_FILE} --model ${TASKCLUSTER_TMP_DIR}/${model_name} --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>&1 1>/dev/null)
export TF_CPP_MIN_VLOG_LEVEL=
set -e
assert_shows_something "${ds_cuda}" "Successfully opened dynamic library nvcuda.dll"
assert_not_present "${ds_cuda}" "Skipping registering GPU devices"
fi;
}
check_versions()
{
set +e

View File

@ -13,4 +13,6 @@ export PATH=${TASKCLUSTER_TMP_DIR}/ds/:$PATH
check_versions
ensure_cuda_usage "$2"
run_basic_inference_tests

View File

@ -59,6 +59,8 @@ node --version
check_runtime_electronjs
ensure_cuda_usage "$4"
run_electronjs_inference_tests
if [ "${OS}" = "Linux" ]; then

View File

@ -10,7 +10,7 @@ source $(dirname "$0")/tc-tests-utils.sh
bitrate=$1
set_ldc_sample_filename "${bitrate}"
if [ "${package_option}" = "--cuda" ]; then
if [ "${package_option}" = "cuda" ]; then
PROJECT_NAME="DeepSpeech-GPU"
elif [ "${package_option}" = "--tflite" ]; then
PROJECT_NAME="DeepSpeech-TFLite"
@ -25,4 +25,7 @@ download_data
install_nuget "${PROJECT_NAME}"
DS_BINARY_FILE="DeepSpeechConsole.exe"
ensure_cuda_usage "$2"
run_netframework_inference_tests

View File

@ -29,4 +29,6 @@ npm install --prefix ${NODE_ROOT} --cache ${NODE_CACHE} ${deepspeech_npm_url}
check_runtime_nodejs
ensure_cuda_usage "$3"
run_all_inference_tests

View File

@ -7,8 +7,8 @@ get_dep_npm_pkg_url()
{
local all_deps="$(curl -s https://community-tc.services.mozilla.com/api/queue/v1/task/${TASK_ID} | python -c 'import json; import sys; print(" ".join(json.loads(sys.stdin.read())["dependencies"]));')"
# We try "deepspeech-tflite" first and if we don't find it we try "deepspeech"
for pkg_basename in "deepspeech-tflite" "deepspeech"; do
# We try "deepspeech-tflite" and "deepspeech-gpu" first and if we don't find it we try "deepspeech"
for pkg_basename in "deepspeech-tflite" "deepspeech-gpu" "deepspeech"; do
local deepspeech_pkg="${pkg_basename}-${DS_VERSION}.tgz"
for dep in ${all_deps}; do
local has_artifact=$(curl -s https://community-tc.services.mozilla.com/api/queue/v1/task/${dep}/artifacts | python -c 'import json; import sys; has_artifact = True in [ e["name"].find("'${deepspeech_pkg}'") > 0 for e in json.loads(sys.stdin.read())["artifacts"] ]; print(has_artifact)')

View File

@ -13,12 +13,19 @@ download_data
virtualenv_activate "${pyalias}" "deepspeech"
deepspeech_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type})
if [ "$3" = "cuda" ]; then
deepspeech_pkg_url=$(get_python_pkg_url "${pyver_pkg}" "${py_unicode_type}" "deepspeech_gpu")
else
deepspeech_pkg_url=$(get_python_pkg_url "${pyver_pkg}" "${py_unicode_type}")
fi;
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: --upgrade ${deepspeech_pkg_url} | cat
which deepspeech
deepspeech --version
ensure_cuda_usage "$3"
run_all_inference_tests
virtualenv_deactivate "${pyalias}" "deepspeech"

View File

@ -48,6 +48,11 @@ pushd ${HOME}/DeepSpeech/ds/
time ./bin/run-tc-ldc93s1_new.sh 249 "${sample_rate}"
time ./bin/run-tc-ldc93s1_new.sh 1 "${sample_rate}"
time ./bin/run-tc-ldc93s1_tflite.sh "${sample_rate}"
# Testing single SDB source
time ./bin/run-tc-ldc93s1_new_sdb.sh 220 "${sample_rate}"
# Testing interleaved source (SDB+CSV combination) - run twice to test preprocessed features
time ./bin/run-tc-ldc93s1_new_sdb_csv.sh 109 "${sample_rate}"
time ./bin/run-tc-ldc93s1_new_sdb_csv.sh 1 "${sample_rate}"
popd
cp /tmp/train/output_graph.pb ${TASKCLUSTER_ARTIFACTS}
@ -62,6 +67,7 @@ cp /tmp/train/output_graph.pbmm ${TASKCLUSTER_ARTIFACTS}
pushd ${HOME}/DeepSpeech/ds/
time ./bin/run-tc-ldc93s1_checkpoint.sh
time ./bin/run-tc-ldc93s1_checkpoint_sdb.sh
popd
virtualenv_deactivate "${pyalias}" "deepspeech"

View File

@ -1,12 +1,21 @@
# disabled because too intermittent to be reliable until we have some KVM-backed infra
build:
template_file: test-android-opt-base.tyml
dependencies:
- "android-x86_64-cpu-opt"
- "test-training_16k-linux-amd64-py36m-opt"
- "swig-linux-amd64"
- "gradle-cache"
- "android-cache-x86_64-android-26"
test_model_task: "test-training_16k-linux-amd64-py36m-opt"
system_setup:
>
apt-get -qq -y install curl make python
cache:
url: ${system.android_cache.x86_64.android_26.url}
namespace: ${system.android_cache.x86_64.android_26.namespace}
gradle_cache:
url: ${system.gradle_cache.url}
namespace: ${system.gradle_cache.namespace}
args:
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-android-apk-tests.sh x86_64 android-26"
metadata:

View File

@ -1,12 +1,21 @@
# disabled because too intermittent to be reliable until we have some KVM-backed infra
build:
template_file: test-android-opt-base.tyml
dependencies:
- "android-x86_64-cpu-opt"
- "test-training_16k-linux-amd64-py36m-opt"
- "swig-linux-amd64"
- "gradle-cache"
- "android-cache-x86_64-android-28"
test_model_task: "test-training_16k-linux-amd64-py36m-opt"
system_setup:
>
apt-get -qq -y install curl make python
cache:
url: ${system.android_cache.x86_64.android_28.url}
namespace: ${system.android_cache.x86_64.android_28.namespace}
gradle_cache:
url: ${system.gradle_cache.url}
namespace: ${system.gradle_cache.namespace}
args:
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-android-apk-tests.sh x86_64 android-28"
metadata:

View File

@ -0,0 +1,23 @@
build:
template_file: test-android-opt-base.tyml
dependencies:
- "android-x86_64-cpu-opt"
- "test-training_16k-linux-amd64-py36m-opt"
- "swig-linux-amd64"
- "gradle-cache"
- "android-cache-x86_64-android-29"
test_model_task: "test-training_16k-linux-amd64-py36m-opt"
system_setup:
>
apt-get -qq -y install curl make python
cache:
url: ${system.android_cache.x86_64.android_29.url}
namespace: ${system.android_cache.x86_64.android_29.namespace}
gradle_cache:
url: ${system.gradle_cache.url}
namespace: ${system.gradle_cache.namespace}
args:
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-android-apk-tests.sh x86_64 android-29"
metadata:
name: "DeepSpeech Android 10.0 x86_64 Google Pixel APK/Java tests"
description: "Testing DeepSpeech APK/Java for Android 10.0 x86_64 Google Pixel, optimized version"

View File

@ -0,0 +1,11 @@
build:
template_file: test-win-cuda-opt-base.tyml
dependencies:
- "win-amd64-gpu-opt"
- "test-training_16k-linux-amd64-py36m-opt"
test_model_task: "test-training_16k-linux-amd64-py36m-opt"
args:
tests_cmdline: "$TASKCLUSTER_TASK_DIR/DeepSpeech/ds/taskcluster/tc-cppwin-ds-tests.sh 16k cuda"
metadata:
name: "DeepSpeech Windows AMD64 CUDA C++ tests (16kHz)"
description: "Testing DeepSpeech C++ for Windows/AMD64, CUDA, optimized version (16kHz)"

View File

@ -0,0 +1,14 @@
build:
template_file: test-win-opt-base.tyml
dependencies:
- "node-package-cpu"
- "test-training_16k-linux-amd64-py36m-opt"
test_model_task: "test-training_16k-linux-amd64-py36m-opt"
system_setup:
>
${system.sox_win} && ${nodejs.win.prep_12}
args:
tests_cmdline: "${system.homedir.win}/DeepSpeech/ds/taskcluster/tc-electron-tests.sh 12.x 8.0.1 16k"
metadata:
name: "DeepSpeech Windows AMD64 CPU ElectronJS MultiArch Package v8.0 tests"
description: "Testing DeepSpeech for Windows/AMD64 on ElectronJS MultiArch Package v8.0, CPU only, optimized version"

View File

@ -0,0 +1,14 @@
build:
template_file: test-win-cuda-opt-base.tyml
dependencies:
- "node-package-gpu"
- "test-training_16k-linux-amd64-py36m-opt"
test_model_task: "test-training_16k-linux-amd64-py36m-opt"
system_setup:
>
${system.sox_win} && ${nodejs.win.prep_12}
args:
tests_cmdline: "${system.homedir.win}/DeepSpeech/ds/taskcluster/tc-electron-tests.sh 12.x 8.0.1 16k cuda"
metadata:
name: "DeepSpeech Windows AMD64 CUDA ElectronJS MultiArch Package v8.0 tests"
description: "Testing DeepSpeech for Windows/AMD64 on ElectronJS MultiArch Package v8.0, CUDA, optimized version"

Some files were not shown because too many files have changed in this diff Show More