Reduce learning rate on plateau.

This commit is contained in:
Daniel 2020-02-11 16:52:56 +01:00 committed by Reuben Morais
parent 44ff4c54b9
commit 17ddc5600e
2 changed files with 39 additions and 21 deletions

47
DeepSpeech.py Executable file → Normal file
View File

@ -260,8 +260,8 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
# (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used, # (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used,
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980), # we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
# because, generally, it requires less fine-tuning. # because, generally, it requires less fine-tuning.
def create_optimizer(): def create_optimizer(learning_rate_var):
optimizer = tfv1.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, optimizer = tfv1.train.AdamOptimizer(learning_rate=learning_rate_var,
beta1=FLAGS.beta1, beta1=FLAGS.beta1,
beta2=FLAGS.beta2, beta2=FLAGS.beta2,
epsilon=FLAGS.epsilon) epsilon=FLAGS.epsilon)
@ -452,7 +452,9 @@ def train():
} }
# Building the graph # Building the graph
optimizer = create_optimizer() learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
optimizer = create_optimizer(learning_rate_var)
# Enable mixed precision training # Enable mixed precision training
if FLAGS.automatic_mixed_precision: if FLAGS.automatic_mixed_precision:
@ -571,6 +573,7 @@ def train():
train_start_time = datetime.utcnow() train_start_time = datetime.utcnow()
best_dev_loss = float('inf') best_dev_loss = float('inf')
dev_losses = [] dev_losses = []
epochs_without_improvement = 0
try: try:
for epoch in range(FLAGS.epochs): for epoch in range(FLAGS.epochs):
# Training # Training
@ -589,29 +592,39 @@ def train():
dev_loss += set_loss * steps dev_loss += set_loss * steps
total_steps += steps total_steps += steps
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss)) log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss))
dev_loss = dev_loss / total_steps
dev_loss = dev_loss / total_steps
dev_losses.append(dev_loss) dev_losses.append(dev_loss)
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
# the improvement has to be greater than FLAGS.es_min_delta
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
epochs_without_improvement += 1
else:
epochs_without_improvement = 0
# Save new best model
if dev_loss < best_dev_loss: if dev_loss < best_dev_loss:
best_dev_loss = dev_loss best_dev_loss = dev_loss
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint') save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path)) log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
# Early stopping # Early stopping
if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps: if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1]) log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
std_loss = np.std(dev_losses[-FLAGS.es_steps:-1]) epochs_without_improvement))
dev_losses = dev_losses[-FLAGS.es_steps:] break
log_debug('Checking for early stopping (last %d steps) validation loss: '
'%f, with standard deviation: %f and mean: %f' % # Reduce learning rate on plateau
(FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss)) if (FLAGS.reduce_lr_on_plateau and
if dev_losses[-1] > np.max(dev_losses[:-1]) or \ epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
(abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th): # If the learning rate was reduced and there is still no improvement
log_info('Early stop triggered as (for last %d steps) validation loss:' # wait FLAGS.plateau_epochs before the learning rate is reduced again
' %f with standard deviation: %f and mean: %f' % session.run(reduce_learning_rate_op)
(FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss)) current_learning_rate = learning_rate_var.eval()
break log_info('Encountered a plateau, reducing learning rate to {}'.format(
current_learning_rate))
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time)) log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))

View File

@ -139,10 +139,15 @@ def create_flags():
# Early Stopping # Early Stopping
f.DEFINE_boolean('early_stop', True, 'enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.') f.DEFINE_boolean('early_stop', True, 'Enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.')
f.DEFINE_integer('es_steps', 4, 'number of validations to consider for early stopping. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point') f.DEFINE_integer('es_epochs', 25, 'Number of epochs with no improvement after which training will be stopped. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point')
f.DEFINE_float('es_mean_th', 0.5, 'mean threshold for loss to determine the condition if early stopping is required') f.DEFINE_float('es_min_delta', 0.05, 'Minimum change in loss to qualify as an improvement. This value will also be used in Reduce learning rate on plateau')
f.DEFINE_float('es_std_th', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')
# Reduce learning rate on plateau
f.DEFINE_boolean('reduce_lr_on_plateau', True, 'Enable reducing the learning rate if a plateau is reached. This is the case if the validation loss did not improve for some epochs.')
f.DEFINE_integer('plateau_epochs', 10, 'Number of epochs to consider for RLROP. Has to be smaller than es_epochs from early stopping')
f.DEFINE_float('plateau_reduction', 0.1, 'Multiplicative factor to apply to the current learning rate if a plateau has occurred.')
# Decoder # Decoder