Fix checkpointing logic
This commit is contained in:
parent
b7b44f3573
commit
9ca61b077e
@ -350,8 +350,12 @@ class SampleIndex:
|
||||
return self.index
|
||||
|
||||
|
||||
def try_loading(session, saver, checkpoint_path, caption):
|
||||
def try_loading(session, saver, checkpoint_filename, caption):
|
||||
try:
|
||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename)
|
||||
if not checkpoint:
|
||||
return False
|
||||
checkpoint_path = checkpoint.model_checkpoint_path
|
||||
saver.restore(session, checkpoint_path)
|
||||
log_info('Restored model from %s checkpoint at %s' % (caption, checkpoint_path))
|
||||
return True
|
||||
@ -448,9 +452,13 @@ def train():
|
||||
|
||||
# Checkpointing
|
||||
checkpoint_saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
||||
checkpoint_path = FLAGS.checkpoint_dir
|
||||
checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train')
|
||||
checkpoint_filename = 'checkpoint'
|
||||
|
||||
best_dev_saver = tf.train.Saver(max_to_keep=1)
|
||||
best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev.ckpt')
|
||||
best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev')
|
||||
best_dev_filename = 'best_dev_checkpoint'
|
||||
|
||||
initializer = tf.global_variables_initializer()
|
||||
|
||||
with tf.Session(config=Config.session_config) as session:
|
||||
@ -460,9 +468,9 @@ def train():
|
||||
# Loading or initializing
|
||||
loaded = False
|
||||
if FLAGS.load in ['auto', 'last']:
|
||||
loaded = try_loading(session, checkpoint_saver, checkpoint_path, 'most recent epoch')
|
||||
loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent epoch')
|
||||
if not loaded and FLAGS.load in ['auto', 'best']:
|
||||
loaded = try_loading(session, best_dev_saver, best_dev_path, 'best validation')
|
||||
loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation')
|
||||
if not loaded:
|
||||
if FLAGS.load in ['auto', 'init']:
|
||||
log_info('Initializing...')
|
||||
@ -513,8 +521,8 @@ def train():
|
||||
step_summary_writer.add_summary(step_summary, current_step)
|
||||
if FLAGS.show_progressbar:
|
||||
pbar.update(step_index + 1, force=True)
|
||||
if FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
|
||||
checkpoint_saver.save(session, checkpoint_path)
|
||||
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
|
||||
checkpoint_time = time.time()
|
||||
if FLAGS.show_progressbar:
|
||||
pbar.finish()
|
||||
@ -537,7 +545,7 @@ def train():
|
||||
log_info('Training epoch %d ...' % current_epoch)
|
||||
train_loss = run_set('train')
|
||||
log_info('Finished training epoch %d - loss: %f' % (current_epoch, train_loss))
|
||||
checkpoint_saver.save(session, checkpoint_path)
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
|
||||
steps_trained = 0
|
||||
# Validation
|
||||
log_info('Validating epoch %d ...' % current_epoch)
|
||||
@ -546,7 +554,7 @@ def train():
|
||||
log_info('Finished validating epoch %d - loss: %f' % (current_epoch, dev_loss))
|
||||
if dev_loss < best_dev_loss:
|
||||
best_dev_loss = dev_loss
|
||||
save_path = best_dev_saver.save(session, best_dev_path)
|
||||
save_path = best_dev_saver.save(session, best_dev_path, latest_filename=best_dev_filename)
|
||||
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
||||
# Early stopping
|
||||
if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps:
|
||||
|
Loading…
x
Reference in New Issue
Block a user