Use try_loading and FLAGS.load in --one_shot_infer code

This commit is contained in:
Reuben Morais 2019-10-14 11:37:32 +02:00
parent e75d1e4b61
commit e0d8ef75e8

View File

@ -394,15 +394,18 @@ def log_grads_and_vars(grads_and_vars):
log_variable(variable, gradient=gradient)
def try_loading(session, saver, checkpoint_filename, caption):
def try_loading(session, saver, checkpoint_filename, caption, load_step=True):
try:
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename)
if not checkpoint:
return False
checkpoint_path = checkpoint.model_checkpoint_path
saver.restore(session, checkpoint_path)
restored_step = session.run(tfv1.train.get_global_step())
log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step))
if load_step:
restored_step = session.run(tfv1.train.get_global_step())
log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step))
else:
log_info('Restored variables from %s checkpoint at %s' % (caption, checkpoint_path))
return True
except tf.errors.InvalidArgumentError as e:
log_error(str(e))
@ -479,11 +482,9 @@ def train():
# Checkpointing
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train')
checkpoint_filename = 'checkpoint'
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev')
best_dev_filename = 'best_dev_checkpoint'
# Save flags next to checkpoints
os.makedirs(FLAGS.checkpoint_dir, exist_ok=True)
@ -509,7 +510,7 @@ def train():
'a CPU-capable graph. If your system is capable of '
'using CuDNN RNN, you can just specify the CuDNN RNN '
'checkpoint normally with --checkpoint_dir.')
exit(1)
sys.exit(1)
log_info('Converting CuDNN RNN checkpoint from {}'.format(FLAGS.cudnn_checkpoint))
ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint)
@ -527,7 +528,7 @@ def train():
log_error('Tried to load a CuDNN RNN checkpoint but there were '
'more missing variables than just the Adam moment '
'tensors.')
exit(1)
sys.exit(1)
# Initialize Adam moment tensors from scratch to allow use of CuDNN
# RNN checkpoints.
@ -539,9 +540,9 @@ def train():
tfv1.get_default_graph().finalize()
if not loaded and FLAGS.load in ['auto', 'last']:
loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent')
loaded = try_loading(session, checkpoint_saver, 'checkpoint', 'most recent')
if not loaded and FLAGS.load in ['auto', 'best']:
loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation')
loaded = try_loading(session, best_dev_saver, 'best_dev_checkpoint', 'best validation')
if not loaded:
if FLAGS.load in ['auto', 'init']:
log_info('Initializing variables...')
@ -638,7 +639,7 @@ def train():
if dev_loss < best_dev_loss:
best_dev_loss = dev_loss
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename=best_dev_filename)
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
# Early stopping
@ -855,15 +856,14 @@ def do_single_file_inference(input_file_path):
saver = tfv1.train.Saver()
# Restore variables from training checkpoint
# TODO: This restores the most recent checkpoint, but if we use validation to counteract
# over-fitting, we may want to restore an earlier checkpoint.
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if not checkpoint:
log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir))
exit(1)
checkpoint_path = checkpoint.model_checkpoint_path
saver.restore(session, checkpoint_path)
loaded = False
if not loaded and FLAGS.load in ['auto', 'last']:
loaded = try_loading(session, saver, 'checkpoint', 'most recent', load_step=False)
if not loaded and FLAGS.load in ['auto', 'best']:
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', load_step=False)
if not loaded:
print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir))
sys.exit(1)
features, features_len = audiofile_to_features(input_file_path)
previous_state_c = np.zeros([1, Config.n_cell_dim])