Moved summary printing to samples printing.
This commit is contained in:
parent
de92142986
commit
8cc91fafb2
14
evaluate.py
14
evaluate.py
@ -8,7 +8,6 @@ import sys
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
import absl.app
|
||||
import numpy as np
|
||||
import progressbar
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
@ -17,7 +16,7 @@ from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||
from six.moves import zip
|
||||
|
||||
from util.config import Config, initialize_globals
|
||||
from util.evaluate_tools import calculate_report, print_report
|
||||
from util.evaluate_tools import calculate_and_print_report
|
||||
from util.feeding import create_dataset
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.logging import create_progressbar, log_error, log_progress
|
||||
@ -132,15 +131,8 @@ def evaluate(test_csvs, create_model, try_loading):
|
||||
bar.finish()
|
||||
|
||||
# Print test summary
|
||||
wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses)
|
||||
mean_loss = np.mean(losses)
|
||||
print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset, wer, cer, mean_loss))
|
||||
print('-' * 80)
|
||||
|
||||
# Print some examples
|
||||
print_report(samples)
|
||||
|
||||
return samples
|
||||
test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset)
|
||||
return test_samples
|
||||
|
||||
samples = []
|
||||
for csv, init_op in zip(test_csvs, test_init_ops):
|
||||
|
@ -15,7 +15,7 @@ from six.moves import zip, range
|
||||
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
|
||||
from deepspeech import Model
|
||||
|
||||
from util.evaluate_tools import calculate_report, print_report
|
||||
from util.evaluate_tools import calculate_and_print_report
|
||||
from util.flags import create_flags
|
||||
|
||||
r'''
|
||||
@ -99,13 +99,7 @@ def main(args, _):
|
||||
wavlist.append(msg['wav'])
|
||||
|
||||
# Print test summary
|
||||
wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses)
|
||||
mean_loss = np.mean(losses)
|
||||
print('Test - WER: %f, CER: %f, loss: %f' % (wer, cer, mean_loss))
|
||||
print('-' * 80)
|
||||
|
||||
# Print some examples
|
||||
print_report(samples)
|
||||
_ = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, args.csv)
|
||||
|
||||
if args.dump:
|
||||
with open(args.dump + '.txt', 'w') as ftxt, open(args.dump + '.out', 'w') as fout:
|
||||
|
@ -3,6 +3,7 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from multiprocessing.dummy import Pool
|
||||
import numpy as np
|
||||
|
||||
from attrdict import AttrDict
|
||||
|
||||
@ -54,9 +55,9 @@ def process_decode_result(item):
|
||||
})
|
||||
|
||||
|
||||
def calculate_report(wav_filenames, labels, decodings, losses):
|
||||
def calculate_and_print_report(wav_filenames, labels, decodings, losses, dataset):
|
||||
r'''
|
||||
This routine will calculate a WER report.
|
||||
This routine will calculate and print a WER report.
|
||||
It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest
|
||||
loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER).
|
||||
'''
|
||||
@ -75,11 +76,19 @@ def calculate_report(wav_filenames, labels, decodings, losses):
|
||||
else:
|
||||
samples.sort(key=lambda s: s.wer)
|
||||
|
||||
return samples_wer, samples_cer, samples
|
||||
# Print the report
|
||||
print_report(samples, losses, samples_wer, samples_cer, dataset)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def print_report(samples):
|
||||
""" Print a report with samples of best, median and worst results """
|
||||
def print_report(samples, losses, wer, cer, dataset):
|
||||
""" Print a report summary and samples of best, median and worst results """
|
||||
|
||||
# Print summary
|
||||
mean_loss = np.mean(losses)
|
||||
print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset, wer, cer, mean_loss))
|
||||
print('-' * 80)
|
||||
|
||||
best_samples = samples[:FLAGS.report_count]
|
||||
worst_samples = samples[-FLAGS.report_count:]
|
||||
|
Loading…
x
Reference in New Issue
Block a user