Remove unused benchmark_nc
This commit is contained in:
		
							parent
							
								
									0428c846f2
								
							
						
					
					
						commit
						dc25509950
					
				| @ -1,504 +0,0 @@ | ||||
| #!/usr/bin/env python | ||||
| # -*- coding: utf-8 -*- | ||||
| 
 | ||||
| from __future__ import absolute_import, division, print_function | ||||
| 
 | ||||
| import os | ||||
| import sys | ||||
| 
 | ||||
| # To use util.tc | ||||
| sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(sys.argv[0])))) | ||||
| import util.taskcluster as tcu | ||||
| from util.helpers import keep_only_digits | ||||
| 
 | ||||
| import paramiko | ||||
| import argparse | ||||
| import tempfile | ||||
| import shutil | ||||
| import subprocess | ||||
| import stat | ||||
| import numpy | ||||
| import matplotlib.pyplot as plt | ||||
| import scipy.stats as scipy_stats | ||||
| import csv | ||||
| import getpass | ||||
| import zipfile | ||||
| 
 | ||||
| from six import iteritems | ||||
| from six.moves import range, map | ||||
| from functools import cmp_to_key | ||||
| 
 | ||||
| r''' | ||||
|  Tool to: | ||||
|   - remote local or remote (ssh) native_client | ||||
|   - handles copying models (as protocolbuffer files) | ||||
|   - run native_client in benchmark mode | ||||
|   - collect timing results | ||||
|   - compute mean values (with wariances) | ||||
|   - output as CSV | ||||
| ''' | ||||
| 
 | ||||
| ssh_conn = None | ||||
| def exec_command(command, cwd=None): | ||||
|     r''' | ||||
|     Helper to exec locally (subprocess) or remotely (paramiko) | ||||
|     ''' | ||||
| 
 | ||||
|     rc = None | ||||
|     stdout = stderr = None | ||||
|     if ssh_conn is None: | ||||
|         ld_library_path = {'LD_LIBRARY_PATH': '.:%s' % os.environ.get('LD_LIBRARY_PATH', '')} | ||||
|         p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=ld_library_path, cwd=cwd) | ||||
|         stdout, stderr = p.communicate() | ||||
|         rc = p.returncode | ||||
|     else: | ||||
|         # environment= requires paramiko >= 2.1 (fails with 2.0.2) | ||||
|         final_command = command if cwd is None else 'cd %s && %s %s' % (cwd, 'LD_LIBRARY_PATH=.:$LD_LIBRARY_PATH', command) | ||||
|         ssh_stdin, ssh_stdout, ssh_stderr = ssh_conn.exec_command(final_command) | ||||
|         stdout = ''.join(ssh_stdout.readlines()) | ||||
|         stderr = ''.join(ssh_stderr.readlines()) | ||||
|         rc = ssh_stdout.channel.recv_exit_status() | ||||
| 
 | ||||
|     return rc, stdout, stderr | ||||
| 
 | ||||
| def assert_valid_dir(dir): | ||||
|     if dir is None: | ||||
|         raise AssertionError('Invalid temp directory') | ||||
|     return True | ||||
| 
 | ||||
| def get_arch_string(): | ||||
|     r''' | ||||
|     Check local or remote system arch, to produce TaskCluster proper link. | ||||
|     ''' | ||||
|     rc, stdout, stderr = exec_command('uname -sm') | ||||
|     if rc > 0: | ||||
|         raise AssertionError('Error checking OS') | ||||
| 
 | ||||
|     stdout = stdout.lower().strip() | ||||
|     if not b'linux' in stdout: | ||||
|         raise AssertionError('Unsupported OS') | ||||
| 
 | ||||
|     if b'armv7l' in stdout: | ||||
|         return 'arm' | ||||
| 
 | ||||
|     if b'x86_64' in stdout: | ||||
|         nv_rc, nv_stdout, nv_stderr = exec_command('nvidia-smi') | ||||
|         nv_stdout = nv_stdout.lower().strip() | ||||
|         if b'NVIDIA-SMI' in nv_stdout: | ||||
|             return 'gpu' | ||||
|         else: | ||||
|             return 'cpu' | ||||
| 
 | ||||
|     raise AssertionError('Unsupported arch:', stdout) | ||||
| 
 | ||||
| def maybe_download_binaries(dir): | ||||
|     assert_valid_dir(dir) | ||||
|     tcu.maybe_download_tc(target_dir=dir, tc_url=tcu.get_tc_url(get_arch_string()), progress=True) | ||||
| 
 | ||||
| def extract_native_client_tarball(dir): | ||||
|     r''' | ||||
|     Download a native_client.tar.xz file from TaskCluster and extract it to dir. | ||||
|     ''' | ||||
|     assert_valid_dir(dir) | ||||
| 
 | ||||
|     target_tarball = os.path.join(dir, 'native_client.tar.xz') | ||||
|     if os.path.isfile(target_tarball) and os.stat(target_tarball).st_size == 0: | ||||
|         return | ||||
| 
 | ||||
