Allow loading a CuDNN RNN checkpoint in a CPU-capable graph
This commit is contained in:
parent
1d50667234
commit
e3d0a44e83
@ -466,11 +466,48 @@ def train():
|
|||||||
with tfv1.Session(config=Config.session_config) as session:
|
with tfv1.Session(config=Config.session_config) as session:
|
||||||
log_debug('Session opened.')
|
log_debug('Session opened.')
|
||||||
|
|
||||||
tfv1.get_default_graph().finalize()
|
|
||||||
|
|
||||||
# Loading or initializing
|
# Loading or initializing
|
||||||
loaded = False
|
loaded = False
|
||||||
if FLAGS.load in ['auto', 'last']:
|
|
||||||
|
# Initialize training from a CuDNN RNN checkpoint
|
||||||
|
if FLAGS.cudnn_checkpoint:
|
||||||
|
if FLAGS.use_cudnn_rnn:
|
||||||
|
log_error('Trying to use --cudnn_checkpoint but --use_cudnn_rnn '
|
||||||
|
'was specified. The --cudnn_checkpoint flag is only '
|
||||||
|
'needed when converting a CuDNN RNN checkpoint to '
|
||||||
|
'a CPU-capable graph. If your system is capable of '
|
||||||
|
'using CuDNN RNN, you can just specify the CuDNN RNN '
|
||||||
|
'checkpoint normally with --checkpoint_dir.')
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
log_info('Converting CuDNN RNN checkpoint from {}'.format(FLAGS.cudnn_checkpoint))
|
||||||
|
ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint)
|
||||||
|
missing_variables = []
|
||||||
|
|
||||||
|
# Load compatible variables from checkpoint
|
||||||
|
for v in tfv1.global_variables():
|
||||||
|
try:
|
||||||
|
v.load(ckpt.get_tensor(v.op.name), session=session)
|
||||||
|
except tf.errors.NotFoundError:
|
||||||
|
missing_variables.append(v)
|
||||||
|
|
||||||
|
# Check that the only missing variables are the Adam moment tensors
|
||||||
|
if any('Adam' not in v.op.name for v in missing_variables):
|
||||||
|
log_error('Tried to load a CuDNN RNN checkpoint but there were '
|
||||||
|
'more missing variables than just the Adam moment '
|
||||||
|
'tensors.')
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# Initialize Adam moment tensors from scratch to allow use of CuDNN
|
||||||
|
# RNN checkpoints.
|
||||||
|
log_info('Initializing missing Adam moment tensors.')
|
||||||
|
init_op = tfv1.variables_initializer(missing_variables)
|
||||||
|
session.run(init_op)
|
||||||
|
loaded = True
|
||||||
|
|
||||||
|
tfv1.get_default_graph().finalize()
|
||||||
|
|
||||||
|
if not loaded and FLAGS.load in ['auto', 'last']:
|
||||||
loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent')
|
loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent')
|
||||||
if not loaded and FLAGS.load in ['auto', 'best']:
|
if not loaded and FLAGS.load in ['auto', 'best']:
|
||||||
loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation')
|
loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation')
|
||||||
|
@ -56,6 +56,7 @@ def create_flags():
|
|||||||
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
|
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
|
||||||
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
|
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
|
||||||
f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
|
f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
|
||||||
|
f.DEFINE_string('cudnn_checkpoint', '', 'path to a checkpoint created using --use_cudnn_rnn. Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.')
|
||||||
|
|
||||||
# Sample limits
|
# Sample limits
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user