From 17ddc5600e8587544d1f26b95f583d092f413773 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 11 Feb 2020 16:52:56 +0100 Subject: [PATCH 1/4] Reduce learning rate on plateau. --- DeepSpeech.py | 47 ++++++++++++++++++++++++++++++----------------- util/flags.py | 13 +++++++++---- 2 files changed, 39 insertions(+), 21 deletions(-) mode change 100755 => 100644 DeepSpeech.py diff --git a/DeepSpeech.py b/DeepSpeech.py old mode 100755 new mode 100644 index dabb3143..8ebd1e25 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -260,8 +260,8 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse): # (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used, # we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980), # because, generally, it requires less fine-tuning. -def create_optimizer(): - optimizer = tfv1.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, +def create_optimizer(learning_rate_var): + optimizer = tfv1.train.AdamOptimizer(learning_rate=learning_rate_var, beta1=FLAGS.beta1, beta2=FLAGS.beta2, epsilon=FLAGS.epsilon) @@ -452,7 +452,9 @@ def train(): } # Building the graph - optimizer = create_optimizer() + learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False) + reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction)) + optimizer = create_optimizer(learning_rate_var) # Enable mixed precision training if FLAGS.automatic_mixed_precision: @@ -571,6 +573,7 @@ def train(): train_start_time = datetime.utcnow() best_dev_loss = float('inf') dev_losses = [] + epochs_without_improvement = 0 try: for epoch in range(FLAGS.epochs): # Training @@ -589,29 +592,39 @@ def train(): dev_loss += set_loss * steps total_steps += steps log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss)) - dev_loss = dev_loss / total_steps + dev_loss = dev_loss / total_steps dev_losses.append(dev_loss) + # Count epochs without an improvement for early stopping and reduction of learning rate on a plateau + # the improvement has to be greater than FLAGS.es_min_delta + if dev_loss > best_dev_loss - FLAGS.es_min_delta: + epochs_without_improvement += 1 + else: + epochs_without_improvement = 0 + + # Save new best model if dev_loss < best_dev_loss: best_dev_loss = dev_loss save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint') log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path)) # Early stopping - if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps: - mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1]) - std_loss = np.std(dev_losses[-FLAGS.es_steps:-1]) - dev_losses = dev_losses[-FLAGS.es_steps:] - log_debug('Checking for early stopping (last %d steps) validation loss: ' - '%f, with standard deviation: %f and mean: %f' % - (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss)) - if dev_losses[-1] > np.max(dev_losses[:-1]) or \ - (abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th): - log_info('Early stop triggered as (for last %d steps) validation loss:' - ' %f with standard deviation: %f and mean: %f' % - (FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss)) - break + if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs: + log_info('Early stop triggered as the loss did not improve the last {} epochs'.format( + epochs_without_improvement)) + break + + # Reduce learning rate on plateau + if (FLAGS.reduce_lr_on_plateau and + epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0): + # If the learning rate was reduced and there is still no improvement + # wait FLAGS.plateau_epochs before the learning rate is reduced again + session.run(reduce_learning_rate_op) + current_learning_rate = learning_rate_var.eval() + log_info('Encountered a plateau, reducing learning rate to {}'.format( + current_learning_rate)) + except KeyboardInterrupt: pass log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time)) diff --git a/util/flags.py b/util/flags.py index 6f3d1bea..24ff764f 100644 --- a/util/flags.py +++ b/util/flags.py @@ -139,10 +139,15 @@ def create_flags(): # Early Stopping - f.DEFINE_boolean('early_stop', True, 'enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.') - f.DEFINE_integer('es_steps', 4, 'number of validations to consider for early stopping. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point') - f.DEFINE_float('es_mean_th', 0.5, 'mean threshold for loss to determine the condition if early stopping is required') - f.DEFINE_float('es_std_th', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required') + f.DEFINE_boolean('early_stop', True, 'Enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.') + f.DEFINE_integer('es_epochs', 25, 'Number of epochs with no improvement after which training will be stopped. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point') + f.DEFINE_float('es_min_delta', 0.05, 'Minimum change in loss to qualify as an improvement. This value will also be used in Reduce learning rate on plateau') + + # Reduce learning rate on plateau + + f.DEFINE_boolean('reduce_lr_on_plateau', True, 'Enable reducing the learning rate if a plateau is reached. This is the case if the validation loss did not improve for some epochs.') + f.DEFINE_integer('plateau_epochs', 10, 'Number of epochs to consider for RLROP. Has to be smaller than es_epochs from early stopping') + f.DEFINE_float('plateau_reduction', 0.1, 'Multiplicative factor to apply to the current learning rate if a plateau has occurred.') # Decoder From 6e12b7caed7b4deb6282cb4856fd647b6e3d6ed7 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Tue, 18 Feb 2020 14:42:09 +0100 Subject: [PATCH 2/4] Allow missing learning rate variable in older checkpoints --- util/checkpoints.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/util/checkpoints.py b/util/checkpoints.py index 753c5fde..94b1273e 100644 --- a/util/checkpoints.py +++ b/util/checkpoints.py @@ -11,29 +11,38 @@ def _load_checkpoint(session, checkpoint_path): # we will exclude variables we do not wish to load and then # we will initialize them instead ckpt = tfv1.train.load_checkpoint(checkpoint_path) + vars_in_ckpt = frozenset(ckpt.get_variable_to_shape_map().keys()) load_vars = set(tfv1.global_variables()) init_vars = set() + # We explicitly allow the learning rate variable to be missing for backwards + # compatibility with older checkpoints. + if 'learning_rate' not in vars_in_ckpt: + lr_var = set(v for v in load_vars if v.op.name == 'learning_rate') + assert len(lr_var) == 1 + load_vars -= lr_var + init_vars |= lr_var + if FLAGS.load_cudnn: # Initialize training from a CuDNN RNN checkpoint # Identify the variables which we cannot load, and set them # for initialization + missing_vars = set() for v in load_vars: - try: - ckpt.get_tensor(v.op.name) - except tf.errors.NotFoundError: - log_error('CUDNN variable not found: %s' % (v.op.name)) + if v.op.name not in vars_in_ckpt: + log_warn('CUDNN variable not found: %s' % (v.op.name)) + missing_vars.add(v) init_vars.add(v) load_vars -= init_vars # Check that the only missing variables (i.e. those to be initialised) # are the Adam moment tensors, if they aren't then we have an issue - init_var_names = [v.op.name for v in init_vars] - if any('Adam' not in v for v in init_var_names): + missing_var_names = [v.op.name for v in missing_vars] + if any('Adam' not in v for v in missing_var_names): log_error('Tried to load a CuDNN RNN checkpoint but there were ' 'more missing variables than just the Adam moment ' - 'tensors. Missing variables: {}'.format(init_var_names)) + 'tensors. Missing variables: {}'.format(missing_var_names)) sys.exit(1) if FLAGS.drop_source_layers > 0: From 78e8dfdf386d533c6581ec45c082ce94a78a11d4 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Tue, 18 Feb 2020 16:14:33 +0100 Subject: [PATCH 3/4] Disable early stopping and LR reduction on plateau by default --- util/flags.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/util/flags.py b/util/flags.py index 24ff764f..f46fdc81 100644 --- a/util/flags.py +++ b/util/flags.py @@ -139,13 +139,13 @@ def create_flags(): # Early Stopping - f.DEFINE_boolean('early_stop', True, 'Enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.') + f.DEFINE_boolean('early_stop', False, 'Enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.') f.DEFINE_integer('es_epochs', 25, 'Number of epochs with no improvement after which training will be stopped. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point') f.DEFINE_float('es_min_delta', 0.05, 'Minimum change in loss to qualify as an improvement. This value will also be used in Reduce learning rate on plateau') # Reduce learning rate on plateau - f.DEFINE_boolean('reduce_lr_on_plateau', True, 'Enable reducing the learning rate if a plateau is reached. This is the case if the validation loss did not improve for some epochs.') + f.DEFINE_boolean('reduce_lr_on_plateau', False, 'Enable reducing the learning rate if a plateau is reached. This is the case if the validation loss did not improve for some epochs.') f.DEFINE_integer('plateau_epochs', 10, 'Number of epochs to consider for RLROP. Has to be smaller than es_epochs from early stopping') f.DEFINE_float('plateau_reduction', 0.1, 'Multiplicative factor to apply to the current learning rate if a plateau has occurred.') From 559042a21846b1a3336486e014b9e0023794c00e Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Tue, 18 Feb 2020 18:11:16 +0100 Subject: [PATCH 4/4] Increase epoch count in train tests to guarantee outputs in 8kHz mode --- taskcluster/tc-train-tests.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taskcluster/tc-train-tests.sh b/taskcluster/tc-train-tests.sh index 5346f540..47bacfa8 100644 --- a/taskcluster/tc-train-tests.sh +++ b/taskcluster/tc-train-tests.sh @@ -78,7 +78,7 @@ mv "${DS_ROOT_TASK}/DeepSpeech/ds/data/smoke_test/${sample_name}" "${DS_ROOT_TAS pushd ${HOME}/DeepSpeech/ds/ # Run twice to test preprocessed features - time ./bin/run-tc-ldc93s1_new.sh 219 "${sample_rate}" + time ./bin/run-tc-ldc93s1_new.sh 249 "${sample_rate}" time ./bin/run-tc-ldc93s1_new.sh 1 "${sample_rate}" time ./bin/run-tc-ldc93s1_tflite.sh "${sample_rate}" popd