parent
afcf7dbf54
commit
c0cd365544
40
evaluate.py
40
evaluate.py
@ -12,7 +12,6 @@ import sys
|
||||
import tables
|
||||
import tensorflow as tf
|
||||
|
||||
from attrdict import AttrDict
|
||||
from collections import namedtuple
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||
from multiprocessing import Pool, cpu_count
|
||||
@ -21,9 +20,9 @@ from util.audio import audiofile_to_input_vector
|
||||
from util.config import Config, initialize_globals
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.logging import log_error
|
||||
from util.preprocess import pmap, preprocess
|
||||
from util.text import Alphabet, wer_cer_batch, levenshtein
|
||||
|
||||
from util.preprocess import preprocess
|
||||
from util.text import Alphabet, levenshtein
|
||||
from util.evaluate_tools import process_decode_result, calculate_report
|
||||
|
||||
def split_data(dataset, batch_size):
|
||||
remainder = len(dataset) % batch_size
|
||||
@ -45,39 +44,6 @@ def pad_to_dense(jagged):
|
||||
return padded
|
||||
|
||||
|
||||
def process_decode_result(item):
|
||||
label, decoding, distance, loss = item
|
||||
word_distance = levenshtein(label.split(), decoding.split())
|
||||
word_length = float(len(label.split()))
|
||||
return AttrDict({
|
||||
'src': label,
|
||||
'res': decoding,
|
||||
'loss': loss,
|
||||
'distance': distance,
|
||||
'wer': word_distance / word_length,
|
||||
})
|
||||
|
||||
|
||||
def calculate_report(labels, decodings, distances, losses):
|
||||
r'''
|
||||
This routine will calculate a WER report.
|
||||
It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest
|
||||
loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER).
|
||||
'''
|
||||
samples = pmap(process_decode_result, zip(labels, decodings, distances, losses))
|
||||
|
||||
# Getting the WER and CER from the accumulated edit distances and lengths
|
||||
samples_wer, samples_cer = wer_cer_batch(labels, decodings)
|
||||
|
||||
# Order the remaining items by their loss (lowest loss on top)
|
||||
samples.sort(key=lambda s: s.loss)
|
||||
|
||||
# Then order by WER (highest WER on top)
|
||||
samples.sort(key=lambda s: s.wer, reverse=True)
|
||||
|
||||
return samples_wer, samples_cer, samples
|
||||
|
||||
|
||||
def evaluate(test_data, inference_graph):
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
|
||||
|
105
evaluate_tflite.py
Normal file
105
evaluate_tflite.py
Normal file
@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import wave
|
||||
import csv
|
||||
import sys
|
||||
|
||||
from six.moves import zip, range
|
||||
from multiprocessing import JoinableQueue, Pool, Process, Queue, cpu_count
|
||||
from deepspeech import Model
|
||||
|
||||
from util.text import levenshtein
|
||||
from util.evaluate_tools import process_decode_result, calculate_report
|
||||
|
||||
r'''
|
||||
This module should be self-contained:
|
||||
- build libdeepspeech.so with TFLite:
|
||||
- add a dep in native_client/BUILD against TFlite: '//tensorflow:linux_x86_64': [ "//tensorflow/contrib/lite/kernels:builtin_ops" ]
|
||||
- bazel build [...] --copt=-DUSE_TFLITE [...] //native_client:libdeepspeech.so
|
||||
- make -C native_client/python/ TFDIR=... bindings
|
||||
- setup a virtualenv
|
||||
- pip install native_client/python/dist/deepspeech*.whl
|
||||
- pip install -r requirements_eval_tflite.txt
|
||||
|
||||
Then run with a TF Lite model, alphabet, LM/trie and a CSV test file
|
||||
'''
|
||||
|
||||
BEAM_WIDTH = 500
|
||||
LM_ALPHA = 0.75
|
||||
LM_BETA = 1.85
|
||||
N_FEATURES = 26
|
||||
N_CONTEXT = 9
|
||||
|
||||
def tflite_worker(model, alphabet, lm, trie, queue_in, queue_out):
|
||||
ds = Model(model, N_FEATURES, N_CONTEXT, alphabet, BEAM_WIDTH)
|
||||
ds.enableDecoderWithLM(alphabet, lm, trie, LM_ALPHA, LM_BETA)
|
||||
|
||||
while True:
|
||||
msg = queue_in.get()
|
||||
|
||||
fin = wave.open(msg['filename'], 'rb')
|
||||
fs = fin.getframerate()
|
||||
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
|
||||
audio_length = fin.getnframes() * (1/16000)
|
||||
fin.close()
|
||||
|
||||
decoded = ds.stt(audio, fs)
|
||||
|
||||
queue_out.put({'prediction': decoded, 'ground_truth': msg['transcript']})
|
||||
queue_in.task_done()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Computing TFLite accuracy')
|
||||
parser.add_argument('--model', required=True,
|
||||
help='Path to the model (protocol buffer binary file)')
|
||||
parser.add_argument('--alphabet', required=True,
|
||||
help='Path to the configuration file specifying the alphabet used by the network')
|
||||
parser.add_argument('--lm', required=True,
|
||||
help='Path to the language model binary file')
|
||||
parser.add_argument('--trie', required=True,
|
||||
help='Path to the language model trie file created with native_client/generate_trie')
|
||||
parser.add_argument('--csv', required=True,
|
||||
help='Path to the CSV source file')
|
||||
args = parser.parse_args()
|
||||
|
||||
work_todo = JoinableQueue() # this is where we are going to store input data
|
||||
work_done = Queue() # this where we are gonna push them out
|
||||
|
||||
processes = []
|
||||
for i in range(cpu_count()):
|
||||
worker_process = Process(target=tflite_worker, args=(args.model, args.alphabet, args.lm, args.trie, work_todo, work_done), daemon=True, name='tflite_process_{}'.format(i))
|
||||
worker_process.start() # Launch reader() as a separate python process
|
||||
processes.append(worker_process)
|
||||
|
||||
print([x.name for x in processes])
|
||||
|
||||
ground_truths = []
|
||||
predictions = []
|
||||
losses = []
|
||||
|
||||
with open(args.csv, 'r') as csvfile:
|
||||
csvreader = csv.DictReader(csvfile)
|
||||
for row in csvreader:
|
||||
work_todo.put({'filename': row['wav_filename'], 'transcript': row['transcript']})
|
||||
work_todo.join()
|
||||
|
||||
while (not work_done.empty()):
|
||||
msg = work_done.get()
|
||||
losses.append(0.0)
|
||||
ground_truths.append(msg['ground_truth'])
|
||||
predictions.append(msg['prediction'])
|
||||
|
||||
distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]
|
||||
|
||||
wer, cer, samples = calculate_report(ground_truths, predictions, distances, losses)
|
||||
mean_loss = np.mean(losses)
|
||||
|
||||
print('Test - WER: %f, CER: %f, loss: %f' %
|
||||
(wer, cer, mean_loss))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
7
requirements_eval_tflite.txt
Normal file
7
requirements_eval_tflite.txt
Normal file
@ -0,0 +1,7 @@
|
||||
attrdict==2.0.0
|
||||
deepspeech
|
||||
numpy==1.16.0
|
||||
pkg-resources==0.0.0
|
||||
progressbar2==3.39.2
|
||||
python-utils==2.3.0
|
||||
six==1.12.0
|
45
util/evaluate_tools.py
Normal file
45
util/evaluate_tools.py
Normal file
@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from attrdict import AttrDict
|
||||
from multiprocessing.dummy import Pool
|
||||
from util.text import wer_cer_batch, levenshtein
|
||||
|
||||
def pmap(fun, iterable):
|
||||
pool = Pool()
|
||||
results = pool.map(fun, iterable)
|
||||
pool.close()
|
||||
return results
|
||||
|
||||
def process_decode_result(item):
|
||||
label, decoding, distance, loss = item
|
||||
word_distance = levenshtein(label.split(), decoding.split())
|
||||
word_length = float(len(label.split()))
|
||||
return AttrDict({
|
||||
'src': label,
|
||||
'res': decoding,
|
||||
'loss': loss,
|
||||
'distance': distance,
|
||||
'wer': word_distance / word_length,
|
||||
})
|
||||
|
||||
|
||||
def calculate_report(labels, decodings, distances, losses):
|
||||
r'''
|
||||
This routine will calculate a WER report.
|
||||
It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest
|
||||
loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER).
|
||||
'''
|
||||
samples = pmap(process_decode_result, zip(labels, decodings, distances, losses))
|
||||
|
||||
# Getting the WER and CER from the accumulated edit distances and lengths
|
||||
samples_wer, samples_cer = wer_cer_batch(labels, decodings)
|
||||
|
||||
# Order the remaining items by their loss (lowest loss on top)
|
||||
samples.sort(key=lambda s: s.loss)
|
||||
|
||||
# Then order by WER (highest WER on top)
|
||||
samples.sort(key=lambda s: s.wer, reverse=True)
|
||||
|
||||
return samples_wer, samples_cer, samples
|
Loading…
x
Reference in New Issue
Block a user