Ignore epochs in checkpoints, always start epoch count from zero

This commit is contained in:
Reuben Morais 2019-04-04 19:49:33 -03:00
parent 57450893ea
commit 2f3f095048
4 changed files with 51 additions and 65 deletions

View File

@ -21,7 +21,7 @@ from tensorflow.python.tools import freeze_graph
from util.config import Config, initialize_globals
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from util.flags import create_flags, FLAGS
from util.logging import log_info, log_error, log_debug, log_warn
from util.logging import log_info, log_error, log_debug
# Graph Creation
@ -350,7 +350,8 @@ def try_loading(session, saver, checkpoint_filename, caption):
return False
checkpoint_path = checkpoint.model_checkpoint_path
saver.restore(session, checkpoint_path)
log_info('Restored model from %s checkpoint at %s' % (caption, checkpoint_path))
restored_step = session.run(tf.train.get_global_step())
log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step))
return True
except tf.errors.InvalidArgumentError as e:
log_error(str(e))
@ -399,11 +400,13 @@ def train():
# Building the graph
optimizer = create_optimizer()
gradients, loss = get_tower_results(iterator, optimizer, dropout_rates)
# Average tower gradients across GPUs
avg_tower_gradients = average_gradients(gradients)
log_grads_and_vars(avg_tower_gradients)
# global_step is automagically incremented by the optimizer
global_step = tf.Variable(0, trainable=False, name='global_step')
global_step = tf.train.get_or_create_global_step()
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
# Summaries
@ -432,39 +435,25 @@ def train():
# Loading or initializing
loaded = False
if FLAGS.load in ['auto', 'last']:
loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent epoch')
loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent')
if not loaded and FLAGS.load in ['auto', 'best']:
loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation')
if not loaded:
if FLAGS.load in ['auto', 'init']:
log_info('Initializing...')
log_info('Initializing variables...')
session.run(initializer)
else:
log_error('Unable to load %s model from specified checkpoint dir'
' - consider using load option "auto" or "init".' % FLAGS.load)
sys.exit(1)
# Retrieving global_step from restored model and setting training parameters accordingly
step = session.run(global_step)
num_gpus = len(Config.available_devices)
steps_per_epoch = max(1, train_batches // num_gpus)
current_epoch = step // steps_per_epoch
target_epoch = current_epoch + abs(FLAGS.epoch) if FLAGS.epoch < 0 else FLAGS.epoch
log_debug('step: %d' % step)
log_debug('epoch: %d' % current_epoch)
log_debug('target epoch: %d' % target_epoch)
log_debug('steps per epoch: %d' % steps_per_epoch)
log_debug('batches per step (GPUs): %d' % num_gpus)
log_debug('number of batches in train set: %d' % train_batches)
def run_set(set_name, init_op, num_batches):
is_train = set_name == 'train'
train_op = apply_gradient_op if is_train else []
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
total_loss = 0.0
step_summary_writer = step_summary_writers.get(set_name)
num_steps = max(1, num_batches // num_gpus)
num_steps = max(1, num_batches // len(Config.available_devices))
checkpoint_time = time.time()
if FLAGS.show_progressbar:
@ -492,51 +481,48 @@ def train():
return total_loss / num_steps
if target_epoch > current_epoch:
log_info('STARTING Optimization')
best_dev_loss = float('inf')
dev_losses = []
coord = tf.train.Coordinator()
with coord.stop_on_exception():
for current_epoch in range(current_epoch, target_epoch):
if coord.should_stop():
break
log_info('STARTING Optimization')
best_dev_loss = float('inf')
dev_losses = []
coord = tf.train.Coordinator()
with coord.stop_on_exception():
for epoch in range(FLAGS.epoch):
if coord.should_stop():
break
# Training
log_info('Training epoch %d ...' % current_epoch)
train_loss = run_set('train', train_init_op, train_batches)
log_info('Finished training epoch %d - loss: %f' % (current_epoch, train_loss))
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
# Training
log_info('Training epoch %d...' % epoch)
train_loss = run_set('train', train_init_op, train_batches)
log_info('Finished training epoch %d - loss: %f' % (epoch, train_loss))
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
if FLAGS.dev_files:
# Validation
log_info('Validating epoch %d ...' % current_epoch)
dev_loss = run_set('dev', dev_init_op, dev_batches)
dev_losses.append(dev_loss)
log_info('Finished validating epoch %d - loss: %f' % (current_epoch, dev_loss))
if FLAGS.dev_files:
# Validation
log_info('Validating epoch %d...' % epoch)
dev_loss = run_set('dev', dev_init_op, dev_batches)
dev_losses.append(dev_loss)
log_info('Finished validating epoch %d - loss: %f' % (epoch, dev_loss))
if dev_loss < best_dev_loss:
best_dev_loss = dev_loss
save_path = best_dev_saver.save(session, best_dev_path, latest_filename=best_dev_filename)
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
if dev_loss < best_dev_loss:
best_dev_loss = dev_loss
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename=best_dev_filename)
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
# Early stopping
if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps:
mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1])
std_loss = np.std(dev_losses[-FLAGS.es_steps:-1])
dev_losses = dev_losses[-FLAGS.es_steps:]
log_debug('Checking for early stopping (last %d steps) validation loss: '
'%f, with standard deviation: %f and mean: %f' %
(FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
if dev_losses[-1] > np.max(dev_losses[:-1]) or \
(abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th):
log_info('Early stop triggered as (for last %d steps) validation loss:'
' %f with standard deviation: %f and mean: %f' %
(FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
break
coord.request_stop()
else:
log_info('Target epoch already reached - skipped training.')
# Early stopping
if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps:
mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1])
std_loss = np.std(dev_losses[-FLAGS.es_steps:-1])
dev_losses = dev_losses[-FLAGS.es_steps:]
log_debug('Checking for early stopping (last %d steps) validation loss: '
'%f, with standard deviation: %f and mean: %f' %
(FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
if dev_losses[-1] > np.max(dev_losses[:-1]) or \
(abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th):
log_info('Early stop triggered as (for last %d steps) validation loss:'
' %f with standard deviation: %f and mean: %f' %
(FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
break
coord.request_stop()
log_debug('Session closed.')

View File

@ -354,7 +354,7 @@ For example, if you want to fine tune the entire graph using your own data in `m
```bash
mkdir fine_tuning_checkpoints
python3 DeepSpeech.py --n_hidden 2048 --checkpoint_dir path/to/checkpoint/folder --epoch -3 --train_files my-train.csv --dev_files my-dev.csv --test_files my_dev.csv --learning_rate 0.0001
python3 DeepSpeech.py --n_hidden 2048 --checkpoint_dir path/to/checkpoint/folder --epoch 3 --train_files my-train.csv --dev_files my-dev.csv --test_files my_dev.csv --learning_rate 0.0001
```
Note: the released models were trained with `--n_hidden 2048`, so you need to use that same value when initializing from the release models. Note as well the use of a negative epoch count -3 (meaning 3 more epochs) since the checkpoint you're loading from was already trained for several epochs.

View File

@ -16,13 +16,13 @@ python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--train_files ${ldc93s1_csv} --train_batch_size 1 \
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
--test_files ${ldc93s1_csv} --test_batch_size 1 \
--n_hidden 100 --epoch -1 \
--n_hidden 100 --epoch 1 \
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt' \
--learning_rate 0.001 --dropout_rate 0.05 \
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
--lm_trie_path 'data/smoke_test/vocab.trie' | tee /tmp/resume.log
if ! grep "Training epoch $epoch_count" /tmp/resume.log; then
if ! grep "Restored variables from most recent checkpoint" /tmp/resume.log; then
echo "Did not resume training from checkpoint"
exit 1
else

View File

@ -22,7 +22,7 @@ def create_flags():
# Global Constants
# ================
tf.app.flags.DEFINE_integer ('epoch', 75, 'target epoch to train - if negative, the absolute number of additional epochs will be trained')
tf.app.flags.DEFINE_integer ('epoch', 75, 'how many epochs (complete runs through the train files) to train for')
tf.app.flags.DEFINE_float ('dropout_rate', 0.05, 'dropout rate for feedforward layers')
tf.app.flags.DEFINE_float ('dropout_rate2', -1.0, 'dropout rate for layer 2 - defaults to dropout_rate')