Merge pull request #2282 from mozilla/dynamic-batch-size-in-train-val-graph

Use dynamic batch size in train/val graph
This commit is contained in:
Reuben Morais 2019-08-07 10:03:53 +02:00 committed by GitHub
commit c76070be19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -141,7 +141,7 @@ def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
return output, output_state return output, output_state
def create_model(batch_x, batch_size, seq_length, dropout, reuse=False, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell): def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
layers = {} layers = {}
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context] # Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
@ -207,7 +207,7 @@ def create_model(batch_x, batch_size, seq_length, dropout, reuse=False, previous
# Conveniently, this loss function is implemented in TensorFlow. # Conveniently, this loss function is implemented in TensorFlow.
# Thus, we can simply make use of this implementation to define our loss. # Thus, we can simply make use of this implementation to define our loss.
def calculate_mean_edit_distance_and_loss(iterator, dropout, batch_size, reuse): def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
r''' r'''
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance. This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
Next to total and average loss it returns the mean edit distance, Next to total and average loss it returns the mean edit distance,
@ -222,7 +222,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, batch_size, reuse):
rnn_impl = rnn_impl_lstmblockfusedcell rnn_impl = rnn_impl_lstmblockfusedcell
# Calculate the logits of the batch # Calculate the logits of the batch
logits, _ = create_model(batch_x, batch_size, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl) logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
# Compute the CTC loss using TensorFlow's `ctc_loss` # Compute the CTC loss using TensorFlow's `ctc_loss`
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len) total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
@ -267,7 +267,7 @@ def create_optimizer():
# on which all operations within the tower execute. # on which all operations within the tower execute.
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`. # For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
def get_tower_results(iterator, optimizer, dropout_rates, batch_size): def get_tower_results(iterator, optimizer, dropout_rates):
r''' r'''
With this preliminary step out of the way, we can for each GPU introduce a With this preliminary step out of the way, we can for each GPU introduce a
tower for which's batch we calculate and return the optimization gradients tower for which's batch we calculate and return the optimization gradients
@ -289,7 +289,7 @@ def get_tower_results(iterator, optimizer, dropout_rates, batch_size):
with tf.name_scope('tower_%d' % i): with tf.name_scope('tower_%d' % i):
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded # Calculate the avg_loss and mean_edit_distance and retrieve the decoded
# batch along with the original batch's labels (Y) of this tower # batch along with the original batch's labels (Y) of this tower
avg_loss = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, batch_size, reuse=i > 0) avg_loss = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
# Allow for variables to be re-used by the next tower # Allow for variables to be re-used by the next tower
tfv1.get_variable_scope().reuse_variables() tfv1.get_variable_scope().reuse_variables()
@ -436,7 +436,7 @@ def train():
# Building the graph # Building the graph
optimizer = create_optimizer() optimizer = create_optimizer()
gradients, loss = get_tower_results(iterator, optimizer, dropout_rates, FLAGS.train_batch_size) gradients, loss = get_tower_results(iterator, optimizer, dropout_rates)
# Average tower gradients across GPUs # Average tower gradients across GPUs
avg_tower_gradients = average_gradients(gradients) avg_tower_gradients = average_gradients(gradients)