Merge pull request #2664 from mozilla/evaluate_tflite_fixes
evaluate_tflite.py fixes
This commit is contained in:
commit
faed282cfc
@ -2,17 +2,21 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import absl.app
|
||||
import argparse
|
||||
import numpy as np
|
||||
import wave
|
||||
import csv
|
||||
import os
|
||||
import sys
|
||||
|
||||
from functools import partial
|
||||
from six.moves import zip, range
|
||||
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
|
||||
from deepspeech import Model
|
||||
|
||||
from util.evaluate_tools import calculate_report
|
||||
from util.flags import create_flags
|
||||
|
||||
r'''
|
||||
This module should be self-contained:
|
||||
@ -40,36 +44,20 @@ def tflite_worker(model, lm, trie, queue_in, queue_out, gpu_mask):
|
||||
msg = queue_in.get()
|
||||
|
||||
filename = msg['filename']
|
||||
wavname = os.path.splitext(os.path.basename(filename))[0]
|
||||
fin = wave.open(filename, 'rb')
|
||||
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
|
||||
fin.close()
|
||||
|
||||
decoded = ds.stt(audio)
|
||||
|
||||
queue_out.put({'wav': wavname, 'prediction': decoded, 'ground_truth': msg['transcript']})
|
||||
queue_out.put({'wav': filename, 'prediction': decoded, 'ground_truth': msg['transcript']})
|
||||
except FileNotFoundError as ex:
|
||||
print('FileNotFoundError: ', ex)
|
||||
|
||||
print(queue_out.qsize(), end='\r') # Update the current progress
|
||||
queue_in.task_done()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Computing TFLite accuracy')
|
||||
parser.add_argument('--model', required=True,
|
||||
help='Path to the model (protocol buffer binary file)')
|
||||
parser.add_argument('--lm', required=True,
|
||||
help='Path to the language model binary file')
|
||||
parser.add_argument('--trie', required=True,
|
||||
help='Path to the language model trie file created with native_client/generate_trie')
|
||||
parser.add_argument('--csv', required=True,
|
||||
help='Path to the CSV source file')
|
||||
parser.add_argument('--proc', required=False, default=cpu_count(), type=int,
|
||||
help='Number of processes to spawn, defaulting to number of CPUs')
|
||||
parser.add_argument('--dump', required=False, action='store_true', default=False,
|
||||
help='Dump the results as text file, with one line for each wav: "wav transcription"')
|
||||
args = parser.parse_args()
|
||||
|
||||
def main(args, _):
|
||||
manager = Manager()
|
||||
work_todo = JoinableQueue() # this is where we are going to store input data
|
||||
work_done = manager.Queue() # this where we are gonna push them out
|
||||
@ -93,6 +81,9 @@ def main():
|
||||
count = 0
|
||||
for row in csvreader:
|
||||
count += 1
|
||||
# Relative paths are relative to the folder the CSV file is in
|
||||
if not os.path.isabs(row['wav_filename']):
|
||||
row['wav_filename'] = os.path.join(os.path.dirname(args.csv), row['wav_filename'])
|
||||
work_todo.put({'filename': row['wav_filename'], 'transcript': row['transcript']})
|
||||
wav_filenames.extend(row['wav_filename'])
|
||||
|
||||
@ -114,12 +105,32 @@ def main():
|
||||
(wer, cer, mean_loss))
|
||||
|
||||
if args.dump:
|
||||
with open(args.csv + '.txt', 'w') as ftxt, open(args.csv + '.out', 'w') as fout:
|
||||
with open(args.dump + '.txt', 'w') as ftxt, open(args.dump + '.out', 'w') as fout:
|
||||
for wav, txt, out in zip(wavlist, ground_truths, predictions):
|
||||
ftxt.write('%s %s\n' % (wav, txt))
|
||||
fout.write('%s %s\n' % (wav, out))
|
||||
print('Reference texts dumped to %s.txt' % args.csv)
|
||||
print('Transcription dumped to %s.out' % args.csv)
|
||||
print('Reference texts dumped to %s.txt' % args.dump)
|
||||
print('Transcription dumped to %s.out' % args.dump)
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Computing TFLite accuracy')
|
||||
parser.add_argument('--model', required=True,
|
||||
help='Path to the model (protocol buffer binary file)')
|
||||
parser.add_argument('--lm', required=True,
|
||||
help='Path to the language model binary file')
|
||||
parser.add_argument('--trie', required=True,
|
||||
help='Path to the language model trie file created with native_client/generate_trie')
|
||||
parser.add_argument('--csv', required=True,
|
||||
help='Path to the CSV source file')
|
||||
parser.add_argument('--proc', required=False, default=cpu_count(), type=int,
|
||||
help='Number of processes to spawn, defaulting to number of CPUs')
|
||||
parser.add_argument('--dump', required=False,
|
||||
help='Path to dump the results as text file, with one line for each wav: "wav transcription".')
|
||||
args, unknown = parser.parse_known_args()
|
||||
# Reconstruct argv for absl.flags
|
||||
sys.argv = [sys.argv[0]] + unknown
|
||||
return args
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
create_flags()
|
||||
absl.app.run(partial(main, parse_args()))
|
||||
|
@ -1,7 +1,8 @@
|
||||
attrdict==2.0.0
|
||||
absl-py==0.9.0
|
||||
attrdict==2.0.1
|
||||
deepspeech
|
||||
numpy==1.16.0
|
||||
pkg-resources==0.0.0
|
||||
progressbar2==3.39.2
|
||||
progressbar2==3.47.0
|
||||
python-utils==2.3.0
|
||||
six==1.12.0
|
||||
six==1.13.0
|
||||
pandas==0.25.3
|
||||
|
54
taskcluster/tc-evaluate_tflite.sh
Executable file
54
taskcluster/tc-evaluate_tflite.sh
Executable file
@ -0,0 +1,54 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/tc-tests-utils.sh
|
||||
|
||||
extract_python_versions "$1" "pyver" "pyver_pkg" "py_unicode_type" "pyconf" "pyalias"
|
||||
|
||||
bitrate=$2
|
||||
set_ldc_sample_filename "${bitrate}"
|
||||
|
||||
unset PYTHON_BIN_PATH
|
||||
unset PYTHONPATH
|
||||
|
||||
if [ -d "${DS_ROOT_TASK}/pyenv.cache/" ]; then
|
||||
export PYENV_ROOT="${DS_ROOT_TASK}/pyenv.cache/ds-test/.pyenv"
|
||||
else
|
||||
export PYENV_ROOT="${DS_ROOT_TASK}/ds-test/.pyenv"
|
||||
fi;
|
||||
|
||||
export PATH="${PYENV_ROOT}/bin:$PATH"
|
||||
|
||||
mkdir -p ${PYENV_ROOT} || true
|
||||
|
||||
download_data
|
||||
|
||||
install_pyenv "${PYENV_ROOT}"
|
||||
install_pyenv_virtualenv "$(pyenv root)/plugins/pyenv-virtualenv"
|
||||
|
||||
maybe_ssl102_py37 ${pyver}
|
||||
|
||||
maybe_numpy_min_version_winamd64 ${pyver}
|
||||
|
||||
PYENV_NAME=deepspeech-test
|
||||
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH PYTHON_CONFIGURE_OPTS="--enable-unicode=${pyconf} ${PY37_OPENSSL} ${EXTRA_PYTHON_CONFIGURE_OPTS}" pyenv_install ${pyver} ${pyalias}
|
||||
|
||||
setup_pyenv_virtualenv "${pyalias}" "${PYENV_NAME}"
|
||||
virtualenv_activate "${pyalias}" "${PYENV_NAME}"
|
||||
|
||||
deepspeech_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type})
|
||||
set -o pipefail
|
||||
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: ${PY37_SOURCE_PACKAGE} --upgrade ${deepspeech_pkg_url} | cat
|
||||
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
|
||||
set +o pipefail
|
||||
|
||||
which deepspeech
|
||||
deepspeech --version
|
||||
|
||||
pushd ${HOME}/DeepSpeech/ds/
|
||||
python bin/import_ldc93s1.py data/smoke_test
|
||||
python evaluate_tflite.py --model "${TASKCLUSTER_TMP_DIR}/${model_name_mmap}" --lm data/smoke_test/vocab.pruned.lm --trie data/smoke_test/vocab.trie --csv data/smoke_test/ldc93s1.csv
|
||||
popd
|
||||
|
||||
virtualenv_deactivate "${pyalias}" "${PYENV_NAME}"
|
14
taskcluster/test-evaluate_tflite-linux-amd64-py36m-opt.yml
Normal file
14
taskcluster/test-evaluate_tflite-linux-amd64-py36m-opt.yml
Normal file
@ -0,0 +1,14 @@
|
||||
build:
|
||||
template_file: test-linux-opt-base.tyml
|
||||
dependencies:
|
||||
- "linux-amd64-cpu-opt"
|
||||
- "test-training_16k-linux-amd64-py36m-opt"
|
||||
test_model_task: "test-training_16k-linux-amd64-py36m-opt"
|
||||
system_setup:
|
||||
>
|
||||
apt-get -qq -y install ${python.packages_trusty.apt}
|
||||
args:
|
||||
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-evaluate_tflite.sh 3.6.4:m 16k"
|
||||
metadata:
|
||||
name: "DeepSpeech Linux AMD64 CPU evaluate_tflite.py Py3.6 (16kHz)"
|
||||
description: "Test evaluate_tflite.py on Linux/AMD64 using upstream TensorFlow Python 3.6, CPU only, optimized version"
|
Loading…
x
Reference in New Issue
Block a user