Ignore epochs in checkpoints, always start epoch count from zero
This commit is contained in:
parent
57450893ea
commit
2f3f095048
108
DeepSpeech.py
108
DeepSpeech.py
@ -21,7 +21,7 @@ from tensorflow.python.tools import freeze_graph
|
||||
from util.config import Config, initialize_globals
|
||||
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.logging import log_info, log_error, log_debug, log_warn
|
||||
from util.logging import log_info, log_error, log_debug
|
||||
|
||||
|
||||
# Graph Creation
|
||||
@ -350,7 +350,8 @@ def try_loading(session, saver, checkpoint_filename, caption):
|
||||
return False
|
||||
checkpoint_path = checkpoint.model_checkpoint_path
|
||||
saver.restore(session, checkpoint_path)
|
||||
log_info('Restored model from %s checkpoint at %s' % (caption, checkpoint_path))
|
||||
restored_step = session.run(tf.train.get_global_step())
|
||||
log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step))
|
||||
return True
|
||||
except tf.errors.InvalidArgumentError as e:
|
||||
log_error(str(e))
|
||||
@ -399,11 +400,13 @@ def train():
|
||||
# Building the graph
|
||||
optimizer = create_optimizer()
|
||||
gradients, loss = get_tower_results(iterator, optimizer, dropout_rates)
|
||||
|
||||
# Average tower gradients across GPUs
|
||||
avg_tower_gradients = average_gradients(gradients)
|
||||
log_grads_and_vars(avg_tower_gradients)
|
||||
|
||||
# global_step is automagically incremented by the optimizer
|
||||
global_step = tf.Variable(0, trainable=False, name='global_step')
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
|
||||
|
||||
# Summaries
|
||||
@ -432,39 +435,25 @@ def train():
|
||||
# Loading or initializing
|
||||
loaded = False
|
||||
if FLAGS.load in ['auto', 'last']:
|
||||
loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent epoch')
|
||||
loaded = try_loading(session, checkpoint_saver, checkpoint_filename, 'most recent')
|
||||
if not loaded and FLAGS.load in ['auto', 'best']:
|
||||
loaded = try_loading(session, best_dev_saver, best_dev_filename, 'best validation')
|
||||
if not loaded:
|
||||
if FLAGS.load in ['auto', 'init']:
|
||||
log_info('Initializing...')
|
||||
log_info('Initializing variables...')
|
||||
session.run(initializer)
|
||||
else:
|
||||
log_error('Unable to load %s model from specified checkpoint dir'
|
||||
' - consider using load option "auto" or "init".' % FLAGS.load)
|
||||
sys.exit(1)
|
||||
|
||||
# Retrieving global_step from restored model and setting training parameters accordingly
|
||||
step = session.run(global_step)
|
||||
num_gpus = len(Config.available_devices)
|
||||
steps_per_epoch = max(1, train_batches // num_gpus)
|
||||
current_epoch = step // steps_per_epoch
|
||||
target_epoch = current_epoch + abs(FLAGS.epoch) if FLAGS.epoch < 0 else FLAGS.epoch
|
||||
|
||||
log_debug('step: %d' % step)
|
||||
log_debug('epoch: %d' % current_epoch)
|
||||
log_debug('target epoch: %d' % target_epoch)
|
||||
log_debug('steps per epoch: %d' % steps_per_epoch)
|
||||
log_debug('batches per step (GPUs): %d' % num_gpus)
|
||||
log_debug('number of batches in train set: %d' % train_batches)
|
||||
|
||||
def run_set(set_name, init_op, num_batches):
|
||||
is_train = set_name == 'train'
|
||||
train_op = apply_gradient_op if is_train else []
|
||||
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
|
||||
total_loss = 0.0
|
||||
step_summary_writer = step_summary_writers.get(set_name)
|
||||
num_steps = max(1, num_batches // num_gpus)
|
||||
num_steps = max(1, num_batches // len(Config.available_devices))
|
||||
checkpoint_time = time.time()
|
||||
|
||||
if FLAGS.show_progressbar:
|
||||
@ -492,51 +481,48 @@ def train():
|
||||
|
||||
return total_loss / num_steps
|
||||
|
||||
if target_epoch > current_epoch:
|
||||
log_info('STARTING Optimization')
|
||||
best_dev_loss = float('inf')
|
||||
dev_losses = []
|
||||
coord = tf.train.Coordinator()
|
||||
with coord.stop_on_exception():
|
||||
for current_epoch in range(current_epoch, target_epoch):
|
||||
if coord.should_stop():
|
||||
break
|
||||
log_info('STARTING Optimization')
|
||||
best_dev_loss = float('inf')
|
||||
dev_losses = []
|
||||
coord = tf.train.Coordinator()
|
||||
with coord.stop_on_exception():
|
||||
for epoch in range(FLAGS.epoch):
|
||||
if coord.should_stop():
|
||||
break
|
||||
|
||||
# Training
|
||||
log_info('Training epoch %d ...' % current_epoch)
|
||||
train_loss = run_set('train', train_init_op, train_batches)
|
||||
log_info('Finished training epoch %d - loss: %f' % (current_epoch, train_loss))
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
|
||||
# Training
|
||||
log_info('Training epoch %d...' % epoch)
|
||||
train_loss = run_set('train', train_init_op, train_batches)
|
||||
log_info('Finished training epoch %d - loss: %f' % (epoch, train_loss))
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
|
||||
|
||||
if FLAGS.dev_files:
|
||||
# Validation
|
||||
log_info('Validating epoch %d ...' % current_epoch)
|
||||
dev_loss = run_set('dev', dev_init_op, dev_batches)
|
||||
dev_losses.append(dev_loss)
|
||||
log_info('Finished validating epoch %d - loss: %f' % (current_epoch, dev_loss))
|
||||
if FLAGS.dev_files:
|
||||
# Validation
|
||||
log_info('Validating epoch %d...' % epoch)
|
||||
dev_loss = run_set('dev', dev_init_op, dev_batches)
|
||||
dev_losses.append(dev_loss)
|
||||
log_info('Finished validating epoch %d - loss: %f' % (epoch, dev_loss))
|
||||
|
||||
if dev_loss < best_dev_loss:
|
||||
best_dev_loss = dev_loss
|
||||
save_path = best_dev_saver.save(session, best_dev_path, latest_filename=best_dev_filename)
|
||||
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
||||
if dev_loss < best_dev_loss:
|
||||
best_dev_loss = dev_loss
|
||||
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename=best_dev_filename)
|
||||
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
||||
|
||||
# Early stopping
|
||||
if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps:
|
||||
mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1])
|
||||
std_loss = np.std(dev_losses[-FLAGS.es_steps:-1])
|
||||
dev_losses = dev_losses[-FLAGS.es_steps:]
|
||||
log_debug('Checking for early stopping (last %d steps) validation loss: '
|
||||
'%f, with standard deviation: %f and mean: %f' %
|
||||
(FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
|
||||
if dev_losses[-1] > np.max(dev_losses[:-1]) or \
|
||||
(abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th):
|
||||
log_info('Early stop triggered as (for last %d steps) validation loss:'
|
||||
' %f with standard deviation: %f and mean: %f' %
|
||||
(FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
|
||||
break
|
||||
coord.request_stop()
|
||||
else:
|
||||
log_info('Target epoch already reached - skipped training.')
|
||||
# Early stopping
|
||||
if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps:
|
||||
mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1])
|
||||
std_loss = np.std(dev_losses[-FLAGS.es_steps:-1])
|
||||
dev_losses = dev_losses[-FLAGS.es_steps:]
|
||||
log_debug('Checking for early stopping (last %d steps) validation loss: '
|
||||
'%f, with standard deviation: %f and mean: %f' %
|
||||
(FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
|
||||
if dev_losses[-1] > np.max(dev_losses[:-1]) or \
|
||||
(abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th):
|
||||
log_info('Early stop triggered as (for last %d steps) validation loss:'
|
||||
' %f with standard deviation: %f and mean: %f' %
|
||||
(FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
|
||||
break
|
||||
coord.request_stop()
|
||||
log_debug('Session closed.')
|
||||
|
||||
|
||||
|
@ -354,7 +354,7 @@ For example, if you want to fine tune the entire graph using your own data in `m
|
||||
|
||||
```bash
|
||||
mkdir fine_tuning_checkpoints
|
||||
python3 DeepSpeech.py --n_hidden 2048 --checkpoint_dir path/to/checkpoint/folder --epoch -3 --train_files my-train.csv --dev_files my-dev.csv --test_files my_dev.csv --learning_rate 0.0001
|
||||
python3 DeepSpeech.py --n_hidden 2048 --checkpoint_dir path/to/checkpoint/folder --epoch 3 --train_files my-train.csv --dev_files my-dev.csv --test_files my_dev.csv --learning_rate 0.0001
|
||||
```
|
||||
|
||||
Note: the released models were trained with `--n_hidden 2048`, so you need to use that same value when initializing from the release models. Note as well the use of a negative epoch count -3 (meaning 3 more epochs) since the checkpoint you're loading from was already trained for several epochs.
|
||||
|
@ -16,13 +16,13 @@ python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
||||
--n_hidden 100 --epoch -1 \
|
||||
--n_hidden 100 --epoch 1 \
|
||||
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt' \
|
||||
--learning_rate 0.001 --dropout_rate 0.05 \
|
||||
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
||||
--lm_trie_path 'data/smoke_test/vocab.trie' | tee /tmp/resume.log
|
||||
|
||||
if ! grep "Training epoch $epoch_count" /tmp/resume.log; then
|
||||
if ! grep "Restored variables from most recent checkpoint" /tmp/resume.log; then
|
||||
echo "Did not resume training from checkpoint"
|
||||
exit 1
|
||||
else
|
||||
|
@ -22,7 +22,7 @@ def create_flags():
|
||||
# Global Constants
|
||||
# ================
|
||||
|
||||
tf.app.flags.DEFINE_integer ('epoch', 75, 'target epoch to train - if negative, the absolute number of additional epochs will be trained')
|
||||
tf.app.flags.DEFINE_integer ('epoch', 75, 'how many epochs (complete runs through the train files) to train for')
|
||||
|
||||
tf.app.flags.DEFINE_float ('dropout_rate', 0.05, 'dropout rate for feedforward layers')
|
||||
tf.app.flags.DEFINE_float ('dropout_rate2', -1.0, 'dropout rate for layer 2 - defaults to dropout_rate')
|
||||
|
Loading…
x
Reference in New Issue
Block a user