diff --git a/DeepSpeech.py b/DeepSpeech.py index a2dd045a..870250e8 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -446,6 +446,10 @@ def train(): # Building the graph optimizer = create_optimizer() + + # Enable mixed precision training + if FLAGS.automatic_mixed_precision: + optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates) # Average tower gradients across GPUs