|     subprocess.check_call(['pixz', '-d', 'native_client.tar.xz'], cwd=dir) | ||||
|     subprocess.check_call(['tar', 'xf', 'native_client.tar'], cwd=dir) | ||||
|     os.unlink(os.path.join(dir, 'native_client.tar')) | ||||
|     open(target_tarball, 'w').close() | ||||
| 
 | ||||
| def is_zip_file(models): | ||||
|     r''' | ||||
|     Ensure that a path is a zip file by: | ||||
|      - checking length is 1 | ||||
|      - checking extension is '.zip' | ||||
|     ''' | ||||
|     ext = os.path.splitext(models[0])[1] | ||||
|     return (len(models) == 1) and (ext == '.zip') | ||||
| 
 | ||||
| def maybe_inspect_zip(models): | ||||
|     r''' | ||||
|     Detect if models is a list of protocolbuffer files or a ZIP file. | ||||
|     If the latter, then unzip it and return the list of protocolbuffer files | ||||
|     that were inside. | ||||
|     ''' | ||||
| 
 | ||||
|     if not(is_zip_file(models)): | ||||
|         return models | ||||
| 
 | ||||
|     if len(models) > 1: | ||||
|         return models | ||||
| 
 | ||||
|     if len(models) < 1: | ||||
|         raise AssertionError('No models at all') | ||||
| 
 | ||||
|     return zipfile.ZipFile(models[0]).namelist() | ||||
| 
 | ||||
| def all_files(models=[]): | ||||
|     r''' | ||||
|     Return a list of full path of files matching 'models', sorted in human | ||||
|     numerical order (i.e., 0 1 2 ..., 10 11 12, ..., 100, ..., 1000). | ||||
| 
 | ||||
|     Files are supposed to be named identically except one variable component | ||||
|     e.g. the list, | ||||
|       test.weights.e5.lstm1200.ldc93s1.pb | ||||
|       test.weights.e5.lstm1000.ldc93s1.pb | ||||
|       test.weights.e5.lstm800.ldc93s1.pb | ||||
|     gets sorted: | ||||
|       test.weights.e5.lstm800.ldc93s1.pb | ||||
|       test.weights.e5.lstm1000.ldc93s1.pb | ||||
|       test.weights.e5.lstm1200.ldc93s1.pb | ||||
|     ''' | ||||
| 
 | ||||
|     def nsort(a, b): | ||||
|         fa = os.path.basename(a).split('.') | ||||
|         fb = os.path.basename(b).split('.') | ||||
|         elements_to_remove = [] | ||||
| 
 | ||||
|         assert len(fa) == len(fb) | ||||
| 
 | ||||
|         for i in range(0, len(fa)): | ||||
|             if fa[i] == fb[i]: | ||||
|                 elements_to_remove.append(fa[i]) | ||||
| 
 | ||||
|         for e in elements_to_remove: | ||||
|             fa.remove(e) | ||||
|             fb.remove(e) | ||||
| 
 | ||||
|         assert len(fa) == len(fb) | ||||
|         assert len(fa) == 1 | ||||
| 
 | ||||
|         fa = int(keep_only_digits(fa[0])) | ||||
|         fb = int(keep_only_digits(fb[0])) | ||||
| 
 | ||||
|         if fa < fb: | ||||
|             return -1 | ||||
|         if fa == fb: | ||||
|             return 0 | ||||
|         if fa > fb: | ||||
|             return 1 | ||||
| 
 | ||||
|     base = list(map(lambda x: os.path.abspath(x), maybe_inspect_zip(models))) | ||||
|     base.sort(key=cmp_to_key(nsort)) | ||||
| 
 | ||||
|     return base | ||||
| 
 | ||||
| def copy_tree(dir): | ||||
|     assert_valid_dir(dir) | ||||
| 
 | ||||
|     sftp = ssh_conn.open_sftp() | ||||
|     # IOError will get triggered if the path does not exists remotely | ||||
|     try: | ||||
|         if stat.S_ISDIR(sftp.stat(dir).st_mode): | ||||
|             print('Directory already existent: %s' % dir) | ||||
|     except IOError: | ||||
|         print('Creating remote directory: %s' % dir) | ||||
|         sftp.mkdir(dir) | ||||
| 
 | ||||
|     print('Copy files to remote') | ||||
|     for fname in os.listdir(dir): | ||||
|         fullpath = os.path.join(dir, fname) | ||||
|         local_stat  = os.stat(fullpath) | ||||
|         try: | ||||
|             remote_mode = sftp.stat(fullpath).st_mode | ||||
|         except IOError: | ||||
|             remote_mode = 0 | ||||
| 
 | ||||
|         if not stat.S_ISREG(remote_mode): | ||||
|             print('Copying %s ...' % fullpath) | ||||
|             remote_mode = sftp.put(fullpath, fullpath, confirm=True).st_mode | ||||
| 
 | ||||
|         if local_stat.st_mode != remote_mode: | ||||
|             print('Setting proper remote mode: %s' % local_stat.st_mode) | ||||
|             sftp.chmod(fullpath, local_stat.st_mode) | ||||
| 
 | ||||
|     sftp.close() | ||||
| 
 | ||||
| def delete_tree(dir): | ||||
|     assert_valid_dir(dir) | ||||
| 
 | ||||
