Merge pull request #2724 from DanBmh/master
Print best and worst results in a WER report.
This commit is contained in:
commit
33efd9b7ff
27
evaluate.py
27
evaluate.py
@ -2,14 +2,12 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import itertools
|
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from multiprocessing import cpu_count
|
from multiprocessing import cpu_count
|
||||||
|
|
||||||
import absl.app
|
import absl.app
|
||||||
import numpy as np
|
|
||||||
import progressbar
|
import progressbar
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow.compat.v1 as tfv1
|
import tensorflow.compat.v1 as tfv1
|
||||||
@ -18,10 +16,10 @@ from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
|||||||
from six.moves import zip
|
from six.moves import zip
|
||||||
|
|
||||||
from util.config import Config, initialize_globals
|
from util.config import Config, initialize_globals
|
||||||
from util.evaluate_tools import calculate_report
|
from util.evaluate_tools import calculate_and_print_report
|
||||||
from util.feeding import create_dataset
|
from util.feeding import create_dataset
|
||||||
from util.flags import create_flags, FLAGS
|
from util.flags import create_flags, FLAGS
|
||||||
from util.logging import log_error, log_progress, create_progressbar
|
from util.logging import create_progressbar, log_error, log_progress
|
||||||
from util.helpers import check_ctcdecoder_version; check_ctcdecoder_version()
|
from util.helpers import check_ctcdecoder_version; check_ctcdecoder_version()
|
||||||
|
|
||||||
|
|
||||||
@ -132,24 +130,9 @@ def evaluate(test_csvs, create_model, try_loading):
|
|||||||
|
|
||||||
bar.finish()
|
bar.finish()
|
||||||
|
|
||||||
wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses)
|
# Print test summary
|
||||||
mean_loss = np.mean(losses)
|
test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset)
|
||||||
|
return test_samples
|
||||||
# Take only the first report_count items
|
|
||||||
report_samples = itertools.islice(samples, FLAGS.report_count)
|
|
||||||
|
|
||||||
print('Test on %s - WER: %f, CER: %f, loss: %f' %
|
|
||||||
(dataset, wer, cer, mean_loss))
|
|
||||||
print('-' * 80)
|
|
||||||
for sample in report_samples:
|
|
||||||
print('WER: %f, CER: %f, loss: %f' %
|
|
||||||
(sample.wer, sample.cer, sample.loss))
|
|
||||||
print(' - wav: file://%s' % sample.wav_filename)
|
|
||||||
print(' - src: "%s"' % sample.src)
|
|
||||||
print(' - res: "%s"' % sample.res)
|
|
||||||
print('-' * 80)
|
|
||||||
|
|
||||||
return samples
|
|
||||||
|
|
||||||
samples = []
|
samples = []
|
||||||
for csv, init_op in zip(test_csvs, test_init_ops):
|
for csv, init_op in zip(test_csvs, test_init_ops):
|
||||||
|
@ -15,7 +15,7 @@ from six.moves import zip, range
|
|||||||
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
|
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
|
||||||
from deepspeech import Model
|
from deepspeech import Model
|
||||||
|
|
||||||
from util.evaluate_tools import calculate_report
|
from util.evaluate_tools import calculate_and_print_report
|
||||||
from util.flags import create_flags
|
from util.flags import create_flags
|
||||||
|
|
||||||
r'''
|
r'''
|
||||||
@ -98,11 +98,8 @@ def main(args, _):
|
|||||||
predictions.append(msg['prediction'])
|
predictions.append(msg['prediction'])
|
||||||
wavlist.append(msg['wav'])
|
wavlist.append(msg['wav'])
|
||||||
|
|
||||||
wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses)
|
# Print test summary
|
||||||
mean_loss = np.mean(losses)
|
_ = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, args.csv)
|
||||||
|
|
||||||
print('Test - WER: %f, CER: %f, loss: %f' %
|
|
||||||
(wer, cer, mean_loss))
|
|
||||||
|
|
||||||
if args.dump:
|
if args.dump:
|
||||||
with open(args.dump + '.txt', 'w') as ftxt, open(args.dump + '.out', 'w') as fout:
|
with open(args.dump + '.txt', 'w') as ftxt, open(args.dump + '.out', 'w') as fout:
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
from multiprocessing.dummy import Pool
|
from multiprocessing.dummy import Pool
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from attrdict import AttrDict
|
from attrdict import AttrDict
|
||||||
|
|
||||||
@ -54,9 +55,9 @@ def process_decode_result(item):
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def calculate_report(wav_filenames, labels, decodings, losses):
|
def calculate_and_print_report(wav_filenames, labels, decodings, losses, dataset_name):
|
||||||
r'''
|
r'''
|
||||||
This routine will calculate a WER report.
|
This routine will calculate and print a WER report.
|
||||||
It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest
|
It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest
|
||||||
loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER).
|
loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER).
|
||||||
'''
|
'''
|
||||||
@ -65,13 +66,52 @@ def calculate_report(wav_filenames, labels, decodings, losses):
|
|||||||
# Getting the WER and CER from the accumulated edit distances and lengths
|
# Getting the WER and CER from the accumulated edit distances and lengths
|
||||||
samples_wer, samples_cer = wer_cer_batch(samples)
|
samples_wer, samples_cer = wer_cer_batch(samples)
|
||||||
|
|
||||||
# Order the remaining items by their loss (lowest loss on top)
|
# Reversed because the worst WER with the best loss is to identify systemic issues, where the acoustic model is confident,
|
||||||
samples.sort(key=lambda s: s.loss)
|
# yet the result is completely off the mark. This can point to transcription errors and stuff like that.
|
||||||
|
samples.sort(key=lambda s: s.loss, reverse=True)
|
||||||
|
|
||||||
# Then order by descending WER/CER
|
# Then order by ascending WER/CER
|
||||||
if FLAGS.utf8:
|
if FLAGS.utf8:
|
||||||
samples.sort(key=lambda s: s.cer, reverse=True)
|
samples.sort(key=lambda s: s.cer)
|
||||||
else:
|
else:
|
||||||
samples.sort(key=lambda s: s.wer, reverse=True)
|
samples.sort(key=lambda s: s.wer)
|
||||||
|
|
||||||
return samples_wer, samples_cer, samples
|
# Print the report
|
||||||
|
print_report(samples, losses, samples_wer, samples_cer, dataset_name)
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def print_report(samples, losses, wer, cer, dataset_name):
|
||||||
|
""" Print a report summary and samples of best, median and worst results """
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
mean_loss = np.mean(losses)
|
||||||
|
print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset_name, wer, cer, mean_loss))
|
||||||
|
print('-' * 80)
|
||||||
|
|
||||||
|
best_samples = samples[:FLAGS.report_count]
|
||||||
|
worst_samples = samples[-FLAGS.report_count:]
|
||||||
|
median_index = int(len(samples) / 2)
|
||||||
|
median_left = int(FLAGS.report_count / 2)
|
||||||
|
median_right = FLAGS.report_count - median_left
|
||||||
|
median_samples = samples[median_index - median_left:median_index + median_right]
|
||||||
|
|
||||||
|
def print_single_sample(sample):
|
||||||
|
print('WER: %f, CER: %f, loss: %f' % (sample.wer, sample.cer, sample.loss))
|
||||||
|
print(' - wav: file://%s' % sample.wav_filename)
|
||||||
|
print(' - src: "%s"' % sample.src)
|
||||||
|
print(' - res: "%s"' % sample.res)
|
||||||
|
print('-' * 80)
|
||||||
|
|
||||||
|
print('Best WER:', '\n' + '-' * 80)
|
||||||
|
for s in best_samples:
|
||||||
|
print_single_sample(s)
|
||||||
|
|
||||||
|
print('Median WER:', '\n' + '-' * 80)
|
||||||
|
for s in median_samples:
|
||||||
|
print_single_sample(s)
|
||||||
|
|
||||||
|
print('Worst WER:', '\n' + '-' * 80)
|
||||||
|
for s in worst_samples:
|
||||||
|
print_single_sample(s)
|
||||||
|
@ -118,7 +118,7 @@ def create_flags():
|
|||||||
f.DEFINE_boolean('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.')
|
f.DEFINE_boolean('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.')
|
||||||
|
|
||||||
f.DEFINE_boolean('log_placement', False, 'whether to log device placement of the operators to the console')
|
f.DEFINE_boolean('log_placement', False, 'whether to log device placement of the operators to the console')
|
||||||
f.DEFINE_integer('report_count', 10, 'number of phrases with lowest WER(best matching) to print out during a WER report')
|
f.DEFINE_integer('report_count', 5, 'number of phrases for each of best WER, median WER and worst WER to print out during a WER report')
|
||||||
|
|
||||||
f.DEFINE_string('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification')
|
f.DEFINE_string('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification')
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user