Merge remote-tracking branch 'upstream/master'

This commit is contained in:
Daniel 2020-04-08 20:27:43 +02:00
commit 00e4dbe3fd
115 changed files with 3939 additions and 3015 deletions

View File

@ -2,15 +2,15 @@
set -xe
apt-get install -y python3-venv
apt-get install -y python3-venv libopus0
python3 -m venv /tmp/venv
source /tmp/venv/bin/activate
pip install -r <(grep -v tensorflow requirements.txt)
pip install tensorflow-gpu==1.15.0
# Install ds_ctcdecoder package from TaskCluster
pip install $(python3 util/taskcluster.py --decoder)
pip install -U setuptools wheel pip
pip install .
pip uninstall -y tensorflow
pip install tensorflow-gpu==1.14
mkdir -p ../keep/summaries
@ -18,17 +18,22 @@ data="${SHARED_DIR}/data"
fis="${data}/LDC/fisher"
swb="${data}/LDC/LDC97S62/swb"
lbs="${data}/OpenSLR/LibriSpeech/librivox"
cv="${data}/mozilla/CommonVoice/en_1087h_2019-06-12/clips"
npr="${data}/NPR/WAMU/sets/v0.3"
python -u DeepSpeech.py \
--train_files "${fis}-train.csv","${swb}-train.csv","${lbs}-train-clean-100.csv","${lbs}-train-clean-360.csv","${lbs}-train-other-500.csv" \
--dev_files "${lbs}-dev-clean.csv"\
--test_files "${lbs}-test-clean.csv" \
--train_files "${npr}/best-train.sdb","${npr}/good-train.sdb","${cv}/train.sdb","${fis}-train.sdb","${swb}-train.sdb","${lbs}-train-clean-100.sdb","${lbs}-train-clean-360.sdb","${lbs}-train-other-500.sdb" \
--dev_files "${lbs}-dev-clean.sdb" \
--test_files "${lbs}-test-clean.sdb" \
--train_batch_size 24 \
--dev_batch_size 48 \
--test_batch_size 48 \
--train_cudnn \
--n_hidden 2048 \
--learning_rate 0.0001 \
--dropout_rate 0.2 \
--epoch 13 \
--dropout_rate 0.40 \
--epochs 150 \
--noearly_stop \
--feature_cache "../tmp/feature.cache" \
--checkpoint_dir "../keep" \
--summary_dir "../keep/summaries"

1
.gitignore vendored
View File

@ -19,6 +19,7 @@
/native_client/python/model_wrap.cpp
/native_client/python/utils_wrap.cpp
/native_client/javascript/build
/native_client/javascript/client.js
/native_client/javascript/deepspeech_wrap.cxx
/doc/.build/
/doc/xml-c/

4
.isort.cfg Normal file
View File

@ -0,0 +1,4 @@
[settings]
line_length=80
multi_line_output=3
default_section=FIRSTPARTY

View File

@ -9,7 +9,7 @@ python:
jobs:
include:
- stage: cardboard linter
- name: cardboard linter
install:
- pip install --upgrade cardboardlint pylint
script:
@ -17,9 +17,10 @@ jobs:
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
cardboardlinter --refspec $TRAVIS_BRANCH -n auto;
fi
- stage: python unit tests
- name: python unit tests
install:
- pip install --upgrade -r requirements_tests.txt
- pip install --upgrade -r requirements_tests.txt;
pip install --upgrade .
script:
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
python -m unittest;

View File

@ -2,934 +2,11 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import os
import sys
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
import absl.app
import json
import numpy as np
import progressbar
import shutil
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
import time
tfv1.logging.set_verbosity({
'0': tfv1.logging.DEBUG,
'1': tfv1.logging.INFO,
'2': tfv1.logging.WARN,
'3': tfv1.logging.ERROR
}.get(DESIRED_LOG_LEVEL))
from datetime import datetime
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
from evaluate import evaluate
from six.moves import zip, range
from util.config import Config, initialize_globals
from util.checkpoints import load_or_init_graph
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from util.flags import create_flags, FLAGS
from util.helpers import check_ctcdecoder_version, ExceptionBox
from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
check_ctcdecoder_version()
# Graph Creation
# ==============
def variable_on_cpu(name, shape, initializer):
r"""
Next we concern ourselves with graph creation.
However, before we do so we must introduce a utility function ``variable_on_cpu()``
used to create a variable in CPU memory.
"""
# Use the /cpu:0 device for scoped operations
with tf.device(Config.cpu_device):
# Create or get apropos variable
var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
return var
def create_overlapping_windows(batch_x):
batch_size = tf.shape(input=batch_x)[0]
window_width = 2 * Config.n_context + 1
num_channels = Config.n_input
# Create a constant convolution filter using an identity matrix, so that the
# convolution returns patches of the input tensor as is, and we can create
# overlapping windows over the MFCCs.
eye_filter = tf.constant(np.eye(window_width * num_channels)
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
# Create overlapping windows
batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME')
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
return batch_x
def dense(name, x, units, dropout_rate=None, relu=True):
with tfv1.variable_scope(name):
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
if relu:
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
if dropout_rate is not None:
output = tf.nn.dropout(output, rate=dropout_rate)
return output
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
forget_bias=0,
reuse=reuse,
name='cudnn_compatible_lstm_cell')
output, output_state = fw_cell(inputs=x,
dtype=tf.float32,
sequence_length=seq_length,
initial_state=previous_state)
return output, output_state
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
# the object it creates the variables, and then you just call it several times
# to enable variable re-use. Because all of our code is structure in an old
# school TensorFlow structure where you can just call tf.get_variable again with
# reuse=True to reuse variables, we can't easily make use of the object oriented
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
# emulating a static function variable.
if not rnn_impl_cudnn_rnn.cell:
# Forward direction cell:
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
num_units=Config.n_cell_dim,
input_mode='linear_input',
direction='unidirectional',
dtype=tf.float32)
rnn_impl_cudnn_rnn.cell = fw_cell
output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
sequence_lengths=seq_length)
return output, output_state
rnn_impl_cudnn_rnn.cell = None
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
# Forward direction cell:
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
forget_bias=0,
reuse=reuse,
name='cudnn_compatible_lstm_cell')
# Split rank N tensor into list of rank N-1 tensors
x = [x[l] for l in range(x.shape[0])]
output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
inputs=x,
sequence_length=seq_length,
initial_state=previous_state,
dtype=tf.float32,
scope='cell_0')
output = tf.concat(output, 0)
return output, output_state
def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
layers = {}
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
if not batch_size:
batch_size = tf.shape(input=batch_x)[0]
# Create overlapping feature windows if needed
if overlap:
batch_x = create_overlapping_windows(batch_x)
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
# Permute n_steps and batch_size
batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
# Reshape to prepare input for first layer
batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
layers['input_reshaped'] = batch_x
# The next three blocks will pass `batch_x` through three hidden layers with
# clipped RELU activation and dropout.
layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0])
layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1])
layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2])
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
# Run through parametrized RNN implementation, as we use different RNNs
# for training and inference
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
output = tf.reshape(output, [-1, Config.n_cell_dim])
layers['rnn_output'] = output
layers['rnn_output_state'] = output_state
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5])
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False)
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
# Note, that this differs from the input in that it is time-major.
layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits')
layers['raw_logits'] = layer_6
# Output shape: [n_steps, batch_size, n_hidden_6]
return layer_6, layers
# Accuracy and Loss
# =================
# In accord with 'Deep Speech: Scaling up end-to-end speech recognition'
# (http://arxiv.org/abs/1412.5567),
# the loss function used by our network should be the CTC loss function
# (http://www.cs.toronto.edu/~graves/preprint.pdf).
# Conveniently, this loss function is implemented in TensorFlow.
# Thus, we can simply make use of this implementation to define our loss.
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
r'''
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
Next to total and average loss it returns the mean edit distance,
the decoded result and the batch's original Y.
'''
# Obtain the next batch of data
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
if FLAGS.train_cudnn:
rnn_impl = rnn_impl_cudnn_rnn
else:
rnn_impl = rnn_impl_lstmblockfusedcell
# Calculate the logits of the batch
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
# Compute the CTC loss using TensorFlow's `ctc_loss`
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
# Check if any files lead to non finite loss
non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss)))
# Calculate the average loss across the batch
avg_loss = tf.reduce_mean(input_tensor=total_loss)
# Finally we return the average loss
return avg_loss, non_finite_files
# Adam Optimization
# =================
# In contrast to 'Deep Speech: Scaling up end-to-end speech recognition'
# (http://arxiv.org/abs/1412.5567),
# in which 'Nesterov's Accelerated Gradient Descent'
# (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used,
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
# because, generally, it requires less fine-tuning.
def create_optimizer(learning_rate_var):
optimizer = tfv1.train.AdamOptimizer(learning_rate=learning_rate_var,
beta1=FLAGS.beta1,
beta2=FLAGS.beta2,
epsilon=FLAGS.epsilon)
return optimizer
# Towers
# ======
# In order to properly make use of multiple GPU's, one must introduce new abstractions,
# not present when using a single GPU, that facilitate the multi-GPU use case.
# In particular, one must introduce a means to isolate the inference and gradient
# calculations on the various GPU's.
# The abstraction we intoduce for this purpose is called a 'tower'.
# A tower is specified by two properties:
# * **Scope** - A scope, as provided by `tf.name_scope()`,
# is a means to isolate the operations within a tower.
# For example, all operations within 'tower 0' could have their name prefixed with `tower_0/`.
# * **Device** - A hardware device, as provided by `tf.device()`,
# on which all operations within the tower execute.
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
def get_tower_results(iterator, optimizer, dropout_rates):
r'''
With this preliminary step out of the way, we can for each GPU introduce a
tower for which's batch we calculate and return the optimization gradients
and the average loss across towers.
'''
# To calculate the mean of the losses
tower_avg_losses = []
# Tower gradients to return
tower_gradients = []
# Aggregate any non finite files in the batches
tower_non_finite_files = []
with tfv1.variable_scope(tfv1.get_variable_scope()):
# Loop over available_devices
for i in range(len(Config.available_devices)):
# Execute operations of tower i on device i
device = Config.available_devices[i]
with tf.device(device):
# Create a scope for all operations of tower i
with tf.name_scope('tower_%d' % i):
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
# batch along with the original batch's labels (Y) of this tower
avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
# Allow for variables to be re-used by the next tower
tfv1.get_variable_scope().reuse_variables()
# Retain tower's avg losses
tower_avg_losses.append(avg_loss)
# Compute gradients for model parameters using tower's mini-batch
gradients = optimizer.compute_gradients(avg_loss)
# Retain tower's gradients
tower_gradients.append(gradients)
tower_non_finite_files.append(non_finite_files)
avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0)
tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])
all_non_finite_files = tf.concat(tower_non_finite_files, axis=0)
# Return gradients and the average loss
return tower_gradients, avg_loss_across_towers, all_non_finite_files
def average_gradients(tower_gradients):
r'''
A routine for computing each variable's average of the gradients obtained from the GPUs.
Note also that this code acts as a synchronization point as it requires all
GPUs to be finished with their mini-batch before it can run to completion.
'''
# List of average gradients to return to the caller
average_grads = []
# Run this on cpu_device to conserve GPU memory
with tf.device(Config.cpu_device):
# Loop over gradient/variable pairs from all towers
for grad_and_vars in zip(*tower_gradients):
# Introduce grads to store the gradients for the current variable
grads = []
# Loop over the gradients for the current variable
for g, _ in grad_and_vars:
# Add 0 dimension to the gradients to represent the tower.
expanded_g = tf.expand_dims(g, 0)
# Append on a 'tower' dimension which we will average over below.
grads.append(expanded_g)
# Average over the 'tower' dimension
grad = tf.concat(grads, 0)
grad = tf.reduce_mean(input_tensor=grad, axis=0)
# Create a gradient/variable tuple for the current variable with its average gradient
grad_and_var = (grad, grad_and_vars[0][1])
# Add the current tuple to average_grads
average_grads.append(grad_and_var)
# Return result to caller
return average_grads
# Logging
# =======
def log_variable(variable, gradient=None):
r'''
We introduce a function for logging a tensor variable's current state.
It logs scalar values for the mean, standard deviation, minimum and maximum.
Furthermore it logs a histogram of its state and (if given) of an optimization gradient.
'''
name = variable.name.replace(':', '_')
mean = tf.reduce_mean(input_tensor=variable)
tfv1.summary.scalar(name='%s/mean' % name, tensor=mean)
tfv1.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(input_tensor=tf.square(variable - mean))))
tfv1.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(input_tensor=variable))
tfv1.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(input_tensor=variable))
tfv1.summary.histogram(name=name, values=variable)
if gradient is not None:
if isinstance(gradient, tf.IndexedSlices):
grad_values = gradient.values
else:
grad_values = gradient
if grad_values is not None:
tfv1.summary.histogram(name='%s/gradients' % name, values=grad_values)
def log_grads_and_vars(grads_and_vars):
r'''
Let's also introduce a helper function for logging collections of gradient/variable tuples.
'''
for gradient, variable in grads_and_vars:
log_variable(variable, gradient=gradient)
def train():
do_cache_dataset = True
# pylint: disable=too-many-boolean-expressions
if (FLAGS.data_aug_features_multiplicative > 0 or
FLAGS.data_aug_features_additive > 0 or
FLAGS.augmentation_spec_dropout_keeprate < 1 or
FLAGS.augmentation_freq_and_time_masking or
FLAGS.augmentation_pitch_and_tempo_scaling or
FLAGS.augmentation_speed_up_std > 0 or
FLAGS.augmentation_sparse_warp):
do_cache_dataset = False
exception_box = ExceptionBox()
# Create training and validation datasets
train_set = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size,
enable_cache=FLAGS.feature_cache and do_cache_dataset,
cache_path=FLAGS.feature_cache,
train_phase=True,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
buffering=FLAGS.read_buffer)
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
tfv1.data.get_output_shapes(train_set),
output_classes=tfv1.data.get_output_classes(train_set))
# Make initialization ops for switching between the two sets
train_init_op = iterator.make_initializer(train_set)
if FLAGS.dev_files:
dev_sources = FLAGS.dev_files.split(',')
dev_sets = [create_dataset([source],
batch_size=FLAGS.dev_batch_size,
train_phase=False,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
buffering=FLAGS.read_buffer) for source in dev_sources]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
# Dropout
dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
dropout_feed_dict = {
dropout_rates[0]: FLAGS.dropout_rate,
dropout_rates[1]: FLAGS.dropout_rate2,
dropout_rates[2]: FLAGS.dropout_rate3,
dropout_rates[3]: FLAGS.dropout_rate4,
dropout_rates[4]: FLAGS.dropout_rate5,
dropout_rates[5]: FLAGS.dropout_rate6,
}
no_dropout_feed_dict = {
rate: 0. for rate in dropout_rates
}
# Building the graph
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
optimizer = create_optimizer(learning_rate_var)
# Enable mixed precision training
if FLAGS.automatic_mixed_precision:
log_info('Enabling automatic mixed precision training.')
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
# Average tower gradients across GPUs
avg_tower_gradients = average_gradients(gradients)
log_grads_and_vars(avg_tower_gradients)
# global_step is automagically incremented by the optimizer
global_step = tfv1.train.get_or_create_global_step()
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
# Summaries
step_summaries_op = tfv1.summary.merge_all('step_summaries')
step_summary_writers = {
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
}
# Checkpointing
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
# Save flags next to checkpoints
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
with open(flags_file, 'w') as fout:
fout.write(FLAGS.flags_into_string())
with tfv1.Session(config=Config.session_config) as session:
log_debug('Session opened.')
# Prevent further graph changes
tfv1.get_default_graph().finalize()
# Load checkpoint or initialize variables
if FLAGS.load == 'auto':
method_order = ['best', 'last', 'init']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
def run_set(set_name, epoch, init_op, dataset=None):
is_train = set_name == 'train'
train_op = apply_gradient_op if is_train else []
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
total_loss = 0.0
step_count = 0
step_summary_writer = step_summary_writers.get(set_name)
checkpoint_time = time.time()
# Setup progress bar
class LossWidget(progressbar.widgets.FormatLabel):
def __init__(self):
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
def __call__(self, progress, data, **kwargs):
data['mean_loss'] = total_loss / step_count if step_count else 0.0
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
widgets = [' | ', progressbar.widgets.Timer(),
' | Steps: ', progressbar.widgets.Counter(),
' | ', LossWidget()]
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
# Initialize iterator to the appropriate dataset
session.run(init_op)
# Batch loop
while True:
try:
_, current_step, batch_loss, problem_files, step_summary = \
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
feed_dict=feed_dict)
exception_box.raise_if_set()
except tf.errors.InvalidArgumentError as err:
if FLAGS.augmentation_sparse_warp:
log_info("Ignoring sparse warp error: {}".format(err))
continue
else:
raise
except tf.errors.OutOfRangeError:
exception_box.raise_if_set()
break
if problem_files.size > 0:
problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
log_error('The following files caused an infinite (or NaN) '
'loss: {}'.format(','.join(problem_files)))
total_loss += batch_loss
step_count += 1
pbar.update(step_count)
step_summary_writer.add_summary(step_summary, current_step)
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
checkpoint_time = time.time()
pbar.finish()
mean_loss = total_loss / step_count if step_count > 0 else 0.0
return mean_loss, step_count
log_info('STARTING Optimization')
train_start_time = datetime.utcnow()
best_dev_loss = float('inf')
dev_losses = []
epochs_without_improvement = 0
try:
for epoch in range(FLAGS.epochs):
# Training
log_progress('Training epoch %d...' % epoch)
train_loss, _ = run_set('train', epoch, train_init_op)
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
if FLAGS.dev_files:
# Validation
dev_loss = 0.0
total_steps = 0
for source, init_op in zip(dev_sources, dev_init_ops):
log_progress('Validating epoch %d on %s...' % (epoch, source))
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
dev_loss += set_loss * steps
total_steps += steps
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
dev_loss = dev_loss / total_steps
dev_losses.append(dev_loss)
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
# the improvement has to be greater than FLAGS.es_min_delta
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
epochs_without_improvement += 1
else:
epochs_without_improvement = 0
# Save new best model
if dev_loss < best_dev_loss:
best_dev_loss = dev_loss
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
# Early stopping
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
epochs_without_improvement))
break
# Reduce learning rate on plateau
if (FLAGS.reduce_lr_on_plateau and
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
# If the learning rate was reduced and there is still no improvement
# wait FLAGS.plateau_epochs before the learning rate is reduced again
session.run(reduce_learning_rate_op)
current_learning_rate = learning_rate_var.eval()
log_info('Encountered a plateau, reducing learning rate to {}'.format(
current_learning_rate))
except KeyboardInterrupt:
pass
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
log_debug('Session closed.')
def test():
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
batch_size = batch_size if batch_size > 0 else None
# Create feature computation graph
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
samples = tf.expand_dims(input_samples, -1)
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
mfccs = tf.identity(mfccs, name='mfccs')
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
# This shape is read by the native_client in DS_CreateModel to know the
# value of n_steps, n_context and n_input. Make sure you update the code
# there if this shape is changed.
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
if batch_size <= 0:
# no state management since n_step is expected to be dynamic too (see below)
previous_state = None
else:
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h)
# One rate per layer
no_dropout = [None] * 6
if tflite:
rnn_impl = rnn_impl_static_rnn
else:
rnn_impl = rnn_impl_lstmblockfusedcell
logits, layers = create_model(batch_x=input_tensor,
batch_size=batch_size,
seq_length=seq_length if not FLAGS.export_tflite else None,
dropout=no_dropout,
previous_state=previous_state,
overlap=False,
rnn_impl=rnn_impl)
# TF Lite runtime will check that input dimensions are 1, 2 or 4
# by default we get 3, the middle one being batch_size which is forced to
# one on inference graph, so remove that dimension
if tflite:
logits = tf.squeeze(logits, [1])
# Apply softmax for CTC decoder
logits = tf.nn.softmax(logits, name='logits')
if batch_size <= 0:
if tflite:
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
if n_steps > 0:
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
return (
{
'input': input_tensor,
'input_lengths': seq_length,
},
{
'outputs': logits,
},
layers
)
new_state_c, new_state_h = layers['rnn_output_state']
new_state_c = tf.identity(new_state_c, name='new_state_c')
new_state_h = tf.identity(new_state_h, name='new_state_h')
inputs = {
'input': input_tensor,
'previous_state_c': previous_state_c,
'previous_state_h': previous_state_h,
'input_samples': input_samples,
}
if not FLAGS.export_tflite:
inputs['input_lengths'] = seq_length
outputs = {
'outputs': logits,
'new_state_c': new_state_c,
'new_state_h': new_state_h,
'mfccs': mfccs,
}
return inputs, outputs, layers
def file_relative_read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
def export():
r'''
Restores the trained variables into a simpler graph that will be exported for serving.
'''
log_info('Exporting the model...')
from tensorflow.python.framework.ops import Tensor, Operation
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
assert graph_version > 0
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
if FLAGS.export_language:
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
# Prevent further graph changes
tfv1.get_default_graph().finalize()
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
output_names = output_names_tensors + output_names_ops
with tf.Session() as session:
# Restore variables from checkpoint
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
output_filename = FLAGS.export_file_name + '.pb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir)
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
sess=session,
input_graph_def=tfv1.get_default_graph().as_graph_def(),
output_node_names=output_names)
frozen_graph = tfv1.graph_util.extract_sub_graph(
graph_def=frozen_graph,
dest_nodes=output_names)
if not FLAGS.export_tflite:
with open(output_graph_path, 'wb') as fout:
fout.write(frozen_graph.SerializeToString())
else:
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
converter.allow_custom_ops = True
tflite_model = converter.convert()
with open(output_tflite_path, 'wb') as fout:
fout.write(tflite_model)
log_info('Models exported at %s' % (FLAGS.export_dir))
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
FLAGS.export_author_id,
FLAGS.export_model_name,
FLAGS.export_model_version))
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
with open(metadata_fname, 'w') as f:
f.write('---\n')
f.write('author: {}\n'.format(FLAGS.export_author_id))
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
f.write('license: {}\n'.format(FLAGS.export_license))
f.write('language: {}\n'.format(FLAGS.export_language))
f.write('runtime: {}\n'.format(model_runtime))
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
f.write('---\n')
f.write('{}\n'.format(FLAGS.export_description))
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
def package_zip():
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
zip_filename = os.path.dirname(export_dir)
shutil.copy(FLAGS.scorer_path, export_dir)
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
log_info('Exported packaged model {}'.format(archive))
def do_single_file_inference(input_file_path):
with tfv1.Session(config=Config.session_config) as session:
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
# Restore variables from training checkpoint
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
features, features_len = audiofile_to_features(input_file_path)
previous_state_c = np.zeros([1, Config.n_cell_dim])
previous_state_h = np.zeros([1, Config.n_cell_dim])
# Add batch dimension
features = tf.expand_dims(features, 0)
features_len = tf.expand_dims(features_len, 0)
# Evaluate
features = create_overlapping_windows(features).eval(session=session)
features_len = features_len.eval(session=session)
logits = outputs['outputs'].eval(feed_dict={
inputs['input']: features,
inputs['input_lengths']: features_len,
inputs['previous_state_c']: previous_state_c,
inputs['previous_state_h']: previous_state_h,
}, session=session)
logits = np.squeeze(logits)
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet)
else:
scorer = None
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
cutoff_top_n=FLAGS.cutoff_top_n)
# Print highest probability result
print(decoded[0][1])
def main(_):
initialize_globals()
if FLAGS.train_files:
tfv1.reset_default_graph()
tfv1.set_random_seed(FLAGS.random_seed)
train()
if FLAGS.test_files:
tfv1.reset_default_graph()
test()
if FLAGS.export_dir and not FLAGS.export_zip:
tfv1.reset_default_graph()
export()
if FLAGS.export_zip:
tfv1.reset_default_graph()
FLAGS.export_tflite = True
if os.listdir(FLAGS.export_dir):
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
sys.exit(1)
export()
package_zip()
if FLAGS.one_shot_infer:
tfv1.reset_default_graph()
do_single_file_inference(FLAGS.one_shot_infer)
if __name__ == '__main__':
create_flags()
absl.app.run(main)
try:
from deepspeech_training import train as ds_train
except ImportError:
print('Training package is not installed. See training documentation.')
raise
ds_train.run_script()

View File

