From a2f05ccabe85d2b822f32aedb1cfe20aa1ab02e6 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 5 Feb 2020 17:50:27 +0100 Subject: [PATCH 01/12] Print best and worst results in a WER report. --- evaluate.py | 14 ++++++++++---- util/flags.py | 2 +- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/evaluate.py b/evaluate.py index 4b0602c1..ea94fcb5 100755 --- a/evaluate.py +++ b/evaluate.py @@ -2,7 +2,6 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function -import itertools import json import sys @@ -135,13 +134,16 @@ def evaluate(test_csvs, create_model, try_loading): wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses) mean_loss = np.mean(losses) - # Take only the first report_count items - report_samples = itertools.islice(samples, FLAGS.report_count) + # Take only the first and last report_count items + best_samples = samples[-FLAGS.report_count:] + worst_samples = samples[:FLAGS.report_count] + report_samples = best_samples + report_samples.extend(reversed(worst_samples)) print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset, wer, cer, mean_loss)) print('-' * 80) - for sample in report_samples: + for i, sample in enumerate(report_samples): print('WER: %f, CER: %f, loss: %f' % (sample.wer, sample.cer, sample.loss)) print(' - wav: file://%s' % sample.wav_filename) @@ -149,6 +151,10 @@ def evaluate(test_csvs, create_model, try_loading): print(' - res: "%s"' % sample.res) print('-' * 80) + if (i == FLAGS.report_count - 1): + print('[...]') + print('-' * 80) + return samples samples = [] diff --git a/util/flags.py b/util/flags.py index f2d8e75d..25903f1a 100644 --- a/util/flags.py +++ b/util/flags.py @@ -118,7 +118,7 @@ def create_flags(): f.DEFINE_boolean('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.') f.DEFINE_boolean('log_placement', False, 'whether to log device placement of the operators to the console') - f.DEFINE_integer('report_count', 10, 'number of phrases with lowest WER(best matching) to print out during a WER report') + f.DEFINE_integer('report_count', 7, 'number of phrases with best WER and with worst WER to print out during a WER report') f.DEFINE_string('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification') From 272ed99d24e1f19e2742a8fd5bb3a378e1e15d5c Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 6 Feb 2020 10:55:48 +0100 Subject: [PATCH 02/12] Add median examples. Fix sorting. --- evaluate.py | 47 ++++++++++++++++++++++++------------------ util/evaluate_tools.py | 7 +++---- util/flags.py | 4 +++- 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/evaluate.py b/evaluate.py index ea94fcb5..d0bec134 100755 --- a/evaluate.py +++ b/evaluate.py @@ -4,7 +4,6 @@ from __future__ import absolute_import, division, print_function import json import sys - from multiprocessing import cpu_count import absl.app @@ -12,16 +11,16 @@ import numpy as np import progressbar import tensorflow as tf import tensorflow.compat.v1 as tfv1 - from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer from six.moves import zip - from util.config import Config, initialize_globals from util.evaluate_tools import calculate_report from util.feeding import create_dataset from util.flags import create_flags, FLAGS -from util.logging import log_error, log_progress, create_progressbar -from util.helpers import check_ctcdecoder_version; check_ctcdecoder_version() +from util.helpers import check_ctcdecoder_version; +from util.logging import create_progressbar, log_error, log_progress + +check_ctcdecoder_version() def sparse_tensor_value_to_texts(value, alphabet): @@ -131,29 +130,37 @@ def evaluate(test_csvs, create_model, try_loading): bar.finish() + # Print test summary wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses) mean_loss = np.mean(losses) - - # Take only the first and last report_count items - best_samples = samples[-FLAGS.report_count:] - worst_samples = samples[:FLAGS.report_count] - report_samples = best_samples - report_samples.extend(reversed(worst_samples)) - - print('Test on %s - WER: %f, CER: %f, loss: %f' % - (dataset, wer, cer, mean_loss)) + print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset, wer, cer, mean_loss)) print('-' * 80) - for i, sample in enumerate(report_samples): - print('WER: %f, CER: %f, loss: %f' % - (sample.wer, sample.cer, sample.loss)) + + # Take only the first, median and last report_count items + best_samples = samples[:FLAGS.report_count] + worst_samples = samples[-FLAGS.report_count:] + median_index = int(len(samples) / 1.5) + median_left = int(FLAGS.report_count / 2) + median_right = FLAGS.report_count - median_left + median_samples = samples[median_index - median_left:median_index + median_right] + + def print_single_sample(sample): + print('WER: %f, CER: %f, loss: %f' % (sample.wer, sample.cer, sample.loss)) print(' - wav: file://%s' % sample.wav_filename) print(' - src: "%s"' % sample.src) print(' - res: "%s"' % sample.res) print('-' * 80) - if (i == FLAGS.report_count - 1): - print('[...]') - print('-' * 80) + for s in best_samples: + print_single_sample(s) + print('[...]', '\n' + '-' * 80) + + for s in median_samples: + print_single_sample(s) + print('[...]', '\n' + '-' * 80) + + for s in worst_samples: + print_single_sample(s) return samples diff --git a/util/evaluate_tools.py b/util/evaluate_tools.py index 7f6a8ffb..d50c3e93 100644 --- a/util/evaluate_tools.py +++ b/util/evaluate_tools.py @@ -5,7 +5,6 @@ from __future__ import absolute_import, division, print_function from multiprocessing.dummy import Pool from attrdict import AttrDict - from util.flags import FLAGS from util.text import levenshtein @@ -68,10 +67,10 @@ def calculate_report(wav_filenames, labels, decodings, losses): # Order the remaining items by their loss (lowest loss on top) samples.sort(key=lambda s: s.loss) - # Then order by descending WER/CER + # Then order by ascending WER/CER if FLAGS.utf8: - samples.sort(key=lambda s: s.cer, reverse=True) + samples.sort(key=lambda s: s.cer) else: - samples.sort(key=lambda s: s.wer, reverse=True) + samples.sort(key=lambda s: s.wer) return samples_wer, samples_cer, samples diff --git a/util/flags.py b/util/flags.py index 25903f1a..437e263f 100644 --- a/util/flags.py +++ b/util/flags.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function import os + import absl.flags FLAGS = absl.flags.FLAGS @@ -118,7 +119,8 @@ def create_flags(): f.DEFINE_boolean('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.') f.DEFINE_boolean('log_placement', False, 'whether to log device placement of the operators to the console') - f.DEFINE_integer('report_count', 7, 'number of phrases with best WER and with worst WER to print out during a WER report') + f.DEFINE_integer('report_count', 5, + 'number of phrases with best WER, median WER and with worst WER to print out during a WER report') f.DEFINE_string('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification') From 369e3c9fc3f080bdf6d181200ef207e9108bf297 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 6 Feb 2020 11:01:37 +0100 Subject: [PATCH 03/12] Revert linebreak. --- util/flags.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/util/flags.py b/util/flags.py index 437e263f..1172a8d4 100644 --- a/util/flags.py +++ b/util/flags.py @@ -1,7 +1,6 @@ from __future__ import absolute_import, division, print_function import os - import absl.flags FLAGS = absl.flags.FLAGS @@ -119,8 +118,7 @@ def create_flags(): f.DEFINE_boolean('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.') f.DEFINE_boolean('log_placement', False, 'whether to log device placement of the operators to the console') - f.DEFINE_integer('report_count', 5, - 'number of phrases with best WER, median WER and with worst WER to print out during a WER report') + f.DEFINE_integer('report_count', 5, 'number of phrases with best WER, median WER and with worst WER to print out during a WER report') f.DEFINE_string('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification') From 320e815bb7d1d842fe64e1119b74e068bf5a9829 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 6 Feb 2020 11:05:15 +0100 Subject: [PATCH 04/12] Remove semicolon. --- evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evaluate.py b/evaluate.py index d0bec134..d3447605 100755 --- a/evaluate.py +++ b/evaluate.py @@ -17,7 +17,7 @@ from util.config import Config, initialize_globals from util.evaluate_tools import calculate_report from util.feeding import create_dataset from util.flags import create_flags, FLAGS -from util.helpers import check_ctcdecoder_version; +from util.helpers import check_ctcdecoder_version from util.logging import create_progressbar, log_error, log_progress check_ctcdecoder_version() From a0b5d3e7e0d93849c16f9a156db16320282c6ca9 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 6 Feb 2020 11:07:11 +0100 Subject: [PATCH 05/12] Restore order of imports. --- evaluate.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/evaluate.py b/evaluate.py index d3447605..fc0147ca 100755 --- a/evaluate.py +++ b/evaluate.py @@ -17,10 +17,8 @@ from util.config import Config, initialize_globals from util.evaluate_tools import calculate_report from util.feeding import create_dataset from util.flags import create_flags, FLAGS -from util.helpers import check_ctcdecoder_version from util.logging import create_progressbar, log_error, log_progress - -check_ctcdecoder_version() +from util.helpers import check_ctcdecoder_version; check_ctcdecoder_version() def sparse_tensor_value_to_texts(value, alphabet): From 63a07e6834c2a7a00e54e810926faa6a96855f7b Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 6 Feb 2020 11:39:31 +0100 Subject: [PATCH 06/12] Added summary to evaluate_tflite.py and moved method to evaluate_tools.py. --- evaluate.py | 32 ++++++-------------------------- evaluate_tflite.py | 11 +++++++---- util/evaluate_tools.py | 29 +++++++++++++++++++++++++++++ util/flags.py | 2 +- 4 files changed, 43 insertions(+), 31 deletions(-) diff --git a/evaluate.py b/evaluate.py index fc0147ca..2a781f49 100755 --- a/evaluate.py +++ b/evaluate.py @@ -4,6 +4,7 @@ from __future__ import absolute_import, division, print_function import json import sys + from multiprocessing import cpu_count import absl.app @@ -11,10 +12,12 @@ import numpy as np import progressbar import tensorflow as tf import tensorflow.compat.v1 as tfv1 + from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer from six.moves import zip + from util.config import Config, initialize_globals -from util.evaluate_tools import calculate_report +from util.evaluate_tools import calculate_report, print_report from util.feeding import create_dataset from util.flags import create_flags, FLAGS from util.logging import create_progressbar, log_error, log_progress @@ -134,31 +137,8 @@ def evaluate(test_csvs, create_model, try_loading): print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset, wer, cer, mean_loss)) print('-' * 80) - # Take only the first, median and last report_count items - best_samples = samples[:FLAGS.report_count] - worst_samples = samples[-FLAGS.report_count:] - median_index = int(len(samples) / 1.5) - median_left = int(FLAGS.report_count / 2) - median_right = FLAGS.report_count - median_left - median_samples = samples[median_index - median_left:median_index + median_right] - - def print_single_sample(sample): - print('WER: %f, CER: %f, loss: %f' % (sample.wer, sample.cer, sample.loss)) - print(' - wav: file://%s' % sample.wav_filename) - print(' - src: "%s"' % sample.src) - print(' - res: "%s"' % sample.res) - print('-' * 80) - - for s in best_samples: - print_single_sample(s) - print('[...]', '\n' + '-' * 80) - - for s in median_samples: - print_single_sample(s) - print('[...]', '\n' + '-' * 80) - - for s in worst_samples: - print_single_sample(s) + # Print some examples + print_report(samples) return samples diff --git a/evaluate_tflite.py b/evaluate_tflite.py index ec80383c..91204646 100644 --- a/evaluate_tflite.py +++ b/evaluate_tflite.py @@ -15,8 +15,8 @@ from six.moves import zip, range from multiprocessing import JoinableQueue, Process, cpu_count, Manager from deepspeech import Model -from util.evaluate_tools import calculate_report -from util.flags import create_flags +from util.evaluate_tools import calculate_report, print_report +from util.flags import create_flags, FLAGS r''' This module should be self-contained: @@ -98,11 +98,14 @@ def main(args, _): predictions.append(msg['prediction']) wavlist.append(msg['wav']) + # Print test summary wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses) mean_loss = np.mean(losses) + print('Test - WER: %f, CER: %f, loss: %f' % (wer, cer, mean_loss)) + print('-' * 80) - print('Test - WER: %f, CER: %f, loss: %f' % - (wer, cer, mean_loss)) + # Print some examples + print_report(samples) if args.dump: with open(args.dump + '.txt', 'w') as ftxt, open(args.dump + '.out', 'w') as fout: diff --git a/util/evaluate_tools.py b/util/evaluate_tools.py index d50c3e93..69970847 100644 --- a/util/evaluate_tools.py +++ b/util/evaluate_tools.py @@ -74,3 +74,32 @@ def calculate_report(wav_filenames, labels, decodings, losses): samples.sort(key=lambda s: s.wer) return samples_wer, samples_cer, samples + + +def print_report(samples): + """ Print a report with samples of best, median and worst results """ + + best_samples = samples[:FLAGS.report_count] + worst_samples = samples[-FLAGS.report_count:] + median_index = int(len(samples) / 2) + median_left = int(FLAGS.report_count / 2) + median_right = FLAGS.report_count - median_left + median_samples = samples[median_index - median_left:median_index + median_right] + + def print_single_sample(sample): + print('WER: %f, CER: %f, loss: %f' % (sample.wer, sample.cer, sample.loss)) + print(' - wav: file://%s' % sample.wav_filename) + print(' - src: "%s"' % sample.src) + print(' - res: "%s"' % sample.res) + print('-' * 80) + + for s in best_samples: + print_single_sample(s) + print('[...]', '\n' + '-' * 80) + + for s in median_samples: + print_single_sample(s) + print('[...]', '\n' + '-' * 80) + + for s in worst_samples: + print_single_sample(s) diff --git a/util/flags.py b/util/flags.py index 1172a8d4..49d54fd0 100644 --- a/util/flags.py +++ b/util/flags.py @@ -118,7 +118,7 @@ def create_flags(): f.DEFINE_boolean('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.') f.DEFINE_boolean('log_placement', False, 'whether to log device placement of the operators to the console') - f.DEFINE_integer('report_count', 5, 'number of phrases with best WER, median WER and with worst WER to print out during a WER report') + f.DEFINE_integer('report_count', 5, 'number of phrases for each of best WER, median WER and worst WER to print out during a WER report') f.DEFINE_string('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification') From 9ec88b7f28334672b9be8f91fa2b0d6cf6b5b995 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 6 Feb 2020 11:41:40 +0100 Subject: [PATCH 07/12] Add whitespace again. --- util/evaluate_tools.py | 1 + 1 file changed, 1 insertion(+) diff --git a/util/evaluate_tools.py b/util/evaluate_tools.py index 69970847..af07d0af 100644 --- a/util/evaluate_tools.py +++ b/util/evaluate_tools.py @@ -5,6 +5,7 @@ from __future__ import absolute_import, division, print_function from multiprocessing.dummy import Pool from attrdict import AttrDict + from util.flags import FLAGS from util.text import levenshtein From f5145526f0690fd24728a4353a6530b1a56eb4b2 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 6 Feb 2020 11:43:48 +0100 Subject: [PATCH 08/12] Dont need flags. --- evaluate_tflite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evaluate_tflite.py b/evaluate_tflite.py index 91204646..9f16918f 100644 --- a/evaluate_tflite.py +++ b/evaluate_tflite.py @@ -16,7 +16,7 @@ from multiprocessing import JoinableQueue, Process, cpu_count, Manager from deepspeech import Model from util.evaluate_tools import calculate_report, print_report -from util.flags import create_flags, FLAGS +from util.flags import create_flags r''' This module should be self-contained: From 4186cbef883400b6d6c30ec64284454a1232f938 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 6 Feb 2020 12:49:46 +0100 Subject: [PATCH 09/12] Reverse ordered loss again. --- util/evaluate_tools.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/util/evaluate_tools.py b/util/evaluate_tools.py index af07d0af..1a8c7638 100644 --- a/util/evaluate_tools.py +++ b/util/evaluate_tools.py @@ -65,8 +65,9 @@ def calculate_report(wav_filenames, labels, decodings, losses): # Getting the WER and CER from the accumulated edit distances and lengths samples_wer, samples_cer = wer_cer_batch(samples) - # Order the remaining items by their loss (lowest loss on top) - samples.sort(key=lambda s: s.loss) + # Reversed because the worst WER with the best loss is to identify systemic issues, where the acoustic model is confident, + # yet the result is completely off the mark. This can point to transcription errors and stuff like that. + samples.sort(key=lambda s: s.loss, reverse=True) # Then order by ascending WER/CER if FLAGS.utf8: From de92142986804ec781a8ad053095c7728f4318be Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 6 Feb 2020 13:31:30 +0100 Subject: [PATCH 10/12] Named example sections. --- util/evaluate_tools.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/util/evaluate_tools.py b/util/evaluate_tools.py index 1a8c7638..2bef89b3 100644 --- a/util/evaluate_tools.py +++ b/util/evaluate_tools.py @@ -95,13 +95,14 @@ def print_report(samples): print(' - res: "%s"' % sample.res) print('-' * 80) + print('Best WER:', '\n' + '-' * 80) for s in best_samples: print_single_sample(s) - print('[...]', '\n' + '-' * 80) + print('Median WER:', '\n' + '-' * 80) for s in median_samples: print_single_sample(s) - print('[...]', '\n' + '-' * 80) + print('Worst WER:', '\n' + '-' * 80) for s in worst_samples: print_single_sample(s) From 8cc91fafb23178f73e15a2458ac66ddbb21e2fb9 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 6 Feb 2020 14:44:41 +0100 Subject: [PATCH 11/12] Moved summary printing to samples printing. --- evaluate.py | 14 +++----------- evaluate_tflite.py | 10 ++-------- util/evaluate_tools.py | 19 ++++++++++++++----- 3 files changed, 19 insertions(+), 24 deletions(-) diff --git a/evaluate.py b/evaluate.py index 2a781f49..8df73966 100755 --- a/evaluate.py +++ b/evaluate.py @@ -8,7 +8,6 @@ import sys from multiprocessing import cpu_count import absl.app -import numpy as np import progressbar import tensorflow as tf import tensorflow.compat.v1 as tfv1 @@ -17,7 +16,7 @@ from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer from six.moves import zip from util.config import Config, initialize_globals -from util.evaluate_tools import calculate_report, print_report +from util.evaluate_tools import calculate_and_print_report from util.feeding import create_dataset from util.flags import create_flags, FLAGS from util.logging import create_progressbar, log_error, log_progress @@ -132,15 +131,8 @@ def evaluate(test_csvs, create_model, try_loading): bar.finish() # Print test summary - wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses) - mean_loss = np.mean(losses) - print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset, wer, cer, mean_loss)) - print('-' * 80) - - # Print some examples - print_report(samples) - - return samples + test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset) + return test_samples samples = [] for csv, init_op in zip(test_csvs, test_init_ops): diff --git a/evaluate_tflite.py b/evaluate_tflite.py index 9f16918f..8b2ba453 100644 --- a/evaluate_tflite.py +++ b/evaluate_tflite.py @@ -15,7 +15,7 @@ from six.moves import zip, range from multiprocessing import JoinableQueue, Process, cpu_count, Manager from deepspeech import Model -from util.evaluate_tools import calculate_report, print_report +from util.evaluate_tools import calculate_and_print_report from util.flags import create_flags r''' @@ -99,13 +99,7 @@ def main(args, _): wavlist.append(msg['wav']) # Print test summary - wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses) - mean_loss = np.mean(losses) - print('Test - WER: %f, CER: %f, loss: %f' % (wer, cer, mean_loss)) - print('-' * 80) - - # Print some examples - print_report(samples) + _ = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, args.csv) if args.dump: with open(args.dump + '.txt', 'w') as ftxt, open(args.dump + '.out', 'w') as fout: diff --git a/util/evaluate_tools.py b/util/evaluate_tools.py index 2bef89b3..d3ad8379 100644 --- a/util/evaluate_tools.py +++ b/util/evaluate_tools.py @@ -3,6 +3,7 @@ from __future__ import absolute_import, division, print_function from multiprocessing.dummy import Pool +import numpy as np from attrdict import AttrDict @@ -54,9 +55,9 @@ def process_decode_result(item): }) -def calculate_report(wav_filenames, labels, decodings, losses): +def calculate_and_print_report(wav_filenames, labels, decodings, losses, dataset): r''' - This routine will calculate a WER report. + This routine will calculate and print a WER report. It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER). ''' @@ -75,11 +76,19 @@ def calculate_report(wav_filenames, labels, decodings, losses): else: samples.sort(key=lambda s: s.wer) - return samples_wer, samples_cer, samples + # Print the report + print_report(samples, losses, samples_wer, samples_cer, dataset) + + return samples -def print_report(samples): - """ Print a report with samples of best, median and worst results """ +def print_report(samples, losses, wer, cer, dataset): + """ Print a report summary and samples of best, median and worst results """ + + # Print summary + mean_loss = np.mean(losses) + print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset, wer, cer, mean_loss)) + print('-' * 80) best_samples = samples[:FLAGS.report_count] worst_samples = samples[-FLAGS.report_count:] From 726cc20586e2f8afc0b87f5ee84afc941e70d2d6 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 6 Feb 2020 14:47:59 +0100 Subject: [PATCH 12/12] Rename dataset param. --- util/evaluate_tools.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/util/evaluate_tools.py b/util/evaluate_tools.py index d3ad8379..13994b8b 100644 --- a/util/evaluate_tools.py +++ b/util/evaluate_tools.py @@ -55,7 +55,7 @@ def process_decode_result(item): }) -def calculate_and_print_report(wav_filenames, labels, decodings, losses, dataset): +def calculate_and_print_report(wav_filenames, labels, decodings, losses, dataset_name): r''' This routine will calculate and print a WER report. It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest @@ -77,17 +77,17 @@ def calculate_and_print_report(wav_filenames, labels, decodings, losses, dataset samples.sort(key=lambda s: s.wer) # Print the report - print_report(samples, losses, samples_wer, samples_cer, dataset) + print_report(samples, losses, samples_wer, samples_cer, dataset_name) return samples -def print_report(samples, losses, wer, cer, dataset): +def print_report(samples, losses, wer, cer, dataset_name): """ Print a report summary and samples of best, median and worst results """ # Print summary mean_loss = np.mean(losses) - print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset, wer, cer, mean_loss)) + print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset_name, wer, cer, mean_loss)) print('-' * 80) best_samples = samples[:FLAGS.report_count]