|     sftp = ssh_conn.open_sftp() | ||||
|     # IOError will get triggered if the path does not exists remotely | ||||
|     try: | ||||
|         if stat.S_ISDIR(sftp.stat(dir).st_mode): | ||||
|             print('Removing remote files') | ||||
|             for fname in sftp.listdir(dir): | ||||
|                 fullpath = os.path.join(dir, fname) | ||||
|                 remote_stat = sftp.stat(fullpath) | ||||
|                 if stat.S_ISREG(remote_stat.st_mode): | ||||
|                     print('Removing %s ...' % fullpath) | ||||
|                     sftp.remove(fullpath) | ||||
| 
 | ||||
|             print('Removing directory %s ...' % dir) | ||||
|             sftp.rmdir(dir) | ||||
| 
 | ||||
|         sftp.close() | ||||
|     except IOError: | ||||
|         print('No remote directory: %s' % dir) | ||||
| 
 | ||||
| def setup_tempdir(dir, models, wav, lm_binary, trie, binaries): | ||||
|     r''' | ||||
|     Copy models, libs and binary to a directory (new one if dir is None) | ||||
|     ''' | ||||
|     if dir is None: | ||||
|         dir = tempfile.mkdtemp(suffix='dsbench') | ||||
| 
 | ||||
|     sorted_models = all_files(models=models) | ||||
|     if binaries is None: | ||||
|         maybe_download_binaries(dir) | ||||
|     else: | ||||
|         print('Using local binaries: %s' % (binaries)) | ||||
|         shutil.copy2(binaries, dir) | ||||
|     extract_native_client_tarball(dir) | ||||
| 
 | ||||
|     filenames = map(lambda x: os.path.join(dir, os.path.basename(x)), sorted_models) | ||||
|     missing_models = list(filter(lambda x: not os.path.isfile(x), filenames)) | ||||
|     if len(missing_models) > 0: | ||||
|         # If we have a ZIP file, directly extract it to the proper path | ||||
|         if is_zip_file(models): | ||||
|             print('Extracting %s to %s' % (models[0], dir)) | ||||
|             zipfile.ZipFile(models[0]).extractall(path=dir) | ||||
|             print('Extracted %s.' % models[0]) | ||||
|         else: | ||||
|             # If one model is missing, let's copy everything again. Be safe. | ||||
|             for f in sorted_models: | ||||
|                 print('Copying %s to %s' % (f, dir)) | ||||
|                 shutil.copy2(f, dir) | ||||
| 
 | ||||
|     for extra_file in [ wav, lm_binary, trie ]: | ||||
|         if extra_file and not os.path.isfile(os.path.join(dir, os.path.basename(extra_file))): | ||||
|             print('Copying %s to %s' % (extra_file, dir)) | ||||
|             shutil.copy2(extra_file, dir) | ||||
| 
 | ||||
|     if ssh_conn: | ||||
|         copy_tree(dir) | ||||
| 
 | ||||
|     return dir, sorted_models | ||||
| 
 | ||||
| def teardown_tempdir(dir): | ||||
|     r''' | ||||
|     Cleanup temporary directory. | ||||
|     ''' | ||||
| 
 | ||||
|     if ssh_conn: | ||||
|         delete_tree(dir) | ||||
| 
 | ||||
|     assert_valid_dir(dir) | ||||
|     shutil.rmtree(dir) | ||||
| 
 | ||||
| def get_sshconfig(): | ||||
|     r''' | ||||
|     Read user's SSH configuration file | ||||
|     ''' | ||||
| 
 | ||||
|     with open(os.path.expanduser('~/.ssh/config')) as f: | ||||
|         cfg = paramiko.SSHConfig() | ||||
|         cfg.parse(f) | ||||
|         ret_dict = {} | ||||
|         for d in cfg._config: | ||||
|             _copy = dict(d) | ||||
|             # Avoid buggy behavior with strange host definitions, we need | ||||
|             # Hostname and not Host. | ||||
|             del _copy['host'] | ||||
|             for host in d['host']: | ||||
|                 ret_dict[host] = _copy['config'] | ||||
| 
 | ||||
|         return ret_dict | ||||
| 
 | ||||
| def establish_ssh(target=None, auto_trust=False, allow_agent=True, look_keys=True): | ||||
|     r''' | ||||
|     Establish a SSH connection to a remote host. It should be able to use | ||||
|     SSH's config file Host name declarations. By default, will not automatically | ||||
|     add trust for hosts, will use SSH agent and will try to load keys. | ||||
|     ''' | ||||
| 
 | ||||
|     def password_prompt(username, hostname): | ||||
|         r''' | ||||
|         If the Host is relying on password authentication, lets ask it. | ||||
|         Relying on SSH itself to take care of that would not work when the | ||||
|         remote authentication is password behind a SSH-key+2FA jumphost. | ||||
|         ''' | ||||
|         return getpass.getpass('No SSH key for %s@%s, please provide password: ' % (username, hostname)) | ||||
| 
 | ||||
