Clean up evaluate.py

This commit is contained in:
Reuben Morais 2018-11-11 18:04:00 -02:00
parent 5cb1aff531
commit 0dcb1f87c5

View File

@ -161,7 +161,6 @@ def evaluate(test_data, inference_graph, alphabet):
ground_truths = []
predictions = []
distances = []
print('Decoding predictions...')
bar = progressbar.ProgressBar(max_value=batch_count,
@ -181,7 +180,8 @@ def evaluate(test_data, inference_graph, alphabet):
ground_truths.extend(alphabet.decode(l) for l in batch['transcript'])
predictions.extend(d[0][1] for d in decoded)
distances.extend(levenshtein(a, b) for a, b in zip(labels, predictions))
distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]
wer, samples = calculate_report(ground_truths, predictions, distances, losses)
mean_edit_distance = np.mean(distances)
@ -190,12 +190,12 @@ def evaluate(test_data, inference_graph, alphabet):
# Take only the first report_count items
report_samples = itertools.islice(samples, FLAGS.report_count)
print('Test - WER: %f, loss: %f, mean edit distance: %f' %
(wer, mean_loss, mean_edit_distance))
print('Test - WER: %f, CER: %f, loss: %f' %
(wer, mean_edit_distance, mean_loss))
print('-' * 80)
for sample in report_samples:
print('WER: %f, loss: %f, edit distance: %f' %
(sample.wer, sample.loss, sample.distance))
print('WER: %f, CER: %f, loss: %f' %
(sample.wer, sample.distance, sample.loss))
print(' - src: "%s"' % sample.src)
print(' - res: "%s"' % sample.res)
print('-' * 80)