adding automatic mixed precision training support

This commit is contained in:
Vinh Nguyen 2019-10-14 12:34:10 +00:00
parent b888058e4e
commit 909fa60601

View File

@ -446,6 +446,10 @@ def train():
# Building the graph # Building the graph
optimizer = create_optimizer() optimizer = create_optimizer()
# Enable mixed precision training
if FLAGS.automatic_mixed_precision:
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates) gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
# Average tower gradients across GPUs # Average tower gradients across GPUs