From 6afd96e30f3ce1f0d32132643032c17deff2b8f2 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Mon, 22 Jul 2019 10:16:04 +0200 Subject: [PATCH] Use static batch size whenever it's known --- DeepSpeech.py | 16 +++++++++------- evaluate.py | 1 + 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/DeepSpeech.py b/DeepSpeech.py index 8492fa66..19e16d3b 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -141,11 +141,12 @@ def rnn_impl_static_rnn(x, seq_length, previous_state, reuse): return output, output_state -def create_model(batch_x, seq_length, dropout, reuse=False, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell): +def create_model(batch_x, batch_size, seq_length, dropout, reuse=False, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell): layers = {} # Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context] - batch_size = tf.shape(batch_x)[0] + if not batch_size: + batch_size = tf.shape(batch_x)[0] # Create overlapping feature windows if needed if overlap: @@ -206,7 +207,7 @@ def create_model(batch_x, seq_length, dropout, reuse=False, previous_state=None, # Conveniently, this loss function is implemented in TensorFlow. # Thus, we can simply make use of this implementation to define our loss. -def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse): +def calculate_mean_edit_distance_and_loss(iterator, dropout, batch_size, reuse): r''' This routine beam search decodes a mini-batch and calculates the loss and mean edit distance. Next to total and average loss it returns the mean edit distance, @@ -221,7 +222,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse): rnn_impl = rnn_impl_lstmblockfusedcell # Calculate the logits of the batch - logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl) + logits, _ = create_model(batch_x, batch_size, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl) # Compute the CTC loss using TensorFlow's `ctc_loss` total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len) @@ -266,7 +267,7 @@ def create_optimizer(): # on which all operations within the tower execute. # For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`. -def get_tower_results(iterator, optimizer, dropout_rates): +def get_tower_results(iterator, optimizer, dropout_rates, batch_size): r''' With this preliminary step out of the way, we can for each GPU introduce a tower for which's batch we calculate and return the optimization gradients @@ -288,7 +289,7 @@ def get_tower_results(iterator, optimizer, dropout_rates): with tf.name_scope('tower_%d' % i): # Calculate the avg_loss and mean_edit_distance and retrieve the decoded # batch along with the original batch's labels (Y) of this tower - avg_loss = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0) + avg_loss = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, batch_size, reuse=i > 0) # Allow for variables to be re-used by the next tower tfv1.get_variable_scope().reuse_variables() @@ -435,7 +436,7 @@ def train(): # Building the graph optimizer = create_optimizer() - gradients, loss = get_tower_results(iterator, optimizer, dropout_rates) + gradients, loss = get_tower_results(iterator, optimizer, dropout_rates, FLAGS.train_batch_size) # Average tower gradients across GPUs avg_tower_gradients = average_gradients(gradients) @@ -626,6 +627,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): rnn_impl = rnn_impl_lstmblockfusedcell logits, layers = create_model(batch_x=input_tensor, + batch_size=batch_size, seq_length=seq_length if not FLAGS.export_tflite else None, dropout=no_dropout, previous_state=previous_state, diff --git a/evaluate.py b/evaluate.py index a864935a..a8de7dc7 100755 --- a/evaluate.py +++ b/evaluate.py @@ -57,6 +57,7 @@ def evaluate(test_csvs, create_model, try_loading): # One rate per layer no_dropout = [None] * 6 logits, _ = create_model(batch_x=batch_x, + batch_size=FLAGS.test_batch_size, seq_length=batch_x_len, dropout=no_dropout)