diff --git a/DeepSpeech.py b/DeepSpeech.py
index a2dd045a..870250e8 100755
--- a/DeepSpeech.py
+++ b/DeepSpeech.py
@@ -446,6 +446,10 @@ def train():
 
     # Building the graph
     optimizer = create_optimizer()
+    
+    # Enable mixed precision training
+    if FLAGS.automatic_mixed_precision:
+        optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
     gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
 
     # Average tower gradients across GPUs