* Redo remote I/O changes once more; this time without messing with taskcluster * Add bin changes * Fix merge-induced issue? * For the interleaved case with multiple collections, unpack audio on the fly To reproduce the previous failure rm data/smoke_test/ldc93s1.csv rm data/smoke_test/ldc93s1.sdb rm -rf /tmp/ldc93s1_cache_sdb_csv rm -rf /tmp/ckpt_sdb_csv rm -rf /tmp/train_sdb_csv ./bin/run-tc-ldc93s1_new_sdb_csv.sh 109 16000 python -u DeepSpeech.py --noshow_progressbar --noearly_stop --train_files ./data/smoke_test/ldc93s1.sdb,./data/smoke_test/ldc93s1.csv --train_batch_size 1 --feature_cache /tmp/ldc93s1_cache_sdb_csv --dev_files ./data/smoke_test/ldc93s1.sdb,./data/smoke_test/ldc93s1.csv --dev_batch_size 1 --test_files ./data/smoke_test/ldc93s1.sdb,./data/smoke_test/ldc93s1.csv --test_batch_size 1 --n_hidden 100 --epochs 109 --max_to_keep 1 --checkpoint_dir /tmp/ckpt_sdb_csv --learning_rate 0.001 --dropout_rate 0.05 --export_dir /tmp/train_sdb_csv --scorer_path data/smoke_test/pruned_lm.scorer --audio_sample_rate 16000 * Attempt to preserve length information with a wrapper around `map()`… this gets pretty python-y * Call the right `__next__()` * Properly implement the rest of the map wrappers here…… * Fix trailing whitespace situation and other linter complaints * Remove data accidentally checked in * Fix overlay augmentations * Wavs must be open in rb mode if we're passing in an external file pointer -- this confused me * Lint whitespace * Revert "Fix trailing whitespace situation and other linter complaints" This reverts commit c3c45397a2f98e9b00d00c18c4ced4fc52475032. * Fix linter issue but without such an aggressive diff * Move unpack_maybe into sample_collections * Use unpack_maybe in place of duplicate lambda * Fix confusing comment * Add clarifying comment for on-the-fly unpacking
129 lines
4.4 KiB
Python
129 lines
4.4 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
from __future__ import absolute_import, division, print_function
|
|
|
|
import json
|
|
from multiprocessing.dummy import Pool
|
|
|
|
import numpy as np
|
|
from attrdict import AttrDict
|
|
|
|
from .flags import FLAGS
|
|
from .text import levenshtein
|
|
from .io import open_remote
|
|
|
|
def pmap(fun, iterable):
|
|
pool = Pool()
|
|
results = pool.map(fun, iterable)
|
|
pool.close()
|
|
return results
|
|
|
|
|
|
def wer_cer_batch(samples):
|
|
r"""
|
|
The WER is defined as the edit/Levenshtein distance on word level divided by
|
|
the amount of words in the original text.
|
|
In case of the original having more words (N) than the result and both
|
|
being totally different (all N words resulting in 1 edit operation each),
|
|
the WER will always be 1 (N / N = 1).
|
|
"""
|
|
wer = sum(s.word_distance for s in samples) / sum(s.word_length for s in samples)
|
|
cer = sum(s.char_distance for s in samples) / sum(s.char_length for s in samples)
|
|
|
|
wer = min(wer, 1.0)
|
|
cer = min(cer, 1.0)
|
|
|
|
return wer, cer
|
|
|
|
|
|
def process_decode_result(item):
|
|
wav_filename, ground_truth, prediction, loss = item
|
|
char_distance = levenshtein(ground_truth, prediction)
|
|
char_length = len(ground_truth)
|
|
word_distance = levenshtein(ground_truth.split(), prediction.split())
|
|
word_length = len(ground_truth.split())
|
|
return AttrDict({
|
|
'wav_filename': wav_filename,
|
|
'src': ground_truth,
|
|
'res': prediction,
|
|
'loss': loss,
|
|
'char_distance': char_distance,
|
|
'char_length': char_length,
|
|
'word_distance': word_distance,
|
|
'word_length': word_length,
|
|
'cer': char_distance / char_length,
|
|
'wer': word_distance / word_length,
|
|
})
|
|
|
|
|
|
def calculate_and_print_report(wav_filenames, labels, decodings, losses, dataset_name):
|
|
r'''
|
|
This routine will calculate and print a WER report.
|
|
It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest
|
|
loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER).
|
|
'''
|
|
samples = pmap(process_decode_result, zip(wav_filenames, labels, decodings, losses))
|
|
|
|
# Getting the WER and CER from the accumulated edit distances and lengths
|
|
samples_wer, samples_cer = wer_cer_batch(samples)
|
|
|
|
# Reversed because the worst WER with the best loss is to identify systemic issues, where the acoustic model is confident,
|
|
# yet the result is completely off the mark. This can point to transcription errors and stuff like that.
|
|
samples.sort(key=lambda s: s.loss, reverse=True)
|
|
|
|
# Then order by ascending WER/CER
|
|
if FLAGS.bytes_output_mode:
|
|
samples.sort(key=lambda s: s.cer)
|
|
else:
|
|
samples.sort(key=lambda s: s.wer)
|
|
|
|
# Print the report
|
|
print_report(samples, losses, samples_wer, samples_cer, dataset_name)
|
|
|
|
return samples
|
|
|
|
|
|
def print_report(samples, losses, wer, cer, dataset_name):
|
|
""" Print a report summary and samples of best, median and worst results """
|
|
|
|
# Print summary
|
|
mean_loss = np.mean(losses)
|
|
print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset_name, wer, cer, mean_loss))
|
|
print('-' * 80)
|
|
|
|
best_samples = samples[:FLAGS.report_count]
|
|
worst_samples = samples[-FLAGS.report_count:]
|
|
median_index = int(len(samples) / 2)
|
|
median_left = int(FLAGS.report_count / 2)
|
|
median_right = FLAGS.report_count - median_left
|
|
median_samples = samples[median_index - median_left:median_index + median_right]
|
|
|
|
def print_single_sample(sample):
|
|
print('WER: %f, CER: %f, loss: %f' % (sample.wer, sample.cer, sample.loss))
|
|
print(' - wav: file://%s' % sample.wav_filename)
|
|
print(' - src: "%s"' % sample.src)
|
|
print(' - res: "%s"' % sample.res)
|
|
print('-' * 80)
|
|
|
|
print('Best WER:', '\n' + '-' * 80)
|
|
for s in best_samples:
|
|
print_single_sample(s)
|
|
|
|
print('Median WER:', '\n' + '-' * 80)
|
|
for s in median_samples:
|
|
print_single_sample(s)
|
|
|
|
print('Worst WER:', '\n' + '-' * 80)
|
|
for s in worst_samples:
|
|
print_single_sample(s)
|
|
|
|
|
|
def save_samples_json(samples, output_path):
|
|
''' Save decoded tuples as JSON, converting NumPy floats to Python floats.
|
|
|
|
We set ensure_ascii=True to prevent json from escaping non-ASCII chars
|
|
in the texts.
|
|
'''
|
|
with open_remote(output_path, 'w') as fout:
|
|
json.dump(samples, fout, default=float, ensure_ascii=False, indent=2)
|