|     ssh_conn = None | ||||
|     if target is not None: | ||||
|         ssh_conf = get_sshconfig() | ||||
|         cfg = { | ||||
|             'hostname': None, | ||||
|             'port': 22, | ||||
|             'allow_agent': allow_agent, | ||||
|             'look_for_keys': look_keys | ||||
|         } | ||||
|         if ssh_conf.has_key(target): | ||||
|             user_config = ssh_conf.get(target) | ||||
| 
 | ||||
|             # If ssh_config file's Host defined 'User' instead of 'Username' | ||||
|             if user_config.has_key('user') and not user_config.has_key('username'): | ||||
|                 user_config['username'] = user_config['user'] | ||||
|                 del user_config['user'] | ||||
| 
 | ||||
|             for k in ('username', 'hostname', 'port'): | ||||
|                 if k in user_config: | ||||
|                     cfg[k] = user_config[k] | ||||
| 
 | ||||
|             # Assume Password auth. If we don't do that, then when connecting | ||||
|             # through a jumphost we will run into issues and the user will | ||||
|             # not be able to input his password to the SSH prompt. | ||||
|             if 'identityfile' in user_config: | ||||
|                 cfg['key_filename'] = user_config['identityfile'] | ||||
|             else: | ||||
|                 cfg['password'] = password_prompt(cfg['username'], cfg['hostname'] or target) | ||||
| 
 | ||||
|             # Should be the last one, since ProxyCommand will issue connection to remote host | ||||
|             if 'proxycommand' in user_config: | ||||
|                 cfg['sock'] = paramiko.ProxyCommand(user_config['proxycommand']) | ||||
| 
 | ||||
|         else: | ||||
|             cfg['username'] = target.split('@')[0] | ||||
|             cfg['hostname'] = target.split('@')[1].split(':')[0] | ||||
|             cfg['password'] = password_prompt(cfg['username'], cfg['hostname']) | ||||
|             try: | ||||
|                 cfg['port'] = int(target.split('@')[1].split(':')[1]) | ||||
|             except IndexError: | ||||
|                 # IndexError will happen if no :PORT is there. | ||||
|                 # Default value 22 is defined above in 'cfg'. | ||||
|                 pass | ||||
| 
 | ||||
|         ssh_conn = paramiko.SSHClient() | ||||
|         if auto_trust: | ||||
|             ssh_conn.set_missing_host_key_policy(paramiko.AutoAddPolicy()) | ||||
| 
 | ||||
|         ssh_conn.connect(**cfg) | ||||
| 
 | ||||
|     return ssh_conn | ||||
| 
 | ||||
| def run_benchmarks(dir, models, wav, lm_binary=None, trie=None, iters=-1): | ||||
|     r''' | ||||
|     Core of the running of the benchmarks. We will run on all of models, against | ||||
|     the WAV file provided as wav. | ||||
|     ''' | ||||
| 
 | ||||
|     assert_valid_dir(dir) | ||||
| 
 | ||||
|     inference_times = [ ] | ||||
| 
 | ||||
|     for model in models: | ||||
|         model_filename = model | ||||
| 
 | ||||
|         current_model = { | ||||
|           'name':   model, | ||||
|           'iters':  [ ], | ||||
|           'mean':   numpy.infty, | ||||
|           'stddev': numpy.infty | ||||
|         } | ||||
| 
 | ||||
|         if lm_binary and trie: | ||||
|             cmdline = './deepspeech --model "%s" --lm "%s" --trie "%s" --audio "%s" -t' % (model_filename, lm_binary, trie, wav) | ||||
|         else: | ||||
|             cmdline = './deepspeech --model "%s" --audio "%s" -t' % (model_filename, wav) | ||||
| 
 | ||||
|         for it in range(iters): | ||||
|             sys.stdout.write('\rRunning %s: %d/%d' % (os.path.basename(model), (it+1), iters)) | ||||
|             sys.stdout.flush() | ||||
|             rc, stdout, stderr = exec_command(cmdline, cwd=dir) | ||||
|             if rc == 0: | ||||
|                 inference_time = float(stdout.split(b'\n')[1].split(b'=')[-1]) | ||||
|                 # print("[%d] model=%s inference=%f" % (it, model, inference_time)) | ||||
|                 current_model['iters'].append(inference_time) | ||||
|             else: | ||||
|                 print('exec_command("%s") failed with rc=%d' % (cmdline, rc)) | ||||
|                 print('stdout: %s' % stdout) | ||||
|                 print('stderr: %s' % stderr) | ||||
|                 raise AssertionError('Execution failure: rc=%d' % (rc)) | ||||
| 
 | ||||
|         sys.stdout.write('\n') | ||||
|         sys.stdout.flush() | ||||
|         current_model['mean']   = numpy.mean(current_model['iters']) | ||||
|         current_model['stddev'] = numpy.std(current_model['iters']) | ||||
|         inference_times.append(current_model) | ||||
| 
 | ||||
|     return inference_times | ||||
| 
 | ||||
| def produce_csv(input, output): | ||||
|     r''' | ||||
|     Take an input dictionnary and write it to the object-file output. | ||||
|     ''' | ||||
|     output.write('"model","mean","std"\n') | ||||
|     for model_data in input: | ||||
|         output.write('"%s",%f,%f\n' % (model_data['name'], model_data['mean'], model_data['stddev'])) | ||||
|     output.flush() | ||||
|     output.close() | ||||
|     print("Wrote as %s" % output.name) | ||||
| 
 | ||||
