Fix #2180 - Added wav_filename to WER report
This commit is contained in:
parent
84e1fa98b9
commit
007e512c00
@ -214,7 +214,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, batch_size, reuse):
|
|||||||
the decoded result and the batch's original Y.
|
the decoded result and the batch's original Y.
|
||||||
'''
|
'''
|
||||||
# Obtain the next batch of data
|
# Obtain the next batch of data
|
||||||
(batch_x, batch_seq_len), batch_y = iterator.get_next()
|
_, (batch_x, batch_seq_len), batch_y = iterator.get_next()
|
||||||
|
|
||||||
if FLAGS.use_cudnn_rnn:
|
if FLAGS.use_cudnn_rnn:
|
||||||
rnn_impl = rnn_impl_cudnn_rnn
|
rnn_impl = rnn_impl_cudnn_rnn
|
||||||
|
11
evaluate.py
11
evaluate.py
@ -52,7 +52,7 @@ def evaluate(test_csvs, create_model, try_loading):
|
|||||||
output_classes=tfv1.data.get_output_classes(test_sets[0]))
|
output_classes=tfv1.data.get_output_classes(test_sets[0]))
|
||||||
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
|
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
|
||||||
|
|
||||||
(batch_x, batch_x_len), batch_y = iterator.get_next()
|
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
|
||||||
|
|
||||||
# One rate per layer
|
# One rate per layer
|
||||||
no_dropout = [None] * 6
|
no_dropout = [None] * 6
|
||||||
@ -89,6 +89,7 @@ def evaluate(test_csvs, create_model, try_loading):
|
|||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
def run_test(init_op, dataset):
|
def run_test(init_op, dataset):
|
||||||
|
wav_filenames = []
|
||||||
losses = []
|
losses = []
|
||||||
predictions = []
|
predictions = []
|
||||||
ground_truths = []
|
ground_truths = []
|
||||||
@ -105,8 +106,8 @@ def evaluate(test_csvs, create_model, try_loading):
|
|||||||
# First pass, compute losses and transposed logits for decoding
|
# First pass, compute losses and transposed logits for decoding
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
batch_logits, batch_loss, batch_lengths, batch_transcripts = \
|
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
|
||||||
session.run([transposed, loss, batch_x_len, batch_y])
|
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
|
||||||
except tf.errors.OutOfRangeError:
|
except tf.errors.OutOfRangeError:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -114,6 +115,7 @@ def evaluate(test_csvs, create_model, try_loading):
|
|||||||
num_processes=num_processes, scorer=scorer)
|
num_processes=num_processes, scorer=scorer)
|
||||||
predictions.extend(d[0][1] for d in decoded)
|
predictions.extend(d[0][1] for d in decoded)
|
||||||
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
|
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
|
||||||
|
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
|
||||||
losses.extend(batch_loss)
|
losses.extend(batch_loss)
|
||||||
|
|
||||||
step_count += 1
|
step_count += 1
|
||||||
@ -121,7 +123,7 @@ def evaluate(test_csvs, create_model, try_loading):
|
|||||||
|
|
||||||
bar.finish()
|
bar.finish()
|
||||||
|
|
||||||
wer, cer, samples = calculate_report(ground_truths, predictions, losses)
|
wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses)
|
||||||
mean_loss = np.mean(losses)
|
mean_loss = np.mean(losses)
|
||||||
|
|
||||||
# Take only the first report_count items
|
# Take only the first report_count items
|
||||||
@ -133,6 +135,7 @@ def evaluate(test_csvs, create_model, try_loading):
|
|||||||
for sample in report_samples:
|
for sample in report_samples:
|
||||||
print('WER: %f, CER: %f, loss: %f' %
|
print('WER: %f, CER: %f, loss: %f' %
|
||||||
(sample.wer, sample.cer, sample.loss))
|
(sample.wer, sample.cer, sample.loss))
|
||||||
|
print(' - wav: file://%s' % sample.wav_filename)
|
||||||
print(' - src: "%s"' % sample.src)
|
print(' - src: "%s"' % sample.src)
|
||||||
print(' - res: "%s"' % sample.res)
|
print(' - res: "%s"' % sample.res)
|
||||||
print('-' * 80)
|
print('-' * 80)
|
||||||
|
@ -34,12 +34,13 @@ def wer_cer_batch(samples):
|
|||||||
|
|
||||||
|
|
||||||
def process_decode_result(item):
|
def process_decode_result(item):
|
||||||
ground_truth, prediction, loss = item
|
wav_filename, ground_truth, prediction, loss = item
|
||||||
char_distance = levenshtein(ground_truth, prediction)
|
char_distance = levenshtein(ground_truth, prediction)
|
||||||
char_length = len(ground_truth)
|
char_length = len(ground_truth)
|
||||||
word_distance = levenshtein(ground_truth.split(), prediction.split())
|
word_distance = levenshtein(ground_truth.split(), prediction.split())
|
||||||
word_length = len(ground_truth.split())
|
word_length = len(ground_truth.split())
|
||||||
return AttrDict({
|
return AttrDict({
|
||||||
|
'wav_filename': wav_filename,
|
||||||
'src': ground_truth,
|
'src': ground_truth,
|
||||||
'res': prediction,
|
'res': prediction,
|
||||||
'loss': loss,
|
'loss': loss,
|
||||||
@ -52,13 +53,13 @@ def process_decode_result(item):
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def calculate_report(labels, decodings, losses):
|
def calculate_report(wav_filenames, labels, decodings, losses):
|
||||||
r'''
|
r'''
|
||||||
This routine will calculate a WER report.
|
This routine will calculate a WER report.
|
||||||
It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest
|
It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest
|
||||||
loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER).
|
loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER).
|
||||||
'''
|
'''
|
||||||
samples = pmap(process_decode_result, zip(labels, decodings, losses))
|
samples = pmap(process_decode_result, zip(wav_filenames, labels, decodings, losses))
|
||||||
|
|
||||||
# Getting the WER and CER from the accumulated edit distances and lengths
|
# Getting the WER and CER from the accumulated edit distances and lengths
|
||||||
samples_wer, samples_cer = wer_cer_batch(samples)
|
samples_wer, samples_cer = wer_cer_batch(samples)
|
||||||
|
@ -52,7 +52,7 @@ def audiofile_to_features(wav_filename):
|
|||||||
def entry_to_features(wav_filename, transcript):
|
def entry_to_features(wav_filename, transcript):
|
||||||
# https://bugs.python.org/issue32117
|
# https://bugs.python.org/issue32117
|
||||||
features, features_len = audiofile_to_features(wav_filename)
|
features, features_len = audiofile_to_features(wav_filename)
|
||||||
return features, features_len, tf.SparseTensor(*transcript)
|
return wav_filename, features, features_len, tf.SparseTensor(*transcript)
|
||||||
|
|
||||||
|
|
||||||
def to_sparse_tuple(sequence):
|
def to_sparse_tuple(sequence):
|
||||||
@ -82,12 +82,13 @@ def create_dataset(csvs, batch_size, cache_path=''):
|
|||||||
shape = sparse.dense_shape
|
shape = sparse.dense_shape
|
||||||
return tf.sparse.reshape(sparse, [shape[0], shape[2]])
|
return tf.sparse.reshape(sparse, [shape[0], shape[2]])
|
||||||
|
|
||||||
def batch_fn(features, features_len, transcripts):
|
def batch_fn(wav_filenames, features, features_len, transcripts):
|
||||||
features = tf.data.Dataset.zip((features, features_len))
|
features = tf.data.Dataset.zip((features, features_len))
|
||||||
features = features.padded_batch(batch_size,
|
features = features.padded_batch(batch_size,
|
||||||
padded_shapes=([None, Config.n_input], []))
|
padded_shapes=([None, Config.n_input], []))
|
||||||
transcripts = transcripts.batch(batch_size).map(sparse_reshape)
|
transcripts = transcripts.batch(batch_size).map(sparse_reshape)
|
||||||
return tf.data.Dataset.zip((features, transcripts))
|
wav_filenames = wav_filenames.batch(batch_size)
|
||||||
|
return tf.data.Dataset.zip((wav_filenames, features, transcripts))
|
||||||
|
|
||||||
num_gpus = len(Config.available_devices)
|
num_gpus = len(Config.available_devices)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user