From 0dcb1f87c59263a07b1836a39824badeb71a9719 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Sun, 11 Nov 2018 18:04:00 -0200 Subject: [PATCH] Clean up evaluate.py --- evaluate.py | 46 +++++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/evaluate.py b/evaluate.py index 50234f17..1f0fe291 100755 --- a/evaluate.py +++ b/evaluate.py @@ -159,29 +159,29 @@ def evaluate(test_data, inference_graph, alphabet): logitses.append(logits) losses.extend(loss) - ground_truths = [] - predictions = [] - distances = [] + ground_truths = [] + predictions = [] - print('Decoding predictions...') - bar = progressbar.ProgressBar(max_value=batch_count, - widget=progressbar.AdaptiveETA) + print('Decoding predictions...') + bar = progressbar.ProgressBar(max_value=batch_count, + widget=progressbar.AdaptiveETA) - # Get number of accessible CPU cores for this process - try: - num_processes = cpu_count() - except: - num_processes = 1 + # Get number of accessible CPU cores for this process + try: + num_processes = cpu_count() + except: + num_processes = 1 - # Second pass, decode logits and compute WER and edit distance metrics - for logits, batch in bar(zip(logitses, split_data(test_data, FLAGS.test_batch_size))): - seq_lengths = batch['features_len'].values.astype(np.int32) - decoded = ctc_beam_search_decoder_batch(logits, seq_lengths, alphabet, FLAGS.beam_width, - num_processes=num_processes, scorer=scorer) + # Second pass, decode logits and compute WER and edit distance metrics + for logits, batch in bar(zip(logitses, split_data(test_data, FLAGS.test_batch_size))): + seq_lengths = batch['features_len'].values.astype(np.int32) + decoded = ctc_beam_search_decoder_batch(logits, seq_lengths, alphabet, FLAGS.beam_width, + num_processes=num_processes, scorer=scorer) - ground_truths.extend(alphabet.decode(l) for l in batch['transcript']) - predictions.extend(d[0][1] for d in decoded) - distances.extend(levenshtein(a, b) for a, b in zip(labels, predictions)) + ground_truths.extend(alphabet.decode(l) for l in batch['transcript']) + predictions.extend(d[0][1] for d in decoded) + + distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)] wer, samples = calculate_report(ground_truths, predictions, distances, losses) mean_edit_distance = np.mean(distances) @@ -190,12 +190,12 @@ def evaluate(test_data, inference_graph, alphabet): # Take only the first report_count items report_samples = itertools.islice(samples, FLAGS.report_count) - print('Test - WER: %f, loss: %f, mean edit distance: %f' % - (wer, mean_loss, mean_edit_distance)) + print('Test - WER: %f, CER: %f, loss: %f' % + (wer, mean_edit_distance, mean_loss)) print('-' * 80) for sample in report_samples: - print('WER: %f, loss: %f, edit distance: %f' % - (sample.wer, sample.loss, sample.distance)) + print('WER: %f, CER: %f, loss: %f' % + (sample.wer, sample.distance, sample.loss)) print(' - src: "%s"' % sample.src) print(' - res: "%s"' % sample.res) print('-' * 80)