adding automatic mixed precision training support
This commit is contained in:
parent
b888058e4e
commit
909fa60601
@ -446,6 +446,10 @@ def train():
|
||||
|
||||
# Building the graph
|
||||
optimizer = create_optimizer()
|
||||
|
||||
# Enable mixed precision training
|
||||
if FLAGS.automatic_mixed_precision:
|
||||
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
|
||||
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
|
||||
|
||||
# Average tower gradients across GPUs
|
||||
|
Loading…
Reference in New Issue
Block a user