parent
afcf7dbf54
commit
c0cd365544
40
evaluate.py
40
evaluate.py
@ -12,7 +12,6 @@ import sys
|
|||||||
import tables
|
import tables
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from attrdict import AttrDict
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||||
from multiprocessing import Pool, cpu_count
|
from multiprocessing import Pool, cpu_count
|
||||||
@ -21,9 +20,9 @@ from util.audio import audiofile_to_input_vector
|
|||||||
from util.config import Config, initialize_globals
|
from util.config import Config, initialize_globals
|
||||||
from util.flags import create_flags, FLAGS
|
from util.flags import create_flags, FLAGS
|
||||||
from util.logging import log_error
|
from util.logging import log_error
|
||||||
from util.preprocess import pmap, preprocess
|
from util.preprocess import preprocess
|
||||||
from util.text import Alphabet, wer_cer_batch, levenshtein
|
from util.text import Alphabet, levenshtein
|
||||||
|
from util.evaluate_tools import process_decode_result, calculate_report
|
||||||
|
|
||||||
def split_data(dataset, batch_size):
|
def split_data(dataset, batch_size):
|
||||||
remainder = len(dataset) % batch_size
|
remainder = len(dataset) % batch_size
|
||||||
@ -45,39 +44,6 @@ def pad_to_dense(jagged):
|
|||||||
return padded
|
return padded
|
||||||
|
|
||||||
|
|
||||||
def process_decode_result(item):
|
|
||||||
label, decoding, distance, loss = item
|
|
||||||
word_distance = levenshtein(label.split(), decoding.split())
|
|
||||||
word_length = float(len(label.split()))
|
|
||||||
return AttrDict({
|
|
||||||
'src': label,
|
|
||||||
'res': decoding,
|
|
||||||
'loss': loss,
|
|
||||||
'distance': distance,
|
|
||||||
'wer': word_distance / word_length,
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_report(labels, decodings, distances, losses):
|
|
||||||
r'''
|
|
||||||
This routine will calculate a WER report.
|
|
||||||
It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest
|
|
||||||
loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER).
|
|
||||||
'''
|
|
||||||
samples = pmap(process_decode_result, zip(labels, decodings, distances, losses))
|
|
||||||
|
|
||||||
# Getting the WER and CER from the accumulated edit distances and lengths
|
|
||||||
samples_wer, samples_cer = wer_cer_batch(labels, decodings)
|
|
||||||
|
|
||||||
# Order the remaining items by their loss (lowest loss on top)
|
|
||||||
samples.sort(key=lambda s: s.loss)
|
|
||||||
|
|
||||||
# Then order by WER (highest WER on top)
|
|
||||||
samples.sort(key=lambda s: s.wer, reverse=True)
|
|
||||||
|
|
||||||
return samples_wer, samples_cer, samples
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(test_data, inference_graph):
|
def evaluate(test_data, inference_graph):
|
||||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||||
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
|
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
|
||||||
|
105
evaluate_tflite.py
Normal file
105
evaluate_tflite.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import wave
|
||||||
|
import csv
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from six.moves import zip, range
|
||||||
|
from multiprocessing import JoinableQueue, Pool, Process, Queue, cpu_count
|
||||||
|
from deepspeech import Model
|
||||||
|
|
||||||
|
from util.text import levenshtein
|
||||||
|
from util.evaluate_tools import process_decode_result, calculate_report
|
||||||
|
|
||||||
|
r'''
|
||||||
|
This module should be self-contained:
|
||||||
|
- build libdeepspeech.so with TFLite:
|
||||||
|
- add a dep in native_client/BUILD against TFlite: '//tensorflow:linux_x86_64': [ "//tensorflow/contrib/lite/kernels:builtin_ops" ]
|
||||||
|
- bazel build [...] --copt=-DUSE_TFLITE [...] //native_client:libdeepspeech.so
|
||||||
|
- make -C native_client/python/ TFDIR=... bindings
|
||||||
|
- setup a virtualenv
|
||||||
|
- pip install native_client/python/dist/deepspeech*.whl
|
||||||
|
- pip install -r requirements_eval_tflite.txt
|
||||||
|
|
||||||
|
Then run with a TF Lite model, alphabet, LM/trie and a CSV test file
|
||||||
|
'''
|
||||||
|
|
||||||
|
BEAM_WIDTH = 500
|
||||||
|
LM_ALPHA = 0.75
|
||||||
|
LM_BETA = 1.85
|
||||||
|
N_FEATURES = 26
|
||||||
|
N_CONTEXT = 9
|
||||||
|
|
||||||
|
def tflite_worker(model, alphabet, lm, trie, queue_in, queue_out):
|
||||||
|
ds = Model(model, N_FEATURES, N_CONTEXT, alphabet, BEAM_WIDTH)
|
||||||
|
ds.enableDecoderWithLM(alphabet, lm, trie, LM_ALPHA, LM_BETA)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
msg = queue_in.get()
|
||||||
|
|
||||||
|
fin = wave.open(msg['filename'], 'rb')
|
||||||
|
fs = fin.getframerate()
|
||||||
|
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
|
||||||
|
audio_length = fin.getnframes() * (1/16000)
|
||||||
|
fin.close()
|
||||||
|
|
||||||
|
decoded = ds.stt(audio, fs)
|
||||||
|
|
||||||
|
queue_out.put({'prediction': decoded, 'ground_truth': msg['transcript']})
|
||||||
|
queue_in.task_done()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Computing TFLite accuracy')
|
||||||
|
parser.add_argument('--model', required=True,
|
||||||
|
help='Path to the model (protocol buffer binary file)')
|
||||||
|
parser.add_argument('--alphabet', required=True,
|
||||||
|
help='Path to the configuration file specifying the alphabet used by the network')
|
||||||
|
parser.add_argument('--lm', required=True,
|
||||||
|
help='Path to the language model binary file')
|
||||||
|
parser.add_argument('--trie', required=True,
|
||||||
|
help='Path to the language model trie file created with native_client/generate_trie')
|
||||||
|
parser.add_argument('--csv', required=True,
|
||||||
|
help='Path to the CSV source file')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
work_todo = JoinableQueue() # this is where we are going to store input data
|
||||||
|
work_done = Queue() # this where we are gonna push them out
|
||||||
|
|
||||||
|
processes = []
|
||||||
|
for i in range(cpu_count()):
|
||||||
|
worker_process = Process(target=tflite_worker, args=(args.model, args.alphabet, args.lm, args.trie, work_todo, work_done), daemon=True, name='tflite_process_{}'.format(i))
|
||||||
|
worker_process.start() # Launch reader() as a separate python process
|
||||||
|
processes.append(worker_process)
|
||||||
|
|
||||||
|
print([x.name for x in processes])
|
||||||
|
|
||||||
|
ground_truths = []
|
||||||
|
predictions = []
|
||||||
|
losses = []
|
||||||
|
|
||||||
|
with open(args.csv, 'r') as csvfile:
|
||||||
|
csvreader = csv.DictReader(csvfile)
|
||||||
|
for row in csvreader:
|
||||||
|
work_todo.put({'filename': row['wav_filename'], 'transcript': row['transcript']})
|
||||||
|
work_todo.join()
|
||||||
|
|
||||||
|
while (not work_done.empty()):
|
||||||
|
msg = work_done.get()
|
||||||
|
losses.append(0.0)
|
||||||
|
ground_truths.append(msg['ground_truth'])
|
||||||
|
predictions.append(msg['prediction'])
|
||||||
|
|
||||||
|
distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]
|
||||||
|
|
||||||
|
wer, cer, samples = calculate_report(ground_truths, predictions, distances, losses)
|
||||||
|
mean_loss = np.mean(losses)
|
||||||
|
|
||||||
|
print('Test - WER: %f, CER: %f, loss: %f' %
|
||||||
|
(wer, cer, mean_loss))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
7
requirements_eval_tflite.txt
Normal file
7
requirements_eval_tflite.txt
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
attrdict==2.0.0
|
||||||
|
deepspeech
|
||||||
|
numpy==1.16.0
|
||||||
|
pkg-resources==0.0.0
|
||||||
|
progressbar2==3.39.2
|
||||||
|
python-utils==2.3.0
|
||||||
|
six==1.12.0
|
45
util/evaluate_tools.py
Normal file
45
util/evaluate_tools.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
from attrdict import AttrDict
|
||||||
|
from multiprocessing.dummy import Pool
|
||||||
|
from util.text import wer_cer_batch, levenshtein
|
||||||
|
|
||||||
|
def pmap(fun, iterable):
|
||||||
|
pool = Pool()
|
||||||
|
results = pool.map(fun, iterable)
|
||||||
|
pool.close()
|
||||||
|
return results
|
||||||
|
|
||||||
|
def process_decode_result(item):
|
||||||
|
label, decoding, distance, loss = item
|
||||||
|
word_distance = levenshtein(label.split(), decoding.split())
|
||||||
|
word_length = float(len(label.split()))
|
||||||
|
return AttrDict({
|
||||||
|
'src': label,
|
||||||
|
'res': decoding,
|
||||||
|
'loss': loss,
|
||||||
|
'distance': distance,
|
||||||
|
'wer': word_distance / word_length,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_report(labels, decodings, distances, losses):
|
||||||
|
r'''
|
||||||
|
This routine will calculate a WER report.
|
||||||
|
It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest
|
||||||
|
loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER).
|
||||||
|
'''
|
||||||
|
samples = pmap(process_decode_result, zip(labels, decodings, distances, losses))
|
||||||
|
|
||||||
|
# Getting the WER and CER from the accumulated edit distances and lengths
|
||||||
|
samples_wer, samples_cer = wer_cer_batch(labels, decodings)
|
||||||
|
|
||||||
|
# Order the remaining items by their loss (lowest loss on top)
|
||||||
|
samples.sort(key=lambda s: s.loss)
|
||||||
|
|
||||||
|
# Then order by WER (highest WER on top)
|
||||||
|
samples.sort(key=lambda s: s.wer, reverse=True)
|
||||||
|
|
||||||
|
return samples_wer, samples_cer, samples
|
Loading…
x
Reference in New Issue
Block a user