Clean up evaluate.py
This commit is contained in:
parent
5cb1aff531
commit
0dcb1f87c5
46
evaluate.py
46
evaluate.py
@ -159,29 +159,29 @@ def evaluate(test_data, inference_graph, alphabet):
|
|||||||
logitses.append(logits)
|
logitses.append(logits)
|
||||||
losses.extend(loss)
|
losses.extend(loss)
|
||||||
|
|
||||||
ground_truths = []
|
ground_truths = []
|
||||||
predictions = []
|
predictions = []
|
||||||
distances = []
|
|
||||||
|
|
||||||
print('Decoding predictions...')
|
print('Decoding predictions...')
|
||||||
bar = progressbar.ProgressBar(max_value=batch_count,
|
bar = progressbar.ProgressBar(max_value=batch_count,
|
||||||
widget=progressbar.AdaptiveETA)
|
widget=progressbar.AdaptiveETA)
|
||||||
|
|
||||||
# Get number of accessible CPU cores for this process
|
# Get number of accessible CPU cores for this process
|
||||||
try:
|
try:
|
||||||
num_processes = cpu_count()
|
num_processes = cpu_count()
|
||||||
except:
|
except:
|
||||||
num_processes = 1
|
num_processes = 1
|
||||||
|
|
||||||
# Second pass, decode logits and compute WER and edit distance metrics
|
# Second pass, decode logits and compute WER and edit distance metrics
|
||||||
for logits, batch in bar(zip(logitses, split_data(test_data, FLAGS.test_batch_size))):
|
for logits, batch in bar(zip(logitses, split_data(test_data, FLAGS.test_batch_size))):
|
||||||
seq_lengths = batch['features_len'].values.astype(np.int32)
|
seq_lengths = batch['features_len'].values.astype(np.int32)
|
||||||
decoded = ctc_beam_search_decoder_batch(logits, seq_lengths, alphabet, FLAGS.beam_width,
|
decoded = ctc_beam_search_decoder_batch(logits, seq_lengths, alphabet, FLAGS.beam_width,
|
||||||
num_processes=num_processes, scorer=scorer)
|
num_processes=num_processes, scorer=scorer)
|
||||||
|
|
||||||
ground_truths.extend(alphabet.decode(l) for l in batch['transcript'])
|
ground_truths.extend(alphabet.decode(l) for l in batch['transcript'])
|
||||||
predictions.extend(d[0][1] for d in decoded)
|
predictions.extend(d[0][1] for d in decoded)
|
||||||
distances.extend(levenshtein(a, b) for a, b in zip(labels, predictions))
|
|
||||||
|
distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]
|
||||||
|
|
||||||
wer, samples = calculate_report(ground_truths, predictions, distances, losses)
|
wer, samples = calculate_report(ground_truths, predictions, distances, losses)
|
||||||
mean_edit_distance = np.mean(distances)
|
mean_edit_distance = np.mean(distances)
|
||||||
@ -190,12 +190,12 @@ def evaluate(test_data, inference_graph, alphabet):
|
|||||||
# Take only the first report_count items
|
# Take only the first report_count items
|
||||||
report_samples = itertools.islice(samples, FLAGS.report_count)
|
report_samples = itertools.islice(samples, FLAGS.report_count)
|
||||||
|
|
||||||
print('Test - WER: %f, loss: %f, mean edit distance: %f' %
|
print('Test - WER: %f, CER: %f, loss: %f' %
|
||||||
(wer, mean_loss, mean_edit_distance))
|
(wer, mean_edit_distance, mean_loss))
|
||||||
print('-' * 80)
|
print('-' * 80)
|
||||||
for sample in report_samples:
|
for sample in report_samples:
|
||||||
print('WER: %f, loss: %f, edit distance: %f' %
|
print('WER: %f, CER: %f, loss: %f' %
|
||||||
(sample.wer, sample.loss, sample.distance))
|
(sample.wer, sample.distance, sample.loss))
|
||||||
print(' - src: "%s"' % sample.src)
|
print(' - src: "%s"' % sample.src)
|
||||||
print(' - res: "%s"' % sample.res)
|
print(' - res: "%s"' % sample.res)
|
||||||
print('-' * 80)
|
print('-' * 80)
|
||||||
|
Loading…
Reference in New Issue
Block a user