Use static batch size whenever it's known

This commit is contained in:
Reuben Morais 2019-07-22 10:16:04 +02:00
parent 23f5bc090d
commit 6afd96e30f
2 changed files with 10 additions and 7 deletions

View File

@ -141,11 +141,12 @@ def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
return output, output_state
def create_model(batch_x, seq_length, dropout, reuse=False, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
def create_model(batch_x, batch_size, seq_length, dropout, reuse=False, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
layers = {}
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
batch_size = tf.shape(batch_x)[0]
if not batch_size:
batch_size = tf.shape(batch_x)[0]
# Create overlapping feature windows if needed
if overlap:
@ -206,7 +207,7 @@ def create_model(batch_x, seq_length, dropout, reuse=False, previous_state=None,
# Conveniently, this loss function is implemented in TensorFlow.
# Thus, we can simply make use of this implementation to define our loss.
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
def calculate_mean_edit_distance_and_loss(iterator, dropout, batch_size, reuse):
r'''
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
Next to total and average loss it returns the mean edit distance,
@ -221,7 +222,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
rnn_impl = rnn_impl_lstmblockfusedcell
# Calculate the logits of the batch
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
logits, _ = create_model(batch_x, batch_size, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
# Compute the CTC loss using TensorFlow's `ctc_loss`
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
@ -266,7 +267,7 @@ def create_optimizer():
# on which all operations within the tower execute.
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
def get_tower_results(iterator, optimizer, dropout_rates):
def get_tower_results(iterator, optimizer, dropout_rates, batch_size):
r'''
With this preliminary step out of the way, we can for each GPU introduce a
tower for which's batch we calculate and return the optimization gradients
@ -288,7 +289,7 @@ def get_tower_results(iterator, optimizer, dropout_rates):
with tf.name_scope('tower_%d' % i):
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
# batch along with the original batch's labels (Y) of this tower
avg_loss = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
avg_loss = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, batch_size, reuse=i > 0)
# Allow for variables to be re-used by the next tower
tfv1.get_variable_scope().reuse_variables()
@ -435,7 +436,7 @@ def train():
# Building the graph
optimizer = create_optimizer()
gradients, loss = get_tower_results(iterator, optimizer, dropout_rates)
gradients, loss = get_tower_results(iterator, optimizer, dropout_rates, FLAGS.train_batch_size)
# Average tower gradients across GPUs
avg_tower_gradients = average_gradients(gradients)
@ -626,6 +627,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
rnn_impl = rnn_impl_lstmblockfusedcell
logits, layers = create_model(batch_x=input_tensor,
batch_size=batch_size,
seq_length=seq_length if not FLAGS.export_tflite else None,
dropout=no_dropout,
previous_state=previous_state,

View File

@ -57,6 +57,7 @@ def evaluate(test_csvs, create_model, try_loading):
# One rate per layer
no_dropout = [None] * 6
logits, _ = create_model(batch_x=batch_x,
batch_size=FLAGS.test_batch_size,
seq_length=batch_x_len,
dropout=no_dropout)