diff --git a/evaluate.py b/evaluate.py index 4b0602c1..8df73966 100755 --- a/evaluate.py +++ b/evaluate.py @@ -2,14 +2,12 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function -import itertools import json import sys from multiprocessing import cpu_count import absl.app -import numpy as np import progressbar import tensorflow as tf import tensorflow.compat.v1 as tfv1 @@ -18,10 +16,10 @@ from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer from six.moves import zip from util.config import Config, initialize_globals -from util.evaluate_tools import calculate_report +from util.evaluate_tools import calculate_and_print_report from util.feeding import create_dataset from util.flags import create_flags, FLAGS -from util.logging import log_error, log_progress, create_progressbar +from util.logging import create_progressbar, log_error, log_progress from util.helpers import check_ctcdecoder_version; check_ctcdecoder_version() @@ -132,24 +130,9 @@ def evaluate(test_csvs, create_model, try_loading): bar.finish() - wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses) - mean_loss = np.mean(losses) - - # Take only the first report_count items - report_samples = itertools.islice(samples, FLAGS.report_count) - - print('Test on %s - WER: %f, CER: %f, loss: %f' % - (dataset, wer, cer, mean_loss)) - print('-' * 80) - for sample in report_samples: - print('WER: %f, CER: %f, loss: %f' % - (sample.wer, sample.cer, sample.loss)) - print(' - wav: file://%s' % sample.wav_filename) - print(' - src: "%s"' % sample.src) - print(' - res: "%s"' % sample.res) - print('-' * 80) - - return samples + # Print test summary + test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset) + return test_samples samples = [] for csv, init_op in zip(test_csvs, test_init_ops): diff --git a/evaluate_tflite.py b/evaluate_tflite.py index ec80383c..8b2ba453 100644 --- a/evaluate_tflite.py +++ b/evaluate_tflite.py @@ -15,7 +15,7 @@ from six.moves import zip, range from multiprocessing import JoinableQueue, Process, cpu_count, Manager from deepspeech import Model -from util.evaluate_tools import calculate_report +from util.evaluate_tools import calculate_and_print_report from util.flags import create_flags r''' @@ -98,11 +98,8 @@ def main(args, _): predictions.append(msg['prediction']) wavlist.append(msg['wav']) - wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses) - mean_loss = np.mean(losses) - - print('Test - WER: %f, CER: %f, loss: %f' % - (wer, cer, mean_loss)) + # Print test summary + _ = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, args.csv) if args.dump: with open(args.dump + '.txt', 'w') as ftxt, open(args.dump + '.out', 'w') as fout: diff --git a/util/evaluate_tools.py b/util/evaluate_tools.py index 7f6a8ffb..13994b8b 100644 --- a/util/evaluate_tools.py +++ b/util/evaluate_tools.py @@ -3,6 +3,7 @@ from __future__ import absolute_import, division, print_function from multiprocessing.dummy import Pool +import numpy as np from attrdict import AttrDict @@ -54,9 +55,9 @@ def process_decode_result(item): }) -def calculate_report(wav_filenames, labels, decodings, losses): +def calculate_and_print_report(wav_filenames, labels, decodings, losses, dataset_name): r''' - This routine will calculate a WER report. + This routine will calculate and print a WER report. It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER). ''' @@ -65,13 +66,52 @@ def calculate_report(wav_filenames, labels, decodings, losses): # Getting the WER and CER from the accumulated edit distances and lengths samples_wer, samples_cer = wer_cer_batch(samples) - # Order the remaining items by their loss (lowest loss on top) - samples.sort(key=lambda s: s.loss) + # Reversed because the worst WER with the best loss is to identify systemic issues, where the acoustic model is confident, + # yet the result is completely off the mark. This can point to transcription errors and stuff like that. + samples.sort(key=lambda s: s.loss, reverse=True) - # Then order by descending WER/CER + # Then order by ascending WER/CER if FLAGS.utf8: - samples.sort(key=lambda s: s.cer, reverse=True) + samples.sort(key=lambda s: s.cer) else: - samples.sort(key=lambda s: s.wer, reverse=True) + samples.sort(key=lambda s: s.wer) - return samples_wer, samples_cer, samples + # Print the report + print_report(samples, losses, samples_wer, samples_cer, dataset_name) + + return samples + + +def print_report(samples, losses, wer, cer, dataset_name): + """ Print a report summary and samples of best, median and worst results """ + + # Print summary + mean_loss = np.mean(losses) + print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset_name, wer, cer, mean_loss)) + print('-' * 80) + + best_samples = samples[:FLAGS.report_count] + worst_samples = samples[-FLAGS.report_count:] + median_index = int(len(samples) / 2) + median_left = int(FLAGS.report_count / 2) + median_right = FLAGS.report_count - median_left + median_samples = samples[median_index - median_left:median_index + median_right] + + def print_single_sample(sample): + print('WER: %f, CER: %f, loss: %f' % (sample.wer, sample.cer, sample.loss)) + print(' - wav: file://%s' % sample.wav_filename) + print(' - src: "%s"' % sample.src) + print(' - res: "%s"' % sample.res) + print('-' * 80) + + print('Best WER:', '\n' + '-' * 80) + for s in best_samples: + print_single_sample(s) + + print('Median WER:', '\n' + '-' * 80) + for s in median_samples: + print_single_sample(s) + + print('Worst WER:', '\n' + '-' * 80) + for s in worst_samples: + print_single_sample(s) diff --git a/util/flags.py b/util/flags.py index f2d8e75d..49d54fd0 100644 --- a/util/flags.py +++ b/util/flags.py @@ -118,7 +118,7 @@ def create_flags(): f.DEFINE_boolean('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.') f.DEFINE_boolean('log_placement', False, 'whether to log device placement of the operators to the console') - f.DEFINE_integer('report_count', 10, 'number of phrases with lowest WER(best matching) to print out during a WER report') + f.DEFINE_integer('report_count', 5, 'number of phrases for each of best WER, median WER and worst WER to print out during a WER report') f.DEFINE_string('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification')