Fix checkpointing logic

This commit is contained in:
Reuben Morais 2019-04-01 18:53:06 -03:00
parent b7b44f3573
commit 9ca61b077e

View File

@ -350,8 +350,12 @@ class SampleIndex:
return self.index
def try_loading(session, saver, checkpoint_path, caption):
def try_loading(session, saver, checkpoint_filename, caption):
try:
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename)
if not checkpoint:
return False
checkpoint_path = checkpoint.model_checkpoint_path
saver.restore(session, checkpoint_path)
log_info('Restored model from %s checkpoint at %s' % (caption, checkpoint_path))
return True
@ -448,9 +452,13 @@ def train():
# Checkpointing
checkpoint_saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
checkpoint_path = FLAGS.checkpoint_dir
checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train')
checkpoint_filename = 'checkpoint'
best_dev_saver = tf.train.Saver(max_to_keep=1)
best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev.ckpt')
best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev')
best_dev_filename = 'best_dev_checkpoint'
initializer = tf.global_variables_initializer()
with tf.Session(config=Config.session_config) as session:
@ -460,9 +468,9 @@ def train():
# Loading or initializing
loaded = False
if FLAGS.load in ['auto', 'last']:
loaded = try_loading(session, checkpoint_saver, checkpoint_path, 'most recent epoch')
loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent epoch')
if not loaded and FLAGS.load in ['auto', 'best']:
loaded = try_loading(session, best_dev_saver, best_dev_path, 'best validation')
loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation')
if not loaded:
if FLAGS.load in ['auto', 'init']:
log_info('Initializing...')
@ -513,8 +521,8 @@ def train():
step_summary_writer.add_summary(step_summary, current_step)
if FLAGS.show_progressbar:
pbar.update(step_index + 1, force=True)
if FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
checkpoint_saver.save(session, checkpoint_path)
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
checkpoint_time = time.time()
if FLAGS.show_progressbar:
pbar.finish()
@ -537,7 +545,7 @@ def train():
log_info('Training epoch %d ...' % current_epoch)
train_loss = run_set('train')
log_info('Finished training epoch %d - loss: %f' % (current_epoch, train_loss))
checkpoint_saver.save(session, checkpoint_path)
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
steps_trained = 0
# Validation
log_info('Validating epoch %d ...' % current_epoch)
@ -546,7 +554,7 @@ def train():
log_info('Finished validating epoch %d - loss: %f' % (current_epoch, dev_loss))
if dev_loss < best_dev_loss:
best_dev_loss = dev_loss
save_path = best_dev_saver.save(session, best_dev_path)
save_path = best_dev_saver.save(session, best_dev_path, latest_filename=best_dev_filename)
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
# Early stopping
if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps: