Merge pull request #2265 from mozilla/cudnnrnn_compatible

Allow loading a CuDNN RNN checkpoint in a CPU-capable graph (Fixes #2264)
This commit is contained in:
Reuben Morais 2019-08-07 14:57:16 +02:00 committed by GitHub
commit 86fff2f660
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 3 deletions

View File

@ -467,11 +467,48 @@ def train():
with tfv1.Session(config=Config.session_config) as session:
log_debug('Session opened.')
tfv1.get_default_graph().finalize()
# Loading or initializing
loaded = False
if FLAGS.load in ['auto', 'last']:
# Initialize training from a CuDNN RNN checkpoint
if FLAGS.cudnn_checkpoint:
if FLAGS.use_cudnn_rnn:
log_error('Trying to use --cudnn_checkpoint but --use_cudnn_rnn '
'was specified. The --cudnn_checkpoint flag is only '
'needed when converting a CuDNN RNN checkpoint to '
'a CPU-capable graph. If your system is capable of '
'using CuDNN RNN, you can just specify the CuDNN RNN '
'checkpoint normally with --checkpoint_dir.')
exit(1)
log_info('Converting CuDNN RNN checkpoint from {}'.format(FLAGS.cudnn_checkpoint))
ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint)
missing_variables = []
# Load compatible variables from checkpoint
for v in tfv1.global_variables():
try:
v.load(ckpt.get_tensor(v.op.name), session=session)
except tf.errors.NotFoundError:
missing_variables.append(v)
# Check that the only missing variables are the Adam moment tensors
if any('Adam' not in v.op.name for v in missing_variables):
log_error('Tried to load a CuDNN RNN checkpoint but there were '
'more missing variables than just the Adam moment '
'tensors.')
exit(1)
# Initialize Adam moment tensors from scratch to allow use of CuDNN
# RNN checkpoints.
log_info('Initializing missing Adam moment tensors.')
init_op = tfv1.variables_initializer(missing_variables)
session.run(init_op)
loaded = True
tfv1.get_default_graph().finalize()
if not loaded and FLAGS.load in ['auto', 'last']:
loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent')
if not loaded and FLAGS.load in ['auto', 'best']:
loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation')

View File

@ -56,6 +56,7 @@ def create_flags():
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
f.DEFINE_string('cudnn_checkpoint', '', 'path to a checkpoint created using --use_cudnn_rnn. Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.')
# Sample limits