| def handle_args(): | ||||
|     parser = argparse.ArgumentParser(description='Benchmarking tooling for DeepSpeech native_client.') | ||||
|     parser.add_argument('--target', required=False, | ||||
|                                  help='SSH user:pass@host string for remote benchmarking. This can also be a name of a matching \'Host\' in your SSH config.') | ||||
|     parser.add_argument('--autotrust', action='store_true', default=False, | ||||
|                                  help='SSH Paramiko policy to automatically trust unknown keys.') | ||||
|     parser.add_argument('--allowagent', action='store_true', dest='allowagent', | ||||
|                                  help='Allow the use of a SSH agent.') | ||||
|     parser.add_argument('--no-allowagent', action='store_false', dest='allowagent', | ||||
|                                  help='Disallow the use of a SSH agent.') | ||||
|     parser.add_argument('--lookforkeys', action='store_true', dest='lookforkeys', | ||||
|                                  help='Allow to look for SSH keys in ~/.ssh/.') | ||||
|     parser.add_argument('--no-lookforkeys', action='store_false', dest='lookforkeys', | ||||
|                                  help='Disallow to look for SSH keys in ~/.ssh/.') | ||||
|     parser.add_argument('--dir', required=False, default=None, | ||||
|                                  help='Local directory where to copy stuff. This will be mirrored to the remote system if needed (make sure to use path that will work on both).') | ||||
|     parser.add_argument('--models', nargs='+', required=False, | ||||
|                                  help='List of files (protocolbuffer) to work on. Might be a zip file.') | ||||
|     parser.add_argument('--wav', required=False, | ||||
|                                  help='WAV file to pass to native_client. Supply again in plotting mode to draw realine line.') | ||||
|     parser.add_argument('--lm_binary', required=False, | ||||
|                                  help='Path to the LM binary file used by the decoder.') | ||||
|     parser.add_argument('--trie', required=False, | ||||
|                                  help='Path to the trie file used by the decoder.') | ||||
|     parser.add_argument('--iters', type=int, required=False, default=5, | ||||
|                                  help='How many iterations to perfom on each model.') | ||||
|     parser.add_argument('--keep', required=False, action='store_true', | ||||
|                                  help='Keeping run files (binaries & models).') | ||||
|     parser.add_argument('--csv', type=argparse.FileType('w'), required=False, | ||||
|                                  help='Target CSV file where to dump data.') | ||||
|     parser.add_argument('--binaries', required=False, default=None, | ||||
|                                  help='Specify non TaskCluster native_client.tar.xz to use') | ||||
|     return parser.parse_args() | ||||
| 
 | ||||
| def do_main(): | ||||
|     cli_args = handle_args() | ||||
| 
 | ||||
|     if not cli_args.models or not cli_args.wav: | ||||
|         raise AssertionError('Missing arguments (models or wav)') | ||||
| 
 | ||||
|     if cli_args.dir is not None and not os.path.isdir(cli_args.dir): | ||||
|         raise AssertionError('Inexistent temp directory') | ||||
| 
 | ||||
|     if cli_args.binaries is not None and cli_args.binaries.find('native_client.tar.xz') == -1: | ||||
|         raise AssertionError('Local binaries must be bundled in a native_client.tar.xz file') | ||||
| 
 | ||||
|     global ssh_conn | ||||
|     ssh_conn = establish_ssh(target=cli_args.target, auto_trust=cli_args.autotrust, allow_agent=cli_args.allowagent, look_keys=cli_args.lookforkeys) | ||||
| 
 | ||||
|     tempdir, sorted_models = setup_tempdir(dir=cli_args.dir, models=cli_args.models, wav=cli_args.wav, lm_binary=cli_args.lm_binary, trie=cli_args.trie, binaries=cli_args.binaries) | ||||
| 
 | ||||
|     dest_sorted_models = list(map(lambda x: os.path.join(tempdir, os.path.basename(x)), sorted_models)) | ||||
|     dest_wav = os.path.join(tempdir, os.path.basename(cli_args.wav)) | ||||
| 
 | ||||
|     if cli_args.lm_binary and cli_args.trie: | ||||
|         dest_lm_binary = os.path.join(tempdir, os.path.basename(cli_args.lm_binary)) | ||||
|         dest_trie = os.path.join(tempdir, os.path.basename(cli_args.trie)) | ||||
|         inference_times = run_benchmarks(dir=tempdir, models=dest_sorted_models, wav=dest_wav, lm_binary=dest_lm_binary, trie=dest_trie, iters=cli_args.iters) | ||||
|     else: | ||||
|         inference_times = run_benchmarks(dir=tempdir, models=dest_sorted_models, wav=dest_wav, iters=cli_args.iters) | ||||
| 
 | ||||
|     if cli_args.csv: | ||||
|         produce_csv(input=inference_times, output=cli_args.csv) | ||||
| 
 | ||||
|     if not cli_args.keep: | ||||
|         teardown_tempdir(dir=tempdir) | ||||
| 
 | ||||
| if __name__ == '__main__' : | ||||
|     do_main() | ||||
| @ -1,145 +0,0 @@ | ||||
| #!/usr/bin/env python | ||||
| # -*- coding: utf-8 -*- | ||||
| 
 | ||||
| from __future__ import absolute_import, division, print_function | ||||
| 
 | ||||
| import os | ||||
| import sys | ||||
| 
 | ||||
| # To use util.tc | ||||
| sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(sys.argv[0])))) | ||||
| from util.helpers import keep_only_digits | ||||
| 
 | ||||
| import argparse | ||||
| import numpy | ||||
| import matplotlib.pyplot as plt | ||||
| import scipy.stats as scipy_stats | ||||
| import scipy.io.wavfile as wav | ||||
| import csv | ||||
| import getpass | ||||
| 
 | ||||
| from six import iteritems | ||||
| from six.moves import range, map | ||||
| 
 | ||||
| r''' | ||||
|  Tool to: | ||||
|   - ingest CSV file produced by benchmark_nc and produce nice plots | ||||
| ''' | ||||
| 
 | ||||
| def reduce_filename(f): | ||||
|     r''' | ||||
|     Expects something like /tmp/tmpAjry4Gdsbench/test.weights.e5.XXX.YYY.pb | ||||
|     Where XXX is a variation on the model size for example | ||||
|     And where YYY is a const related to the training dataset | ||||
|     ''' | ||||
| 
 | ||||
|     f = os.path.basename(f).split('.') | ||||
|     return int(keep_only_digits(f[-3])) | ||||
| 
 | ||||
| def ingest_csv(datasets=None, range=None): | ||||
|     existing_files = filter(lambda x: os.path.isfile(x[1]), datasets) | ||||
|     assert len(list(datasets)) == len(list(existing_files)) | ||||
| 
 | ||||
|     if range: | ||||
|         range = map(int, range.split(',')) | ||||
| 
 | ||||
|     data = {} | ||||
|     for (dsname, dsfile) in datasets: | ||||
|         print('Reading %s from %s' % (dsname, dsfile)) | ||||
|         with open(dsfile) as f: | ||||
|             d = csv.DictReader(f) | ||||
|             data[dsname] = [] | ||||
|             for e in d: | ||||
|                 if range: | ||||
|                     re       = reduce_filename(e['model']) | ||||
|                     in_range = (re >= range[0] and re <= range[1]) | ||||
|                     if in_range: | ||||
|                         data[dsname].append(e) | ||||
|                 else: | ||||
|                     data[dsname].append(e) | ||||
| 
 | ||||
|     return data | ||||
| 
 | ||||
| def produce_plot(input=None, output=None): | ||||
|     x = range(len(input)) | ||||
|     xlabels = list(map(lambda a: a['name'], input)) | ||||
|     y = list(map(lambda a: a['mean'], input)) | ||||
|     yerr = list(map(lambda a: a['stddev'], input)) | ||||
| 
 | ||||
|     print('y=', y) | ||||
|     print('yerr=', yerr) | ||||
|     plt.errorbar(x, y, yerr=yerr) | ||||
|     plt.show() | ||||
|     print("Wrote as %s" % output.name) | ||||
| 
 | ||||
| def produce_plot_multiseries(input=None, output=None, title=None, size=None, fig_dpi=None, source_wav=None): | ||||
|     fig, ax = plt.subplots() | ||||
|     # float() required because size.split()[] is a string | ||||
|     fig.set_figwidth(float(size.split('x')[0]) / fig_dpi) | ||||
|     fig.set_figheight(float(size.split('x')[1]) / fig_dpi) | ||||
| 
 | ||||
|     nb_items = len(input[list(input.keys())[0]]) | ||||
|     x_all    = list(range(nb_items)) | ||||
|     for serie, serie_values in iteritems(input): | ||||
|         xtics  = list(map(lambda a: reduce_filename(a['model']), serie_values)) | ||||
|         y      = list(map(lambda a: float(a['mean']), serie_values)) | ||||
|         yerr   = list(map(lambda a: float(a['std']), serie_values)) | ||||
|         linreg = scipy_stats.linregress(x_all, y) | ||||
|         ylin   = linreg.intercept + linreg.slope * numpy.asarray(x_all) | ||||
| 
 | ||||
|         ax.errorbar(x_all, y, yerr=yerr, label=('%s' % serie), fmt='-', capsize=4, elinewidth=1) | ||||
|         ax.plot(x_all, ylin, label=('%s ~= %0.4f*x+%0.4f (R=%0.4f)' % (serie, linreg.slope, linreg.intercept, linreg.rvalue))) | ||||
| 
 | ||||
|         plt.xticks(x_all, xtics, rotation=60) | ||||
| 
 | ||||
|     if source_wav: | ||||
|         audio = wav.read(source_wav) | ||||
|         print('Adding realtime') | ||||
|         for rt_factor in [ 0.5, 1.0, 1.5, 2.0 ]: | ||||
|             rt_secs = len(audio[1]) / audio[0] * rt_factor | ||||
|             y_rt    = numpy.repeat(rt_secs, nb_items) | ||||
|             ax.plot(x_all, y_rt, label=('Realtime: %0.4f secs [%0.1f]' % (rt_secs, rt_factor))) | ||||
| 
 | ||||
|     ax.set_title(title) | ||||
|     ax.set_xlabel('Model size') | ||||
|     ax.set_ylabel('Execution time (s)') | ||||
|     legend = ax.legend(loc='best') | ||||
| 
 | ||||
|     plot_format = os.path.splitext(output.name)[-1].split('.')[-1] | ||||
| 
 | ||||
|     plt.grid() | ||||
|     plt.tight_layout() | ||||
|     plt.savefig(output, transparent=False, frameon=True, dpi=fig_dpi, format=plot_format) | ||||
| 
 | ||||
| def handle_args(): | ||||
|     parser = argparse.ArgumentParser(description='Benchmarking tooling for DeepSpeech native_client.') | ||||
|     parser.add_argument('--wav', required=False, | ||||
|                                  help='WAV file to pass to native_client. Supply again in plotting mode to draw realine line.') | ||||
|     parser.add_argument('--dataset', action='append', nargs=2, metavar=('name','source'), | ||||
|                                 help='Include dataset NAME from file SOURCE. Repeat the option to add more datasets.') | ||||
|     parser.add_argument('--title', default=None, help='Title of the plot.') | ||||
|     parser.add_argument('--plot', type=argparse.FileType('wb'), required=False, | ||||
|                                 help='Target file where to plot data. Format will be deduced from extension.') | ||||
|     parser.add_argument('--size', default='800x600', | ||||
|                                 help='Size (px) of the resulting plot.') | ||||
|     parser.add_argument('--dpi', type=int, default=96, | ||||
|                                 help='Set plot DPI.') | ||||
|     parser.add_argument('--range', default=None, | ||||
|                                 help='Range of model size to use. Comma-separated string of boundaries: min,max') | ||||
|     return parser.parse_args() | ||||
| 
 | ||||
| def do_main(): | ||||
|     cli_args = handle_args() | ||||
| 
 | ||||
|     if not cli_args.dataset or not cli_args.plot: | ||||
|         raise AssertionError('Missing arguments (dataset or target file)') | ||||
| 
 | ||||
|     # This is required to avoid errors about missing DISPLAY env var | ||||
|     plt.switch_backend('agg') | ||||
|     all_inference_times = ingest_csv(datasets=cli_args.dataset, range=cli_args.range) | ||||
| 
 | ||||
|     if cli_args.plot: | ||||
|         produce_plot_multiseries(input=all_inference_times, output=cli_args.plot, title=cli_args.title, size=cli_args.size, fig_dpi=cli_args.dpi, source_wav=cli_args.wav) | ||||
| 
 | ||||
| if __name__ == '__main__' : | ||||
|     do_main() | ||||
| @ -20,8 +20,3 @@ bs4 | ||||
| requests | ||||
| librosa | ||||
| soundfile | ||||
| 
 | ||||
| # Miscellaneous scripts | ||||
| paramiko >= 2.1 | ||||
| scipy | ||||
| matplotlib | ||||
|  | ||||
| @ -1,106 +0,0 @@ | ||||
| #!/bin/bash | ||||
| 
 | ||||
| set -xe | ||||
| 
 | ||||
| source $(dirname "$0")/tc-tests-utils.sh | ||||
| 
 | ||||
| exec_benchmark() | ||||
| { | ||||
|     model_file="$1" | ||||
|     run_postfix=$2 | ||||
| 
 | ||||
|     mkdir -p /tmp/bench-ds/ || true | ||||
|     mkdir -p /tmp/bench-ds-nolm/ || true | ||||
| 
 | ||||
|     csv=${TASKCLUSTER_ARTIFACTS}/benchmark-${run_postfix}.csv | ||||
|     csv_nolm=${TASKCLUSTER_ARTIFACTS}/benchmark-nolm-${run_postfix}.csv | ||||
|     png=${TASKCLUSTER_ARTIFACTS}/benchmark-${run_postfix}.png | ||||
|     svg=${TASKCLUSTER_ARTIFACTS}/benchmark-${run_postfix}.svg | ||||
| 
 | ||||
|     python ${DS_ROOT_TASK}/DeepSpeech/ds/bin/benchmark_nc.py \ | ||||
|         --dir /tmp/bench-ds/ \ | ||||
|         --models ${model_file} \ | ||||
|         --wav /tmp/LDC93S1.wav \ | ||||
|         --lm_binary /tmp/lm.binary \ | ||||
|         --trie /tmp/trie \ | ||||
|         --csv ${csv} | ||||
| 
 | ||||
|     python ${DS_ROOT_TASK}/DeepSpeech/ds/bin/benchmark_nc.py \ | ||||
|         --dir /tmp/bench-ds-nolm/ \ | ||||
|         --models ${model_file} \ | ||||
|         --wav /tmp/LDC93S1.wav \ | ||||
|         --csv ${csv_nolm} | ||||
| 
 | ||||
|     python ${DS_ROOT_TASK}/DeepSpeech/ds/bin/benchmark_plotter.py \ | ||||
|         --dataset "TaskCluster model" ${csv} \ | ||||
|         --dataset "TaskCluster model (no LM)" ${csv_nolm} \ | ||||
|         --title "TaskCluster model benchmark" \ | ||||
|         --wav /tmp/LDC93S1.wav \ | ||||
|         --plot ${png} \ | ||||
|         --size 1280x720 | ||||
| 
 | ||||
|     python ${DS_ROOT_TASK}/DeepSpeech/ds/bin/benchmark_plotter.py \ | ||||
|         --dataset "TaskCluster model" ${csv} \ | ||||
|         --dataset "TaskCluster model (no LM)" ${csv_nolm} \ | ||||
|         --title "TaskCluster model benchmark" \ | ||||
|         --wav /tmp/LDC93S1.wav \ | ||||
|         --plot ${svg} \ | ||||
|         --size 1280x720 | ||||
| } | ||||
| 
 | ||||