@ -24,7 +24,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
libsox-fmt-mp3 \
htop \
nano \
swig \
cmake \
libboost-all-dev \
zlib1g-dev \
@ -66,7 +65,7 @@ RUN wget https://bootstrap.pypa.io/get-pip.py && \
# >> START Configure Tensorflow Build
# Clone TensoFlow from Mozilla repo
# Clone TensorFlow from Mozilla repo
RUN git clone https://github.com/mozilla/tensorflow/
WORKDIR /tensorflow
RUN git checkout r1.15
@ -150,7 +149,7 @@ COPY . /DeepSpeech/
WORKDIR /DeepSpeech
RUN pip3 --no-cache-dir install -r requirements.txt
RUN pip3 --no-cache-dir install .
# Link DeepSpeech native_client libs to tf folder
RUN ln -s /DeepSpeech/native_client /tensorflow
@ -201,10 +200,10 @@ WORKDIR /DeepSpeech/native_client
RUN make deepspeech
WORKDIR /DeepSpeech/native_client/python
RUN make bindings
RUN pip3 install dist/deepspeech*
RUN pip3 install --upgrade dist/deepspeech*
WORKDIR /DeepSpeech/native_client/ctcdecode
RUN make
RUN pip3 install dist/*.whl
RUN make bindings
RUN pip3 install --upgrade dist/*.whl
# << END Build and bind

View File

@ -1,53 +1,69 @@
#!/usr/bin/env python
'''
"""
Tool for building Sample Databases (SDB files) from DeepSpeech CSV files and other SDB files
Use "python3 build_sdb.py -h" for help
'''
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
"""
import argparse
import progressbar
from util.downloader import SIMPLE_BAR
from util.audio import change_audio_types, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS
from util.sample_collections import samples_from_files, DirectSDBWriter
from deepspeech_training.util.audio import (
AUDIO_TYPE_OPUS,
AUDIO_TYPE_WAV,
change_audio_types,
)
from deepspeech_training.util.downloader import SIMPLE_BAR
from deepspeech_training.util.sample_collections import (
DirectSDBWriter,
samples_from_files,
)
AUDIO_TYPE_LOOKUP = {
'wav': AUDIO_TYPE_WAV,
'opus': AUDIO_TYPE_OPUS
}
AUDIO_TYPE_LOOKUP = {"wav": AUDIO_TYPE_WAV, "opus": AUDIO_TYPE_OPUS}
def build_sdb():
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled) as sdb_writer:
with DirectSDBWriter(
CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled
) as sdb_writer:
samples = samples_from_files(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR)
for sample in bar(change_audio_types(samples, audio_type=audio_type, processes=CLI_ARGS.workers)):
for sample in bar(
change_audio_types(
samples, audio_type=audio_type, processes=CLI_ARGS.workers
)
):
sdb_writer.add(sample)
def handle_args():
parser = argparse.ArgumentParser(description='Tool for building Sample Databases (SDB files) '
'from DeepSpeech CSV files and other SDB files')
parser.add_argument('sources', nargs='+',
help='Source CSV and/or SDB files - '
'Note: For getting a correctly ordered target SDB, source SDBs have to have their samples '
'already ordered from shortest to longest.')
parser.add_argument('target', help='SDB file to create')
parser.add_argument('--audio-type', default='opus', choices=AUDIO_TYPE_LOOKUP.keys(),
help='Audio representation inside target SDB')
parser.add_argument('--workers', type=int, default=None,
help='Number of encoding SDB workers')
parser.add_argument('--unlabeled', action='store_true',
help='If to build an SDB with unlabeled (audio only) samples - '
'typically used for building noise augmentation corpora')
parser = argparse.ArgumentParser(
description="Tool for building Sample Databases (SDB files) "
"from DeepSpeech CSV files and other SDB files"
)
parser.add_argument(
"sources",
nargs="+",
help="Source CSV and/or SDB files - "
"Note: For getting a correctly ordered target SDB, source SDBs have to have their samples "
"already ordered from shortest to longest.",
)
parser.add_argument("target", help="SDB file to create")
parser.add_argument(
"--audio-type",
default="opus",
choices=AUDIO_TYPE_LOOKUP.keys(),
help="Audio representation inside target SDB",
)
parser.add_argument(
"--workers", type=int, default=None, help="Number of encoding SDB workers"
)
parser.add_argument(
"--unlabeled",
action="store_true",
help="If to build an SDB with unlabeled (audio only) samples - "
"typically used for building noise augmentation corpora",
)
return parser.parse_args()

View File

@ -1,11 +0,0 @@
#!/usr/bin/env python
import sys
import os
sys.path.append(os.path.abspath('.'))
from util.gpu_usage import GPUUsage
gu = GPUUsage()
gu.start()

View File

@ -1,10 +0,0 @@
#!/usr/bin/env python
import sys
import os
sys.path.append(os.path.abspath('.'))
from util.gpu_usage import GPUUsageChart
GPUUsageChart(sys.argv[1], sys.argv[2])

View File

@ -1,20 +1,21 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import tensorflow.compat.v1 as tfv1
import sys
import tensorflow.compat.v1 as tfv1
from google.protobuf import text_format
def main():
# Load and export as string
with tfv1.gfile.FastGFile(sys.argv[1], 'rb') as fin:
with tfv1.gfile.FastGFile(sys.argv[1], "rb") as fin:
graph_def = tfv1.GraphDef()
graph_def.ParseFromString(fin.read())
with tfv1.gfile.FastGFile(sys.argv[1] + 'txt', 'w') as fout:
with tfv1.gfile.FastGFile(sys.argv[1] + "txt", "w") as fout:
fout.write(text_format.MessageToString(graph_def))
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -1,23 +1,17 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
import glob
import pandas
import os
import tarfile
import pandas
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
from deepspeech_training.util.importers import get_importers_parser
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
def extract(archive_path, target_dir):
print('Extracting {} into {}...'.format(archive_path, target_dir))
print("Extracting {} into {}...".format(archive_path, target_dir))
with tarfile.open(archive_path) as tar:
tar.extractall(target_dir)
@ -25,9 +19,9 @@ def extract(archive_path, target_dir):
def preprocess_data(tgz_file, target_dir):
# First extract main archive and sub-archives
extract(tgz_file, target_dir)
main_folder = os.path.join(target_dir, 'aidatatang_200zh')
main_folder = os.path.join(target_dir, "aidatatang_200zh")
for targz in glob.glob(os.path.join(main_folder, 'corpus', '*', '*.tar.gz')):
for targz in glob.glob(os.path.join(main_folder, "corpus", "*", "*.tar.gz")):
extract(targz, os.path.dirname(targz))
# Folder structure is now:
@ -46,9 +40,11 @@ def preprocess_data(tgz_file, target_dir):
# Since the transcripts themselves can contain spaces, we split on space but
# only once, then build a mapping from file name to transcript
transcripts_path = os.path.join(main_folder, 'transcript', 'aidatatang_200_zh_transcript.txt')
transcripts_path = os.path.join(
main_folder, "transcript", "aidatatang_200_zh_transcript.txt"
)
with open(transcripts_path) as fin:
transcripts = dict((line.split(' ', maxsplit=1) for line in fin))
transcripts = dict((line.split(" ", maxsplit=1) for line in fin))
def load_set(glob_path):
set_files = []
@ -57,33 +53,39 @@ def preprocess_data(tgz_file, target_dir):
wav_filename = wav
wav_filesize = os.path.getsize(wav)
transcript_key = os.path.splitext(os.path.basename(wav))[0]
transcript = transcripts[transcript_key].strip('\n')
transcript = transcripts[transcript_key].strip("\n")
set_files.append((wav_filename, wav_filesize, transcript))
except KeyError:
print('Warning: Missing transcript for WAV file {}.'.format(wav))
print("Warning: Missing transcript for WAV file {}.".format(wav))
return set_files
for subset in ('train', 'dev', 'test'):
print('Loading {} set samples...'.format(subset))
subset_files = load_set(os.path.join(main_folder, 'corpus', subset, '*', '*.wav'))
for subset in ("train", "dev", "test"):
print("Loading {} set samples...".format(subset))
subset_files = load_set(
os.path.join(main_folder, "corpus", subset, "*", "*.wav")
)
df = pandas.DataFrame(data=subset_files, columns=COLUMN_NAMES)
# Trim train set to under 10s by removing the last couple hundred samples
if subset == 'train':
durations = (df['wav_filesize'] - 44) / 16000 / 2
if subset == "train":
durations = (df["wav_filesize"] - 44) / 16000 / 2
df = df[durations <= 10.0]
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum()))
print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
dest_csv = os.path.join(target_dir, 'aidatatang_{}.csv'.format(subset))
print('Saving {} set into {}...'.format(subset, dest_csv))
dest_csv = os.path.join(target_dir, "aidatatang_{}.csv".format(subset))
print("Saving {} set into {}...".format(subset, dest_csv))
df.to_csv(dest_csv, index=False)
def main():
# https://www.openslr.org/62/
parser = get_importers_parser(description='Import aidatatang_200zh corpus')
parser.add_argument('tgz_file', help='Path to aidatatang_200zh.tgz')
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
parser = get_importers_parser(description="Import aidatatang_200zh corpus")
parser.add_argument("tgz_file", help="Path to aidatatang_200zh.tgz")
parser.add_argument(
"--target_dir",
default="",
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
)
params = parser.parse_args()
if not params.target_dir:

View File

@ -1,23 +1,17 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
import glob
import os
import tarfile
import pandas
from deepspeech_training.util.importers import get_importers_parser
COLUMNNAMES = ['wav_filename', 'wav_filesize', 'transcript']
COLUMNNAMES = ["wav_filename", "wav_filesize", "transcript"]
def extract(archive_path, target_dir):
print('Extracting {} into {}...'.format(archive_path, target_dir))
print("Extracting {} into {}...".format(archive_path, target_dir))
with tarfile.open(archive_path) as tar:
tar.extractall(target_dir)
@ -25,10 +19,10 @@ def extract(archive_path, target_dir):
def preprocess_data(tgz_file, target_dir):
# First extract main archive and sub-archives
extract(tgz_file, target_dir)
main_folder = os.path.join(target_dir, 'data_aishell')
main_folder = os.path.join(target_dir, "data_aishell")
wav_archives_folder = os.path.join(main_folder, 'wav')
for targz in glob.glob(os.path.join(wav_archives_folder, '*.tar.gz')):
wav_archives_folder = os.path.join(main_folder, "wav")
for targz in glob.glob(os.path.join(wav_archives_folder, "*.tar.gz")):
extract(targz, main_folder)
# Folder structure is now:
@ -45,9 +39,11 @@ def preprocess_data(tgz_file, target_dir):
# Since the transcripts themselves can contain spaces, we split on space but
# only once, then build a mapping from file name to transcript
transcripts_path = os.path.join(main_folder, 'transcript', 'aishell_transcript_v0.8.txt')
transcripts_path = os.path.join(
main_folder, "transcript", "aishell_transcript_v0.8.txt"
)
with open(transcripts_path) as fin:
transcripts = dict((line.split(' ', maxsplit=1) for line in fin))
transcripts = dict((line.split(" ", maxsplit=1) for line in fin))
def load_set(glob_path):
set_files = []
@ -56,33 +52,37 @@ def preprocess_data(tgz_file, target_dir):
wav_filename = wav
wav_filesize = os.path.getsize(wav)
transcript_key = os.path.splitext(os.path.basename(wav))[0]
transcript = transcripts[transcript_key].strip('\n')
transcript = transcripts[transcript_key].strip("\n")
set_files.append((wav_filename, wav_filesize, transcript))
except KeyError:
print('Warning: Missing transcript for WAV file {}.'.format(wav))
print("Warning: Missing transcript for WAV file {}.".format(wav))
return set_files
for subset in ('train', 'dev', 'test'):
print('Loading {} set samples...'.format(subset))
subset_files = load_set(os.path.join(main_folder, subset, 'S*', '*.wav'))
for subset in ("train", "dev", "test"):
print("Loading {} set samples...".format(subset))
subset_files = load_set(os.path.join(main_folder, subset, "S*", "*.wav"))
df = pandas.DataFrame(data=subset_files, columns=COLUMNNAMES)
# Trim train set to under 10s by removing the last couple hundred samples
if subset == 'train':
durations = (df['wav_filesize'] - 44) / 16000 / 2
if subset == "train":
durations = (df["wav_filesize"] - 44) / 16000 / 2
df = df[durations <= 10.0]
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum()))
print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
dest_csv = os.path.join(target_dir, 'aishell_{}.csv'.format(subset))
print('Saving {} set into {}...'.format(subset, dest_csv))
dest_csv = os.path.join(target_dir, "aishell_{}.csv".format(subset))
print("Saving {} set into {}...".format(subset, dest_csv))
df.to_csv(dest_csv, index=False)
def main():
# http://www.openslr.org/33/
parser = get_importers_parser(description='Import AISHELL corpus')
parser.add_argument('aishell_tgz_file', help='Path to data_aishell.tgz')
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
parser = get_importers_parser(description="Import AISHELL corpus")
parser.add_argument("aishell_tgz_file", help="Path to data_aishell.tgz")
parser.add_argument(
"--target_dir",
default="",
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
)
params = parser.parse_args()
if not params.target_dir:

View File

@ -1,34 +1,35 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import csv
import sox
import tarfile
import os
import subprocess
import progressbar
import tarfile
from glob import glob
from os import path
from multiprocessing import Pool
from util.importers import validate_label_eng as validate_label, get_counter, get_imported_samples, print_import_report
from util.downloader import maybe_download, SIMPLE_BAR
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
import progressbar
import sox
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import (
get_counter,
get_imported_samples,
print_import_report,
)
from deepspeech_training.util.importers import validate_label_eng as validate_label
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000
MAX_SECS = 10
ARCHIVE_DIR_NAME = 'cv_corpus_v1'
ARCHIVE_NAME = ARCHIVE_DIR_NAME + '.tar.gz'
ARCHIVE_URL = 'https://s3.us-east-2.amazonaws.com/common-voice-data-download/' + ARCHIVE_NAME
ARCHIVE_DIR_NAME = "cv_corpus_v1"
ARCHIVE_NAME = ARCHIVE_DIR_NAME + ".tar.gz"
ARCHIVE_URL = (
"https://s3.us-east-2.amazonaws.com/common-voice-data-download/" + ARCHIVE_NAME
)
def _download_and_preprocess_data(target_dir):
# Making path absolute
target_dir = path.abspath(target_dir)
target_dir = os.path.abspath(target_dir)
# Conditionally download data
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
# Conditionally extract common voice data
@ -36,56 +37,70 @@ def _download_and_preprocess_data(target_dir):
# Conditionally convert common voice CSV files and mp3 data to DeepSpeech CSVs and wav
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data)
if not path.exists(extracted_path):
extracted_path = os.join(target_dir, extracted_data)
if not os.path.exists(extracted_path):
print('No directory "%s" - extracting archive...' % extracted_path)
with tarfile.open(archive_path) as tar:
tar.extractall(target_dir)
else:
print('Found directory "%s" - not extracting it from archive.' % extracted_path)
def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data)
for source_csv in glob(path.join(extracted_dir, '*.csv')):
_maybe_convert_set(extracted_dir, source_csv, path.join(target_dir, os.path.split(source_csv)[-1]))
extracted_dir = os.path.join(target_dir, extracted_data)
for source_csv in glob(os.path.join(extracted_dir, "*.csv")):
_maybe_convert_set(
extracted_dir,
source_csv,
os.path.join(target_dir, os.path.split(source_csv)[-1]),
)
def one_sample(sample):
mp3_filename = sample[0]
# Storing wav files next to the mp3 ones - just with a different suffix
wav_filename = path.splitext(mp3_filename)[0] + ".wav"
_maybe_convert_wav(mp3_filename, wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
frames = int(
subprocess.check_output(["soxi", "-s", wav_filename], stderr=subprocess.STDOUT)
)
file_size = -1
if path.exists(wav_filename):
if os.path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = validate_label(sample[1])
rows = []
counter = get_counter()
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
counter["failed"] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
counter["invalid_label"] += 1
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
counter["too_short"] += 1
elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
counter["too_long"] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
counter["all"] += 1
counter["total_time"] += frames
return (counter, rows)
def _maybe_convert_set(extracted_dir, source_csv, target_csv):
print()
if path.exists(target_csv):
if os.path.exists(target_csv):
print('Found CSV file "%s" - not importing "%s".' % (target_csv, source_csv))
return
print('No CSV file "%s" - importing "%s"...' % (target_csv, source_csv))
@ -93,14 +108,14 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
with open(source_csv) as source_csv_file:
reader = csv.DictReader(source_csv_file)
for row in reader:
samples.append((os.path.join(extracted_dir, row['filename']), row['text']))
samples.append((os.path.join(extracted_dir, row["filename"]), row["text"]))
# Mutable counters for the concurrent embedded routine
counter = get_counter()
num_samples = len(samples)
rows = []
print('Importing mp3 files...')
print("Importing mp3 files...")
pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
@ -112,21 +127,28 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
pool.join()
print('Writing "%s"...' % target_csv)
with open(target_csv, 'w') as target_csv_file:
with open(target_csv, "w") as target_csv_file:
writer = csv.DictWriter(target_csv_file, fieldnames=FIELDNAMES)
writer.writeheader()
bar = progressbar.ProgressBar(max_value=len(rows), widgets=SIMPLE_BAR)
for filename, file_size, transcript in bar(rows):
writer.writerow({ 'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript })
writer.writerow(
{
"wav_filename": filename,
"wav_filesize": file_size,
"transcript": transcript,
}
)
imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples
assert counter["all"] == num_samples
assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(mp3_filename, wav_filename):
if not path.exists(wav_filename):
if not os.path.exists(wav_filename):
transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE)
try:
@ -134,5 +156,6 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
except sox.core.SoxError:
pass
if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1])

View File

@ -1,90 +1,96 @@
#!/usr/bin/env python
'''
"""
Broadly speaking, this script takes the audio downloaded from Common Voice
for a certain language, in addition to the *.tsv files output by CorporaCreator,
and the script formats the data and transcripts to be in a state usable by
DeepSpeech.py
Use "python3 import_cv2.py -h" for help
'''
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
"""
import csv
import sox
import os
import subprocess
import progressbar
import unicodedata
from os import path
from multiprocessing import Pool
from util.downloader import SIMPLE_BAR
from util.text import Alphabet
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import progressbar
import sox
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
from deepspeech_training.util.downloader import SIMPLE_BAR
from deepspeech_training.util.importers import (
get_counter,
get_imported_samples,
get_importers_parser,
get_validate_label,
print_import_report,
)
from deepspeech_training.util.text import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000
MAX_SECS = 10
def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
for dataset in ['train', 'test', 'dev', 'validated', 'other']:
input_tsv = path.join(path.abspath(tsv_dir), dataset+".tsv")
for dataset in ["train", "test", "dev", "validated", "other"]:
input_tsv = os.path.join(os.path.abspath(tsv_dir), dataset + ".tsv")
if os.path.isfile(input_tsv):
print("Loading TSV file: ", input_tsv)
_maybe_convert_set(input_tsv, audio_dir, space_after_every_character)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
mp3_filename = sample[0]
if not path.splitext(mp3_filename.lower())[1] == '.mp3':
if not os.path.splitext(mp3_filename.lower())[1] == ".mp3":
mp3_filename += ".mp3"
# Storing wav files next to the mp3 ones - just with a different suffix
wav_filename = path.splitext(mp3_filename)[0] + ".wav"
wav_filename = os.path.splitext(mp3_filename)[0] + ".wav"
_maybe_convert_wav(mp3_filename, wav_filename)
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
if os.path.exists(wav_filename):
file_size = os.path.getsize(wav_filename)
frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = label_filter_fun(sample[1])
rows = []
counter = get_counter()
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
counter["failed"] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
counter["invalid_label"] += 1
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
counter["too_short"] += 1
elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
counter["too_long"] += 1
else:
# This one is good - keep it for the target CSV
rows.append((os.path.split(wav_filename)[-1], file_size, label))
counter['all'] += 1
counter['total_time'] += frames
counter["all"] += 1
counter["total_time"] += frames
return (counter, rows)
def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
output_csv = path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv'))
output_csv = os.path.join(
audio_dir, os.path.split(input_tsv)[-1].replace("tsv", "csv")
)
print("Saving new DeepSpeech-formatted CSV file to: ", output_csv)
# Get audiofile path and transcript for each sentence in tsv
samples = []
with open(input_tsv, encoding='utf-8') as input_tsv_file:
reader = csv.DictReader(input_tsv_file, delimiter='\t')
with open(input_tsv, encoding="utf-8") as input_tsv_file:
reader = csv.DictReader(input_tsv_file, delimiter="\t")
for row in reader:
samples.append((path.join(audio_dir, row['path']), row['sentence']))
samples.append((os.path.join(audio_dir, row["path"]), row["sentence"]))
counter = get_counter()
num_samples = len(samples)
@ -101,26 +107,38 @@ def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
pool.close()
pool.join()
with open(output_csv, 'w', encoding='utf-8') as output_csv_file:
print('Writing CSV file for DeepSpeech.py as: ', output_csv)
with open(output_csv, "w", encoding="utf-8") as output_csv_file:
print("Writing CSV file for DeepSpeech.py as: ", output_csv)
writer = csv.DictWriter(output_csv_file, fieldnames=FIELDNAMES)
writer.writeheader()
bar = progressbar.ProgressBar(max_value=len(rows), widgets=SIMPLE_BAR)
for filename, file_size, transcript in bar(rows):
if space_after_every_character:
writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': ' '.join(transcript)})
writer.writerow(
{
"wav_filename": filename,
"wav_filesize": file_size,
"transcript": " ".join(transcript),
}
)
else:
writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript})
writer.writerow(
{
"wav_filename": filename,
"wav_filesize": file_size,
"transcript": transcript,
}
)
imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples
assert counter["all"] == num_samples
assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(mp3_filename, wav_filename):
if not path.exists(wav_filename):
if not os.path.exists(wav_filename):
transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE)
try:
@ -130,24 +148,42 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
if __name__ == "__main__":
PARSER = get_importers_parser(description='Import CommonVoice v2.0 corpora')
PARSER.add_argument('tsv_dir', help='Directory containing tsv files')
PARSER.add_argument('--audio_dir', help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"')
PARSER.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
PARSER.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
PARSER.add_argument('--space_after_every_character', action='store_true', help='To help transcript join by white space')
PARSER = get_importers_parser(description="Import CommonVoice v2.0 corpora")
PARSER.add_argument("tsv_dir", help="Directory containing tsv files")
PARSER.add_argument(
"--audio_dir",
help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"',
)
PARSER.add_argument(
"--filter_alphabet",
help="Exclude samples with characters not in provided alphabet",
)
PARSER.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
PARSER.add_argument(
"--space_after_every_character",
action="store_true",
help="To help transcript join by white space",
)
PARAMS = PARSER.parse_args()
validate_label = get_validate_label(PARAMS)
AUDIO_DIR = PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, 'clips')
AUDIO_DIR = (
PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, "clips")
)
ALPHABET = Alphabet(PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None
def label_filter_fun(label):
if PARAMS.normalize:
label = unicodedata.normalize("NFKD", label.strip()) \
.encode("ascii", "ignore") \
label = (
unicodedata.normalize("NFKD", label.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
label = validate_label(label)
if ALPHABET and label:
try:

View File

@ -1,25 +1,20 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
import codecs
import fnmatch
import os
import subprocess
import sys
import unicodedata
import librosa
import pandas
import soundfile # <= Has an external dependency on libsndfile
from deepspeech_training.util.importers import validate_label_eng as validate_label
# Prerequisite: Having the sph2pipe tool in your PATH:
# https://www.ldc.upenn.edu/language-resources/tools/sphere-conversion-tools
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import codecs
import fnmatch
import os
import pandas
import subprocess
import unicodedata
import librosa
import soundfile # <= Has an external dependency on libsndfile
from util.importers import validate_label_eng as validate_label
def _download_and_preprocess_data(data_dir):
# Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19
@ -29,33 +24,55 @@ def _download_and_preprocess_data(data_dir):
_maybe_convert_wav(data_dir, "LDC2005S13", "fisher-2005-wav")
# Conditionally split Fisher wav data
all_2004 = _split_wav_and_sentences(data_dir,
original_data="fisher-2004-wav",
converted_data="fisher-2004-split-wav",
trans_data=os.path.join("LDC2004T19", "fe_03_p1_tran", "data", "trans"))
all_2005 = _split_wav_and_sentences(data_dir,
original_data="fisher-2005-wav",
converted_data="fisher-2005-split-wav",
trans_data=os.path.join("LDC2005T19", "fe_03_p2_tran", "data", "trans"))
all_2004 = _split_wav_and_sentences(
data_dir,
original_data="fisher-2004-wav",
converted_data="fisher-2004-split-wav",
trans_data=os.path.join("LDC2004T19", "fe_03_p1_tran", "data", "trans"),
)
all_2005 = _split_wav_and_sentences(
data_dir,
original_data="fisher-2005-wav",
converted_data="fisher-2005-split-wav",
trans_data=os.path.join("LDC2005T19", "fe_03_p2_tran", "data", "trans"),
)
# The following files have incorrect transcripts that are much longer than
# their audio source. The result is that we end up with more labels than time
# slices, which breaks CTC.
all_2004.loc[all_2004["wav_filename"].str.endswith("fe_03_00265-33.53-33.81.wav"), "transcript"] = "correct"
all_2004.loc[all_2004["wav_filename"].str.endswith("fe_03_00991-527.39-528.3.wav"), "transcript"] = "that's one of those"
all_2005.loc[all_2005["wav_filename"].str.endswith("fe_03_10282-344.42-344.84.wav"), "transcript"] = "they don't want"
all_2005.loc[all_2005["wav_filename"].str.endswith("fe_03_10677-101.04-106.41.wav"), "transcript"] = "uh my mine yeah the german shepherd pitbull mix he snores almost as loud as i do"
all_2004.loc[
all_2004["wav_filename"].str.endswith("fe_03_00265-33.53-33.81.wav"),
"transcript",
] = "correct"
all_2004.loc[
all_2004["wav_filename"].str.endswith("fe_03_00991-527.39-528.3.wav"),
"transcript",
] = "that's one of those"
all_2005.loc[
all_2005["wav_filename"].str.endswith("fe_03_10282-344.42-344.84.wav"),
"transcript",
] = "they don't want"
all_2005.loc[
all_2005["wav_filename"].str.endswith("fe_03_10677-101.04-106.41.wav"),
"transcript",
] = "uh my mine yeah the german shepherd pitbull mix he snores almost as loud as i do"
# The following file is just a short sound and not at all transcribed like provided.
# So we just exclude it.
all_2004 = all_2004[~all_2004["wav_filename"].str.endswith("fe_03_00027-393.8-394.05.wav")]
all_2004 = all_2004[
~all_2004["wav_filename"].str.endswith("fe_03_00027-393.8-394.05.wav")
]
# The following file is far too long and would ruin our training batch size.
# So we just exclude it.
all_2005 = all_2005[~all_2005["wav_filename"].str.endswith("fe_03_11487-31.09-234.06.wav")]
all_2005 = all_2005[
~all_2005["wav_filename"].str.endswith("fe_03_11487-31.09-234.06.wav")
]
# The following file is too large for its transcript, so we just exclude it.
all_2004 = all_2004[~all_2004["wav_filename"].str.endswith("fe_03_01326-307.42-307.93.wav")]
all_2004 = all_2004[
~all_2004["wav_filename"].str.endswith("fe_03_01326-307.42-307.93.wav")
]
# Conditionally split Fisher data into train/validation/test sets
train_2004, dev_2004, test_2004 = _split_sets(all_2004)
@ -71,6 +88,7 @@ def _download_and_preprocess_data(data_dir):
dev_files.to_csv(os.path.join(data_dir, "fisher-dev.csv"), index=False)
test_files.to_csv(os.path.join(data_dir, "fisher-test.csv"), index=False)
def _maybe_convert_wav(data_dir, original_data, converted_data):
source_dir = os.path.join(data_dir, original_data)
target_dir = os.path.join(data_dir, converted_data)
@ -88,10 +106,18 @@ def _maybe_convert_wav(data_dir, original_data, converted_data):
for filename in fnmatch.filter(filenames, "*.sph"):
sph_file = os.path.join(root, filename)
for channel in ["1", "2"]:
wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "_c" + channel + ".wav"
wav_filename = (
os.path.splitext(os.path.basename(sph_file))[0]
+ "_c"
+ channel
+ ".wav"
)
wav_file = os.path.join(target_dir, wav_filename)
print("converting {} to {}".format(sph_file, wav_file))
subprocess.check_call(["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, wav_file])
subprocess.check_call(
["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, wav_file]
)
def _parse_transcriptions(trans_file):
segments = []
@ -109,18 +135,23 @@ def _parse_transcriptions(trans_file):
# We need to do the encode-decode dance here because encode
# returns a bytes() object on Python 3, and text_to_char_array
# expects a string.
transcript = unicodedata.normalize("NFKD", transcript) \
.encode("ascii", "ignore") \
.decode("ascii", "ignore")
transcript = (
unicodedata.normalize("NFKD", transcript)
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
segments.append({
"start_time": start_time,
"stop_time": stop_time,
"speaker": speaker,
"transcript": transcript,
})
segments.append(
{
"start_time": start_time,
"stop_time": stop_time,
"speaker": speaker,
"transcript": transcript,
}
)
return segments
def _split_wav_and_sentences(data_dir, trans_data, original_data, converted_data):
trans_dir = os.path.join(data_dir, trans_data)
source_dir = os.path.join(data_dir, original_data)
@ -137,43 +168,73 @@ def _split_wav_and_sentences(data_dir, trans_data, original_data, converted_data
segments = _parse_transcriptions(trans_file)
# Open wav corresponding to transcription file
wav_filenames = [os.path.splitext(os.path.basename(trans_file))[0] + "_c" + channel + ".wav" for channel in ["1", "2"]]
wav_files = [os.path.join(source_dir, wav_filename) for wav_filename in wav_filenames]
wav_filenames = [
os.path.splitext(os.path.basename(trans_file))[0]
+ "_c"
+ channel
+ ".wav"
for channel in ["1", "2"]
]
wav_files = [
os.path.join(source_dir, wav_filename) for wav_filename in wav_filenames
]
print("splitting {} according to {}".format(wav_files, trans_file))
origAudios = [librosa.load(wav_file, sr=16000, mono=False) for wav_file in wav_files]
origAudios = [
librosa.load(wav_file, sr=16000, mono=False) for wav_file in wav_files
]
# Loop over segments and split wav_file for each segment
for segment in segments:
# Create wav segment filename
start_time = segment["start_time"]
stop_time = segment["stop_time"]
new_wav_filename = os.path.splitext(os.path.basename(trans_file))[0] + "-" + str(start_time) + "-" + str(stop_time) + ".wav"
new_wav_filename = (
os.path.splitext(os.path.basename(trans_file))[0]
+ "-"
+ str(start_time)
+ "-"
+ str(stop_time)
+ ".wav"
)
new_wav_file = os.path.join(target_dir, new_wav_filename)
channel = 0 if segment["speaker"] == "A:" else 1
_split_and_resample_wav(origAudios[channel], start_time, stop_time, new_wav_file)
_split_and_resample_wav(
origAudios[channel], start_time, stop_time, new_wav_file
)
new_wav_filesize = os.path.getsize(new_wav_file)
transcript = validate_label(segment["transcript"])
if transcript != None:
files.append((os.path.abspath(new_wav_file), new_wav_filesize, transcript))
files.append(
(os.path.abspath(new_wav_file), new_wav_filesize, transcript)
)
return pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
)
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
def _split_audio(origAudio, start_time, stop_time):
audioData, frameRate = origAudio
nChannels = len(audioData.shape)
startIndex = int(start_time * frameRate)
stopIndex = int(stop_time * frameRate)
return audioData[startIndex: stopIndex] if 1 == nChannels else audioData[:, startIndex: stopIndex]
return (
audioData[startIndex:stopIndex]
if 1 == nChannels
else audioData[:, startIndex:stopIndex]
)
def _split_and_resample_wav(origAudio, start_time, stop_time, new_wav_file):
frameRate = origAudio[1]
chunkData = _split_audio(origAudio, start_time, stop_time)
soundfile.write(new_wav_file, chunkData, frameRate, "PCM_16")
def _split_sets(filelist):
# We initially split the entire set into 80% train and 20% test, then
# split the train set into 80% train and 20% validation.
@ -187,9 +248,12 @@ def _split_sets(filelist):
test_beg = dev_end
test_end = len(filelist)
return (filelist[train_beg:train_end],
filelist[dev_beg:dev_end],
filelist[test_beg:test_end])
return (
filelist[train_beg:train_end],
filelist[dev_beg:dev_end],
filelist[test_beg:test_end],
)
if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1])

View File

@ -1,24 +1,18 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
import glob
import numpy as np
import pandas
import os
import tarfile
import numpy as np
import pandas
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
from deepspeech_training.util.importers import get_importers_parser
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
def extract(archive_path, target_dir):
print('Extracting {} into {}...'.format(archive_path, target_dir))
print("Extracting {} into {}...".format(archive_path, target_dir))
with tarfile.open(archive_path) as tar:
tar.extractall(target_dir)
@ -26,7 +20,7 @@ def extract(archive_path, target_dir):
def preprocess_data(tgz_file, target_dir):
# First extract main archive and sub-archives
extract(tgz_file, target_dir)
main_folder = os.path.join(target_dir, 'ST-CMDS-20170001_1-OS')
main_folder = os.path.join(target_dir, "ST-CMDS-20170001_1-OS")
# Folder structure is now:
# - ST-CMDS-20170001_1-OS/
@ -39,16 +33,16 @@ def preprocess_data(tgz_file, target_dir):
for wav in glob.glob(glob_path):
wav_filename = wav
wav_filesize = os.path.getsize(wav)
txt_filename = os.path.splitext(wav_filename)[0] + '.txt'
with open(txt_filename, 'r') as fin:
txt_filename = os.path.splitext(wav_filename)[0] + ".txt"
with open(txt_filename, "r") as fin:
transcript = fin.read()
set_files.append((wav_filename, wav_filesize, transcript))
return set_files
# Load all files, then deterministically split into train/dev/test sets
all_files = load_set(os.path.join(main_folder, '*.wav'))
all_files = load_set(os.path.join(main_folder, "*.wav"))
df = pandas.DataFrame(data=all_files, columns=COLUMN_NAMES)
df.sort_values(by='wav_filename', inplace=True)
df.sort_values(by="wav_filename", inplace=True)
indices = np.arange(0, len(df))
np.random.seed(12345)
@ -61,29 +55,33 @@ def preprocess_data(tgz_file, target_dir):
train_indices = indices[:-10000]
train_files = df.iloc[train_indices]
durations = (train_files['wav_filesize'] - 44) / 16000 / 2
durations = (train_files["wav_filesize"] - 44) / 16000 / 2
train_files = train_files[durations <= 10.0]
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum()))
dest_csv = os.path.join(target_dir, 'freestmandarin_train.csv')
print('Saving train set into {}...'.format(dest_csv))
print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
dest_csv = os.path.join(target_dir, "freestmandarin_train.csv")
print("Saving train set into {}...".format(dest_csv))
train_files.to_csv(dest_csv, index=False)
dev_files = df.iloc[dev_indices]
dest_csv = os.path.join(target_dir, 'freestmandarin_dev.csv')
print('Saving dev set into {}...'.format(dest_csv))
dest_csv = os.path.join(target_dir, "freestmandarin_dev.csv")
print("Saving dev set into {}...".format(dest_csv))
dev_files.to_csv(dest_csv, index=False)
test_files = df.iloc[test_indices]
dest_csv = os.path.join(target_dir, 'freestmandarin_test.csv')
print('Saving test set into {}...'.format(dest_csv))
dest_csv = os.path.join(target_dir, "freestmandarin_test.csv")
print("Saving test set into {}...".format(dest_csv))
test_files.to_csv(dest_csv, index=False)
def main():
# https://www.openslr.org/38/
parser = get_importers_parser(description='Import Free ST Chinese Mandarin corpus')
parser.add_argument('tgz_file', help='Path to ST-CMDS-20170001_1-OS.tar.gz')
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
parser = get_importers_parser(description="Import Free ST Chinese Mandarin corpus")
parser.add_argument("tgz_file", help="Path to ST-CMDS-20170001_1-OS.tar.gz")
parser.add_argument(
"--target_dir",
default="",
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
)
params = parser.parse_args()
if not params.target_dir:

View File

@ -1,24 +1,18 @@
#!/usr/bin/env python
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import csv
import math
import urllib
import logging
from util.importers import get_importers_parser, get_validate_label
import math
import os
import subprocess
from os import path
import urllib
from pathlib import Path
import swifter
import pandas as pd
from sox import Transformer
import swifter
from deepspeech_training.util.importers import get_importers_parser, get_validate_label
__version__ = "0.1.0"
_logger = logging.getLogger(__name__)
@ -40,9 +34,7 @@ def parse_args(args):
Returns:
:obj:`argparse.Namespace`: command line parameters namespace
"""
parser = get_importers_parser(
description="Imports GramVaani data for Deep Speech"
)
parser = get_importers_parser(description="Imports GramVaani data for Deep Speech")
parser.add_argument(
"--version",
action="version",
@ -82,6 +74,7 @@ def parse_args(args):
)
return parser.parse_args(args)
def setup_logging(level):
"""Setup basic logging
Args:
@ -92,6 +85,7 @@ def setup_logging(level):
level=level, stream=sys.stdout, format=format, datefmt="%Y-%m-%d %H:%M:%S"
)
class GramVaaniCSV:
"""GramVaaniCSV representing a GramVaani dataset.
Args:
@ -107,8 +101,17 @@ class GramVaaniCSV:
_logger.info("Parsing csv file...%s", os.path.abspath(csv_filename))
data = pd.read_csv(
os.path.abspath(csv_filename),
names=["piece_id","audio_url","transcript_labelled","transcript","labels","content_filename","audio_length","user_id"],
usecols=["audio_url","transcript","audio_length"],
names=[
"piece_id",
"audio_url",
"transcript_labelled",
"transcript",
"labels",
"content_filename",
"audio_length",
"user_id",
],
usecols=["audio_url", "transcript", "audio_length"],
skiprows=[0],
engine="python",
encoding="utf-8",
@ -119,6 +122,7 @@ class GramVaaniCSV:
_logger.info("Parsed %d lines csv file." % len(data))
return data
class GramVaaniDownloader:
"""GramVaaniDownloader downloads a GramVaani dataset.
Args:
@ -138,15 +142,17 @@ class GramVaaniDownloader:
mp3_directory (os.path): The directory into which the associated mp3's were downloaded
"""
mp3_directory = self._pre_download()
self.data.swifter.apply(func=lambda arg: self._download(*arg, mp3_directory), axis=1, raw=True)
self.data.swifter.apply(
func=lambda arg: self._download(*arg, mp3_directory), axis=1, raw=True
)
return mp3_directory
def _pre_download(self):
mp3_directory = path.join(self.target_dir, "mp3")
if not path.exists(self.target_dir):
mp3_directory = os.path.join(self.target_dir, "mp3")
if not os.path.exists(self.target_dir):
_logger.info("Creating directory...%s", self.target_dir)
os.mkdir(self.target_dir)
if not path.exists(mp3_directory):
if not os.path.exists(mp3_directory):
_logger.info("Creating directory...%s", mp3_directory)
os.mkdir(mp3_directory)
return mp3_directory
@ -154,13 +160,14 @@ class GramVaaniDownloader:
def _download(self, audio_url, transcript, audio_length, mp3_directory):
if audio_url == "audio_url":
return
mp3_filename = path.join(mp3_directory, os.path.basename(audio_url))
if not path.exists(mp3_filename):
mp3_filename = os.path.join(mp3_directory, os.path.basename(audio_url))
if not os.path.exists(mp3_filename):
_logger.debug("Downloading mp3 file...%s", audio_url)
urllib.request.urlretrieve(audio_url, mp3_filename)
else:
_logger.debug("Already downloaded mp3 file...%s", audio_url)
class GramVaaniConverter:
"""GramVaaniConverter converts the mp3's to wav's for a GramVaani dataset.
Args:
@ -181,37 +188,53 @@ class GramVaaniConverter:
wav_directory (os.path): The directory into which the associated wav's were downloaded
"""
wav_directory = self._pre_convert()
for mp3_filename in self.mp3_directory.glob('**/*.mp3'):
wav_filename = path.join(wav_directory, os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
if not path.exists(wav_filename):
_logger.debug("Converting mp3 file %s to wav file %s" % (mp3_filename, wav_filename))
for mp3_filename in self.mp3_directory.glob("**/*.mp3"):
wav_filename = os.path.join(
wav_directory,
os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav",
)
if not os.path.exists(wav_filename):
_logger.debug(
"Converting mp3 file %s to wav file %s"
% (mp3_filename, wav_filename)
)
transformer = Transformer()
transformer.convert(samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH)
transformer.convert(
samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH
)
transformer.build(str(mp3_filename), str(wav_filename))
else:
_logger.debug("Already converted mp3 file %s to wav file %s" % (mp3_filename, wav_filename))
_logger.debug(
"Already converted mp3 file %s to wav file %s"
% (mp3_filename, wav_filename)
)
return wav_directory
def _pre_convert(self):
wav_directory = path.join(self.target_dir, "wav")
if not path.exists(self.target_dir):
wav_directory = os.path.join(self.target_dir, "wav")
if not os.path.exists(self.target_dir):
_logger.info("Creating directory...%s", self.target_dir)
os.mkdir(self.target_dir)
if not path.exists(wav_directory):
if not os.path.exists(wav_directory):
_logger.info("Creating directory...%s", wav_directory)
os.mkdir(wav_directory)
return wav_directory
class GramVaaniDataSets:
def __init__(self, target_dir, wav_directory, gram_vaani_csv):
self.target_dir = target_dir
self.wav_directory = wav_directory
self.csv_data = gram_vaani_csv.data
self.raw = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
self.valid = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
self.train = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
self.dev = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
self.test = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
self.raw = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
self.valid = pd.DataFrame(
columns=["wav_filename", "wav_filesize", "transcript"]
)
self.train = pd.DataFrame(
columns=["wav_filename", "wav_filesize", "transcript"]
)
self.dev = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
self.test = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
def create(self):
self._convert_csv_data_to_raw_data()
@ -220,30 +243,45 @@ class GramVaaniDataSets:
self.valid = self.valid.sample(frac=1).reset_index(drop=True)
train_size, dev_size, test_size = self._calculate_data_set_sizes()
self.train = self.valid.loc[0:train_size]
self.dev = self.valid.loc[train_size:train_size+dev_size]
self.test = self.valid.loc[train_size+dev_size:train_size+dev_size+test_size]
self.dev = self.valid.loc[train_size : train_size + dev_size]
self.test = self.valid.loc[
train_size + dev_size : train_size + dev_size + test_size
]
def _convert_csv_data_to_raw_data(self):
self.raw[["wav_filename","wav_filesize","transcript"]] = self.csv_data[
["audio_url","transcript","audio_length"]
].swifter.apply(func=lambda arg: self._convert_csv_data_to_raw_data_impl(*arg), axis=1, raw=True)
self.raw[["wav_filename", "wav_filesize", "transcript"]] = self.csv_data[
["audio_url", "transcript", "audio_length"]
].swifter.apply(
func=lambda arg: self._convert_csv_data_to_raw_data_impl(*arg),
axis=1,
raw=True,
)
self.raw.reset_index()
def _convert_csv_data_to_raw_data_impl(self, audio_url, transcript, audio_length):
if audio_url == "audio_url":
return pd.Series(["wav_filename", "wav_filesize", "transcript"])
mp3_filename = os.path.basename(audio_url)
wav_relative_filename = path.join("wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
wav_filesize = path.getsize(path.join(self.target_dir, wav_relative_filename))
wav_relative_filename = os.path.join(
"wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav"
)
wav_filesize = os.path.getsize(
os.path.join(self.target_dir, wav_relative_filename)
)
transcript = validate_label(transcript)
if None == transcript:
transcript = ""
return pd.Series([wav_relative_filename, wav_filesize, transcript])
return pd.Series([wav_relative_filename, wav_filesize, transcript])
def _is_valid_raw_rows(self):
is_valid_raw_transcripts = self._is_valid_raw_transcripts()
is_valid_raw_wav_frames = self._is_valid_raw_wav_frames()
is_valid_raw_row = [(is_valid_raw_transcript & is_valid_raw_wav_frame) for is_valid_raw_transcript, is_valid_raw_wav_frame in zip(is_valid_raw_transcripts, is_valid_raw_wav_frames)]
is_valid_raw_row = [
(is_valid_raw_transcript & is_valid_raw_wav_frame)
for is_valid_raw_transcript, is_valid_raw_wav_frame in zip(
is_valid_raw_transcripts, is_valid_raw_wav_frames
)
]
series = pd.Series(is_valid_raw_row)
return series
@ -252,16 +290,29 @@ class GramVaaniDataSets:
def _is_valid_raw_wav_frames(self):
transcripts = [str(transcript) for transcript in self.raw.transcript]
wav_filepaths = [path.join(self.target_dir, str(wav_filename)) for wav_filename in self.raw.wav_filename]
wav_frames = [int(subprocess.check_output(['soxi', '-s', wav_filepath], stderr=subprocess.STDOUT)) for wav_filepath in wav_filepaths]
is_valid_raw_wav_frames = [self._is_wav_frame_valid(wav_frame, transcript) for wav_frame, transcript in zip(wav_frames, transcripts)]
wav_filepaths = [
os.path.join(self.target_dir, str(wav_filename))
for wav_filename in self.raw.wav_filename
]
wav_frames = [
int(
subprocess.check_output(
["soxi", "-s", wav_filepath], stderr=subprocess.STDOUT
)
)
for wav_filepath in wav_filepaths
]
is_valid_raw_wav_frames = [
self._is_wav_frame_valid(wav_frame, transcript)
for wav_frame, transcript in zip(wav_frames, transcripts)
]
return pd.Series(is_valid_raw_wav_frames)
def _is_wav_frame_valid(self, wav_frame, transcript):
is_wav_frame_valid = True
if int(wav_frame/SAMPLE_RATE*1000/10/2) < len(str(transcript)):
if int(wav_frame / SAMPLE_RATE * 1000 / 10 / 2) < len(str(transcript)):
is_wav_frame_valid = False
elif wav_frame/SAMPLE_RATE > MAX_SECS:
elif wav_frame / SAMPLE_RATE > MAX_SECS:
is_wav_frame_valid = False
return is_wav_frame_valid
@ -280,7 +331,14 @@ class GramVaaniDataSets:
def _save(self, dataset):
dataset_path = os.path.join(self.target_dir, dataset + ".csv")
dataframe = getattr(self, dataset)
dataframe.to_csv(dataset_path, index=False, encoding="utf-8", escapechar='\\', quoting=csv.QUOTE_MINIMAL)
dataframe.to_csv(
dataset_path,
index=False,
encoding="utf-8",
escapechar="\\",
quoting=csv.QUOTE_MINIMAL,
)
def main(args):
"""Main entry point allowing external calls
@ -304,4 +362,5 @@ def main(args):
datasets.save()
_logger.info("Finished GramVaani importer...")
main(sys.argv[1:])

View File

@ -1,28 +1,33 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import sys
import pandas
from util.downloader import maybe_download
from deepspeech_training.util.downloader import maybe_download
def _download_and_preprocess_data(data_dir):
# Conditionally download data
LDC93S1_BASE = "LDC93S1"
LDC93S1_BASE_URL = "https://catalog.ldc.upenn.edu/desc/addenda/"
local_file = maybe_download(LDC93S1_BASE + ".wav", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".wav")
trans_file = maybe_download(LDC93S1_BASE + ".txt", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".txt")
local_file = maybe_download(
LDC93S1_BASE + ".wav", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".wav"
)
trans_file = maybe_download(
LDC93S1_BASE + ".txt", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".txt"
)
with open(trans_file, "r") as fin:
transcript = ' '.join(fin.read().strip().lower().split(' ')[2:]).replace('.', '')
transcript = " ".join(fin.read().strip().lower().split(" ")[2:]).replace(
".", ""
)
df = pandas.DataFrame(data=[(os.path.abspath(local_file), os.path.getsize(local_file), transcript)],
columns=["wav_filename", "wav_filesize", "transcript"])
df = pandas.DataFrame(
data=[(os.path.abspath(local_file), os.path.getsize(local_file), transcript)],
columns=["wav_filename", "wav_filesize", "transcript"],
)
df.to_csv(os.path.join(data_dir, "ldc93s1.csv"), index=False)
if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1])

View File

@ -1,33 +1,39 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import codecs
import fnmatch
import pandas
import progressbar
import os
import subprocess
import sys
import tarfile
import unicodedata
import pandas
import progressbar
from sox import Transformer
from util.downloader import maybe_download
from tensorflow.python.platform import gfile
from deepspeech_training.util.downloader import maybe_download
SAMPLE_RATE = 16000
def _download_and_preprocess_data(data_dir):
# Conditionally download data to data_dir
print("Downloading Librivox data set (55GB) into {} if not already present...".format(data_dir))
print(
"Downloading Librivox data set (55GB) into {} if not already present...".format(
data_dir
)
)
with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar:
TRAIN_CLEAN_100_URL = "http://www.openslr.org/resources/12/train-clean-100.tar.gz"
TRAIN_CLEAN_360_URL = "http://www.openslr.org/resources/12/train-clean-360.tar.gz"
TRAIN_OTHER_500_URL = "http://www.openslr.org/resources/12/train-other-500.tar.gz"
TRAIN_CLEAN_100_URL = (
"http://www.openslr.org/resources/12/train-clean-100.tar.gz"
)
TRAIN_CLEAN_360_URL = (
"http://www.openslr.org/resources/12/train-clean-360.tar.gz"
)
TRAIN_OTHER_500_URL = (
"http://www.openslr.org/resources/12/train-other-500.tar.gz"
)
DEV_CLEAN_URL = "http://www.openslr.org/resources/12/dev-clean.tar.gz"
DEV_OTHER_URL = "http://www.openslr.org/resources/12/dev-other.tar.gz"
@ -35,12 +41,20 @@ def _download_and_preprocess_data(data_dir):
TEST_CLEAN_URL = "http://www.openslr.org/resources/12/test-clean.tar.gz"
TEST_OTHER_URL = "http://www.openslr.org/resources/12/test-other.tar.gz"
def filename_of(x): return os.path.split(x)[1]
train_clean_100 = maybe_download(filename_of(TRAIN_CLEAN_100_URL), data_dir, TRAIN_CLEAN_100_URL)
def filename_of(x):
return os.path.split(x)[1]
train_clean_100 = maybe_download(
filename_of(TRAIN_CLEAN_100_URL), data_dir, TRAIN_CLEAN_100_URL
)
bar.update(0)
train_clean_360 = maybe_download(filename_of(TRAIN_CLEAN_360_URL), data_dir, TRAIN_CLEAN_360_URL)
train_clean_360 = maybe_download(
filename_of(TRAIN_CLEAN_360_URL), data_dir, TRAIN_CLEAN_360_URL
)
bar.update(1)
train_other_500 = maybe_download(filename_of(TRAIN_OTHER_500_URL), data_dir, TRAIN_OTHER_500_URL)
train_other_500 = maybe_download(
filename_of(TRAIN_OTHER_500_URL), data_dir, TRAIN_OTHER_500_URL
)
bar.update(2)
dev_clean = maybe_download(filename_of(DEV_CLEAN_URL), data_dir, DEV_CLEAN_URL)
@ -48,9 +62,13 @@ def _download_and_preprocess_data(data_dir):
dev_other = maybe_download(filename_of(DEV_OTHER_URL), data_dir, DEV_OTHER_URL)
bar.update(4)
test_clean = maybe_download(filename_of(TEST_CLEAN_URL), data_dir, TEST_CLEAN_URL)
test_clean = maybe_download(
filename_of(TEST_CLEAN_URL), data_dir, TEST_CLEAN_URL
)
bar.update(5)
test_other = maybe_download(filename_of(TEST_OTHER_URL), data_dir, TEST_OTHER_URL)
test_other = maybe_download(
filename_of(TEST_OTHER_URL), data_dir, TEST_OTHER_URL
)
bar.update(6)
# Conditionally extract LibriSpeech data
@ -61,11 +79,17 @@ def _download_and_preprocess_data(data_dir):
LIBRIVOX_DIR = "LibriSpeech"
work_dir = os.path.join(data_dir, LIBRIVOX_DIR)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-100"), train_clean_100)
_maybe_extract(
data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-100"), train_clean_100
)
bar.update(0)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-360"), train_clean_360)
_maybe_extract(
data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-360"), train_clean_360
)
bar.update(1)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-other-500"), train_other_500)
_maybe_extract(
data_dir, os.path.join(LIBRIVOX_DIR, "train-other-500"), train_other_500
)
bar.update(2)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "dev-clean"), dev_clean)
@ -91,28 +115,48 @@ def _download_and_preprocess_data(data_dir):
# data_dir/LibriSpeech/split-wav/1-2-2.txt
# ...
print("Converting FLAC to WAV and splitting transcriptions...")
with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar:
train_100 = _convert_audio_and_split_sentences(work_dir, "train-clean-100", "train-clean-100-wav")
with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar:
train_100 = _convert_audio_and_split_sentences(
work_dir, "train-clean-100", "train-clean-100-wav"
)
bar.update(0)
train_360 = _convert_audio_and_split_sentences(work_dir, "train-clean-360", "train-clean-360-wav")
train_360 = _convert_audio_and_split_sentences(
work_dir, "train-clean-360", "train-clean-360-wav"
)
bar.update(1)
train_500 = _convert_audio_and_split_sentences(work_dir, "train-other-500", "train-other-500-wav")
train_500 = _convert_audio_and_split_sentences(
work_dir, "train-other-500", "train-other-500-wav"
)
bar.update(2)
dev_clean = _convert_audio_and_split_sentences(work_dir, "dev-clean", "dev-clean-wav")
dev_clean = _convert_audio_and_split_sentences(
work_dir, "dev-clean", "dev-clean-wav"
)
bar.update(3)
dev_other = _convert_audio_and_split_sentences(work_dir, "dev-other", "dev-other-wav")
dev_other = _convert_audio_and_split_sentences(
work_dir, "dev-other", "dev-other-wav"
)
bar.update(4)
test_clean = _convert_audio_and_split_sentences(work_dir, "test-clean", "test-clean-wav")
test_clean = _convert_audio_and_split_sentences(
work_dir, "test-clean", "test-clean-wav"
)
bar.update(5)
test_other = _convert_audio_and_split_sentences(work_dir, "test-other", "test-other-wav")
test_other = _convert_audio_and_split_sentences(
work_dir, "test-other", "test-other-wav"
)
bar.update(6)
# Write sets to disk as CSV files
train_100.to_csv(os.path.join(data_dir, "librivox-train-clean-100.csv"), index=False)
train_360.to_csv(os.path.join(data_dir, "librivox-train-clean-360.csv"), index=False)
train_500.to_csv(os.path.join(data_dir, "librivox-train-other-500.csv"), index=False)
train_100.to_csv(
os.path.join(data_dir, "librivox-train-clean-100.csv"), index=False
)
train_360.to_csv(
os.path.join(data_dir, "librivox-train-clean-360.csv"), index=False
)
train_500.to_csv(
os.path.join(data_dir, "librivox-train-other-500.csv"), index=False
)
dev_clean.to_csv(os.path.join(data_dir, "librivox-dev-clean.csv"), index=False)
dev_other.to_csv(os.path.join(data_dir, "librivox-dev-other.csv"), index=False)
@ -120,6 +164,7 @@ def _download_and_preprocess_data(data_dir):
test_clean.to_csv(os.path.join(data_dir, "librivox-test-clean.csv"), index=False)
test_other.to_csv(os.path.join(data_dir, "librivox-test-other.csv"), index=False)
def _maybe_extract(data_dir, extracted_data, archive):
# If data_dir/extracted_data does not exist, extract archive in data_dir
if not gfile.Exists(os.path.join(data_dir, extracted_data)):
@ -127,6 +172,7 @@ def _maybe_extract(data_dir, extracted_data, archive):
tar.extractall(data_dir)
tar.close()
def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
source_dir = os.path.join(extracted_dir, data_set)
target_dir = os.path.join(extracted_dir, dest_dir)
@ -149,20 +195,22 @@ def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
# We also convert the corresponding FLACs to WAV in the same pass
files = []
for root, dirnames, filenames in os.walk(source_dir):
for filename in fnmatch.filter(filenames, '*.trans.txt'):
for filename in fnmatch.filter(filenames, "*.trans.txt"):
trans_filename = os.path.join(root, filename)
with codecs.open(trans_filename, "r", "utf-8") as fin:
for line in fin:
# Parse each segment line
first_space = line.find(" ")
seqid, transcript = line[:first_space], line[first_space+1:]
seqid, transcript = line[:first_space], line[first_space + 1 :]
# We need to do the encode-decode dance here because encode
# returns a bytes() object on Python 3, and text_to_char_array
# expects a string.
transcript = unicodedata.normalize("NFKD", transcript) \
.encode("ascii", "ignore") \
.decode("ascii", "ignore")
transcript = (
unicodedata.normalize("NFKD", transcript)
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
transcript = transcript.lower().strip()
@ -177,7 +225,10 @@ def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
files.append((os.path.abspath(wav_file), wav_filesize, transcript))
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
return pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
)
if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1])

View File

@ -1,44 +1,39 @@
#!/usr/bin/env python3
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import argparse
import csv
import os
import re
import sox
import zipfile
import subprocess
import progressbar
import unicodedata
from multiprocessing import Pool
from util.downloader import SIMPLE_BAR
from os import path
import zipfile
from glob import glob
from multiprocessing import Pool
from util.downloader import maybe_download
from util.text import Alphabet
import progressbar
import sox
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import (
get_counter,
get_imported_samples,
get_importers_parser,
get_validate_label,
print_import_report,
)
from deepspeech_training.util.text import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000
MAX_SECS = 10
ARCHIVE_DIR_NAME = 'lingua_libre'
ARCHIVE_NAME = 'Q{qId}-{iso639_3}-{language_English_name}.zip'
ARCHIVE_URL = 'https://lingualibre.fr/datasets/' + ARCHIVE_NAME
ARCHIVE_DIR_NAME = "lingua_libre"
ARCHIVE_NAME = "Q{qId}-{iso639_3}-{language_English_name}.zip"
ARCHIVE_URL = "https://lingualibre.fr/datasets/" + ARCHIVE_NAME
def _download_and_preprocess_data(target_dir):
# Making path absolute
target_dir = path.abspath(target_dir)
target_dir = os.path.abspath(target_dir)
# Conditionally download data
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
# Conditionally extract data
@ -46,10 +41,11 @@ def _download_and_preprocess_data(target_dir):
# Produce CSV files and convert ogg data to wav
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data)
if not path.exists(extracted_path):
extracted_path = os.path.join(target_dir, extracted_data)
if not os.path.exists(extracted_path):
print('No directory "%s" - extracting archive...' % extracted_path)
if not os.path.isdir(extracted_path):
os.mkdir(extracted_path)
@ -58,57 +54,70 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
else:
print('Found directory "%s" - not extracting it from archive.' % archive_path)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
ogg_filename = sample[0]
# Storing wav files next to the ogg ones - just with a different suffix
wav_filename = path.splitext(ogg_filename)[0] + ".wav"
wav_filename = os.path.splitext(ogg_filename)[0] + ".wav"
_maybe_convert_wav(ogg_filename, wav_filename)
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
if os.path.exists(wav_filename):
file_size = os.path.getsize(wav_filename)
frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = label_filter(sample[1])
rows = []
counter = get_counter()
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
counter["failed"] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
counter["invalid_label"] += 1
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
counter["too_short"] += 1
elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
counter["too_long"] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
counter["all"] += 1
counter["total_time"] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data)
extracted_dir = os.path.join(target_dir, extracted_data)
# override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME + '_' + ARCHIVE_NAME.replace('.zip', '_{}.csv'))
target_csv_template = os.path.join(
target_dir, ARCHIVE_DIR_NAME + "_" + ARCHIVE_NAME.replace(".zip", "_{}.csv")
)
if os.path.isfile(target_csv_template):
return
ogg_root_dir = os.path.join(extracted_dir, ARCHIVE_NAME.replace('.zip', ''))
ogg_root_dir = os.path.join(extracted_dir, ARCHIVE_NAME.replace(".zip", ""))
# Get audiofile path and transcript for each sentence in tsv
samples = []
glob_dir = os.path.join(ogg_root_dir, '**/*.ogg')
glob_dir = os.path.join(ogg_root_dir, "**/*.ogg")
for record in glob(glob_dir, recursive=True):
record_file = record.replace(ogg_root_dir + os.path.sep, '')
record_file = record.replace(ogg_root_dir + os.path.sep, "")
if record_filter(record_file):
samples.append((os.path.join(ogg_root_dir, record_file), os.path.splitext(os.path.basename(record_file))[0]))
samples.append(
(
os.path.join(ogg_root_dir, record_file),
os.path.splitext(os.path.basename(record_file))[0],
)
)
counter = get_counter()
num_samples = len(samples)
@ -125,9 +134,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
pool.close()
pool.join()
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80%
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10%
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10%
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
train_writer.writeheader()
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
@ -139,7 +148,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
transcript = validate_label(item[2])
if not transcript:
continue
wav_filename = os.path.join(ogg_root_dir, item[0].replace('.ogg', '.wav'))
wav_filename = os.path.join(
ogg_root_dir, item[0].replace(".ogg", ".wav")
)
i_mod = i % 10
if i_mod == 0:
writer = test_writer
@ -147,38 +158,63 @@ def _maybe_convert_sets(target_dir, extracted_data):
writer = dev_writer
else:
writer = train_writer
writer.writerow(dict(
wav_filename=wav_filename,
wav_filesize=os.path.getsize(wav_filename),
transcript=transcript,
))
writer.writerow(
dict(
wav_filename=wav_filename,
wav_filesize=os.path.getsize(wav_filename),
transcript=transcript,
)
)
imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples
assert counter["all"] == num_samples
assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(ogg_filename, wav_filename):
if not path.exists(wav_filename):
if not os.path.exists(wav_filename):
transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE)
try:
transformer.build(ogg_filename, wav_filename)
except sox.core.SoxError as ex:
print('SoX processing error', ex, ogg_filename, wav_filename)
print("SoX processing error", ex, ogg_filename, wav_filename)
def handle_args():
parser = get_importers_parser(description='Importer for LinguaLibre dataset. Check https://lingualibre.fr/wiki/Help:Download_from_LinguaLibre for details.')
parser.add_argument(dest='target_dir')
parser.add_argument('--qId', type=int, required=True, help='LinguaLibre language qId')
parser.add_argument('--iso639-3', type=str, required=True, help='ISO639-3 language code')
parser.add_argument('--english-name', type=str, required=True, help='Enligh name of the language')
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
parser.add_argument('--bogus-records', type=argparse.FileType('r'), required=False, help='Text file listing well-known bogus record to skip from importing, from https://lingualibre.fr/wiki/LinguaLibre:Misleading_items')
parser = get_importers_parser(
description="Importer for LinguaLibre dataset. Check https://lingualibre.fr/wiki/Help:Download_from_LinguaLibre for details."
)
parser.add_argument(dest="target_dir")
parser.add_argument(
"--qId", type=int, required=True, help="LinguaLibre language qId"
)
parser.add_argument(
"--iso639-3", type=str, required=True, help="ISO639-3 language code"
)
parser.add_argument(
"--english-name", type=str, required=True, help="Enligh name of the language"
)
parser.add_argument(
"--filter_alphabet",
help="Exclude samples with characters not in provided alphabet",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
parser.add_argument(
"--bogus-records",
type=argparse.FileType("r"),
required=False,
help="Text file listing well-known bogus record to skip from importing, from https://lingualibre.fr/wiki/LinguaLibre:Misleading_items",
)
return parser.parse_args()
if __name__ == "__main__":
CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
@ -191,15 +227,17 @@ if __name__ == "__main__":
def record_filter(path):
if any(regex.match(path) for regex in bogus_regexes):
print('Reject', path)
print("Reject", path)
return False
return True
def label_filter(label):
if CLI_ARGS.normalize:
label = unicodedata.normalize("NFKD", label.strip()) \
.encode("ascii", "ignore") \
label = (
unicodedata.normalize("NFKD", label.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
label = validate_label(label)
if ALPHABET and label:
try:
@ -208,6 +246,14 @@ if __name__ == "__main__":
label = None
return label
ARCHIVE_NAME = ARCHIVE_NAME.format(qId=CLI_ARGS.qId, iso639_3=CLI_ARGS.iso639_3, language_English_name=CLI_ARGS.english_name)
ARCHIVE_URL = ARCHIVE_URL.format(qId=CLI_ARGS.qId, iso639_3=CLI_ARGS.iso639_3, language_English_name=CLI_ARGS.english_name)
ARCHIVE_NAME = ARCHIVE_NAME.format(
qId=CLI_ARGS.qId,
iso639_3=CLI_ARGS.iso639_3,
language_English_name=CLI_ARGS.english_name,
)
ARCHIVE_URL = ARCHIVE_URL.format(
qId=CLI_ARGS.qId,
iso639_3=CLI_ARGS.iso639_3,
language_English_name=CLI_ARGS.english_name,
)
_download_and_preprocess_data(target_dir=CLI_ARGS.target_dir)

View File

@ -1,43 +1,37 @@
#!/usr/bin/env python3
# pylint: disable=invalid-name
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv
import os
import subprocess
import progressbar
import unicodedata
import tarfile
from multiprocessing import Pool
from util.downloader import SIMPLE_BAR
from os import path
import unicodedata
from glob import glob
from multiprocessing import Pool
from util.downloader import maybe_download
from util.text import Alphabet
import progressbar
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import (
get_counter,
get_imported_samples,
get_importers_parser,
get_validate_label,
print_import_report,
)
from deepspeech_training.util.text import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000
MAX_SECS = 15
ARCHIVE_DIR_NAME = '{language}'
ARCHIVE_NAME = '{language}.tgz'
ARCHIVE_URL = 'http://www.caito.de/data/Training/stt_tts/' + ARCHIVE_NAME
ARCHIVE_DIR_NAME = "{language}"
ARCHIVE_NAME = "{language}.tgz"
ARCHIVE_URL = "http://www.caito.de/data/Training/stt_tts/" + ARCHIVE_NAME
def _download_and_preprocess_data(target_dir):
# Making path absolute
target_dir = path.abspath(target_dir)
target_dir = os.path.abspath(target_dir)
# Conditionally download data
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
# Conditionally extract data
@ -48,8 +42,8 @@ def _download_and_preprocess_data(target_dir):
def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data)
if not path.exists(extracted_path):
extracted_path = os.path.join(target_dir, extracted_data)
if not os.path.exists(extracted_path):
print('No directory "%s" - extracting archive...' % extracted_path)
if not os.path.isdir(extracted_path):
os.mkdir(extracted_path)
@ -65,9 +59,13 @@ def one_sample(sample):
wav_filename = sample[0]
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
if os.path.exists(wav_filename):
file_size = os.path.getsize(wav_filename)
frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = label_filter(sample[1])
counter = get_counter()
rows = []
@ -75,27 +73,30 @@ def one_sample(sample):
if file_size == -1:
# Excluding samples that failed upon conversion
print("conversion failure", wav_filename)
counter['failed'] += 1
counter["failed"] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
counter["invalid_label"] += 1
elif int(frames / SAMPLE_RATE * 1000 / 15 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
counter["too_short"] += 1
elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
counter["too_long"] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
counter["all"] += 1
counter["total_time"] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data)
extracted_dir = os.path.join(target_dir, extracted_data)
# override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tgz', '_{}.csv'))
target_csv_template = os.path.join(
target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace(".tgz", "_{}.csv")
)
if os.path.isfile(target_csv_template):
return
@ -103,14 +104,16 @@ def _maybe_convert_sets(target_dir, extracted_data):
# Get audiofile path and transcript for each sentence in tsv
samples = []
glob_dir = os.path.join(wav_root_dir, '**/metadata.csv')
glob_dir = os.path.join(wav_root_dir, "**/metadata.csv")
for record in glob(glob_dir, recursive=True):
if any(map(lambda sk: sk in record, SKIP_LIST)): # pylint: disable=cell-var-from-loop
if any(
map(lambda sk: sk in record, SKIP_LIST)
): # pylint: disable=cell-var-from-loop
continue
with open(record, 'r') as rec:
with open(record, "r") as rec:
for re in rec.readlines():
re = re.strip().split('|')
audio = os.path.join(os.path.dirname(record), 'wavs', re[0] + '.wav')
re = re.strip().split("|")
audio = os.path.join(os.path.dirname(record), "wavs", re[0] + ".wav")
transcript = re[2]
samples.append((audio, transcript))
@ -129,9 +132,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
pool.close()
pool.join()
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80%
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10%
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10%
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
train_writer.writeheader()
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
@ -151,39 +154,60 @@ def _maybe_convert_sets(target_dir, extracted_data):
writer = dev_writer
else:
writer = train_writer
writer.writerow(dict(
wav_filename=os.path.relpath(wav_filename, extracted_dir),
wav_filesize=os.path.getsize(wav_filename),
transcript=transcript,
))
writer.writerow(
dict(
wav_filename=os.path.relpath(wav_filename, extracted_dir),
wav_filesize=os.path.getsize(wav_filename),
transcript=transcript,
)
)
imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples
assert counter["all"] == num_samples
assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def handle_args():
parser = get_importers_parser(description='Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/.')
parser.add_argument(dest='target_dir')
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
parser.add_argument('--skiplist', type=str, default='', help='Directories / books to skip, comma separated')
parser.add_argument('--language', required=True, type=str, help='Dataset language to use')
parser = get_importers_parser(
description="Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/."
)
parser.add_argument(dest="target_dir")
parser.add_argument(
"--filter_alphabet",
help="Exclude samples with characters not in provided alphabet",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
parser.add_argument(
"--skiplist",
type=str,
default="",
help="Directories / books to skip, comma separated",
)
parser.add_argument(
"--language", required=True, type=str, help="Dataset language to use"
)
return parser.parse_args()
if __name__ == "__main__":
CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(','))
SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(","))
validate_label = get_validate_label(CLI_ARGS)
def label_filter(label):
if CLI_ARGS.normalize:
label = unicodedata.normalize("NFKD", label.strip()) \
.encode("ascii", "ignore") \
label = (
unicodedata.normalize("NFKD", label.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
label = validate_label(label)
if ALPHABET and label:
try:

View File

@ -1,30 +1,24 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
import glob
import pandas
import os
import tarfile
import wave
import pandas
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
from deepspeech_training.util.importers import get_importers_parser
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
def extract(archive_path, target_dir):
print('Extracting {} into {}...'.format(archive_path, target_dir))
print("Extracting {} into {}...".format(archive_path, target_dir))
with tarfile.open(archive_path) as tar:
tar.extractall(target_dir)
def is_file_truncated(wav_filename, wav_filesize):
with wave.open(wav_filename, mode='rb') as fin:
with wave.open(wav_filename, mode="rb") as fin:
assert fin.getframerate() == 16000
assert fin.getsampwidth() == 2
assert fin.getnchannels() == 1
@ -37,8 +31,13 @@ def is_file_truncated(wav_filename, wav_filesize):
def preprocess_data(folder_with_archives, target_dir):
# First extract subset archives
for subset in ('train', 'dev', 'test'):
extract(os.path.join(folder_with_archives, 'magicdata_{}_set.tar.gz'.format(subset)), target_dir)
for subset in ("train", "dev", "test"):
extract(
os.path.join(
folder_with_archives, "magicdata_{}_set.tar.gz".format(subset)
),
target_dir,
)
# Folder structure is now:
# - magicdata_{train,dev,test}.tar.gz
@ -54,58 +53,73 @@ def preprocess_data(folder_with_archives, target_dir):
# name, one containing the speaker ID, and one containing the transcription
def load_set(set_path):
transcripts = pandas.read_csv(os.path.join(set_path, 'TRANS.txt'), sep='\t', index_col=0)
glob_path = os.path.join(set_path, '*', '*.wav')
transcripts = pandas.read_csv(
os.path.join(set_path, "TRANS.txt"), sep="\t", index_col=0
)
glob_path = os.path.join(set_path, "*", "*.wav")
set_files = []
for wav in glob.glob(glob_path):
try:
wav_filename = wav
wav_filesize = os.path.getsize(wav)
transcript_key = os.path.basename(wav)
transcript = transcripts.loc[transcript_key, 'Transcription']
transcript = transcripts.loc[transcript_key, "Transcription"]
# Some files in this dataset are truncated, the header duration
# doesn't match the file size. This causes errors at training
# time, so check here if things are fine before including a file
if is_file_truncated(wav_filename, wav_filesize):
print('Warning: File {} is corrupted, header duration does '
'not match file size. Ignoring.'.format(wav_filename))
print(
"Warning: File {} is corrupted, header duration does "
"not match file size. Ignoring.".format(wav_filename)
)
continue
set_files.append((wav_filename, wav_filesize, transcript))
except KeyError:
print('Warning: Missing transcript for WAV file {}.'.format(wav))
print("Warning: Missing transcript for WAV file {}.".format(wav))
return set_files
for subset in ('train', 'dev', 'test'):
print('Loading {} set samples...'.format(subset))
for subset in ("train", "dev", "test"):
print("Loading {} set samples...".format(subset))
subset_files = load_set(os.path.join(target_dir, subset))
df = pandas.DataFrame(data=subset_files, columns=COLUMN_NAMES)
# Trim train set to under 10s
if subset == 'train':
durations = (df['wav_filesize'] - 44) / 16000 / 2
if subset == "train":
durations = (df["wav_filesize"] - 44) / 16000 / 2
df = df[durations <= 10.0]
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum()))
with_noise = df['transcript'].str.contains(r'\[(FIL|SPK)\]')
df = df[~with_noise]
print('Trimming {} samples with noise ([FIL] or [SPK])'.format(sum(with_noise)))
print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
dest_csv = os.path.join(target_dir, 'magicdata_{}.csv'.format(subset))
print('Saving {} set into {}...'.format(subset, dest_csv))
with_noise = df["transcript"].str.contains(r"\[(FIL|SPK)\]")
df = df[~with_noise]
print(
"Trimming {} samples with noise ([FIL] or [SPK])".format(
sum(with_noise)
)
)
dest_csv = os.path.join(target_dir, "magicdata_{}.csv".format(subset))
print("Saving {} set into {}...".format(subset, dest_csv))
df.to_csv(dest_csv, index=False)
def main():
# https://openslr.org/68/
parser = get_importers_parser(description='Import MAGICDATA corpus')
parser.add_argument('folder_with_archives', help='Path to folder containing magicdata_{train,dev,test}.tar.gz')
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to a folder called magicdata next to the archives')
parser = get_importers_parser(description="Import MAGICDATA corpus")
parser.add_argument(
"folder_with_archives",
help="Path to folder containing magicdata_{train,dev,test}.tar.gz",
)
parser.add_argument(
"--target_dir",
default="",
help="Target folder to extract files into and put the resulting CSVs. Defaults to a folder called magicdata next to the archives",
)
params = parser.parse_args()
if not params.target_dir:
params.target_dir = os.path.join(params.folder_with_archives, 'magicdata')
params.target_dir = os.path.join(params.folder_with_archives, "magicdata")
preprocess_data(params.folder_with_archives, params.target_dir)

View File

@ -1,25 +1,19 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
import glob
import json
import numpy as np
import pandas
import os
import tarfile
import numpy as np
import pandas
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
from deepspeech_training.util.importers import get_importers_parser
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
def extract(archive_path, target_dir):
print('Extracting {} into {}...'.format(archive_path, target_dir))
print("Extracting {} into {}...".format(archive_path, target_dir))
with tarfile.open(archive_path) as tar:
tar.extractall(target_dir)
@ -27,7 +21,7 @@ def extract(archive_path, target_dir):
def preprocess_data(tgz_file, target_dir):
# First extract main archive and sub-archives
extract(tgz_file, target_dir)
main_folder = os.path.join(target_dir, 'primewords_md_2018_set1')
main_folder = os.path.join(target_dir, "primewords_md_2018_set1")
# Folder structure is now:
# - primewords_md_2018_set1/
@ -35,14 +29,11 @@ def preprocess_data(tgz_file, target_dir):
# - [0-f]/[00-0f]/*.wav
# - set1_transcript.json
transcripts_path = os.path.join(main_folder, 'set1_transcript.json')
transcripts_path = os.path.join(main_folder, "set1_transcript.json")
with open(transcripts_path) as fin:
transcripts = json.load(fin)
transcripts = {
entry['file']: entry['text']
for entry in transcripts
}
transcripts = {entry["file"]: entry["text"] for entry in transcripts}
def load_set(glob_path):
set_files = []
@ -54,13 +45,13 @@ def preprocess_data(tgz_file, target_dir):
transcript = transcripts[transcript_key]
set_files.append((wav_filename, wav_filesize, transcript))
except KeyError:
print('Warning: Missing transcript for WAV file {}.'.format(wav))
print("Warning: Missing transcript for WAV file {}.".format(wav))
return set_files
# Load all files, then deterministically split into train/dev/test sets
all_files = load_set(os.path.join(main_folder, 'audio_files', '*', '*', '*.wav'))
all_files = load_set(os.path.join(main_folder, "audio_files", "*", "*", "*.wav"))
df = pandas.DataFrame(data=all_files, columns=COLUMN_NAMES)
df.sort_values(by='wav_filename', inplace=True)
df.sort_values(by="wav_filename", inplace=True)
indices = np.arange(0, len(df))
np.random.seed(12345)
@ -73,29 +64,33 @@ def preprocess_data(tgz_file, target_dir):
train_indices = indices[:-10000]
train_files = df.iloc[train_indices]
durations = (train_files['wav_filesize'] - 44) / 16000 / 2
durations = (train_files["wav_filesize"] - 44) / 16000 / 2
train_files = train_files[durations <= 15.0]
print('Trimming {} samples > 15 seconds'.format((durations > 15.0).sum()))
dest_csv = os.path.join(target_dir, 'primewords_train.csv')
print('Saving train set into {}...'.format(dest_csv))
print("Trimming {} samples > 15 seconds".format((durations > 15.0).sum()))
dest_csv = os.path.join(target_dir, "primewords_train.csv")
print("Saving train set into {}...".format(dest_csv))
train_files.to_csv(dest_csv, index=False)
dev_files = df.iloc[dev_indices]
dest_csv = os.path.join(target_dir, 'primewords_dev.csv')
print('Saving dev set into {}...'.format(dest_csv))
dest_csv = os.path.join(target_dir, "primewords_dev.csv")
print("Saving dev set into {}...".format(dest_csv))
dev_files.to_csv(dest_csv, index=False)
test_files = df.iloc[test_indices]
dest_csv = os.path.join(target_dir, 'primewords_test.csv')
print('Saving test set into {}...'.format(dest_csv))
dest_csv = os.path.join(target_dir, "primewords_test.csv")
print("Saving test set into {}...".format(dest_csv))
test_files.to_csv(dest_csv, index=False)
def main():
# https://www.openslr.org/47/
parser = get_importers_parser(description='Import Primewords Chinese corpus set 1')
parser.add_argument('tgz_file', help='Path to primewords_md_2018_set1.tar.gz')
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
parser = get_importers_parser(description="Import Primewords Chinese corpus set 1")
parser.add_argument("tgz_file", help="Path to primewords_md_2018_set1.tar.gz")
parser.add_argument(
"--target_dir",
default="",
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
)
params = parser.parse_args()
if not params.target_dir:

View File

@ -1,45 +1,39 @@
#!/usr/bin/env python3
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv
import os
import re
import sox
import zipfile
import subprocess
import progressbar
import unicodedata
import tarfile
from multiprocessing import Pool
from util.downloader import SIMPLE_BAR
from os import path
import unicodedata
import zipfile
from glob import glob
from multiprocessing import Pool
from util.downloader import maybe_download
from util.text import Alphabet
from util.helpers import secs_to_hours
import progressbar
import sox
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import (
get_counter,
get_imported_samples,
get_importers_parser,
get_validate_label,
print_import_report,
)
from deepspeech_training.util.text import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000
MAX_SECS = 15
ARCHIVE_DIR_NAME = 'African_Accented_French'
ARCHIVE_NAME = 'African_Accented_French.tar.gz'
ARCHIVE_URL = 'http://www.openslr.org/resources/57/' + ARCHIVE_NAME
ARCHIVE_DIR_NAME = "African_Accented_French"
ARCHIVE_NAME = "African_Accented_French.tar.gz"
ARCHIVE_URL = "http://www.openslr.org/resources/57/" + ARCHIVE_NAME
def _download_and_preprocess_data(target_dir):
# Making path absolute
target_dir = path.abspath(target_dir)
target_dir = os.path.abspath(target_dir)
# Conditionally download data
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
# Conditionally extract data
@ -47,10 +41,11 @@ def _download_and_preprocess_data(target_dir):
# Produce CSV files
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data)
if not path.exists(extracted_path):
extracted_path = os.path.join(target_dir, extracted_data)
if not os.path.exists(extracted_path):
print('No directory "%s" - extracting archive...' % extracted_path)
if not os.path.isdir(extracted_path):
os.mkdir(extracted_path)
@ -60,81 +55,89 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
else:
print('Found directory "%s" - not extracting it from archive.' % archive_path)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
wav_filename = sample[0]
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
if os.path.exists(wav_filename):
file_size = os.path.getsize(wav_filename)
frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = label_filter(sample[1])
counter = get_counter()
rows = []
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
counter["failed"] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
counter["invalid_label"] += 1
elif int(frames / SAMPLE_RATE * 1000 / 15 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
counter["too_short"] += 1
elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
counter["too_long"] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
counter["all"] += 1
counter["total_time"] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data)
extracted_dir = os.path.join(target_dir, extracted_data)
# override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tar.gz', '_{}.csv'))
target_csv_template = os.path.join(
target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace(".tar.gz", "_{}.csv")
)
if os.path.isfile(target_csv_template):
return
wav_root_dir = os.path.join(extracted_dir)
all_files = [
'transcripts/train/yaounde/fn_text.txt',
'transcripts/train/ca16_conv/transcripts.txt',
'transcripts/train/ca16_read/conditioned.txt',
'transcripts/dev/niger_west_african_fr/transcripts.txt',
'speech/dev/niger_west_african_fr/niger_wav_file_name_transcript.tsv',
'transcripts/devtest/ca16_read/conditioned.txt',
'transcripts/test/ca16/prompts.txt',
"transcripts/train/yaounde/fn_text.txt",
"transcripts/train/ca16_conv/transcripts.txt",
"transcripts/train/ca16_read/conditioned.txt",
"transcripts/dev/niger_west_african_fr/transcripts.txt",
"speech/dev/niger_west_african_fr/niger_wav_file_name_transcript.tsv",
"transcripts/devtest/ca16_read/conditioned.txt",
"transcripts/test/ca16/prompts.txt",
]
transcripts = {}
for tr in all_files:
with open(os.path.join(target_dir, ARCHIVE_DIR_NAME, tr), 'r') as tr_source:
with open(os.path.join(target_dir, ARCHIVE_DIR_NAME, tr), "r") as tr_source:
for line in tr_source.readlines():
line = line.strip()
if '.tsv' in tr:
sep = ' '
if ".tsv" in tr:
sep = " "
else:
sep = ' '
sep = " "
audio = os.path.basename(line.split(sep)[0])
if not ('.wav' in audio):
if '.tdf' in audio:
audio = audio.replace('.tdf', '.wav')
if not (".wav" in audio):
if ".tdf" in audio:
audio = audio.replace(".tdf", ".wav")
else:
audio += '.wav'
audio += ".wav"
transcript = ' '.join(line.split(sep)[1:])
transcript = " ".join(line.split(sep)[1:])
transcripts[audio] = transcript
# Get audiofile path and transcript for each sentence in tsv
samples = []
glob_dir = os.path.join(wav_root_dir, '**/*.wav')
glob_dir = os.path.join(wav_root_dir, "**/*.wav")
for record in glob(glob_dir, recursive=True):
record_file = os.path.basename(record)
if record_file in transcripts:
@ -156,9 +159,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
pool.close()
pool.join()
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80%
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10%
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10%
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
train_writer.writeheader()
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
@ -178,25 +181,38 @@ def _maybe_convert_sets(target_dir, extracted_data):
writer = dev_writer
else:
writer = train_writer
writer.writerow(dict(
wav_filename=wav_filename,
wav_filesize=os.path.getsize(wav_filename),
transcript=transcript,
))
writer.writerow(
dict(
wav_filename=wav_filename,
wav_filesize=os.path.getsize(wav_filename),
transcript=transcript,
)
)
imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples
assert counter["all"] == num_samples
assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def handle_args():
parser = get_importers_parser(description='Importer for African Accented French dataset. More information on http://www.openslr.org/57/.')
parser.add_argument(dest='target_dir')
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
parser = get_importers_parser(
description="Importer for African Accented French dataset. More information on http://www.openslr.org/57/."
)
parser.add_argument(dest="target_dir")
parser.add_argument(
"--filter_alphabet",
help="Exclude samples with characters not in provided alphabet",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
return parser.parse_args()
if __name__ == "__main__":
CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
@ -204,9 +220,11 @@ if __name__ == "__main__":
def label_filter(label):
if CLI_ARGS.normalize:
label = unicodedata.normalize("NFKD", label.strip()) \
.encode("ascii", "ignore") \
label = (
unicodedata.normalize("NFKD", label.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
label = validate_label(label)
if ALPHABET and label:
try:

View File

@ -1,44 +1,38 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
# ensure that you have downloaded the LDC dataset LDC97S62 and tar exists in a folder e.g.
# ./data/swb/swb1_LDC97S62.tgz
# from the deepspeech directory run with: ./bin/import_swb.py ./data/swb/
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import codecs
import fnmatch
import pandas
import os
import subprocess
import sys
import tarfile
import unicodedata
import wave
import codecs
import tarfile
import requests
from util.importers import validate_label_eng as validate_label
import librosa
import soundfile # <= Has an external dependency on libsndfile
import pandas
import requests
import soundfile # <= Has an external dependency on libsndfile
from deepspeech_training.util.importers import validate_label_eng as validate_label
# ARCHIVE_NAME refers to ISIP alignments from 01/29/03
ARCHIVE_NAME = 'switchboard_word_alignments.tar.gz'
ARCHIVE_URL = 'http://www.openslr.org/resources/5/'
ARCHIVE_DIR_NAME = 'LDC97S62'
LDC_DATASET = 'swb1_LDC97S62.tgz'
ARCHIVE_NAME = "switchboard_word_alignments.tar.gz"
ARCHIVE_URL = "http://www.openslr.org/resources/5/"
ARCHIVE_DIR_NAME = "LDC97S62"
LDC_DATASET = "swb1_LDC97S62.tgz"
def download_file(folder, url):
# https://stackoverflow.com/a/16696317/738515
local_filename = url.split('/')[-1]
local_filename = url.split("/")[-1]
full_filename = os.path.join(folder, local_filename)
r = requests.get(url, stream=True)
with open(full_filename, 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
with open(full_filename, "wb") as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
return full_filename
@ -46,7 +40,7 @@ def download_file(folder, url):
def maybe_download(archive_url, target_dir, ldc_dataset):
# If archive file does not exist, download it...
archive_path = os.path.join(target_dir, ldc_dataset)
ldc_path = archive_url+ldc_dataset
ldc_path = archive_url + ldc_dataset
if not os.path.exists(target_dir):
print('No path "%s" - creating ...' % target_dir)
makedirs(target_dir)
@ -65,17 +59,23 @@ def _download_and_preprocess_data(data_dir):
archive_path = os.path.abspath(os.path.join(data_dir, LDC_DATASET))
# Check swb1_LDC97S62.tgz then extract
assert(os.path.isfile(archive_path))
assert os.path.isfile(archive_path)
_extract(target_dir, archive_path)
# Transcripts
transcripts_path = maybe_download(ARCHIVE_URL, target_dir, ARCHIVE_NAME)
_extract(target_dir, transcripts_path)
# Check swb1_d1/2/3/4/swb_ms98_transcriptions
expected_folders = ["swb1_d1","swb1_d2","swb1_d3","swb1_d4","swb_ms98_transcriptions"]
assert(all([os.path.isdir(os.path.join(target_dir,e)) for e in expected_folders]))
expected_folders = [
"swb1_d1",
"swb1_d2",
"swb1_d3",
"swb1_d4",
"swb_ms98_transcriptions",
]
assert all([os.path.isdir(os.path.join(target_dir, e)) for e in expected_folders])
# Conditionally convert swb sph data to wav
_maybe_convert_wav(target_dir, "swb1_d1", "swb1_d1-wav")
_maybe_convert_wav(target_dir, "swb1_d2", "swb1_d2-wav")
@ -83,13 +83,21 @@ def _download_and_preprocess_data(data_dir):
_maybe_convert_wav(target_dir, "swb1_d4", "swb1_d4-wav")
# Conditionally split wav data
d1 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d1-wav", "swb1_d1-split-wav")
d2 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d2-wav", "swb1_d2-split-wav")
d3 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d3-wav", "swb1_d3-split-wav")
d4 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d4-wav", "swb1_d4-split-wav")
d1 = _maybe_split_wav_and_sentences(
target_dir, "swb_ms98_transcriptions", "swb1_d1-wav", "swb1_d1-split-wav"
)
d2 = _maybe_split_wav_and_sentences(
target_dir, "swb_ms98_transcriptions", "swb1_d2-wav", "swb1_d2-split-wav"
)
d3 = _maybe_split_wav_and_sentences(
target_dir, "swb_ms98_transcriptions", "swb1_d3-wav", "swb1_d3-split-wav"
)
d4 = _maybe_split_wav_and_sentences(
target_dir, "swb_ms98_transcriptions", "swb1_d4-wav", "swb1_d4-split-wav"
)
swb_files = d1.append(d2).append(d3).append(d4)
train_files, dev_files, test_files = _split_sets(swb_files)
# Write sets to disk as CSV files
@ -97,7 +105,7 @@ def _download_and_preprocess_data(data_dir):
dev_files.to_csv(os.path.join(target_dir, "swb-dev.csv"), index=False)
test_files.to_csv(os.path.join(target_dir, "swb-test.csv"), index=False)
def _extract(target_dir, archive_path):
with tarfile.open(archive_path) as tar:
tar.extractall(target_dir)
@ -118,25 +126,46 @@ def _maybe_convert_wav(data_dir, original_data, converted_data):
# Loop over sph files in source_dir and convert each to 16-bit PCM wav
for root, dirnames, filenames in os.walk(source_dir):
for filename in fnmatch.filter(filenames, "*.sph"):
for channel in ['1', '2']:
for channel in ["1", "2"]:
sph_file = os.path.join(root, filename)
wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "-" + channel + ".wav"
wav_filename = (
os.path.splitext(os.path.basename(sph_file))[0]
+ "-"
+ channel
+ ".wav"
)
wav_file = os.path.join(target_dir, wav_filename)
temp_wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "-" + channel + "-temp.wav"
temp_wav_filename = (
os.path.splitext(os.path.basename(sph_file))[0]
+ "-"
+ channel
+ "-temp.wav"
)
temp_wav_file = os.path.join(target_dir, temp_wav_filename)
print("converting {} to {}".format(sph_file, temp_wav_file))
subprocess.check_call(["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, temp_wav_file])
subprocess.check_call(
[
"sph2pipe",
"-c",
channel,
"-p",
"-f",
"rif",
sph_file,
temp_wav_file,
]
)
print("upsampling {} to {}".format(temp_wav_file, wav_file))
audioData, frameRate = librosa.load(temp_wav_file, sr=16000, mono=True)
soundfile.write(wav_file, audioData, frameRate, "PCM_16")
os.remove(temp_wav_file)
def _parse_transcriptions(trans_file):
segments = []
with codecs.open(trans_file, "r", "utf-8") as fin:
for line in fin:
if line.startswith("#") or len(line) <= 1:
if line.startswith("#") or len(line) <= 1:
continue
tokens = line.split()
@ -150,15 +179,19 @@ def _parse_transcriptions(trans_file):
# We need to do the encode-decode dance here because encode
# returns a bytes() object on Python 3, and text_to_char_array
# expects a string.
transcript = unicodedata.normalize("NFKD", transcript) \
.encode("ascii", "ignore") \
.decode("ascii", "ignore")
transcript = (
unicodedata.normalize("NFKD", transcript)
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
segments.append({
"start_time": start_time,
"stop_time": stop_time,
"transcript": transcript,
})
segments.append(
{
"start_time": start_time,
"stop_time": stop_time,
"transcript": transcript,
}
)
return segments
@ -183,8 +216,16 @@ def _maybe_split_wav_and_sentences(data_dir, trans_data, original_data, converte
segments = _parse_transcriptions(trans_file)
# Open wav corresponding to transcription file
channel = ("2","1")[(os.path.splitext(os.path.basename(trans_file))[0])[6] == 'A']
wav_filename = "sw0" + (os.path.splitext(os.path.basename(trans_file))[0])[2:6] + "-" + channel + ".wav"
channel = ("2", "1")[
(os.path.splitext(os.path.basename(trans_file))[0])[6] == "A"
]
wav_filename = (
"sw0"
+ (os.path.splitext(os.path.basename(trans_file))[0])[2:6]
+ "-"
+ channel
+ ".wav"
)
wav_file = os.path.join(source_dir, wav_filename)
print("splitting {} according to {}".format(wav_file, trans_file))
@ -200,26 +241,39 @@ def _maybe_split_wav_and_sentences(data_dir, trans_data, original_data, converte
# Create wav segment filename
start_time = segment["start_time"]
stop_time = segment["stop_time"]
new_wav_filename = os.path.splitext(os.path.basename(trans_file))[0] + "-" + str(
start_time) + "-" + str(stop_time) + ".wav"
new_wav_filename = (
os.path.splitext(os.path.basename(trans_file))[0]
+ "-"
+ str(start_time)
+ "-"
+ str(stop_time)
+ ".wav"
)
if _is_wav_too_short(new_wav_filename):
continue
continue
new_wav_file = os.path.join(target_dir, new_wav_filename)
_split_wav(origAudio, start_time, stop_time, new_wav_file)
new_wav_filesize = os.path.getsize(new_wav_file)
transcript = segment["transcript"]
files.append((os.path.abspath(new_wav_file), new_wav_filesize, transcript))
files.append(
(os.path.abspath(new_wav_file), new_wav_filesize, transcript)
)
# Close origAudio
origAudio.close()
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
return pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
)
def _is_wav_too_short(wav_filename):
short_wav_filenames = ['sw2986A-ms98-a-trans-80.6385-83.358875.wav', 'sw2663A-ms98-a-trans-161.12025-164.213375.wav']
short_wav_filenames = [
"sw2986A-ms98-a-trans-80.6385-83.358875.wav",
"sw2663A-ms98-a-trans-161.12025-164.213375.wav",
]
return wav_filename in short_wav_filenames
@ -234,7 +288,7 @@ def _split_wav(origAudio, start_time, stop_time, new_wav_file):
chunkAudio.writeframes(chunkData)
chunkAudio.close()
def _split_sets(filelist):
# We initially split the entire set into 80% train and 20% test, then
# split the train set into 80% train and 20% validation.
@ -248,10 +302,24 @@ def _split_sets(filelist):
test_beg = dev_end
test_end = len(filelist)
return (filelist[train_beg:train_end], filelist[dev_beg:dev_end], filelist[test_beg:test_end])
return (
filelist[train_beg:train_end],
filelist[dev_beg:dev_end],
filelist[test_beg:test_end],
)
def _read_data_set(filelist, thread_count, batch_size, numcep, numcontext, stride=1, offset=0, next_index=lambda i: i + 1, limit=0):
def _read_data_set(
filelist,
thread_count,
batch_size,
numcep,
numcontext,
stride=1,
offset=0,
next_index=lambda i: i + 1,
limit=0,
):
# Optionally apply dataset size limit
if limit > 0:
filelist = filelist.iloc[:limit]
@ -259,7 +327,9 @@ def _read_data_set(filelist, thread_count, batch_size, numcep, numcontext, strid
filelist = filelist[offset::stride]
# Return DataSet
return DataSet(txt_files, thread_count, batch_size, numcep, numcontext, next_index=next_index)
return DataSet(
txt_files, thread_count, batch_size, numcep, numcontext, next_index=next_index
)
if __name__ == "__main__":

View File

@ -1,70 +1,76 @@
#!/usr/bin/env python
'''
"""
Downloads and prepares (parts of) the "Spoken Wikipedia Corpora" for DeepSpeech.py
Use "python3 import_swc.py -h" for help
'''
from __future__ import absolute_import, division, print_function
"""
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import re
import csv
import sox
import wave
import shutil
import random
import tarfile
import argparse
import progressbar
import csv
import os
import random
import re
import shutil
import sys
import tarfile
import unicodedata
import wave
import xml.etree.cElementTree as ET
from os import path
from glob import glob
from collections import Counter
from glob import glob
from multiprocessing.pool import ThreadPool
from util.text import Alphabet
from util.importers import validate_label_eng as validate_label
from util.downloader import maybe_download, SIMPLE_BAR
import progressbar
import sox
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import validate_label_eng as validate_label
from deepspeech_training.util.text import Alphabet
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"
SWC_ARCHIVE = "SWC_{language}.tar"
LANGUAGES = ['dutch', 'english', 'german']
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
FIELDNAMES_EXT = FIELDNAMES + ['article', 'speaker']
LANGUAGES = ["dutch", "english", "german"]
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
FIELDNAMES_EXT = FIELDNAMES + ["article", "speaker"]
CHANNELS = 1
SAMPLE_RATE = 16000
UNKNOWN = '<unknown>'
AUDIO_PATTERN = 'audio*.ogg'
WAV_NAME = 'audio.wav'
ALIGNED_NAME = 'aligned.swc'
UNKNOWN = "<unknown>"
AUDIO_PATTERN = "audio*.ogg"
WAV_NAME = "audio.wav"
ALIGNED_NAME = "aligned.swc"
SUBSTITUTIONS = {
'german': [
(re.compile(r'\$'), 'dollar'),
(re.compile(r''), 'euro'),
(re.compile(r'£'), 'pfund'),
(re.compile(r'ein tausend ([^\s]+) hundert ([^\s]+) er( |$)'), r'\1zehnhundert \2er '),
(re.compile(r'ein tausend (acht|neun) hundert'), r'\1zehnhundert'),
(re.compile(r'eins punkt null null null punkt null null null punkt null null null'), 'eine milliarde'),
(re.compile(r'punkt null null null punkt null null null punkt null null null'), 'milliarden'),
(re.compile(r'eins punkt null null null punkt null null null'), 'eine million'),
(re.compile(r'punkt null null null punkt null null null'), 'millionen'),
(re.compile(r'eins punkt null null null'), 'ein tausend'),
(re.compile(r'punkt null null null'), 'tausend'),
(re.compile(r'punkt null'), None)
"german": [
(re.compile(r"\$"), "dollar"),
(re.compile(r""), "euro"),
(re.compile(r"£"), "pfund"),
(
re.compile(r"ein tausend ([^\s]+) hundert ([^\s]+) er( |$)"),
r"\1zehnhundert \2er ",
),
(re.compile(r"ein tausend (acht|neun) hundert"), r"\1zehnhundert"),
(
re.compile(
r"eins punkt null null null punkt null null null punkt null null null"
),
"eine milliarde",
),
(
re.compile(
r"punkt null null null punkt null null null punkt null null null"
),
"milliarden",
),
(re.compile(r"eins punkt null null null punkt null null null"), "eine million"),
(re.compile(r"punkt null null null punkt null null null"), "millionen"),
(re.compile(r"eins punkt null null null"), "ein tausend"),
(re.compile(r"punkt null null null"), "tausend"),
(re.compile(r"punkt null"), None),
]
}
DONT_NORMALIZE = {
'german': 'ÄÖÜäöüß'
}
DONT_NORMALIZE = {"german": "ÄÖÜäöüß"}
PRE_FILTER = str.maketrans(dict.fromkeys('/()[]{}<>:'))
PRE_FILTER = str.maketrans(dict.fromkeys("/()[]{}<>:"))
class Sample:
@ -98,11 +104,14 @@ def get_sample_size(population_size):
margin_of_error = 0.01
fraction_picking = 0.50
z_score = 2.58 # Corresponds to confidence level 99%
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (margin_of_error ** 2)
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2
)
sample_size = 0
for train_size in range(population_size, 0, -1):
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / \
(margin_of_error ** 2 * train_size)
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2 * train_size
)
sample_size = int(numerator / denominator)
if 2 * sample_size + train_size <= population_size:
break
@ -111,14 +120,16 @@ def get_sample_size(population_size):
def maybe_download_language(language):
lang_upper = language[0].upper() + language[1:]
return maybe_download(SWC_ARCHIVE.format(language=lang_upper),
CLI_ARGS.base_dir,
SWC_URL.format(language=lang_upper))
return maybe_download(
SWC_ARCHIVE.format(language=lang_upper),
CLI_ARGS.base_dir,
SWC_URL.format(language=lang_upper),
)
def maybe_extract(data_dir, extracted_data, archive):
extracted = path.join(data_dir, extracted_data)
if path.isdir(extracted):
extracted = os.path.join(data_dir, extracted_data)
if os.path.isdir(extracted):
print('Found directory "{}" - not extracting.'.format(extracted))
else:
print('Extracting "{}"...'.format(archive))
@ -133,29 +144,29 @@ def maybe_extract(data_dir, extracted_data, archive):
def ignored(node):
if node is None:
return False
if node.tag == 'ignored':
if node.tag == "ignored":
return True
return ignored(node.find('..'))
return ignored(node.find(".."))
def read_token(token):
texts, start, end = [], None, None
notes = token.findall('n')
notes = token.findall("n")
if len(notes) > 0:
for note in notes:
attributes = note.attrib
if start is None and 'start' in attributes:
start = int(attributes['start'])
if 'end' in attributes:
token_end = int(attributes['end'])
if start is None and "start" in attributes:
start = int(attributes["start"])
if "end" in attributes:
token_end = int(attributes["end"])
if end is None or token_end > end:
end = token_end
if 'pronunciation' in attributes:
t = attributes['pronunciation']
if "pronunciation" in attributes:
t = attributes["pronunciation"]
texts.append(t)
elif 'text' in token.attrib:
texts.append(token.attrib['text'])
return start, end, ' '.join(texts)
elif "text" in token.attrib:
texts.append(token.attrib["text"])
return start, end, " ".join(texts)
def in_alphabet(alphabet, c):
@ -163,10 +174,12 @@ def in_alphabet(alphabet, c):
ALPHABETS = {}
def get_alphabet(language):
if language in ALPHABETS:
return ALPHABETS[language]
alphabet_path = getattr(CLI_ARGS, language + '_alphabet')
alphabet_path = getattr(CLI_ARGS, language + "_alphabet")
alphabet = Alphabet(alphabet_path) if alphabet_path else None
ALPHABETS[language] = alphabet
return alphabet
@ -176,27 +189,35 @@ def label_filter(label, language):
label = label.translate(PRE_FILTER)
label = validate_label(label)
if label is None:
return None, 'validation'
return None, "validation"
substitutions = SUBSTITUTIONS[language] if language in SUBSTITUTIONS else []
for pattern, replacement in substitutions:
if replacement is None:
if pattern.match(label):
return None, 'substitution rule'
return None, "substitution rule"
else:
label = pattern.sub(replacement, label)
chars = []
dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else ''
dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else ""
alphabet = get_alphabet(language)
for c in label:
if CLI_ARGS.normalize and c not in dont_normalize and not in_alphabet(alphabet, c):
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore")
if (
CLI_ARGS.normalize
and c not in dont_normalize
and not in_alphabet(alphabet, c)
):
c = (
unicodedata.normalize("NFKD", c)
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
for sc in c:
if not in_alphabet(alphabet, sc):
return None, 'illegal character'
return None, "illegal character"
chars.append(sc)
label = ''.join(chars)
label = "".join(chars)
label = validate_label(label)
return label, 'validation' if label is None else None
return label, "validation" if label is None else None
def collect_samples(base_dir, language):
@ -207,7 +228,9 @@ def collect_samples(base_dir, language):
samples = []
reasons = Counter()
def add_sample(p_wav_path, p_article, p_speaker, p_start, p_end, p_text, p_reason='complete'):
def add_sample(
p_wav_path, p_article, p_speaker, p_start, p_end, p_text, p_reason="complete"
):
if p_start is not None and p_end is not None and p_text is not None:
duration = p_end - p_start
text, filter_reason = label_filter(p_text, language)
@ -217,53 +240,67 @@ def collect_samples(base_dir, language):
p_reason = filter_reason
elif CLI_ARGS.exclude_unknown_speakers and p_speaker == UNKNOWN:
skip = True
p_reason = 'unknown speaker'
p_reason = "unknown speaker"
elif CLI_ARGS.exclude_unknown_articles and p_article == UNKNOWN:
skip = True
p_reason = 'unknown article'
p_reason = "unknown article"
elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long:
skip = True
p_reason = 'exceeded duration'
p_reason = "exceeded duration"
elif int(duration / 30) < len(text):
skip = True
p_reason = 'too short to decode'
p_reason = "too short to decode"
elif duration / len(text) < 10:
skip = True
p_reason = 'length duration ratio'
p_reason = "length duration ratio"
if skip:
reasons[p_reason] += 1
else:
samples.append(Sample(p_wav_path, p_start, p_end, text, p_article, p_speaker))
samples.append(
Sample(p_wav_path, p_start, p_end, text, p_article, p_speaker)
)
elif p_start is None or p_end is None:
reasons['missing timestamps'] += 1
reasons["missing timestamps"] += 1
else:
reasons['missing text'] += 1
reasons["missing text"] += 1
print('Collecting samples...')
print("Collecting samples...")
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
for root in bar(roots):
wav_path = path.join(root, WAV_NAME)
wav_path = os.path.join(root, WAV_NAME)
aligned = ET.parse(path.join(root, ALIGNED_NAME))
article = UNKNOWN
speaker = UNKNOWN
for prop in aligned.iter('prop'):
for prop in aligned.iter("prop"):
attributes = prop.attrib
if 'key' in attributes and 'value' in attributes:
if attributes['key'] == 'DC.identifier':
article = attributes['value']
elif attributes['key'] == 'reader.name':
speaker = attributes['value']
for sentence in aligned.iter('s'):
if "key" in attributes and "value" in attributes:
if attributes["key"] == "DC.identifier":
article = attributes["value"]
elif attributes["key"] == "reader.name":
speaker = attributes["value"]
for sentence in aligned.iter("s"):
if ignored(sentence):
continue
split = False
tokens = list(map(read_token, sentence.findall('t')))
tokens = list(map(read_token, sentence.findall("t")))
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
for token_start, token_end, token_text in tokens:
if CLI_ARGS.exclude_numbers and any(c.isdigit() for c in token_text):
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts),
p_reason='has numbers')
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
add_sample(
wav_path,
article,
speaker,
sample_start,
sample_end,
" ".join(sample_texts),
p_reason="has numbers",
)
sample_start, sample_end, token_texts, sample_texts = (
None,
None,
[],
[],
)
continue
if sample_start is None:
sample_start = token_start
@ -271,20 +308,37 @@ def collect_samples(base_dir, language):
continue
token_texts.append(token_text)
if token_end is not None:
if token_start != sample_start and token_end - sample_start > CLI_ARGS.max_duration > 0:
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts),
p_reason='split')
if (
token_start != sample_start
and token_end - sample_start > CLI_ARGS.max_duration > 0
):
add_sample(
wav_path,
article,
speaker,
sample_start,
sample_end,
" ".join(sample_texts),
p_reason="split",
)
sample_start = sample_end
sample_texts = []
split = True
sample_end = token_end
sample_texts.extend(token_texts)
token_texts = []
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts),
p_reason='split' if split else 'complete')
print('Skipped samples:')
add_sample(
wav_path,
article,
speaker,
sample_start,
sample_end,
" ".join(sample_texts),
p_reason="split" if split else "complete",
)
print("Skipped samples:")
for reason, n in reasons.most_common():
print(' - {}: {}'.format(reason, n))
print(" - {}: {}".format(reason, n))
return samples
@ -294,8 +348,8 @@ def maybe_convert_one_to_wav(entry):
transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
combiner = sox.Combiner()
combiner.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
output_wav = path.join(root, WAV_NAME)
if path.isfile(output_wav):
output_wav = os.path.join(root, WAV_NAME)
if os.path.isfile(output_wav):
return
files = sorted(glob(path.join(root, AUDIO_PATTERN)))
try:
@ -304,18 +358,18 @@ def maybe_convert_one_to_wav(entry):
elif len(files) > 1:
wav_files = []
for i, file in enumerate(files):
wav_path = path.join(root, 'audio{}.wav'.format(i))
wav_path = os.path.join(root, "audio{}.wav".format(i))
transformer.build(file, wav_path)
wav_files.append(wav_path)
combiner.set_input_format(file_type=['wav'] * len(wav_files))
combiner.build(wav_files, output_wav, 'concatenate')
combiner.set_input_format(file_type=["wav"] * len(wav_files))
combiner.build(wav_files, output_wav, "concatenate")
except sox.core.SoxError:
return
def maybe_convert_to_wav(base_dir):
roots = list(os.walk(base_dir))
print('Converting and joining source audio files...')
print("Converting and joining source audio files...")
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
tp = ThreadPool()
for _ in bar(tp.imap_unordered(maybe_convert_one_to_wav, roots)):
@ -335,53 +389,66 @@ def assign_sub_sets(samples):
sample_set.extend(speakers.pop(0))
train_set = sum(speakers, [])
if len(train_set) == 0:
print('WARNING: Unable to build dev and test sets without speaker bias as there is no speaker meta data')
print(
"WARNING: Unable to build dev and test sets without speaker bias as there is no speaker meta data"
)
random.seed(42) # same source data == same output
random.shuffle(samples)
for index, sample in enumerate(samples):
if index < sample_size:
sample.sub_set = 'dev'
sample.sub_set = "dev"
elif index < 2 * sample_size:
sample.sub_set = 'test'
sample.sub_set = "test"
else:
sample.sub_set = 'train'
sample.sub_set = "train"
else:
for sub_set, sub_set_samples in [('train', train_set), ('dev', sample_sets[0]), ('test', sample_sets[1])]:
for sub_set, sub_set_samples in [
("train", train_set),
("dev", sample_sets[0]),
("test", sample_sets[1]),
]:
for sample in sub_set_samples:
sample.sub_set = sub_set
for sub_set, sub_set_samples in group(samples, lambda s: s.sub_set).items():
t = sum(map(lambda s: s.end - s.start, sub_set_samples)) / (1000 * 60 * 60)
print('Sub-set "{}" with {} samples (duration: {:.2f} h)'
.format(sub_set, len(sub_set_samples), t))
print(
'Sub-set "{}" with {} samples (duration: {:.2f} h)'.format(
sub_set, len(sub_set_samples), t
)
)
def create_sample_dirs(language):
print('Creating sample directories...')
for set_name in ['train', 'dev', 'test']:
dir_path = path.join(CLI_ARGS.base_dir, language + '-' + set_name)
if not path.isdir(dir_path):
print("Creating sample directories...")
for set_name in ["train", "dev", "test"]:
dir_path = os.path.join(CLI_ARGS.base_dir, language + "-" + set_name)
if not os.path.isdir(dir_path):
os.mkdir(dir_path)
def split_audio_files(samples, language):
print('Splitting audio files...')
print("Splitting audio files...")
sub_sets = Counter()
src_wav_files = group(samples, lambda s: s.wav_path).items()
bar = progressbar.ProgressBar(max_value=len(src_wav_files), widgets=SIMPLE_BAR)
for wav_path, file_samples in bar(src_wav_files):
file_samples = sorted(file_samples, key=lambda s: s.start)
with wave.open(wav_path, 'r') as src_wav_file:
with wave.open(wav_path, "r") as src_wav_file:
rate = src_wav_file.getframerate()
for sample in file_samples:
index = sub_sets[sample.sub_set]
sample_wav_path = path.join(CLI_ARGS.base_dir,
language + '-' + sample.sub_set,
'sample-{0:06d}.wav'.format(index))
sample_wav_path = os.path.join(
CLI_ARGS.base_dir,
language + "-" + sample.sub_set,
"sample-{0:06d}.wav".format(index),
)
sample.wav_path = sample_wav_path
sub_sets[sample.sub_set] += 1
src_wav_file.setpos(int(sample.start * rate / 1000.0))
data = src_wav_file.readframes(int((sample.end - sample.start) * rate / 1000.0))
with wave.open(sample_wav_path, 'w') as sample_wav_file:
data = src_wav_file.readframes(
int((sample.end - sample.start) * rate / 1000.0)
)
with wave.open(sample_wav_path, "w") as sample_wav_file:
sample_wav_file.setnchannels(src_wav_file.getnchannels())
sample_wav_file.setsampwidth(src_wav_file.getsampwidth())
sample_wav_file.setframerate(rate)
@ -391,22 +458,26 @@ def split_audio_files(samples, language):
def write_csvs(samples, language):
for sub_set, set_samples in group(samples, lambda s: s.sub_set).items():
set_samples = sorted(set_samples, key=lambda s: s.wav_path)
base_dir = path.abspath(CLI_ARGS.base_dir)
csv_path = path.join(base_dir, language + '-' + sub_set + '.csv')
base_dir = os.path.abspath(CLI_ARGS.base_dir)
csv_path = os.path.join(base_dir, language + "-" + sub_set + ".csv")
print('Writing "{}"...'.format(csv_path))
with open(csv_path, 'w') as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES)
with open(csv_path, "w") as csv_file:
writer = csv.DictWriter(
csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES
)
writer.writeheader()
bar = progressbar.ProgressBar(max_value=len(set_samples), widgets=SIMPLE_BAR)
bar = progressbar.ProgressBar(
max_value=len(set_samples), widgets=SIMPLE_BAR
)
for sample in bar(set_samples):
row = {
'wav_filename': path.relpath(sample.wav_path, base_dir),
'wav_filesize': path.getsize(sample.wav_path),
'transcript': sample.text
"wav_filename": os.path.relpath(sample.wav_path, base_dir),
"wav_filesize": os.path.getsize(sample.wav_path),
"transcript": sample.text,
}
if CLI_ARGS.add_meta:
row['article'] = sample.article
row['speaker'] = sample.speaker
row["article"] = sample.article
row["speaker"] = sample.speaker
writer.writerow(row)
@ -414,8 +485,8 @@ def cleanup(archive, language):
if not CLI_ARGS.keep_archive:
print('Removing archive "{}"...'.format(archive))
os.remove(archive)
language_dir = path.join(CLI_ARGS.base_dir, language)
if not CLI_ARGS.keep_intermediate and path.isdir(language_dir):
language_dir = os.path.join(CLI_ARGS.base_dir, language)
if not CLI_ARGS.keep_intermediate and os.path.isdir(language_dir):
print('Removing intermediate files in "{}"...'.format(language_dir))
shutil.rmtree(language_dir)
@ -433,34 +504,75 @@ def prepare_language(language):
def handle_args():
parser = argparse.ArgumentParser(description='Import Spoken Wikipedia Corpora')
parser.add_argument('base_dir', help='Directory containing all data')
parser.add_argument('--language', default='all', help='One of (all|{})'.format('|'.join(LANGUAGES)))
parser.add_argument('--exclude_numbers', type=bool, default=True,
help='If sequences with non-transliterated numbers should be excluded')
parser.add_argument('--max_duration', type=int, default=10000, help='Maximum sample duration in milliseconds')
parser.add_argument('--ignore_too_long', type=bool, default=False,
help='If samples exceeding max_duration should be removed')
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
parser = argparse.ArgumentParser(description="Import Spoken Wikipedia Corpora")
parser.add_argument("base_dir", help="Directory containing all data")
parser.add_argument(
"--language", default="all", help="One of (all|{})".format("|".join(LANGUAGES))
)
parser.add_argument(
"--exclude_numbers",
type=bool,
default=True,
help="If sequences with non-transliterated numbers should be excluded",
)
parser.add_argument(
"--max_duration",
type=int,
default=10000,
help="Maximum sample duration in milliseconds",
)
parser.add_argument(
"--ignore_too_long",
type=bool,
default=False,
help="If samples exceeding max_duration should be removed",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
for language in LANGUAGES:
parser.add_argument('--{}_alphabet'.format(language),
help='Exclude {} samples with characters not in provided alphabet file'.format(language))
parser.add_argument('--add_meta', action='store_true', help='Adds article and speaker CSV columns')
parser.add_argument('--exclude_unknown_speakers', action='store_true', help='Exclude unknown speakers')
parser.add_argument('--exclude_unknown_articles', action='store_true', help='Exclude unknown articles')
parser.add_argument('--keep_archive', type=bool, default=True,
help='If downloaded archives should be kept')
parser.add_argument('--keep_intermediate', type=bool, default=False,
help='If intermediate files should be kept')
parser.add_argument(
"--{}_alphabet".format(language),
help="Exclude {} samples with characters not in provided alphabet file".format(
language
),
)
parser.add_argument(
"--add_meta", action="store_true", help="Adds article and speaker CSV columns"
)
parser.add_argument(
"--exclude_unknown_speakers",
action="store_true",
help="Exclude unknown speakers",
)
parser.add_argument(
"--exclude_unknown_articles",
action="store_true",
help="Exclude unknown articles",
)
parser.add_argument(
"--keep_archive",
type=bool,
default=True,
help="If downloaded archives should be kept",
)
parser.add_argument(
"--keep_intermediate",
type=bool,
default=False,
help="If intermediate files should be kept",
)
return parser.parse_args()
if __name__ == "__main__":
CLI_ARGS = handle_args()
if CLI_ARGS.language == 'all':
if CLI_ARGS.language == "all":
for lang in LANGUAGES:
prepare_language(lang)
elif CLI_ARGS.language in LANGUAGES:
prepare_language(CLI_ARGS.language)
else:
fail('Wrong language id')
fail("Wrong language id")

View File

@ -1,24 +1,18 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import codecs
import pandas
import tarfile
import unicodedata
import wave
from glob import glob
from os import makedirs, path, remove, rmdir
import pandas
from sox import Transformer
from util.downloader import maybe_download
from tensorflow.python.platform import gfile
from util.stm import parse_stm_file
from deepspeech_training.util.downloader import maybe_download
from deepspeech_training.util.stm import parse_stm_file
def _download_and_preprocess_data(data_dir):
# Conditionally download data
@ -41,6 +35,7 @@ def _download_and_preprocess_data(data_dir):
dev_files.to_csv(path.join(data_dir, "ted-dev.csv"), index=False)
test_files.to_csv(path.join(data_dir, "ted-test.csv"), index=False)
def _maybe_extract(data_dir, extracted_data, archive):
# If data_dir/extracted_data does not exist, extract archive in data_dir
if not gfile.Exists(path.join(data_dir, extracted_data)):
@ -48,6 +43,7 @@ def _maybe_extract(data_dir, extracted_data, archive):
tar.extractall(data_dir)
tar.close()
def _maybe_convert_wav(data_dir, extracted_data):
# Create extracted_data dir
extracted_dir = path.join(data_dir, extracted_data)
@ -61,6 +57,7 @@ def _maybe_convert_wav(data_dir, extracted_data):
# Conditionally convert test sph to wav
_maybe_convert_wav_dataset(extracted_dir, "test")
def _maybe_convert_wav_dataset(extracted_dir, data_set):
# Create source dir
source_dir = path.join(extracted_dir, data_set, "sph")
@ -84,6 +81,7 @@ def _maybe_convert_wav_dataset(extracted_dir, data_set):
# Remove source_dir
rmdir(source_dir)
def _maybe_split_sentences(data_dir, extracted_data):
# Create extracted_data dir
extracted_dir = path.join(data_dir, extracted_data)
@ -99,6 +97,7 @@ def _maybe_split_sentences(data_dir, extracted_data):
return train_files, dev_files, test_files
def _maybe_split_dataset(extracted_dir, data_set):
# Create stm dir
stm_dir = path.join(extracted_dir, data_set, "stm")
@ -116,14 +115,21 @@ def _maybe_split_dataset(extracted_dir, data_set):
# Open wav corresponding to stm_file
wav_filename = path.splitext(path.basename(stm_file))[0] + ".wav"
wav_file = path.join(wav_dir, wav_filename)
origAudio = wave.open(wav_file,'r')
origAudio = wave.open(wav_file, "r")
# Loop over stm_segments and split wav_file for each segment
for stm_segment in stm_segments:
# Create wav segment filename
start_time = stm_segment.start_time
stop_time = stm_segment.stop_time
new_wav_filename = path.splitext(path.basename(stm_file))[0] + "-" + str(start_time) + "-" + str(stop_time) + ".wav"
new_wav_filename = (
path.splitext(path.basename(stm_file))[0]
+ "-"
+ str(start_time)
+ "-"
+ str(stop_time)
+ ".wav"
)
new_wav_file = path.join(wav_dir, new_wav_filename)
# If the wav segment filename does not exist create it
@ -131,23 +137,29 @@ def _maybe_split_dataset(extracted_dir, data_set):
_split_wav(origAudio, start_time, stop_time, new_wav_file)
new_wav_filesize = path.getsize(new_wav_file)
files.append((path.abspath(new_wav_file), new_wav_filesize, stm_segment.transcript))
files.append(
(path.abspath(new_wav_file), new_wav_filesize, stm_segment.transcript)
)
# Close origAudio
origAudio.close()
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
return pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
)
def _split_wav(origAudio, start_time, stop_time, new_wav_file):
frameRate = origAudio.getframerate()
origAudio.setpos(int(start_time*frameRate))
chunkData = origAudio.readframes(int((stop_time - start_time)*frameRate))
chunkAudio = wave.open(new_wav_file,'w')
origAudio.setpos(int(start_time * frameRate))
chunkData = origAudio.readframes(int((stop_time - start_time) * frameRate))
chunkAudio = wave.open(new_wav_file, "w")
chunkAudio.setnchannels(origAudio.getnchannels())
chunkAudio.setsampwidth(origAudio.getsampwidth())
chunkAudio.setframerate(frameRate)
chunkAudio.writeframes(chunkData)
chunkAudio.close()
if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1])

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python
'''
"""
NAME : LDC TIMIT Dataset
URL : https://catalog.ldc.upenn.edu/ldc93s1
HOURS : 5
@ -8,29 +8,32 @@
AUTHORS : Garofolo, John, et al.
TYPE : LDC Membership
LICENCE : LDC User Agreement
'''
"""
import errno
import fnmatch
import os
from os import path
import subprocess
import sys
import tarfile
import fnmatch
from os import path
import pandas as pd
import subprocess
def clean(word):
# LC ALL & strip punctuation which are not required
new = word.lower().replace('.', '')
new = new.replace(',', '')
new = new.replace(';', '')
new = new.replace('"', '')
new = new.replace('!', '')
new = new.replace('?', '')
new = new.replace(':', '')
new = new.replace('-', '')
new = word.lower().replace(".", "")
new = new.replace(",", "")
new = new.replace(";", "")
new = new.replace('"', "")
new = new.replace("!", "")
new = new.replace("?", "")
new = new.replace(":", "")
new = new.replace("-", "")
return new
def _preprocess_data(args):
# Assume data is downloaded from LDC - https://catalog.ldc.upenn.edu/ldc93s1
@ -40,16 +43,24 @@ def _preprocess_data(args):
if ignoreSASentences:
print("Using recommended ignore SA sentences")
print("Ignoring SA sentences (2 x sentences which are repeated by all speakers)")
print(
"Ignoring SA sentences (2 x sentences which are repeated by all speakers)"
)
else:
print("Using unrecommended setting to include SA sentences")
datapath = args
target = path.join(datapath, "TIMIT")
print("Checking to see if data has already been extracted in given argument: %s", target)
print(
"Checking to see if data has already been extracted in given argument: %s",
target,
)
if not path.isdir(target):
print("Could not find extracted data, trying to find: TIMIT-LDC93S1.tgz in: ", datapath)
print(
"Could not find extracted data, trying to find: TIMIT-LDC93S1.tgz in: ",
datapath,
)
filepath = path.join(datapath, "TIMIT-LDC93S1.tgz")
if path.isfile(filepath):
print("File found, extracting")
@ -103,40 +114,58 @@ def _preprocess_data(args):
# if ignoreSAsentences we only want those without SA in the name
# OR
# if not ignoreSAsentences we want all to be added
if (ignoreSASentences and not ('SA' in os.path.basename(full_wav))) or (not ignoreSASentences):
if 'train' in full_wav.lower():
if (ignoreSASentences and not ("SA" in os.path.basename(full_wav))) or (
not ignoreSASentences
):
if "train" in full_wav.lower():
train_list_wavs.append(full_wav)
train_list_trans.append(trans)
train_list_size.append(wav_filesize)
elif 'test' in full_wav.lower():
elif "test" in full_wav.lower():
test_list_wavs.append(full_wav)
test_list_trans.append(trans)
test_list_size.append(wav_filesize)
else:
raise IOError
a = {'wav_filename': train_list_wavs,
'wav_filesize': train_list_size,
'transcript': train_list_trans
}
a = {
"wav_filename": train_list_wavs,
"wav_filesize": train_list_size,
"transcript": train_list_trans,
}
c = {'wav_filename': test_list_wavs,
'wav_filesize': test_list_size,
'transcript': test_list_trans
}
c = {
"wav_filename": test_list_wavs,
"wav_filesize": test_list_size,
"transcript": test_list_trans,
}
all = {'wav_filename': train_list_wavs + test_list_wavs,
'wav_filesize': train_list_size + test_list_size,
'transcript': train_list_trans + test_list_trans
}
all = {
"wav_filename": train_list_wavs + test_list_wavs,
"wav_filesize": train_list_size + test_list_size,
"transcript": train_list_trans + test_list_trans,
}
df_all = pd.DataFrame(all, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int)
df_train = pd.DataFrame(a, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int)
df_test = pd.DataFrame(c, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int)
df_all = pd.DataFrame(
all, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
)
df_train = pd.DataFrame(
a, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
)
df_test = pd.DataFrame(
c, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
)
df_all.to_csv(
target + "/timit_all.csv", sep=",", header=True, index=False, encoding="ascii"
)
df_train.to_csv(
target + "/timit_train.csv", sep=",", header=True, index=False, encoding="ascii"
)
df_test.to_csv(
target + "/timit_test.csv", sep=",", header=True, index=False, encoding="ascii"
)
df_all.to_csv(target+"/timit_all.csv", sep=',', header=True, index=False, encoding='ascii')
df_train.to_csv(target+"/timit_train.csv", sep=',', header=True, index=False, encoding='ascii')
df_test.to_csv(target+"/timit_test.csv", sep=',', header=True, index=False, encoding='ascii')
if __name__ == "__main__":
_preprocess_data(sys.argv[1])

View File

@ -1,52 +1,53 @@
#!/usr/bin/env python3
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import csv
import os
import re
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv
import unidecode
import zipfile
import sox
import subprocess
import progressbar
import zipfile
from multiprocessing import Pool
from util.downloader import SIMPLE_BAR
from os import path
import progressbar
import sox
from util.downloader import maybe_download
import unidecode
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import (
get_counter,
get_imported_samples,
get_importers_parser,
get_validate_label,
print_import_report,
)
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000
MAX_SECS = 15
ARCHIVE_NAME = '2019-04-11_fr_FR'
ARCHIVE_DIR_NAME = 'ts_' + ARCHIVE_NAME
ARCHIVE_URL = 'https://deepspeech-storage-mirror.s3.fr-par.scw.cloud/' + ARCHIVE_NAME + '.zip'
ARCHIVE_NAME = "2019-04-11_fr_FR"
ARCHIVE_DIR_NAME = "ts_" + ARCHIVE_NAME
ARCHIVE_URL = (
"https://deepspeech-storage-mirror.s3.fr-par.scw.cloud/" + ARCHIVE_NAME + ".zip"
)
def _download_and_preprocess_data(target_dir, english_compatible=False):
# Making path absolute
target_dir = path.abspath(target_dir)
target_dir = os.path.abspath(target_dir)
# Conditionally download data
archive_path = maybe_download('ts_' + ARCHIVE_NAME + '.zip', target_dir, ARCHIVE_URL)
archive_path = maybe_download(
"ts_" + ARCHIVE_NAME + ".zip", target_dir, ARCHIVE_URL
)
# Conditionally extract archive data
_maybe_extract(target_dir, ARCHIVE_DIR_NAME, archive_path)
# Conditionally convert TrainingSpeech data to DeepSpeech CSVs and wav
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME, english_compatible=english_compatible)
_maybe_convert_sets(
target_dir, ARCHIVE_DIR_NAME, english_compatible=english_compatible
)
def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data)
if not path.exists(extracted_path):
extracted_path = os.path.join(target_dir, extracted_data)
if not os.path.exists(extracted_path):
print('No directory "%s" - extracting archive...' % extracted_path)
if not os.path.isdir(extracted_path):
os.mkdir(extracted_path)
@ -58,16 +59,20 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
orig_filename = sample['path']
orig_filename = sample["path"]
# Storing wav files next to the wav ones - just with a different suffix
wav_filename = path.splitext(orig_filename)[0] + ".converted.wav"
wav_filename = os.path.splitext(orig_filename)[0] + ".converted.wav"
_maybe_convert_wav(orig_filename, wav_filename)
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = sample['text']
if os.path.exists(wav_filename):
file_size = os.path.getsize(wav_filename)
frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = sample["text"]
rows = []
@ -75,40 +80,41 @@ def one_sample(sample):
counter = get_counter()
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
counter["failed"] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
counter["invalid_label"] += 1
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
counter["too_short"] += 1
elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
counter["too_long"] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
counter["all"] += 1
counter["total_time"] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
extracted_dir = path.join(target_dir, extracted_data)
extracted_dir = os.path.join(target_dir, extracted_data)
# override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, 'ts_' + ARCHIVE_NAME + '_{}.csv')
target_csv_template = os.path.join(target_dir, "ts_" + ARCHIVE_NAME + "_{}.csv")
if os.path.isfile(target_csv_template):
return
path_to_original_csv = os.path.join(extracted_dir, 'data.csv')
path_to_original_csv = os.path.join(extracted_dir, "data.csv")
with open(path_to_original_csv) as csv_f:
data = [
d for d in csv.DictReader(csv_f, delimiter=',')
if float(d['duration']) <= MAX_SECS
d
for d in csv.DictReader(csv_f, delimiter=",")
if float(d["duration"]) <= MAX_SECS
]
for line in data:
line['path'] = os.path.join(extracted_dir, line['path'])
line["path"] = os.path.join(extracted_dir, line["path"])
num_samples = len(data)
rows = []
@ -125,9 +131,9 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
pool.close()
pool.join()
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80%
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10%
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10%
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
train_writer.writeheader()
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
@ -136,7 +142,11 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
test_writer.writeheader()
for i, item in enumerate(rows):
transcript = validate_label(cleanup_transcript(item[2], english_compatible=english_compatible))
transcript = validate_label(
cleanup_transcript(
item[2], english_compatible=english_compatible
)
)
if not transcript:
continue
wav_filename = os.path.join(target_dir, extracted_data, item[0])
@ -147,45 +157,53 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
writer = dev_writer
else:
writer = train_writer
writer.writerow(dict(
wav_filename=wav_filename,
wav_filesize=os.path.getsize(wav_filename),
transcript=transcript,
))
writer.writerow(
dict(
wav_filename=wav_filename,
wav_filesize=os.path.getsize(wav_filename),
transcript=transcript,
)
)
imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples
assert counter["all"] == num_samples
assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(orig_filename, wav_filename):
if not path.exists(wav_filename):
if not os.path.exists(wav_filename):
transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE)
try:
transformer.build(orig_filename, wav_filename)
except sox.core.SoxError as ex:
print('SoX processing error', ex, orig_filename, wav_filename)
print("SoX processing error", ex, orig_filename, wav_filename)
PUNCTUATIONS_REG = re.compile(r"\-,;!?.()\[\]*…—]")
MULTIPLE_SPACES_REG = re.compile(r'\s{2,}')
MULTIPLE_SPACES_REG = re.compile(r"\s{2,}")
def cleanup_transcript(text, english_compatible=False):
text = text.replace('', "'").replace('\u00A0', ' ')
text = PUNCTUATIONS_REG.sub(' ', text)
text = MULTIPLE_SPACES_REG.sub(' ', text)
text = text.replace("", "'").replace("\u00A0", " ")
text = PUNCTUATIONS_REG.sub(" ", text)
text = MULTIPLE_SPACES_REG.sub(" ", text)
if english_compatible:
text = unidecode.unidecode(text)
return text.strip().lower()
def handle_args():
parser = get_importers_parser(description='Importer for TrainingSpeech dataset.')
parser.add_argument(dest='target_dir')
parser.add_argument('--english-compatible', action='store_true', dest='english_compatible', help='Remove diactrics and other non-ascii chars.')
parser = get_importers_parser(description="Importer for TrainingSpeech dataset.")
parser.add_argument(dest="target_dir")
parser.add_argument(
"--english-compatible",
action="store_true",
dest="english_compatible",
help="Remove diactrics and other non-ascii chars.",
)
return parser.parse_args()

