diff --git a/evaluate.py b/evaluate.py index 2a781f49..8df73966 100755 --- a/evaluate.py +++ b/evaluate.py @@ -8,7 +8,6 @@ import sys from multiprocessing import cpu_count import absl.app -import numpy as np import progressbar import tensorflow as tf import tensorflow.compat.v1 as tfv1 @@ -17,7 +16,7 @@ from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer from six.moves import zip from util.config import Config, initialize_globals -from util.evaluate_tools import calculate_report, print_report +from util.evaluate_tools import calculate_and_print_report from util.feeding import create_dataset from util.flags import create_flags, FLAGS from util.logging import create_progressbar, log_error, log_progress @@ -132,15 +131,8 @@ def evaluate(test_csvs, create_model, try_loading): bar.finish() # Print test summary - wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses) - mean_loss = np.mean(losses) - print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset, wer, cer, mean_loss)) - print('-' * 80) - - # Print some examples - print_report(samples) - - return samples + test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset) + return test_samples samples = [] for csv, init_op in zip(test_csvs, test_init_ops): diff --git a/evaluate_tflite.py b/evaluate_tflite.py index 9f16918f..8b2ba453 100644 --- a/evaluate_tflite.py +++ b/evaluate_tflite.py @@ -15,7 +15,7 @@ from six.moves import zip, range from multiprocessing import JoinableQueue, Process, cpu_count, Manager from deepspeech import Model -from util.evaluate_tools import calculate_report, print_report +from util.evaluate_tools import calculate_and_print_report from util.flags import create_flags r''' @@ -99,13 +99,7 @@ def main(args, _): wavlist.append(msg['wav']) # Print test summary - wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses) - mean_loss = np.mean(losses) - print('Test - WER: %f, CER: %f, loss: %f' % (wer, cer, mean_loss)) - print('-' * 80) - - # Print some examples - print_report(samples) + _ = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, args.csv) if args.dump: with open(args.dump + '.txt', 'w') as ftxt, open(args.dump + '.out', 'w') as fout: diff --git a/util/evaluate_tools.py b/util/evaluate_tools.py index 2bef89b3..d3ad8379 100644 --- a/util/evaluate_tools.py +++ b/util/evaluate_tools.py @@ -3,6 +3,7 @@ from __future__ import absolute_import, division, print_function from multiprocessing.dummy import Pool +import numpy as np from attrdict import AttrDict @@ -54,9 +55,9 @@ def process_decode_result(item): }) -def calculate_report(wav_filenames, labels, decodings, losses): +def calculate_and_print_report(wav_filenames, labels, decodings, losses, dataset): r''' - This routine will calculate a WER report. + This routine will calculate and print a WER report. It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER). ''' @@ -75,11 +76,19 @@ def calculate_report(wav_filenames, labels, decodings, losses): else: samples.sort(key=lambda s: s.wer) - return samples_wer, samples_cer, samples + # Print the report + print_report(samples, losses, samples_wer, samples_cer, dataset) + + return samples -def print_report(samples): - """ Print a report with samples of best, median and worst results """ +def print_report(samples, losses, wer, cer, dataset): + """ Print a report summary and samples of best, median and worst results """ + + # Print summary + mean_loss = np.mean(losses) + print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset, wer, cer, mean_loss)) + print('-' * 80) best_samples = samples[:FLAGS.report_count] worst_samples = samples[-FLAGS.report_count:]