diff --git a/DeepSpeech.py b/DeepSpeech.py index 0e91b0f5..3146d6a2 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -376,10 +376,9 @@ def train(): train_init_op = iterator.make_initializer(train_set) if FLAGS.dev_files: - dev_set = create_dataset(FLAGS.dev_files.split(','), - batch_size=FLAGS.dev_batch_size, - cache_path=FLAGS.dev_cached_features_path) - dev_init_op = iterator.make_initializer(dev_set) + dev_csvs = FLAGS.dev_files.split(',') + dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size) for csv in dev_csvs] + dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets] # Dropout dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)] @@ -425,6 +424,15 @@ def train(): initializer = tf.global_variables_initializer() + # Disable progress logging if needed + if FLAGS.show_progressbar: + pbar_class = progressbar.ProgressBar + def log_progress(*args, **kwargs): + pass + else: + pbar_class = progressbar.NullBar + log_progress = log_info + with tf.Session(config=Config.session_config) as session: log_debug('Session opened.') @@ -445,7 +453,7 @@ def train(): ' - consider using load option "auto" or "init".' % FLAGS.load) sys.exit(1) - def run_set(set_name, init_op): + def run_set(set_name, epoch, init_op, dataset=None): is_train = set_name == 'train' train_op = apply_gradient_op if is_train else [] feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict @@ -456,6 +464,7 @@ def train(): step_summary_writer = step_summary_writers.get(set_name) checkpoint_time = time.time() + # Setup progress bar class LossWidget(progressbar.widgets.FormatLabel): def __init__(self): progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f') @@ -464,12 +473,12 @@ def train(): data['mean_loss'] = total_loss / step_count if step_count else 0.0 return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs) - if FLAGS.show_progressbar: - pbar = progressbar.ProgressBar(widgets=['Epoch {}'.format(epoch), - ' | ', progressbar.widgets.Timer(), - ' | Steps: ', progressbar.widgets.Counter(), - ' | ', LossWidget()]) - pbar.start() + prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation') + widgets = [' | ', progressbar.widgets.Timer(), + ' | Steps: ', progressbar.widgets.Counter(), + ' | ', LossWidget()] + suffix = ' | Dataset: {}'.format(dataset) if dataset else None + pbar = pbar_class(prefix=prefix, widgets=widgets, suffix=suffix, fd=sys.stdout).start() # Initialize iterator to the appropriate dataset session.run(init_op) @@ -486,8 +495,7 @@ def train(): total_loss += batch_loss step_count += 1 - if FLAGS.show_progressbar: - pbar.update(step_count) + pbar.update(step_count) step_summary_writer.add_summary(step_summary, current_step) @@ -495,10 +503,9 @@ def train(): checkpoint_saver.save(session, checkpoint_path, global_step=current_step) checkpoint_time = time.time() - if FLAGS.show_progressbar: - pbar.finish() - - return total_loss / step_count + pbar.finish() + mean_loss = total_loss / step_count if step_count > 0 else 0.0 + return mean_loss log_info('STARTING Optimization') best_dev_loss = float('inf') @@ -506,20 +513,21 @@ def train(): try: for epoch in range(FLAGS.epochs): # Training - if not FLAGS.show_progressbar: - log_info('Training epoch %d...' % epoch) - train_loss = run_set('train', train_init_op) - if not FLAGS.show_progressbar: - log_info('Finished training epoch %d - loss: %f' % (epoch, train_loss)) + log_progress('Training epoch %d...' % epoch) + train_loss = run_set('train', epoch, train_init_op) + log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss)) checkpoint_saver.save(session, checkpoint_path, global_step=global_step) if FLAGS.dev_files: # Validation - if not FLAGS.show_progressbar: - log_info('Validating epoch %d...' % epoch) - dev_loss = run_set('dev', dev_init_op) - if not FLAGS.show_progressbar: - log_info('Finished validating epoch %d - loss: %f' % (epoch, dev_loss)) + dev_loss = 0.0 + for csv, init_op in zip(dev_csvs, dev_init_ops): + log_progress('Validating epoch %d on %s...' % (epoch, csv)) + set_loss = run_set('dev', epoch, init_op, dataset=csv) + dev_loss += set_loss + log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss)) + dev_loss = dev_loss / len(dev_csvs) + dev_losses.append(dev_loss) if dev_loss < best_dev_loss: diff --git a/util/feeding.py b/util/feeding.py index 2f01c880..e15914ab 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -63,7 +63,7 @@ def to_sparse_tuple(sequence): return indices, sequence, shape -def create_dataset(csvs, batch_size, cache_path): +def create_dataset(csvs, batch_size, cache_path=''): df = read_csvs(csvs) df.sort_values(by='wav_filesize', inplace=True) diff --git a/util/flags.py b/util/flags.py index a6b6386f..f9d3bcd3 100644 --- a/util/flags.py +++ b/util/flags.py @@ -17,7 +17,6 @@ def create_flags(): f.DEFINE_string('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.') f.DEFINE_string('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged') - f.DEFINE_string('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged') f.DEFINE_string('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged') f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')