View File

@ -1,45 +1,40 @@
#!/usr/bin/env python
'''
"""
Downloads and prepares (parts of) the "German Distant Speech" corpus (TUDA) for DeepSpeech.py
Use "python3 import_tuda.py -h" for help
'''
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import csv
import wave
import tarfile
"""
import argparse
import progressbar
import csv
import os
import tarfile
import unicodedata
import wave
import xml.etree.cElementTree as ET
from os import path
from collections import Counter
from util.text import Alphabet
from util.importers import validate_label_eng as validate_label
from util.downloader import maybe_download, SIMPLE_BAR
TUDA_VERSION = 'v2'
TUDA_PACKAGE = 'german-speechdata-package-{}'.format(TUDA_VERSION)
TUDA_URL = 'http://ltdata1.informatik.uni-hamburg.de/kaldi_tuda_de/{}.tar.gz'.format(TUDA_PACKAGE)
TUDA_ARCHIVE = '{}.tar.gz'.format(TUDA_PACKAGE)
import progressbar
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import validate_label_eng as validate_label
from deepspeech_training.util.text import Alphabet
TUDA_VERSION = "v2"
TUDA_PACKAGE = "german-speechdata-package-{}".format(TUDA_VERSION)
TUDA_URL = "http://ltdata1.informatik.uni-hamburg.de/kaldi_tuda_de/{}.tar.gz".format(
TUDA_PACKAGE
)
TUDA_ARCHIVE = "{}.tar.gz".format(TUDA_PACKAGE)
CHANNELS = 1
SAMPLE_WIDTH = 2
SAMPLE_RATE = 16000
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
def maybe_extract(archive):
extracted = path.join(CLI_ARGS.base_dir, TUDA_PACKAGE)
if path.isdir(extracted):
extracted = os.path.join(CLI_ARGS.base_dir, TUDA_PACKAGE)
if os.path.isdir(extracted):
print('Found directory "{}" - not extracting.'.format(extracted))
else:
print('Extracting "{}"...'.format(archive))
@ -52,86 +47,100 @@ def maybe_extract(archive):
def check_and_prepare_sentence(sentence):
sentence = sentence.lower().replace('co2', 'c o zwei')
sentence = sentence.lower().replace("co2", "c o zwei")
chars = []
for c in sentence:
if CLI_ARGS.normalize and c not in 'äöüß' and (ALPHABET is None or not ALPHABET.has_char(c)):
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore")
if (
CLI_ARGS.normalize
and c not in "äöüß"
and (ALPHABET is None or not ALPHABET.has_char(c))
):
c = (
unicodedata.normalize("NFKD", c)
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
for sc in c:
if ALPHABET is not None and not ALPHABET.has_char(c):
return None
chars.append(sc)
return validate_label(''.join(chars))
return validate_label("".join(chars))
def check_wav_file(wav_path, sentence): # pylint: disable=too-many-return-statements
try:
with wave.open(wav_path, 'r') as src_wav_file:
with wave.open(wav_path, "r") as src_wav_file:
rate = src_wav_file.getframerate()
channels = src_wav_file.getnchannels()
sample_width = src_wav_file.getsampwidth()
milliseconds = int(src_wav_file.getnframes() * 1000 / rate)
if rate != SAMPLE_RATE:
return False, 'wrong sample rate'
return False, "wrong sample rate"
if channels != CHANNELS:
return False, 'wrong number of channels'
return False, "wrong number of channels"
if sample_width != SAMPLE_WIDTH:
return False, 'wrong sample width'
return False, "wrong sample width"
if milliseconds / len(sentence) < 30:
return False, 'too short'
return False, "too short"
if milliseconds > CLI_ARGS.max_duration > 0:
return False, 'too long'
return False, "too long"
except wave.Error:
return False, 'invalid wav file'
return False, "invalid wav file"
except EOFError:
return False, 'premature EOF'
return True, 'OK'
return False, "premature EOF"
return True, "OK"
def write_csvs(extracted):
sample_counter = 0
reasons = Counter()
for sub_set in ['train', 'dev', 'test']:
set_path = path.join(extracted, sub_set)
for sub_set in ["train", "dev", "test"]:
set_path = os.path.join(extracted, sub_set)
set_files = os.listdir(set_path)
recordings = {}
for file in set_files:
if file.endswith('.xml'):
if file.endswith(".xml"):
recordings[file[:-4]] = []
for file in set_files:
if file.endswith('.wav') and '_' in file:
prefix = file.split('_')[0]
if file.endswith(".wav") and "_" in file:
prefix = file.split("_")[0]
if prefix in recordings:
recordings[prefix].append(file)
recordings = recordings.items()
csv_path = path.join(CLI_ARGS.base_dir, 'tuda-{}-{}.csv'.format(TUDA_VERSION, sub_set))
csv_path = os.path.join(
CLI_ARGS.base_dir, "tuda-{}-{}.csv".format(TUDA_VERSION, sub_set)
)
print('Writing "{}"...'.format(csv_path))
with open(csv_path, 'w') as csv_file:
with open(csv_path, "w") as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES)
writer.writeheader()
set_dir = path.join(extracted, sub_set)
set_dir = os.path.join(extracted, sub_set)
bar = progressbar.ProgressBar(max_value=len(recordings), widgets=SIMPLE_BAR)
for prefix, wav_names in bar(recordings):
xml_path = path.join(set_dir, prefix + '.xml')
xml_path = os.path.join(set_dir, prefix + ".xml")
meta = ET.parse(xml_path).getroot()
sentence = list(meta.iter('cleaned_sentence'))[0].text
sentence = list(meta.iter("cleaned_sentence"))[0].text
sentence = check_and_prepare_sentence(sentence)
if sentence is None:
continue
for wav_name in wav_names:
sample_counter += 1
wav_path = path.join(set_path, wav_name)
wav_path = os.path.join(set_path, wav_name)
keep, reason = check_wav_file(wav_path, sentence)
if keep:
writer.writerow({
'wav_filename': path.relpath(wav_path, CLI_ARGS.base_dir),
'wav_filesize': path.getsize(wav_path),
'transcript': sentence.lower()
})
writer.writerow(
{
"wav_filename": os.path.relpath(
wav_path, CLI_ARGS.base_dir
),
"wav_filesize": os.path.getsize(wav_path),
"transcript": sentence.lower(),
}
)
else:
reasons[reason] += 1
if len(reasons.keys()) > 0:
print('Excluded samples:')
print("Excluded samples:")
for reason, n in reasons.most_common():
print(' - "{}": {} ({:.2f}%)'.format(reason, n, n * 100 / sample_counter))
@ -150,13 +159,29 @@ def download_and_prepare():
def handle_args():
parser = argparse.ArgumentParser(description='Import German Distant Speech (TUDA)')
parser.add_argument('base_dir', help='Directory containing all data')
parser.add_argument('--max_duration', type=int, default=10000, help='Maximum sample duration in milliseconds')
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
parser.add_argument('--alphabet', help='Exclude samples with characters not in provided alphabet file')
parser.add_argument('--keep_archive', type=bool, default=True,
help='If downloaded archives should be kept')
parser = argparse.ArgumentParser(description="Import German Distant Speech (TUDA)")
parser.add_argument("base_dir", help="Directory containing all data")
parser.add_argument(
"--max_duration",
type=int,
default=10000,
help="Maximum sample duration in milliseconds",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
parser.add_argument(
"--alphabet",
help="Exclude samples with characters not in provided alphabet file",
)
parser.add_argument(
"--keep_archive",
type=bool,
default=True,
help="If downloaded archives should be kept",
)
return parser.parse_args()

View File

@ -1,29 +1,22 @@
#!/usr/bin/env python
# VCTK used in wavenet paper https://arxiv.org/pdf/1609.03499.pdf
# Licenced under Open Data Commons Attribution License (ODC-By) v1.0.
# as per https://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import random
import sys
sys.path.insert(1, os.path.join(sys.path[0], ".."))
from util.importers import get_counter, get_imported_samples, print_import_report
import re
from multiprocessing import Pool
from zipfile import ZipFile
import librosa
import progressbar
from os import path
from multiprocessing import Pool
from util.downloader import maybe_download, SIMPLE_BAR
from zipfile import ZipFile
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import (
get_counter,
get_imported_samples,
print_import_report,
)
SAMPLE_RATE = 16000
MAX_SECS = 10
@ -37,7 +30,7 @@ ARCHIVE_URL = (
def _download_and_preprocess_data(target_dir):
# Making path absolute
target_dir = path.abspath(target_dir)
target_dir = os.path.abspath(target_dir)
# Conditionally download data
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
# Conditionally extract common voice data
@ -48,8 +41,8 @@ def _download_and_preprocess_data(target_dir):
def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data)
if not path.exists(extracted_path):
extracted_path = os.path.join(target_dir, extracted_data)
if not os.path.exists(extracted_path):
print(f"No directory {extracted_path} - extracting archive...")
with ZipFile(archive_path, "r") as zipobj:
# Extract all the contents of zip file in current directory
@ -59,15 +52,17 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data, "wav48")
txt_dir = path.join(target_dir, extracted_data, "txt")
extracted_dir = os.path.join(target_dir, extracted_data, "wav48")
txt_dir = os.path.join(target_dir, extracted_data, "txt")
directory = os.path.expanduser(extracted_dir)
srtd = len(sorted(os.listdir(directory)))
all_samples = []
for target in sorted(os.listdir(directory)):
all_samples += _maybe_prepare_set(path.join(extracted_dir, os.path.split(target)[-1]))
all_samples += _maybe_prepare_set(
path.join(extracted_dir, os.path.split(target)[-1])
)
num_samples = len(all_samples)
print(f"Converting wav files to {SAMPLE_RATE}hz...")
@ -81,6 +76,7 @@ def _maybe_convert_sets(target_dir, extracted_data):
_write_csv(extracted_dir, txt_dir, target_dir)
def one_sample(sample):
if is_audio_file(sample):
y, sr = librosa.load(sample, sr=16000)
@ -103,6 +99,7 @@ def _maybe_prepare_set(target_csv):
samples = new_samples
return samples
def _write_csv(extracted_dir, txt_dir, target_dir):
print(f"Writing CSV file")
dset_abs_path = extracted_dir
@ -197,7 +194,9 @@ AUDIO_EXTENSIONS = [".wav", "WAV"]
def is_audio_file(filepath):
return any(os.path.basename(filepath).endswith(extension) for extension in AUDIO_EXTENSIONS)
return any(
os.path.basename(filepath).endswith(extension) for extension in AUDIO_EXTENSIONS
)
if __name__ == "__main__":

View File

@ -1,24 +1,19 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
import codecs
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import tarfile
import pandas
import re
import unicodedata
import tarfile
import threading
from multiprocessing.pool import ThreadPool
from six.moves import urllib
import unicodedata
import urllib
from glob import glob
from multiprocessing.pool import ThreadPool
from os import makedirs, path
import pandas
from bs4 import BeautifulSoup
from tensorflow.python.platform import gfile
from util.downloader import maybe_download
from deepspeech_training.util.downloader import maybe_download
"""The number of jobs to run in parallel"""
NUM_PARALLEL = 8
@ -26,8 +21,10 @@ NUM_PARALLEL = 8
"""Lambda function returns the filename of a path"""
filename_of = lambda x: path.split(x)[1]
class AtomicCounter(object):
"""A class that atomically increments a counter"""
def __init__(self, start_count=0):
"""Initialize the counter
:param start_count: the number to start counting at
@ -50,6 +47,7 @@ class AtomicCounter(object):
"""Returns the current value of the counter (not atomic)"""
return self.__count
def _parallel_downloader(voxforge_url, archive_dir, total, counter):
"""Generate a function to download a file based on given parameters
This works by currying the above given arguments into a closure
@ -61,6 +59,7 @@ def _parallel_downloader(voxforge_url, archive_dir, total, counter):
:param counter: an atomic counter to keep track of # of downloaded files
:return: a function that actually downloads a file given these params
"""
def download(d):
"""Binds voxforge_url, archive_dir, total, and counter into this scope
Downloads the given file
@ -68,12 +67,14 @@ def _parallel_downloader(voxforge_url, archive_dir, total, counter):
of the file to download and file is the name of the file to download
"""
(i, file) = d
download_url = voxforge_url + '/' + file
download_url = voxforge_url + "/" + file
c = counter.increment()
print('Downloading file {} ({}/{})...'.format(i+1, c, total))
print("Downloading file {} ({}/{})...".format(i + 1, c, total))
maybe_download(filename_of(download_url), archive_dir, download_url)
return download
def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter):
"""Generate a function to extract a tar file based on given parameters
This works by currying the above given arguments into a closure
@ -86,6 +87,7 @@ def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter)
:param counter: an atomic counter to keep track of # of extracted files
:return: a function that actually extracts a tar file given these params
"""
def extract(d):
"""Binds data_dir, number_of_test, number_of_dev, total, and counter into this scope
Extracts the given file
@ -95,58 +97,74 @@ def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter)
(i, archive) = d
if i < number_of_test:
dataset_dir = path.join(data_dir, "test")
elif i<number_of_test+number_of_dev:
elif i < number_of_test + number_of_dev:
dataset_dir = path.join(data_dir, "dev")
else:
dataset_dir = path.join(data_dir, "train")
if not gfile.Exists(path.join(dataset_dir, '.'.join(filename_of(archive).split(".")[:-1]))):
if not gfile.Exists(
os.path.join(dataset_dir, ".".join(filename_of(archive).split(".")[:-1]))
):
c = counter.increment()
print('Extracting file {} ({}/{})...'.format(i+1, c, total))
print("Extracting file {} ({}/{})...".format(i + 1, c, total))
tar = tarfile.open(archive)
tar.extractall(dataset_dir)
tar.close()
return extract
def _download_and_preprocess_data(data_dir):
# Conditionally download data to data_dir
if not path.isdir(data_dir):
makedirs(data_dir)
archive_dir = data_dir+"/archive"
archive_dir = data_dir + "/archive"
if not path.isdir(archive_dir):
makedirs(archive_dir)
print("Downloading Voxforge data set into {} if not already present...".format(archive_dir))
print(
"Downloading Voxforge data set into {} if not already present...".format(
archive_dir
)
)
voxforge_url = 'http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/Audio/Main/16kHz_16bit'
voxforge_url = "http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/Audio/Main/16kHz_16bit"
html_page = urllib.request.urlopen(voxforge_url)
soup = BeautifulSoup(html_page, 'html.parser')
soup = BeautifulSoup(html_page, "html.parser")
# list all links
refs = [l['href'] for l in soup.find_all('a') if ".tgz" in l['href']]
refs = [l["href"] for l in soup.find_all("a") if ".tgz" in l["href"]]
# download files in parallel
print('{} files to download'.format(len(refs)))
downloader = _parallel_downloader(voxforge_url, archive_dir, len(refs), AtomicCounter())
print("{} files to download".format(len(refs)))
downloader = _parallel_downloader(
voxforge_url, archive_dir, len(refs), AtomicCounter()
)
p = ThreadPool(NUM_PARALLEL)
p.map(downloader, enumerate(refs))
# Conditionally extract data to dataset_dir
if not path.isdir(path.join(data_dir,"test")):
makedirs(path.join(data_dir,"test"))
if not path.isdir(path.join(data_dir,"dev")):
makedirs(path.join(data_dir,"dev"))
if not path.isdir(path.join(data_dir,"train")):
makedirs(path.join(data_dir,"train"))
if not path.isdir(os.path.join(data_dir, "test")):
makedirs(os.path.join(data_dir, "test"))
if not path.isdir(os.path.join(data_dir, "dev")):
makedirs(os.path.join(data_dir, "dev"))
if not path.isdir(os.path.join(data_dir, "train")):
makedirs(os.path.join(data_dir, "train"))
tarfiles = glob(path.join(archive_dir, "*.tgz"))
tarfiles = glob(os.path.join(archive_dir, "*.tgz"))
number_of_files = len(tarfiles)
number_of_test = number_of_files//100
number_of_dev = number_of_files//100
number_of_test = number_of_files // 100
number_of_dev = number_of_files // 100
# extract tars in parallel
print("Extracting Voxforge data set into {} if not already present...".format(data_dir))
extracter = _parallel_extracter(data_dir, number_of_test, number_of_dev, len(tarfiles), AtomicCounter())
print(
"Extracting Voxforge data set into {} if not already present...".format(
data_dir
)
)
extracter = _parallel_extracter(
data_dir, number_of_test, number_of_dev, len(tarfiles), AtomicCounter()
)
p.map(extracter, enumerate(tarfiles))
# Generate data set
@ -156,42 +174,50 @@ def _download_and_preprocess_data(data_dir):
train_files = _generate_dataset(data_dir, "train")
# Write sets to disk as CSV files
train_files.to_csv(path.join(data_dir, "voxforge-train.csv"), index=False)
dev_files.to_csv(path.join(data_dir, "voxforge-dev.csv"), index=False)
test_files.to_csv(path.join(data_dir, "voxforge-test.csv"), index=False)
train_files.to_csv(os.path.join(data_dir, "voxforge-train.csv"), index=False)
dev_files.to_csv(os.path.join(data_dir, "voxforge-dev.csv"), index=False)
test_files.to_csv(os.path.join(data_dir, "voxforge-test.csv"), index=False)
def _generate_dataset(data_dir, data_set):
extracted_dir = path.join(data_dir, data_set)
files = []
for promts_file in glob(path.join(extracted_dir+"/*/etc/", "PROMPTS")):
if path.isdir(path.join(promts_file[:-11],"wav")):
with codecs.open(promts_file, 'r', 'utf-8') as f:
for promts_file in glob(os.path.join(extracted_dir + "/*/etc/", "PROMPTS")):
if path.isdir(os.path.join(promts_file[:-11], "wav")):
with codecs.open(promts_file, "r", "utf-8") as f:
for line in f:
id = line.split(' ')[0].split('/')[-1]
sentence = ' '.join(line.split(' ')[1:])
sentence = re.sub("[^a-z']"," ",sentence.strip().lower())
id = line.split(" ")[0].split("/")[-1]
sentence = " ".join(line.split(" ")[1:])
sentence = re.sub("[^a-z']", " ", sentence.strip().lower())
transcript = ""
for token in sentence.split(" "):
word = token.strip()
if word!="" and word!=" ":
if word != "" and word != " ":
transcript += word + " "
transcript = unicodedata.normalize("NFKD", transcript.strip()) \
.encode("ascii", "ignore") \
.decode("ascii", "ignore")
wav_file = path.join(promts_file[:-11],"wav/" + id + ".wav")
transcript = (
unicodedata.normalize("NFKD", transcript.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
wav_file = path.join(promts_file[:-11], "wav/" + id + ".wav")
if gfile.Exists(wav_file):
wav_filesize = path.getsize(wav_file)
# remove audios that are shorter than 0.5s and longer than 20s.
# remove audios that are too short for transcript.
if (wav_filesize/32000)>0.5 and (wav_filesize/32000)<20 and transcript!="" and \
wav_filesize/len(transcript)>1400:
files.append((path.abspath(wav_file), wav_filesize, transcript))
if (
(wav_filesize / 32000) > 0.5
and (wav_filesize / 32000) < 20
and transcript != ""
and wav_filesize / len(transcript) > 1400
):
files.append(
(os.path.abspath(wav_file), wav_filesize, transcript)
)
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
return pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
)
if __name__=="__main__":
if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1])

View File

@ -1,15 +1,18 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import tensorflow.compat.v1 as tfv1
import sys
import tensorflow.compat.v1 as tfv1
def main():
with tfv1.gfile.FastGFile(sys.argv[1], 'rb') as fin:
with tfv1.gfile.FastGFile(sys.argv[1], "rb") as fin:
graph_def = tfv1.GraphDef()
graph_def.ParseFromString(fin.read())
print('\n'.join(sorted(set(n.op for n in graph_def.node))))
print("\n".join(sorted(set(n.op for n in graph_def.node))))
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -3,19 +3,13 @@
Tool for playing samples from Sample Databases (SDB files) and DeepSpeech CSV files
Use "python3 build_sdb.py -h" for help
"""
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import random
import argparse
import random
import sys
from util.sample_collections import samples_from_file, LabeledSample
from util.audio import AUDIO_TYPE_PCM
from deepspeech_training.util.audio import AUDIO_TYPE_PCM
from deepspeech_training.util.sample_collections import LabeledSample, samples_from_file
def play_sample(samples, index):
@ -24,7 +18,7 @@ def play_sample(samples, index):
if CLI_ARGS.random:
index = random.randint(0, len(samples))
elif index >= len(samples):
print('No sample with index {}'.format(CLI_ARGS.start))
print("No sample with index {}".format(CLI_ARGS.start))
sys.exit(1)
sample = samples[index]
print('Sample "{}"'.format(sample.sample_id))
@ -50,13 +44,28 @@ def play_collection():
def handle_args():
parser = argparse.ArgumentParser(description='Tool for playing samples from Sample Databases (SDB files) '
'and DeepSpeech CSV files')
parser.add_argument('collection', help='Sample DB or CSV file to play samples from')
parser.add_argument('--start', type=int, default=0,
help='Sample index to start at (negative numbers are relative to the end of the collection)')
parser.add_argument('--number', type=int, default=-1, help='Number of samples to play (-1 for endless)')
parser.add_argument('--random', action='store_true', help='If samples should be played in random order')
parser = argparse.ArgumentParser(
description="Tool for playing samples from Sample Databases (SDB files) "
"and DeepSpeech CSV files"
)
parser.add_argument("collection", help="Sample DB or CSV file to play samples from")
parser.add_argument(
"--start",
type=int,
default=0,
help="Sample index to start at (negative numbers are relative to the end of the collection)",
)
parser.add_argument(
"--number",
type=int,
default=-1,
help="Number of samples to play (-1 for endless)",
)
parser.add_argument(
"--random",
action="store_true",
help="If samples should be played in random order",
)
return parser.parse_args()
@ -70,5 +79,5 @@ if __name__ == "__main__":
try:
play_collection()
except KeyboardInterrupt:
print(' Stopped')
print(" Stopped")
sys.exit(0)

