134 lines
4.4 KiB
Python
134 lines
4.4 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
from __future__ import absolute_import, division, print_function
|
|
|
|
import json
|
|
from multiprocessing.dummy import Pool
|
|
|
|
import numpy as np
|
|
from attrdict import AttrDict
|
|
|
|
from .config import Config
|
|
from .io import open_remote
|
|
from .text import levenshtein
|
|
|
|
|
|
def pmap(fun, iterable):
|
|
pool = Pool()
|
|
results = pool.map(fun, iterable)
|
|
pool.close()
|
|
return results
|
|
|
|
|
|
def wer_cer_batch(samples):
|
|
r"""
|
|
The WER is defined as the edit/Levenshtein distance on word level divided by
|
|
the amount of words in the original text.
|
|
In case of the original having more words (N) than the result and both
|
|
being totally different (all N words resulting in 1 edit operation each),
|
|
the WER will always be 1 (N / N = 1).
|
|
"""
|
|
wer = sum(s.word_distance for s in samples) / sum(s.word_length for s in samples)
|
|
cer = sum(s.char_distance for s in samples) / sum(s.char_length for s in samples)
|
|
|
|
wer = min(wer, 1.0)
|
|
cer = min(cer, 1.0)
|
|
|
|
return wer, cer
|
|
|
|
|
|
def process_decode_result(item):
|
|
wav_filename, ground_truth, prediction, loss = item
|
|
char_distance = levenshtein(ground_truth, prediction)
|
|
char_length = len(ground_truth)
|
|
word_distance = levenshtein(ground_truth.split(), prediction.split())
|
|
word_length = len(ground_truth.split())
|
|
return AttrDict(
|
|
{
|
|
"wav_filename": wav_filename,
|
|
"src": ground_truth,
|
|
"res": prediction,
|
|
"loss": loss,
|
|
"char_distance": char_distance,
|
|
"char_length": char_length,
|
|
"word_distance": word_distance,
|
|
"word_length": word_length,
|
|
"cer": char_distance / char_length,
|
|
"wer": word_distance / word_length,
|
|
}
|
|
)
|
|
|
|
|
|
def calculate_and_print_report(wav_filenames, labels, decodings, losses, dataset_name):
|
|
r"""
|
|
This routine will calculate and print a WER report.
|
|
It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest
|
|
loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER).
|
|
"""
|
|
samples = pmap(process_decode_result, zip(wav_filenames, labels, decodings, losses))
|
|
|
|
# Getting the WER and CER from the accumulated edit distances and lengths
|
|
samples_wer, samples_cer = wer_cer_batch(samples)
|
|
|
|
# Reversed because the worst WER with the best loss is to identify systemic issues, where the acoustic model is confident,
|
|
# yet the result is completely off the mark. This can point to transcription errors and stuff like that.
|
|
samples.sort(key=lambda s: s.loss, reverse=True)
|
|
|
|
# Then order by ascending WER/CER
|
|
if Config.bytes_output_mode:
|
|
samples.sort(key=lambda s: s.cer)
|
|
else:
|
|
samples.sort(key=lambda s: s.wer)
|
|
|
|
# Print the report
|
|
print_report(samples, losses, samples_wer, samples_cer, dataset_name)
|
|
|
|
return samples
|
|
|
|
|
|
def print_report(samples, losses, wer, cer, dataset_name):
|
|
""" Print a report summary and samples of best, median and worst results """
|
|
|
|
# Print summary
|
|
mean_loss = np.mean(losses)
|
|
print(
|
|
"Test on %s - WER: %f, CER: %f, loss: %f" % (dataset_name, wer, cer, mean_loss)
|
|
)
|
|
print("-" * 80)
|
|
|
|
best_samples = samples[: Config.report_count]
|
|
worst_samples = samples[-Config.report_count :]
|
|
median_index = int(len(samples) / 2)
|
|
median_left = int(Config.report_count / 2)
|
|
median_right = Config.report_count - median_left
|
|
median_samples = samples[median_index - median_left : median_index + median_right]
|
|
|
|
def print_single_sample(sample):
|
|
print("WER: %f, CER: %f, loss: %f" % (sample.wer, sample.cer, sample.loss))
|
|
print(" - wav: file://%s" % sample.wav_filename)
|
|
print(' - src: "%s"' % sample.src)
|
|
print(' - res: "%s"' % sample.res)
|
|
print("-" * 80)
|
|
|
|
print("Best WER:", "\n" + "-" * 80)
|
|
for s in best_samples:
|
|
print_single_sample(s)
|
|
|
|
print("Median WER:", "\n" + "-" * 80)
|
|
for s in median_samples:
|
|
print_single_sample(s)
|
|
|
|
print("Worst WER:", "\n" + "-" * 80)
|
|
for s in worst_samples:
|
|
print_single_sample(s)
|
|
|
|
|
|
def save_samples_json(samples, output_path):
|
|
"""Save decoded tuples as JSON, converting NumPy floats to Python floats.
|
|
|
|
We set ensure_ascii=True to prevent json from escaping non-ASCII chars
|
|
in the texts.
|
|
"""
|
|
with open_remote(output_path, "w") as fout:
|
|
json.dump(samples, fout, default=float, ensure_ascii=False, indent=2)
|