Merge pull request #2265 from mozilla/cudnnrnn_compatible
Allow loading a CuDNN RNN checkpoint in a CPU-capable graph (Fixes #2264)
This commit is contained in:
commit
86fff2f660
@ -467,11 +467,48 @@ def train():
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
log_debug('Session opened.')
|
||||
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
# Loading or initializing
|
||||
loaded = False
|
||||
if FLAGS.load in ['auto', 'last']:
|
||||
|
||||
# Initialize training from a CuDNN RNN checkpoint
|
||||
if FLAGS.cudnn_checkpoint:
|
||||
if FLAGS.use_cudnn_rnn:
|
||||
log_error('Trying to use --cudnn_checkpoint but --use_cudnn_rnn '
|
||||
'was specified. The --cudnn_checkpoint flag is only '
|
||||
'needed when converting a CuDNN RNN checkpoint to '
|
||||
'a CPU-capable graph. If your system is capable of '
|
||||
'using CuDNN RNN, you can just specify the CuDNN RNN '
|
||||
'checkpoint normally with --checkpoint_dir.')
|
||||
exit(1)
|
||||
|
||||
log_info('Converting CuDNN RNN checkpoint from {}'.format(FLAGS.cudnn_checkpoint))
|
||||
ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint)
|
||||
missing_variables = []
|
||||
|
||||
# Load compatible variables from checkpoint
|
||||
for v in tfv1.global_variables():
|
||||
try:
|
||||
v.load(ckpt.get_tensor(v.op.name), session=session)
|
||||
except tf.errors.NotFoundError:
|
||||
missing_variables.append(v)
|
||||
|
||||
# Check that the only missing variables are the Adam moment tensors
|
||||
if any('Adam' not in v.op.name for v in missing_variables):
|
||||
log_error('Tried to load a CuDNN RNN checkpoint but there were '
|
||||
'more missing variables than just the Adam moment '
|
||||
'tensors.')
|
||||
exit(1)
|
||||
|
||||
# Initialize Adam moment tensors from scratch to allow use of CuDNN
|
||||
# RNN checkpoints.
|
||||
log_info('Initializing missing Adam moment tensors.')
|
||||
init_op = tfv1.variables_initializer(missing_variables)
|
||||
session.run(init_op)
|
||||
loaded = True
|
||||
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
if not loaded and FLAGS.load in ['auto', 'last']:
|
||||
loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent')
|
||||
if not loaded and FLAGS.load in ['auto', 'best']:
|
||||
loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation')
|
||||
|
@ -56,6 +56,7 @@ def create_flags():
|
||||
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
|
||||
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
|
||||
f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
|
||||
f.DEFINE_string('cudnn_checkpoint', '', 'path to a checkpoint created using --use_cudnn_rnn. Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.')
|
||||
|
||||
# Sample limits
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user