Rewrite input pipeline to use tf.data API
This commit is contained in:
parent
bd7358d94e
commit
1cea2b0fe8
383
DeepSpeech.py
383
DeepSpeech.py
@ -18,12 +18,10 @@ import tensorflow as tf
|
|||||||
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||||
from six.moves import zip, range
|
from six.moves import zip, range
|
||||||
from tensorflow.python.tools import freeze_graph
|
from tensorflow.python.tools import freeze_graph
|
||||||
from util.audio import audiofile_to_input_vector
|
|
||||||
from util.config import Config, initialize_globals
|
from util.config import Config, initialize_globals
|
||||||
from util.feeding import DataSet, ModelFeeder
|
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||||
from util.flags import create_flags, FLAGS
|
from util.flags import create_flags, FLAGS
|
||||||
from util.logging import log_info, log_error, log_debug, log_warn
|
from util.logging import log_info, log_error, log_debug, log_warn
|
||||||
from util.preprocess import preprocess
|
|
||||||
|
|
||||||
|
|
||||||
# Graph Creation
|
# Graph Creation
|
||||||
@ -42,26 +40,83 @@ def variable_on_cpu(name, shape, initializer):
|
|||||||
return var
|
return var
|
||||||
|
|
||||||
|
|
||||||
def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1, previous_state=None, tflite=False):
|
def create_overlapping_windows(batch_x):
|
||||||
r'''
|
batch_size = tf.shape(batch_x)[0]
|
||||||
That done, we will define the learned variables, the weights and biases,
|
window_width = 2 * Config.n_context + 1
|
||||||
within the method ``BiRNN()`` which also constructs the neural network.
|
num_channels = Config.n_input
|
||||||
The variables named ``hn``, where ``n`` is an integer, hold the learned weight variables.
|
|
||||||
The variables named ``bn``, where ``n`` is an integer, hold the learned bias variables.
|
# Create a constant convolution filter using an identity matrix, so that the
|
||||||
In particular, the first variable ``h1`` holds the learned weight matrix that
|
# convolution returns patches of the input tensor as is, and we can create
|
||||||
converts an input vector of dimension ``n_input + 2*n_input*n_context``
|
# overlapping windows over the MFCCs.
|
||||||
to a vector of dimension ``n_hidden_1``.
|
eye_filter = tf.constant(np.eye(window_width * num_channels)
|
||||||
Similarly, the second variable ``h2`` holds the weight matrix converting
|
.reshape(window_width, num_channels, window_width * num_channels), tf.float32)
|
||||||
an input vector of dimension ``n_hidden_1`` to one of dimension ``n_hidden_2``.
|
|
||||||
The variables ``h3``, ``h5``, and ``h6`` are similar.
|
# Create overlapping windows
|
||||||
Likewise, the biases, ``b1``, ``b2``..., hold the biases for the various layers.
|
batch_x = tf.nn.conv1d(batch_x, eye_filter, stride=1, padding='SAME')
|
||||||
'''
|
|
||||||
|
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
|
||||||
|
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
|
||||||
|
|
||||||
|
return batch_x
|
||||||
|
|
||||||
|
|
||||||
|
def dense(name, x, units, dropout_rate=None, relu=True):
|
||||||
|
with tf.variable_scope(name):
|
||||||
|
bias = variable_on_cpu('bias', [units], tf.zeros_initializer())
|
||||||
|
weights = variable_on_cpu('weights', [x.shape[-1], units], tf.contrib.layers.xavier_initializer())
|
||||||
|
|
||||||
|
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
|
||||||
|
|
||||||
|
if relu:
|
||||||
|
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
|
||||||
|
|
||||||
|
if dropout_rate is not None:
|
||||||
|
output = tf.nn.dropout(output, rate=dropout_rate)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
|
||||||
|
# Forward direction cell:
|
||||||
|
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim, reuse=reuse)
|
||||||
|
|
||||||
|
output, output_state = fw_cell(inputs=x,
|
||||||
|
dtype=tf.float32,
|
||||||
|
sequence_length=seq_length,
|
||||||
|
initial_state=previous_state)
|
||||||
|
|
||||||
|
return output, output_state
|
||||||
|
|
||||||
|
|
||||||
|
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
|
||||||
|
# Forward direction cell:
|
||||||
|
fw_cell = tf.nn.rnn_cell.LSTMCell(Config.n_cell_dim, reuse=reuse)
|
||||||
|
|
||||||
|
# Split rank N tensor into list of rank N-1 tensors
|
||||||
|
x = [x[l] for l in range(x.shape[0])]
|
||||||
|
|
||||||
|
# We parametrize the RNN implementation as the training and inference graph
|
||||||
|
# need to do different things here.
|
||||||
|
output, output_state = tf.nn.static_rnn(cell=fw_cell,
|
||||||
|
inputs=x,
|
||||||
|
initial_state=previous_state,
|
||||||
|
dtype=tf.float32,
|
||||||
|
sequence_length=seq_length)
|
||||||
|
output = tf.concat(output, 0)
|
||||||
|
|
||||||
|
return output, output_state
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(batch_x, seq_length, dropout, reuse=False, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
|
||||||
layers = {}
|
layers = {}
|
||||||
|
|
||||||
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
|
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
|
||||||
if not batch_size:
|
|
||||||
batch_size = tf.shape(batch_x)[0]
|
batch_size = tf.shape(batch_x)[0]
|
||||||
|
|
||||||
|
# Create overlapping feature windows if needed
|
||||||
|
if overlap:
|
||||||
|
batch_x = create_overlapping_windows(batch_x)
|
||||||
|
|
||||||
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
|
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
|
||||||
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
|
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
|
||||||
|
|
||||||
@ -73,58 +128,17 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1
|
|||||||
|
|
||||||
# The next three blocks will pass `batch_x` through three hidden layers with
|
# The next three blocks will pass `batch_x` through three hidden layers with
|
||||||
# clipped RELU activation and dropout.
|
# clipped RELU activation and dropout.
|
||||||
|
layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1)
|
||||||
# 1st layer
|
layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2)
|
||||||
b1 = variable_on_cpu('b1', [Config.n_hidden_1], tf.zeros_initializer())
|
layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3)
|
||||||
h1 = variable_on_cpu('h1', [Config.n_input + 2*Config.n_input*Config.n_context, Config.n_hidden_1], tf.contrib.layers.xavier_initializer())
|
|
||||||
layer_1 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(batch_x, h1), b1)), FLAGS.relu_clip)
|
|
||||||
layer_1 = tf.nn.dropout(layer_1, rate=dropout[0])
|
|
||||||
layers['layer_1'] = layer_1
|
|
||||||
|
|
||||||
# 2nd layer
|
|
||||||
b2 = variable_on_cpu('b2', [Config.n_hidden_2], tf.zeros_initializer())
|
|
||||||
h2 = variable_on_cpu('h2', [Config.n_hidden_1, Config.n_hidden_2], tf.contrib.layers.xavier_initializer())
|
|
||||||
layer_2 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(layer_1, h2), b2)), FLAGS.relu_clip)
|
|
||||||
layer_2 = tf.nn.dropout(layer_2, rate=dropout[1])
|
|
||||||
layers['layer_2'] = layer_2
|
|
||||||
|
|
||||||
# 3rd layer
|
|
||||||
b3 = variable_on_cpu('b3', [Config.n_hidden_3], tf.zeros_initializer())
|
|
||||||
h3 = variable_on_cpu('h3', [Config.n_hidden_2, Config.n_hidden_3], tf.contrib.layers.xavier_initializer())
|
|
||||||
layer_3 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(layer_2, h3), b3)), FLAGS.relu_clip)
|
|
||||||
layer_3 = tf.nn.dropout(layer_3, rate=dropout[2])
|
|
||||||
layers['layer_3'] = layer_3
|
|
||||||
|
|
||||||
# Now we create the forward and backward LSTM units.
|
|
||||||
# Both of which have inputs of length `n_cell_dim` and bias `1.0` for the forget gate of the LSTM.
|
|
||||||
|
|
||||||
# Forward direction cell:
|
|
||||||
if not tflite:
|
|
||||||
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim, reuse=reuse)
|
|
||||||
layers['fw_cell'] = fw_cell
|
|
||||||
else:
|
|
||||||
fw_cell = tf.nn.rnn_cell.LSTMCell(Config.n_cell_dim, reuse=reuse)
|
|
||||||
|
|
||||||
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
|
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
|
||||||
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
|
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
|
||||||
layer_3 = tf.reshape(layer_3, [n_steps, batch_size, Config.n_hidden_3])
|
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
|
||||||
if tflite:
|
|
||||||
# Generated StridedSlice, not supported by NNAPI
|
|
||||||
#n_layer_3 = []
|
|
||||||
#for l in range(layer_3.shape[0]):
|
|
||||||
# n_layer_3.append(layer_3[l])
|
|
||||||
#layer_3 = n_layer_3
|
|
||||||
|
|
||||||
# Unstack/Unpack is not supported by NNAPI
|
# Run through parametrized RNN implementation, as we use different RNNs
|
||||||
layer_3 = tf.unstack(layer_3, n_steps)
|
# for training and inference
|
||||||
|
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
|
||||||
# We parametrize the RNN implementation as the training and inference graph
|
|
||||||
# need to do different things here.
|
|
||||||
if not tflite:
|
|
||||||
output, output_state = fw_cell(inputs=layer_3, dtype=tf.float32, sequence_length=seq_length, initial_state=previous_state)
|
|
||||||
else:
|
|
||||||
output, output_state = tf.nn.static_rnn(fw_cell, layer_3, previous_state, tf.float32)
|
|
||||||
output = tf.concat(output, 0)
|
|
||||||
|
|
||||||
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
|
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
|
||||||
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
|
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
|
||||||
@ -132,24 +146,16 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1
|
|||||||
layers['rnn_output'] = output
|
layers['rnn_output'] = output
|
||||||
layers['rnn_output_state'] = output_state
|
layers['rnn_output_state'] = output_state
|
||||||
|
|
||||||
# Now we feed `output` to the fifth hidden layer with clipped RELU activation and dropout
|
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
|
||||||
b5 = variable_on_cpu('b5', [Config.n_hidden_5], tf.zeros_initializer())
|
layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5)
|
||||||
h5 = variable_on_cpu('h5', [Config.n_cell_dim, Config.n_hidden_5], tf.contrib.layers.xavier_initializer())
|
|
||||||
layer_5 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(output, h5), b5)), FLAGS.relu_clip)
|
|
||||||
layer_5 = tf.nn.dropout(layer_5, rate=dropout[5])
|
|
||||||
layers['layer_5'] = layer_5
|
|
||||||
|
|
||||||
# Now we apply the weight matrix `h6` and bias `b6` to the output of `layer_5`
|
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
|
||||||
# creating `n_classes` dimensional vectors, the logits.
|
layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False)
|
||||||
b6 = variable_on_cpu('b6', [Config.n_hidden_6], tf.zeros_initializer())
|
|
||||||
h6 = variable_on_cpu('h6', [Config.n_hidden_5, Config.n_hidden_6], tf.contrib.layers.xavier_initializer())
|
|
||||||
layer_6 = tf.add(tf.matmul(layer_5, h6), b6)
|
|
||||||
layers['layer_6'] = layer_6
|
|
||||||
|
|
||||||
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
|
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
|
||||||
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
|
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
|
||||||
# Note, that this differs from the input in that it is time-major.
|
# Note, that this differs from the input in that it is time-major.
|
||||||
layer_6 = tf.reshape(layer_6, [n_steps, batch_size, Config.n_hidden_6], name="raw_logits")
|
layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits')
|
||||||
layers['raw_logits'] = layer_6
|
layers['raw_logits'] = layer_6
|
||||||
|
|
||||||
# Output shape: [n_steps, batch_size, n_hidden_6]
|
# Output shape: [n_steps, batch_size, n_hidden_6]
|
||||||
@ -166,17 +172,17 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1
|
|||||||
# Conveniently, this loss function is implemented in TensorFlow.
|
# Conveniently, this loss function is implemented in TensorFlow.
|
||||||
# Thus, we can simply make use of this implementation to define our loss.
|
# Thus, we can simply make use of this implementation to define our loss.
|
||||||
|
|
||||||
def calculate_mean_edit_distance_and_loss(model_feeder, tower, dropout, reuse):
|
def calculate_mean_edit_distance_and_loss(iterator, tower, dropout, reuse):
|
||||||
r'''
|
r'''
|
||||||
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
|
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
|
||||||
Next to total and average loss it returns the mean edit distance,
|
Next to total and average loss it returns the mean edit distance,
|
||||||
the decoded result and the batch's original Y.
|
the decoded result and the batch's original Y.
|
||||||
'''
|
'''
|
||||||
# Obtain the next batch of data
|
# Obtain the next batch of data
|
||||||
batch_x, batch_seq_len, batch_y = model_feeder.next_batch(tower)
|
(batch_x, batch_seq_len), batch_y = iterator.get_next()
|
||||||
|
|
||||||
# Calculate the logits of the batch using BiRNN
|
# Calculate the logits of the batch
|
||||||
logits, _ = BiRNN(batch_x, batch_seq_len, dropout, reuse)
|
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse)
|
||||||
|
|
||||||
# Compute the CTC loss using TensorFlow's `ctc_loss`
|
# Compute the CTC loss using TensorFlow's `ctc_loss`
|
||||||
total_loss = tf.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
|
total_loss = tf.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
|
||||||
@ -221,7 +227,7 @@ def create_optimizer():
|
|||||||
# on which all operations within the tower execute.
|
# on which all operations within the tower execute.
|
||||||
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
|
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
|
||||||
|
|
||||||
def get_tower_results(model_feeder, optimizer, dropout_rates):
|
def get_tower_results(iterator, optimizer, dropout_rates):
|
||||||
r'''
|
r'''
|
||||||
With this preliminary step out of the way, we can for each GPU introduce a
|
With this preliminary step out of the way, we can for each GPU introduce a
|
||||||
tower for which's batch we calculate and return the optimization gradients
|
tower for which's batch we calculate and return the optimization gradients
|
||||||
@ -243,7 +249,7 @@ def get_tower_results(model_feeder, optimizer, dropout_rates):
|
|||||||
with tf.name_scope('tower_%d' % i) as scope:
|
with tf.name_scope('tower_%d' % i) as scope:
|
||||||
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
|
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
|
||||||
# batch along with the original batch's labels (Y) of this tower
|
# batch along with the original batch's labels (Y) of this tower
|
||||||
avg_loss = calculate_mean_edit_distance_and_loss(model_feeder, i, dropout_rates, reuse=i>0)
|
avg_loss = calculate_mean_edit_distance_and_loss(iterator, i, dropout_rates, reuse=i>0)
|
||||||
|
|
||||||
# Allow for variables to be re-used by the next tower
|
# Allow for variables to be re-used by the next tower
|
||||||
tf.get_variable_scope().reuse_variables()
|
tf.get_variable_scope().reuse_variables()
|
||||||
@ -337,19 +343,6 @@ def log_grads_and_vars(grads_and_vars):
|
|||||||
log_variable(variable, gradient=gradient)
|
log_variable(variable, gradient=gradient)
|
||||||
|
|
||||||
|
|
||||||
# Helpers
|
|
||||||
# =======
|
|
||||||
|
|
||||||
|
|
||||||
class SampleIndex:
|
|
||||||
def __init__(self, index=0):
|
|
||||||
self.index = index
|
|
||||||
|
|
||||||
def inc(self, old_index):
|
|
||||||
self.index += 1
|
|
||||||
return self.index
|
|
||||||
|
|
||||||
|
|
||||||
def try_loading(session, saver, checkpoint_filename, caption):
|
def try_loading(session, saver, checkpoint_filename, caption):
|
||||||
try:
|
try:
|
||||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename)
|
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename)
|
||||||
@ -371,48 +364,23 @@ def try_loading(session, saver, checkpoint_filename, caption):
|
|||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
r'''
|
# Create training and validation datasets
|
||||||
Trains the network on a given server of a cluster.
|
train_set, train_batches = create_dataset(FLAGS.train_files.split(','),
|
||||||
If no server provided, it performs single process training.
|
batch_size=FLAGS.train_batch_size,
|
||||||
'''
|
cache_path=FLAGS.train_cached_features_path)
|
||||||
|
|
||||||
# Reading training set
|
iterator = tf.data.Iterator.from_structure(train_set.output_types,
|
||||||
train_index = SampleIndex()
|
train_set.output_shapes,
|
||||||
|
output_classes=train_set.output_classes)
|
||||||
|
|
||||||
train_data = preprocess(FLAGS.train_files.split(','),
|
# Make initialization ops for switching between the two sets
|
||||||
FLAGS.train_batch_size,
|
train_init_op = iterator.make_initializer(train_set)
|
||||||
Config.n_input,
|
|
||||||
Config.n_context,
|
|
||||||
Config.alphabet,
|
|
||||||
hdf5_cache_path=FLAGS.train_cached_features_path)
|
|
||||||
|
|
||||||
train_set = DataSet(train_data,
|
if FLAGS.dev:
|
||||||
FLAGS.train_batch_size,
|
dev_set, dev_batches = create_dataset(FLAGS.dev_files.split(','),
|
||||||
limit=FLAGS.limit_train,
|
batch_size=FLAGS.dev_batch_size,
|
||||||
next_index=train_index.inc)
|
cache_path=FLAGS.dev_cached_features_path)
|
||||||
|
dev_init_op = iterator.make_initializer(dev_set)
|
||||||
# Reading validation set
|
|
||||||
dev_index = SampleIndex()
|
|
||||||
|
|
||||||
dev_data = preprocess(FLAGS.dev_files.split(','),
|
|
||||||
FLAGS.dev_batch_size,
|
|
||||||
Config.n_input,
|
|
||||||
Config.n_context,
|
|
||||||
Config.alphabet,
|
|
||||||
hdf5_cache_path=FLAGS.dev_cached_features_path)
|
|
||||||
|
|
||||||
dev_set = DataSet(dev_data,
|
|
||||||
FLAGS.dev_batch_size,
|
|
||||||
limit=FLAGS.limit_dev,
|
|
||||||
next_index=dev_index.inc)
|
|
||||||
|
|
||||||
# Combining all sets to a multi set model feeder
|
|
||||||
model_feeder = ModelFeeder(train_set,
|
|
||||||
dev_set,
|
|
||||||
Config.n_input,
|
|
||||||
Config.n_context,
|
|
||||||
Config.alphabet,
|
|
||||||
tower_feeder_count=len(Config.available_devices))
|
|
||||||
|
|
||||||
# Dropout
|
# Dropout
|
||||||
dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
|
dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
|
||||||
@ -425,17 +393,12 @@ def train():
|
|||||||
dropout_rates[5]: FLAGS.dropout_rate6,
|
dropout_rates[5]: FLAGS.dropout_rate6,
|
||||||
}
|
}
|
||||||
no_dropout_feed_dict = {
|
no_dropout_feed_dict = {
|
||||||
dropout_rates[0]: 0.,
|
rate: 0. for rate in dropout_rates
|
||||||
dropout_rates[1]: 0.,
|
|
||||||
dropout_rates[2]: 0.,
|
|
||||||
dropout_rates[3]: 0.,
|
|
||||||
dropout_rates[4]: 0.,
|
|
||||||
dropout_rates[5]: 0.,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Building the graph
|
# Building the graph
|
||||||
optimizer = create_optimizer()
|
optimizer = create_optimizer()
|
||||||
gradients, loss = get_tower_results(model_feeder, optimizer, dropout_rates)
|
gradients, loss = get_tower_results(iterator, optimizer, dropout_rates)
|
||||||
# Average tower gradients across GPUs
|
# Average tower gradients across GPUs
|
||||||
avg_tower_gradients = average_gradients(gradients)
|
avg_tower_gradients = average_gradients(gradients)
|
||||||
log_grads_and_vars(avg_tower_gradients)
|
log_grads_and_vars(avg_tower_gradients)
|
||||||
@ -463,6 +426,7 @@ def train():
|
|||||||
|
|
||||||
with tf.Session(config=Config.session_config) as session:
|
with tf.Session(config=Config.session_config) as session:
|
||||||
log_debug('Session opened.')
|
log_debug('Session opened.')
|
||||||
|
|
||||||
tf.get_default_graph().finalize()
|
tf.get_default_graph().finalize()
|
||||||
|
|
||||||
# Loading or initializing
|
# Loading or initializing
|
||||||
@ -481,51 +445,51 @@ def train():
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Retrieving global_step from restored model and setting training parameters accordingly
|
# Retrieving global_step from restored model and setting training parameters accordingly
|
||||||
model_feeder.set_data_set(no_dropout_feed_dict, train_set)
|
step = session.run(global_step)
|
||||||
step = session.run(global_step, feed_dict=no_dropout_feed_dict)
|
|
||||||
num_gpus = len(Config.available_devices)
|
num_gpus = len(Config.available_devices)
|
||||||
steps_per_epoch = max(1, train_set.total_batches // num_gpus)
|
steps_per_epoch = max(1, train_batches // num_gpus)
|
||||||
steps_trained = step % steps_per_epoch
|
|
||||||
current_epoch = step // steps_per_epoch
|
current_epoch = step // steps_per_epoch
|
||||||
target_epoch = current_epoch + abs(FLAGS.epoch) if FLAGS.epoch < 0 else FLAGS.epoch
|
target_epoch = current_epoch + abs(FLAGS.epoch) if FLAGS.epoch < 0 else FLAGS.epoch
|
||||||
train_index.index = steps_trained * num_gpus
|
|
||||||
|
|
||||||
log_debug('step: %d' % step)
|
log_debug('step: %d' % step)
|
||||||
log_debug('epoch: %d' % current_epoch)
|
log_debug('epoch: %d' % current_epoch)
|
||||||
log_debug('target epoch: %d' % target_epoch)
|
log_debug('target epoch: %d' % target_epoch)
|
||||||
log_debug('steps per epoch: %d' % steps_per_epoch)
|
log_debug('steps per epoch: %d' % steps_per_epoch)
|
||||||
log_debug('batches per step (GPUs): %d' % num_gpus)
|
log_debug('batches per step (GPUs): %d' % num_gpus)
|
||||||
log_debug('number of batches in train set: %d' % train_set.total_batches)
|
log_debug('number of batches in train set: %d' % train_batches)
|
||||||
log_debug('number of batches already trained in epoch: %d' % train_index.index)
|
|
||||||
|
|
||||||
def run_set(set_name):
|
def run_set(set_name, init_op, num_batches):
|
||||||
data_set = getattr(model_feeder, set_name)
|
|
||||||
is_train = set_name == 'train'
|
is_train = set_name == 'train'
|
||||||
train_op = apply_gradient_op if is_train else []
|
train_op = apply_gradient_op if is_train else []
|
||||||
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
|
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
|
||||||
model_feeder.set_data_set(feed_dict, data_set)
|
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
step_summary_writer = step_summary_writers.get(set_name)
|
step_summary_writer = step_summary_writers.get(set_name)
|
||||||
num_steps = max(1, data_set.total_batches // num_gpus)
|
num_steps = max(1, num_batches // num_gpus)
|
||||||
checkpoint_time = time.time()
|
checkpoint_time = time.time()
|
||||||
|
|
||||||
if FLAGS.show_progressbar:
|
if FLAGS.show_progressbar:
|
||||||
pbar = progressbar.ProgressBar(max_value=num_steps, redirect_stdout=True).start()
|
pbar = progressbar.ProgressBar(max_value=num_steps, redirect_stdout=True).start()
|
||||||
|
else:
|
||||||
|
pbar = lambda i: i
|
||||||
|
|
||||||
|
# Initialize iterator to the appropriate dataset
|
||||||
|
session.run(init_op)
|
||||||
|
|
||||||
# Batch loop
|
# Batch loop
|
||||||
for step_index in range(steps_trained, num_steps):
|
for step_index in pbar(range(num_steps)):
|
||||||
if coord.should_stop():
|
if coord.should_stop():
|
||||||
break
|
break
|
||||||
|
|
||||||
_, current_step, batch_loss, step_summary = \
|
_, current_step, batch_loss, step_summary = \
|
||||||
session.run([train_op, global_step, loss, step_summaries_op],
|
session.run([train_op, global_step, loss, step_summaries_op],
|
||||||
feed_dict=feed_dict)
|
feed_dict=feed_dict)
|
||||||
total_loss += batch_loss
|
total_loss += batch_loss
|
||||||
step_summary_writer.add_summary(step_summary, current_step)
|
step_summary_writer.add_summary(step_summary, current_step)
|
||||||
if FLAGS.show_progressbar:
|
|
||||||
pbar.update(step_index + 1, force=True)
|
|
||||||
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
|
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
|
||||||
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
|
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
|
||||||
checkpoint_time = time.time()
|
checkpoint_time = time.time()
|
||||||
if FLAGS.show_progressbar:
|
|
||||||
pbar.finish()
|
|
||||||
return total_loss / num_steps
|
return total_loss / num_steps
|
||||||
|
|
||||||
if target_epoch > current_epoch:
|
if target_epoch > current_epoch:
|
||||||
@ -534,28 +498,28 @@ def train():
|
|||||||
dev_losses = []
|
dev_losses = []
|
||||||
coord = tf.train.Coordinator()
|
coord = tf.train.Coordinator()
|
||||||
with coord.stop_on_exception():
|
with coord.stop_on_exception():
|
||||||
log_debug('Starting queue runners...')
|
|
||||||
model_feeder.start_queue_threads(session, coord=coord)
|
|
||||||
log_debug('Queue runners started.')
|
|
||||||
# Epoch loop
|
|
||||||
for current_epoch in range(current_epoch, target_epoch):
|
for current_epoch in range(current_epoch, target_epoch):
|
||||||
# Training
|
|
||||||
if coord.should_stop():
|
if coord.should_stop():
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Training
|
||||||
log_info('Training epoch %d ...' % current_epoch)
|
log_info('Training epoch %d ...' % current_epoch)
|
||||||
train_loss = run_set('train')
|
train_loss = run_set('train', train_init_op, train_batches)
|
||||||
log_info('Finished training epoch %d - loss: %f' % (current_epoch, train_loss))
|
log_info('Finished training epoch %d - loss: %f' % (current_epoch, train_loss))
|
||||||
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
|
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
|
||||||
steps_trained = 0
|
|
||||||
|
if FLAGS.dev:
|
||||||
# Validation
|
# Validation
|
||||||
log_info('Validating epoch %d ...' % current_epoch)
|
log_info('Validating epoch %d ...' % current_epoch)
|
||||||
dev_loss = run_set('dev')
|
dev_loss = run_set('dev', dev_init_op, dev_batches)
|
||||||
dev_losses.append(dev_loss)
|
dev_losses.append(dev_loss)
|
||||||
log_info('Finished validating epoch %d - loss: %f' % (current_epoch, dev_loss))
|
log_info('Finished validating epoch %d - loss: %f' % (current_epoch, dev_loss))
|
||||||
|
|
||||||
if dev_loss < best_dev_loss:
|
if dev_loss < best_dev_loss:
|
||||||
best_dev_loss = dev_loss
|
best_dev_loss = dev_loss
|
||||||
save_path = best_dev_saver.save(session, best_dev_path, latest_filename=best_dev_filename)
|
save_path = best_dev_saver.save(session, best_dev_path, latest_filename=best_dev_filename)
|
||||||
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
||||||
|
|
||||||
# Early stopping
|
# Early stopping
|
||||||
if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps:
|
if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps:
|
||||||
mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1])
|
mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1])
|
||||||
@ -570,30 +534,26 @@ def train():
|
|||||||
' %f with standard deviation: %f and mean: %f' %
|
' %f with standard deviation: %f and mean: %f' %
|
||||||
(FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
|
(FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
|
||||||
break
|
break
|
||||||
log_debug('Closing queues...')
|
|
||||||
coord.request_stop()
|
coord.request_stop()
|
||||||
model_feeder.close_queues(session)
|
|
||||||
log_debug('Queues closed.')
|
|
||||||
else:
|
else:
|
||||||
log_info('Target epoch already reached - skipped training.')
|
log_info('Target epoch already reached - skipped training.')
|
||||||
log_debug('Session closed.')
|
log_debug('Session closed.')
|
||||||
|
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
# Reading test set
|
evaluate.evaluate(FLAGS.test_files.split(','), create_model)
|
||||||
test_data = preprocess(FLAGS.test_files.split(','),
|
|
||||||
FLAGS.test_batch_size,
|
|
||||||
Config.n_input,
|
|
||||||
Config.n_context,
|
|
||||||
Config.alphabet,
|
|
||||||
hdf5_cache_path=FLAGS.test_cached_features_path)
|
|
||||||
|
|
||||||
graph = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1)
|
|
||||||
evaluate.evaluate(test_data, graph)
|
|
||||||
|
|
||||||
|
|
||||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||||
batch_size = batch_size if batch_size > 0 else None
|
batch_size = batch_size if batch_size > 0 else None
|
||||||
|
|
||||||
|
# Create feature computation graph
|
||||||
|
input_samples = tf.placeholder(tf.float32, [None], 'input_samples')
|
||||||
|
samples = tf.expand_dims(input_samples, -1)
|
||||||
|
mfccs, mfccs_len = samples_to_mfccs(samples, 16000)
|
||||||
|
mfccs = tf.identity(mfccs, name='mfccs')
|
||||||
|
mfccs_len = tf.identity(mfccs_len, name='mfccs_len')
|
||||||
|
|
||||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||||
input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
|
input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
|
||||||
seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths')
|
seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths')
|
||||||
@ -611,15 +571,20 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||||||
|
|
||||||
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)
|
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)
|
||||||
|
|
||||||
no_dropout = [0.0] * 6
|
# One rate per layer
|
||||||
|
no_dropout = [None] * 6
|
||||||
|
|
||||||
logits, layers = BiRNN(batch_x=input_tensor,
|
if tflite:
|
||||||
|
rnn_impl = rnn_impl_static_rnn
|
||||||
|
else:
|
||||||
|
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||||
|
|
||||||
|
logits, layers = create_model(batch_x=input_tensor,
|
||||||
seq_length=seq_length if FLAGS.use_seq_length else None,
|
seq_length=seq_length if FLAGS.use_seq_length else None,
|
||||||
dropout=no_dropout,
|
dropout=no_dropout,
|
||||||
batch_size=batch_size,
|
|
||||||
n_steps=n_steps,
|
|
||||||
previous_state=previous_state,
|
previous_state=previous_state,
|
||||||
tflite=tflite)
|
overlap=False,
|
||||||
|
rnn_impl=rnn_impl)
|
||||||
|
|
||||||
# TF Lite runtime will check that input dimensions are 1, 2 or 4
|
# TF Lite runtime will check that input dimensions are 1, 2 or 4
|
||||||
# by default we get 3, the middle one being batch_size which is forced to
|
# by default we get 3, the middle one being batch_size which is forced to
|
||||||
@ -659,10 +624,13 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||||||
{
|
{
|
||||||
'input': input_tensor,
|
'input': input_tensor,
|
||||||
'input_lengths': seq_length,
|
'input_lengths': seq_length,
|
||||||
|
'input_samples': input_samples,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'outputs': logits,
|
'outputs': logits,
|
||||||
'initialize_state': initialize_state,
|
'initialize_state': initialize_state,
|
||||||
|
'mfccs': mfccs,
|
||||||
|
'mfccs_len': mfccs_len,
|
||||||
},
|
},
|
||||||
layers
|
layers
|
||||||
)
|
)
|
||||||
@ -671,16 +639,24 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||||||
new_state_c = tf.identity(new_state_c, name='new_state_c')
|
new_state_c = tf.identity(new_state_c, name='new_state_c')
|
||||||
new_state_h = tf.identity(new_state_h, name='new_state_h')
|
new_state_h = tf.identity(new_state_h, name='new_state_h')
|
||||||
|
|
||||||
return (
|
inputs = {
|
||||||
{
|
|
||||||
'input': input_tensor,
|
'input': input_tensor,
|
||||||
'previous_state_c': previous_state_c,
|
'previous_state_c': previous_state_c,
|
||||||
'previous_state_h': previous_state_h,
|
'previous_state_h': previous_state_h,
|
||||||
},
|
'input_samples': input_samples,
|
||||||
|
}
|
||||||
|
|
||||||
|
if FLAGS.use_seq_length:
|
||||||
|
inputs.update({'input_lengths': seq_length})
|
||||||
|
|
||||||
|
return (
|
||||||
|
inputs,
|
||||||
{
|
{
|
||||||
'outputs': logits,
|
'outputs': logits,
|
||||||
'new_state_c': new_state_c,
|
'new_state_c': new_state_c,
|
||||||
'new_state_h': new_state_h,
|
'new_state_h': new_state_h,
|
||||||
|
'mfccs': mfccs,
|
||||||
|
'mfccs_len': mfccs_len,
|
||||||
},
|
},
|
||||||
layers
|
layers
|
||||||
)
|
)
|
||||||
@ -753,6 +729,8 @@ def export():
|
|||||||
|
|
||||||
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
||||||
converter.post_training_quantize = True
|
converter.post_training_quantize = True
|
||||||
|
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
|
||||||
|
converter.allow_custom_ops = True
|
||||||
tflite_model = converter.convert()
|
tflite_model = converter.convert()
|
||||||
|
|
||||||
with open(output_tflite_path, 'wb') as fout:
|
with open(output_tflite_path, 'wb') as fout:
|
||||||
@ -764,6 +742,7 @@ def export():
|
|||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
log_error(str(e))
|
log_error(str(e))
|
||||||
|
|
||||||
|
|
||||||
def do_single_file_inference(input_file_path):
|
def do_single_file_inference(input_file_path):
|
||||||
with tf.Session(config=Config.session_config) as session:
|
with tf.Session(config=Config.session_config) as session:
|
||||||
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
||||||
@ -784,22 +763,20 @@ def do_single_file_inference(input_file_path):
|
|||||||
saver.restore(session, checkpoint_path)
|
saver.restore(session, checkpoint_path)
|
||||||
session.run(outputs['initialize_state'])
|
session.run(outputs['initialize_state'])
|
||||||
|
|
||||||
features = audiofile_to_input_vector(input_file_path, Config.n_input, Config.n_context)
|
features, features_len = audiofile_to_features(input_file_path)
|
||||||
num_strides = len(features) - (Config.n_context * 2)
|
|
||||||
|
|
||||||
# Create a view into the array with overlapping strides of size
|
# Add batch dimension
|
||||||
# numcontext (past) + 1 (present) + numcontext (future)
|
features = tf.expand_dims(features, 0)
|
||||||
window_size = 2*Config.n_context+1
|
features_len = tf.expand_dims(features_len, 0)
|
||||||
features = np.lib.stride_tricks.as_strided(
|
|
||||||
features,
|
|
||||||
(num_strides, window_size, Config.n_input),
|
|
||||||
(features.strides[0], features.strides[0], features.strides[1]),
|
|
||||||
writeable=False)
|
|
||||||
|
|
||||||
logits = session.run(outputs['outputs'], feed_dict = {
|
# Evaluate
|
||||||
inputs['input']: [features],
|
features = create_overlapping_windows(features).eval(session=session)
|
||||||
inputs['input_lengths']: [num_strides],
|
features_len = features_len.eval(session=session)
|
||||||
})
|
|
||||||
|
logits = outputs['outputs'].eval(feed_dict={
|
||||||
|
inputs['input']: features,
|
||||||
|
inputs['input_lengths']: features_len,
|
||||||
|
}, session=session)
|
||||||
|
|
||||||
logits = np.squeeze(logits)
|
logits = np.squeeze(logits)
|
||||||
|
|
||||||
|
@ -334,7 +334,7 @@ Refer to the corresponding [README.md](native_client/README.md) for information
|
|||||||
|
|
||||||
### Exporting a model for TFLite
|
### Exporting a model for TFLite
|
||||||
|
|
||||||
If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the `--export_tflite` flag. If you already have a trained model, you can re-export it for TFLite by running `DeepSpeech.py` again and specifying the same `checkpoint_dir` that you used for training, as well as passing `--notrain --notest --export_tflite --export_dir /model/export/destination`.
|
If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the `--nouse_seq_length --export_tflite` flags. If you already have a trained model, you can re-export it for TFLite by running `DeepSpeech.py` again and specifying the same `checkpoint_dir` that you used for training, as well as passing `--notrain --notest --nouse_seq_length --export_tflite --export_dir /model/export/destination`.
|
||||||
|
|
||||||
### Making a mmap-able model for inference
|
### Making a mmap-able model for inference
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ else
|
|||||||
checkpoint_dir=$(python -c 'from xdg import BaseDirectory as xdg; print(xdg.save_data_path("deepspeech/ldc93s1"))')
|
checkpoint_dir=$(python -c 'from xdg import BaseDirectory as xdg; print(xdg.save_data_path("deepspeech/ldc93s1"))')
|
||||||
fi
|
fi
|
||||||
|
|
||||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
python -u DeepSpeech.py --noshow_progressbar --nodev \
|
||||||
--train_files data/ldc93s1/ldc93s1.csv \
|
--train_files data/ldc93s1/ldc93s1.csv \
|
||||||
--dev_files data/ldc93s1/ldc93s1.csv \
|
--dev_files data/ldc93s1/ldc93s1.csv \
|
||||||
--test_files data/ldc93s1/ldc93s1.csv \
|
--test_files data/ldc93s1/ldc93s1.csv \
|
||||||
|
@ -17,4 +17,4 @@ python -u DeepSpeech.py --noshow_progressbar \
|
|||||||
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
||||||
--lm_trie_path 'data/smoke_test/vocab.trie' \
|
--lm_trie_path 'data/smoke_test/vocab.trie' \
|
||||||
--notrain --notest \
|
--notrain --notest \
|
||||||
--export_tflite \
|
--export_tflite --nouse_seq_length \
|
||||||
|
159
evaluate.py
159
evaluate.py
@ -5,92 +5,67 @@ from __future__ import absolute_import, division, print_function
|
|||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
|
||||||
import pandas
|
|
||||||
import progressbar
|
import progressbar
|
||||||
import sys
|
|
||||||
import tables
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from collections import namedtuple
|
|
||||||
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||||
from multiprocessing import Pool, cpu_count
|
from multiprocessing import cpu_count
|
||||||
from six.moves import zip, range
|
from six.moves import zip, range
|
||||||
from util.audio import audiofile_to_input_vector
|
|
||||||
from util.config import Config, initialize_globals
|
from util.config import Config, initialize_globals
|
||||||
|
from util.evaluate_tools import calculate_report
|
||||||
|
from util.feeding import create_dataset
|
||||||
from util.flags import create_flags, FLAGS
|
from util.flags import create_flags, FLAGS
|
||||||
from util.logging import log_error
|
from util.logging import log_error
|
||||||
from util.preprocess import preprocess
|
from util.text import levenshtein
|
||||||
from util.text import Alphabet, levenshtein
|
|
||||||
from util.evaluate_tools import process_decode_result, calculate_report
|
|
||||||
|
|
||||||
def split_data(dataset, batch_size):
|
|
||||||
remainder = len(dataset) % batch_size
|
|
||||||
if remainder != 0:
|
|
||||||
dataset = dataset[:-remainder]
|
|
||||||
|
|
||||||
for i in range(0, len(dataset), batch_size):
|
|
||||||
yield dataset[i:i + batch_size]
|
|
||||||
|
|
||||||
|
|
||||||
def pad_to_dense(jagged):
|
def sparse_tensor_value_to_texts(value, alphabet):
|
||||||
maxlen = max(len(r) for r in jagged)
|
r"""
|
||||||
subshape = jagged[0].shape
|
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
|
||||||
|
representing its values, converting tokens to strings using ``alphabet``.
|
||||||
padded = np.zeros((len(jagged), maxlen) +
|
"""
|
||||||
subshape[1:], dtype=jagged[0].dtype)
|
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
|
||||||
for i, row in enumerate(jagged):
|
|
||||||
padded[i, :len(row)] = row
|
|
||||||
return padded
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(test_data, inference_graph):
|
def sparse_tuple_to_texts(tuple, alphabet):
|
||||||
|
indices = tuple[0]
|
||||||
|
values = tuple[1]
|
||||||
|
results = [''] * tuple[2][0]
|
||||||
|
for i in range(len(indices)):
|
||||||
|
index = indices[i][0]
|
||||||
|
results[index] += alphabet.string_from_label(values[i])
|
||||||
|
# List of strings
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(test_csvs, create_model):
|
||||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||||
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
|
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
|
||||||
Config.alphabet)
|
Config.alphabet)
|
||||||
|
|
||||||
|
test_set, test_batches = create_dataset(test_csvs,
|
||||||
|
batch_size=FLAGS.test_batch_size,
|
||||||
|
cache_path=FLAGS.test_cached_features_path)
|
||||||
|
it = test_set.make_one_shot_iterator()
|
||||||
|
|
||||||
def create_windows(features):
|
(batch_x, batch_x_len), batch_y = it.get_next()
|
||||||
num_strides = len(features) - (Config.n_context * 2)
|
|
||||||
|
|
||||||
# Create a view into the array with overlapping strides of size
|
# One rate per layer
|
||||||
# numcontext (past) + 1 (present) + numcontext (future)
|
no_dropout = [None] * 6
|
||||||
window_size = 2*Config.n_context+1
|
logits, _ = create_model(batch_x=batch_x,
|
||||||
features = np.lib.stride_tricks.as_strided(
|
seq_length=batch_x_len,
|
||||||
features,
|
dropout=no_dropout)
|
||||||
(num_strides, window_size, Config.n_input),
|
|
||||||
(features.strides[0], features.strides[0], features.strides[1]),
|
|
||||||
writeable=False)
|
|
||||||
|
|
||||||
return features
|
# Transpose to batch major and apply softmax for decoder
|
||||||
|
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
||||||
|
|
||||||
# Create overlapping windows over the features
|
loss = tf.nn.ctc_loss(labels=batch_y,
|
||||||
test_data['features'] = test_data['features'].apply(create_windows)
|
inputs=logits,
|
||||||
|
sequence_length=batch_x_len)
|
||||||
|
|
||||||
with tf.Session(config=Config.session_config) as session:
|
with tf.Session(config=Config.session_config) as session:
|
||||||
inputs, outputs, layers = inference_graph
|
|
||||||
|
|
||||||
# Transpose to batch major for decoder
|
|
||||||
transposed = tf.transpose(outputs['outputs'], [1, 0, 2])
|
|
||||||
|
|
||||||
labels_ph = tf.placeholder(tf.int32, [FLAGS.test_batch_size, None], name="labels")
|
|
||||||
label_lengths_ph = tf.placeholder(tf.int32, [FLAGS.test_batch_size], name="label_lengths")
|
|
||||||
|
|
||||||
# We add 1 to all elements of the transcript to avoid any zero values
|
|
||||||
# since we use that as an end-of-sequence token for converting the batch
|
|
||||||
# into a SparseTensor. So here we convert the placeholder back into a
|
|
||||||
# SparseTensor and subtract ones to get the real labels.
|
|
||||||
sparse_labels = tf.contrib.layers.dense_to_sparse(labels_ph)
|
|
||||||
neg_ones = tf.SparseTensor(sparse_labels.indices, -1 * tf.ones_like(sparse_labels.values), sparse_labels.dense_shape)
|
|
||||||
sparse_labels = tf.sparse_add(sparse_labels, neg_ones)
|
|
||||||
|
|
||||||
loss = tf.nn.ctc_loss(labels=sparse_labels,
|
|
||||||
inputs=layers['raw_logits'],
|
|
||||||
sequence_length=inputs['input_lengths'])
|
|
||||||
|
|
||||||
# Create a saver using variables from the above newly created graph
|
# Create a saver using variables from the above newly created graph
|
||||||
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
|
saver = tf.train.Saver()
|
||||||
saver = tf.train.Saver(mapping)
|
|
||||||
|
|
||||||
# Restore variables from training checkpoint
|
# Restore variables from training checkpoint
|
||||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
|
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
|
||||||
@ -103,51 +78,38 @@ def evaluate(test_data, inference_graph):
|
|||||||
|
|
||||||
logitses = []
|
logitses = []
|
||||||
losses = []
|
losses = []
|
||||||
|
seq_lengths = []
|
||||||
|
ground_truths = []
|
||||||
|
|
||||||
print('Computing acoustic model predictions...')
|
print('Computing acoustic model predictions...')
|
||||||
batch_count = len(test_data) // FLAGS.test_batch_size
|
bar = progressbar.ProgressBar(max_value=test_batches,
|
||||||
bar = progressbar.ProgressBar(max_value=batch_count,
|
|
||||||
widget=progressbar.AdaptiveETA)
|
widget=progressbar.AdaptiveETA)
|
||||||
|
|
||||||
# First pass, compute losses and transposed logits for decoding
|
# First pass, compute losses and transposed logits for decoding
|
||||||
for batch in bar(split_data(test_data, FLAGS.test_batch_size)):
|
for batch in bar(range(test_batches)):
|
||||||
session.run(outputs['initialize_state'])
|
logits, loss_, lengths, transcripts = session.run([transposed, loss, batch_x_len, batch_y])
|
||||||
|
|
||||||
features = pad_to_dense(batch['features'].values)
|
|
||||||
features_len = batch['features_len'].values
|
|
||||||
labels = pad_to_dense(batch['transcript'].values + 1)
|
|
||||||
label_lengths = batch['transcript_len'].values
|
|
||||||
|
|
||||||
logits, loss_ = session.run([transposed, loss], feed_dict={
|
|
||||||
inputs['input']: features,
|
|
||||||
inputs['input_lengths']: features_len,
|
|
||||||
labels_ph: labels,
|
|
||||||
label_lengths_ph: label_lengths
|
|
||||||
})
|
|
||||||
|
|
||||||
logitses.append(logits)
|
logitses.append(logits)
|
||||||
losses.extend(loss_)
|
losses.extend(loss_)
|
||||||
|
seq_lengths.append(lengths)
|
||||||
|
ground_truths.extend(sparse_tensor_value_to_texts(transcripts, Config.alphabet))
|
||||||
|
|
||||||
ground_truths = []
|
|
||||||
predictions = []
|
predictions = []
|
||||||
|
|
||||||
print('Decoding predictions...')
|
|
||||||
bar = progressbar.ProgressBar(max_value=batch_count,
|
|
||||||
widget=progressbar.AdaptiveETA)
|
|
||||||
|
|
||||||
# Get number of accessible CPU cores for this process
|
# Get number of accessible CPU cores for this process
|
||||||
try:
|
try:
|
||||||
num_processes = cpu_count()
|
num_processes = cpu_count()
|
||||||
except:
|
except:
|
||||||
num_processes = 1
|
num_processes = 1
|
||||||
|
|
||||||
# Second pass, decode logits and compute WER and edit distance metrics
|
print('Decoding predictions...')
|
||||||
for logits, batch in bar(zip(logitses, split_data(test_data, FLAGS.test_batch_size))):
|
bar = progressbar.ProgressBar(max_value=test_batches,
|
||||||
seq_lengths = batch['features_len'].values.astype(np.int32)
|
widget=progressbar.AdaptiveETA)
|
||||||
decoded = ctc_beam_search_decoder_batch(logits, seq_lengths, Config.alphabet, FLAGS.beam_width,
|
|
||||||
num_processes=num_processes, scorer=scorer)
|
|
||||||
|
|
||||||
ground_truths.extend(Config.alphabet.decode(l) for l in batch['transcript'])
|
# Second pass, decode logits and compute WER and edit distance metrics
|
||||||
|
for logits, seq_length in bar(zip(logitses, seq_lengths)):
|
||||||
|
decoded = ctc_beam_search_decoder_batch(logits, seq_length, Config.alphabet, FLAGS.beam_width,
|
||||||
|
num_processes=num_processes, scorer=scorer)
|
||||||
predictions.extend(d[0][1] for d in decoded)
|
predictions.extend(d[0][1] for d in decoded)
|
||||||
|
|
||||||
distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]
|
distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]
|
||||||
@ -179,21 +141,8 @@ def main(_):
|
|||||||
'the --test_files flag.')
|
'the --test_files flag.')
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
# sort examples by length, improves packing of batches and timesteps
|
from DeepSpeech import create_model
|
||||||
test_data = preprocess(
|
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||||
FLAGS.test_files.split(','),
|
|
||||||
FLAGS.test_batch_size,
|
|
||||||
alphabet=Config.alphabet,
|
|
||||||
numcep=Config.n_input,
|
|
||||||
numcontext=Config.n_context,
|
|
||||||
hdf5_cache_path=FLAGS.hdf5_test_set).sort_values(
|
|
||||||
by="features_len",
|
|
||||||
ascending=False)
|
|
||||||
|
|
||||||
from DeepSpeech import create_inference_graph
|
|
||||||
graph = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1)
|
|
||||||
|
|
||||||
samples = evaluate(test_data, graph)
|
|
||||||
|
|
||||||
if FLAGS.test_output_file:
|
if FLAGS.test_output_file:
|
||||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||||
|
@ -119,6 +119,10 @@ tf_cc_shared_object(
|
|||||||
"//tensorflow/core/kernels:control_flow_ops", # Enter
|
"//tensorflow/core/kernels:control_flow_ops", # Enter
|
||||||
"//tensorflow/core/kernels:tile_ops", # Tile
|
"//tensorflow/core/kernels:tile_ops", # Tile
|
||||||
"//tensorflow/core/kernels:gather_op", # Gather
|
"//tensorflow/core/kernels:gather_op", # Gather
|
||||||
|
"//tensorflow/core/kernels:mfcc_op", # Mfcc
|
||||||
|
"//tensorflow/core/kernels:spectrogram_op", # AudioSpectrogram
|
||||||
|
"//tensorflow/core/kernels:strided_slice_op", # StridedSlice
|
||||||
|
"//tensorflow/core/kernels:slice_op", # Slice, needed by StridedSlice
|
||||||
"//tensorflow/contrib/rnn:lstm_ops_kernels", # BlockLSTM
|
"//tensorflow/contrib/rnn:lstm_ops_kernels", # BlockLSTM
|
||||||
"//tensorflow/core/kernels:random_ops", # RandomGammaGrad
|
"//tensorflow/core/kernels:random_ops", # RandomGammaGrad
|
||||||
"//tensorflow/core/kernels:pack_op", # Pack
|
"//tensorflow/core/kernels:pack_op", # Pack
|
||||||
|
@ -108,7 +108,6 @@ using std::vector;
|
|||||||
struct StreamingState {
|
struct StreamingState {
|
||||||
vector<float> accumulated_logits;
|
vector<float> accumulated_logits;
|
||||||
vector<float> audio_buffer;
|
vector<float> audio_buffer;
|
||||||
float last_sample; // used for preemphasis
|
|
||||||
vector<float> mfcc_buffer;
|
vector<float> mfcc_buffer;
|
||||||
vector<float> batch_buffer;
|
vector<float> batch_buffer;
|
||||||
ModelState* model;
|
ModelState* model;
|
||||||
@ -152,10 +151,13 @@ struct ModelState {
|
|||||||
int input_node_idx;
|
int input_node_idx;
|
||||||
int previous_state_c_idx;
|
int previous_state_c_idx;
|
||||||
int previous_state_h_idx;
|
int previous_state_h_idx;
|
||||||
|
int input_samples_idx;
|
||||||
|
|
||||||
int logits_idx;
|
int logits_idx;
|
||||||
int new_state_c_idx;
|
int new_state_c_idx;
|
||||||
int new_state_h_idx;
|
int new_state_h_idx;
|
||||||
|
int mfccs_idx;
|
||||||
|
int mfccs_len_idx;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
ModelState();
|
ModelState();
|
||||||
@ -204,7 +206,9 @@ struct ModelState {
|
|||||||
*
|
*
|
||||||
* @param[out] output_logits Where to store computed logits.
|
* @param[out] output_logits Where to store computed logits.
|
||||||
*/
|
*/
|
||||||
void infer(const float* mfcc, unsigned int n_frames, vector<float>& output_logits);
|
void infer(const float* mfcc, unsigned int n_frames, vector<float>& logits_output);
|
||||||
|
|
||||||
|
void compute_mfcc(const vector<float> audio_buffer, vector<float>& mfcc_output);
|
||||||
};
|
};
|
||||||
|
|
||||||
StreamingState* setupStreamAndFeedAudioContent(ModelState* aCtx, const short* aBuffer,
|
StreamingState* setupStreamAndFeedAudioContent(ModelState* aCtx, const short* aBuffer,
|
||||||
@ -258,10 +262,9 @@ StreamingState::feedAudioContent(const short* buffer,
|
|||||||
// Consume all the data that was passed in, processing full buffers if needed
|
// Consume all the data that was passed in, processing full buffers if needed
|
||||||
while (buffer_size > 0) {
|
while (buffer_size > 0) {
|
||||||
while (buffer_size > 0 && audio_buffer.size() < AUDIO_WIN_LEN_SAMPLES) {
|
while (buffer_size > 0 && audio_buffer.size() < AUDIO_WIN_LEN_SAMPLES) {
|
||||||
// Apply preemphasis to input sample and buffer it
|
// Convert i16 sample into f32
|
||||||
float sample = (float)(*buffer) - (PREEMPHASIS_COEFF * last_sample);
|
float multiplier = 1.0f / (1 << 15);
|
||||||
audio_buffer.push_back(sample);
|
audio_buffer.push_back((float)(*buffer) * multiplier);
|
||||||
last_sample = *buffer;
|
|
||||||
++buffer;
|
++buffer;
|
||||||
--buffer_size;
|
--buffer_size;
|
||||||
}
|
}
|
||||||
@ -304,15 +307,11 @@ void
|
|||||||
StreamingState::processAudioWindow(const vector<float>& buf)
|
StreamingState::processAudioWindow(const vector<float>& buf)
|
||||||
{
|
{
|
||||||
// Compute MFCC features
|
// Compute MFCC features
|
||||||
float* mfcc;
|
vector<float> mfcc;
|
||||||
int n_frames = csf_mfcc(buf.data(), buf.size(), SAMPLE_RATE,
|
mfcc.reserve(MFCC_FEATURES);
|
||||||
AUDIO_WIN_LEN, AUDIO_WIN_STEP, MFCC_FEATURES, N_FILTERS, N_FFT,
|
model->compute_mfcc(buf, mfcc);
|
||||||
LOWFREQ, SAMPLE_RATE/2, 0.f, CEP_LIFTER, 1, hamming_window.data(),
|
|
||||||
&mfcc);
|
|
||||||
assert(n_frames == 1);
|
|
||||||
|
|
||||||
pushMfccBuffer(mfcc, n_frames * MFCC_FEATURES);
|
pushMfccBuffer(mfcc.data(), MFCC_FEATURES);
|
||||||
free(mfcc);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
@ -396,7 +395,7 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
|
|||||||
input_mapped(i) = aMfcc[i];
|
input_mapped(i) = aMfcc[i];
|
||||||
}
|
}
|
||||||
for (; i < n_steps*mfcc_feats_per_timestep; ++i) {
|
for (; i < n_steps*mfcc_feats_per_timestep; ++i) {
|
||||||
input_mapped(i) = 0;
|
input_mapped(i) = 0.;
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor input_lengths(DT_INT32, TensorShape({1}));
|
Tensor input_lengths(DT_INT32, TensorShape({1}));
|
||||||
@ -454,6 +453,53 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
|
|||||||
#endif // USE_TFLITE
|
#endif // USE_TFLITE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
ModelState::compute_mfcc(const vector<float> samples, vector<float>& mfcc_output)
|
||||||
|
{
|
||||||
|
#ifndef USE_TFLITE
|
||||||
|
Tensor input(DT_FLOAT, TensorShape({static_cast<long long>(samples.size())}));
|
||||||
|
auto input_mapped = input.flat<float>();
|
||||||
|
for (int i = 0; i < samples.size(); ++i) {
|
||||||
|
input_mapped(i) = samples[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<Tensor> outputs;
|
||||||
|
Status status = session->Run({{"input_samples", input}}, {"mfccs", "mfccs_len"}, {}, &outputs);
|
||||||
|
|
||||||
|
if (!status.ok()) {
|
||||||
|
std::cerr << "Error running session: " << status << "\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto mfcc_len_mapped = outputs[1].flat<int32>();
|
||||||
|
int n_windows = mfcc_len_mapped(0);
|
||||||
|
|
||||||
|
auto mfcc_mapped = outputs[0].flat<float>();
|
||||||
|
for (int i = 0; i < n_windows * MFCC_FEATURES; ++i) {
|
||||||
|
mfcc_output.push_back(mfcc_mapped(i));
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// Feeding input_node
|
||||||
|
float* input_samples = interpreter->typed_tensor<float>(input_samples_idx);
|
||||||
|
for (int i = 0; i < samples.size(); ++i) {
|
||||||
|
input_samples[i] = samples[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus status = interpreter->Invoke();
|
||||||
|
if (status != kTfLiteOk) {
|
||||||
|
std::cerr << "Error running session: " << status << "\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int n_windows = *interpreter->typed_tensor<float>(mfccs_len_idx);
|
||||||
|
|
||||||
|
float* outputs = interpreter->typed_tensor<float>(mfccs_idx);
|
||||||
|
for (int i = 0; i < n_windows * MFCC_FEATURES; ++i) {
|
||||||
|
mfcc_output.push_back(outputs[i]);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
char*
|
char*
|
||||||
ModelState::decode(vector<float>& logits)
|
ModelState::decode(vector<float>& logits)
|
||||||
{
|
{
|
||||||
@ -640,8 +686,6 @@ DS_CreateModel(const char* aModelPath,
|
|||||||
*retval = model.release();
|
*retval = model.release();
|
||||||
return DS_ERR_OK;
|
return DS_ERR_OK;
|
||||||
#else // USE_TFLITE
|
#else // USE_TFLITE
|
||||||
TfLiteStatus status;
|
|
||||||
|
|
||||||
model->fbmodel = tflite::FlatBufferModel::BuildFromFile(aModelPath);
|
model->fbmodel = tflite::FlatBufferModel::BuildFromFile(aModelPath);
|
||||||
if (!model->fbmodel) {
|
if (!model->fbmodel) {
|
||||||
std::cerr << "Error at reading model file " << aModelPath << std::endl;
|
std::cerr << "Error at reading model file " << aModelPath << std::endl;
|
||||||
@ -663,9 +707,12 @@ DS_CreateModel(const char* aModelPath,
|
|||||||
model->input_node_idx = tflite_get_input_tensor_by_name(model.get(), "input_node");
|
model->input_node_idx = tflite_get_input_tensor_by_name(model.get(), "input_node");
|
||||||
model->previous_state_c_idx = tflite_get_input_tensor_by_name(model.get(), "previous_state_c");
|
model->previous_state_c_idx = tflite_get_input_tensor_by_name(model.get(), "previous_state_c");
|
||||||
model->previous_state_h_idx = tflite_get_input_tensor_by_name(model.get(), "previous_state_h");
|
model->previous_state_h_idx = tflite_get_input_tensor_by_name(model.get(), "previous_state_h");
|
||||||
|
model->input_samples_idx = tflite_get_input_tensor_by_name(model.get(), "input_samples");
|
||||||
model->logits_idx = tflite_get_output_tensor_by_name(model.get(), "logits");
|
model->logits_idx = tflite_get_output_tensor_by_name(model.get(), "logits");
|
||||||
model->new_state_c_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_c");
|
model->new_state_c_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_c");
|
||||||
model->new_state_h_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_h");
|
model->new_state_h_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_h");
|
||||||
|
model->mfccs_idx = tflite_get_output_tensor_by_name(model.get(), "mfccs");
|
||||||
|
model->mfccs_len_idx = tflite_get_output_tensor_by_name(model.get(), "mfccs_len");
|
||||||
|
|
||||||
TfLiteIntArray* dims_input_node = model->interpreter->tensor(model->input_node_idx)->dims;
|
TfLiteIntArray* dims_input_node = model->interpreter->tensor(model->input_node_idx)->dims;
|
||||||
|
|
||||||
@ -796,7 +843,6 @@ DS_SetupStream(ModelState* aCtx,
|
|||||||
ctx->accumulated_logits.reserve(aPreAllocFrames * BATCH_SIZE * num_classes);
|
ctx->accumulated_logits.reserve(aPreAllocFrames * BATCH_SIZE * num_classes);
|
||||||
|
|
||||||
ctx->audio_buffer.reserve(AUDIO_WIN_LEN_SAMPLES);
|
ctx->audio_buffer.reserve(AUDIO_WIN_LEN_SAMPLES);
|
||||||
ctx->last_sample = 0;
|
|
||||||
ctx->mfcc_buffer.reserve(aCtx->mfcc_feats_per_timestep);
|
ctx->mfcc_buffer.reserve(aCtx->mfcc_feats_per_timestep);
|
||||||
ctx->mfcc_buffer.resize(MFCC_FEATURES*aCtx->n_context, 0.f);
|
ctx->mfcc_buffer.resize(MFCC_FEATURES*aCtx->n_context, 0.f);
|
||||||
ctx->batch_buffer.reserve(aCtx->n_steps * aCtx->mfcc_feats_per_timestep);
|
ctx->batch_buffer.reserve(aCtx->n_steps * aCtx->mfcc_feats_per_timestep);
|
||||||
|
@ -1,19 +1,23 @@
|
|||||||
pandas
|
# Main training requirements
|
||||||
progressbar2
|
|
||||||
python-utils
|
|
||||||
tensorflow == 1.13.1
|
tensorflow == 1.13.1
|
||||||
numpy == 1.15.4
|
numpy == 1.15.4
|
||||||
matplotlib
|
progressbar2
|
||||||
scipy
|
pandas
|
||||||
sox
|
|
||||||
paramiko >= 2.1
|
|
||||||
python_speech_features
|
|
||||||
pyxdg
|
|
||||||
bs4
|
|
||||||
six
|
six
|
||||||
requests
|
pyxdg
|
||||||
tables
|
|
||||||
attrdict
|
attrdict
|
||||||
|
|
||||||
|
# Requirements for building native_client files
|
||||||
setuptools
|
setuptools
|
||||||
|
|
||||||
|
# Requirements for importers
|
||||||
|
sox
|
||||||
|
bs4
|
||||||
|
requests
|
||||||
librosa
|
librosa
|
||||||
soundfile
|
soundfile
|
||||||
|
|
||||||
|
# Miscellaneous scripts
|
||||||
|
paramiko >= 2.1
|
||||||
|
scipy
|
||||||
|
matplotlib
|
||||||
|
@ -1,24 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import scipy.io.wavfile as wav
|
|
||||||
|
|
||||||
from python_speech_features import mfcc
|
|
||||||
|
|
||||||
|
|
||||||
def audiofile_to_input_vector(audio_filename, numcep, numcontext):
|
|
||||||
r"""
|
|
||||||
Given a WAV audio file at ``audio_filename``, calculates ``numcep`` MFCC features
|
|
||||||
at every 0.01s time step with a window length of 0.025s. Appends ``numcontext``
|
|
||||||
context frames to the left and right of each time step, and returns this data
|
|
||||||
in a numpy array.
|
|
||||||
"""
|
|
||||||
# Load wav files
|
|
||||||
fs, audio = wav.read(audio_filename)
|
|
||||||
|
|
||||||
# Get mfcc coefficients
|
|
||||||
features = mfcc(audio, samplerate=fs, numcep=numcep, winlen=0.032, winstep=0.02, winfunc=np.hamming)
|
|
||||||
|
|
||||||
# Add empty initial and final contexts
|
|
||||||
empty_context = np.zeros((numcontext, numcep), dtype=features.dtype)
|
|
||||||
features = np.concatenate((empty_context, features, empty_context))
|
|
||||||
|
|
||||||
return features
|
|
263
util/feeding.py
263
util/feeding.py
@ -1,198 +1,97 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import pandas
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from math import ceil
|
from functools import partial
|
||||||
from six.moves import range
|
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
|
||||||
from threading import Thread
|
from util.config import Config
|
||||||
from util.gpu import get_available_gpus
|
from util.text import text_to_char_array
|
||||||
|
|
||||||
|
|
||||||
class ModelFeeder(object):
|
def read_csvs(csv_files):
|
||||||
'''
|
source_data = None
|
||||||
Feeds data into a model.
|
for csv in csv_files:
|
||||||
Feeding is parallelized by independent units called tower feeders (usually one per GPU).
|
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
|
||||||
Each tower feeder provides data from runtime switchable sources (train, dev).
|
#FIXME: not cross-platform
|
||||||
These sources are to be provided by the DataSet instances whose references are kept.
|
csv_dir = os.path.dirname(os.path.abspath(csv))
|
||||||
Creates, owns and delegates to tower_feeder_count internal tower feeder objects.
|
file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1)))
|
||||||
'''
|
if source_data is None:
|
||||||
def __init__(self,
|
source_data = file
|
||||||
train_set,
|
else:
|
||||||
dev_set,
|
source_data = source_data.append(file)
|
||||||
numcep,
|
return source_data
|
||||||
numcontext,
|
|
||||||
alphabet,
|
|
||||||
tower_feeder_count=-1,
|
|
||||||
threads_per_queue=4):
|
|
||||||
|
|
||||||
self.train = train_set
|
|
||||||
self.dev = dev_set
|
|
||||||
self.sets = [train_set, dev_set]
|
|
||||||
self.numcep = numcep
|
|
||||||
self.numcontext = numcontext
|
|
||||||
self.tower_feeder_count = max(len(get_available_gpus()), 1) if tower_feeder_count < 0 else tower_feeder_count
|
|
||||||
self.threads_per_queue = threads_per_queue
|
|
||||||
|
|
||||||
self.ph_x = tf.placeholder(tf.float32, [None, 2*numcontext+1, numcep])
|
|
||||||
self.ph_x_length = tf.placeholder(tf.int32, [])
|
|
||||||
self.ph_y = tf.placeholder(tf.int32, [None,])
|
|
||||||
self.ph_y_length = tf.placeholder(tf.int32, [])
|
|
||||||
self.ph_batch_size = tf.placeholder(tf.int32, [])
|
|
||||||
self.ph_queue_selector = tf.placeholder(tf.int32, name='Queue_Selector')
|
|
||||||
|
|
||||||
self._tower_feeders = [_TowerFeeder(self, i, alphabet) for i in range(self.tower_feeder_count)]
|
|
||||||
|
|
||||||
def start_queue_threads(self, session, coord):
|
|
||||||
'''
|
|
||||||
Starts required queue threads on all tower feeders.
|
|
||||||
'''
|
|
||||||
queue_threads = []
|
|
||||||
for tower_feeder in self._tower_feeders:
|
|
||||||
queue_threads += tower_feeder.start_queue_threads(session, coord)
|
|
||||||
return queue_threads
|
|
||||||
|
|
||||||
def close_queues(self, session):
|
|
||||||
'''
|
|
||||||
Closes queues of all tower feeders.
|
|
||||||
'''
|
|
||||||
for tower_feeder in self._tower_feeders:
|
|
||||||
tower_feeder.close_queues(session)
|
|
||||||
|
|
||||||
def set_data_set(self, feed_dict, data_set):
|
|
||||||
'''
|
|
||||||
Switches all tower feeders to a different source DataSet.
|
|
||||||
The provided feed_dict will get enriched with required placeholder/value pairs.
|
|
||||||
The DataSet has to be one of those that got passed into the constructor.
|
|
||||||
'''
|
|
||||||
index = self.sets.index(data_set)
|
|
||||||
assert index >= 0
|
|
||||||
feed_dict[self.ph_queue_selector] = index
|
|
||||||
feed_dict[self.ph_batch_size] = data_set.batch_size
|
|
||||||
|
|
||||||
def next_batch(self, tower_feeder_index):
|
|
||||||
'''
|
|
||||||
Draw the next batch from one of the tower feeders.
|
|
||||||
'''
|
|
||||||
return self._tower_feeders[tower_feeder_index].next_batch()
|
|
||||||
|
|
||||||
|
|
||||||
class DataSet(object):
|
def samples_to_mfccs(samples, sample_rate):
|
||||||
'''
|
spectrogram = contrib_audio.audio_spectrogram(samples, window_size=512, stride=320, magnitude_squared=True)
|
||||||
Represents a collection of audio samples and their respective transcriptions.
|
mfccs = contrib_audio.mfcc(spectrogram, sample_rate, dct_coefficient_count=Config.n_input)
|
||||||
Takes a set of CSV files produced by importers in /bin.
|
mfccs = tf.reshape(mfccs, [-1, Config.n_input])
|
||||||
'''
|
|
||||||
def __init__(self, data, batch_size, skip=0, limit=0, ascending=True, next_index=lambda i: i + 1):
|
return mfccs, tf.shape(mfccs)[0]
|
||||||
self.data = data
|
|
||||||
self.data.sort_values(by="features_len", ascending=ascending, inplace=True)
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.next_index = next_index
|
|
||||||
self.total_batches = int(ceil(len(self.data) / batch_size))
|
|
||||||
|
|
||||||
|
|
||||||
class _DataSetLoader(object):
|
def audiofile_to_features(wav_filename):
|
||||||
'''
|
samples = tf.read_file(wav_filename)
|
||||||
Internal class that represents an input queue with data from one of the DataSet objects.
|
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
|
||||||
Each tower feeder will create and combine three data set loaders to one switchable queue.
|
features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate)
|
||||||
Keeps a ModelFeeder reference for accessing shared settings and placeholders.
|
|
||||||
Keeps a DataSet reference to access its samples.
|
|
||||||
'''
|
|
||||||
def __init__(self, model_feeder, data_set, alphabet):
|
|
||||||
self._model_feeder = model_feeder
|
|
||||||
self._data_set = data_set
|
|
||||||
self.queue = tf.PaddingFIFOQueue(shapes=[[None, 2 * model_feeder.numcontext + 1, model_feeder.numcep], [], [None,], []],
|
|
||||||
dtypes=[tf.float32, tf.int32, tf.int32, tf.int32],
|
|
||||||
capacity=data_set.batch_size * 8)
|
|
||||||
self._enqueue_op = self.queue.enqueue([model_feeder.ph_x, model_feeder.ph_x_length, model_feeder.ph_y, model_feeder.ph_y_length])
|
|
||||||
self._close_op = self.queue.close(cancel_pending_enqueues=True)
|
|
||||||
self._alphabet = alphabet
|
|
||||||
|
|
||||||
def start_queue_threads(self, session, coord):
|
return features, features_len
|
||||||
'''
|
|
||||||
Starts concurrent queue threads for reading samples from the data set.
|
|
||||||
'''
|
|
||||||
queue_threads = [Thread(target=self._populate_batch_queue, args=(session, coord))
|
|
||||||
for i in range(self._model_feeder.threads_per_queue)]
|
|
||||||
for queue_thread in queue_threads:
|
|
||||||
coord.register_thread(queue_thread)
|
|
||||||
queue_thread.daemon = True
|
|
||||||
queue_thread.start()
|
|
||||||
return queue_threads
|
|
||||||
|
|
||||||
def close_queue(self, session):
|
|
||||||
'''
|
|
||||||
Closes the data set queue.
|
|
||||||
'''
|
|
||||||
session.run(self._close_op)
|
|
||||||
|
|
||||||
def _populate_batch_queue(self, session, coord):
|
|
||||||
'''
|
|
||||||
Queue thread routine.
|
|
||||||
'''
|
|
||||||
file_count = len(self._data_set.data)
|
|
||||||
index = -1
|
|
||||||
while not coord.should_stop():
|
|
||||||
index = self._data_set.next_index(index) % file_count
|
|
||||||
features, num_strides, transcript, transcript_len = self._data_set.data.iloc[index]
|
|
||||||
|
|
||||||
# Create a view into the array with overlapping strides of size
|
|
||||||
# numcontext (past) + 1 (present) + numcontext (future)
|
|
||||||
window_size = 2*self._model_feeder.numcontext+1
|
|
||||||
features = np.lib.stride_tricks.as_strided(
|
|
||||||
features,
|
|
||||||
(num_strides, window_size, self._model_feeder.numcep),
|
|
||||||
(features.strides[0], features.strides[0], features.strides[1]),
|
|
||||||
writeable=False)
|
|
||||||
|
|
||||||
# We add 1 to all elements of the transcript here to avoid any zero
|
|
||||||
# values since we use that as an end-of-sequence token for converting
|
|
||||||
# the batch into a SparseTensor.
|
|
||||||
try:
|
|
||||||
session.run(self._enqueue_op, feed_dict={
|
|
||||||
self._model_feeder.ph_x: features,
|
|
||||||
self._model_feeder.ph_x_length: num_strides,
|
|
||||||
self._model_feeder.ph_y: transcript + 1,
|
|
||||||
self._model_feeder.ph_y_length: transcript_len
|
|
||||||
})
|
|
||||||
except tf.errors.CancelledError:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
class _TowerFeeder(object):
|
def entry_to_features(wav_filename, transcript):
|
||||||
'''
|
# https://bugs.python.org/issue32117
|
||||||
Internal class that represents a switchable input queue for one tower.
|
features, features_len = audiofile_to_features(wav_filename)
|
||||||
It creates, owns and combines three _DataSetLoader instances.
|
return features, features_len, tf.SparseTensor(*transcript)
|
||||||
Keeps a ModelFeeder reference for accessing shared settings and placeholders.
|
|
||||||
'''
|
|
||||||
def __init__(self, model_feeder, index, alphabet):
|
|
||||||
self._model_feeder = model_feeder
|
|
||||||
self.index = index
|
|
||||||
self._loaders = [_DataSetLoader(model_feeder, data_set, alphabet) for data_set in model_feeder.sets]
|
|
||||||
self._queues = [set_queue.queue for set_queue in self._loaders]
|
|
||||||
self._queue = tf.QueueBase.from_list(model_feeder.ph_queue_selector, self._queues)
|
|
||||||
self._close_op = self._queue.close(cancel_pending_enqueues=True)
|
|
||||||
|
|
||||||
def next_batch(self):
|
|
||||||
'''
|
|
||||||
Draw the next batch from from the combined switchable queue.
|
|
||||||
'''
|
|
||||||
source, source_lengths, target, target_lengths = self._queue.dequeue_many(self._model_feeder.ph_batch_size)
|
|
||||||
# Back to sparse, then subtract one to get the real labels
|
|
||||||
sparse_labels = tf.contrib.layers.dense_to_sparse(target)
|
|
||||||
neg_ones = tf.SparseTensor(sparse_labels.indices, -1 * tf.ones_like(sparse_labels.values), sparse_labels.dense_shape)
|
|
||||||
return source, source_lengths, tf.sparse_add(sparse_labels, neg_ones)
|
|
||||||
|
|
||||||
def start_queue_threads(self, session, coord):
|
def to_sparse_tuple(sequence):
|
||||||
'''
|
r"""Creates a sparse representention of ``sequence``.
|
||||||
Starts the queue threads of all owned _DataSetLoader instances.
|
Returns a tuple with (indices, values, shape)
|
||||||
'''
|
"""
|
||||||
queue_threads = []
|
indices = np.asarray(list(zip([0]*len(sequence), range(len(sequence)))), dtype=np.int64)
|
||||||
for set_queue in self._loaders:
|
shape = np.asarray([1, len(sequence)], dtype=np.int64)
|
||||||
queue_threads += set_queue.start_queue_threads(session, coord)
|
return indices, sequence, shape
|
||||||
return queue_threads
|
|
||||||
|
|
||||||
def close_queues(self, session):
|
|
||||||
'''
|
|
||||||
Closes queues of all owned _DataSetLoader instances.
|
|
||||||
'''
|
|
||||||
for set_queue in self._loaders:
|
|
||||||
set_queue.close_queue(session)
|
|
||||||
|
|
||||||
|
def create_dataset(csvs, batch_size, cache_path):
|
||||||
|
df = read_csvs(csvs)
|
||||||
|
df.sort_values(by='wav_filesize', inplace=True)
|
||||||
|
|
||||||
|
num_batches = len(df) // batch_size
|
||||||
|
|
||||||
|
# Convert to character index arrays
|
||||||
|
df['transcript'] = df['transcript'].apply(partial(text_to_char_array, alphabet=Config.alphabet))
|
||||||
|
|
||||||
|
def generate_values():
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
yield row.wav_filename, to_sparse_tuple(row.transcript)
|
||||||
|
|
||||||
|
# Batching a dataset of 2D SparseTensors creates 3D batches, which fail
|
||||||
|
# when passed to tf.nn.ctc_loss, so we reshape them to remove the extra
|
||||||
|
# dimension here.
|
||||||
|
def sparse_reshape(sparse):
|
||||||
|
shape = sparse.dense_shape
|
||||||
|
return tf.sparse.reshape(sparse, [shape[0], shape[2]])
|
||||||
|
|
||||||
|
def batch_fn(features, features_len, transcripts):
|
||||||
|
features = tf.data.Dataset.zip((features, features_len))
|
||||||
|
features = features.padded_batch(batch_size,
|
||||||
|
padded_shapes=([None, Config.n_input], []))
|
||||||
|
transcripts = transcripts.batch(batch_size).map(sparse_reshape)
|
||||||
|
return tf.data.Dataset.zip((features, transcripts))
|
||||||
|
|
||||||
|
num_gpus = len(Config.available_devices)
|
||||||
|
|
||||||
|
dataset = (tf.data.Dataset.from_generator(generate_values,
|
||||||
|
output_types=(tf.string, (tf.int64, tf.int32, tf.int64)))
|
||||||
|
.map(entry_to_features, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||||
|
.cache(cache_path)
|
||||||
|
.window(batch_size, drop_remainder=True).flat_map(batch_fn)
|
||||||
|
.prefetch(num_gpus)
|
||||||
|
.repeat())
|
||||||
|
|
||||||
|
return dataset, num_batches
|
||||||
|
@ -23,6 +23,7 @@ def create_flags():
|
|||||||
# ================
|
# ================
|
||||||
|
|
||||||
tf.app.flags.DEFINE_boolean ('train', True, 'whether to train the network')
|
tf.app.flags.DEFINE_boolean ('train', True, 'whether to train the network')
|
||||||
|
tf.app.flags.DEFINE_boolean ('dev', True, 'whether to run validation epochs')
|
||||||
tf.app.flags.DEFINE_boolean ('test', True, 'whether to test the network')
|
tf.app.flags.DEFINE_boolean ('test', True, 'whether to test the network')
|
||||||
tf.app.flags.DEFINE_integer ('epoch', 75, 'target epoch to train - if negative, the absolute number of additional epochs will be trained')
|
tf.app.flags.DEFINE_integer ('epoch', 75, 'target epoch to train - if negative, the absolute number of additional epochs will be trained')
|
||||||
|
|
||||||
|
@ -1,101 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import pandas
|
|
||||||
import tables
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
from multiprocessing.dummy import Pool
|
|
||||||
from util.audio import audiofile_to_input_vector
|
|
||||||
from util.text import text_to_char_array
|
|
||||||
|
|
||||||
def pmap(fun, iterable):
|
|
||||||
pool = Pool()
|
|
||||||
results = pool.map(fun, iterable)
|
|
||||||
pool.close()
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def process_single_file(row, numcep, numcontext, alphabet):
|
|
||||||
# row = index, Series
|
|
||||||
_, file = row
|
|
||||||
features = audiofile_to_input_vector(file.wav_filename, numcep, numcontext)
|
|
||||||
features_len = len(features) - 2*numcontext
|
|
||||||
transcript = text_to_char_array(file.transcript, alphabet)
|
|
||||||
|
|
||||||
if features_len < len(transcript):
|
|
||||||
raise ValueError('Error: Audio file {} is too short for transcription.'.format(file.wav_filename))
|
|
||||||
|
|
||||||
return features, features_len, transcript, len(transcript)
|
|
||||||
|
|
||||||
|
|
||||||
# load samples from CSV, compute features, optionally cache results on disk
|
|
||||||
def preprocess(csv_files, batch_size, numcep, numcontext, alphabet, hdf5_cache_path=None):
|
|
||||||
COLUMNS = ('features', 'features_len', 'transcript', 'transcript_len')
|
|
||||||
|
|
||||||
print('Preprocessing', csv_files)
|
|
||||||
|
|
||||||
if hdf5_cache_path and os.path.exists(hdf5_cache_path):
|
|
||||||
with tables.open_file(hdf5_cache_path, 'r') as file:
|
|
||||||
features = file.root.features[:]
|
|
||||||
features_len = file.root.features_len[:]
|
|
||||||
transcript = file.root.transcript[:]
|
|
||||||
transcript_len = file.root.transcript_len[:]
|
|
||||||
|
|
||||||
# features are stored flattened, so reshape into [n_steps, numcep]
|
|
||||||
for i in range(len(features)):
|
|
||||||
features[i].shape = [features_len[i]+2*numcontext, numcep]
|
|
||||||
|
|
||||||
in_data = list(zip(features, features_len,
|
|
||||||
transcript, transcript_len))
|
|
||||||
print('Loaded from cache at', hdf5_cache_path)
|
|
||||||
return pandas.DataFrame(data=in_data, columns=COLUMNS)
|
|
||||||
|
|
||||||
source_data = None
|
|
||||||
for csv in csv_files:
|
|
||||||
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
|
|
||||||
#FIXME: not cross-platform
|
|
||||||
csv_dir = os.path.dirname(os.path.abspath(csv))
|
|
||||||
file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1)))
|
|
||||||
if source_data is None:
|
|
||||||
source_data = file
|
|
||||||
else:
|
|
||||||
source_data = source_data.append(file)
|
|
||||||
|
|
||||||
step_fn = partial(process_single_file,
|
|
||||||
numcep=numcep,
|
|
||||||
numcontext=numcontext,
|
|
||||||
alphabet=alphabet)
|
|
||||||
out_data = pmap(step_fn, source_data.iterrows())
|
|
||||||
|
|
||||||
if hdf5_cache_path:
|
|
||||||
print('Saving to', hdf5_cache_path)
|
|
||||||
|
|
||||||
# list of tuples -> tuple of lists
|
|
||||||
features, features_len, transcript, transcript_len = zip(*out_data)
|
|
||||||
|
|
||||||
with tables.open_file(hdf5_cache_path, 'w') as file:
|
|
||||||
features_dset = file.create_vlarray(file.root,
|
|
||||||
'features',
|
|
||||||
tables.Float32Atom(),
|
|
||||||
filters=tables.Filters(complevel=1))
|
|
||||||
# VLArray atoms need to be 1D, so flatten feature array
|
|
||||||
for f in features:
|
|
||||||
features_dset.append(np.reshape(f, -1))
|
|
||||||
|
|
||||||
features_len_dset = file.create_array(file.root,
|
|
||||||
'features_len',
|
|
||||||
features_len)
|
|
||||||
|
|
||||||
transcript_dset = file.create_vlarray(file.root,
|
|
||||||
'transcript',
|
|
||||||
tables.Int32Atom(),
|
|
||||||
filters=tables.Filters(complevel=1))
|
|
||||||
for t in transcript:
|
|
||||||
transcript_dset.append(t)
|
|
||||||
|
|
||||||
transcript_len_dset = file.create_array(file.root,
|
|
||||||
'transcript_len',
|
|
||||||
transcript_len)
|
|
||||||
|
|
||||||
print('Preprocessing done')
|
|
||||||
return pandas.DataFrame(data=out_data, columns=COLUMNS)
|
|
Loading…
x
Reference in New Issue
Block a user