Merge pull request #2268 from tilmankamp/reportfilename

Fix #2180 - Added wav_filename to WER report
This commit is contained in:
Tilman Kamp 2019-07-23 17:19:32 +02:00 committed by GitHub
commit daa6167829
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 16 additions and 11 deletions

View File

@ -214,7 +214,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, batch_size, reuse):
the decoded result and the batch's original Y.
'''
# Obtain the next batch of data
(batch_x, batch_seq_len), batch_y = iterator.get_next()
_, (batch_x, batch_seq_len), batch_y = iterator.get_next()
if FLAGS.use_cudnn_rnn:
rnn_impl = rnn_impl_cudnn_rnn

View File

@ -52,7 +52,7 @@ def evaluate(test_csvs, create_model, try_loading):
output_classes=tfv1.data.get_output_classes(test_sets[0]))
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
(batch_x, batch_x_len), batch_y = iterator.get_next()
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
# One rate per layer
no_dropout = [None] * 6
@ -89,6 +89,7 @@ def evaluate(test_csvs, create_model, try_loading):
exit(1)
def run_test(init_op, dataset):
wav_filenames = []
losses = []
predictions = []
ground_truths = []
@ -105,8 +106,8 @@ def evaluate(test_csvs, create_model, try_loading):
# First pass, compute losses and transposed logits for decoding
while True:
try:
batch_logits, batch_loss, batch_lengths, batch_transcripts = \
session.run([transposed, loss, batch_x_len, batch_y])
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
except tf.errors.OutOfRangeError:
break
@ -114,6 +115,7 @@ def evaluate(test_csvs, create_model, try_loading):
num_processes=num_processes, scorer=scorer)
predictions.extend(d[0][1] for d in decoded)
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
losses.extend(batch_loss)
step_count += 1
@ -121,7 +123,7 @@ def evaluate(test_csvs, create_model, try_loading):
bar.finish()
wer, cer, samples = calculate_report(ground_truths, predictions, losses)
wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses)
mean_loss = np.mean(losses)
# Take only the first report_count items
@ -133,6 +135,7 @@ def evaluate(test_csvs, create_model, try_loading):
for sample in report_samples:
print('WER: %f, CER: %f, loss: %f' %
(sample.wer, sample.cer, sample.loss))
print(' - wav: file://%s' % sample.wav_filename)
print(' - src: "%s"' % sample.src)
print(' - res: "%s"' % sample.res)
print('-' * 80)

View File

@ -34,12 +34,13 @@ def wer_cer_batch(samples):
def process_decode_result(item):
ground_truth, prediction, loss = item
wav_filename, ground_truth, prediction, loss = item
char_distance = levenshtein(ground_truth, prediction)
char_length = len(ground_truth)
word_distance = levenshtein(ground_truth.split(), prediction.split())
word_length = len(ground_truth.split())
return AttrDict({
'wav_filename': wav_filename,
'src': ground_truth,
'res': prediction,
'loss': loss,
@ -52,13 +53,13 @@ def process_decode_result(item):
})
def calculate_report(labels, decodings, losses):
def calculate_report(wav_filenames, labels, decodings, losses):
r'''
This routine will calculate a WER report.
It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest
loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER).
'''
samples = pmap(process_decode_result, zip(labels, decodings, losses))
samples = pmap(process_decode_result, zip(wav_filenames, labels, decodings, losses))
# Getting the WER and CER from the accumulated edit distances and lengths
samples_wer, samples_cer = wer_cer_batch(samples)

View File

@ -52,7 +52,7 @@ def audiofile_to_features(wav_filename):
def entry_to_features(wav_filename, transcript):
# https://bugs.python.org/issue32117
features, features_len = audiofile_to_features(wav_filename)
return features, features_len, tf.SparseTensor(*transcript)
return wav_filename, features, features_len, tf.SparseTensor(*transcript)
def to_sparse_tuple(sequence):
@ -82,12 +82,13 @@ def create_dataset(csvs, batch_size, cache_path=''):
shape = sparse.dense_shape
return tf.sparse.reshape(sparse, [shape[0], shape[2]])
def batch_fn(features, features_len, transcripts):
def batch_fn(wav_filenames, features, features_len, transcripts):
features = tf.data.Dataset.zip((features, features_len))
features = features.padded_batch(batch_size,
padded_shapes=([None, Config.n_input], []))
transcripts = transcripts.batch(batch_size).map(sparse_reshape)
return tf.data.Dataset.zip((features, transcripts))
wav_filenames = wav_filenames.batch(batch_size)
return tf.data.Dataset.zip((wav_filenames, features, transcripts))
num_gpus = len(Config.available_devices)