Reduce learning rate on plateau.
This commit is contained in:
parent
44ff4c54b9
commit
17ddc5600e
|
@ -260,8 +260,8 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
|
||||||
# (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used,
|
# (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used,
|
||||||
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
|
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
|
||||||
# because, generally, it requires less fine-tuning.
|
# because, generally, it requires less fine-tuning.
|
||||||
def create_optimizer():
|
def create_optimizer(learning_rate_var):
|
||||||
optimizer = tfv1.train.AdamOptimizer(learning_rate=FLAGS.learning_rate,
|
optimizer = tfv1.train.AdamOptimizer(learning_rate=learning_rate_var,
|
||||||
beta1=FLAGS.beta1,
|
beta1=FLAGS.beta1,
|
||||||
beta2=FLAGS.beta2,
|
beta2=FLAGS.beta2,
|
||||||
epsilon=FLAGS.epsilon)
|
epsilon=FLAGS.epsilon)
|
||||||
|
@ -452,7 +452,9 @@ def train():
|
||||||
}
|
}
|
||||||
|
|
||||||
# Building the graph
|
# Building the graph
|
||||||
optimizer = create_optimizer()
|
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
|
||||||
|
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
|
||||||
|
optimizer = create_optimizer(learning_rate_var)
|
||||||
|
|
||||||
# Enable mixed precision training
|
# Enable mixed precision training
|
||||||
if FLAGS.automatic_mixed_precision:
|
if FLAGS.automatic_mixed_precision:
|
||||||
|
@ -571,6 +573,7 @@ def train():
|
||||||
train_start_time = datetime.utcnow()
|
train_start_time = datetime.utcnow()
|
||||||
best_dev_loss = float('inf')
|
best_dev_loss = float('inf')
|
||||||
dev_losses = []
|
dev_losses = []
|
||||||
|
epochs_without_improvement = 0
|
||||||
try:
|
try:
|
||||||
for epoch in range(FLAGS.epochs):
|
for epoch in range(FLAGS.epochs):
|
||||||
# Training
|
# Training
|
||||||
|
@ -589,29 +592,39 @@ def train():
|
||||||
dev_loss += set_loss * steps
|
dev_loss += set_loss * steps
|
||||||
total_steps += steps
|
total_steps += steps
|
||||||
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss))
|
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss))
|
||||||
dev_loss = dev_loss / total_steps
|
|
||||||
|
|
||||||
|
dev_loss = dev_loss / total_steps
|
||||||
dev_losses.append(dev_loss)
|
dev_losses.append(dev_loss)
|
||||||
|
|
||||||
|
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
|
||||||
|
# the improvement has to be greater than FLAGS.es_min_delta
|
||||||
|
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
|
||||||
|
epochs_without_improvement += 1
|
||||||
|
else:
|
||||||
|
epochs_without_improvement = 0
|
||||||
|
|
||||||
|
# Save new best model
|
||||||
if dev_loss < best_dev_loss:
|
if dev_loss < best_dev_loss:
|
||||||
best_dev_loss = dev_loss
|
best_dev_loss = dev_loss
|
||||||
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
|
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
|
||||||
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
||||||
|
|
||||||
# Early stopping
|
# Early stopping
|
||||||
if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps:
|
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
|
||||||
mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1])
|
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
|
||||||
std_loss = np.std(dev_losses[-FLAGS.es_steps:-1])
|
epochs_without_improvement))
|
||||||
dev_losses = dev_losses[-FLAGS.es_steps:]
|
|
||||||
log_debug('Checking for early stopping (last %d steps) validation loss: '
|
|
||||||
'%f, with standard deviation: %f and mean: %f' %
|
|
||||||
(FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
|
|
||||||
if dev_losses[-1] > np.max(dev_losses[:-1]) or \
|
|
||||||
(abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th):
|
|
||||||
log_info('Early stop triggered as (for last %d steps) validation loss:'
|
|
||||||
' %f with standard deviation: %f and mean: %f' %
|
|
||||||
(FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Reduce learning rate on plateau
|
||||||
|
if (FLAGS.reduce_lr_on_plateau and
|
||||||
|
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
|
||||||
|
# If the learning rate was reduced and there is still no improvement
|
||||||
|
# wait FLAGS.plateau_epochs before the learning rate is reduced again
|
||||||
|
session.run(reduce_learning_rate_op)
|
||||||
|
current_learning_rate = learning_rate_var.eval()
|
||||||
|
log_info('Encountered a plateau, reducing learning rate to {}'.format(
|
||||||
|
current_learning_rate))
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
|
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
|
||||||
|
|
|
@ -139,10 +139,15 @@ def create_flags():
|
||||||
|
|
||||||
# Early Stopping
|
# Early Stopping
|
||||||
|
|
||||||
f.DEFINE_boolean('early_stop', True, 'enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.')
|
f.DEFINE_boolean('early_stop', True, 'Enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.')
|
||||||
f.DEFINE_integer('es_steps', 4, 'number of validations to consider for early stopping. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point')
|
f.DEFINE_integer('es_epochs', 25, 'Number of epochs with no improvement after which training will be stopped. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point')
|
||||||
f.DEFINE_float('es_mean_th', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
|
f.DEFINE_float('es_min_delta', 0.05, 'Minimum change in loss to qualify as an improvement. This value will also be used in Reduce learning rate on plateau')
|
||||||
f.DEFINE_float('es_std_th', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')
|
|
||||||
|
# Reduce learning rate on plateau
|
||||||
|
|
||||||
|
f.DEFINE_boolean('reduce_lr_on_plateau', True, 'Enable reducing the learning rate if a plateau is reached. This is the case if the validation loss did not improve for some epochs.')
|
||||||
|
f.DEFINE_integer('plateau_epochs', 10, 'Number of epochs to consider for RLROP. Has to be smaller than es_epochs from early stopping')
|
||||||
|
f.DEFINE_float('plateau_reduction', 0.1, 'Multiplicative factor to apply to the current learning rate if a plateau has occurred.')
|
||||||
|
|
||||||
# Decoder
|
# Decoder
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue