Merge pull request #2664 from mozilla/evaluate_tflite_fixes

evaluate_tflite.py fixes
This commit is contained in:
Reuben Morais 2020-01-13 10:03:40 +01:00 committed by GitHub
commit faed282cfc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 106 additions and 26 deletions

View File

@ -2,17 +2,21 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import absl.app
import argparse
import numpy as np
import wave
import csv
import os
import sys
from functools import partial
from six.moves import zip, range
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
from deepspeech import Model
from util.evaluate_tools import calculate_report
from util.flags import create_flags
r'''
This module should be self-contained:
@ -40,36 +44,20 @@ def tflite_worker(model, lm, trie, queue_in, queue_out, gpu_mask):
msg = queue_in.get()
filename = msg['filename']
wavname = os.path.splitext(os.path.basename(filename))[0]
fin = wave.open(filename, 'rb')
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
fin.close()
decoded = ds.stt(audio)
queue_out.put({'wav': wavname, 'prediction': decoded, 'ground_truth': msg['transcript']})
queue_out.put({'wav': filename, 'prediction': decoded, 'ground_truth': msg['transcript']})
except FileNotFoundError as ex:
print('FileNotFoundError: ', ex)
print(queue_out.qsize(), end='\r') # Update the current progress
queue_in.task_done()
def main():
parser = argparse.ArgumentParser(description='Computing TFLite accuracy')
parser.add_argument('--model', required=True,
help='Path to the model (protocol buffer binary file)')
parser.add_argument('--lm', required=True,
help='Path to the language model binary file')
parser.add_argument('--trie', required=True,
help='Path to the language model trie file created with native_client/generate_trie')
parser.add_argument('--csv', required=True,
help='Path to the CSV source file')
parser.add_argument('--proc', required=False, default=cpu_count(), type=int,
help='Number of processes to spawn, defaulting to number of CPUs')
parser.add_argument('--dump', required=False, action='store_true', default=False,
help='Dump the results as text file, with one line for each wav: "wav transcription"')
args = parser.parse_args()
def main(args, _):
manager = Manager()
work_todo = JoinableQueue() # this is where we are going to store input data
work_done = manager.Queue() # this where we are gonna push them out
@ -93,6 +81,9 @@ def main():
count = 0
for row in csvreader:
count += 1
# Relative paths are relative to the folder the CSV file is in
if not os.path.isabs(row['wav_filename']):
row['wav_filename'] = os.path.join(os.path.dirname(args.csv), row['wav_filename'])
work_todo.put({'filename': row['wav_filename'], 'transcript': row['transcript']})
wav_filenames.extend(row['wav_filename'])
@ -114,12 +105,32 @@ def main():
(wer, cer, mean_loss))
if args.dump:
with open(args.csv + '.txt', 'w') as ftxt, open(args.csv + '.out', 'w') as fout:
with open(args.dump + '.txt', 'w') as ftxt, open(args.dump + '.out', 'w') as fout:
for wav, txt, out in zip(wavlist, ground_truths, predictions):
ftxt.write('%s %s\n' % (wav, txt))
fout.write('%s %s\n' % (wav, out))
print('Reference texts dumped to %s.txt' % args.csv)
print('Transcription dumped to %s.out' % args.csv)
print('Reference texts dumped to %s.txt' % args.dump)
print('Transcription dumped to %s.out' % args.dump)
def parse_args():
parser = argparse.ArgumentParser(description='Computing TFLite accuracy')
parser.add_argument('--model', required=True,
help='Path to the model (protocol buffer binary file)')
parser.add_argument('--lm', required=True,
help='Path to the language model binary file')
parser.add_argument('--trie', required=True,
help='Path to the language model trie file created with native_client/generate_trie')
parser.add_argument('--csv', required=True,
help='Path to the CSV source file')
parser.add_argument('--proc', required=False, default=cpu_count(), type=int,
help='Number of processes to spawn, defaulting to number of CPUs')
parser.add_argument('--dump', required=False,
help='Path to dump the results as text file, with one line for each wav: "wav transcription".')
args, unknown = parser.parse_known_args()
# Reconstruct argv for absl.flags
sys.argv = [sys.argv[0]] + unknown
return args
if __name__ == '__main__':
main()
create_flags()
absl.app.run(partial(main, parse_args()))

View File

@ -1,7 +1,8 @@
attrdict==2.0.0
absl-py==0.9.0
attrdict==2.0.1
deepspeech
numpy==1.16.0
pkg-resources==0.0.0
progressbar2==3.39.2
progressbar2==3.47.0
python-utils==2.3.0
six==1.12.0
six==1.13.0
pandas==0.25.3

View File

@ -0,0 +1,54 @@
#!/bin/bash
set -xe
source $(dirname "$0")/tc-tests-utils.sh
extract_python_versions "$1" "pyver" "pyver_pkg" "py_unicode_type" "pyconf" "pyalias"
bitrate=$2
set_ldc_sample_filename "${bitrate}"
unset PYTHON_BIN_PATH
unset PYTHONPATH
if [ -d "${DS_ROOT_TASK}/pyenv.cache/" ]; then
export PYENV_ROOT="${DS_ROOT_TASK}/pyenv.cache/ds-test/.pyenv"
else
export PYENV_ROOT="${DS_ROOT_TASK}/ds-test/.pyenv"
fi;
export PATH="${PYENV_ROOT}/bin:$PATH"
mkdir -p ${PYENV_ROOT} || true
download_data
install_pyenv "${PYENV_ROOT}"
install_pyenv_virtualenv "$(pyenv root)/plugins/pyenv-virtualenv"
maybe_ssl102_py37 ${pyver}
maybe_numpy_min_version_winamd64 ${pyver}
PYENV_NAME=deepspeech-test
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH PYTHON_CONFIGURE_OPTS="--enable-unicode=${pyconf} ${PY37_OPENSSL} ${EXTRA_PYTHON_CONFIGURE_OPTS}" pyenv_install ${pyver} ${pyalias}
setup_pyenv_virtualenv "${pyalias}" "${PYENV_NAME}"
virtualenv_activate "${pyalias}" "${PYENV_NAME}"
deepspeech_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type})
set -o pipefail
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: ${PY37_SOURCE_PACKAGE} --upgrade ${deepspeech_pkg_url} | cat
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
set +o pipefail
which deepspeech
deepspeech --version
pushd ${HOME}/DeepSpeech/ds/
python bin/import_ldc93s1.py data/smoke_test
python evaluate_tflite.py --model "${TASKCLUSTER_TMP_DIR}/${model_name_mmap}" --lm data/smoke_test/vocab.pruned.lm --trie data/smoke_test/vocab.trie --csv data/smoke_test/ldc93s1.csv
popd
virtualenv_deactivate "${pyalias}" "${PYENV_NAME}"

View File

@ -0,0 +1,14 @@
build:
template_file: test-linux-opt-base.tyml
dependencies:
- "linux-amd64-cpu-opt"
- "test-training_16k-linux-amd64-py36m-opt"
test_model_task: "test-training_16k-linux-amd64-py36m-opt"
system_setup:
>
apt-get -qq -y install ${python.packages_trusty.apt}
args:
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-evaluate_tflite.sh 3.6.4:m 16k"
metadata:
name: "DeepSpeech Linux AMD64 CPU evaluate_tflite.py Py3.6 (16kHz)"
description: "Test evaluate_tflite.py on Linux/AMD64 using upstream TensorFlow Python 3.6, CPU only, optimized version"