Rewrite input pipeline to use tf.data API

This commit is contained in:
Reuben Morais 2019-03-22 21:14:10 -03:00
parent bd7358d94e
commit 1cea2b0fe8
12 changed files with 432 additions and 677 deletions

View File

@ -18,12 +18,10 @@ import tensorflow as tf
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
from six.moves import zip, range
from tensorflow.python.tools import freeze_graph
from util.audio import audiofile_to_input_vector
from util.config import Config, initialize_globals
from util.feeding import DataSet, ModelFeeder
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from util.flags import create_flags, FLAGS
from util.logging import log_info, log_error, log_debug, log_warn
from util.preprocess import preprocess
# Graph Creation
@ -42,25 +40,82 @@ def variable_on_cpu(name, shape, initializer):
return var
def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1, previous_state=None, tflite=False):
r'''
That done, we will define the learned variables, the weights and biases,
within the method ``BiRNN()`` which also constructs the neural network.
The variables named ``hn``, where ``n`` is an integer, hold the learned weight variables.
The variables named ``bn``, where ``n`` is an integer, hold the learned bias variables.
In particular, the first variable ``h1`` holds the learned weight matrix that
converts an input vector of dimension ``n_input + 2*n_input*n_context``
to a vector of dimension ``n_hidden_1``.
Similarly, the second variable ``h2`` holds the weight matrix converting
an input vector of dimension ``n_hidden_1`` to one of dimension ``n_hidden_2``.
The variables ``h3``, ``h5``, and ``h6`` are similar.
Likewise, the biases, ``b1``, ``b2``..., hold the biases for the various layers.
'''
def create_overlapping_windows(batch_x):
batch_size = tf.shape(batch_x)[0]
window_width = 2 * Config.n_context + 1
num_channels = Config.n_input
# Create a constant convolution filter using an identity matrix, so that the
# convolution returns patches of the input tensor as is, and we can create
# overlapping windows over the MFCCs.
eye_filter = tf.constant(np.eye(window_width * num_channels)
.reshape(window_width, num_channels, window_width * num_channels), tf.float32)
# Create overlapping windows
batch_x = tf.nn.conv1d(batch_x, eye_filter, stride=1, padding='SAME')
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
return batch_x
def dense(name, x, units, dropout_rate=None, relu=True):
with tf.variable_scope(name):
bias = variable_on_cpu('bias', [units], tf.zeros_initializer())
weights = variable_on_cpu('weights', [x.shape[-1], units], tf.contrib.layers.xavier_initializer())
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
if relu:
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
if dropout_rate is not None:
output = tf.nn.dropout(output, rate=dropout_rate)
return output
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
# Forward direction cell:
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim, reuse=reuse)
output, output_state = fw_cell(inputs=x,
dtype=tf.float32,
sequence_length=seq_length,
initial_state=previous_state)
return output, output_state
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
# Forward direction cell:
fw_cell = tf.nn.rnn_cell.LSTMCell(Config.n_cell_dim, reuse=reuse)
# Split rank N tensor into list of rank N-1 tensors
x = [x[l] for l in range(x.shape[0])]
# We parametrize the RNN implementation as the training and inference graph
# need to do different things here.
output, output_state = tf.nn.static_rnn(cell=fw_cell,
inputs=x,
initial_state=previous_state,
dtype=tf.float32,
sequence_length=seq_length)
output = tf.concat(output, 0)
return output, output_state
def create_model(batch_x, seq_length, dropout, reuse=False, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
layers = {}
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
if not batch_size:
batch_size = tf.shape(batch_x)[0]
batch_size = tf.shape(batch_x)[0]
# Create overlapping feature windows if needed
if overlap:
batch_x = create_overlapping_windows(batch_x)
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
@ -73,58 +128,17 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1
# The next three blocks will pass `batch_x` through three hidden layers with
# clipped RELU activation and dropout.
# 1st layer
b1 = variable_on_cpu('b1', [Config.n_hidden_1], tf.zeros_initializer())
h1 = variable_on_cpu('h1', [Config.n_input + 2*Config.n_input*Config.n_context, Config.n_hidden_1], tf.contrib.layers.xavier_initializer())
layer_1 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(batch_x, h1), b1)), FLAGS.relu_clip)
layer_1 = tf.nn.dropout(layer_1, rate=dropout[0])
layers['layer_1'] = layer_1
# 2nd layer
b2 = variable_on_cpu('b2', [Config.n_hidden_2], tf.zeros_initializer())
h2 = variable_on_cpu('h2', [Config.n_hidden_1, Config.n_hidden_2], tf.contrib.layers.xavier_initializer())
layer_2 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(layer_1, h2), b2)), FLAGS.relu_clip)
layer_2 = tf.nn.dropout(layer_2, rate=dropout[1])
layers['layer_2'] = layer_2
# 3rd layer
b3 = variable_on_cpu('b3', [Config.n_hidden_3], tf.zeros_initializer())
h3 = variable_on_cpu('h3', [Config.n_hidden_2, Config.n_hidden_3], tf.contrib.layers.xavier_initializer())
layer_3 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(layer_2, h3), b3)), FLAGS.relu_clip)
layer_3 = tf.nn.dropout(layer_3, rate=dropout[2])
layers['layer_3'] = layer_3
# Now we create the forward and backward LSTM units.
# Both of which have inputs of length `n_cell_dim` and bias `1.0` for the forget gate of the LSTM.
# Forward direction cell:
if not tflite:
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim, reuse=reuse)
layers['fw_cell'] = fw_cell
else:
fw_cell = tf.nn.rnn_cell.LSTMCell(Config.n_cell_dim, reuse=reuse)
layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1)
layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2)
layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3)
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
layer_3 = tf.reshape(layer_3, [n_steps, batch_size, Config.n_hidden_3])
if tflite:
# Generated StridedSlice, not supported by NNAPI
#n_layer_3 = []
#for l in range(layer_3.shape[0]):
# n_layer_3.append(layer_3[l])
#layer_3 = n_layer_3
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
# Unstack/Unpack is not supported by NNAPI
layer_3 = tf.unstack(layer_3, n_steps)
# We parametrize the RNN implementation as the training and inference graph
# need to do different things here.
if not tflite:
output, output_state = fw_cell(inputs=layer_3, dtype=tf.float32, sequence_length=seq_length, initial_state=previous_state)
else:
output, output_state = tf.nn.static_rnn(fw_cell, layer_3, previous_state, tf.float32)
output = tf.concat(output, 0)
# Run through parametrized RNN implementation, as we use different RNNs
# for training and inference
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
@ -132,24 +146,16 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1
layers['rnn_output'] = output
layers['rnn_output_state'] = output_state
# Now we feed `output` to the fifth hidden layer with clipped RELU activation and dropout
b5 = variable_on_cpu('b5', [Config.n_hidden_5], tf.zeros_initializer())
h5 = variable_on_cpu('h5', [Config.n_cell_dim, Config.n_hidden_5], tf.contrib.layers.xavier_initializer())
layer_5 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(output, h5), b5)), FLAGS.relu_clip)
layer_5 = tf.nn.dropout(layer_5, rate=dropout[5])
layers['layer_5'] = layer_5
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5)
# Now we apply the weight matrix `h6` and bias `b6` to the output of `layer_5`
# creating `n_classes` dimensional vectors, the logits.
b6 = variable_on_cpu('b6', [Config.n_hidden_6], tf.zeros_initializer())
h6 = variable_on_cpu('h6', [Config.n_hidden_5, Config.n_hidden_6], tf.contrib.layers.xavier_initializer())
layer_6 = tf.add(tf.matmul(layer_5, h6), b6)
layers['layer_6'] = layer_6
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False)
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
# Note, that this differs from the input in that it is time-major.
layer_6 = tf.reshape(layer_6, [n_steps, batch_size, Config.n_hidden_6], name="raw_logits")
layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits')
layers['raw_logits'] = layer_6
# Output shape: [n_steps, batch_size, n_hidden_6]
@ -166,17 +172,17 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1
# Conveniently, this loss function is implemented in TensorFlow.
# Thus, we can simply make use of this implementation to define our loss.
def calculate_mean_edit_distance_and_loss(model_feeder, tower, dropout, reuse):
def calculate_mean_edit_distance_and_loss(iterator, tower, dropout, reuse):
r'''
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
Next to total and average loss it returns the mean edit distance,
the decoded result and the batch's original Y.
'''
# Obtain the next batch of data
batch_x, batch_seq_len, batch_y = model_feeder.next_batch(tower)
(batch_x, batch_seq_len), batch_y = iterator.get_next()
# Calculate the logits of the batch using BiRNN
logits, _ = BiRNN(batch_x, batch_seq_len, dropout, reuse)
# Calculate the logits of the batch
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse)
# Compute the CTC loss using TensorFlow's `ctc_loss`
total_loss = tf.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
@ -221,7 +227,7 @@ def create_optimizer():
# on which all operations within the tower execute.
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
def get_tower_results(model_feeder, optimizer, dropout_rates):
def get_tower_results(iterator, optimizer, dropout_rates):
r'''
With this preliminary step out of the way, we can for each GPU introduce a
tower for which's batch we calculate and return the optimization gradients
@ -243,7 +249,7 @@ def get_tower_results(model_feeder, optimizer, dropout_rates):
with tf.name_scope('tower_%d' % i) as scope:
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
# batch along with the original batch's labels (Y) of this tower
avg_loss = calculate_mean_edit_distance_and_loss(model_feeder, i, dropout_rates, reuse=i>0)
avg_loss = calculate_mean_edit_distance_and_loss(iterator, i, dropout_rates, reuse=i>0)
# Allow for variables to be re-used by the next tower
tf.get_variable_scope().reuse_variables()
@ -337,19 +343,6 @@ def log_grads_and_vars(grads_and_vars):
log_variable(variable, gradient=gradient)
# Helpers
# =======
class SampleIndex:
def __init__(self, index=0):
self.index = index
def inc(self, old_index):
self.index += 1
return self.index
def try_loading(session, saver, checkpoint_filename, caption):
try:
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename)
@ -371,48 +364,23 @@ def try_loading(session, saver, checkpoint_filename, caption):
def train():
r'''
Trains the network on a given server of a cluster.
If no server provided, it performs single process training.
'''
# Create training and validation datasets
train_set, train_batches = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size,
cache_path=FLAGS.train_cached_features_path)
# Reading training set
train_index = SampleIndex()
iterator = tf.data.Iterator.from_structure(train_set.output_types,
train_set.output_shapes,
output_classes=train_set.output_classes)
train_data = preprocess(FLAGS.train_files.split(','),
FLAGS.train_batch_size,
Config.n_input,
Config.n_context,
Config.alphabet,
hdf5_cache_path=FLAGS.train_cached_features_path)
# Make initialization ops for switching between the two sets
train_init_op = iterator.make_initializer(train_set)
train_set = DataSet(train_data,
FLAGS.train_batch_size,
limit=FLAGS.limit_train,
next_index=train_index.inc)
# Reading validation set
dev_index = SampleIndex()
dev_data = preprocess(FLAGS.dev_files.split(','),
FLAGS.dev_batch_size,
Config.n_input,
Config.n_context,
Config.alphabet,
hdf5_cache_path=FLAGS.dev_cached_features_path)
dev_set = DataSet(dev_data,
FLAGS.dev_batch_size,
limit=FLAGS.limit_dev,
next_index=dev_index.inc)
# Combining all sets to a multi set model feeder
model_feeder = ModelFeeder(train_set,
dev_set,
Config.n_input,
Config.n_context,
Config.alphabet,
tower_feeder_count=len(Config.available_devices))
if FLAGS.dev:
dev_set, dev_batches = create_dataset(FLAGS.dev_files.split(','),
batch_size=FLAGS.dev_batch_size,
cache_path=FLAGS.dev_cached_features_path)
dev_init_op = iterator.make_initializer(dev_set)
# Dropout
dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
@ -425,17 +393,12 @@ def train():
dropout_rates[5]: FLAGS.dropout_rate6,
}
no_dropout_feed_dict = {
dropout_rates[0]: 0.,
dropout_rates[1]: 0.,
dropout_rates[2]: 0.,
dropout_rates[3]: 0.,
dropout_rates[4]: 0.,
dropout_rates[5]: 0.,
rate: 0. for rate in dropout_rates
}
# Building the graph
optimizer = create_optimizer()
gradients, loss = get_tower_results(model_feeder, optimizer, dropout_rates)
gradients, loss = get_tower_results(iterator, optimizer, dropout_rates)
# Average tower gradients across GPUs
avg_tower_gradients = average_gradients(gradients)
log_grads_and_vars(avg_tower_gradients)
@ -463,6 +426,7 @@ def train():
with tf.Session(config=Config.session_config) as session:
log_debug('Session opened.')
tf.get_default_graph().finalize()
# Loading or initializing
@ -481,51 +445,51 @@ def train():
sys.exit(1)
# Retrieving global_step from restored model and setting training parameters accordingly
model_feeder.set_data_set(no_dropout_feed_dict, train_set)
step = session.run(global_step, feed_dict=no_dropout_feed_dict)
step = session.run(global_step)
num_gpus = len(Config.available_devices)
steps_per_epoch = max(1, train_set.total_batches // num_gpus)
steps_trained = step % steps_per_epoch
steps_per_epoch = max(1, train_batches // num_gpus)
current_epoch = step // steps_per_epoch
target_epoch = current_epoch + abs(FLAGS.epoch) if FLAGS.epoch < 0 else FLAGS.epoch
train_index.index = steps_trained * num_gpus
log_debug('step: %d' % step)
log_debug('epoch: %d' % current_epoch)
log_debug('target epoch: %d' % target_epoch)
log_debug('steps per epoch: %d' % steps_per_epoch)
log_debug('batches per step (GPUs): %d' % num_gpus)
log_debug('number of batches in train set: %d' % train_set.total_batches)
log_debug('number of batches already trained in epoch: %d' % train_index.index)
log_debug('number of batches in train set: %d' % train_batches)
def run_set(set_name):
data_set = getattr(model_feeder, set_name)
def run_set(set_name, init_op, num_batches):
is_train = set_name == 'train'
train_op = apply_gradient_op if is_train else []
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
model_feeder.set_data_set(feed_dict, data_set)
total_loss = 0.0
step_summary_writer = step_summary_writers.get(set_name)
num_steps = max(1, data_set.total_batches // num_gpus)
num_steps = max(1, num_batches // num_gpus)
checkpoint_time = time.time()
if FLAGS.show_progressbar:
pbar = progressbar.ProgressBar(max_value=num_steps, redirect_stdout=True).start()
else:
pbar = lambda i: i
# Initialize iterator to the appropriate dataset
session.run(init_op)
# Batch loop
for step_index in range(steps_trained, num_steps):
for step_index in pbar(range(num_steps)):
if coord.should_stop():
break
_, current_step, batch_loss, step_summary = \
session.run([train_op, global_step, loss, step_summaries_op],
feed_dict=feed_dict)
total_loss += batch_loss
step_summary_writer.add_summary(step_summary, current_step)
if FLAGS.show_progressbar:
pbar.update(step_index + 1, force=True)
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
checkpoint_time = time.time()
if FLAGS.show_progressbar:
pbar.finish()
return total_loss / num_steps
if target_epoch > current_epoch:
@ -534,68 +498,64 @@ def train():
dev_losses = []
coord = tf.train.Coordinator()
with coord.stop_on_exception():
log_debug('Starting queue runners...')
model_feeder.start_queue_threads(session, coord=coord)
log_debug('Queue runners started.')
# Epoch loop
for current_epoch in range(current_epoch, target_epoch):
# Training
if coord.should_stop():
break
# Training
log_info('Training epoch %d ...' % current_epoch)
train_loss = run_set('train')
train_loss = run_set('train', train_init_op, train_batches)
log_info('Finished training epoch %d - loss: %f' % (current_epoch, train_loss))
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
steps_trained = 0
# Validation
log_info('Validating epoch %d ...' % current_epoch)
dev_loss = run_set('dev')
dev_losses.append(dev_loss)
log_info('Finished validating epoch %d - loss: %f' % (current_epoch, dev_loss))
if dev_loss < best_dev_loss:
best_dev_loss = dev_loss
save_path = best_dev_saver.save(session, best_dev_path, latest_filename=best_dev_filename)
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
# Early stopping
if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps:
mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1])
std_loss = np.std(dev_losses[-FLAGS.es_steps:-1])
dev_losses = dev_losses[-FLAGS.es_steps:]
log_debug('Checking for early stopping (last %d steps) validation loss: '
'%f, with standard deviation: %f and mean: %f' %
(FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
if dev_losses[-1] > np.max(dev_losses[:-1]) or \
(abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th):
log_info('Early stop triggered as (for last %d steps) validation loss:'
' %f with standard deviation: %f and mean: %f' %
(FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
break
log_debug('Closing queues...')
if FLAGS.dev:
# Validation
log_info('Validating epoch %d ...' % current_epoch)
dev_loss = run_set('dev', dev_init_op, dev_batches)
dev_losses.append(dev_loss)
log_info('Finished validating epoch %d - loss: %f' % (current_epoch, dev_loss))
if dev_loss < best_dev_loss:
best_dev_loss = dev_loss
save_path = best_dev_saver.save(session, best_dev_path, latest_filename=best_dev_filename)
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
# Early stopping
if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps:
mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1])
std_loss = np.std(dev_losses[-FLAGS.es_steps:-1])
dev_losses = dev_losses[-FLAGS.es_steps:]
log_debug('Checking for early stopping (last %d steps) validation loss: '
'%f, with standard deviation: %f and mean: %f' %
(FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
if dev_losses[-1] > np.max(dev_losses[:-1]) or \
(abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th):
log_info('Early stop triggered as (for last %d steps) validation loss:'
' %f with standard deviation: %f and mean: %f' %
(FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
break
coord.request_stop()
model_feeder.close_queues(session)
log_debug('Queues closed.')
else:
log_info('Target epoch already reached - skipped training.')
log_debug('Session closed.')
def test():
# Reading test set
test_data = preprocess(FLAGS.test_files.split(','),
FLAGS.test_batch_size,
Config.n_input,
Config.n_context,
Config.alphabet,
hdf5_cache_path=FLAGS.test_cached_features_path)
graph = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1)
evaluate.evaluate(test_data, graph)
evaluate.evaluate(FLAGS.test_files.split(','), create_model)
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
batch_size = batch_size if batch_size > 0 else None
# Create feature computation graph
input_samples = tf.placeholder(tf.float32, [None], 'input_samples')
samples = tf.expand_dims(input_samples, -1)
mfccs, mfccs_len = samples_to_mfccs(samples, 16000)
mfccs = tf.identity(mfccs, name='mfccs')
mfccs_len = tf.identity(mfccs_len, name='mfccs_len')
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2*Config.n_context+1, Config.n_input], name='input_node')
input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths')
if batch_size <= 0:
@ -611,15 +571,20 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)
no_dropout = [0.0] * 6
# One rate per layer
no_dropout = [None] * 6
logits, layers = BiRNN(batch_x=input_tensor,
seq_length=seq_length if FLAGS.use_seq_length else None,
dropout=no_dropout,
batch_size=batch_size,
n_steps=n_steps,
previous_state=previous_state,
tflite=tflite)
if tflite:
rnn_impl = rnn_impl_static_rnn
else:
rnn_impl = rnn_impl_lstmblockfusedcell
logits, layers = create_model(batch_x=input_tensor,
seq_length=seq_length if FLAGS.use_seq_length else None,
dropout=no_dropout,
previous_state=previous_state,
overlap=False,
rnn_impl=rnn_impl)
# TF Lite runtime will check that input dimensions are 1, 2 or 4
# by default we get 3, the middle one being batch_size which is forced to
@ -659,10 +624,13 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
{
'input': input_tensor,
'input_lengths': seq_length,
'input_samples': input_samples,
},
{
'outputs': logits,
'initialize_state': initialize_state,
'mfccs': mfccs,
'mfccs_len': mfccs_len,
},
layers
)
@ -671,16 +639,24 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
new_state_c = tf.identity(new_state_c, name='new_state_c')
new_state_h = tf.identity(new_state_h, name='new_state_h')
inputs = {
'input': input_tensor,
'previous_state_c': previous_state_c,
'previous_state_h': previous_state_h,
'input_samples': input_samples,
}
if FLAGS.use_seq_length:
inputs.update({'input_lengths': seq_length})
return (
{
'input': input_tensor,
'previous_state_c': previous_state_c,
'previous_state_h': previous_state_h,
},
inputs,
{
'outputs': logits,
'new_state_c': new_state_c,
'new_state_h': new_state_h,
'mfccs': mfccs,
'mfccs_len': mfccs_len,
},
layers
)
@ -753,6 +729,8 @@ def export():
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
converter.post_training_quantize = True
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
converter.allow_custom_ops = True
tflite_model = converter.convert()
with open(output_tflite_path, 'wb') as fout:
@ -764,6 +742,7 @@ def export():
except RuntimeError as e:
log_error(str(e))
def do_single_file_inference(input_file_path):
with tf.Session(config=Config.session_config) as session:
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
@ -784,22 +763,20 @@ def do_single_file_inference(input_file_path):
saver.restore(session, checkpoint_path)
session.run(outputs['initialize_state'])
features = audiofile_to_input_vector(input_file_path, Config.n_input, Config.n_context)
num_strides = len(features) - (Config.n_context * 2)
features, features_len = audiofile_to_features(input_file_path)
# Create a view into the array with overlapping strides of size
# numcontext (past) + 1 (present) + numcontext (future)
window_size = 2*Config.n_context+1
features = np.lib.stride_tricks.as_strided(
features,
(num_strides, window_size, Config.n_input),
(features.strides[0], features.strides[0], features.strides[1]),
writeable=False)
# Add batch dimension
features = tf.expand_dims(features, 0)
features_len = tf.expand_dims(features_len, 0)
logits = session.run(outputs['outputs'], feed_dict = {
inputs['input']: [features],
inputs['input_lengths']: [num_strides],
})
# Evaluate
features = create_overlapping_windows(features).eval(session=session)
features_len = features_len.eval(session=session)
logits = outputs['outputs'].eval(feed_dict={
inputs['input']: features,
inputs['input_lengths']: features_len,
}, session=session)
logits = np.squeeze(logits)

View File

@ -334,7 +334,7 @@ Refer to the corresponding [README.md](native_client/README.md) for information
### Exporting a model for TFLite
If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the `--export_tflite` flag. If you already have a trained model, you can re-export it for TFLite by running `DeepSpeech.py` again and specifying the same `checkpoint_dir` that you used for training, as well as passing `--notrain --notest --export_tflite --export_dir /model/export/destination`.
If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the `--nouse_seq_length --export_tflite` flags. If you already have a trained model, you can re-export it for TFLite by running `DeepSpeech.py` again and specifying the same `checkpoint_dir` that you used for training, as well as passing `--notrain --notest --nouse_seq_length --export_tflite --export_dir /model/export/destination`.
### Making a mmap-able model for inference

View File

@ -16,7 +16,7 @@ else
checkpoint_dir=$(python -c 'from xdg import BaseDirectory as xdg; print(xdg.save_data_path("deepspeech/ldc93s1"))')
fi
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
python -u DeepSpeech.py --noshow_progressbar --nodev \
--train_files data/ldc93s1/ldc93s1.csv \
--dev_files data/ldc93s1/ldc93s1.csv \
--test_files data/ldc93s1/ldc93s1.csv \

View File

@ -17,4 +17,4 @@ python -u DeepSpeech.py --noshow_progressbar \
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
--lm_trie_path 'data/smoke_test/vocab.trie' \
--notrain --notest \
--export_tflite \
--export_tflite --nouse_seq_length \

View File

@ -5,92 +5,67 @@ from __future__ import absolute_import, division, print_function
import itertools
import json
import numpy as np
import os
import pandas
import progressbar
import sys
import tables
import tensorflow as tf
from collections import namedtuple
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from multiprocessing import Pool, cpu_count
from multiprocessing import cpu_count
from six.moves import zip, range
from util.audio import audiofile_to_input_vector
from util.config import Config, initialize_globals
from util.evaluate_tools import calculate_report
from util.feeding import create_dataset
from util.flags import create_flags, FLAGS
from util.logging import log_error
from util.preprocess import preprocess
from util.text import Alphabet, levenshtein
from util.evaluate_tools import process_decode_result, calculate_report
def split_data(dataset, batch_size):
remainder = len(dataset) % batch_size
if remainder != 0:
dataset = dataset[:-remainder]
for i in range(0, len(dataset), batch_size):
yield dataset[i:i + batch_size]
from util.text import levenshtein
def pad_to_dense(jagged):
maxlen = max(len(r) for r in jagged)
subshape = jagged[0].shape
padded = np.zeros((len(jagged), maxlen) +
subshape[1:], dtype=jagged[0].dtype)
for i, row in enumerate(jagged):
padded[i, :len(row)] = row
return padded
def sparse_tensor_value_to_texts(value, alphabet):
r"""
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
representing its values, converting tokens to strings using ``alphabet``.
"""
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
def evaluate(test_data, inference_graph):
def sparse_tuple_to_texts(tuple, alphabet):
indices = tuple[0]
values = tuple[1]
results = [''] * tuple[2][0]
for i in range(len(indices)):
index = indices[i][0]
results[index] += alphabet.string_from_label(values[i])
# List of strings
return results
def evaluate(test_csvs, create_model):
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
Config.alphabet)
test_set, test_batches = create_dataset(test_csvs,
batch_size=FLAGS.test_batch_size,
cache_path=FLAGS.test_cached_features_path)
it = test_set.make_one_shot_iterator()
def create_windows(features):
num_strides = len(features) - (Config.n_context * 2)
(batch_x, batch_x_len), batch_y = it.get_next()
# Create a view into the array with overlapping strides of size
# numcontext (past) + 1 (present) + numcontext (future)
window_size = 2*Config.n_context+1
features = np.lib.stride_tricks.as_strided(
features,
(num_strides, window_size, Config.n_input),
(features.strides[0], features.strides[0], features.strides[1]),
writeable=False)
# One rate per layer
no_dropout = [None] * 6
logits, _ = create_model(batch_x=batch_x,
seq_length=batch_x_len,
dropout=no_dropout)
return features
# Transpose to batch major and apply softmax for decoder
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
# Create overlapping windows over the features
test_data['features'] = test_data['features'].apply(create_windows)
loss = tf.nn.ctc_loss(labels=batch_y,
inputs=logits,
sequence_length=batch_x_len)
with tf.Session(config=Config.session_config) as session:
inputs, outputs, layers = inference_graph
# Transpose to batch major for decoder
transposed = tf.transpose(outputs['outputs'], [1, 0, 2])
labels_ph = tf.placeholder(tf.int32, [FLAGS.test_batch_size, None], name="labels")
label_lengths_ph = tf.placeholder(tf.int32, [FLAGS.test_batch_size], name="label_lengths")
# We add 1 to all elements of the transcript to avoid any zero values
# since we use that as an end-of-sequence token for converting the batch
# into a SparseTensor. So here we convert the placeholder back into a
# SparseTensor and subtract ones to get the real labels.
sparse_labels = tf.contrib.layers.dense_to_sparse(labels_ph)
neg_ones = tf.SparseTensor(sparse_labels.indices, -1 * tf.ones_like(sparse_labels.values), sparse_labels.dense_shape)
sparse_labels = tf.sparse_add(sparse_labels, neg_ones)
loss = tf.nn.ctc_loss(labels=sparse_labels,
inputs=layers['raw_logits'],
sequence_length=inputs['input_lengths'])
# Create a saver using variables from the above newly created graph
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
saver = tf.train.Saver(mapping)
saver = tf.train.Saver()
# Restore variables from training checkpoint
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
@ -103,51 +78,38 @@ def evaluate(test_data, inference_graph):
logitses = []
losses = []
seq_lengths = []
ground_truths = []
print('Computing acoustic model predictions...')
batch_count = len(test_data) // FLAGS.test_batch_size
bar = progressbar.ProgressBar(max_value=batch_count,
bar = progressbar.ProgressBar(max_value=test_batches,
widget=progressbar.AdaptiveETA)
# First pass, compute losses and transposed logits for decoding
for batch in bar(split_data(test_data, FLAGS.test_batch_size)):
session.run(outputs['initialize_state'])
features = pad_to_dense(batch['features'].values)
features_len = batch['features_len'].values
labels = pad_to_dense(batch['transcript'].values + 1)
label_lengths = batch['transcript_len'].values
logits, loss_ = session.run([transposed, loss], feed_dict={
inputs['input']: features,
inputs['input_lengths']: features_len,
labels_ph: labels,
label_lengths_ph: label_lengths
})
for batch in bar(range(test_batches)):
logits, loss_, lengths, transcripts = session.run([transposed, loss, batch_x_len, batch_y])
logitses.append(logits)
losses.extend(loss_)
seq_lengths.append(lengths)
ground_truths.extend(sparse_tensor_value_to_texts(transcripts, Config.alphabet))
ground_truths = []
predictions = []
print('Decoding predictions...')
bar = progressbar.ProgressBar(max_value=batch_count,
widget=progressbar.AdaptiveETA)
# Get number of accessible CPU cores for this process
try:
num_processes = cpu_count()
except:
num_processes = 1
# Second pass, decode logits and compute WER and edit distance metrics
for logits, batch in bar(zip(logitses, split_data(test_data, FLAGS.test_batch_size))):
seq_lengths = batch['features_len'].values.astype(np.int32)
decoded = ctc_beam_search_decoder_batch(logits, seq_lengths, Config.alphabet, FLAGS.beam_width,
num_processes=num_processes, scorer=scorer)
print('Decoding predictions...')
bar = progressbar.ProgressBar(max_value=test_batches,
widget=progressbar.AdaptiveETA)
ground_truths.extend(Config.alphabet.decode(l) for l in batch['transcript'])
# Second pass, decode logits and compute WER and edit distance metrics
for logits, seq_length in bar(zip(logitses, seq_lengths)):
decoded = ctc_beam_search_decoder_batch(logits, seq_length, Config.alphabet, FLAGS.beam_width,
num_processes=num_processes, scorer=scorer)
predictions.extend(d[0][1] for d in decoded)
distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]
@ -179,21 +141,8 @@ def main(_):
'the --test_files flag.')
exit(1)
# sort examples by length, improves packing of batches and timesteps
test_data = preprocess(
FLAGS.test_files.split(','),
FLAGS.test_batch_size,
alphabet=Config.alphabet,
numcep=Config.n_input,
numcontext=Config.n_context,
hdf5_cache_path=FLAGS.hdf5_test_set).sort_values(
by="features_len",
ascending=False)
from DeepSpeech import create_inference_graph
graph = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1)
samples = evaluate(test_data, graph)
from DeepSpeech import create_model
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats

View File

@ -119,6 +119,10 @@ tf_cc_shared_object(
"//tensorflow/core/kernels:control_flow_ops", # Enter
"//tensorflow/core/kernels:tile_ops", # Tile
"//tensorflow/core/kernels:gather_op", # Gather
"//tensorflow/core/kernels:mfcc_op", # Mfcc
"//tensorflow/core/kernels:spectrogram_op", # AudioSpectrogram
"//tensorflow/core/kernels:strided_slice_op", # StridedSlice
"//tensorflow/core/kernels:slice_op", # Slice, needed by StridedSlice
"//tensorflow/contrib/rnn:lstm_ops_kernels", # BlockLSTM
"//tensorflow/core/kernels:random_ops", # RandomGammaGrad
"//tensorflow/core/kernels:pack_op", # Pack

View File

@ -108,7 +108,6 @@ using std::vector;
struct StreamingState {
vector<float> accumulated_logits;
vector<float> audio_buffer;
float last_sample; // used for preemphasis
vector<float> mfcc_buffer;
vector<float> batch_buffer;
ModelState* model;
@ -152,10 +151,13 @@ struct ModelState {
int input_node_idx;
int previous_state_c_idx;
int previous_state_h_idx;
int input_samples_idx;
int logits_idx;
int new_state_c_idx;
int new_state_h_idx;
int mfccs_idx;
int mfccs_len_idx;
#endif
ModelState();
@ -204,7 +206,9 @@ struct ModelState {
*
* @param[out] output_logits Where to store computed logits.
*/
void infer(const float* mfcc, unsigned int n_frames, vector<float>& output_logits);
void infer(const float* mfcc, unsigned int n_frames, vector<float>& logits_output);
void compute_mfcc(const vector<float> audio_buffer, vector<float>& mfcc_output);
};
StreamingState* setupStreamAndFeedAudioContent(ModelState* aCtx, const short* aBuffer,
@ -258,10 +262,9 @@ StreamingState::feedAudioContent(const short* buffer,
// Consume all the data that was passed in, processing full buffers if needed
while (buffer_size > 0) {
while (buffer_size > 0 && audio_buffer.size() < AUDIO_WIN_LEN_SAMPLES) {
// Apply preemphasis to input sample and buffer it
float sample = (float)(*buffer) - (PREEMPHASIS_COEFF * last_sample);
audio_buffer.push_back(sample);
last_sample = *buffer;
// Convert i16 sample into f32
float multiplier = 1.0f / (1 << 15);
audio_buffer.push_back((float)(*buffer) * multiplier);
++buffer;
--buffer_size;
}
@ -304,15 +307,11 @@ void
StreamingState::processAudioWindow(const vector<float>& buf)
{
// Compute MFCC features
float* mfcc;
int n_frames = csf_mfcc(buf.data(), buf.size(), SAMPLE_RATE,
AUDIO_WIN_LEN, AUDIO_WIN_STEP, MFCC_FEATURES, N_FILTERS, N_FFT,
LOWFREQ, SAMPLE_RATE/2, 0.f, CEP_LIFTER, 1, hamming_window.data(),
&mfcc);
assert(n_frames == 1);
vector<float> mfcc;
mfcc.reserve(MFCC_FEATURES);
model->compute_mfcc(buf, mfcc);
pushMfccBuffer(mfcc, n_frames * MFCC_FEATURES);
free(mfcc);
pushMfccBuffer(mfcc.data(), MFCC_FEATURES);
}
void
@ -396,7 +395,7 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
input_mapped(i) = aMfcc[i];
}
for (; i < n_steps*mfcc_feats_per_timestep; ++i) {
input_mapped(i) = 0;
input_mapped(i) = 0.;
}
Tensor input_lengths(DT_INT32, TensorShape({1}));
@ -454,6 +453,53 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
#endif // USE_TFLITE
}
void
ModelState::compute_mfcc(const vector<float> samples, vector<float>& mfcc_output)
{
#ifndef USE_TFLITE
Tensor input(DT_FLOAT, TensorShape({static_cast<long long>(samples.size())}));
auto input_mapped = input.flat<float>();
for (int i = 0; i < samples.size(); ++i) {
input_mapped(i) = samples[i];
}
vector<Tensor> outputs;
Status status = session->Run({{"input_samples", input}}, {"mfccs", "mfccs_len"}, {}, &outputs);
if (!status.ok()) {
std::cerr << "Error running session: " << status << "\n";
return;
}
auto mfcc_len_mapped = outputs[1].flat<int32>();
int n_windows = mfcc_len_mapped(0);
auto mfcc_mapped = outputs[0].flat<float>();
for (int i = 0; i < n_windows * MFCC_FEATURES; ++i) {
mfcc_output.push_back(mfcc_mapped(i));
}
#else
// Feeding input_node
float* input_samples = interpreter->typed_tensor<float>(input_samples_idx);
for (int i = 0; i < samples.size(); ++i) {
input_samples[i] = samples[i];
}
TfLiteStatus status = interpreter->Invoke();
if (status != kTfLiteOk) {
std::cerr << "Error running session: " << status << "\n";
return;
}
int n_windows = *interpreter->typed_tensor<float>(mfccs_len_idx);
float* outputs = interpreter->typed_tensor<float>(mfccs_idx);
for (int i = 0; i < n_windows * MFCC_FEATURES; ++i) {
mfcc_output.push_back(outputs[i]);
}
#endif
}
char*
ModelState::decode(vector<float>& logits)
{
@ -640,8 +686,6 @@ DS_CreateModel(const char* aModelPath,
*retval = model.release();
return DS_ERR_OK;
#else // USE_TFLITE
TfLiteStatus status;
model->fbmodel = tflite::FlatBufferModel::BuildFromFile(aModelPath);
if (!model->fbmodel) {
std::cerr << "Error at reading model file " << aModelPath << std::endl;
@ -663,9 +707,12 @@ DS_CreateModel(const char* aModelPath,
model->input_node_idx = tflite_get_input_tensor_by_name(model.get(), "input_node");
model->previous_state_c_idx = tflite_get_input_tensor_by_name(model.get(), "previous_state_c");
model->previous_state_h_idx = tflite_get_input_tensor_by_name(model.get(), "previous_state_h");
model->input_samples_idx = tflite_get_input_tensor_by_name(model.get(), "input_samples");
model->logits_idx = tflite_get_output_tensor_by_name(model.get(), "logits");
model->new_state_c_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_c");
model->new_state_h_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_h");
model->mfccs_idx = tflite_get_output_tensor_by_name(model.get(), "mfccs");
model->mfccs_len_idx = tflite_get_output_tensor_by_name(model.get(), "mfccs_len");
TfLiteIntArray* dims_input_node = model->interpreter->tensor(model->input_node_idx)->dims;
@ -796,7 +843,6 @@ DS_SetupStream(ModelState* aCtx,
ctx->accumulated_logits.reserve(aPreAllocFrames * BATCH_SIZE * num_classes);
ctx->audio_buffer.reserve(AUDIO_WIN_LEN_SAMPLES);
ctx->last_sample = 0;
ctx->mfcc_buffer.reserve(aCtx->mfcc_feats_per_timestep);
ctx->mfcc_buffer.resize(MFCC_FEATURES*aCtx->n_context, 0.f);
ctx->batch_buffer.reserve(aCtx->n_steps * aCtx->mfcc_feats_per_timestep);

View File

@ -1,19 +1,23 @@
pandas
progressbar2
python-utils
# Main training requirements
tensorflow == 1.13.1
numpy == 1.15.4
matplotlib
scipy
sox
paramiko >= 2.1
python_speech_features
pyxdg
bs4
progressbar2
pandas
six
requests
tables
pyxdg
attrdict
# Requirements for building native_client files
setuptools
# Requirements for importers
sox
bs4
requests
librosa
soundfile
# Miscellaneous scripts
paramiko >= 2.1
scipy
matplotlib

View File

@ -1,24 +0,0 @@
import numpy as np
import scipy.io.wavfile as wav
from python_speech_features import mfcc
def audiofile_to_input_vector(audio_filename, numcep, numcontext):
r"""
Given a WAV audio file at ``audio_filename``, calculates ``numcep`` MFCC features
at every 0.01s time step with a window length of 0.025s. Appends ``numcontext``
context frames to the left and right of each time step, and returns this data
in a numpy array.
"""
# Load wav files
fs, audio = wav.read(audio_filename)
# Get mfcc coefficients
features = mfcc(audio, samplerate=fs, numcep=numcep, winlen=0.032, winstep=0.02, winfunc=np.hamming)
# Add empty initial and final contexts
empty_context = np.zeros((numcontext, numcep), dtype=features.dtype)
features = np.concatenate((empty_context, features, empty_context))
return features

View File

@ -1,198 +1,97 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import numpy as np
import os
import pandas
import tensorflow as tf
from math import ceil
from six.moves import range
from threading import Thread
from util.gpu import get_available_gpus
from functools import partial
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
from util.config import Config
from util.text import text_to_char_array
class ModelFeeder(object):
'''
Feeds data into a model.
Feeding is parallelized by independent units called tower feeders (usually one per GPU).
Each tower feeder provides data from runtime switchable sources (train, dev).
These sources are to be provided by the DataSet instances whose references are kept.
Creates, owns and delegates to tower_feeder_count internal tower feeder objects.
'''
def __init__(self,
train_set,
dev_set,
numcep,
numcontext,
alphabet,
tower_feeder_count=-1,
threads_per_queue=4):
self.train = train_set
self.dev = dev_set
self.sets = [train_set, dev_set]
self.numcep = numcep
self.numcontext = numcontext
self.tower_feeder_count = max(len(get_available_gpus()), 1) if tower_feeder_count < 0 else tower_feeder_count
self.threads_per_queue = threads_per_queue
self.ph_x = tf.placeholder(tf.float32, [None, 2*numcontext+1, numcep])
self.ph_x_length = tf.placeholder(tf.int32, [])
self.ph_y = tf.placeholder(tf.int32, [None,])
self.ph_y_length = tf.placeholder(tf.int32, [])
self.ph_batch_size = tf.placeholder(tf.int32, [])
self.ph_queue_selector = tf.placeholder(tf.int32, name='Queue_Selector')
self._tower_feeders = [_TowerFeeder(self, i, alphabet) for i in range(self.tower_feeder_count)]
def start_queue_threads(self, session, coord):
'''
Starts required queue threads on all tower feeders.
'''
queue_threads = []
for tower_feeder in self._tower_feeders:
queue_threads += tower_feeder.start_queue_threads(session, coord)
return queue_threads
def close_queues(self, session):
'''
Closes queues of all tower feeders.
'''
for tower_feeder in self._tower_feeders:
tower_feeder.close_queues(session)
def set_data_set(self, feed_dict, data_set):
'''
Switches all tower feeders to a different source DataSet.
The provided feed_dict will get enriched with required placeholder/value pairs.
The DataSet has to be one of those that got passed into the constructor.
'''
index = self.sets.index(data_set)
assert index >= 0
feed_dict[self.ph_queue_selector] = index
feed_dict[self.ph_batch_size] = data_set.batch_size
def next_batch(self, tower_feeder_index):
'''
Draw the next batch from one of the tower feeders.
'''
return self._tower_feeders[tower_feeder_index].next_batch()
def read_csvs(csv_files):
source_data = None
for csv in csv_files:
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
#FIXME: not cross-platform
csv_dir = os.path.dirname(os.path.abspath(csv))
file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1)))
if source_data is None:
source_data = file
else:
source_data = source_data.append(file)
return source_data
class DataSet(object):
'''
Represents a collection of audio samples and their respective transcriptions.
Takes a set of CSV files produced by importers in /bin.
'''
def __init__(self, data, batch_size, skip=0, limit=0, ascending=True, next_index=lambda i: i + 1):
self.data = data
self.data.sort_values(by="features_len", ascending=ascending, inplace=True)
self.batch_size = batch_size
self.next_index = next_index
self.total_batches = int(ceil(len(self.data) / batch_size))
def samples_to_mfccs(samples, sample_rate):
spectrogram = contrib_audio.audio_spectrogram(samples, window_size=512, stride=320, magnitude_squared=True)
mfccs = contrib_audio.mfcc(spectrogram, sample_rate, dct_coefficient_count=Config.n_input)
mfccs = tf.reshape(mfccs, [-1, Config.n_input])
return mfccs, tf.shape(mfccs)[0]
class _DataSetLoader(object):
'''
Internal class that represents an input queue with data from one of the DataSet objects.
Each tower feeder will create and combine three data set loaders to one switchable queue.
Keeps a ModelFeeder reference for accessing shared settings and placeholders.
Keeps a DataSet reference to access its samples.
'''
def __init__(self, model_feeder, data_set, alphabet):
self._model_feeder = model_feeder
self._data_set = data_set
self.queue = tf.PaddingFIFOQueue(shapes=[[None, 2 * model_feeder.numcontext + 1, model_feeder.numcep], [], [None,], []],
dtypes=[tf.float32, tf.int32, tf.int32, tf.int32],
capacity=data_set.batch_size * 8)
self._enqueue_op = self.queue.enqueue([model_feeder.ph_x, model_feeder.ph_x_length, model_feeder.ph_y, model_feeder.ph_y_length])
self._close_op = self.queue.close(cancel_pending_enqueues=True)
self._alphabet = alphabet
def audiofile_to_features(wav_filename):
samples = tf.read_file(wav_filename)
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate)
def start_queue_threads(self, session, coord):
'''
Starts concurrent queue threads for reading samples from the data set.
'''
queue_threads = [Thread(target=self._populate_batch_queue, args=(session, coord))
for i in range(self._model_feeder.threads_per_queue)]
for queue_thread in queue_threads:
coord.register_thread(queue_thread)
queue_thread.daemon = True
queue_thread.start()
return queue_threads
def close_queue(self, session):
'''
Closes the data set queue.
'''
session.run(self._close_op)
def _populate_batch_queue(self, session, coord):
'''
Queue thread routine.
'''
file_count = len(self._data_set.data)
index = -1
while not coord.should_stop():
index = self._data_set.next_index(index) % file_count
features, num_strides, transcript, transcript_len = self._data_set.data.iloc[index]
# Create a view into the array with overlapping strides of size
# numcontext (past) + 1 (present) + numcontext (future)
window_size = 2*self._model_feeder.numcontext+1
features = np.lib.stride_tricks.as_strided(
features,
(num_strides, window_size, self._model_feeder.numcep),
(features.strides[0], features.strides[0], features.strides[1]),
writeable=False)
# We add 1 to all elements of the transcript here to avoid any zero
# values since we use that as an end-of-sequence token for converting
# the batch into a SparseTensor.
try:
session.run(self._enqueue_op, feed_dict={
self._model_feeder.ph_x: features,
self._model_feeder.ph_x_length: num_strides,
self._model_feeder.ph_y: transcript + 1,
self._model_feeder.ph_y_length: transcript_len
})
except tf.errors.CancelledError:
return
return features, features_len
class _TowerFeeder(object):
'''
Internal class that represents a switchable input queue for one tower.
It creates, owns and combines three _DataSetLoader instances.
Keeps a ModelFeeder reference for accessing shared settings and placeholders.
'''
def __init__(self, model_feeder, index, alphabet):
self._model_feeder = model_feeder
self.index = index
self._loaders = [_DataSetLoader(model_feeder, data_set, alphabet) for data_set in model_feeder.sets]
self._queues = [set_queue.queue for set_queue in self._loaders]
self._queue = tf.QueueBase.from_list(model_feeder.ph_queue_selector, self._queues)
self._close_op = self._queue.close(cancel_pending_enqueues=True)
def entry_to_features(wav_filename, transcript):
# https://bugs.python.org/issue32117
features, features_len = audiofile_to_features(wav_filename)
return features, features_len, tf.SparseTensor(*transcript)
def next_batch(self):
'''
Draw the next batch from from the combined switchable queue.
'''
source, source_lengths, target, target_lengths = self._queue.dequeue_many(self._model_feeder.ph_batch_size)
# Back to sparse, then subtract one to get the real labels
sparse_labels = tf.contrib.layers.dense_to_sparse(target)
neg_ones = tf.SparseTensor(sparse_labels.indices, -1 * tf.ones_like(sparse_labels.values), sparse_labels.dense_shape)
return source, source_lengths, tf.sparse_add(sparse_labels, neg_ones)
def start_queue_threads(self, session, coord):
'''
Starts the queue threads of all owned _DataSetLoader instances.
'''
queue_threads = []
for set_queue in self._loaders:
queue_threads += set_queue.start_queue_threads(session, coord)
return queue_threads
def to_sparse_tuple(sequence):
r"""Creates a sparse representention of ``sequence``.
Returns a tuple with (indices, values, shape)
"""
indices = np.asarray(list(zip([0]*len(sequence), range(len(sequence)))), dtype=np.int64)
shape = np.asarray([1, len(sequence)], dtype=np.int64)
return indices, sequence, shape
def close_queues(self, session):
'''
Closes queues of all owned _DataSetLoader instances.
'''
for set_queue in self._loaders:
set_queue.close_queue(session)
def create_dataset(csvs, batch_size, cache_path):
df = read_csvs(csvs)
df.sort_values(by='wav_filesize', inplace=True)
num_batches = len(df) // batch_size
# Convert to character index arrays
df['transcript'] = df['transcript'].apply(partial(text_to_char_array, alphabet=Config.alphabet))
def generate_values():
for _, row in df.iterrows():
yield row.wav_filename, to_sparse_tuple(row.transcript)
# Batching a dataset of 2D SparseTensors creates 3D batches, which fail
# when passed to tf.nn.ctc_loss, so we reshape them to remove the extra
# dimension here.
def sparse_reshape(sparse):
shape = sparse.dense_shape
return tf.sparse.reshape(sparse, [shape[0], shape[2]])
def batch_fn(features, features_len, transcripts):
features = tf.data.Dataset.zip((features, features_len))
features = features.padded_batch(batch_size,
padded_shapes=([None, Config.n_input], []))
transcripts = transcripts.batch(batch_size).map(sparse_reshape)
return tf.data.Dataset.zip((features, transcripts))
num_gpus = len(Config.available_devices)
dataset = (tf.data.Dataset.from_generator(generate_values,
output_types=(tf.string, (tf.int64, tf.int32, tf.int64)))
.map(entry_to_features, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.cache(cache_path)
.window(batch_size, drop_remainder=True).flat_map(batch_fn)
.prefetch(num_gpus)
.repeat())
return dataset, num_batches

View File

@ -23,6 +23,7 @@ def create_flags():
# ================
tf.app.flags.DEFINE_boolean ('train', True, 'whether to train the network')
tf.app.flags.DEFINE_boolean ('dev', True, 'whether to run validation epochs')
tf.app.flags.DEFINE_boolean ('test', True, 'whether to test the network')
tf.app.flags.DEFINE_integer ('epoch', 75, 'target epoch to train - if negative, the absolute number of additional epochs will be trained')

View File

@ -1,101 +0,0 @@
import numpy as np
import os
import pandas
import tables
from functools import partial
from multiprocessing.dummy import Pool
from util.audio import audiofile_to_input_vector
from util.text import text_to_char_array
def pmap(fun, iterable):
pool = Pool()
results = pool.map(fun, iterable)
pool.close()
return results
def process_single_file(row, numcep, numcontext, alphabet):
# row = index, Series
_, file = row
features = audiofile_to_input_vector(file.wav_filename, numcep, numcontext)
features_len = len(features) - 2*numcontext
transcript = text_to_char_array(file.transcript, alphabet)
if features_len < len(transcript):
raise ValueError('Error: Audio file {} is too short for transcription.'.format(file.wav_filename))
return features, features_len, transcript, len(transcript)
# load samples from CSV, compute features, optionally cache results on disk
def preprocess(csv_files, batch_size, numcep, numcontext, alphabet, hdf5_cache_path=None):
COLUMNS = ('features', 'features_len', 'transcript', 'transcript_len')
print('Preprocessing', csv_files)
if hdf5_cache_path and os.path.exists(hdf5_cache_path):
with tables.open_file(hdf5_cache_path, 'r') as file:
features = file.root.features[:]
features_len = file.root.features_len[:]
transcript = file.root.transcript[:]
transcript_len = file.root.transcript_len[:]
# features are stored flattened, so reshape into [n_steps, numcep]
for i in range(len(features)):
features[i].shape = [features_len[i]+2*numcontext, numcep]
in_data = list(zip(features, features_len,
transcript, transcript_len))
print('Loaded from cache at', hdf5_cache_path)
return pandas.DataFrame(data=in_data, columns=COLUMNS)
source_data = None
for csv in csv_files:
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
#FIXME: not cross-platform
csv_dir = os.path.dirname(os.path.abspath(csv))
file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1)))
if source_data is None:
source_data = file
else:
source_data = source_data.append(file)
step_fn = partial(process_single_file,
numcep=numcep,
numcontext=numcontext,
alphabet=alphabet)
out_data = pmap(step_fn, source_data.iterrows())
if hdf5_cache_path:
print('Saving to', hdf5_cache_path)
# list of tuples -> tuple of lists
features, features_len, transcript, transcript_len = zip(*out_data)
with tables.open_file(hdf5_cache_path, 'w') as file:
features_dset = file.create_vlarray(file.root,
'features',
tables.Float32Atom(),
filters=tables.Filters(complevel=1))
# VLArray atoms need to be 1D, so flatten feature array
for f in features:
features_dset.append(np.reshape(f, -1))
features_len_dset = file.create_array(file.root,
'features_len',
features_len)
transcript_dset = file.create_vlarray(file.root,
'transcript',
tables.Int32Atom(),
filters=tables.Filters(complevel=1))
for t in transcript:
transcript_dset.append(t)
transcript_len_dset = file.create_array(file.root,
'transcript_len',
transcript_len)
print('Preprocessing done')
return pandas.DataFrame(data=out_data, columns=COLUMNS)