Merge PR #2434 - Add flag for automatic mixed precision training

This commit is contained in:
Reuben Morais 2019-10-15 13:44:57 +02:00
commit ef3bdb2540
2 changed files with 8 additions and 0 deletions

View File

@ -458,6 +458,12 @@ def train():
# Building the graph
optimizer = create_optimizer()
# Enable mixed precision training
if FLAGS.automatic_mixed_precision:
log_info('Enabling automatic mixed precision training.')
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
# Average tower gradients across GPUs

View File

@ -80,6 +80,8 @@ def create_flags():
f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
f.DEFINE_string('cudnn_checkpoint', '', 'path to a checkpoint created using --use_cudnn_rnn. Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.')
f.DEFINE_boolean('automatic_mixed_precision', False, 'whether to allow automatic mixed precision training. USE OF THIS FLAG IS UNSUPPORTED. Checkpoints created with automatic mixed precision training will not be usable without mixed precision.')
# Sample limits
f.DEFINE_integer('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit')