adding automatic mixed precision training support
This commit is contained in:
parent
b888058e4e
commit
909fa60601
@ -446,6 +446,10 @@ def train():
|
|||||||
|
|
||||||
# Building the graph
|
# Building the graph
|
||||||
optimizer = create_optimizer()
|
optimizer = create_optimizer()
|
||||||
|
|
||||||
|
# Enable mixed precision training
|
||||||
|
if FLAGS.automatic_mixed_precision:
|
||||||
|
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
|
||||||
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
|
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
|
||||||
|
|
||||||
# Average tower gradients across GPUs
|
# Average tower gradients across GPUs
|
||||||
|
Loading…
Reference in New Issue
Block a user