STT/util/evaluate_tools.py
2019-04-11 07:02:21 -03:00

48 lines
1.5 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
from multiprocessing.dummy import Pool
from attrdict import AttrDict
from util.text import wer_cer_batch, levenshtein
def pmap(fun, iterable):
pool = Pool()
results = pool.map(fun, iterable)
pool.close()
return results
def process_decode_result(item):
label, decoding, distance, loss = item
word_distance = levenshtein(label.split(), decoding.split())
word_length = float(len(label.split()))
return AttrDict({
'src': label,
'res': decoding,
'loss': loss,
'distance': distance,
'wer': word_distance / word_length,
})
def calculate_report(labels, decodings, distances, losses):
r'''
This routine will calculate a WER report.
It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest
loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER).
'''
samples = pmap(process_decode_result, zip(labels, decodings, distances, losses))
# Getting the WER and CER from the accumulated edit distances and lengths
samples_wer, samples_cer = wer_cer_batch(labels, decodings)
# Order the remaining items by their loss (lowest loss on top)
samples.sort(key=lambda s: s.loss)
# Then order by WER (highest WER on top)
samples.sort(key=lambda s: s.wer, reverse=True)
return samples_wer, samples_cer, samples