Merge pull request #2282 from mozilla/dynamic-batch-size-in-train-val-graph
Use dynamic batch size in train/val graph
This commit is contained in:
commit
c76070be19
@ -141,7 +141,7 @@ def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
|
|||||||
return output, output_state
|
return output, output_state
|
||||||
|
|
||||||
|
|
||||||
def create_model(batch_x, batch_size, seq_length, dropout, reuse=False, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
|
def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
|
||||||
layers = {}
|
layers = {}
|
||||||
|
|
||||||
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
|
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
|
||||||
@ -207,7 +207,7 @@ def create_model(batch_x, batch_size, seq_length, dropout, reuse=False, previous
|
|||||||
# Conveniently, this loss function is implemented in TensorFlow.
|
# Conveniently, this loss function is implemented in TensorFlow.
|
||||||
# Thus, we can simply make use of this implementation to define our loss.
|
# Thus, we can simply make use of this implementation to define our loss.
|
||||||
|
|
||||||
def calculate_mean_edit_distance_and_loss(iterator, dropout, batch_size, reuse):
|
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
|
||||||
r'''
|
r'''
|
||||||
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
|
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
|
||||||
Next to total and average loss it returns the mean edit distance,
|
Next to total and average loss it returns the mean edit distance,
|
||||||
@ -222,7 +222,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, batch_size, reuse):
|
|||||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||||
|
|
||||||
# Calculate the logits of the batch
|
# Calculate the logits of the batch
|
||||||
logits, _ = create_model(batch_x, batch_size, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
|
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
|
||||||
|
|
||||||
# Compute the CTC loss using TensorFlow's `ctc_loss`
|
# Compute the CTC loss using TensorFlow's `ctc_loss`
|
||||||
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
|
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
|
||||||
@ -267,7 +267,7 @@ def create_optimizer():
|
|||||||
# on which all operations within the tower execute.
|
# on which all operations within the tower execute.
|
||||||
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
|
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
|
||||||
|
|
||||||
def get_tower_results(iterator, optimizer, dropout_rates, batch_size):
|
def get_tower_results(iterator, optimizer, dropout_rates):
|
||||||
r'''
|
r'''
|
||||||
With this preliminary step out of the way, we can for each GPU introduce a
|
With this preliminary step out of the way, we can for each GPU introduce a
|
||||||
tower for which's batch we calculate and return the optimization gradients
|
tower for which's batch we calculate and return the optimization gradients
|
||||||
@ -289,7 +289,7 @@ def get_tower_results(iterator, optimizer, dropout_rates, batch_size):
|
|||||||
with tf.name_scope('tower_%d' % i):
|
with tf.name_scope('tower_%d' % i):
|
||||||
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
|
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
|
||||||
# batch along with the original batch's labels (Y) of this tower
|
# batch along with the original batch's labels (Y) of this tower
|
||||||
avg_loss = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, batch_size, reuse=i > 0)
|
avg_loss = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
|
||||||
|
|
||||||
# Allow for variables to be re-used by the next tower
|
# Allow for variables to be re-used by the next tower
|
||||||
tfv1.get_variable_scope().reuse_variables()
|
tfv1.get_variable_scope().reuse_variables()
|
||||||
@ -436,7 +436,7 @@ def train():
|
|||||||
|
|
||||||
# Building the graph
|
# Building the graph
|
||||||
optimizer = create_optimizer()
|
optimizer = create_optimizer()
|
||||||
gradients, loss = get_tower_results(iterator, optimizer, dropout_rates, FLAGS.train_batch_size)
|
gradients, loss = get_tower_results(iterator, optimizer, dropout_rates)
|
||||||
|
|
||||||
# Average tower gradients across GPUs
|
# Average tower gradients across GPUs
|
||||||
avg_tower_gradients = average_gradients(gradients)
|
avg_tower_gradients = average_gradients(gradients)
|
||||||
|
Loading…
Reference in New Issue
Block a user