Do separate validation epochs if multiple input files are specified
This commit is contained in:
parent
68c17611c6
commit
a85af3da49
@ -376,10 +376,9 @@ def train():
|
||||
train_init_op = iterator.make_initializer(train_set)
|
||||
|
||||
if FLAGS.dev_files:
|
||||
dev_set = create_dataset(FLAGS.dev_files.split(','),
|
||||
batch_size=FLAGS.dev_batch_size,
|
||||
cache_path=FLAGS.dev_cached_features_path)
|
||||
dev_init_op = iterator.make_initializer(dev_set)
|
||||
dev_csvs = FLAGS.dev_files.split(',')
|
||||
dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size) for csv in dev_csvs]
|
||||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||
|
||||
# Dropout
|
||||
dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
|
||||
@ -425,6 +424,15 @@ def train():
|
||||
|
||||
initializer = tf.global_variables_initializer()
|
||||
|
||||
# Disable progress logging if needed
|
||||
if FLAGS.show_progressbar:
|
||||
pbar_class = progressbar.ProgressBar
|
||||
def log_progress(*args, **kwargs):
|
||||
pass
|
||||
else:
|
||||
pbar_class = progressbar.NullBar
|
||||
log_progress = log_info
|
||||
|
||||
with tf.Session(config=Config.session_config) as session:
|
||||
log_debug('Session opened.')
|
||||
|
||||
@ -445,7 +453,7 @@ def train():
|
||||
' - consider using load option "auto" or "init".' % FLAGS.load)
|
||||
sys.exit(1)
|
||||
|
||||
def run_set(set_name, init_op):
|
||||
def run_set(set_name, epoch, init_op, dataset=None):
|
||||
is_train = set_name == 'train'
|
||||
train_op = apply_gradient_op if is_train else []
|
||||
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
|
||||
@ -456,6 +464,7 @@ def train():
|
||||
step_summary_writer = step_summary_writers.get(set_name)
|
||||
checkpoint_time = time.time()
|
||||
|
||||
# Setup progress bar
|
||||
class LossWidget(progressbar.widgets.FormatLabel):
|
||||
def __init__(self):
|
||||
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
|
||||
@ -464,12 +473,12 @@ def train():
|
||||
data['mean_loss'] = total_loss / step_count if step_count else 0.0
|
||||
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
|
||||
|
||||
if FLAGS.show_progressbar:
|
||||
pbar = progressbar.ProgressBar(widgets=['Epoch {}'.format(epoch),
|
||||
' | ', progressbar.widgets.Timer(),
|
||||
' | Steps: ', progressbar.widgets.Counter(),
|
||||
' | ', LossWidget()])
|
||||
pbar.start()
|
||||
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
|
||||
widgets = [' | ', progressbar.widgets.Timer(),
|
||||
' | Steps: ', progressbar.widgets.Counter(),
|
||||
' | ', LossWidget()]
|
||||
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
|
||||
pbar = pbar_class(prefix=prefix, widgets=widgets, suffix=suffix, fd=sys.stdout).start()
|
||||
|
||||
# Initialize iterator to the appropriate dataset
|
||||
session.run(init_op)
|
||||
@ -486,8 +495,7 @@ def train():
|
||||
total_loss += batch_loss
|
||||
step_count += 1
|
||||
|
||||
if FLAGS.show_progressbar:
|
||||
pbar.update(step_count)
|
||||
pbar.update(step_count)
|
||||
|
||||
step_summary_writer.add_summary(step_summary, current_step)
|
||||
|
||||
@ -495,10 +503,9 @@ def train():
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
|
||||
checkpoint_time = time.time()
|
||||
|
||||
if FLAGS.show_progressbar:
|
||||
pbar.finish()
|
||||
|
||||
return total_loss / step_count
|
||||
pbar.finish()
|
||||
mean_loss = total_loss / step_count if step_count > 0 else 0.0
|
||||
return mean_loss
|
||||
|
||||
log_info('STARTING Optimization')
|
||||
best_dev_loss = float('inf')
|
||||
@ -506,20 +513,21 @@ def train():
|
||||
try:
|
||||
for epoch in range(FLAGS.epochs):
|
||||
# Training
|
||||
if not FLAGS.show_progressbar:
|
||||
log_info('Training epoch %d...' % epoch)
|
||||
train_loss = run_set('train', train_init_op)
|
||||
if not FLAGS.show_progressbar:
|
||||
log_info('Finished training epoch %d - loss: %f' % (epoch, train_loss))
|
||||
log_progress('Training epoch %d...' % epoch)
|
||||
train_loss = run_set('train', epoch, train_init_op)
|
||||
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
|
||||
|
||||
if FLAGS.dev_files:
|
||||
# Validation
|
||||
if not FLAGS.show_progressbar:
|
||||
log_info('Validating epoch %d...' % epoch)
|
||||
dev_loss = run_set('dev', dev_init_op)
|
||||
if not FLAGS.show_progressbar:
|
||||
log_info('Finished validating epoch %d - loss: %f' % (epoch, dev_loss))
|
||||
dev_loss = 0.0
|
||||
for csv, init_op in zip(dev_csvs, dev_init_ops):
|
||||
log_progress('Validating epoch %d on %s...' % (epoch, csv))
|
||||
set_loss = run_set('dev', epoch, init_op, dataset=csv)
|
||||
dev_loss += set_loss
|
||||
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss))
|
||||
dev_loss = dev_loss / len(dev_csvs)
|
||||
|
||||
dev_losses.append(dev_loss)
|
||||
|
||||
if dev_loss < best_dev_loss:
|
||||
|
@ -63,7 +63,7 @@ def to_sparse_tuple(sequence):
|
||||
return indices, sequence, shape
|
||||
|
||||
|
||||
def create_dataset(csvs, batch_size, cache_path):
|
||||
def create_dataset(csvs, batch_size, cache_path=''):
|
||||
df = read_csvs(csvs)
|
||||
df.sort_values(by='wav_filesize', inplace=True)
|
||||
|
||||
|
@ -17,7 +17,6 @@ def create_flags():
|
||||
f.DEFINE_string('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')
|
||||
|
||||
f.DEFINE_string('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
|
||||
f.DEFINE_string('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
|
||||
f.DEFINE_string('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
|
||||
|
||||
f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
|
||||
|
Loading…
x
Reference in New Issue
Block a user