View File

@ -31,7 +31,7 @@ for LOAD in 'init' 'last' 'auto'; do
echo "########################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--alphabet_config_path "./data/alphabet.txt" \
--load "$LOAD" \
--load_train "$LOAD" \
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
@ -45,60 +45,7 @@ for LOAD in 'init' 'last' 'auto'; do
echo "##############################################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--alphabet_config_path "./data/alphabet.txt" \
--load "$LOAD" \
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
--save_checkpoint_dir '/tmp/ckpt/transfer/eng' \
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
--scorer_path '' \
--n_hidden 100 \
--epochs 10
echo "#################################################################################"
echo "#### Transfer Russian model with --save_checkpoint_dir --load_checkpoint_dir ####"
echo "#################################################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--drop_source_layers 1 \
--alphabet_config_path "${ru_dir}/alphabet.ru" \
--load 'last' \
--train_files "${ru_csv}" --train_batch_size 1 \
--dev_files "${ru_csv}" --dev_batch_size 1 \
--test_files "${ru_csv}" --test_batch_size 1 \
--save_checkpoint_dir '/tmp/ckpt/transfer/ru' \
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
--scorer_path '' \
--n_hidden 100 \
--epochs 10
done
echo "#######################################################"
echo "##### Train ENGLISH model and transfer to RUSSIAN #####"
echo "##### while iterating over loading logic #####"
echo "#######################################################"
for LOAD in 'init' 'last' 'auto'; do
echo "########################################################"
echo "#### Train ENGLISH model with just --checkpoint_dir ####"
echo "########################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--alphabet_config_path "./data/alphabet.txt" \
--load "$LOAD" \
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
--checkpoint_dir '/tmp/ckpt/transfer/eng' \
--scorer_path '' \
--n_hidden 100 \
--epochs 10
echo "##############################################################################"
echo "#### Train ENGLISH model with --save_checkpoint_dir --load_checkpoint_dir ####"
echo "##############################################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--alphabet_config_path "./data/alphabet.txt" \
--load "$LOAD" \
--load_train "$LOAD" \
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
@ -114,13 +61,20 @@ for LOAD in 'init' 'last' 'auto'; do
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--drop_source_layers 1 \
--alphabet_config_path "${ru_dir}/alphabet.ru" \
--load 'last' \
--load_train 'last' \
--train_files "${ru_csv}" --train_batch_size 1 \
--dev_files "${ru_csv}" --dev_batch_size 1 \
--test_files "${ru_csv}" --test_batch_size 1 \
--save_checkpoint_dir '/tmp/ckpt/transfer/ru' \
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
--scorer_path '' \
--n_hidden 100 \
--epochs 10
# Test transfer learning checkpoint
python -u evaluate.py --noshow_progressbar \
--test_files "${ru_csv}" --test_batch_size 1 \
--alphabet_config_path "${ru_dir}/alphabet.ru" \
--load_checkpoint_dir '/tmp/ckpt/transfer/ru' \
--scorer_path '' \
--n_hidden 100
done

