Use try_loading and FLAGS.load in --one_shot_infer code
This commit is contained in:
parent
e75d1e4b61
commit
e0d8ef75e8
@ -394,15 +394,18 @@ def log_grads_and_vars(grads_and_vars):
|
||||
log_variable(variable, gradient=gradient)
|
||||
|
||||
|
||||
def try_loading(session, saver, checkpoint_filename, caption):
|
||||
def try_loading(session, saver, checkpoint_filename, caption, load_step=True):
|
||||
try:
|
||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename)
|
||||
if not checkpoint:
|
||||
return False
|
||||
checkpoint_path = checkpoint.model_checkpoint_path
|
||||
saver.restore(session, checkpoint_path)
|
||||
restored_step = session.run(tfv1.train.get_global_step())
|
||||
log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step))
|
||||
if load_step:
|
||||
restored_step = session.run(tfv1.train.get_global_step())
|
||||
log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step))
|
||||
else:
|
||||
log_info('Restored variables from %s checkpoint at %s' % (caption, checkpoint_path))
|
||||
return True
|
||||
except tf.errors.InvalidArgumentError as e:
|
||||
log_error(str(e))
|
||||
@ -479,11 +482,9 @@ def train():
|
||||
# Checkpointing
|
||||
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
||||
checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train')
|
||||
checkpoint_filename = 'checkpoint'
|
||||
|
||||
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
||||
best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev')
|
||||
best_dev_filename = 'best_dev_checkpoint'
|
||||
|
||||
# Save flags next to checkpoints
|
||||
os.makedirs(FLAGS.checkpoint_dir, exist_ok=True)
|
||||
@ -509,7 +510,7 @@ def train():
|
||||
'a CPU-capable graph. If your system is capable of '
|
||||
'using CuDNN RNN, you can just specify the CuDNN RNN '
|
||||
'checkpoint normally with --checkpoint_dir.')
|
||||
exit(1)
|
||||
sys.exit(1)
|
||||
|
||||
log_info('Converting CuDNN RNN checkpoint from {}'.format(FLAGS.cudnn_checkpoint))
|
||||
ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint)
|
||||
@ -527,7 +528,7 @@ def train():
|
||||
log_error('Tried to load a CuDNN RNN checkpoint but there were '
|
||||
'more missing variables than just the Adam moment '
|
||||
'tensors.')
|
||||
exit(1)
|
||||
sys.exit(1)
|
||||
|
||||
# Initialize Adam moment tensors from scratch to allow use of CuDNN
|
||||
# RNN checkpoints.
|
||||
@ -539,9 +540,9 @@ def train():
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
if not loaded and FLAGS.load in ['auto', 'last']:
|
||||
loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent')
|
||||
loaded = try_loading(session, checkpoint_saver, 'checkpoint', 'most recent')
|
||||
if not loaded and FLAGS.load in ['auto', 'best']:
|
||||
loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation')
|
||||
loaded = try_loading(session, best_dev_saver, 'best_dev_checkpoint', 'best validation')
|
||||
if not loaded:
|
||||
if FLAGS.load in ['auto', 'init']:
|
||||
log_info('Initializing variables...')
|
||||
@ -638,7 +639,7 @@ def train():
|
||||
|
||||
if dev_loss < best_dev_loss:
|
||||
best_dev_loss = dev_loss
|
||||
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename=best_dev_filename)
|
||||
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
|
||||
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
||||
|
||||
# Early stopping
|
||||
@ -855,15 +856,14 @@ def do_single_file_inference(input_file_path):
|
||||
saver = tfv1.train.Saver()
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
# TODO: This restores the most recent checkpoint, but if we use validation to counteract
|
||||
# over-fitting, we may want to restore an earlier checkpoint.
|
||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
|
||||
if not checkpoint:
|
||||
log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir))
|
||||
exit(1)
|
||||
|
||||
checkpoint_path = checkpoint.model_checkpoint_path
|
||||
saver.restore(session, checkpoint_path)
|
||||
loaded = False
|
||||
if not loaded and FLAGS.load in ['auto', 'last']:
|
||||
loaded = try_loading(session, saver, 'checkpoint', 'most recent', load_step=False)
|
||||
if not loaded and FLAGS.load in ['auto', 'best']:
|
||||
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', load_step=False)
|
||||
if not loaded:
|
||||
print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir))
|
||||
sys.exit(1)
|
||||
|
||||
features, features_len = audiofile_to_features(input_file_path)
|
||||
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
||||
|
Loading…
x
Reference in New Issue
Block a user