Error message if a sample has non-finite loss
This commit is contained in:
parent
c76070be19
commit
248c01001e
@ -214,7 +214,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
|
||||
the decoded result and the batch's original Y.
|
||||
'''
|
||||
# Obtain the next batch of data
|
||||
_, (batch_x, batch_seq_len), batch_y = iterator.get_next()
|
||||
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
|
||||
|
||||
if FLAGS.use_cudnn_rnn:
|
||||
rnn_impl = rnn_impl_cudnn_rnn
|
||||
@ -227,11 +227,14 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
|
||||
# Compute the CTC loss using TensorFlow's `ctc_loss`
|
||||
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
|
||||
|
||||
# Check if any files lead to non finite loss
|
||||
non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss)))
|
||||
|
||||
# Calculate the average loss across the batch
|
||||
avg_loss = tf.reduce_mean(total_loss)
|
||||
|
||||
# Finally we return the average loss
|
||||
return avg_loss
|
||||
return avg_loss, non_finite_files
|
||||
|
||||
|
||||
# Adam Optimization
|
||||
@ -279,6 +282,9 @@ def get_tower_results(iterator, optimizer, dropout_rates):
|
||||
# Tower gradients to return
|
||||
tower_gradients = []
|
||||
|
||||
# Aggregate any non finite files in the batches
|
||||
tower_non_finite_files = []
|
||||
|
||||
with tfv1.variable_scope(tfv1.get_variable_scope()):
|
||||
# Loop over available_devices
|
||||
for i in range(len(Config.available_devices)):
|
||||
@ -289,7 +295,7 @@ def get_tower_results(iterator, optimizer, dropout_rates):
|
||||
with tf.name_scope('tower_%d' % i):
|
||||
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
|
||||
# batch along with the original batch's labels (Y) of this tower
|
||||
avg_loss = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
|
||||
avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
|
||||
|
||||
# Allow for variables to be re-used by the next tower
|
||||
tfv1.get_variable_scope().reuse_variables()
|
||||
@ -303,13 +309,15 @@ def get_tower_results(iterator, optimizer, dropout_rates):
|
||||
# Retain tower's gradients
|
||||
tower_gradients.append(gradients)
|
||||
|
||||
tower_non_finite_files.append(non_finite_files)
|
||||
|
||||
avg_loss_across_towers = tf.reduce_mean(tower_avg_losses, 0)
|
||||
|
||||
tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])
|
||||
|
||||
all_non_finite_files = tf.concat(tower_non_finite_files, axis=0)
|
||||
|
||||
# Return gradients and the average loss
|
||||
return tower_gradients, avg_loss_across_towers
|
||||
return tower_gradients, avg_loss_across_towers, all_non_finite_files
|
||||
|
||||
|
||||
def average_gradients(tower_gradients):
|
||||
@ -436,7 +444,7 @@ def train():
|
||||
|
||||
# Building the graph
|
||||
optimizer = create_optimizer()
|
||||
gradients, loss = get_tower_results(iterator, optimizer, dropout_rates)
|
||||
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
|
||||
|
||||
# Average tower gradients across GPUs
|
||||
avg_tower_gradients = average_gradients(gradients)
|
||||
@ -517,12 +525,17 @@ def train():
|
||||
# Batch loop
|
||||
while True:
|
||||
try:
|
||||
_, current_step, batch_loss, step_summary = \
|
||||
session.run([train_op, global_step, loss, step_summaries_op],
|
||||
_, current_step, batch_loss, problem_files, step_summary = \
|
||||
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
|
||||
feed_dict=feed_dict)
|
||||
except tf.errors.OutOfRangeError:
|
||||
break
|
||||
|
||||
if problem_files.size > 0:
|
||||
problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
|
||||
log_error('The following files caused an infinite (or NaN) '
|
||||
'loss: {}'.format(','.join(problem_files)))
|
||||
|
||||
total_loss += batch_loss
|
||||
step_count += 1
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user