diff --git a/DeepSpeech.py b/DeepSpeech.py index 8492fa66..4e11b1f9 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -466,11 +466,48 @@ def train(): with tfv1.Session(config=Config.session_config) as session: log_debug('Session opened.') - tfv1.get_default_graph().finalize() - # Loading or initializing loaded = False - if FLAGS.load in ['auto', 'last']: + + # Initialize training from a CuDNN RNN checkpoint + if FLAGS.cudnn_checkpoint: + if FLAGS.use_cudnn_rnn: + log_error('Trying to use --cudnn_checkpoint but --use_cudnn_rnn ' + 'was specified. The --cudnn_checkpoint flag is only ' + 'needed when converting a CuDNN RNN checkpoint to ' + 'a CPU-capable graph. If your system is capable of ' + 'using CuDNN RNN, you can just specify the CuDNN RNN ' + 'checkpoint normally with --checkpoint_dir.') + exit(1) + + log_info('Converting CuDNN RNN checkpoint from {}'.format(FLAGS.cudnn_checkpoint)) + ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint) + missing_variables = [] + + # Load compatible variables from checkpoint + for v in tfv1.global_variables(): + try: + v.load(ckpt.get_tensor(v.op.name), session=session) + except tf.errors.NotFoundError: + missing_variables.append(v) + + # Check that the only missing variables are the Adam moment tensors + if any('Adam' not in v.op.name for v in missing_variables): + log_error('Tried to load a CuDNN RNN checkpoint but there were ' + 'more missing variables than just the Adam moment ' + 'tensors.') + exit(1) + + # Initialize Adam moment tensors from scratch to allow use of CuDNN + # RNN checkpoints. + log_info('Initializing missing Adam moment tensors.') + init_op = tfv1.variables_initializer(missing_variables) + session.run(init_op) + loaded = True + + tfv1.get_default_graph().finalize() + + if not loaded and FLAGS.load in ['auto', 'last']: loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent') if not loaded and FLAGS.load in ['auto', 'best']: loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation') diff --git a/util/flags.py b/util/flags.py index 923b4471..6cebb3a6 100644 --- a/util/flags.py +++ b/util/flags.py @@ -56,6 +56,7 @@ def create_flags(): f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED') f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED') f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work') + f.DEFINE_string('cudnn_checkpoint', '', 'path to a checkpoint created using --use_cudnn_rnn. Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.') # Sample limits