Merge pull request #2664 from mozilla/evaluate_tflite_fixes
evaluate_tflite.py fixes
This commit is contained in:
		
						commit
						faed282cfc
					
				| @ -2,17 +2,21 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| from __future__ import absolute_import, division, print_function | ||||
| 
 | ||||
| import absl.app | ||||
| import argparse | ||||
| import numpy as np | ||||
| import wave | ||||
| import csv | ||||
| import os | ||||
| import sys | ||||
| 
 | ||||
| from functools import partial | ||||
| from six.moves import zip, range | ||||
| from multiprocessing import JoinableQueue, Process, cpu_count, Manager | ||||
| from deepspeech import Model | ||||
| 
 | ||||
| from util.evaluate_tools import calculate_report | ||||
| from util.flags import create_flags | ||||
| 
 | ||||
| r''' | ||||
| This module should be self-contained: | ||||
| @ -40,36 +44,20 @@ def tflite_worker(model, lm, trie, queue_in, queue_out, gpu_mask): | ||||
|             msg = queue_in.get() | ||||
| 
 | ||||
|             filename = msg['filename'] | ||||
|             wavname = os.path.splitext(os.path.basename(filename))[0] | ||||
|             fin = wave.open(filename, 'rb') | ||||
|             audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16) | ||||
|             fin.close() | ||||
| 
 | ||||
|             decoded = ds.stt(audio) | ||||
| 
 | ||||
|             queue_out.put({'wav': wavname, 'prediction': decoded, 'ground_truth': msg['transcript']}) | ||||
|             queue_out.put({'wav': filename, 'prediction': decoded, 'ground_truth': msg['transcript']}) | ||||
|         except FileNotFoundError as ex: | ||||
|             print('FileNotFoundError: ', ex) | ||||
| 
 | ||||
|         print(queue_out.qsize(), end='\r') # Update the current progress | ||||
|         queue_in.task_done() | ||||
| 
 | ||||
| def main(): | ||||
|     parser = argparse.ArgumentParser(description='Computing TFLite accuracy') | ||||
|     parser.add_argument('--model', required=True, | ||||
|                         help='Path to the model (protocol buffer binary file)') | ||||
|     parser.add_argument('--lm', required=True, | ||||
|                         help='Path to the language model binary file') | ||||
|     parser.add_argument('--trie', required=True, | ||||
|                         help='Path to the language model trie file created with native_client/generate_trie') | ||||
|     parser.add_argument('--csv', required=True, | ||||
|                         help='Path to the CSV source file') | ||||
|     parser.add_argument('--proc', required=False, default=cpu_count(), type=int, | ||||
|                         help='Number of processes to spawn, defaulting to number of CPUs') | ||||
|     parser.add_argument('--dump', required=False, action='store_true', default=False, | ||||
|                         help='Dump the results as text file, with one line for each wav: "wav transcription"') | ||||
|     args = parser.parse_args() | ||||
| 
 | ||||
| def main(args, _): | ||||
|     manager = Manager() | ||||
|     work_todo = JoinableQueue()   # this is where we are going to store input data | ||||
|     work_done = manager.Queue()  # this where we are gonna push them out | ||||
| @ -93,6 +81,9 @@ def main(): | ||||
|         count = 0 | ||||
|         for row in csvreader: | ||||
|             count += 1 | ||||
|             # Relative paths are relative to the folder the CSV file is in | ||||
|             if not os.path.isabs(row['wav_filename']): | ||||
|                 row['wav_filename'] = os.path.join(os.path.dirname(args.csv), row['wav_filename']) | ||||
|             work_todo.put({'filename': row['wav_filename'], 'transcript': row['transcript']}) | ||||
|             wav_filenames.extend(row['wav_filename']) | ||||
| 
 | ||||
| @ -114,12 +105,32 @@ def main(): | ||||
|           (wer, cer, mean_loss)) | ||||
| 
 | ||||
|     if args.dump: | ||||
|         with open(args.csv + '.txt', 'w') as ftxt, open(args.csv + '.out', 'w') as fout: | ||||
|         with open(args.dump + '.txt', 'w') as ftxt, open(args.dump + '.out', 'w') as fout: | ||||
|             for wav, txt, out in zip(wavlist, ground_truths, predictions): | ||||
|                 ftxt.write('%s %s\n' % (wav, txt)) | ||||
|                 fout.write('%s %s\n' % (wav, out)) | ||||
|             print('Reference texts dumped to %s.txt' % args.csv) | ||||
|             print('Transcription   dumped to %s.out' % args.csv) | ||||
|             print('Reference texts dumped to %s.txt' % args.dump) | ||||
|             print('Transcription   dumped to %s.out' % args.dump) | ||||
| 
 | ||||
| def parse_args(): | ||||
|     parser = argparse.ArgumentParser(description='Computing TFLite accuracy') | ||||
|     parser.add_argument('--model', required=True, | ||||
|                         help='Path to the model (protocol buffer binary file)') | ||||
|     parser.add_argument('--lm', required=True, | ||||
|                         help='Path to the language model binary file') | ||||
|     parser.add_argument('--trie', required=True, | ||||
|                         help='Path to the language model trie file created with native_client/generate_trie') | ||||
|     parser.add_argument('--csv', required=True, | ||||
|                         help='Path to the CSV source file') | ||||
|     parser.add_argument('--proc', required=False, default=cpu_count(), type=int, | ||||
|                         help='Number of processes to spawn, defaulting to number of CPUs') | ||||
|     parser.add_argument('--dump', required=False, | ||||
|                         help='Path to dump the results as text file, with one line for each wav: "wav transcription".') | ||||
|     args, unknown = parser.parse_known_args() | ||||
|     # Reconstruct argv for absl.flags | ||||
|     sys.argv = [sys.argv[0]] + unknown | ||||
|     return args | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     main() | ||||
|     create_flags() | ||||
|     absl.app.run(partial(main, parse_args())) | ||||
|  | ||||
| @ -1,7 +1,8 @@ | ||||
| attrdict==2.0.0 | ||||
| absl-py==0.9.0 | ||||
| attrdict==2.0.1 | ||||
| deepspeech | ||||
| numpy==1.16.0 | ||||
| pkg-resources==0.0.0 | ||||
| progressbar2==3.39.2 | ||||
| progressbar2==3.47.0 | ||||
| python-utils==2.3.0 | ||||
| six==1.12.0 | ||||
| six==1.13.0 | ||||
| pandas==0.25.3 | ||||
|  | ||||
							
								
								
									
										54
									
								
								taskcluster/tc-evaluate_tflite.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										54
									
								
								taskcluster/tc-evaluate_tflite.sh
									
									
									
									
									
										Executable file
									
								
							| @ -0,0 +1,54 @@ | ||||
| #!/bin/bash | ||||
| 
 | ||||
| set -xe | ||||
| 
 | ||||
| source $(dirname "$0")/tc-tests-utils.sh | ||||
| 
 | ||||
| extract_python_versions "$1" "pyver" "pyver_pkg" "py_unicode_type" "pyconf" "pyalias" | ||||
| 
 | ||||
| bitrate=$2 | ||||
| set_ldc_sample_filename "${bitrate}" | ||||
| 
 | ||||
| unset PYTHON_BIN_PATH | ||||
| unset PYTHONPATH | ||||
| 
 | ||||
| if [ -d "${DS_ROOT_TASK}/pyenv.cache/" ]; then | ||||
|   export PYENV_ROOT="${DS_ROOT_TASK}/pyenv.cache/ds-test/.pyenv" | ||||
| else | ||||
|   export PYENV_ROOT="${DS_ROOT_TASK}/ds-test/.pyenv" | ||||
| fi; | ||||
| 
 | ||||
| export PATH="${PYENV_ROOT}/bin:$PATH" | ||||
| 
 | ||||
| mkdir -p ${PYENV_ROOT} || true | ||||
| 
 | ||||
| download_data | ||||
| 
 | ||||
| install_pyenv "${PYENV_ROOT}" | ||||
| install_pyenv_virtualenv "$(pyenv root)/plugins/pyenv-virtualenv" | ||||
| 
 | ||||
| maybe_ssl102_py37 ${pyver} | ||||
| 
 | ||||
| maybe_numpy_min_version_winamd64 ${pyver} | ||||
| 
 | ||||
| PYENV_NAME=deepspeech-test | ||||
| LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH PYTHON_CONFIGURE_OPTS="--enable-unicode=${pyconf} ${PY37_OPENSSL} ${EXTRA_PYTHON_CONFIGURE_OPTS}" pyenv_install ${pyver} ${pyalias} | ||||
| 
 | ||||
| setup_pyenv_virtualenv "${pyalias}" "${PYENV_NAME}" | ||||
| virtualenv_activate "${pyalias}" "${PYENV_NAME}" | ||||
| 
 | ||||
| deepspeech_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type}) | ||||
| set -o pipefail | ||||
| LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: ${PY37_SOURCE_PACKAGE} --upgrade ${deepspeech_pkg_url} | cat | ||||
| pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat | ||||
| set +o pipefail | ||||
| 
 | ||||
| which deepspeech | ||||
| deepspeech --version | ||||
| 
 | ||||
| pushd ${HOME}/DeepSpeech/ds/ | ||||
|     python bin/import_ldc93s1.py data/smoke_test | ||||
|     python evaluate_tflite.py --model "${TASKCLUSTER_TMP_DIR}/${model_name_mmap}" --lm data/smoke_test/vocab.pruned.lm --trie data/smoke_test/vocab.trie --csv data/smoke_test/ldc93s1.csv | ||||
| popd | ||||
| 
 | ||||
| virtualenv_deactivate "${pyalias}" "${PYENV_NAME}" | ||||
							
								
								
									
										14
									
								
								taskcluster/test-evaluate_tflite-linux-amd64-py36m-opt.yml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								taskcluster/test-evaluate_tflite-linux-amd64-py36m-opt.yml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,14 @@ | ||||
| build: | ||||
|   template_file: test-linux-opt-base.tyml | ||||
|   dependencies: | ||||
|     - "linux-amd64-cpu-opt" | ||||
|     - "test-training_16k-linux-amd64-py36m-opt" | ||||
|   test_model_task: "test-training_16k-linux-amd64-py36m-opt" | ||||
|   system_setup: | ||||
|     > | ||||
|       apt-get -qq -y install ${python.packages_trusty.apt} | ||||
|   args: | ||||
|     tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-evaluate_tflite.sh 3.6.4:m 16k" | ||||
|   metadata: | ||||
|     name: "DeepSpeech Linux AMD64 CPU evaluate_tflite.py Py3.6 (16kHz)" | ||||
|     description: "Test evaluate_tflite.py on Linux/AMD64 using upstream TensorFlow Python 3.6, CPU only, optimized version" | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user