From 1cea2b0fe88b888ae8bbbb4cbe2743c1a6087552 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Fri, 22 Mar 2019 21:14:10 -0300 Subject: [PATCH] Rewrite input pipeline to use tf.data API --- DeepSpeech.py | 441 +++++++++++++++++------------------ README.md | 2 +- bin/run-ldc93s1.sh | 2 +- bin/run-tc-ldc93s1_tflite.sh | 2 +- evaluate.py | 159 +++++-------- native_client/BUILD | 4 + native_client/deepspeech.cc | 82 +++++-- requirements.txt | 28 ++- util/audio.py | 24 -- util/feeding.py | 263 +++++++-------------- util/flags.py | 1 + util/preprocess.py | 101 -------- 12 files changed, 432 insertions(+), 677 deletions(-) delete mode 100644 util/audio.py delete mode 100644 util/preprocess.py diff --git a/DeepSpeech.py b/DeepSpeech.py index ea654345..e3e60ad1 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -18,12 +18,10 @@ import tensorflow as tf from ds_ctcdecoder import ctc_beam_search_decoder, Scorer from six.moves import zip, range from tensorflow.python.tools import freeze_graph -from util.audio import audiofile_to_input_vector from util.config import Config, initialize_globals -from util.feeding import DataSet, ModelFeeder +from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features from util.flags import create_flags, FLAGS from util.logging import log_info, log_error, log_debug, log_warn -from util.preprocess import preprocess # Graph Creation @@ -42,25 +40,82 @@ def variable_on_cpu(name, shape, initializer): return var -def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1, previous_state=None, tflite=False): - r''' - That done, we will define the learned variables, the weights and biases, - within the method ``BiRNN()`` which also constructs the neural network. - The variables named ``hn``, where ``n`` is an integer, hold the learned weight variables. - The variables named ``bn``, where ``n`` is an integer, hold the learned bias variables. - In particular, the first variable ``h1`` holds the learned weight matrix that - converts an input vector of dimension ``n_input + 2*n_input*n_context`` - to a vector of dimension ``n_hidden_1``. - Similarly, the second variable ``h2`` holds the weight matrix converting - an input vector of dimension ``n_hidden_1`` to one of dimension ``n_hidden_2``. - The variables ``h3``, ``h5``, and ``h6`` are similar. - Likewise, the biases, ``b1``, ``b2``..., hold the biases for the various layers. - ''' +def create_overlapping_windows(batch_x): + batch_size = tf.shape(batch_x)[0] + window_width = 2 * Config.n_context + 1 + num_channels = Config.n_input + + # Create a constant convolution filter using an identity matrix, so that the + # convolution returns patches of the input tensor as is, and we can create + # overlapping windows over the MFCCs. + eye_filter = tf.constant(np.eye(window_width * num_channels) + .reshape(window_width, num_channels, window_width * num_channels), tf.float32) + + # Create overlapping windows + batch_x = tf.nn.conv1d(batch_x, eye_filter, stride=1, padding='SAME') + + # Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input] + batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels]) + + return batch_x + + +def dense(name, x, units, dropout_rate=None, relu=True): + with tf.variable_scope(name): + bias = variable_on_cpu('bias', [units], tf.zeros_initializer()) + weights = variable_on_cpu('weights', [x.shape[-1], units], tf.contrib.layers.xavier_initializer()) + + output = tf.nn.bias_add(tf.matmul(x, weights), bias) + + if relu: + output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip) + + if dropout_rate is not None: + output = tf.nn.dropout(output, rate=dropout_rate) + + return output + + +def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse): + # Forward direction cell: + fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim, reuse=reuse) + + output, output_state = fw_cell(inputs=x, + dtype=tf.float32, + sequence_length=seq_length, + initial_state=previous_state) + + return output, output_state + + +def rnn_impl_static_rnn(x, seq_length, previous_state, reuse): + # Forward direction cell: + fw_cell = tf.nn.rnn_cell.LSTMCell(Config.n_cell_dim, reuse=reuse) + + # Split rank N tensor into list of rank N-1 tensors + x = [x[l] for l in range(x.shape[0])] + + # We parametrize the RNN implementation as the training and inference graph + # need to do different things here. + output, output_state = tf.nn.static_rnn(cell=fw_cell, + inputs=x, + initial_state=previous_state, + dtype=tf.float32, + sequence_length=seq_length) + output = tf.concat(output, 0) + + return output, output_state + + +def create_model(batch_x, seq_length, dropout, reuse=False, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell): layers = {} # Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context] - if not batch_size: - batch_size = tf.shape(batch_x)[0] + batch_size = tf.shape(batch_x)[0] + + # Create overlapping feature windows if needed + if overlap: + batch_x = create_overlapping_windows(batch_x) # Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`. # This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`. @@ -73,58 +128,17 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1 # The next three blocks will pass `batch_x` through three hidden layers with # clipped RELU activation and dropout. - - # 1st layer - b1 = variable_on_cpu('b1', [Config.n_hidden_1], tf.zeros_initializer()) - h1 = variable_on_cpu('h1', [Config.n_input + 2*Config.n_input*Config.n_context, Config.n_hidden_1], tf.contrib.layers.xavier_initializer()) - layer_1 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(batch_x, h1), b1)), FLAGS.relu_clip) - layer_1 = tf.nn.dropout(layer_1, rate=dropout[0]) - layers['layer_1'] = layer_1 - - # 2nd layer - b2 = variable_on_cpu('b2', [Config.n_hidden_2], tf.zeros_initializer()) - h2 = variable_on_cpu('h2', [Config.n_hidden_1, Config.n_hidden_2], tf.contrib.layers.xavier_initializer()) - layer_2 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(layer_1, h2), b2)), FLAGS.relu_clip) - layer_2 = tf.nn.dropout(layer_2, rate=dropout[1]) - layers['layer_2'] = layer_2 - - # 3rd layer - b3 = variable_on_cpu('b3', [Config.n_hidden_3], tf.zeros_initializer()) - h3 = variable_on_cpu('h3', [Config.n_hidden_2, Config.n_hidden_3], tf.contrib.layers.xavier_initializer()) - layer_3 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(layer_2, h3), b3)), FLAGS.relu_clip) - layer_3 = tf.nn.dropout(layer_3, rate=dropout[2]) - layers['layer_3'] = layer_3 - - # Now we create the forward and backward LSTM units. - # Both of which have inputs of length `n_cell_dim` and bias `1.0` for the forget gate of the LSTM. - - # Forward direction cell: - if not tflite: - fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim, reuse=reuse) - layers['fw_cell'] = fw_cell - else: - fw_cell = tf.nn.rnn_cell.LSTMCell(Config.n_cell_dim, reuse=reuse) + layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1) + layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2) + layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3) # `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`, # as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`. - layer_3 = tf.reshape(layer_3, [n_steps, batch_size, Config.n_hidden_3]) - if tflite: - # Generated StridedSlice, not supported by NNAPI - #n_layer_3 = [] - #for l in range(layer_3.shape[0]): - # n_layer_3.append(layer_3[l]) - #layer_3 = n_layer_3 + layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3]) - # Unstack/Unpack is not supported by NNAPI - layer_3 = tf.unstack(layer_3, n_steps) - - # We parametrize the RNN implementation as the training and inference graph - # need to do different things here. - if not tflite: - output, output_state = fw_cell(inputs=layer_3, dtype=tf.float32, sequence_length=seq_length, initial_state=previous_state) - else: - output, output_state = tf.nn.static_rnn(fw_cell, layer_3, previous_state, tf.float32) - output = tf.concat(output, 0) + # Run through parametrized RNN implementation, as we use different RNNs + # for training and inference + output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse) # Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim] # to a tensor of shape [n_steps*batch_size, n_cell_dim] @@ -132,24 +146,16 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1 layers['rnn_output'] = output layers['rnn_output_state'] = output_state - # Now we feed `output` to the fifth hidden layer with clipped RELU activation and dropout - b5 = variable_on_cpu('b5', [Config.n_hidden_5], tf.zeros_initializer()) - h5 = variable_on_cpu('h5', [Config.n_cell_dim, Config.n_hidden_5], tf.contrib.layers.xavier_initializer()) - layer_5 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(output, h5), b5)), FLAGS.relu_clip) - layer_5 = tf.nn.dropout(layer_5, rate=dropout[5]) - layers['layer_5'] = layer_5 + # Now we feed `output` to the fifth hidden layer with clipped RELU activation + layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5) - # Now we apply the weight matrix `h6` and bias `b6` to the output of `layer_5` - # creating `n_classes` dimensional vectors, the logits. - b6 = variable_on_cpu('b6', [Config.n_hidden_6], tf.zeros_initializer()) - h6 = variable_on_cpu('h6', [Config.n_hidden_5, Config.n_hidden_6], tf.contrib.layers.xavier_initializer()) - layer_6 = tf.add(tf.matmul(layer_5, h6), b6) - layers['layer_6'] = layer_6 + # Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits. + layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False) # Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6] # to the slightly more useful shape [n_steps, batch_size, n_hidden_6]. # Note, that this differs from the input in that it is time-major. - layer_6 = tf.reshape(layer_6, [n_steps, batch_size, Config.n_hidden_6], name="raw_logits") + layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits') layers['raw_logits'] = layer_6 # Output shape: [n_steps, batch_size, n_hidden_6] @@ -166,17 +172,17 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1 # Conveniently, this loss function is implemented in TensorFlow. # Thus, we can simply make use of this implementation to define our loss. -def calculate_mean_edit_distance_and_loss(model_feeder, tower, dropout, reuse): +def calculate_mean_edit_distance_and_loss(iterator, tower, dropout, reuse): r''' This routine beam search decodes a mini-batch and calculates the loss and mean edit distance. Next to total and average loss it returns the mean edit distance, the decoded result and the batch's original Y. ''' # Obtain the next batch of data - batch_x, batch_seq_len, batch_y = model_feeder.next_batch(tower) + (batch_x, batch_seq_len), batch_y = iterator.get_next() - # Calculate the logits of the batch using BiRNN - logits, _ = BiRNN(batch_x, batch_seq_len, dropout, reuse) + # Calculate the logits of the batch + logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse) # Compute the CTC loss using TensorFlow's `ctc_loss` total_loss = tf.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len) @@ -221,7 +227,7 @@ def create_optimizer(): # on which all operations within the tower execute. # For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`. -def get_tower_results(model_feeder, optimizer, dropout_rates): +def get_tower_results(iterator, optimizer, dropout_rates): r''' With this preliminary step out of the way, we can for each GPU introduce a tower for which's batch we calculate and return the optimization gradients @@ -243,7 +249,7 @@ def get_tower_results(model_feeder, optimizer, dropout_rates): with tf.name_scope('tower_%d' % i) as scope: # Calculate the avg_loss and mean_edit_distance and retrieve the decoded # batch along with the original batch's labels (Y) of this tower - avg_loss = calculate_mean_edit_distance_and_loss(model_feeder, i, dropout_rates, reuse=i>0) + avg_loss = calculate_mean_edit_distance_and_loss(iterator, i, dropout_rates, reuse=i>0) # Allow for variables to be re-used by the next tower tf.get_variable_scope().reuse_variables() @@ -337,19 +343,6 @@ def log_grads_and_vars(grads_and_vars): log_variable(variable, gradient=gradient) -# Helpers -# ======= - - -class SampleIndex: - def __init__(self, index=0): - self.index = index - - def inc(self, old_index): - self.index += 1 - return self.index - - def try_loading(session, saver, checkpoint_filename, caption): try: checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename) @@ -371,48 +364,23 @@ def try_loading(session, saver, checkpoint_filename, caption): def train(): - r''' - Trains the network on a given server of a cluster. - If no server provided, it performs single process training. - ''' + # Create training and validation datasets + train_set, train_batches = create_dataset(FLAGS.train_files.split(','), + batch_size=FLAGS.train_batch_size, + cache_path=FLAGS.train_cached_features_path) - # Reading training set - train_index = SampleIndex() + iterator = tf.data.Iterator.from_structure(train_set.output_types, + train_set.output_shapes, + output_classes=train_set.output_classes) - train_data = preprocess(FLAGS.train_files.split(','), - FLAGS.train_batch_size, - Config.n_input, - Config.n_context, - Config.alphabet, - hdf5_cache_path=FLAGS.train_cached_features_path) + # Make initialization ops for switching between the two sets + train_init_op = iterator.make_initializer(train_set) - train_set = DataSet(train_data, - FLAGS.train_batch_size, - limit=FLAGS.limit_train, - next_index=train_index.inc) - - # Reading validation set - dev_index = SampleIndex() - - dev_data = preprocess(FLAGS.dev_files.split(','), - FLAGS.dev_batch_size, - Config.n_input, - Config.n_context, - Config.alphabet, - hdf5_cache_path=FLAGS.dev_cached_features_path) - - dev_set = DataSet(dev_data, - FLAGS.dev_batch_size, - limit=FLAGS.limit_dev, - next_index=dev_index.inc) - - # Combining all sets to a multi set model feeder - model_feeder = ModelFeeder(train_set, - dev_set, - Config.n_input, - Config.n_context, - Config.alphabet, - tower_feeder_count=len(Config.available_devices)) + if FLAGS.dev: + dev_set, dev_batches = create_dataset(FLAGS.dev_files.split(','), + batch_size=FLAGS.dev_batch_size, + cache_path=FLAGS.dev_cached_features_path) + dev_init_op = iterator.make_initializer(dev_set) # Dropout dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)] @@ -425,17 +393,12 @@ def train(): dropout_rates[5]: FLAGS.dropout_rate6, } no_dropout_feed_dict = { - dropout_rates[0]: 0., - dropout_rates[1]: 0., - dropout_rates[2]: 0., - dropout_rates[3]: 0., - dropout_rates[4]: 0., - dropout_rates[5]: 0., + rate: 0. for rate in dropout_rates } # Building the graph optimizer = create_optimizer() - gradients, loss = get_tower_results(model_feeder, optimizer, dropout_rates) + gradients, loss = get_tower_results(iterator, optimizer, dropout_rates) # Average tower gradients across GPUs avg_tower_gradients = average_gradients(gradients) log_grads_and_vars(avg_tower_gradients) @@ -463,6 +426,7 @@ def train(): with tf.Session(config=Config.session_config) as session: log_debug('Session opened.') + tf.get_default_graph().finalize() # Loading or initializing @@ -481,51 +445,51 @@ def train(): sys.exit(1) # Retrieving global_step from restored model and setting training parameters accordingly - model_feeder.set_data_set(no_dropout_feed_dict, train_set) - step = session.run(global_step, feed_dict=no_dropout_feed_dict) + step = session.run(global_step) num_gpus = len(Config.available_devices) - steps_per_epoch = max(1, train_set.total_batches // num_gpus) - steps_trained = step % steps_per_epoch + steps_per_epoch = max(1, train_batches // num_gpus) current_epoch = step // steps_per_epoch target_epoch = current_epoch + abs(FLAGS.epoch) if FLAGS.epoch < 0 else FLAGS.epoch - train_index.index = steps_trained * num_gpus log_debug('step: %d' % step) log_debug('epoch: %d' % current_epoch) log_debug('target epoch: %d' % target_epoch) log_debug('steps per epoch: %d' % steps_per_epoch) log_debug('batches per step (GPUs): %d' % num_gpus) - log_debug('number of batches in train set: %d' % train_set.total_batches) - log_debug('number of batches already trained in epoch: %d' % train_index.index) + log_debug('number of batches in train set: %d' % train_batches) - def run_set(set_name): - data_set = getattr(model_feeder, set_name) + def run_set(set_name, init_op, num_batches): is_train = set_name == 'train' train_op = apply_gradient_op if is_train else [] feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict - model_feeder.set_data_set(feed_dict, data_set) total_loss = 0.0 step_summary_writer = step_summary_writers.get(set_name) - num_steps = max(1, data_set.total_batches // num_gpus) + num_steps = max(1, num_batches // num_gpus) checkpoint_time = time.time() + if FLAGS.show_progressbar: pbar = progressbar.ProgressBar(max_value=num_steps, redirect_stdout=True).start() + else: + pbar = lambda i: i + + # Initialize iterator to the appropriate dataset + session.run(init_op) + # Batch loop - for step_index in range(steps_trained, num_steps): + for step_index in pbar(range(num_steps)): if coord.should_stop(): break + _, current_step, batch_loss, step_summary = \ session.run([train_op, global_step, loss, step_summaries_op], feed_dict=feed_dict) total_loss += batch_loss step_summary_writer.add_summary(step_summary, current_step) - if FLAGS.show_progressbar: - pbar.update(step_index + 1, force=True) + if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs: checkpoint_saver.save(session, checkpoint_path, global_step=current_step) checkpoint_time = time.time() - if FLAGS.show_progressbar: - pbar.finish() + return total_loss / num_steps if target_epoch > current_epoch: @@ -534,68 +498,64 @@ def train(): dev_losses = [] coord = tf.train.Coordinator() with coord.stop_on_exception(): - log_debug('Starting queue runners...') - model_feeder.start_queue_threads(session, coord=coord) - log_debug('Queue runners started.') - # Epoch loop for current_epoch in range(current_epoch, target_epoch): - # Training if coord.should_stop(): break + + # Training log_info('Training epoch %d ...' % current_epoch) - train_loss = run_set('train') + train_loss = run_set('train', train_init_op, train_batches) log_info('Finished training epoch %d - loss: %f' % (current_epoch, train_loss)) checkpoint_saver.save(session, checkpoint_path, global_step=global_step) - steps_trained = 0 - # Validation - log_info('Validating epoch %d ...' % current_epoch) - dev_loss = run_set('dev') - dev_losses.append(dev_loss) - log_info('Finished validating epoch %d - loss: %f' % (current_epoch, dev_loss)) - if dev_loss < best_dev_loss: - best_dev_loss = dev_loss - save_path = best_dev_saver.save(session, best_dev_path, latest_filename=best_dev_filename) - log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path)) - # Early stopping - if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps: - mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1]) - std_loss = np.std(dev_losses[-FLAGS.es_steps:-1]) - dev_losses = dev_losses[-FLAGS.es_steps:] - log_debug('Checking for early stopping (last %d steps) validation loss: ' - '%f, with standard deviation: %f and mean: %f' % - (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss)) - if dev_losses[-1] > np.max(dev_losses[:-1]) or \ - (abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th): - log_info('Early stop triggered as (for last %d steps) validation loss:' - ' %f with standard deviation: %f and mean: %f' % - (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss)) - break - log_debug('Closing queues...') + + if FLAGS.dev: + # Validation + log_info('Validating epoch %d ...' % current_epoch) + dev_loss = run_set('dev', dev_init_op, dev_batches) + dev_losses.append(dev_loss) + log_info('Finished validating epoch %d - loss: %f' % (current_epoch, dev_loss)) + + if dev_loss < best_dev_loss: + best_dev_loss = dev_loss + save_path = best_dev_saver.save(session, best_dev_path, latest_filename=best_dev_filename) + log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path)) + + # Early stopping + if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps: + mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1]) + std_loss = np.std(dev_losses[-FLAGS.es_steps:-1]) + dev_losses = dev_losses[-FLAGS.es_steps:] + log_debug('Checking for early stopping (last %d steps) validation loss: ' + '%f, with standard deviation: %f and mean: %f' % + (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss)) + if dev_losses[-1] > np.max(dev_losses[:-1]) or \ + (abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th): + log_info('Early stop triggered as (for last %d steps) validation loss:' + ' %f with standard deviation: %f and mean: %f' % + (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss)) + break coord.request_stop() - model_feeder.close_queues(session) - log_debug('Queues closed.') else: log_info('Target epoch already reached - skipped training.') log_debug('Session closed.') def test(): - # Reading test set - test_data = preprocess(FLAGS.test_files.split(','), - FLAGS.test_batch_size, - Config.n_input, - Config.n_context, - Config.alphabet, - hdf5_cache_path=FLAGS.test_cached_features_path) - - graph = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1) - evaluate.evaluate(test_data, graph) + evaluate.evaluate(FLAGS.test_files.split(','), create_model) def create_inference_graph(batch_size=1, n_steps=16, tflite=False): batch_size = batch_size if batch_size > 0 else None + + # Create feature computation graph + input_samples = tf.placeholder(tf.float32, [None], 'input_samples') + samples = tf.expand_dims(input_samples, -1) + mfccs, mfccs_len = samples_to_mfccs(samples, 16000) + mfccs = tf.identity(mfccs, name='mfccs') + mfccs_len = tf.identity(mfccs_len, name='mfccs_len') + # Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input] - input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2*Config.n_context+1, Config.n_input], name='input_node') + input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node') seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths') if batch_size <= 0: @@ -611,15 +571,20 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h) - no_dropout = [0.0] * 6 + # One rate per layer + no_dropout = [None] * 6 - logits, layers = BiRNN(batch_x=input_tensor, - seq_length=seq_length if FLAGS.use_seq_length else None, - dropout=no_dropout, - batch_size=batch_size, - n_steps=n_steps, - previous_state=previous_state, - tflite=tflite) + if tflite: + rnn_impl = rnn_impl_static_rnn + else: + rnn_impl = rnn_impl_lstmblockfusedcell + + logits, layers = create_model(batch_x=input_tensor, + seq_length=seq_length if FLAGS.use_seq_length else None, + dropout=no_dropout, + previous_state=previous_state, + overlap=False, + rnn_impl=rnn_impl) # TF Lite runtime will check that input dimensions are 1, 2 or 4 # by default we get 3, the middle one being batch_size which is forced to @@ -659,10 +624,13 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): { 'input': input_tensor, 'input_lengths': seq_length, + 'input_samples': input_samples, }, { 'outputs': logits, 'initialize_state': initialize_state, + 'mfccs': mfccs, + 'mfccs_len': mfccs_len, }, layers ) @@ -671,16 +639,24 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): new_state_c = tf.identity(new_state_c, name='new_state_c') new_state_h = tf.identity(new_state_h, name='new_state_h') + inputs = { + 'input': input_tensor, + 'previous_state_c': previous_state_c, + 'previous_state_h': previous_state_h, + 'input_samples': input_samples, + } + + if FLAGS.use_seq_length: + inputs.update({'input_lengths': seq_length}) + return ( - { - 'input': input_tensor, - 'previous_state_c': previous_state_c, - 'previous_state_h': previous_state_h, - }, + inputs, { 'outputs': logits, 'new_state_c': new_state_c, 'new_state_h': new_state_h, + 'mfccs': mfccs, + 'mfccs_len': mfccs_len, }, layers ) @@ -753,6 +729,8 @@ def export(): converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values()) converter.post_training_quantize = True + # AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite + converter.allow_custom_ops = True tflite_model = converter.convert() with open(output_tflite_path, 'wb') as fout: @@ -764,6 +742,7 @@ def export(): except RuntimeError as e: log_error(str(e)) + def do_single_file_inference(input_file_path): with tf.Session(config=Config.session_config) as session: inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1) @@ -784,22 +763,20 @@ def do_single_file_inference(input_file_path): saver.restore(session, checkpoint_path) session.run(outputs['initialize_state']) - features = audiofile_to_input_vector(input_file_path, Config.n_input, Config.n_context) - num_strides = len(features) - (Config.n_context * 2) + features, features_len = audiofile_to_features(input_file_path) - # Create a view into the array with overlapping strides of size - # numcontext (past) + 1 (present) + numcontext (future) - window_size = 2*Config.n_context+1 - features = np.lib.stride_tricks.as_strided( - features, - (num_strides, window_size, Config.n_input), - (features.strides[0], features.strides[0], features.strides[1]), - writeable=False) + # Add batch dimension + features = tf.expand_dims(features, 0) + features_len = tf.expand_dims(features_len, 0) - logits = session.run(outputs['outputs'], feed_dict = { - inputs['input']: [features], - inputs['input_lengths']: [num_strides], - }) + # Evaluate + features = create_overlapping_windows(features).eval(session=session) + features_len = features_len.eval(session=session) + + logits = outputs['outputs'].eval(feed_dict={ + inputs['input']: features, + inputs['input_lengths']: features_len, + }, session=session) logits = np.squeeze(logits) diff --git a/README.md b/README.md index 3a906d90..78851870 100644 --- a/README.md +++ b/README.md @@ -334,7 +334,7 @@ Refer to the corresponding [README.md](native_client/README.md) for information ### Exporting a model for TFLite -If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the `--export_tflite` flag. If you already have a trained model, you can re-export it for TFLite by running `DeepSpeech.py` again and specifying the same `checkpoint_dir` that you used for training, as well as passing `--notrain --notest --export_tflite --export_dir /model/export/destination`. +If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the `--nouse_seq_length --export_tflite` flags. If you already have a trained model, you can re-export it for TFLite by running `DeepSpeech.py` again and specifying the same `checkpoint_dir` that you used for training, as well as passing `--notrain --notest --nouse_seq_length --export_tflite --export_dir /model/export/destination`. ### Making a mmap-able model for inference diff --git a/bin/run-ldc93s1.sh b/bin/run-ldc93s1.sh index a48a85e5..a50d1dfb 100755 --- a/bin/run-ldc93s1.sh +++ b/bin/run-ldc93s1.sh @@ -16,7 +16,7 @@ else checkpoint_dir=$(python -c 'from xdg import BaseDirectory as xdg; print(xdg.save_data_path("deepspeech/ldc93s1"))') fi -python -u DeepSpeech.py --noshow_progressbar --noearly_stop \ +python -u DeepSpeech.py --noshow_progressbar --nodev \ --train_files data/ldc93s1/ldc93s1.csv \ --dev_files data/ldc93s1/ldc93s1.csv \ --test_files data/ldc93s1/ldc93s1.csv \ diff --git a/bin/run-tc-ldc93s1_tflite.sh b/bin/run-tc-ldc93s1_tflite.sh index 04b0ce82..874c7bf2 100755 --- a/bin/run-tc-ldc93s1_tflite.sh +++ b/bin/run-tc-ldc93s1_tflite.sh @@ -17,4 +17,4 @@ python -u DeepSpeech.py --noshow_progressbar \ --lm_binary_path 'data/smoke_test/vocab.pruned.lm' \ --lm_trie_path 'data/smoke_test/vocab.trie' \ --notrain --notest \ - --export_tflite \ + --export_tflite --nouse_seq_length \ diff --git a/evaluate.py b/evaluate.py index a231029f..e60abc09 100755 --- a/evaluate.py +++ b/evaluate.py @@ -5,92 +5,67 @@ from __future__ import absolute_import, division, print_function import itertools import json import numpy as np -import os -import pandas import progressbar -import sys -import tables import tensorflow as tf -from collections import namedtuple from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer -from multiprocessing import Pool, cpu_count +from multiprocessing import cpu_count from six.moves import zip, range -from util.audio import audiofile_to_input_vector from util.config import Config, initialize_globals +from util.evaluate_tools import calculate_report +from util.feeding import create_dataset from util.flags import create_flags, FLAGS from util.logging import log_error -from util.preprocess import preprocess -from util.text import Alphabet, levenshtein -from util.evaluate_tools import process_decode_result, calculate_report - -def split_data(dataset, batch_size): - remainder = len(dataset) % batch_size - if remainder != 0: - dataset = dataset[:-remainder] - - for i in range(0, len(dataset), batch_size): - yield dataset[i:i + batch_size] +from util.text import levenshtein -def pad_to_dense(jagged): - maxlen = max(len(r) for r in jagged) - subshape = jagged[0].shape - - padded = np.zeros((len(jagged), maxlen) + - subshape[1:], dtype=jagged[0].dtype) - for i, row in enumerate(jagged): - padded[i, :len(row)] = row - return padded +def sparse_tensor_value_to_texts(value, alphabet): + r""" + Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings + representing its values, converting tokens to strings using ``alphabet``. + """ + return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet) -def evaluate(test_data, inference_graph): +def sparse_tuple_to_texts(tuple, alphabet): + indices = tuple[0] + values = tuple[1] + results = [''] * tuple[2][0] + for i in range(len(indices)): + index = indices[i][0] + results[index] += alphabet.string_from_label(values[i]) + # List of strings + return results + + +def evaluate(test_csvs, create_model): scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path, FLAGS.lm_trie_path, Config.alphabet) + test_set, test_batches = create_dataset(test_csvs, + batch_size=FLAGS.test_batch_size, + cache_path=FLAGS.test_cached_features_path) + it = test_set.make_one_shot_iterator() - def create_windows(features): - num_strides = len(features) - (Config.n_context * 2) + (batch_x, batch_x_len), batch_y = it.get_next() - # Create a view into the array with overlapping strides of size - # numcontext (past) + 1 (present) + numcontext (future) - window_size = 2*Config.n_context+1 - features = np.lib.stride_tricks.as_strided( - features, - (num_strides, window_size, Config.n_input), - (features.strides[0], features.strides[0], features.strides[1]), - writeable=False) + # One rate per layer + no_dropout = [None] * 6 + logits, _ = create_model(batch_x=batch_x, + seq_length=batch_x_len, + dropout=no_dropout) - return features + # Transpose to batch major and apply softmax for decoder + transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2])) - # Create overlapping windows over the features - test_data['features'] = test_data['features'].apply(create_windows) + loss = tf.nn.ctc_loss(labels=batch_y, + inputs=logits, + sequence_length=batch_x_len) with tf.Session(config=Config.session_config) as session: - inputs, outputs, layers = inference_graph - - # Transpose to batch major for decoder - transposed = tf.transpose(outputs['outputs'], [1, 0, 2]) - - labels_ph = tf.placeholder(tf.int32, [FLAGS.test_batch_size, None], name="labels") - label_lengths_ph = tf.placeholder(tf.int32, [FLAGS.test_batch_size], name="label_lengths") - - # We add 1 to all elements of the transcript to avoid any zero values - # since we use that as an end-of-sequence token for converting the batch - # into a SparseTensor. So here we convert the placeholder back into a - # SparseTensor and subtract ones to get the real labels. - sparse_labels = tf.contrib.layers.dense_to_sparse(labels_ph) - neg_ones = tf.SparseTensor(sparse_labels.indices, -1 * tf.ones_like(sparse_labels.values), sparse_labels.dense_shape) - sparse_labels = tf.sparse_add(sparse_labels, neg_ones) - - loss = tf.nn.ctc_loss(labels=sparse_labels, - inputs=layers['raw_logits'], - sequence_length=inputs['input_lengths']) - # Create a saver using variables from the above newly created graph - mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')} - saver = tf.train.Saver(mapping) + saver = tf.train.Saver() # Restore variables from training checkpoint checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) @@ -103,51 +78,38 @@ def evaluate(test_data, inference_graph): logitses = [] losses = [] + seq_lengths = [] + ground_truths = [] print('Computing acoustic model predictions...') - batch_count = len(test_data) // FLAGS.test_batch_size - bar = progressbar.ProgressBar(max_value=batch_count, + bar = progressbar.ProgressBar(max_value=test_batches, widget=progressbar.AdaptiveETA) # First pass, compute losses and transposed logits for decoding - for batch in bar(split_data(test_data, FLAGS.test_batch_size)): - session.run(outputs['initialize_state']) - - features = pad_to_dense(batch['features'].values) - features_len = batch['features_len'].values - labels = pad_to_dense(batch['transcript'].values + 1) - label_lengths = batch['transcript_len'].values - - logits, loss_ = session.run([transposed, loss], feed_dict={ - inputs['input']: features, - inputs['input_lengths']: features_len, - labels_ph: labels, - label_lengths_ph: label_lengths - }) + for batch in bar(range(test_batches)): + logits, loss_, lengths, transcripts = session.run([transposed, loss, batch_x_len, batch_y]) logitses.append(logits) losses.extend(loss_) + seq_lengths.append(lengths) + ground_truths.extend(sparse_tensor_value_to_texts(transcripts, Config.alphabet)) - ground_truths = [] predictions = [] - print('Decoding predictions...') - bar = progressbar.ProgressBar(max_value=batch_count, - widget=progressbar.AdaptiveETA) - # Get number of accessible CPU cores for this process try: num_processes = cpu_count() except: num_processes = 1 - # Second pass, decode logits and compute WER and edit distance metrics - for logits, batch in bar(zip(logitses, split_data(test_data, FLAGS.test_batch_size))): - seq_lengths = batch['features_len'].values.astype(np.int32) - decoded = ctc_beam_search_decoder_batch(logits, seq_lengths, Config.alphabet, FLAGS.beam_width, - num_processes=num_processes, scorer=scorer) + print('Decoding predictions...') + bar = progressbar.ProgressBar(max_value=test_batches, + widget=progressbar.AdaptiveETA) - ground_truths.extend(Config.alphabet.decode(l) for l in batch['transcript']) + # Second pass, decode logits and compute WER and edit distance metrics + for logits, seq_length in bar(zip(logitses, seq_lengths)): + decoded = ctc_beam_search_decoder_batch(logits, seq_length, Config.alphabet, FLAGS.beam_width, + num_processes=num_processes, scorer=scorer) predictions.extend(d[0][1] for d in decoded) distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)] @@ -179,21 +141,8 @@ def main(_): 'the --test_files flag.') exit(1) - # sort examples by length, improves packing of batches and timesteps - test_data = preprocess( - FLAGS.test_files.split(','), - FLAGS.test_batch_size, - alphabet=Config.alphabet, - numcep=Config.n_input, - numcontext=Config.n_context, - hdf5_cache_path=FLAGS.hdf5_test_set).sort_values( - by="features_len", - ascending=False) - - from DeepSpeech import create_inference_graph - graph = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1) - - samples = evaluate(test_data, graph) + from DeepSpeech import create_model + samples = evaluate(FLAGS.test_files.split(','), create_model) if FLAGS.test_output_file: # Save decoded tuples as JSON, converting NumPy floats to Python floats diff --git a/native_client/BUILD b/native_client/BUILD index 1e5587de..9699460b 100644 --- a/native_client/BUILD +++ b/native_client/BUILD @@ -119,6 +119,10 @@ tf_cc_shared_object( "//tensorflow/core/kernels:control_flow_ops", # Enter "//tensorflow/core/kernels:tile_ops", # Tile "//tensorflow/core/kernels:gather_op", # Gather + "//tensorflow/core/kernels:mfcc_op", # Mfcc + "//tensorflow/core/kernels:spectrogram_op", # AudioSpectrogram + "//tensorflow/core/kernels:strided_slice_op", # StridedSlice + "//tensorflow/core/kernels:slice_op", # Slice, needed by StridedSlice "//tensorflow/contrib/rnn:lstm_ops_kernels", # BlockLSTM "//tensorflow/core/kernels:random_ops", # RandomGammaGrad "//tensorflow/core/kernels:pack_op", # Pack diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index 31418072..f710f4e4 100644 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -108,7 +108,6 @@ using std::vector; struct StreamingState { vector accumulated_logits; vector audio_buffer; - float last_sample; // used for preemphasis vector mfcc_buffer; vector batch_buffer; ModelState* model; @@ -152,10 +151,13 @@ struct ModelState { int input_node_idx; int previous_state_c_idx; int previous_state_h_idx; + int input_samples_idx; int logits_idx; int new_state_c_idx; int new_state_h_idx; + int mfccs_idx; + int mfccs_len_idx; #endif ModelState(); @@ -204,7 +206,9 @@ struct ModelState { * * @param[out] output_logits Where to store computed logits. */ - void infer(const float* mfcc, unsigned int n_frames, vector& output_logits); + void infer(const float* mfcc, unsigned int n_frames, vector& logits_output); + + void compute_mfcc(const vector audio_buffer, vector& mfcc_output); }; StreamingState* setupStreamAndFeedAudioContent(ModelState* aCtx, const short* aBuffer, @@ -258,10 +262,9 @@ StreamingState::feedAudioContent(const short* buffer, // Consume all the data that was passed in, processing full buffers if needed while (buffer_size > 0) { while (buffer_size > 0 && audio_buffer.size() < AUDIO_WIN_LEN_SAMPLES) { - // Apply preemphasis to input sample and buffer it - float sample = (float)(*buffer) - (PREEMPHASIS_COEFF * last_sample); - audio_buffer.push_back(sample); - last_sample = *buffer; + // Convert i16 sample into f32 + float multiplier = 1.0f / (1 << 15); + audio_buffer.push_back((float)(*buffer) * multiplier); ++buffer; --buffer_size; } @@ -304,15 +307,11 @@ void StreamingState::processAudioWindow(const vector& buf) { // Compute MFCC features - float* mfcc; - int n_frames = csf_mfcc(buf.data(), buf.size(), SAMPLE_RATE, - AUDIO_WIN_LEN, AUDIO_WIN_STEP, MFCC_FEATURES, N_FILTERS, N_FFT, - LOWFREQ, SAMPLE_RATE/2, 0.f, CEP_LIFTER, 1, hamming_window.data(), - &mfcc); - assert(n_frames == 1); + vector mfcc; + mfcc.reserve(MFCC_FEATURES); + model->compute_mfcc(buf, mfcc); - pushMfccBuffer(mfcc, n_frames * MFCC_FEATURES); - free(mfcc); + pushMfccBuffer(mfcc.data(), MFCC_FEATURES); } void @@ -396,7 +395,7 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector& logi input_mapped(i) = aMfcc[i]; } for (; i < n_steps*mfcc_feats_per_timestep; ++i) { - input_mapped(i) = 0; + input_mapped(i) = 0.; } Tensor input_lengths(DT_INT32, TensorShape({1})); @@ -454,6 +453,53 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector& logi #endif // USE_TFLITE } +void +ModelState::compute_mfcc(const vector samples, vector& mfcc_output) +{ +#ifndef USE_TFLITE + Tensor input(DT_FLOAT, TensorShape({static_cast(samples.size())})); + auto input_mapped = input.flat(); + for (int i = 0; i < samples.size(); ++i) { + input_mapped(i) = samples[i]; + } + + vector outputs; + Status status = session->Run({{"input_samples", input}}, {"mfccs", "mfccs_len"}, {}, &outputs); + + if (!status.ok()) { + std::cerr << "Error running session: " << status << "\n"; + return; + } + + auto mfcc_len_mapped = outputs[1].flat(); + int n_windows = mfcc_len_mapped(0); + + auto mfcc_mapped = outputs[0].flat(); + for (int i = 0; i < n_windows * MFCC_FEATURES; ++i) { + mfcc_output.push_back(mfcc_mapped(i)); + } +#else + // Feeding input_node + float* input_samples = interpreter->typed_tensor(input_samples_idx); + for (int i = 0; i < samples.size(); ++i) { + input_samples[i] = samples[i]; + } + + TfLiteStatus status = interpreter->Invoke(); + if (status != kTfLiteOk) { + std::cerr << "Error running session: " << status << "\n"; + return; + } + + int n_windows = *interpreter->typed_tensor(mfccs_len_idx); + + float* outputs = interpreter->typed_tensor(mfccs_idx); + for (int i = 0; i < n_windows * MFCC_FEATURES; ++i) { + mfcc_output.push_back(outputs[i]); + } +#endif +} + char* ModelState::decode(vector& logits) { @@ -640,8 +686,6 @@ DS_CreateModel(const char* aModelPath, *retval = model.release(); return DS_ERR_OK; #else // USE_TFLITE - TfLiteStatus status; - model->fbmodel = tflite::FlatBufferModel::BuildFromFile(aModelPath); if (!model->fbmodel) { std::cerr << "Error at reading model file " << aModelPath << std::endl; @@ -663,9 +707,12 @@ DS_CreateModel(const char* aModelPath, model->input_node_idx = tflite_get_input_tensor_by_name(model.get(), "input_node"); model->previous_state_c_idx = tflite_get_input_tensor_by_name(model.get(), "previous_state_c"); model->previous_state_h_idx = tflite_get_input_tensor_by_name(model.get(), "previous_state_h"); + model->input_samples_idx = tflite_get_input_tensor_by_name(model.get(), "input_samples"); model->logits_idx = tflite_get_output_tensor_by_name(model.get(), "logits"); model->new_state_c_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_c"); model->new_state_h_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_h"); + model->mfccs_idx = tflite_get_output_tensor_by_name(model.get(), "mfccs"); + model->mfccs_len_idx = tflite_get_output_tensor_by_name(model.get(), "mfccs_len"); TfLiteIntArray* dims_input_node = model->interpreter->tensor(model->input_node_idx)->dims; @@ -796,7 +843,6 @@ DS_SetupStream(ModelState* aCtx, ctx->accumulated_logits.reserve(aPreAllocFrames * BATCH_SIZE * num_classes); ctx->audio_buffer.reserve(AUDIO_WIN_LEN_SAMPLES); - ctx->last_sample = 0; ctx->mfcc_buffer.reserve(aCtx->mfcc_feats_per_timestep); ctx->mfcc_buffer.resize(MFCC_FEATURES*aCtx->n_context, 0.f); ctx->batch_buffer.reserve(aCtx->n_steps * aCtx->mfcc_feats_per_timestep); diff --git a/requirements.txt b/requirements.txt index 4d5f1439..2958643b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,19 +1,23 @@ -pandas -progressbar2 -python-utils +# Main training requirements tensorflow == 1.13.1 numpy == 1.15.4 -matplotlib -scipy -sox -paramiko >= 2.1 -python_speech_features -pyxdg -bs4 +progressbar2 +pandas six -requests -tables +pyxdg attrdict + +# Requirements for building native_client files setuptools + +# Requirements for importers +sox +bs4 +requests librosa soundfile + +# Miscellaneous scripts +paramiko >= 2.1 +scipy +matplotlib diff --git a/util/audio.py b/util/audio.py deleted file mode 100644 index ac9dde63..00000000 --- a/util/audio.py +++ /dev/null @@ -1,24 +0,0 @@ -import numpy as np -import scipy.io.wavfile as wav - -from python_speech_features import mfcc - - -def audiofile_to_input_vector(audio_filename, numcep, numcontext): - r""" - Given a WAV audio file at ``audio_filename``, calculates ``numcep`` MFCC features - at every 0.01s time step with a window length of 0.025s. Appends ``numcontext`` - context frames to the left and right of each time step, and returns this data - in a numpy array. - """ - # Load wav files - fs, audio = wav.read(audio_filename) - - # Get mfcc coefficients - features = mfcc(audio, samplerate=fs, numcep=numcep, winlen=0.032, winstep=0.02, winfunc=np.hamming) - - # Add empty initial and final contexts - empty_context = np.zeros((numcontext, numcep), dtype=features.dtype) - features = np.concatenate((empty_context, features, empty_context)) - - return features diff --git a/util/feeding.py b/util/feeding.py index 71e2a9fc..2feb5bbc 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -1,198 +1,97 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, division, print_function + import numpy as np +import os +import pandas import tensorflow as tf -from math import ceil -from six.moves import range -from threading import Thread -from util.gpu import get_available_gpus +from functools import partial +from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio +from util.config import Config +from util.text import text_to_char_array -class ModelFeeder(object): - ''' - Feeds data into a model. - Feeding is parallelized by independent units called tower feeders (usually one per GPU). - Each tower feeder provides data from runtime switchable sources (train, dev). - These sources are to be provided by the DataSet instances whose references are kept. - Creates, owns and delegates to tower_feeder_count internal tower feeder objects. - ''' - def __init__(self, - train_set, - dev_set, - numcep, - numcontext, - alphabet, - tower_feeder_count=-1, - threads_per_queue=4): - - self.train = train_set - self.dev = dev_set - self.sets = [train_set, dev_set] - self.numcep = numcep - self.numcontext = numcontext - self.tower_feeder_count = max(len(get_available_gpus()), 1) if tower_feeder_count < 0 else tower_feeder_count - self.threads_per_queue = threads_per_queue - - self.ph_x = tf.placeholder(tf.float32, [None, 2*numcontext+1, numcep]) - self.ph_x_length = tf.placeholder(tf.int32, []) - self.ph_y = tf.placeholder(tf.int32, [None,]) - self.ph_y_length = tf.placeholder(tf.int32, []) - self.ph_batch_size = tf.placeholder(tf.int32, []) - self.ph_queue_selector = tf.placeholder(tf.int32, name='Queue_Selector') - - self._tower_feeders = [_TowerFeeder(self, i, alphabet) for i in range(self.tower_feeder_count)] - - def start_queue_threads(self, session, coord): - ''' - Starts required queue threads on all tower feeders. - ''' - queue_threads = [] - for tower_feeder in self._tower_feeders: - queue_threads += tower_feeder.start_queue_threads(session, coord) - return queue_threads - - def close_queues(self, session): - ''' - Closes queues of all tower feeders. - ''' - for tower_feeder in self._tower_feeders: - tower_feeder.close_queues(session) - - def set_data_set(self, feed_dict, data_set): - ''' - Switches all tower feeders to a different source DataSet. - The provided feed_dict will get enriched with required placeholder/value pairs. - The DataSet has to be one of those that got passed into the constructor. - ''' - index = self.sets.index(data_set) - assert index >= 0 - feed_dict[self.ph_queue_selector] = index - feed_dict[self.ph_batch_size] = data_set.batch_size - - def next_batch(self, tower_feeder_index): - ''' - Draw the next batch from one of the tower feeders. - ''' - return self._tower_feeders[tower_feeder_index].next_batch() +def read_csvs(csv_files): + source_data = None + for csv in csv_files: + file = pandas.read_csv(csv, encoding='utf-8', na_filter=False) + #FIXME: not cross-platform + csv_dir = os.path.dirname(os.path.abspath(csv)) + file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1))) + if source_data is None: + source_data = file + else: + source_data = source_data.append(file) + return source_data -class DataSet(object): - ''' - Represents a collection of audio samples and their respective transcriptions. - Takes a set of CSV files produced by importers in /bin. - ''' - def __init__(self, data, batch_size, skip=0, limit=0, ascending=True, next_index=lambda i: i + 1): - self.data = data - self.data.sort_values(by="features_len", ascending=ascending, inplace=True) - self.batch_size = batch_size - self.next_index = next_index - self.total_batches = int(ceil(len(self.data) / batch_size)) +def samples_to_mfccs(samples, sample_rate): + spectrogram = contrib_audio.audio_spectrogram(samples, window_size=512, stride=320, magnitude_squared=True) + mfccs = contrib_audio.mfcc(spectrogram, sample_rate, dct_coefficient_count=Config.n_input) + mfccs = tf.reshape(mfccs, [-1, Config.n_input]) + + return mfccs, tf.shape(mfccs)[0] -class _DataSetLoader(object): - ''' - Internal class that represents an input queue with data from one of the DataSet objects. - Each tower feeder will create and combine three data set loaders to one switchable queue. - Keeps a ModelFeeder reference for accessing shared settings and placeholders. - Keeps a DataSet reference to access its samples. - ''' - def __init__(self, model_feeder, data_set, alphabet): - self._model_feeder = model_feeder - self._data_set = data_set - self.queue = tf.PaddingFIFOQueue(shapes=[[None, 2 * model_feeder.numcontext + 1, model_feeder.numcep], [], [None,], []], - dtypes=[tf.float32, tf.int32, tf.int32, tf.int32], - capacity=data_set.batch_size * 8) - self._enqueue_op = self.queue.enqueue([model_feeder.ph_x, model_feeder.ph_x_length, model_feeder.ph_y, model_feeder.ph_y_length]) - self._close_op = self.queue.close(cancel_pending_enqueues=True) - self._alphabet = alphabet +def audiofile_to_features(wav_filename): + samples = tf.read_file(wav_filename) + decoded = contrib_audio.decode_wav(samples, desired_channels=1) + features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate) - def start_queue_threads(self, session, coord): - ''' - Starts concurrent queue threads for reading samples from the data set. - ''' - queue_threads = [Thread(target=self._populate_batch_queue, args=(session, coord)) - for i in range(self._model_feeder.threads_per_queue)] - for queue_thread in queue_threads: - coord.register_thread(queue_thread) - queue_thread.daemon = True - queue_thread.start() - return queue_threads - - def close_queue(self, session): - ''' - Closes the data set queue. - ''' - session.run(self._close_op) - - def _populate_batch_queue(self, session, coord): - ''' - Queue thread routine. - ''' - file_count = len(self._data_set.data) - index = -1 - while not coord.should_stop(): - index = self._data_set.next_index(index) % file_count - features, num_strides, transcript, transcript_len = self._data_set.data.iloc[index] - - # Create a view into the array with overlapping strides of size - # numcontext (past) + 1 (present) + numcontext (future) - window_size = 2*self._model_feeder.numcontext+1 - features = np.lib.stride_tricks.as_strided( - features, - (num_strides, window_size, self._model_feeder.numcep), - (features.strides[0], features.strides[0], features.strides[1]), - writeable=False) - - # We add 1 to all elements of the transcript here to avoid any zero - # values since we use that as an end-of-sequence token for converting - # the batch into a SparseTensor. - try: - session.run(self._enqueue_op, feed_dict={ - self._model_feeder.ph_x: features, - self._model_feeder.ph_x_length: num_strides, - self._model_feeder.ph_y: transcript + 1, - self._model_feeder.ph_y_length: transcript_len - }) - except tf.errors.CancelledError: - return + return features, features_len -class _TowerFeeder(object): - ''' - Internal class that represents a switchable input queue for one tower. - It creates, owns and combines three _DataSetLoader instances. - Keeps a ModelFeeder reference for accessing shared settings and placeholders. - ''' - def __init__(self, model_feeder, index, alphabet): - self._model_feeder = model_feeder - self.index = index - self._loaders = [_DataSetLoader(model_feeder, data_set, alphabet) for data_set in model_feeder.sets] - self._queues = [set_queue.queue for set_queue in self._loaders] - self._queue = tf.QueueBase.from_list(model_feeder.ph_queue_selector, self._queues) - self._close_op = self._queue.close(cancel_pending_enqueues=True) +def entry_to_features(wav_filename, transcript): + # https://bugs.python.org/issue32117 + features, features_len = audiofile_to_features(wav_filename) + return features, features_len, tf.SparseTensor(*transcript) - def next_batch(self): - ''' - Draw the next batch from from the combined switchable queue. - ''' - source, source_lengths, target, target_lengths = self._queue.dequeue_many(self._model_feeder.ph_batch_size) - # Back to sparse, then subtract one to get the real labels - sparse_labels = tf.contrib.layers.dense_to_sparse(target) - neg_ones = tf.SparseTensor(sparse_labels.indices, -1 * tf.ones_like(sparse_labels.values), sparse_labels.dense_shape) - return source, source_lengths, tf.sparse_add(sparse_labels, neg_ones) - def start_queue_threads(self, session, coord): - ''' - Starts the queue threads of all owned _DataSetLoader instances. - ''' - queue_threads = [] - for set_queue in self._loaders: - queue_threads += set_queue.start_queue_threads(session, coord) - return queue_threads +def to_sparse_tuple(sequence): + r"""Creates a sparse representention of ``sequence``. + Returns a tuple with (indices, values, shape) + """ + indices = np.asarray(list(zip([0]*len(sequence), range(len(sequence)))), dtype=np.int64) + shape = np.asarray([1, len(sequence)], dtype=np.int64) + return indices, sequence, shape - def close_queues(self, session): - ''' - Closes queues of all owned _DataSetLoader instances. - ''' - for set_queue in self._loaders: - set_queue.close_queue(session) +def create_dataset(csvs, batch_size, cache_path): + df = read_csvs(csvs) + df.sort_values(by='wav_filesize', inplace=True) + + num_batches = len(df) // batch_size + + # Convert to character index arrays + df['transcript'] = df['transcript'].apply(partial(text_to_char_array, alphabet=Config.alphabet)) + + def generate_values(): + for _, row in df.iterrows(): + yield row.wav_filename, to_sparse_tuple(row.transcript) + + # Batching a dataset of 2D SparseTensors creates 3D batches, which fail + # when passed to tf.nn.ctc_loss, so we reshape them to remove the extra + # dimension here. + def sparse_reshape(sparse): + shape = sparse.dense_shape + return tf.sparse.reshape(sparse, [shape[0], shape[2]]) + + def batch_fn(features, features_len, transcripts): + features = tf.data.Dataset.zip((features, features_len)) + features = features.padded_batch(batch_size, + padded_shapes=([None, Config.n_input], [])) + transcripts = transcripts.batch(batch_size).map(sparse_reshape) + return tf.data.Dataset.zip((features, transcripts)) + + num_gpus = len(Config.available_devices) + + dataset = (tf.data.Dataset.from_generator(generate_values, + output_types=(tf.string, (tf.int64, tf.int32, tf.int64))) + .map(entry_to_features, num_parallel_calls=tf.data.experimental.AUTOTUNE) + .cache(cache_path) + .window(batch_size, drop_remainder=True).flat_map(batch_fn) + .prefetch(num_gpus) + .repeat()) + + return dataset, num_batches diff --git a/util/flags.py b/util/flags.py index f0bac5e4..25cc9b8d 100644 --- a/util/flags.py +++ b/util/flags.py @@ -23,6 +23,7 @@ def create_flags(): # ================ tf.app.flags.DEFINE_boolean ('train', True, 'whether to train the network') + tf.app.flags.DEFINE_boolean ('dev', True, 'whether to run validation epochs') tf.app.flags.DEFINE_boolean ('test', True, 'whether to test the network') tf.app.flags.DEFINE_integer ('epoch', 75, 'target epoch to train - if negative, the absolute number of additional epochs will be trained') diff --git a/util/preprocess.py b/util/preprocess.py deleted file mode 100644 index 1feb1ebb..00000000 --- a/util/preprocess.py +++ /dev/null @@ -1,101 +0,0 @@ -import numpy as np -import os -import pandas -import tables - -from functools import partial -from multiprocessing.dummy import Pool -from util.audio import audiofile_to_input_vector -from util.text import text_to_char_array - -def pmap(fun, iterable): - pool = Pool() - results = pool.map(fun, iterable) - pool.close() - return results - - -def process_single_file(row, numcep, numcontext, alphabet): - # row = index, Series - _, file = row - features = audiofile_to_input_vector(file.wav_filename, numcep, numcontext) - features_len = len(features) - 2*numcontext - transcript = text_to_char_array(file.transcript, alphabet) - - if features_len < len(transcript): - raise ValueError('Error: Audio file {} is too short for transcription.'.format(file.wav_filename)) - - return features, features_len, transcript, len(transcript) - - -# load samples from CSV, compute features, optionally cache results on disk -def preprocess(csv_files, batch_size, numcep, numcontext, alphabet, hdf5_cache_path=None): - COLUMNS = ('features', 'features_len', 'transcript', 'transcript_len') - - print('Preprocessing', csv_files) - - if hdf5_cache_path and os.path.exists(hdf5_cache_path): - with tables.open_file(hdf5_cache_path, 'r') as file: - features = file.root.features[:] - features_len = file.root.features_len[:] - transcript = file.root.transcript[:] - transcript_len = file.root.transcript_len[:] - - # features are stored flattened, so reshape into [n_steps, numcep] - for i in range(len(features)): - features[i].shape = [features_len[i]+2*numcontext, numcep] - - in_data = list(zip(features, features_len, - transcript, transcript_len)) - print('Loaded from cache at', hdf5_cache_path) - return pandas.DataFrame(data=in_data, columns=COLUMNS) - - source_data = None - for csv in csv_files: - file = pandas.read_csv(csv, encoding='utf-8', na_filter=False) - #FIXME: not cross-platform - csv_dir = os.path.dirname(os.path.abspath(csv)) - file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1))) - if source_data is None: - source_data = file - else: - source_data = source_data.append(file) - - step_fn = partial(process_single_file, - numcep=numcep, - numcontext=numcontext, - alphabet=alphabet) - out_data = pmap(step_fn, source_data.iterrows()) - - if hdf5_cache_path: - print('Saving to', hdf5_cache_path) - - # list of tuples -> tuple of lists - features, features_len, transcript, transcript_len = zip(*out_data) - - with tables.open_file(hdf5_cache_path, 'w') as file: - features_dset = file.create_vlarray(file.root, - 'features', - tables.Float32Atom(), - filters=tables.Filters(complevel=1)) - # VLArray atoms need to be 1D, so flatten feature array - for f in features: - features_dset.append(np.reshape(f, -1)) - - features_len_dset = file.create_array(file.root, - 'features_len', - features_len) - - transcript_dset = file.create_vlarray(file.root, - 'transcript', - tables.Int32Atom(), - filters=tables.Filters(complevel=1)) - for t in transcript: - transcript_dset.append(t) - - transcript_len_dset = file.create_array(file.root, - 'transcript_len', - transcript_len) - - print('Preprocessing done') - return pandas.DataFrame(data=out_data, columns=COLUMNS)