Add read only validation metrics

For now this is just CTC loss like a validation set, but without affecting
best validation checkpoint tracking logic. Eventually this could compute WER
on a smaller set, for example.
This commit is contained in:
Reuben Morais 2020-06-08 15:26:37 +02:00
parent 572963e7bd
commit e069b6d61f
2 changed files with 32 additions and 4 deletions

View File

@ -450,6 +450,16 @@ def train():
buffering=FLAGS.read_buffer) for source in dev_sources]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
if FLAGS.metrics_files:
metrics_sources = FLAGS.metrics_files.split(',')
metrics_sets = [create_dataset([source],
batch_size=FLAGS.dev_batch_size,
train_phase=False,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
buffering=FLAGS.read_buffer) for source in metrics_sources]
metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets]
# Dropout
dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
dropout_feed_dict = {
@ -488,7 +498,14 @@ def train():
step_summaries_op = tfv1.summary.merge_all('step_summaries')
step_summary_writers = {
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120),
'metrics': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'metrics'), max_queue=120),
}
human_readable_set_names = {
'train': 'Training',
'dev': 'Validation',
'metrics': 'Metrics',
}
# Checkpointing
@ -533,7 +550,7 @@ def train():
data['mean_loss'] = total_loss / step_count if step_count else 0.0
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
prefix = 'Epoch {} | {:>10}'.format(epoch, human_readable_set_names[set_name])
widgets = [' | ', progressbar.widgets.Timer(),
' | Steps: ', progressbar.widgets.Counter(),
' | ', LossWidget()]
@ -635,6 +652,16 @@ def train():
log_info('Encountered a plateau, reducing learning rate to {}'.format(
current_learning_rate))
if FLAGS.metrics_files:
# Read only metrics, not affecting best validation loss tracking
for source, init_op in zip(metrics_sources, metrics_init_ops):
log_progress('Metrics for epoch %d on %s...' % (epoch, source))
set_loss, _ = run_set('metrics', epoch, init_op, dataset=source)
log_progress('Metrics for epoch %d on %s - loss: %f' % (epoch, source, set_loss))
print('-' * 80)
except KeyboardInterrupt:
pass
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))

View File

@ -13,8 +13,9 @@ def create_flags():
f = absl.flags
f.DEFINE_string('train_files', '', 'comma separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run.')
f.DEFINE_string('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.')
f.DEFINE_string('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')
f.DEFINE_string('dev_files', '', 'comma separated list of files specifying the datasets used for validation. Multiple files will get reported separately. If empty, validation will not be run.')
f.DEFINE_string('test_files', '', 'comma separated list of files specifying the datasets used for testing. Multiple files will get reported separately. If empty, the model will not be tested.')
f.DEFINE_string('metrics_files', '', 'comma separated list of files specifying the datasets used for tracking of metrics (after validation step). Currently the only metric is the CTC loss but without affecting the tracking of best validation loss. Multiple files will get reported separately. If empty, metrics will not be computed.')
f.DEFINE_string('read_buffer', '1MB', 'buffer-size for reading samples from datasets (supports file-size suffixes KB, MB, GB, TB)')
f.DEFINE_string('feature_cache', '', 'cache MFCC features to disk to speed up future training runs on the same data. This flag specifies the path where cached features extracted from --train_files will be saved. If empty, or if online augmentation flags are enabled, caching will be disabled.')