View File

@ -1,17 +1,11 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
import argparse
import shutil
import sys
from util.text import Alphabet, UTF8Alphabet
from deepspeech_training.util.text import Alphabet, UTF8Alphabet
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
@ -48,6 +42,9 @@ def create_bundle(
if use_utf8:
serialized_alphabet = UTF8Alphabet().serialize()
else:
if not alphabet_path:
print("No --alphabet path specified, can't continue.")
sys.exit(1)
serialized_alphabet = Alphabet(alphabet_path).serialize()
alphabet = NativeAlphabet()

View File

@ -1,13 +1,17 @@
C API Usage example
===================
Examples are from `native_client/client.cc`.
Creating a model instance and loading model
-------------------------------------------
.. literalinclude:: ../native_client/client.cc
:language: c
:linenos:
:lines: 370-375,386-390
:lineno-match:
:start-after: sphinx-doc: c_ref_model_start
:end-before: sphinx-doc: c_ref_model_stop
Performing inference
--------------------
@ -15,7 +19,9 @@ Performing inference
.. literalinclude:: ../native_client/client.cc
:language: c
:linenos:
:lines: 59-94
:lineno-match:
:start-after: sphinx-doc: c_ref_inference_start
:end-before: sphinx-doc: c_ref_inference_stop
Full source code
----------------

View File

@ -16,9 +16,9 @@ DeepSpeech Class
:members:
DeepSpeechStream Class
----------------
----------------------
.. doxygenclass:: DeepSpeechClient::DeepSpeechStream
.. doxygenclass:: DeepSpeechClient::Models::DeepSpeechStream
:project: deepspeech-dotnet
:members:

29
doc/DotNet-Examples.rst Normal file
View File

@ -0,0 +1,29 @@
.Net API Usage example
======================
Examples are from `native_client/dotnet/DeepSpeechConsole/Program.cs`.
Creating a model instance and loading model
-------------------------------------------
.. literalinclude:: ../native_client/dotnet/DeepSpeechConsole/Program.cs
:language: csharp
:linenos:
:lineno-match:
:start-after: sphinx-doc: csharp_ref_model_start
:end-before: sphinx-doc: csharp_ref_model_stop
Performing inference
--------------------
.. literalinclude:: ../native_client/dotnet/DeepSpeechConsole/Program.cs
:language: csharp
:linenos:
:lineno-match:
:start-after: sphinx-doc: csharp_ref_inference_start
:end-before: sphinx-doc: csharp_ref_inference_stop
Full source code
----------------
See :download:`Full source code<../native_client/dotnet/DeepSpeechConsole/Program.cs>`.

View File

@ -1,13 +1,17 @@
Java API Usage example
======================
Examples are from `native_client/java/app/src/main/java/org/mozilla/deepspeech/DeepSpeechActivity.java`.
Creating a model instance and loading model
-------------------------------------------
.. literalinclude:: ../native_client/java/app/src/main/java/org/mozilla/deepspeech/DeepSpeechActivity.java
:language: java
:linenos:
:lines: 52
:lineno-match:
:start-after: sphinx-doc: java_ref_model_start
:end-before: sphinx-doc: java_ref_model_stop
Performing inference
--------------------
@ -15,7 +19,9 @@ Performing inference
.. literalinclude:: ../native_client/java/app/src/main/java/org/mozilla/deepspeech/DeepSpeechActivity.java
:language: java
:linenos:
:lines: 101
:lineno-match:
:start-after: sphinx-doc: java_ref_inference_start
:end-before: sphinx-doc: java_ref_inference_stop
Full source code
----------------

View File

@ -1,6 +1,8 @@
JavaScript (NodeJS / ElectronJS)
================================
Support for TypeScript is :download:`provided in index.d.ts<../native_client/javascript/index.d.ts>`
Model
-----

View File

@ -1,23 +1,29 @@
JavaScript API Usage example
=============================
Examples are from `native_client/javascript/client.ts`.
Creating a model instance and loading model
-------------------------------------------
.. literalinclude:: ../native_client/javascript/client.js
.. literalinclude:: ../native_client/javascript/client.ts
:language: javascript
:linenos:
:lines: 56,69
:lineno-match:
:start-after: sphinx-doc: js_ref_model_start
:end-before: sphinx-doc: js_ref_model_stop
Performing inference
--------------------
.. literalinclude:: ../native_client/javascript/client.js
.. literalinclude:: ../native_client/javascript/client.ts
:language: javascript
:linenos:
:lines: 122
:lineno-match:
:start-after: sphinx-doc: js_ref_inference_start
:end-before: sphinx-doc: js_ref_inference_stop
Full source code
----------------
See :download:`Full source code<../native_client/javascript/client.js>`.
See :download:`Full source code<../native_client/javascript/client.ts>`.

View File

@ -1,13 +1,17 @@
Python API Usage example
========================
Examples are from `native_client/python/client.cc`.
Creating a model instance and loading model
-------------------------------------------
.. literalinclude:: ../native_client/python/client.py
:language: python
:linenos:
:lines: 111,123
:lineno-match:
:start-after: sphinx-doc: python_ref_model_start
:end-before: sphinx-doc: python_ref_model_stop
Performing inference
--------------------
@ -15,7 +19,9 @@ Performing inference
.. literalinclude:: ../native_client/python/client.py
:language: python
:linenos:
:lines: 143-148
:lineno-match:
:start-after: sphinx-doc: python_ref_inference_start
:end-before: sphinx-doc: python_ref_inference_stop
Full source code
----------------

View File

@ -25,7 +25,7 @@ In creating a virtual environment you will create a directory containing a ``pyt
.. code-block::
$ virtualenv -p python3 $HOME/tmp/deepspeech-train-venv/
$ python3 -m venv $HOME/tmp/deepspeech-train-venv/
Once this command completes successfully, the environment will be ready to be activated.
@ -38,15 +38,16 @@ Each time you need to work with DeepSpeech, you have to *activate* this virtual
$ source $HOME/tmp/deepspeech-train-venv/bin/activate
Installing Python dependencies
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Installing DeepSpeech Training Code and its dependencies
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Install the required dependencies using ``pip3``\ :
.. code-block:: bash
cd DeepSpeech
pip3 install -r requirements.txt
pip3 install --upgrade pip==20.0.2 wheel==0.34.2 setuptools==46.1.3
pip3 install --upgrade --force-reinstall -e .
The ``webrtcvad`` Python package might require you to ensure you have proper tooling to build Python modules:
@ -54,14 +55,6 @@ The ``webrtcvad`` Python package might require you to ensure you have proper too
sudo apt-get install python3-dev
You'll also need to install the ``ds_ctcdecoder`` Python package. ``ds_ctcdecoder`` is required for decoding the outputs of the ``deepspeech`` acoustic model into text. You can use ``util/taskcluster.py`` with the ``--decoder`` flag to get a URL to a binary of the decoder package appropriate for your platform and Python version:
.. code-block:: bash
pip3 install $(python3 util/taskcluster.py --decoder)
This command will download and install the ``ds_ctcdecoder`` package. You can override the platform with ``--arch`` if you want the package for ARM7 (\ ``--arch arm``\ ) or ARM64 (\ ``--arch arm64``\ ). If you prefer building the ``ds_ctcdecoder`` package from source, see the :github:`native_client README file <native_client/README.rst>`.
Recommendations
^^^^^^^^^^^^^^^
@ -70,7 +63,7 @@ If you have a capable (NVIDIA, at least 8GB of VRAM) GPU, it is highly recommend
.. code-block:: bash
pip3 uninstall tensorflow
pip3 install 'tensorflow-gpu==1.15.0'
pip3 install 'tensorflow-gpu==1.15.2'
Please ensure you have the required `CUDA dependency <USING.rst#cuda-dependency>`_.
@ -148,6 +141,18 @@ Feel also free to pass additional (or overriding) ``DeepSpeech.py`` parameters t
Each dataset has a corresponding importer script in ``bin/`` that can be used to download (if it's freely available) and preprocess the dataset. See ``bin/import_librivox.py`` for an example of how to import and preprocess a large dataset for training with DeepSpeech.
Some importers might require additional code to properly handled your locale-specific requirements. Such handling is dealt with ``--validate_label_locale`` flag that allows you to source out-of-tree Python script that defines a ``validate_label`` function. Please refer to ``util/importers.py`` for implementation example of that function.
If you don't provide this argument, the default ``validate_label`` function will be used. This one is only intended for English language, so you might have consistency issues in your data for other languages.
For example, in order to use a custom validation function that disallows any sample with "a" in its transcript, and lower cases everything else, you could put the following code in a file called ``my_validation.py`` and then use ``--validate_label_locale my_validation.py``:
.. code-block:: python
def validate_label(label):
if 'a' in label: # disallow labels with 'a'
return None
return label.lower() # lower case valid labels
If you've run the old importers (in ``util/importers/``\ ), they could have removed source files that are needed for the new importers to run. In that case, simply remove the extracted folders and let the importer extract and process the dataset from scratch, and things should work.
Training with automatic mixed precision

View File

@ -125,6 +125,8 @@ Please note that as of now, we support:
- Node.JS versions 4 to 13.
- Electron.JS versions 1.6 to 7.1
TypeScript support is also provided.
Alternatively, if you're using Linux and have a supported NVIDIA GPU, you can install the GPU specific package as follows:
.. code-block:: bash
@ -133,7 +135,7 @@ Alternatively, if you're using Linux and have a supported NVIDIA GPU, you can in
See the `release notes <https://github.com/mozilla/DeepSpeech/releases>`_ to find which GPUs are supported. Please ensure you have the required `CUDA dependency <#cuda-dependency>`_.
See :github:`client.js <native_client/javascript/client.js>` for an example of how to use the bindings.
See :github:`client.ts <native_client/javascript/client.ts>` for an example of how to use the bindings.
Using the Command-Line client
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

View File

@ -104,7 +104,7 @@ language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
exclude_patterns = ['.build', 'Thumbs.db', '.DS_Store']
exclude_patterns = ['.build', 'Thumbs.db', '.DS_Store', 'node_modules']
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'

View File

@ -14,6 +14,8 @@ Welcome to DeepSpeech's documentation!
TRAINING
Decoder
.. toctree::
:maxdepth: 2
:caption: DeepSpeech Model
@ -52,17 +54,19 @@ Welcome to DeepSpeech's documentation!
C-Examples
NodeJS-Examples
DotNet-Examples
Java-Examples
NodeJS-Examples
Python-Examples
.. toctree::
:maxdepth: 2
:caption: Contributed examples
DotNet-contrib-examples.rst
DotNet-contrib-examples
NodeJS-contrib-Examples

158
evaluate.py Executable file → Normal file
View File

@ -2,155 +2,11 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import json
import sys
from multiprocessing import cpu_count
import absl.app
import progressbar
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from six.moves import zip
from util.config import Config, initialize_globals
from util.checkpoints import load_or_init_graph
from util.evaluate_tools import calculate_and_print_report
from util.feeding import create_dataset
from util.flags import create_flags, FLAGS
from util.helpers import check_ctcdecoder_version
from util.logging import create_progressbar, log_error, log_progress
check_ctcdecoder_version()
def sparse_tensor_value_to_texts(value, alphabet):
r"""
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
representing its values, converting tokens to strings using ``alphabet``.
"""
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
def sparse_tuple_to_texts(sp_tuple, alphabet):
indices = sp_tuple[0]
values = sp_tuple[1]
results = [[] for _ in range(sp_tuple[2][0])]
for i, index in enumerate(indices):
results[index[0]].append(values[i])
# List of strings
return [alphabet.decode(res) for res in results]
def evaluate(test_csvs, create_model):
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet)
else:
scorer = None
test_csvs = FLAGS.test_files.split(',')
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
tfv1.data.get_output_shapes(test_sets[0]),
output_classes=tfv1.data.get_output_classes(test_sets[0]))
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
# One rate per layer
no_dropout = [None] * 6
logits, _ = create_model(batch_x=batch_x,
batch_size=FLAGS.test_batch_size,
seq_length=batch_x_len,
dropout=no_dropout)
# Transpose to batch major and apply softmax for decoder
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
loss = tfv1.nn.ctc_loss(labels=batch_y,
inputs=logits,
sequence_length=batch_x_len)
tfv1.train.get_or_create_global_step()
# Get number of accessible CPU cores for this process
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
with tfv1.Session(config=Config.session_config) as session:
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
def run_test(init_op, dataset):
wav_filenames = []
losses = []
predictions = []
ground_truths = []
bar = create_progressbar(prefix='Test epoch | ',
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
log_progress('Test epoch...')
step_count = 0
# Initialize iterator to the appropriate dataset
session.run(init_op)
# First pass, compute losses and transposed logits for decoding
while True:
try:
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
except tf.errors.OutOfRangeError:
break
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
num_processes=num_processes, scorer=scorer,
cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n)
predictions.extend(d[0][1] for d in decoded)
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
losses.extend(batch_loss)
step_count += 1
bar.update(step_count)
bar.finish()
# Print test summary
test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset)
return test_samples
samples = []
for csv, init_op in zip(test_csvs, test_init_ops):
print('Testing model on {}'.format(csv))
samples.extend(run_test(init_op, dataset=csv))
return samples
def main(_):
initialize_globals()
if not FLAGS.test_files:
log_error('You need to specify what files to use for evaluation via '
'the --test_files flag.')
sys.exit(1)
from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
if __name__ == '__main__':
create_flags()
absl.app.run(main)
try:
from deepspeech_training import evaluate as ds_evaluate
except ImportError:
print('Training package is not installed. See training documentation.')
raise
ds_evaluate.run_script()

