Do separate validation epochs if multiple input files are specified

This commit is contained in:
Reuben Morais 2019-04-10 16:29:11 -03:00
parent 68c17611c6
commit a85af3da49
3 changed files with 36 additions and 29 deletions

View File

@ -376,10 +376,9 @@ def train():
train_init_op = iterator.make_initializer(train_set)
if FLAGS.dev_files:
dev_set = create_dataset(FLAGS.dev_files.split(','),
batch_size=FLAGS.dev_batch_size,
cache_path=FLAGS.dev_cached_features_path)
dev_init_op = iterator.make_initializer(dev_set)
dev_csvs = FLAGS.dev_files.split(',')
dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size) for csv in dev_csvs]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
# Dropout
dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
@ -425,6 +424,15 @@ def train():
initializer = tf.global_variables_initializer()
# Disable progress logging if needed
if FLAGS.show_progressbar:
pbar_class = progressbar.ProgressBar
def log_progress(*args, **kwargs):
pass
else:
pbar_class = progressbar.NullBar
log_progress = log_info
with tf.Session(config=Config.session_config) as session:
log_debug('Session opened.')
@ -445,7 +453,7 @@ def train():
' - consider using load option "auto" or "init".' % FLAGS.load)
sys.exit(1)
def run_set(set_name, init_op):
def run_set(set_name, epoch, init_op, dataset=None):
is_train = set_name == 'train'
train_op = apply_gradient_op if is_train else []
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
@ -456,6 +464,7 @@ def train():
step_summary_writer = step_summary_writers.get(set_name)
checkpoint_time = time.time()
# Setup progress bar
class LossWidget(progressbar.widgets.FormatLabel):
def __init__(self):
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
@ -464,12 +473,12 @@ def train():
data['mean_loss'] = total_loss / step_count if step_count else 0.0
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
if FLAGS.show_progressbar:
pbar = progressbar.ProgressBar(widgets=['Epoch {}'.format(epoch),
' | ', progressbar.widgets.Timer(),
' | Steps: ', progressbar.widgets.Counter(),
' | ', LossWidget()])
pbar.start()
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
widgets = [' | ', progressbar.widgets.Timer(),
' | Steps: ', progressbar.widgets.Counter(),
' | ', LossWidget()]
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
pbar = pbar_class(prefix=prefix, widgets=widgets, suffix=suffix, fd=sys.stdout).start()
# Initialize iterator to the appropriate dataset
session.run(init_op)
@ -486,8 +495,7 @@ def train():
total_loss += batch_loss
step_count += 1
if FLAGS.show_progressbar:
pbar.update(step_count)
pbar.update(step_count)
step_summary_writer.add_summary(step_summary, current_step)
@ -495,10 +503,9 @@ def train():
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
checkpoint_time = time.time()
if FLAGS.show_progressbar:
pbar.finish()
return total_loss / step_count
pbar.finish()
mean_loss = total_loss / step_count if step_count > 0 else 0.0
return mean_loss
log_info('STARTING Optimization')
best_dev_loss = float('inf')
@ -506,20 +513,21 @@ def train():
try:
for epoch in range(FLAGS.epochs):
# Training
if not FLAGS.show_progressbar:
log_info('Training epoch %d...' % epoch)
train_loss = run_set('train', train_init_op)
if not FLAGS.show_progressbar:
log_info('Finished training epoch %d - loss: %f' % (epoch, train_loss))
log_progress('Training epoch %d...' % epoch)
train_loss = run_set('train', epoch, train_init_op)
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
if FLAGS.dev_files:
# Validation
if not FLAGS.show_progressbar:
log_info('Validating epoch %d...' % epoch)
dev_loss = run_set('dev', dev_init_op)
if not FLAGS.show_progressbar:
log_info('Finished validating epoch %d - loss: %f' % (epoch, dev_loss))
dev_loss = 0.0
for csv, init_op in zip(dev_csvs, dev_init_ops):
log_progress('Validating epoch %d on %s...' % (epoch, csv))
set_loss = run_set('dev', epoch, init_op, dataset=csv)
dev_loss += set_loss
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss))
dev_loss = dev_loss / len(dev_csvs)
dev_losses.append(dev_loss)
if dev_loss < best_dev_loss:

View File

@ -63,7 +63,7 @@ def to_sparse_tuple(sequence):
return indices, sequence, shape
def create_dataset(csvs, batch_size, cache_path):
def create_dataset(csvs, batch_size, cache_path=''):
df = read_csvs(csvs)
df.sort_values(by='wav_filesize', inplace=True)

View File

@ -17,7 +17,6 @@ def create_flags():
f.DEFINE_string('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')
f.DEFINE_string('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
f.DEFINE_string('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
f.DEFINE_string('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')