| pyver=3.5.6 | ||||
| 
 | ||||
| unset PYTHON_BIN_PATH | ||||
| unset PYTHONPATH | ||||
| export PYENV_ROOT="${HOME}/ds-test/.pyenv" | ||||
| export PATH="${PYENV_ROOT}/bin:$PATH" | ||||
| 
 | ||||
| mkdir -p ${TASKCLUSTER_ARTIFACTS} || true | ||||
| mkdir -p ${PYENV_ROOT} || true | ||||
| 
 | ||||
| # We still need to get model, wav and alphabet | ||||
| download_data | ||||
| 
 | ||||
| # Follow benchmark naming from parameters in bin/run-tc-ldc93s1.sh | ||||
| # Okay, it's not really the real LSTM sizes, just a way to verify how things | ||||
| # actually behave. | ||||
| for size in 100 200 300 400 500 600 700 800 900; | ||||
| do | ||||
|     cp /tmp/${model_name} /tmp/test.frozen.e75.lstm${size}.ldc93s1.pb | ||||
| done; | ||||
| 
 | ||||
| # Let's make it a ZIP file. We don't want the directory structure. | ||||
| zip --junk-paths -r9 /tmp/test.frozen.e75.lstm100-900.ldc93s1.zip /tmp/test.frozen.e75.lstm*.ldc93s1.pb && rm /tmp/test.frozen.e75.lstm*.ldc93s1.pb | ||||
| 
 | ||||
| # And prepare for multiple files on the CLI | ||||
| model_list="" | ||||
| for size in 10 20 30 40 50 60 70 80 90; | ||||
| do | ||||
|     cp /tmp/${model_name} /tmp/test.frozen.e75.lstm${size}.ldc93s1.pb | ||||
|     model_list="${model_list} /tmp/test.frozen.e75.lstm${size}.ldc93s1.pb" | ||||
| done; | ||||
| 
 | ||||
| # Let's prepare another model for single-model codepath | ||||
| mv /tmp/${model_name} /tmp/test.frozen.e75.lstm494.ldc93s1.pb | ||||
| 
 | ||||
| export TASKCLUSTER_SCHEME=${DEEPSPEECH_ARTIFACTS_ROOT}/native_client.tar.xz | ||||
| 
 | ||||
| install_pyenv "${PYENV_ROOT}" | ||||
| install_pyenv_virtualenv "$(pyenv root)/plugins/pyenv-virtualenv" | ||||
| 
 | ||||
| PYENV_NAME=deepspeech-test | ||||
| pyenv install ${pyver} | ||||
| pyenv virtualenv ${pyver} ${PYENV_NAME} | ||||
| source ${PYENV_ROOT}/versions/${pyver}/envs/${PYENV_NAME}/bin/activate | ||||
| 
 | ||||
| set -o pipefail | ||||
| pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat | ||||
| pip install -r ${DS_ROOT_TASK}/DeepSpeech/ds/requirements.txt | cat | ||||
| set +o pipefail | ||||
| 
 | ||||
| exec_benchmark "/tmp/test.frozen.e75.lstm494.ldc93s1.pb" "single-model" | ||||
| exec_benchmark "/tmp/test.frozen.e75.lstm100-900.ldc93s1.zip" "zipfile-model" | ||||
| exec_benchmark "${model_list}" "multi-model" | ||||
| 
 | ||||
| deactivate | ||||
| pyenv uninstall --force ${PYENV_NAME} | ||||
| @ -1,14 +0,0 @@ | ||||
| build: | ||||
|   template_file: test-linux-opt-base.tyml | ||||
|   dependencies: | ||||
|     - "linux-amd64-cpu-opt" | ||||
|     - "test-training_16k-linux-amd64-py36m-opt" | ||||
|   test_model_task: "test-training_16k-linux-amd64-py36m-opt" | ||||
|   system_setup: | ||||
|     > | ||||
|       apt-get -qq -y install ${python.packages_trusty.apt} zip | ||||
|   args: | ||||
|     tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-benchmark-tests.sh" | ||||
|   metadata: | ||||
|     name: "DeepSpeech Linux AMD64 CPU benchmarking" | ||||
|     description: "Benchmarking DeepSpeech for Linux/AMD64, CPU only, optimized version" | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user