View File

@ -10,13 +10,12 @@ import csv
import os
import sys
from functools import partial
from six.moves import zip, range
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
from deepspeech import Model
from util.evaluate_tools import calculate_and_print_report
from util.flags import create_flags
from deepspeech_training.util.evaluate_tools import calculate_and_print_report
from deepspeech_training.util.flags import create_flags
from functools import partial
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
from six.moves import zip, range
r'''
This module should be self-contained:

View File

@ -2,19 +2,18 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
import sys
import optuna
import absl.app
from ds_ctcdecoder import Scorer
import optuna
import sys
import tensorflow.compat.v1 as tfv1
from DeepSpeech import create_model
from evaluate import evaluate
from util.config import Config, initialize_globals
from util.flags import create_flags, FLAGS
from util.logging import log_error
from util.evaluate_tools import wer_cer_batch
from deepspeech_training.evaluate import evaluate
from deepspeech_training.train import create_model
from deepspeech_training.util.config import Config, initialize_globals
from deepspeech_training.util.flags import create_flags, FLAGS
from deepspeech_training.util.logging import log_error
from deepspeech_training.util.evaluate_tools import wer_cer_batch
from ds_ctcdecoder import Scorer
def character_based():
@ -28,11 +27,23 @@ def objective(trial):
FLAGS.lm_alpha = trial.suggest_uniform('lm_alpha', 0, FLAGS.lm_alpha_max)
FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max)
tfv1.reset_default_graph()
samples = evaluate(FLAGS.test_files.split(','), create_model)
is_character_based = trial.study.user_attrs['is_character_based']
samples = []
for step, test_file in enumerate(FLAGS.test_files.split(',')):
tfv1.reset_default_graph()
current_samples = evaluate([test_file], create_model)
samples += current_samples
# Report intermediate objective value.
wer, cer = wer_cer_batch(current_samples)
trial.report(cer if is_character_based else wer, step)
# Handle pruning based on the intermediate value.
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
wer, cer = wer_cer_batch(samples)
return cer if is_character_based else wer

View File

@ -14,7 +14,10 @@ It is required to use our fork of TensorFlow since it includes fixes for common
If you'd like to build the language bindings or the decoder package, you'll also need:
* `SWIG >= 3.0.12 <http://www.swig.org/>`_. If you intend to build NodeJS / ElectronJS bindings you will need a patched version of SWIG. Please refer to the matching section below.
* `SWIG >= 3.0.12 <http://www.swig.org/>`_.
Unfortunately, NodeJS / ElectronJS after 10.x support on SWIG is a bit behind, and while there are pending patches proposed to upstream, it is not yet merged.
The proper prebuilt patched version (covering linux, windows and macOS) of SWIG should get installed under `native_client/ <native_client/>`_ as soon as you build any bindings that requires it.
* `node-pre-gyp <https://github.com/mapbox/node-pre-gyp>`_ (for Node.JS bindings only)
Dependencies
@ -108,10 +111,6 @@ The API mirrors the C++ API and is demonstrated in `client.py <python/client.py>
Install NodeJS / ElectronJS bindings
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Unfortunately, JavaScript support on SWIG is a bit behind, and while there are pending patches proposed to upstream, it is not yet merged.
You should be able to build from `our fork <https://github.com/lissyx/swig/tree/taskcluster>`_, and you can find pre-built binaries on `TaskCluster <https://community-tc.services.mozilla.com/tasks/index/project.deepspeech.swig>`_ (please look for swig fork sha1).
Extract the `ds-swig.tar.gz` to some place in your `$HOME`, then update `$PATH` accordingly. You might need to symlink `ds-swig` as `swig`, and you will have to `export SWIG_LIB=<path/to/swig/share>` so that it contains path to `share/swig/<VERSION>/`.
After following the above build and installation instructions, the Node.JS bindings can be built:
.. code-block::

View File

@ -162,6 +162,7 @@ LocalDsSTT(ModelState* aCtx, const short* aBuffer, size_t aBufferSize,
clock_t ds_start_time = clock();
// sphinx-doc: c_ref_inference_start
if (extended_output) {
Metadata *result = DS_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize, 1);
res.string = CandidateTranscriptToString(&result->transcripts[0]);
@ -198,6 +199,7 @@ LocalDsSTT(ModelState* aCtx, const short* aBuffer, size_t aBufferSize,
} else {
res.string = DS_SpeechToText(aCtx, aBuffer, aBufferSize);
}
// sphinx-doc: c_ref_inference_stop
clock_t ds_end_infer = clock();
@ -393,6 +395,7 @@ main(int argc, char **argv)
// Initialise DeepSpeech
ModelState* ctx;
// sphinx-doc: c_ref_model_start
int status = DS_CreateModel(model, &ctx);
if (status != 0) {
fprintf(stderr, "Could not create model.\n");
@ -421,6 +424,7 @@ main(int argc, char **argv)
}
}
}
// sphinx-doc: c_ref_model_stop
#ifndef NO_SOX
// Initialise SOX

View File

@ -43,16 +43,16 @@ workspace_status.cc:
# Enforce PATH here because swig calls from build_ext looses track of some
# variables over several runs
bindings: clean-keep-third-party workspace_status.cc
bindings: clean-keep-third-party workspace_status.cc ds-swig
pip install --quiet $(PYTHON_PACKAGES) wheel==0.33.6 setuptools==39.1.0
PATH=$(TOOLCHAIN):$$PATH AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
PATH=$(DS_SWIG_BIN_PATH):$(TOOLCHAIN):$$PATH SWIG_LIB="$(SWIG_LIB)" AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
find temp_build -type f -name "*.o" -delete
AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py bdist_wheel $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
rm -rf temp_build
bindings-debug: clean-keep-third-party workspace_status.cc
bindings-debug: clean-keep-third-party workspace_status.cc ds-swig
pip install --quiet $(PYTHON_PACKAGES) wheel==0.33.6 setuptools==39.1.0
PATH=$(TOOLCHAIN):$$PATH AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS) -DDEBUG" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --debug --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
PATH=$(DS_SWIG_BIN_PATH):$(TOOLCHAIN):$$PATH SWIG_LIB="$(SWIG_LIB)" AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS) -DDEBUG" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --debug --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
$(GENERATE_DEBUG_SYMS)
find temp_build -type f -name "*.o" -delete
AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS) -DDEBUG" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py bdist_wheel $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)

View File

@ -12,7 +12,11 @@ TOOL_LD := ld
TOOL_LDD := ldd
TOOL_LIBEXE :=
DEEPSPEECH_BIN := deepspeech
ifeq ($(findstring _NT,$(OS)),_NT)
PLATFORM_EXE_SUFFIX := .exe
endif
DEEPSPEECH_BIN := deepspeech$(PLATFORM_EXE_SUFFIX)
CFLAGS_DEEPSPEECH := -std=c++11 -o $(DEEPSPEECH_BIN)
LINK_DEEPSPEECH := -ldeepspeech
LINK_PATH_DEEPSPEECH := -L${TFDIR}/bazel-bin/native_client
@ -36,7 +40,6 @@ endif
endif
ifeq ($(TARGET),host-win)
DEEPSPEECH_BIN := deepspeech.exe
TOOLCHAIN := '$(VCINSTALLDIR)\bin\amd64\'
TOOL_CC := cl.exe
TOOL_CXX := cl.exe
@ -170,3 +173,36 @@ define copy_missing_libs
done; \
fi;
endef
SWIG_DIST_URL ?=
ifeq ($(findstring Linux,$(OS)),Linux)
SWIG_DIST_URL := "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.swig.linux.amd64.b5fea54d39832d1d132d7dd921b69c0c2c9d5118/artifacts/public/ds-swig.tar.gz"
else ifeq ($(findstring Darwin,$(OS)),Darwin)
SWIG_DIST_URL := "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.swig.darwin.amd64.b5fea54d39832d1d132d7dd921b69c0c2c9d5118/artifacts/public/ds-swig.tar.gz"
else ifeq ($(findstring _NT,$(OS)),_NT)
SWIG_DIST_URL := "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.swig.win.amd64.b5fea54d39832d1d132d7dd921b69c0c2c9d5118/artifacts/public/ds-swig.tar.gz"
else
$(error There is no prebuilt SWIG available for your platform. Please produce one and set SWIG_DIST_URL.)
endif
# Should point to native_client/ subdir by default
SWIG_ROOT ?= $(abspath $(shell dirname "$(lastword $(MAKEFILE_LIST))"))/ds-swig
ifeq ($(findstring _NT,$(OS)),_NT)
SWIG_ROOT ?= $(shell cygpath -u "$(SWIG_ROOT)")
endif
SWIG_LIB ?= $(SWIG_ROOT)/share/swig/4.0.2/
SWIG_BIN := swig$(PLATFORM_EXE_SUFFIX)
DS_SWIG_BIN := ds-swig$(PLATFORM_EXE_SUFFIX)
DS_SWIG_BIN_PATH := $(SWIG_ROOT)/bin
DS_SWIG_ENV := SWIG_LIB="$(SWIG_LIB)" PATH="${PATH}:$(DS_SWIG_BIN_PATH)"
$(DS_SWIG_BIN_PATH)/swig:
mkdir -p $(SWIG_ROOT)
wget -O - "$(SWIG_DIST_URL)" | tar -C $(SWIG_ROOT) -zxf -
ln -s $(DS_SWIG_BIN) $(DS_SWIG_BIN_PATH)/$(SWIG_BIN)
ds-swig: $(DS_SWIG_BIN_PATH)/swig
$(DS_SWIG_ENV) swig -version
$(DS_SWIG_ENV) swig -swiglib

View File

@ -51,8 +51,10 @@ namespace CSharpExamples
{
Console.WriteLine("Loading model...");
stopwatch.Start();
// sphinx-doc: csharp_ref_model_start
using (IDeepSpeech sttClient = new DeepSpeech(model ?? "output_graph.pbmm"))
{
// sphinx-doc: csharp_ref_model_stop
stopwatch.Stop();
Console.WriteLine($"Model loaded - {stopwatch.Elapsed.Milliseconds} ms");
@ -72,6 +74,7 @@ namespace CSharpExamples
stopwatch.Start();
string speechResult;
// sphinx-doc: csharp_ref_inference_start
if (extended)
{
Metadata metaResult = sttClient.SpeechToTextWithMetadata(waveBuffer.ShortBuffer,
@ -83,6 +86,7 @@ namespace CSharpExamples
speechResult = sttClient.SpeechToText(waveBuffer.ShortBuffer,
Convert.ToUInt32(waveBuffer.MaxSize / 2));
}
// sphinx-doc: csharp_ref_inference_stop
stopwatch.Stop();
@ -99,4 +103,4 @@ namespace CSharpExamples
}
}
}
}
}

View File

@ -27,5 +27,5 @@ maven-bundle: apk
$(GRADLE) uploadArchives
$(GRADLE) zipMavenArtifacts
bindings: clean
swig -c++ -java -package org.mozilla.deepspeech.libdeepspeech -outdir libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/ -o jni/deepspeech_wrap.cpp jni/deepspeech.i
bindings: clean ds-swig
$(DS_SWIG_ENV) swig -c++ -java -package org.mozilla.deepspeech.libdeepspeech -outdir libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/ -o jni/deepspeech_wrap.cpp jni/deepspeech.i

View File

@ -49,8 +49,10 @@ public class DeepSpeechActivity extends AppCompatActivity {
private void newModel(String tfliteModel) {
this._tfliteStatus.setText("Creating model");
if (this._m == null) {
// sphinx-doc: java_ref_model_start
this._m = new DeepSpeechModel(tfliteModel);
this._m.setBeamWidth(BEAM_WIDTH);
// sphinx-doc: java_ref_model_stop
}
}
@ -98,7 +100,9 @@ public class DeepSpeechActivity extends AppCompatActivity {
long inferenceStartTime = System.currentTimeMillis();
// sphinx-doc: java_ref_inference_start
String decoded = this._m.stt(shorts, shorts.length);
// sphinx-doc: java_ref_inference_stop
inferenceExecTime = System.currentTimeMillis() - inferenceStartTime;

View File

@ -1,26 +1,37 @@
NODE_BUILD_TOOL ?= node-pre-gyp
NODE_ABI_TARGET ?=
NODE_BUILD_VERBOSE ?= --verbose
NPM_TOOL ?= npm
PROJECT_NAME ?= deepspeech
PROJECT_VERSION ?= $(shell cat ../../VERSION | tr -d '\n')
NPM_ROOT ?= $(shell npm root)
NODE_MODULES_BIN ?= $(NPM_ROOT)/.bin/
ifeq ($(findstring _NT,$(OS)),_NT)
# On Windows, we seem to need both in PATH for node-pre-gyp as well as tsc
# they do not get installed the same way.
NODE_MODULES_BIN := $(shell cygpath -u $(NPM_ROOT)/.bin/):$(shell cygpath -u `dirname "$(NPM_ROOT)"`)
endif
include ../definitions.mk
ifeq ($(TARGET),host-win)
ifeq ($(findstring _NT,$(OS)),_NT)
LIBS := '$(shell cygpath -w $(subst .lib,,$(LIBS)))'
endif
.PHONY: npm-dev
default: build
clean:
rm -f deepspeech_wrap.cxx package.json
rm -f deepspeech_wrap.cxx package.json package-lock.json
rm -rf ./build/
clean-npm-pack:
rm -fr ./node_modules/
rm -fr ./deepspeech-*.tgz
really-clean: clean clean-npm-pack
rm -fr ./node_modules/
rm -fr ./lib/
package.json: package.json.in
@ -29,22 +40,27 @@ package.json: package.json.in
-e 's/$$(PROJECT_VERSION)/$(PROJECT_VERSION)/' \
package.json.in > package.json && cat package.json
configure: deepspeech_wrap.cxx package.json
$(NODE_BUILD_TOOL) configure $(NODE_BUILD_VERBOSE)
npm-dev: package.json
ifeq ($(findstring _NT,$(OS)),_NT)
# node-gyp@5.x behaves erratically with VS2015 and MSBuild.exe detection
$(NPM_TOOL) install node-gyp@4.x
endif
$(NPM_TOOL) install --prefix=$(NPM_ROOT)/../ --ignore-scripts --force --verbose --production=false .
configure: deepspeech_wrap.cxx package.json npm-dev
PATH="$(NODE_MODULES_BIN):${PATH}" $(NODE_BUILD_TOOL) configure $(NODE_BUILD_VERBOSE)
build: configure deepspeech_wrap.cxx
AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS)" CXXFLAGS="$(CXXFLAGS)" LDFLAGS="$(RPATH_NODEJS) $(LDFLAGS)" LIBS=$(LIBS) $(NODE_BUILD_TOOL) $(NODE_PLATFORM_TARGET) $(NODE_RUNTIME) $(NODE_ABI_TARGET) $(NODE_DIST_URL) rebuild $(NODE_BUILD_VERBOSE)
PATH="$(NODE_MODULES_BIN):${PATH}" AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS)" CXXFLAGS="$(CXXFLAGS)" LDFLAGS="$(RPATH_NODEJS) $(LDFLAGS)" LIBS=$(LIBS) $(NODE_BUILD_TOOL) $(NODE_PLATFORM_TARGET) $(NODE_RUNTIME) $(NODE_ABI_TARGET) $(NODE_DIST_URL) --no-color rebuild $(NODE_BUILD_VERBOSE)
copy-deps: build
$(call copy_missing_libs,lib/binding/*/*/*/deepspeech.node,lib/binding/*/*/)
node-wrapper: copy-deps build
$(NODE_BUILD_TOOL) $(NODE_PLATFORM_TARGET) $(NODE_RUNTIME) $(NODE_ABI_TARGET) $(NODE_DIST_URL) package $(NODE_BUILD_VERBOSE)
PATH="$(NODE_MODULES_BIN):${PATH}" $(NODE_BUILD_TOOL) $(NODE_PLATFORM_TARGET) $(NODE_RUNTIME) $(NODE_ABI_TARGET) $(NODE_DIST_URL) --no-color package $(NODE_BUILD_VERBOSE)
npm-pack: clean package.json index.js
npm install node-pre-gyp@0.14.x
npm pack $(NODE_BUILD_VERBOSE)
npm-pack: clean package.json index.js npm-dev
PATH="$(NODE_MODULES_BIN):${PATH}" tsc && $(NPM_TOOL) pack $(NODE_BUILD_VERBOSE)
deepspeech_wrap.cxx: deepspeech.i
swig -version
swig -c++ -javascript -node deepspeech.i
deepspeech_wrap.cxx: deepspeech.i ds-swig
$(DS_SWIG_ENV) swig -c++ -javascript -node deepspeech.i

View File

@ -1 +1,18 @@
Full project description and documentation on GitHub: [https://github.com/mozilla/DeepSpeech](https://github.com/mozilla/DeepSpeech).
## Generating TypeScript Type Definitions
You can generate the TypeScript type declaration file using `dts-gen`.
This requires a compiled/installed version of the DeepSpeech NodeJS client.
Upon API change, it is required to generate a new `index.d.ts` type declaration
file, you have to run:
```sh
npm install -g dts-gen
dts-gen --module deepspeech --file index.d.ts
```
### Example usage
See `client.ts`

View File

@ -1,48 +1,42 @@
#!/usr/bin/env node
'use strict';
const Fs = require('fs');
const Sox = require('sox-stream');
const Ds = require('./index.js');
const argparse = require('argparse');
const MemoryStream = require('memory-stream');
const Wav = require('node-wav');
const Duplex = require('stream').Duplex;
const util = require('util');
// This is required for process.versions.electron below
/// <reference types="electron" />
var VersionAction = function VersionAction(options) {
options = options || {};
options.nargs = 0;
argparse.Action.call(this, options);
}
util.inherits(VersionAction, argparse.Action);
import Ds from "./index";
import * as Fs from "fs";
import Sox from "sox-stream";
import * as argparse from "argparse";
VersionAction.prototype.call = function(parser) {
console.log('DeepSpeech ' + Ds.Version());
let runtime = 'Node';
if (process.versions.electron) {
runtime = 'Electron';
const MemoryStream = require("memory-stream");
const Wav = require("node-wav");
const Duplex = require("stream").Duplex;
class VersionAction extends argparse.Action {
call(parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: string | string[], optionString: string | null) {
console.log('DeepSpeech ' + Ds.Version());
let runtime = 'Node';
if (process.versions.electron) {
runtime = 'Electron';
}
console.error('Runtime: ' + runtime);
process.exit(0);
}
console.error('Runtime: ' + runtime);
process.exit(0);
}
var parser = new argparse.ArgumentParser({addHelp: true, description: 'Running DeepSpeech inference.'});
let parser = new argparse.ArgumentParser({addHelp: true, description: 'Running DeepSpeech inference.'});
parser.addArgument(['--model'], {required: true, help: 'Path to the model (protocol buffer binary file)'});
parser.addArgument(['--scorer'], {help: 'Path to the external scorer file'});
parser.addArgument(['--audio'], {required: true, help: 'Path to the audio file to run (WAV format)'});
parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', type: 'int'});
parser.addArgument(['--lm_alpha'], {help: 'Language model weight (lm_alpha). If not specified, use default from the scorer package.', type: 'float'});
parser.addArgument(['--lm_beta'], {help: 'Word insertion bonus (lm_beta). If not specified, use default from the scorer package.', type: 'float'});
parser.addArgument(['--version'], {action: VersionAction, help: 'Print version and exits'});
parser.addArgument(['--version'], {action: VersionAction, nargs: 0, help: 'Print version and exits'});
parser.addArgument(['--extended'], {action: 'storeTrue', help: 'Output string from extended metadata'});
var args = parser.parseArgs();
let args = parser.parseArgs();
function totalTime(hrtimeValue) {
function totalTime(hrtimeValue: number[]): string {
return (hrtimeValue[0] + hrtimeValue[1] / 1000000000).toPrecision(4);
}
function candidateTranscriptToString(transcript) {
function candidateTranscriptToString(transcript: Ds.CandidateTranscript): string {
var retval = ""
for (var i = 0; i < transcript.tokens.length; ++i) {
retval += transcript.tokens[i].text;
@ -50,17 +44,19 @@ function candidateTranscriptToString(transcript) {
return retval;
}
// sphinx-doc: js_ref_model_start
console.error('Loading model from file %s', args['model']);
const model_load_start = process.hrtime();
var model = new Ds.Model(args['model']);
let model = new Ds.Model(args['model']);
const model_load_end = process.hrtime(model_load_start);
console.error('Loaded model in %ds.', totalTime(model_load_end));
if (args['beam_width']) {
model.setBeamWidth(args['beam_width']);
}
// sphinx-doc: js_ref_model_stop
var desired_sample_rate = model.sampleRate();
let desired_sample_rate = model.sampleRate();
if (args['scorer']) {
console.error('Loading scorer from file %s', args['scorer']);
@ -78,23 +74,24 @@ const buffer = Fs.readFileSync(args['audio']);
const result = Wav.decode(buffer);
if (result.sampleRate < desired_sample_rate) {
console.error('Warning: original sample rate (' + result.sampleRate + ') ' +
'is lower than ' + desired_sample_rate + 'Hz. ' +
'Up-sampling might produce erratic speech recognition.');
console.error(`Warning: original sample rate ( ${result.sampleRate})` +
`is lower than ${desired_sample_rate} Hz. ` +
`Up-sampling might produce erratic speech recognition.`);
}
function bufferToStream(buffer) {
function bufferToStream(buffer: Buffer) {
var stream = new Duplex();
stream.push(buffer);
stream.push(null);
return stream;
}
var audioStream = new MemoryStream();
let audioStream = new MemoryStream();
bufferToStream(buffer).
pipe(Sox({
global: {
'no-dither': true,
'replay-gain': 'off',
},
output: {
bits: 16,
@ -115,6 +112,7 @@ audioStream.on('finish', () => {
console.error('Running inference.');
const audioLength = (audioBuffer.length / 2) * (1 / desired_sample_rate);
// sphinx-doc: js_ref_inference_start
if (args['extended']) {
let metadata = model.sttWithMetadata(audioBuffer, 1);
console.log(candidateTranscriptToString(metadata.transcripts[0]));
@ -122,6 +120,7 @@ audioStream.on('finish', () => {
} else {
console.log(model.stt(audioBuffer));
}
// sphinx-doc: js_ref_inference_stop
const inference_stop = process.hrtime(inference_start);
console.error('Inference took %ds for %ds audio file.', totalTime(inference_stop), audioLength.toPrecision(4));
Ds.FreeModel(model);

196
native_client/javascript/index.d.ts vendored Normal file
View File

@ -0,0 +1,196 @@
/**
* Stores text of an individual token, along with its timing information
*/
export interface TokenMetadata {
text: string;
timestep: number;
start_time: number;
}
/**
* A single transcript computed by the model, including a confidence value and
* the metadata for its constituent tokens.
*/
export interface CandidateTranscript {
tokens: TokenMetadata[];
confidence: number;
}
/**
* An array of CandidateTranscript objects computed by the model.
*/
export interface Metadata {
transcripts: CandidateTranscript[];
}
/**
* An object providing an interface to a trained DeepSpeech model.
*
* @param aModelPath The path to the frozen model graph.
*
* @throws on error
*/
export class Model {
constructor(aModelPath: string)
/**
* Get beam width value used by the model. If :js:func:Model.setBeamWidth was
* not called before, will return the default value loaded from the model file.
*
* @return Beam width value used by the model.
*/
beamWidth(): number;
/**
* Set beam width value used by the model.
*
* @param The beam width used by the model. A larger beam width value generates better results at the cost of decoding time.
*
* @return Zero on success, non-zero on failure.
*/
setBeamWidth(aBeamWidth: number): number;
/**
* Return the sample rate expected by the model.
*
* @return Sample rate.
*/
sampleRate(): number;
/**
* Enable decoding using an external scorer.
*
* @param aScorerPath The path to the external scorer file.
*
* @return Zero on success, non-zero on failure (invalid arguments).
*/
enableExternalScorer(aScorerPath: string): number;
/**
* Disable decoding using an external scorer.
*
* @return Zero on success, non-zero on failure (invalid arguments).
*/
disableExternalScorer(): number;
/**
* Set hyperparameters alpha and beta of the external scorer.
*
* @param aLMAlpha The alpha hyperparameter of the CTC decoder. Language Model weight.
* @param aLMBeta The beta hyperparameter of the CTC decoder. Word insertion weight.
*
* @return Zero on success, non-zero on failure (invalid arguments).
*/
setScorerAlphaBeta(aLMAlpha: number, aLMBeta: number): number;
/**
* Use the DeepSpeech model to perform Speech-To-Text.
*
* @param aBuffer A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
*
* @return The STT result. Returns undefined on error.
*/
stt(aBuffer: object): string;
/**
* Use the DeepSpeech model to perform Speech-To-Text and output metadata
* about the results.
*
* @param aBuffer A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
* @param aNumResults Maximum number of candidate transcripts to return. Returned list might be smaller than this.
* Default value is 1 if not specified.
*
* @return :js:func:`Metadata` object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information.
* The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`. Returns undefined on error.
*/
sttWithMetadata(aBuffer: object, aNumResults: number): Metadata;
/**
* Create a new streaming inference state. One can then call :js:func:`Stream.feedAudioContent` and :js:func:`Stream.finishStream` on the returned stream object.
*
* @return a :js:func:`Stream` object that represents the streaming state.
*
* @throws on error
*/
createStream(): object;
}
/**
* @class
* Provides an interface to a DeepSpeech stream. The constructor cannot be called
* directly, use :js:func:`Model.createStream`.
*/
declare class Stream {
/**
* Feed audio samples to an ongoing streaming inference.
*
* @param aBuffer An array of 16-bit, mono raw audio samples at the
* appropriate sample rate (matching what the model was trained on).
*/
feedAudioContent(aBuffer: object): void;
/**
* Compute the intermediate decoding of an ongoing streaming inference.
*
* @return The STT intermediate result.
*/
intermediateDecode(aSctx: object): string;
/**
* Compute the intermediate decoding of an ongoing streaming inference, return results including metadata.
*
* @param aNumResults Maximum number of candidate transcripts to return. Returned list might be smaller than this. Default value is 1 if not specified.
*
* @return :js:func:`Metadata` object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`. Returns undefined on error.
*/
intermediateDecodeWithMetadata (aNumResults: number): Metadata;
/**
* Compute the final decoding of an ongoing streaming inference and return the result. Signals the end of an ongoing streaming inference.
*
* @return The STT result.
*
* This method will free the stream, it must not be used after this method is called.
*/
finishStream(): string;
/**
* Compute the final decoding of an ongoing streaming inference and return the results including metadata. Signals the end of an ongoing streaming inference.
*
* @param aNumResults Maximum number of candidate transcripts to return. Returned list might be smaller than this. Default value is 1 if not specified.
*
* @return Outputs a :js:func:`Metadata` struct of individual letters along with their timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`.
*
* This method will free the stream, it must not be used after this method is called.
*/
finishStreamWithMetadata(aNumResults: number): Metadata;
}
/**
* Frees associated resources and destroys model object.
*
* @param model A model pointer returned by :js:func:`Model`
*
*/
export function FreeModel(model: Model): void;
/**
* Free memory allocated for metadata information.
*
* @param metadata Object containing metadata as returned by :js:func:`Model.sttWithMetadata` or :js:func:`Model.finishStreamWithMetadata`
*/
export function FreeMetadata(metadata: Metadata): void;
/**
* Destroy a streaming state without decoding the computed logits. This
* can be used if you no longer need the result of an ongoing streaming
* inference and don't want to perform a costly decode operation.
*
* @param stream A streaming state pointer returned by :js:func:`Model.createStream`.
*/
export function FreeStream(stream: object): void;
/**
* Print version of this library and of the linked TensorFlow library on standard output.
*/
export function Version(): void;

View File

@ -2,7 +2,8 @@
"name" : "$(PROJECT_NAME)",
"version" : "$(PROJECT_VERSION)",
"description" : "DeepSpeech NodeJS bindings",
"main" : "./index",
"main" : "./index.js",
"types": "./index.d.ts",
"bin": {
"deepspeech": "./client.js"
},
@ -13,6 +14,7 @@
"README.md",
"client.js",
"index.js",
"index.d.ts",
"lib/*"
],
"bugs": {
@ -37,6 +39,11 @@
"node-wav": "0.0.2"
},
"devDependencies": {
"electron": "^1.7.9",
"node-gyp": "4.x - 5.x",
"typescript": "3.6.x",
"@types/argparse": "1.0.x",
"@types/node": "13.9.x"
},
"scripts": {
"test": "node index.js"

View File

@ -0,0 +1,18 @@
{
"compilerOptions": {
"baseUrl": ".",
"target": "es6",
"module": "commonjs",
"moduleResolution": "node",
"esModuleInterop": true,
"noImplicitAny": true,
"noImplicitThis": true,
"strictFunctionTypes": true,
"strictNullChecks": true,
"forceConsistentCasingInFileNames": true
},
"files": [
"index.d.ts",
"client.ts"
]
}

View File

@ -8,9 +8,9 @@ bindings-clean:
# Enforce PATH here because swig calls from build_ext looses track of some
# variables over several runs
bindings-build:
bindings-build: ds-swig
pip install --quiet $(PYTHON_PACKAGES) wheel==0.33.6 setuptools==39.1.0
PATH=$(TOOLCHAIN):$$PATH AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED) $(RPATH_PYTHON)" MODEL_LDFLAGS="$(LDFLAGS_DIRS)" MODEL_LIBS="$(LIBS)" $(PYTHON_PATH) $(PYTHON_SYSCONFIGDATA) $(NUMPY_INCLUDE) python ./setup.py build_ext $(PYTHON_PLATFORM_NAME)
PATH=$(TOOLCHAIN):$(DS_SWIG_BIN_PATH):$$PATH SWIG_LIB="$(SWIG_LIB)" AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED) $(RPATH_PYTHON)" MODEL_LDFLAGS="$(LDFLAGS_DIRS)" MODEL_LIBS="$(LIBS)" $(PYTHON_PATH) $(PYTHON_SYSCONFIGDATA) $(NUMPY_INCLUDE) python ./setup.py build_ext $(PYTHON_PLATFORM_NAME)
MANIFEST.in: bindings-build
> $@

View File

@ -111,7 +111,9 @@ def main():
print('Loading model from file {}'.format(args.model), file=sys.stderr)
model_load_start = timer()
# sphinx-doc: python_ref_model_start
ds = Model(args.model)
# sphinx-doc: python_ref_model_stop
model_load_end = timer() - model_load_start
print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr)
@ -131,24 +133,26 @@ def main():
ds.setScorerAlphaBeta(args.lm_alpha, args.lm_beta)
fin = wave.open(args.audio, 'rb')
fs = fin.getframerate()
if fs != desired_sample_rate:
print('Warning: original sample rate ({}) is different than {}hz. Resampling might produce erratic speech recognition.'.format(fs, desired_sample_rate), file=sys.stderr)
fs, audio = convert_samplerate(args.audio, desired_sample_rate)
fs_orig = fin.getframerate()
if fs_orig != desired_sample_rate:
print('Warning: original sample rate ({}) is different than {}hz. Resampling might produce erratic speech recognition.'.format(fs_orig, desired_sample_rate), file=sys.stderr)
fs_new, audio = convert_samplerate(args.audio, desired_sample_rate)
else:
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
audio_length = fin.getnframes() * (1/fs)
audio_length = fin.getnframes() * (1/fs_orig)
fin.close()
print('Running inference.', file=sys.stderr)
inference_start = timer()
# sphinx-doc: python_ref_inference_start
if args.extended:
print(metadata_to_string(ds.sttWithMetadata(audio, 1).transcripts[0]))
elif args.json:
print(metadata_json_output(ds.sttWithMetadata(audio, 3)))
else:
print(ds.stt(audio))
# sphinx-doc: python_ref_inference_stop
inference_end = timer() - inference_start
print('Inference took %0.3fs for %0.3fs audio file.' % (inference_end, audio_length), file=sys.stderr)

View File

@ -1,24 +0,0 @@
# Main training requirements
tensorflow == 1.15.2
numpy == 1.18.1
progressbar2
six
pyxdg
attrdict
absl-py
semver
opuslib == 2.0.0
# Requirements for building native_client files
setuptools
# Requirements for importers
sox
bs4
pandas
requests
librosa
soundfile
# Requirements for optimizer
optuna

125
setup.py Normal file
View File

@ -0,0 +1,125 @@
import os
import platform
import sys
from pathlib import Path
from pkg_resources import parse_version
from setuptools import find_packages, setup
def get_decoder_pkg_url(version, artifacts_root=None):
is_arm = 'arm' in platform.machine()
is_mac = 'darwin' in sys.platform
is_64bit = sys.maxsize > (2**31 - 1)
if is_arm:
tc_arch = 'arm64-ctc' if is_64bit else 'arm-ctc'
elif is_mac:
tc_arch = 'osx-ctc'
else:
tc_arch = 'cpu-ctc'
ds_version = parse_version(version)
branch = "v{}".format(version)
plat = platform.system().lower()
arch = platform.machine()
if plat == 'linux' and arch == 'x86_64':
plat = 'manylinux1'
if plat == 'darwin':
plat = 'macosx_10_10'
is_ucs2 = sys.maxunicode < 0x10ffff
m_or_mu = 'mu' if is_ucs2 else 'm'
pyver = ''.join(str(i) for i in sys.version_info[0:2])
if not artifacts_root:
artifacts_root = 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.deepspeech.native_client.{branch_name}.{tc_arch_string}/artifacts/public'.format(
branch_name=branch,
tc_arch_string=tc_arch)
return 'ds_ctcdecoder @ {artifacts_root}/ds_ctcdecoder-{ds_version}-cp{pyver}-cp{pyver}{m_or_mu}-{platform}_{arch}.whl'.format(
artifacts_root=artifacts_root,
ds_version=ds_version,
pyver=pyver,
m_or_mu=m_or_mu,
platform=plat,
arch=arch,
)
def main():
version_file = Path(__file__).parent / 'VERSION'
with open(str(version_file)) as fin:
version = fin.read().strip()
decoder_pkg_url = get_decoder_pkg_url(version)
install_requires_base = [
'tensorflow == 1.15.2',
'numpy == 1.18.1',
'progressbar2',
'six',
'pyxdg',
'attrdict',
'absl-py',
'semver',
'opuslib == 2.0.0',
'optuna',
'sox',
'bs4',
'pandas',
'requests',
'librosa',
'soundfile',
]
# Due to pip craziness environment variables are the only consistent way to
# get options into this script when doing `pip install`.
tc_decoder_artifacts_root = os.environ.get('DECODER_ARTIFACTS_ROOT', '')
if tc_decoder_artifacts_root:
# We're running inside the TaskCluster environment, override the decoder
# package URL with the one we just built.
decoder_pkg_url = get_decoder_pkg_url(version, tc_decoder_artifacts_root)
install_requires = install_requires_base + [decoder_pkg_url]
elif os.environ.get('DS_NODECODER', ''):
install_requires = install_requires_base
else:
install_requires = install_requires_base + [decoder_pkg_url]
setup(
name='deepspeech_training',
version=version,
description='Training code for mozilla DeepSpeech',
url='https://github.com/mozilla/DeepSpeech',
author='Mozilla',
license='MPL-2.0',
# Classifiers help users find your project by categorizing it.
#
# For a list of valid classifiers, see https://pypi.org/classifiers/
classifiers=[
'Development Status :: 3 - Alpha',
'Intended Audience :: Developers',
'Topic :: Multimedia :: Sound/Audio :: Speech',
'License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)',
'Programming Language :: Python :: 3',
],
package_dir={'': 'training'},
packages=find_packages(where='training'),
python_requires='>=3.5, <4',
install_requires=install_requires,
# If there are data files included in your packages that need to be
# installed, specify them here.
package_data={
'deepspeech_training': [
'VERSION',
'GRAPH_VERSION',
],
},
)
if __name__ == '__main__':
main()

View File

@ -1,10 +1,29 @@
#!/usr/bin/env python3
import argparse
import os
import functools
import pandas
from deepspeech_training.util.helpers import secs_to_hours
from pathlib import Path
def read_csvs(csv_files):
# Relative paths are relative to CSV location
def absolutify(csv, path):
path = Path(path)
if path.is_absolute():
return str(path)
return str(csv.parent / path)
sets = []
for csv in csv_files:
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
file['wav_filename'] = file['wav_filename'].apply(functools.partial(absolutify, csv))
sets.append(file)
# Concat all sets, drop any extra columns, re-index the final result as 0..N
return pandas.concat(sets, join='inner', ignore_index=True)
from util.helpers import secs_to_hours
from util.feeding import read_csvs
def main():
parser = argparse.ArgumentParser()
@ -14,20 +33,16 @@ def main():
parser.add_argument("--channels", type=int, default=1, required=False, help="Audio channels")
parser.add_argument("--bits-per-sample", type=int, default=16, required=False, help="Audio bits per sample")
args = parser.parse_args()
in_files = [os.path.abspath(i) for i in args.csv_files.split(",")]
in_files = [Path(i).absolute() for i in args.csv_files.split(",")]
csv_dataframe = read_csvs(in_files)
total_bytes = csv_dataframe['wav_filesize'].sum()
total_files = len(csv_dataframe.index)
total_files = len(csv_dataframe)
total_seconds = ((csv_dataframe['wav_filesize'] - 44) / args.sample_rate / args.channels / (args.bits_per_sample // 8)).sum()
bytes_without_headers = total_bytes - 44 * total_files
total_time = bytes_without_headers / (args.sample_rate * args.channels * args.bits_per_sample / 8)
print('total_bytes', total_bytes)
print('total_files', total_files)
print('bytes_without_headers', bytes_without_headers)
print('total_time', secs_to_hours(total_time))
print('Total bytes:', total_bytes)
print('Total files:', total_files)
print('Total time:', secs_to_hours(total_seconds))
if __name__ == '__main__':
main()

View File

@ -115,10 +115,6 @@ system:
swig:
repo: "https://github.com/lissyx/swig"
sha1: "b5fea54d39832d1d132d7dd921b69c0c2c9d5118"
cache:
linux_amd64: 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.swig.linux.amd64.b5fea54d39832d1d132d7dd921b69c0c2c9d5118/artifacts/public/ds-swig.tar.gz'
darwin_amd64: 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.swig.darwin.amd64.b5fea54d39832d1d132d7dd921b69c0c2c9d5118/artifacts/public/ds-swig.tar.gz'
win_amd64: 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.swig.win.amd64.b5fea54d39832d1d132d7dd921b69c0c2c9d5118/artifacts/public/ds-swig.tar.gz'
username: 'build-user'
homedir:
linux: '/home/build-user'

View File

@ -88,10 +88,6 @@ payload:
format: tar.gz
content:
url: ${system.node_gyp_cache.url}
- directory: ds-swig
format: tar.gz
content:
url: ${system.swig.cache.darwin_amd64}
- file: home.tar.xz
content:
url: ${build.tensorflow}

View File

@ -1,6 +1,6 @@
breathe==4.13.1
breathe==4.14.2
semver==2.8.1
sphinx==2.2.0
sphinx==2.4.4
sphinx-js==2.8
sphinx-rtd-theme==0.4.3
pygments==2.4.2
pygments==2.6.1

View File

@ -48,7 +48,7 @@ then:
adduser --system --home ${system.homedir.linux} ${system.username} &&
apt-get -qq update && apt-get -qq -y install ${tensorflow.packages_trusty.apt} pixz pkg-config realpath unzip wget zip && ${extraSystemSetup} &&
cd ${system.homedir.linux}/ &&
echo -e "#!/bin/bash\nset -xe\n env && id && (wget -O - $TENSORFLOW_BUILD_ARTIFACT | pixz -d | tar -C ${system.homedir.linux}/ -xf - ) && git clone --quiet ${event.head.repo.url} ~/DeepSpeech/ds/ && cd ~/DeepSpeech/ds && git checkout --quiet ${event.head.sha} && ln -s ~/DeepSpeech/ds/native_client/ ~/DeepSpeech/tf/native_client && mkdir -p ${system.homedir.linux}/.cache/node-gyp/ && wget -O - ${system.node_gyp_cache.url} | tar -C ${system.homedir.linux}/.cache/node-gyp/ -xzf - && mkdir -p ${system.homedir.linux}/ds-swig/bin/ && wget -O - ${system.swig.cache.linux_amd64} | tar -C ${system.homedir.linux}/ds-swig/ -xzf - && mkdir -p ${system.homedir.linux}/pyenv-root/ && wget -O - ${system.pyenv.linux.url} | tar -C ${system.homedir.linux}/pyenv-root/ -xzf - && if [ ! -z "${build.gradle_cache.url}" ]; then wget -O - ${build.gradle_cache.url} | tar -C ${system.homedir.linux}/ -xzf - ; fi && if [ ! -z "${build.android_cache.url}" ]; then wget -O - ${build.android_cache.url} | tar -C ${system.homedir.linux}/ -xzf - ; fi;" > /tmp/clone.sh && chmod +x /tmp/clone.sh &&
echo -e "#!/bin/bash\nset -xe\n env && id && (wget -O - $TENSORFLOW_BUILD_ARTIFACT | pixz -d | tar -C ${system.homedir.linux}/ -xf - ) && git clone --quiet ${event.head.repo.url} ~/DeepSpeech/ds/ && cd ~/DeepSpeech/ds && git checkout --quiet ${event.head.sha} && ln -s ~/DeepSpeech/ds/native_client/ ~/DeepSpeech/tf/native_client && mkdir -p ${system.homedir.linux}/.cache/node-gyp/ && wget -O - ${system.node_gyp_cache.url} | tar -C ${system.homedir.linux}/.cache/node-gyp/ -xzf - && mkdir -p ${system.homedir.linux}/pyenv-root/ && wget -O - ${system.pyenv.linux.url} | tar -C ${system.homedir.linux}/pyenv-root/ -xzf - && if [ ! -z "${build.gradle_cache.url}" ]; then wget -O - ${build.gradle_cache.url} | tar -C ${system.homedir.linux}/ -xzf - ; fi && if [ ! -z "${build.android_cache.url}" ]; then wget -O - ${build.android_cache.url} | tar -C ${system.homedir.linux}/ -xzf - ; fi;" > /tmp/clone.sh && chmod +x /tmp/clone.sh &&
sudo -H -u ${system.username} /bin/bash /tmp/clone.sh && ${extraSystemConfig} &&
sudo -H -u ${system.username} --preserve-env /bin/bash ${system.homedir.linux}/DeepSpeech/ds/${build.scripts.build} &&
sudo -H -u ${system.username} /bin/bash ${system.homedir.linux}/DeepSpeech/ds/${build.scripts.package}

View File

@ -122,25 +122,3 @@ verify_bazel_rebuild()
exit 1
fi;
}
# Should be called from context where Python virtualenv is set
verify_ctcdecoder_url()
{
default_url=$(python util/taskcluster.py --decoder)
echo "${default_url}" | grep -F "deepspeech.native_client.v${DS_VERSION}"
rc_default_url=$?
tag_url=$(python util/taskcluster.py --decoder --branch 'v1.2.3')
echo "${tag_url}" | grep -F "deepspeech.native_client.v1.2.3"
rc_tag_url=$?
master_url=$(python util/taskcluster.py --decoder --branch 'master')
echo "${master_url}" | grep -F "deepspeech.native_client.master"
rc_master_url=$?
if [ ${rc_default_url} -eq 0 -a ${rc_tag_url} -eq 0 -a ${rc_master_url} -eq 0 ]; then
return 0
else
return 1
fi;
}

View File

@ -6,14 +6,12 @@ export OS=$(uname)
if [ "${OS}" = "Linux" ]; then
export DS_ROOT_TASK=${HOME}
export PYENV_ROOT="${DS_ROOT_TASK}/pyenv-root"
export SWIG_ROOT="${HOME}/ds-swig"
export DS_CPU_COUNT=$(nproc)
fi;
if [ "${OS}" = "${TC_MSYS_VERSION}" ]; then
export DS_ROOT_TASK=${TASKCLUSTER_TASK_DIR}
export PYENV_ROOT="${TASKCLUSTER_TASK_DIR}/pyenv-root"
export SWIG_ROOT="$(cygpath ${USERPROFILE})/ds-swig"
export PLATFORM_EXE_SUFFIX=.exe
export DS_CPU_COUNT=$(nproc)
@ -22,7 +20,6 @@ if [ "${OS}" = "${TC_MSYS_VERSION}" ]; then
fi;
if [ "${OS}" = "Darwin" ]; then
export SWIG_ROOT="${TASKCLUSTER_ORIG_TASKDIR}/ds-swig"
export DS_ROOT_TASK=${TASKCLUSTER_TASK_DIR}
export DS_CPU_COUNT=$(sysctl hw.ncpu |cut -d' ' -f2)
export PYENV_ROOT="${DS_ROOT_TASK}/pyenv-root"
@ -45,19 +42,6 @@ if [ "${OS}" = "Darwin" ]; then
fi;
fi;
SWIG_BIN=swig${PLATFORM_EXE_SUFFIX}
DS_SWIG_BIN=ds-swig${PLATFORM_EXE_SUFFIX}
if [ -f "${SWIG_ROOT}/bin/${DS_SWIG_BIN}" ]; then
export PATH=${SWIG_ROOT}/bin/:$PATH
export SWIG_LIB="$(find ${SWIG_ROOT}/share/swig/ -type f -name "swig.swg" | xargs dirname)"
# Make an alias to be more magic
if [ ! -L "${SWIG_ROOT}/bin/${SWIG_BIN}" ]; then
ln -s ${DS_SWIG_BIN} ${SWIG_ROOT}/bin/${SWIG_BIN}
fi;
swig -version
swig -swiglib
fi;
PY37_OPENSSL_DIR="${PYENV_ROOT}/ssl-xenial"
export PY37_LDPATH="${PY37_OPENSSL_DIR}/usr/lib/"
export LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH

View File

@ -105,16 +105,10 @@ do_deepspeech_nodejs_build()
# Python 2.7 is required for node-pre-gyp, it is only required to force it on
# Windows
if [ "${OS}" = "${TC_MSYS_VERSION}" ]; then
NPM_ROOT=$(cygpath -u "$(npm root)")
PYTHON27=":/c/Python27"
# node-gyp@5.x behaves erratically with VS2015 and MSBuild.exe detection
npm install node-gyp@4.x node-pre-gyp
else
NPM_ROOT="$(npm root)"
npm install node-gyp@5.x node-pre-gyp
PYTHON27="/c/Python27"
fi
export PATH="$NPM_ROOT/.bin/${PYTHON27}:$PATH"
export PATH="${PYTHON27}:$PATH"
for node in ${SUPPORTED_NODEJS_VERSIONS}; do
EXTRA_CFLAGS="${EXTRA_LOCAL_CFLAGS}" EXTRA_LDFLAGS="${EXTRA_LOCAL_LDFLAGS}" EXTRA_LIBS="${EXTRA_LOCAL_LIBS}" make -C native_client/javascript \
@ -157,16 +151,10 @@ do_deepspeech_npm_package()
# Python 2.7 is required for node-pre-gyp, it is only required to force it on
# Windows
if [ "${OS}" = "${TC_MSYS_VERSION}" ]; then
NPM_ROOT=$(cygpath -u "$(npm root)")
PYTHON27=":/c/Python27"
# node-gyp@5.x behaves erratically with VS2015 and MSBuild.exe detection
npm install node-gyp@4.x node-pre-gyp
else
NPM_ROOT="$(npm root)"
npm install node-gyp@5.x node-pre-gyp
PYTHON27="/c/Python27"
fi
export PATH="$NPM_ROOT/.bin/$PYTHON27:$PATH"
export PATH="${NPM_BIN}${PYTHON27}:$PATH"
all_tasks="$(curl -s https://community-tc.services.mozilla.com/api/queue/v1/task/${TASK_ID} | python -c 'import json; import sys; print(" ".join(json.loads(sys.stdin.read())["dependencies"]));')"

View File

@ -17,7 +17,9 @@ deepspeech_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type})
set -o pipefail
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: --upgrade ${deepspeech_pkg_url} | cat
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
pushd ${HOME}/DeepSpeech/ds
pip install --upgrade . | cat
popd
set +o pipefail
which deepspeech

View File

@ -17,12 +17,11 @@ virtualenv_activate "${pyalias}" "deepspeech"
set -o pipefail
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
pushd ${HOME}/DeepSpeech/ds
pip install --upgrade . | cat
popd
set +o pipefail
decoder_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type} "ds_ctcdecoder" "${DECODER_ARTIFACTS_ROOT}")
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: --upgrade ${decoder_pkg_url} | cat
pushd ${HOME}/DeepSpeech/ds/
time ./bin/run-tc-ldc93s1_singleshotinference.sh
popd

View File

@ -16,15 +16,10 @@ virtualenv_activate "${pyalias}" "deepspeech"
set -o pipefail
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
set +o pipefail
pushd ${HOME}/DeepSpeech/ds/
verify_ctcdecoder_url
pushd ${HOME}/DeepSpeech/ds
pip install --upgrade . | cat
popd
decoder_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type} "ds_ctcdecoder" "${DECODER_ARTIFACTS_ROOT}")
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: ${PY37_SOURCE_PACKAGE} ${decoder_pkg_url} | cat
set +o pipefail
# Prepare correct arguments for training
case "${bitrate}" in

View File

@ -14,15 +14,10 @@ virtualenv_activate "${pyalias}" "deepspeech"
set -o pipefail
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
set +o pipefail
pushd ${HOME}/DeepSpeech/ds/
verify_ctcdecoder_url
pushd ${HOME}/DeepSpeech/ds
DS_NODECODER=1 pip install --upgrade . | cat
popd
decoder_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type} "ds_ctcdecoder" "${DECODER_ARTIFACTS_ROOT}")
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: ${PY37_SOURCE_PACKAGE} ${decoder_pkg_url} | cat
set +o pipefail
pushd ${HOME}/DeepSpeech/ds/
time ./bin/run-tc-transfer.sh

View File

@ -50,7 +50,7 @@ then:
${extraSystemSetup} && chmod 777 /dev/kvm &&
adduser --system --home ${system.homedir.linux} ${system.username} &&
cd ${system.homedir.linux} &&
echo -e "#!/bin/bash\nset -xe\n env && id && mkdir ~/DeepSpeech/ && git clone --quiet ${event.head.repo.url} ~/DeepSpeech/ds/ && cd ~/DeepSpeech/ds && git checkout --quiet ${event.head.sha} && mkdir -p ${system.homedir.linux}/ds-swig/bin/ && wget -O - ${system.swig.cache.linux_amd64} | tar -C ${system.homedir.linux}/ds-swig/ -xzf - && wget -O - ${build.cache.url} | tar -C ${system.homedir.linux} -xzf - && if [ ! -z "${build.gradle_cache.url}" ]; then wget -O - ${build.gradle_cache.url} | tar -C ${system.homedir.linux}/ -xzf - ; fi;" > /tmp/clone.sh && chmod +x /tmp/clone.sh &&
echo -e "#!/bin/bash\nset -xe\n env && id && mkdir ~/DeepSpeech/ && git clone --quiet ${event.head.repo.url} ~/DeepSpeech/ds/ && cd ~/DeepSpeech/ds && git checkout --quiet ${event.head.sha} && wget -O - ${build.cache.url} | tar -C ${system.homedir.linux} -xzf - && if [ ! -z "${build.gradle_cache.url}" ]; then wget -O - ${build.gradle_cache.url} | tar -C ${system.homedir.linux}/ -xzf - ; fi;" > /tmp/clone.sh && chmod +x /tmp/clone.sh &&
sudo -H -u ${system.username} /bin/bash /tmp/clone.sh &&
sudo -H -u ${system.username} --preserve-env /bin/bash ${build.args.tests_cmdline}

View File

@ -85,10 +85,6 @@ payload:
format: tar.gz
content:
url: ${system.node_gyp_cache.url}
- directory: ds-swig
format: tar.gz
content:
url: ${system.swig.cache.win_amd64}
artifacts:
- type: "directory"

View File

@ -1,10 +1,14 @@
import unittest
from argparse import Namespace
from .importers import validate_label_eng, get_validate_label
from deepspeech_training.util.importers import validate_label_eng, get_validate_label
from pathlib import Path
def from_here(path):
here = Path(__file__)
return here.parent / path
class TestValidateLabelEng(unittest.TestCase):
def test_numbers(self):
label = validate_label_eng("this is a 1 2 3 test")
self.assertEqual(label, None)
@ -24,12 +28,12 @@ class TestGetValidateLabel(unittest.TestCase):
self.assertEqual(f('toto1234[{[{[]'), None)
def test_get_validate_label_missing(self):
args = Namespace(validate_label_locale='util/test_data/validate_locale_ger.py')
args = Namespace(validate_label_locale=from_here('test_data/validate_locale_ger.py'))
f = get_validate_label(args)
self.assertEqual(f, None)
def test_get_validate_label(self):
args = Namespace(validate_label_locale='util/test_data/validate_locale_fra.py')
args = Namespace(validate_label_locale=from_here('test_data/validate_locale_fra.py'))
f = get_validate_label(args)
l = f('toto')
self.assertEqual(l, 'toto')

View File

@ -1,7 +1,7 @@
import unittest
import os
from .text import Alphabet
from deepspeech_training.util.text import Alphabet
class TestAlphabetParsing(unittest.TestCase):

View File

@ -0,0 +1 @@
../../GRAPH_VERSION

View File

@ -0,0 +1 @@
../../VERSION

View File

View File

@ -0,0 +1,155 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import json
import sys
from multiprocessing import cpu_count
import absl.app
import progressbar
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from six.moves import zip
from .util.config import Config, initialize_globals
from .util.checkpoints import load_graph_for_evaluation
from .util.evaluate_tools import calculate_and_print_report
from .util.feeding import create_dataset
from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version
from .util.logging import create_progressbar, log_error, log_progress
check_ctcdecoder_version()
def sparse_tensor_value_to_texts(value, alphabet):
r"""
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
representing its values, converting tokens to strings using ``alphabet``.
"""
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
def sparse_tuple_to_texts(sp_tuple, alphabet):
indices = sp_tuple[0]
values = sp_tuple[1]
results = [[] for _ in range(sp_tuple[2][0])]
for i, index in enumerate(indices):
results[index[0]].append(values[i])
# List of strings
return [alphabet.decode(res) for res in results]
def evaluate(test_csvs, create_model):
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet)
else:
scorer = None
test_csvs = FLAGS.test_files.split(',')
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
tfv1.data.get_output_shapes(test_sets[0]),
output_classes=tfv1.data.get_output_classes(test_sets[0]))
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
# One rate per layer
no_dropout = [None] * 6
logits, _ = create_model(batch_x=batch_x,
batch_size=FLAGS.test_batch_size,
seq_length=batch_x_len,
dropout=no_dropout)
# Transpose to batch major and apply softmax for decoder
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
loss = tfv1.nn.ctc_loss(labels=batch_y,
inputs=logits,
sequence_length=batch_x_len)
tfv1.train.get_or_create_global_step()
# Get number of accessible CPU cores for this process
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
with tfv1.Session(config=Config.session_config) as session:
load_graph_for_evaluation(session)
def run_test(init_op, dataset):
wav_filenames = []
losses = []
predictions = []
ground_truths = []
bar = create_progressbar(prefix='Test epoch | ',
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
log_progress('Test epoch...')
step_count = 0
# Initialize iterator to the appropriate dataset
session.run(init_op)
# First pass, compute losses and transposed logits for decoding
while True:
try:
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
except tf.errors.OutOfRangeError:
break
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
num_processes=num_processes, scorer=scorer,
cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n)
predictions.extend(d[0][1] for d in decoded)
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
losses.extend(batch_loss)
step_count += 1
bar.update(step_count)
bar.finish()
# Print test summary
test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset)
return test_samples
samples = []
for csv, init_op in zip(test_csvs, test_init_ops):
print('Testing model on {}'.format(csv))
samples.extend(run_test(init_op, dataset=csv))
return samples
def main(_):
initialize_globals()
if not FLAGS.test_files:
log_error('You need to specify what files to use for evaluation via '
'the --test_files flag.')
sys.exit(1)
from .train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
def run_script():
create_flags()
absl.app.run(main)
if __name__ == '__main__':
run_script()

View File

@ -0,0 +1,942 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import os
import sys
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
import absl.app
import json
import numpy as np
import progressbar
import shutil
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
import time
tfv1.logging.set_verbosity({
'0': tfv1.logging.DEBUG,
'1': tfv1.logging.INFO,
'2': tfv1.logging.WARN,
'3': tfv1.logging.ERROR
}.get(DESIRED_LOG_LEVEL))
from datetime import datetime
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
from .evaluate import evaluate
from six.moves import zip, range
from .util.config import Config, initialize_globals
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version, ExceptionBox
from .util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
check_ctcdecoder_version()
# Graph Creation
# ==============
def variable_on_cpu(name, shape, initializer):
r"""
Next we concern ourselves with graph creation.
However, before we do so we must introduce a utility function ``variable_on_cpu()``
used to create a variable in CPU memory.
"""
# Use the /cpu:0 device for scoped operations
with tf.device(Config.cpu_device):
# Create or get apropos variable
var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
return var
def create_overlapping_windows(batch_x):
batch_size = tf.shape(input=batch_x)[0]
window_width = 2 * Config.n_context + 1
num_channels = Config.n_input
# Create a constant convolution filter using an identity matrix, so that the
# convolution returns patches of the input tensor as is, and we can create
# overlapping windows over the MFCCs.
eye_filter = tf.constant(np.eye(window_width * num_channels)
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
# Create overlapping windows
batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME')
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
return batch_x
def dense(name, x, units, dropout_rate=None, relu=True):
with tfv1.variable_scope(name):
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
if relu:
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
if dropout_rate is not None:
output = tf.nn.dropout(output, rate=dropout_rate)
return output
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
forget_bias=0,
reuse=reuse,
name='cudnn_compatible_lstm_cell')
output, output_state = fw_cell(inputs=x,
dtype=tf.float32,
sequence_length=seq_length,
initial_state=previous_state)
return output, output_state
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
# the object it creates the variables, and then you just call it several times
# to enable variable re-use. Because all of our code is structure in an old
# school TensorFlow structure where you can just call tf.get_variable again with
# reuse=True to reuse variables, we can't easily make use of the object oriented
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
# emulating a static function variable.
if not rnn_impl_cudnn_rnn.cell:
# Forward direction cell:
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
num_units=Config.n_cell_dim,
input_mode='linear_input',
direction='unidirectional',
dtype=tf.float32)
rnn_impl_cudnn_rnn.cell = fw_cell
output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
sequence_lengths=seq_length)
return output, output_state
rnn_impl_cudnn_rnn.cell = None
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
# Forward direction cell:
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
forget_bias=0,
reuse=reuse,
name='cudnn_compatible_lstm_cell')
# Split rank N tensor into list of rank N-1 tensors
x = [x[l] for l in range(x.shape[0])]
output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
inputs=x,
sequence_length=seq_length,
initial_state=previous_state,
dtype=tf.float32,
scope='cell_0')
output = tf.concat(output, 0)
return output, output_state
def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
layers = {}
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
if not batch_size:
batch_size = tf.shape(input=batch_x)[0]
# Create overlapping feature windows if needed
if overlap:
batch_x = create_overlapping_windows(batch_x)
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
# Permute n_steps and batch_size
batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
# Reshape to prepare input for first layer
batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
layers['input_reshaped'] = batch_x
# The next three blocks will pass `batch_x` through three hidden layers with
# clipped RELU activation and dropout.
layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0])
layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1])
layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2])
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
# Run through parametrized RNN implementation, as we use different RNNs
# for training and inference
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
output = tf.reshape(output, [-1, Config.n_cell_dim])
layers['rnn_output'] = output
layers['rnn_output_state'] = output_state
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5])
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False)
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
# Note, that this differs from the input in that it is time-major.
layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits')
layers['raw_logits'] = layer_6
# Output shape: [n_steps, batch_size, n_hidden_6]
return layer_6, layers
# Accuracy and Loss
# =================
# In accord with 'Deep Speech: Scaling up end-to-end speech recognition'
# (http://arxiv.org/abs/1412.5567),
# the loss function used by our network should be the CTC loss function
# (http://www.cs.toronto.edu/~graves/preprint.pdf).
# Conveniently, this loss function is implemented in TensorFlow.
# Thus, we can simply make use of this implementation to define our loss.
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
r'''
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
Next to total and average loss it returns the mean edit distance,
the decoded result and the batch's original Y.
'''
# Obtain the next batch of data
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
if FLAGS.train_cudnn:
rnn_impl = rnn_impl_cudnn_rnn
else:
rnn_impl = rnn_impl_lstmblockfusedcell
# Calculate the logits of the batch
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
# Compute the CTC loss using TensorFlow's `ctc_loss`
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
# Check if any files lead to non finite loss
non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss)))
# Calculate the average loss across the batch
avg_loss = tf.reduce_mean(input_tensor=total_loss)
# Finally we return the average loss
return avg_loss, non_finite_files
# Adam Optimization
# =================
# In contrast to 'Deep Speech: Scaling up end-to-end speech recognition'
# (http://arxiv.org/abs/1412.5567),
# in which 'Nesterov's Accelerated Gradient Descent'
# (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used,
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
# because, generally, it requires less fine-tuning.
def create_optimizer(learning_rate_var):
optimizer = tfv1.train.AdamOptimizer(learning_rate=learning_rate_var,
beta1=FLAGS.beta1,
beta2=FLAGS.beta2,
epsilon=FLAGS.epsilon)
return optimizer
# Towers
# ======
# In order to properly make use of multiple GPU's, one must introduce new abstractions,
# not present when using a single GPU, that facilitate the multi-GPU use case.
# In particular, one must introduce a means to isolate the inference and gradient
# calculations on the various GPU's.
# The abstraction we intoduce for this purpose is called a 'tower'.
# A tower is specified by two properties:
# * **Scope** - A scope, as provided by `tf.name_scope()`,
# is a means to isolate the operations within a tower.
# For example, all operations within 'tower 0' could have their name prefixed with `tower_0/`.
# * **Device** - A hardware device, as provided by `tf.device()`,
# on which all operations within the tower execute.
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
def get_tower_results(iterator, optimizer, dropout_rates):
r'''
With this preliminary step out of the way, we can for each GPU introduce a
tower for which's batch we calculate and return the optimization gradients
and the average loss across towers.
'''
# To calculate the mean of the losses
tower_avg_losses = []
# Tower gradients to return
tower_gradients = []
# Aggregate any non finite files in the batches
tower_non_finite_files = []
with tfv1.variable_scope(tfv1.get_variable_scope()):
# Loop over available_devices
for i in range(len(Config.available_devices)):
# Execute operations of tower i on device i
device = Config.available_devices[i]
with tf.device(device):
# Create a scope for all operations of tower i
with tf.name_scope('tower_%d' % i):
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
# batch along with the original batch's labels (Y) of this tower
avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
# Allow for variables to be re-used by the next tower
tfv1.get_variable_scope().reuse_variables()
# Retain tower's avg losses
tower_avg_losses.append(avg_loss)
# Compute gradients for model parameters using tower's mini-batch
gradients = optimizer.compute_gradients(avg_loss)
# Retain tower's gradients
tower_gradients.append(gradients)
tower_non_finite_files.append(non_finite_files)
avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0)
tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])
all_non_finite_files = tf.concat(tower_non_finite_files, axis=0)
# Return gradients and the average loss
return tower_gradients, avg_loss_across_towers, all_non_finite_files
def average_gradients(tower_gradients):
r'''
A routine for computing each variable's average of the gradients obtained from the GPUs.
Note also that this code acts as a synchronization point as it requires all
GPUs to be finished with their mini-batch before it can run to completion.
'''
# List of average gradients to return to the caller
average_grads = []
# Run this on cpu_device to conserve GPU memory
with tf.device(Config.cpu_device):
# Loop over gradient/variable pairs from all towers
for grad_and_vars in zip(*tower_gradients):
# Introduce grads to store the gradients for the current variable
grads = []
# Loop over the gradients for the current variable
for g, _ in grad_and_vars:
# Add 0 dimension to the gradients to represent the tower.
expanded_g = tf.expand_dims(g, 0)
# Append on a 'tower' dimension which we will average over below.
grads.append(expanded_g)
# Average over the 'tower' dimension
grad = tf.concat(grads, 0)
grad = tf.reduce_mean(input_tensor=grad, axis=0)
# Create a gradient/variable tuple for the current variable with its average gradient
grad_and_var = (grad, grad_and_vars[0][1])
# Add the current tuple to average_grads
average_grads.append(grad_and_var)
# Return result to caller
return average_grads
# Logging
# =======
def log_variable(variable, gradient=None):
r'''
We introduce a function for logging a tensor variable's current state.
It logs scalar values for the mean, standard deviation, minimum and maximum.
Furthermore it logs a histogram of its state and (if given) of an optimization gradient.
'''
name = variable.name.replace(':', '_')
mean = tf.reduce_mean(input_tensor=variable)
tfv1.summary.scalar(name='%s/mean' % name, tensor=mean)
tfv1.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(input_tensor=tf.square(variable - mean))))
tfv1.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(input_tensor=variable))
tfv1.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(input_tensor=variable))
tfv1.summary.histogram(name=name, values=variable)
if gradient is not None:
if isinstance(gradient, tf.IndexedSlices):
grad_values = gradient.values
else:
grad_values = gradient
if grad_values is not None:
tfv1.summary.histogram(name='%s/gradients' % name, values=grad_values)
def log_grads_and_vars(grads_and_vars):
r'''
Let's also introduce a helper function for logging collections of gradient/variable tuples.
'''
for gradient, variable in grads_and_vars:
log_variable(variable, gradient=gradient)
def train():
do_cache_dataset = True
# pylint: disable=too-many-boolean-expressions
if (FLAGS.data_aug_features_multiplicative > 0 or
FLAGS.data_aug_features_additive > 0 or
FLAGS.augmentation_spec_dropout_keeprate < 1 or
FLAGS.augmentation_freq_and_time_masking or
FLAGS.augmentation_pitch_and_tempo_scaling or
FLAGS.augmentation_speed_up_std > 0 or
FLAGS.augmentation_sparse_warp):
do_cache_dataset = False
exception_box = ExceptionBox()
# Create training and validation datasets
train_set = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size,
enable_cache=FLAGS.feature_cache and do_cache_dataset,
cache_path=FLAGS.feature_cache,
train_phase=True,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
buffering=FLAGS.read_buffer)
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
tfv1.data.get_output_shapes(train_set),
output_classes=tfv1.data.get_output_classes(train_set))
# Make initialization ops for switching between the two sets
train_init_op = iterator.make_initializer(train_set)
if FLAGS.dev_files:
dev_sources = FLAGS.dev_files.split(',')
dev_sets = [create_dataset([source],
batch_size=FLAGS.dev_batch_size,
train_phase=False,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
buffering=FLAGS.read_buffer) for source in dev_sources]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
# Dropout
dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
dropout_feed_dict = {
dropout_rates[0]: FLAGS.dropout_rate,
dropout_rates[1]: FLAGS.dropout_rate2,
dropout_rates[2]: FLAGS.dropout_rate3,
dropout_rates[3]: FLAGS.dropout_rate4,
dropout_rates[4]: FLAGS.dropout_rate5,
dropout_rates[5]: FLAGS.dropout_rate6,
}
no_dropout_feed_dict = {
rate: 0. for rate in dropout_rates
}
# Building the graph
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
optimizer = create_optimizer(learning_rate_var)
# Enable mixed precision training
if FLAGS.automatic_mixed_precision:
log_info('Enabling automatic mixed precision training.')
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
# Average tower gradients across GPUs
avg_tower_gradients = average_gradients(gradients)
log_grads_and_vars(avg_tower_gradients)
# global_step is automagically incremented by the optimizer
global_step = tfv1.train.get_or_create_global_step()
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
# Summaries
step_summaries_op = tfv1.summary.merge_all('step_summaries')
step_summary_writers = {
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
}
# Checkpointing
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
# Save flags next to checkpoints
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
with open(flags_file, 'w') as fout:
fout.write(FLAGS.flags_into_string())
with tfv1.Session(config=Config.session_config) as session:
log_debug('Session opened.')
# Prevent further graph changes
tfv1.get_default_graph().finalize()
# Load checkpoint or initialize variables
load_or_init_graph_for_training(session)
def run_set(set_name, epoch, init_op, dataset=None):
is_train = set_name == 'train'
train_op = apply_gradient_op if is_train else []
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
total_loss = 0.0
step_count = 0
step_summary_writer = step_summary_writers.get(set_name)
checkpoint_time = time.time()
# Setup progress bar
class LossWidget(progressbar.widgets.FormatLabel):
def __init__(self):
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
def __call__(self, progress, data, **kwargs):
data['mean_loss'] = total_loss / step_count if step_count else 0.0
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
widgets = [' | ', progressbar.widgets.Timer(),
' | Steps: ', progressbar.widgets.Counter(),
' | ', LossWidget()]
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
# Initialize iterator to the appropriate dataset
session.run(init_op)
# Batch loop
while True:
try:
_, current_step, batch_loss, problem_files, step_summary = \
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
feed_dict=feed_dict)
exception_box.raise_if_set()
except tf.errors.InvalidArgumentError as err:
if FLAGS.augmentation_sparse_warp:
log_info("Ignoring sparse warp error: {}".format(err))
continue
raise
except tf.errors.OutOfRangeError:
exception_box.raise_if_set()
break
if problem_files.size > 0:
problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
log_error('The following files caused an infinite (or NaN) '
'loss: {}'.format(','.join(problem_files)))
total_loss += batch_loss
step_count += 1
pbar.update(step_count)
step_summary_writer.add_summary(step_summary, current_step)
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
checkpoint_time = time.time()
pbar.finish()
mean_loss = total_loss / step_count if step_count > 0 else 0.0
return mean_loss, step_count
log_info('STARTING Optimization')
train_start_time = datetime.utcnow()
best_dev_loss = float('inf')
dev_losses = []
epochs_without_improvement = 0
try:
for epoch in range(FLAGS.epochs):
# Training
log_progress('Training epoch %d...' % epoch)
train_loss, _ = run_set('train', epoch, train_init_op)
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
if FLAGS.dev_files:
# Validation
dev_loss = 0.0
total_steps = 0
for source, init_op in zip(dev_sources, dev_init_ops):
log_progress('Validating epoch %d on %s...' % (epoch, source))
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
dev_loss += set_loss * steps
total_steps += steps
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
dev_loss = dev_loss / total_steps
dev_losses.append(dev_loss)
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
# the improvement has to be greater than FLAGS.es_min_delta
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
epochs_without_improvement += 1
else:
epochs_without_improvement = 0
# Save new best model
if dev_loss < best_dev_loss:
best_dev_loss = dev_loss
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
# Early stopping
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
epochs_without_improvement))
break
# Reduce learning rate on plateau
if (FLAGS.reduce_lr_on_plateau and
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
# If the learning rate was reduced and there is still no improvement
# wait FLAGS.plateau_epochs before the learning rate is reduced again
session.run(reduce_learning_rate_op)
current_learning_rate = learning_rate_var.eval()
log_info('Encountered a plateau, reducing learning rate to {}'.format(
current_learning_rate))
except KeyboardInterrupt:
pass
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
log_debug('Session closed.')
def test():
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
batch_size = batch_size if batch_size > 0 else None
# Create feature computation graph
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
samples = tf.expand_dims(input_samples, -1)
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
mfccs = tf.identity(mfccs, name='mfccs')
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
# This shape is read by the native_client in DS_CreateModel to know the
# value of n_steps, n_context and n_input. Make sure you update the code
# there if this shape is changed.
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
if batch_size <= 0:
# no state management since n_step is expected to be dynamic too (see below)
previous_state = None
else:
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h)
# One rate per layer
no_dropout = [None] * 6
if tflite:
rnn_impl = rnn_impl_static_rnn
else:
rnn_impl = rnn_impl_lstmblockfusedcell
logits, layers = create_model(batch_x=input_tensor,
batch_size=batch_size,
seq_length=seq_length if not FLAGS.export_tflite else None,
dropout=no_dropout,
previous_state=previous_state,
overlap=False,
rnn_impl=rnn_impl)
# TF Lite runtime will check that input dimensions are 1, 2 or 4
# by default we get 3, the middle one being batch_size which is forced to
# one on inference graph, so remove that dimension
if tflite:
logits = tf.squeeze(logits, [1])
# Apply softmax for CTC decoder
logits = tf.nn.softmax(logits, name='logits')
if batch_size <= 0:
if tflite:
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
if n_steps > 0:
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
return (
{
'input': input_tensor,
'input_lengths': seq_length,
},
{
'outputs': logits,
},
layers
)
new_state_c, new_state_h = layers['rnn_output_state']
new_state_c = tf.identity(new_state_c, name='new_state_c')
new_state_h = tf.identity(new_state_h, name='new_state_h')
inputs = {
'input': input_tensor,
'previous_state_c': previous_state_c,
'previous_state_h': previous_state_h,
'input_samples': input_samples,
}
if not FLAGS.export_tflite:
inputs['input_lengths'] = seq_length
outputs = {
'outputs': logits,
'new_state_c': new_state_c,
'new_state_h': new_state_h,
'mfccs': mfccs,
}
return inputs, outputs, layers
def file_relative_read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
def export():
r'''
Restores the trained variables into a simpler graph that will be exported for serving.
'''
log_info('Exporting the model...')
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
assert graph_version > 0
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
if FLAGS.export_language:
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
# Prevent further graph changes
tfv1.get_default_graph().finalize()
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, tf.Tensor)]
output_names_ops = [op.name for op in outputs.values() if isinstance(op, tf.Operation)]
output_names = output_names_tensors + output_names_ops
with tf.Session() as session:
# Restore variables from checkpoint
load_graph_for_evaluation(session)
output_filename = FLAGS.export_file_name + '.pb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir)
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
sess=session,
input_graph_def=tfv1.get_default_graph().as_graph_def(),
output_node_names=output_names)
frozen_graph = tfv1.graph_util.extract_sub_graph(
graph_def=frozen_graph,
dest_nodes=output_names)
if not FLAGS.export_tflite:
with open(output_graph_path, 'wb') as fout:
fout.write(frozen_graph.SerializeToString())
else:
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
converter.allow_custom_ops = True
tflite_model = converter.convert()
with open(output_tflite_path, 'wb') as fout:
fout.write(tflite_model)
log_info('Models exported at %s' % (FLAGS.export_dir))
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
FLAGS.export_author_id,
FLAGS.export_model_name,
FLAGS.export_model_version))
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
with open(metadata_fname, 'w') as f:
f.write('---\n')
f.write('author: {}\n'.format(FLAGS.export_author_id))
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
f.write('license: {}\n'.format(FLAGS.export_license))
f.write('language: {}\n'.format(FLAGS.export_language))
f.write('runtime: {}\n'.format(model_runtime))
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
f.write('---\n')
f.write('{}\n'.format(FLAGS.export_description))
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
def package_zip():
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
zip_filename = os.path.dirname(export_dir)
shutil.copy(FLAGS.scorer_path, export_dir)
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
log_info('Exported packaged model {}'.format(archive))
def do_single_file_inference(input_file_path):
with tfv1.Session(config=Config.session_config) as session:
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
# Restore variables from training checkpoint
load_graph_for_evaluation(session)
features, features_len = audiofile_to_features(input_file_path)
previous_state_c = np.zeros([1, Config.n_cell_dim])
previous_state_h = np.zeros([1, Config.n_cell_dim])
# Add batch dimension
features = tf.expand_dims(features, 0)
features_len = tf.expand_dims(features_len, 0)
# Evaluate
features = create_overlapping_windows(features).eval(session=session)
features_len = features_len.eval(session=session)
logits = outputs['outputs'].eval(feed_dict={
inputs['input']: features,
inputs['input_lengths']: features_len,
inputs['previous_state_c']: previous_state_c,
inputs['previous_state_h']: previous_state_h,
}, session=session)
logits = np.squeeze(logits)
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet)
else:
scorer = None
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
cutoff_top_n=FLAGS.cutoff_top_n)
# Print highest probability result
print(decoded[0][1])
def early_training_checks():
# Check for proper scorer early
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet)
del scorer
if FLAGS.train_files and FLAGS.test_files and FLAGS.load_checkpoint_dir != FLAGS.save_checkpoint_dir:
log_warn('WARNING: You specified different values for --load_checkpoint_dir '
'and --save_checkpoint_dir, but you are running training and testing '
'in a single invocation. The testing step will respect --load_checkpoint_dir, '
'and thus WILL NOT TEST THE CHECKPOINT CREATED BY THE TRAINING STEP. '
'Train and test in two separate invocations, specifying the correct '
'--load_checkpoint_dir in both cases, or use the same location '
'for loading and saving.')
def main(_):
initialize_globals()
early_training_checks()
if FLAGS.train_files:
tfv1.reset_default_graph()
tfv1.set_random_seed(FLAGS.random_seed)
train()
if FLAGS.test_files:
tfv1.reset_default_graph()
test()
if FLAGS.export_dir and not FLAGS.export_zip:
tfv1.reset_default_graph()
export()
if FLAGS.export_zip:
tfv1.reset_default_graph()
FLAGS.export_tflite = True
if os.listdir(FLAGS.export_dir):
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
sys.exit(1)
export()
package_zip()
if FLAGS.one_shot_infer:
tfv1.reset_default_graph()
do_single_file_inference(FLAGS.one_shot_infer)
def run_script():
create_flags()
absl.app.run(main)
if __name__ == '__main__':
run_script()

View File

@ -5,7 +5,7 @@ import tempfile
import collections
import numpy as np
from util.helpers import LimitingPool
from .helpers import LimitingPool
DEFAULT_RATE = 16000
DEFAULT_CHANNELS = 1

View File

@ -2,11 +2,11 @@ import sys
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from util.flags import FLAGS
from util.logging import log_info, log_error, log_warn
from .flags import FLAGS
from .logging import log_info, log_error, log_warn
def _load_checkpoint(session, checkpoint_path):
def _load_checkpoint(session, checkpoint_path, allow_drop_layers):
# Load the checkpoint and put all variables into loading list
# we will exclude variables we do not wish to load and then
# we will initialize them instead
@ -45,7 +45,7 @@ def _load_checkpoint(session, checkpoint_path):
'tensors. Missing variables: {}'.format(missing_var_names))
sys.exit(1)
if FLAGS.drop_source_layers > 0:
if allow_drop_layers and FLAGS.drop_source_layers > 0:
# This transfer learning approach requires supplying
# the layers which we exclude from the source model.
# Say we want to exclude all layers except for the first one,
@ -87,20 +87,14 @@ def _initialize_all_variables(session):
session.run(v.initializer)
def load_or_init_graph(session, method_order):
'''
Load variables from checkpoint or initialize variables following the method
order specified in the method_order parameter.
Valid methods are 'best', 'last' and 'init'.
'''
def _load_or_init_impl(session, method_order, allow_drop_layers):
for method in method_order:
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
if method == 'best':
ckpt_path = _checkpoint_path_or_none('best_dev_checkpoint')
if ckpt_path:
log_info('Loading best validating checkpoint from {}'.format(ckpt_path))
return _load_checkpoint(session, ckpt_path)
return _load_checkpoint(session, ckpt_path, allow_drop_layers)
log_info('Could not find best validating checkpoint.')
# Load most recent checkpoint, saved in checkpoint file 'checkpoint'
@ -108,7 +102,7 @@ def load_or_init_graph(session, method_order):
ckpt_path = _checkpoint_path_or_none('checkpoint')
if ckpt_path:
log_info('Loading most recent checkpoint from {}'.format(ckpt_path))
return _load_checkpoint(session, ckpt_path)
return _load_checkpoint(session, ckpt_path, allow_drop_layers)
log_info('Could not find most recent checkpoint.')
# Initialize all variables
@ -122,3 +116,31 @@ def load_or_init_graph(session, method_order):
log_error('All initialization methods failed ({}).'.format(method_order))
sys.exit(1)
def load_or_init_graph_for_training(session):
'''
Load variables from checkpoint or initialize variables. By default this will
try to load the best validating checkpoint, then try the last checkpoint,
and finally initialize the weights from scratch. This can be overriden with
the `--load_train` flag. See its documentation for more info.
'''
if FLAGS.load_train == 'auto':
methods = ['best', 'last', 'init']
else:
methods = [FLAGS.load_train]
_load_or_init_impl(session, methods, allow_drop_layers=True)
def load_graph_for_evaluation(session):
'''
Load variables from checkpoint. Initialization is not allowed. By default
this will try to load the best validating checkpoint, then try the last
checkpoint. This can be overriden with the `--load_evaluate` flag. See its
documentation for more info.
'''
if FLAGS.load_evaluate == 'auto':
methods = ['best', 'last']
else:
methods = [FLAGS.load_evaluate]
_load_or_init_impl(session, methods, allow_drop_layers=False)

View File

@ -8,11 +8,11 @@ import tensorflow.compat.v1 as tfv1
from attrdict import AttrDict
from xdg import BaseDirectory as xdg
from util.flags import FLAGS
from util.gpu import get_available_gpus
from util.logging import log_error
from util.text import Alphabet, UTF8Alphabet
from util.helpers import parse_file_size
from .flags import FLAGS
from .gpu import get_available_gpus
from .logging import log_error, log_warn
from .text import Alphabet, UTF8Alphabet
from .helpers import parse_file_size
class ConfigSingleton:
_config = None
@ -45,8 +45,11 @@ def initialize_globals():
if not FLAGS.checkpoint_dir:
FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('deepspeech', 'checkpoints'))
if FLAGS.load not in ['last', 'best', 'init', 'auto']:
FLAGS.load = 'auto'
if FLAGS.load_train not in ['last', 'best', 'init', 'auto']:
FLAGS.load_train = 'auto'
if FLAGS.load_evaluate not in ['last', 'best', 'auto']:
FLAGS.load_evaluate = 'auto'
# Set default summary dir
if not FLAGS.summary_dir:

View File

@ -7,8 +7,8 @@ import numpy as np
from attrdict import AttrDict
from util.flags import FLAGS
from util.text import levenshtein
from .flags import FLAGS
from .text import levenshtein
def pmap(fun, iterable):

Some files were not shown because too many files have changed in this diff Show More