STT/training/deepspeech_training/util/evaluate_tools.py
Catalin Voss 6640cf2341
Remote training I/O once more (#3437)
* Redo remote I/O changes once more; this time without messing with taskcluster

* Add bin changes

* Fix merge-induced issue?

* For the interleaved case with multiple collections, unpack audio on the fly

To reproduce the previous failure

rm data/smoke_test/ldc93s1.csv
rm data/smoke_test/ldc93s1.sdb
rm -rf /tmp/ldc93s1_cache_sdb_csv
rm -rf /tmp/ckpt_sdb_csv
rm -rf /tmp/train_sdb_csv

./bin/run-tc-ldc93s1_new_sdb_csv.sh 109 16000
python -u DeepSpeech.py --noshow_progressbar --noearly_stop --train_files ./data/smoke_test/ldc93s1.sdb,./data/smoke_test/ldc93s1.csv --train_batch_size 1 --feature_cache /tmp/ldc93s1_cache_sdb_csv --dev_files ./data/smoke_test/ldc93s1.sdb,./data/smoke_test/ldc93s1.csv --dev_batch_size 1 --test_files ./data/smoke_test/ldc93s1.sdb,./data/smoke_test/ldc93s1.csv --test_batch_size 1 --n_hidden 100 --epochs 109 --max_to_keep 1 --checkpoint_dir /tmp/ckpt_sdb_csv --learning_rate 0.001 --dropout_rate 0.05 --export_dir /tmp/train_sdb_csv --scorer_path data/smoke_test/pruned_lm.scorer --audio_sample_rate 16000

* Attempt to preserve length information with a wrapper around `map()`… this gets pretty python-y

* Call the right `__next__()`

* Properly implement the rest of the map wrappers here……

* Fix trailing whitespace situation and other linter complaints

* Remove data accidentally checked in

* Fix overlay augmentations

* Wavs must be open in rb mode if we're passing in an external file pointer -- this confused me

* Lint whitespace

* Revert "Fix trailing whitespace situation and other linter complaints"

This reverts commit c3c45397a2f98e9b00d00c18c4ced4fc52475032.

* Fix linter issue but without such an aggressive diff

* Move unpack_maybe into sample_collections

* Use unpack_maybe in place of duplicate lambda

* Fix confusing comment

* Add clarifying comment for on-the-fly unpacking
2020-12-07 13:07:34 +01:00

129 lines
4.4 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import json
from multiprocessing.dummy import Pool
import numpy as np
from attrdict import AttrDict
from .flags import FLAGS
from .text import levenshtein
from .io import open_remote
def pmap(fun, iterable):
pool = Pool()
results = pool.map(fun, iterable)
pool.close()
return results
def wer_cer_batch(samples):
r"""
The WER is defined as the edit/Levenshtein distance on word level divided by
the amount of words in the original text.
In case of the original having more words (N) than the result and both
being totally different (all N words resulting in 1 edit operation each),
the WER will always be 1 (N / N = 1).
"""
wer = sum(s.word_distance for s in samples) / sum(s.word_length for s in samples)
cer = sum(s.char_distance for s in samples) / sum(s.char_length for s in samples)
wer = min(wer, 1.0)
cer = min(cer, 1.0)
return wer, cer
def process_decode_result(item):
wav_filename, ground_truth, prediction, loss = item
char_distance = levenshtein(ground_truth, prediction)
char_length = len(ground_truth)
word_distance = levenshtein(ground_truth.split(), prediction.split())
word_length = len(ground_truth.split())
return AttrDict({
'wav_filename': wav_filename,
'src': ground_truth,
'res': prediction,
'loss': loss,
'char_distance': char_distance,
'char_length': char_length,
'word_distance': word_distance,
'word_length': word_length,
'cer': char_distance / char_length,
'wer': word_distance / word_length,
})
def calculate_and_print_report(wav_filenames, labels, decodings, losses, dataset_name):
r'''
This routine will calculate and print a WER report.
It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest
loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER).
'''
samples = pmap(process_decode_result, zip(wav_filenames, labels, decodings, losses))
# Getting the WER and CER from the accumulated edit distances and lengths
samples_wer, samples_cer = wer_cer_batch(samples)
# Reversed because the worst WER with the best loss is to identify systemic issues, where the acoustic model is confident,
# yet the result is completely off the mark. This can point to transcription errors and stuff like that.
samples.sort(key=lambda s: s.loss, reverse=True)
# Then order by ascending WER/CER
if FLAGS.bytes_output_mode:
samples.sort(key=lambda s: s.cer)
else:
samples.sort(key=lambda s: s.wer)
# Print the report
print_report(samples, losses, samples_wer, samples_cer, dataset_name)
return samples
def print_report(samples, losses, wer, cer, dataset_name):
""" Print a report summary and samples of best, median and worst results """
# Print summary
mean_loss = np.mean(losses)
print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset_name, wer, cer, mean_loss))
print('-' * 80)
best_samples = samples[:FLAGS.report_count]
worst_samples = samples[-FLAGS.report_count:]
median_index = int(len(samples) / 2)
median_left = int(FLAGS.report_count / 2)
median_right = FLAGS.report_count - median_left
median_samples = samples[median_index - median_left:median_index + median_right]
def print_single_sample(sample):
print('WER: %f, CER: %f, loss: %f' % (sample.wer, sample.cer, sample.loss))
print(' - wav: file://%s' % sample.wav_filename)
print(' - src: "%s"' % sample.src)
print(' - res: "%s"' % sample.res)
print('-' * 80)
print('Best WER:', '\n' + '-' * 80)
for s in best_samples:
print_single_sample(s)
print('Median WER:', '\n' + '-' * 80)
for s in median_samples:
print_single_sample(s)
print('Worst WER:', '\n' + '-' * 80)
for s in worst_samples:
print_single_sample(s)
def save_samples_json(samples, output_path):
''' Save decoded tuples as JSON, converting NumPy floats to Python floats.
We set ensure_ascii=True to prevent json from escaping non-ASCII chars
in the texts.
'''
with open_remote(output_path, 'w') as fout:
json.dump(samples, fout, default=float, ensure_ascii=False, indent=2)