Merge pull request #2038 from mozilla/split-dev-test-epochs
Perform separate validation and test epochs per dataset when multiple files are specified (Fixes #1634 and #2043)
This commit is contained in:
commit
1e601d5c4a
@ -14,6 +14,7 @@ import progressbar
|
|||||||
import shutil
|
import shutil
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||||
from evaluate import evaluate
|
from evaluate import evaluate
|
||||||
from six.moves import zip, range
|
from six.moves import zip, range
|
||||||
@ -21,7 +22,7 @@ from tensorflow.python.tools import freeze_graph
|
|||||||
from util.config import Config, initialize_globals
|
from util.config import Config, initialize_globals
|
||||||
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||||
from util.flags import create_flags, FLAGS
|
from util.flags import create_flags, FLAGS
|
||||||
from util.logging import log_info, log_error, log_debug
|
from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
|
||||||
|
|
||||||
|
|
||||||
# Graph Creation
|
# Graph Creation
|
||||||
@ -366,7 +367,7 @@ def train():
|
|||||||
# Create training and validation datasets
|
# Create training and validation datasets
|
||||||
train_set = create_dataset(FLAGS.train_files.split(','),
|
train_set = create_dataset(FLAGS.train_files.split(','),
|
||||||
batch_size=FLAGS.train_batch_size,
|
batch_size=FLAGS.train_batch_size,
|
||||||
cache_path=FLAGS.train_cached_features_path)
|
cache_path=FLAGS.feature_cache)
|
||||||
|
|
||||||
iterator = tf.data.Iterator.from_structure(train_set.output_types,
|
iterator = tf.data.Iterator.from_structure(train_set.output_types,
|
||||||
train_set.output_shapes,
|
train_set.output_shapes,
|
||||||
@ -376,10 +377,9 @@ def train():
|
|||||||
train_init_op = iterator.make_initializer(train_set)
|
train_init_op = iterator.make_initializer(train_set)
|
||||||
|
|
||||||
if FLAGS.dev_files:
|
if FLAGS.dev_files:
|
||||||
dev_set = create_dataset(FLAGS.dev_files.split(','),
|
dev_csvs = FLAGS.dev_files.split(',')
|
||||||
batch_size=FLAGS.dev_batch_size,
|
dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size) for csv in dev_csvs]
|
||||||
cache_path=FLAGS.dev_cached_features_path)
|
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||||
dev_init_op = iterator.make_initializer(dev_set)
|
|
||||||
|
|
||||||
# Dropout
|
# Dropout
|
||||||
dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
|
dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
|
||||||
@ -445,7 +445,7 @@ def train():
|
|||||||
' - consider using load option "auto" or "init".' % FLAGS.load)
|
' - consider using load option "auto" or "init".' % FLAGS.load)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
def run_set(set_name, init_op):
|
def run_set(set_name, epoch, init_op, dataset=None):
|
||||||
is_train = set_name == 'train'
|
is_train = set_name == 'train'
|
||||||
train_op = apply_gradient_op if is_train else []
|
train_op = apply_gradient_op if is_train else []
|
||||||
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
|
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
|
||||||
@ -456,6 +456,7 @@ def train():
|
|||||||
step_summary_writer = step_summary_writers.get(set_name)
|
step_summary_writer = step_summary_writers.get(set_name)
|
||||||
checkpoint_time = time.time()
|
checkpoint_time = time.time()
|
||||||
|
|
||||||
|
# Setup progress bar
|
||||||
class LossWidget(progressbar.widgets.FormatLabel):
|
class LossWidget(progressbar.widgets.FormatLabel):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
|
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
|
||||||
@ -464,12 +465,12 @@ def train():
|
|||||||
data['mean_loss'] = total_loss / step_count if step_count else 0.0
|
data['mean_loss'] = total_loss / step_count if step_count else 0.0
|
||||||
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
|
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
|
||||||
|
|
||||||
if FLAGS.show_progressbar:
|
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
|
||||||
pbar = progressbar.ProgressBar(widgets=['Epoch {}'.format(epoch),
|
widgets = [' | ', progressbar.widgets.Timer(),
|
||||||
' | ', progressbar.widgets.Timer(),
|
' | Steps: ', progressbar.widgets.Counter(),
|
||||||
' | Steps: ', progressbar.widgets.Counter(),
|
' | ', LossWidget()]
|
||||||
' | ', LossWidget()])
|
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
|
||||||
pbar.start()
|
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
|
||||||
|
|
||||||
# Initialize iterator to the appropriate dataset
|
# Initialize iterator to the appropriate dataset
|
||||||
session.run(init_op)
|
session.run(init_op)
|
||||||
@ -486,8 +487,7 @@ def train():
|
|||||||
total_loss += batch_loss
|
total_loss += batch_loss
|
||||||
step_count += 1
|
step_count += 1
|
||||||
|
|
||||||
if FLAGS.show_progressbar:
|
pbar.update(step_count)
|
||||||
pbar.update(step_count)
|
|
||||||
|
|
||||||
step_summary_writer.add_summary(step_summary, current_step)
|
step_summary_writer.add_summary(step_summary, current_step)
|
||||||
|
|
||||||
@ -495,31 +495,34 @@ def train():
|
|||||||
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
|
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
|
||||||
checkpoint_time = time.time()
|
checkpoint_time = time.time()
|
||||||
|
|
||||||
if FLAGS.show_progressbar:
|
pbar.finish()
|
||||||
pbar.finish()
|
mean_loss = total_loss / step_count if step_count > 0 else 0.0
|
||||||
|
return mean_loss, step_count
|
||||||
return total_loss / step_count
|
|
||||||
|
|
||||||
log_info('STARTING Optimization')
|
log_info('STARTING Optimization')
|
||||||
|
train_start_time = datetime.utcnow()
|
||||||
best_dev_loss = float('inf')
|
best_dev_loss = float('inf')
|
||||||
dev_losses = []
|
dev_losses = []
|
||||||
try:
|
try:
|
||||||
for epoch in range(FLAGS.epochs):
|
for epoch in range(FLAGS.epochs):
|
||||||
# Training
|
# Training
|
||||||
if not FLAGS.show_progressbar:
|
log_progress('Training epoch %d...' % epoch)
|
||||||
log_info('Training epoch %d...' % epoch)
|
train_loss, _ = run_set('train', epoch, train_init_op)
|
||||||
train_loss = run_set('train', train_init_op)
|
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
|
||||||
if not FLAGS.show_progressbar:
|
|
||||||
log_info('Finished training epoch %d - loss: %f' % (epoch, train_loss))
|
|
||||||
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
|
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
|
||||||
|
|
||||||
if FLAGS.dev_files:
|
if FLAGS.dev_files:
|
||||||
# Validation
|
# Validation
|
||||||
if not FLAGS.show_progressbar:
|
dev_loss = 0.0
|
||||||
log_info('Validating epoch %d...' % epoch)
|
total_steps = 0
|
||||||
dev_loss = run_set('dev', dev_init_op)
|
for csv, init_op in zip(dev_csvs, dev_init_ops):
|
||||||
if not FLAGS.show_progressbar:
|
log_progress('Validating epoch %d on %s...' % (epoch, csv))
|
||||||
log_info('Finished validating epoch %d - loss: %f' % (epoch, dev_loss))
|
set_loss, steps = run_set('dev', epoch, init_op, dataset=csv)
|
||||||
|
dev_loss += set_loss * steps
|
||||||
|
total_steps += steps
|
||||||
|
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss))
|
||||||
|
dev_loss = dev_loss / total_steps
|
||||||
|
|
||||||
dev_losses.append(dev_loss)
|
dev_losses.append(dev_loss)
|
||||||
|
|
||||||
if dev_loss < best_dev_loss:
|
if dev_loss < best_dev_loss:
|
||||||
@ -543,6 +546,7 @@ def train():
|
|||||||
break
|
break
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
|
||||||
log_debug('Session closed.')
|
log_debug('Session closed.')
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ fi;
|
|||||||
|
|
||||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||||
--train_cached_features_path '/tmp/ldc93s1_cache' \
|
--feature_cache '/tmp/ldc93s1_cache' \
|
||||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||||
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
||||||
--n_hidden 100 --epochs $epoch_count \
|
--n_hidden 100 --epochs $epoch_count \
|
||||||
|
129
evaluate.py
129
evaluate.py
@ -18,7 +18,7 @@ from util.config import Config, initialize_globals
|
|||||||
from util.evaluate_tools import calculate_report
|
from util.evaluate_tools import calculate_report
|
||||||
from util.feeding import create_dataset
|
from util.feeding import create_dataset
|
||||||
from util.flags import create_flags, FLAGS
|
from util.flags import create_flags, FLAGS
|
||||||
from util.logging import log_error
|
from util.logging import log_error, log_progress, create_progressbar
|
||||||
from util.text import levenshtein
|
from util.text import levenshtein
|
||||||
|
|
||||||
|
|
||||||
@ -45,12 +45,14 @@ def evaluate(test_csvs, create_model, try_loading):
|
|||||||
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
|
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
|
||||||
Config.alphabet)
|
Config.alphabet)
|
||||||
|
|
||||||
test_set = create_dataset(test_csvs,
|
test_csvs = FLAGS.test_files.split(',')
|
||||||
batch_size=FLAGS.test_batch_size,
|
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size) for csv in test_csvs]
|
||||||
cache_path=FLAGS.test_cached_features_path)
|
iterator = tf.data.Iterator.from_structure(test_sets[0].output_types,
|
||||||
it = test_set.make_one_shot_iterator()
|
test_sets[0].output_shapes,
|
||||||
|
output_classes=test_sets[0].output_classes)
|
||||||
|
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
|
||||||
|
|
||||||
(batch_x, batch_x_len), batch_y = it.get_next()
|
(batch_x, batch_x_len), batch_y = iterator.get_next()
|
||||||
|
|
||||||
# One rate per layer
|
# One rate per layer
|
||||||
no_dropout = [None] * 6
|
no_dropout = [None] * 6
|
||||||
@ -67,10 +69,16 @@ def evaluate(test_csvs, create_model, try_loading):
|
|||||||
|
|
||||||
tf.train.get_or_create_global_step()
|
tf.train.get_or_create_global_step()
|
||||||
|
|
||||||
with tf.Session(config=Config.session_config) as session:
|
# Get number of accessible CPU cores for this process
|
||||||
# Create a saver using variables from the above newly created graph
|
try:
|
||||||
saver = tf.train.Saver()
|
num_processes = cpu_count()
|
||||||
|
except NotImplementedError:
|
||||||
|
num_processes = 1
|
||||||
|
|
||||||
|
# Create a saver using variables from the above newly created graph
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
|
||||||
|
with tf.Session(config=Config.session_config) as session:
|
||||||
# Restore variables from training checkpoint
|
# Restore variables from training checkpoint
|
||||||
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
|
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
|
||||||
if not loaded:
|
if not loaded:
|
||||||
@ -79,70 +87,75 @@ def evaluate(test_csvs, create_model, try_loading):
|
|||||||
log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir))
|
log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir))
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
logitses = []
|
def run_test(init_op, dataset):
|
||||||
losses = []
|
logitses = []
|
||||||
seq_lengths = []
|
losses = []
|
||||||
ground_truths = []
|
seq_lengths = []
|
||||||
|
ground_truths = []
|
||||||
|
|
||||||
print('Computing acoustic model predictions...')
|
bar = create_progressbar(prefix='Computing acoustic model predictions | ',
|
||||||
bar = progressbar.ProgressBar(widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()])
|
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
|
||||||
|
log_progress('Computing acoustic model predictions...')
|
||||||
|
|
||||||
step_count = 0
|
step_count = 0
|
||||||
|
|
||||||
# First pass, compute losses and transposed logits for decoding
|
# Initialize iterator to the appropriate dataset
|
||||||
while True:
|
session.run(init_op)
|
||||||
try:
|
|
||||||
logits, loss_, lengths, transcripts = session.run([transposed, loss, batch_x_len, batch_y])
|
|
||||||
except tf.errors.OutOfRangeError:
|
|
||||||
break
|
|
||||||
|
|
||||||
step_count += 1
|
# First pass, compute losses and transposed logits for decoding
|
||||||
bar.update(step_count)
|
while True:
|
||||||
|
try:
|
||||||
|
logits, loss_, lengths, transcripts = session.run([transposed, loss, batch_x_len, batch_y])
|
||||||
|
except tf.errors.OutOfRangeError:
|
||||||
|
break
|
||||||
|
|
||||||
logitses.append(logits)
|
step_count += 1
|
||||||
losses.extend(loss_)
|
bar.update(step_count)
|
||||||
seq_lengths.append(lengths)
|
|
||||||
ground_truths.extend(sparse_tensor_value_to_texts(transcripts, Config.alphabet))
|
|
||||||
|
|
||||||
bar.finish()
|
logitses.append(logits)
|
||||||
|
losses.extend(loss_)
|
||||||
|
seq_lengths.append(lengths)
|
||||||
|
ground_truths.extend(sparse_tensor_value_to_texts(transcripts, Config.alphabet))
|
||||||
|
|
||||||
predictions = []
|
bar.finish()
|
||||||
|
|
||||||
# Get number of accessible CPU cores for this process
|
predictions = []
|
||||||
try:
|
|
||||||
num_processes = cpu_count()
|
|
||||||
except NotImplementedError:
|
|
||||||
num_processes = 1
|
|
||||||
|
|
||||||
print('Decoding predictions...')
|
bar = create_progressbar(max_value=step_count,
|
||||||
bar = progressbar.ProgressBar(max_value=step_count,
|
prefix='Decoding predictions | ').start()
|
||||||
widget=progressbar.AdaptiveETA)
|
log_progress('Decoding predictions...')
|
||||||
|
|
||||||
# Second pass, decode logits and compute WER and edit distance metrics
|
# Second pass, decode logits and compute WER and edit distance metrics
|
||||||
for logits, seq_length in bar(zip(logitses, seq_lengths)):
|
for logits, seq_length in bar(zip(logitses, seq_lengths)):
|
||||||
decoded = ctc_beam_search_decoder_batch(logits, seq_length, Config.alphabet, FLAGS.beam_width,
|
decoded = ctc_beam_search_decoder_batch(logits, seq_length, Config.alphabet, FLAGS.beam_width,
|
||||||
num_processes=num_processes, scorer=scorer)
|
num_processes=num_processes, scorer=scorer)
|
||||||
predictions.extend(d[0][1] for d in decoded)
|
predictions.extend(d[0][1] for d in decoded)
|
||||||
|
|
||||||
distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]
|
distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]
|
||||||
|
|
||||||
wer, cer, samples = calculate_report(ground_truths, predictions, distances, losses)
|
wer, cer, samples = calculate_report(ground_truths, predictions, distances, losses)
|
||||||
mean_loss = np.mean(losses)
|
mean_loss = np.mean(losses)
|
||||||
|
|
||||||
# Take only the first report_count items
|
# Take only the first report_count items
|
||||||
report_samples = itertools.islice(samples, FLAGS.report_count)
|
report_samples = itertools.islice(samples, FLAGS.report_count)
|
||||||
|
|
||||||
print('Test - WER: %f, CER: %f, loss: %f' %
|
print('Test on %s - WER: %f, CER: %f, loss: %f' %
|
||||||
(wer, cer, mean_loss))
|
(dataset, wer, cer, mean_loss))
|
||||||
print('-' * 80)
|
print('-' * 80)
|
||||||
for sample in report_samples:
|
for sample in report_samples:
|
||||||
print('WER: %f, CER: %f, loss: %f' %
|
print('WER: %f, CER: %f, loss: %f' %
|
||||||
(sample.wer, sample.distance, sample.loss))
|
(sample.wer, sample.distance, sample.loss))
|
||||||
print(' - src: "%s"' % sample.src)
|
print(' - src: "%s"' % sample.src)
|
||||||
print(' - res: "%s"' % sample.res)
|
print(' - res: "%s"' % sample.res)
|
||||||
print('-' * 80)
|
print('-' * 80)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
samples = []
|
||||||
|
for csv, init_op in zip(test_csvs, test_init_ops):
|
||||||
|
print('Testing model on {}'.format(csv))
|
||||||
|
samples.extend(run_test(init_op, dataset=csv))
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
def main(_):
|
||||||
|
@ -63,7 +63,7 @@ def to_sparse_tuple(sequence):
|
|||||||
return indices, sequence, shape
|
return indices, sequence, shape
|
||||||
|
|
||||||
|
|
||||||
def create_dataset(csvs, batch_size, cache_path):
|
def create_dataset(csvs, batch_size, cache_path=''):
|
||||||
df = read_csvs(csvs)
|
df = read_csvs(csvs)
|
||||||
df.sort_values(by='wav_filesize', inplace=True)
|
df.sort_values(by='wav_filesize', inplace=True)
|
||||||
|
|
||||||
|
@ -16,9 +16,7 @@ def create_flags():
|
|||||||
f.DEFINE_string('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.')
|
f.DEFINE_string('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.')
|
||||||
f.DEFINE_string('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')
|
f.DEFINE_string('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')
|
||||||
|
|
||||||
f.DEFINE_string('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
|
f.DEFINE_string('feature_cache', '', 'path where cached features extracted from --train_files will be saved. If empty, caching will be done in memory and no files will be written.')
|
||||||
f.DEFINE_string('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
|
|
||||||
f.DEFINE_string('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
|
|
||||||
|
|
||||||
f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
|
f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
|
||||||
f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds')
|
f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds')
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import progressbar
|
||||||
|
import sys
|
||||||
|
|
||||||
from util.flags import FLAGS
|
from util.flags import FLAGS
|
||||||
|
|
||||||
|
|
||||||
@ -28,3 +31,19 @@ def log_warn(message):
|
|||||||
def log_error(message):
|
def log_error(message):
|
||||||
if FLAGS.log_level <= 3:
|
if FLAGS.log_level <= 3:
|
||||||
prefix_print('E ', message)
|
prefix_print('E ', message)
|
||||||
|
|
||||||
|
|
||||||
|
def create_progressbar(*args, **kwargs):
|
||||||
|
# Progress bars in stdout by default
|
||||||
|
if 'fd' not in kwargs:
|
||||||
|
kwargs['fd'] = sys.stdout
|
||||||
|
|
||||||
|
if FLAGS.show_progressbar:
|
||||||
|
return progressbar.ProgressBar(*args, **kwargs)
|
||||||
|
|
||||||
|
return progressbar.NullBar(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def log_progress(message):
|
||||||
|
if not FLAGS.show_progressbar:
|
||||||
|
log_info(message)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user