Merge pull request #2704 from lissyx/remove-benchmark-nc
Remove unused benchmark_nc
This commit is contained in:
commit
5d0e4cc8ed
|
@ -1,504 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# To use util.tc
|
||||
sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(sys.argv[0]))))
|
||||
import util.taskcluster as tcu
|
||||
from util.helpers import keep_only_digits
|
||||
|
||||
import paramiko
|
||||
import argparse
|
||||
import tempfile
|
||||
import shutil
|
||||
import subprocess
|
||||
import stat
|
||||
import numpy
|
||||
import matplotlib.pyplot as plt
|
||||
import scipy.stats as scipy_stats
|
||||
import csv
|
||||
import getpass
|
||||
import zipfile
|
||||
|
||||
from six import iteritems
|
||||
from six.moves import range, map
|
||||
from functools import cmp_to_key
|
||||
|
||||
r'''
|
||||
Tool to:
|
||||
- remote local or remote (ssh) native_client
|
||||
- handles copying models (as protocolbuffer files)
|
||||
- run native_client in benchmark mode
|
||||
- collect timing results
|
||||
- compute mean values (with wariances)
|
||||
- output as CSV
|
||||
'''
|
||||
|
||||
ssh_conn = None
|
||||
def exec_command(command, cwd=None):
|
||||
r'''
|
||||
Helper to exec locally (subprocess) or remotely (paramiko)
|
||||
'''
|
||||
|
||||
rc = None
|
||||
stdout = stderr = None
|
||||
if ssh_conn is None:
|
||||
ld_library_path = {'LD_LIBRARY_PATH': '.:%s' % os.environ.get('LD_LIBRARY_PATH', '')}
|
||||
p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=ld_library_path, cwd=cwd)
|
||||
stdout, stderr = p.communicate()
|
||||
rc = p.returncode
|
||||
else:
|
||||
# environment= requires paramiko >= 2.1 (fails with 2.0.2)
|
||||
final_command = command if cwd is None else 'cd %s && %s %s' % (cwd, 'LD_LIBRARY_PATH=.:$LD_LIBRARY_PATH', command)
|
||||
ssh_stdin, ssh_stdout, ssh_stderr = ssh_conn.exec_command(final_command)
|
||||
stdout = ''.join(ssh_stdout.readlines())
|
||||
stderr = ''.join(ssh_stderr.readlines())
|
||||
rc = ssh_stdout.channel.recv_exit_status()
|
||||
|
||||
return rc, stdout, stderr
|
||||
|
||||
def assert_valid_dir(dir):
|
||||
if dir is None:
|
||||
raise AssertionError('Invalid temp directory')
|
||||
return True
|
||||
|
||||
def get_arch_string():
|
||||
r'''
|
||||
Check local or remote system arch, to produce TaskCluster proper link.
|
||||
'''
|
||||
rc, stdout, stderr = exec_command('uname -sm')
|
||||
if rc > 0:
|
||||
raise AssertionError('Error checking OS')
|
||||
|
||||
stdout = stdout.lower().strip()
|
||||
if not b'linux' in stdout:
|
||||
raise AssertionError('Unsupported OS')
|
||||
|
||||
if b'armv7l' in stdout:
|
||||
return 'arm'
|
||||
|
||||
if b'x86_64' in stdout:
|
||||
nv_rc, nv_stdout, nv_stderr = exec_command('nvidia-smi')
|
||||
nv_stdout = nv_stdout.lower().strip()
|
||||
if b'NVIDIA-SMI' in nv_stdout:
|
||||
return 'gpu'
|
||||
else:
|
||||
return 'cpu'
|
||||
|
||||
raise AssertionError('Unsupported arch:', stdout)
|
||||
|
||||
def maybe_download_binaries(dir):
|
||||
assert_valid_dir(dir)
|
||||
tcu.maybe_download_tc(target_dir=dir, tc_url=tcu.get_tc_url(get_arch_string()), progress=True)
|
||||
|
||||
def extract_native_client_tarball(dir):
|
||||
r'''
|
||||
Download a native_client.tar.xz file from TaskCluster and extract it to dir.
|
||||
'''
|
||||
assert_valid_dir(dir)
|
||||
|
||||
target_tarball = os.path.join(dir, 'native_client.tar.xz')
|
||||
if os.path.isfile(target_tarball) and os.stat(target_tarball).st_size == 0:
|
||||
return
|
||||
|
||||
subprocess.check_call(['pixz', '-d', 'native_client.tar.xz'], cwd=dir)
|
||||
subprocess.check_call(['tar', 'xf', 'native_client.tar'], cwd=dir)
|
||||
os.unlink(os.path.join(dir, 'native_client.tar'))
|
||||
open(target_tarball, 'w').close()
|
||||
|
||||
def is_zip_file(models):
|
||||
r'''
|
||||
Ensure that a path is a zip file by:
|
||||
- checking length is 1
|
||||
- checking extension is '.zip'
|
||||
'''
|
||||
ext = os.path.splitext(models[0])[1]
|
||||
return (len(models) == 1) and (ext == '.zip')
|
||||
|
||||
def maybe_inspect_zip(models):
|
||||
r'''
|
||||
Detect if models is a list of protocolbuffer files or a ZIP file.
|
||||
If the latter, then unzip it and return the list of protocolbuffer files
|
||||
that were inside.
|
||||
'''
|
||||
|
||||
if not(is_zip_file(models)):
|
||||
return models
|
||||
|
||||
if len(models) > 1:
|
||||
return models
|
||||
|
||||
if len(models) < 1:
|
||||
raise AssertionError('No models at all')
|
||||
|
||||
return zipfile.ZipFile(models[0]).namelist()
|
||||
|
||||
def all_files(models=[]):
|
||||
r'''
|
||||
Return a list of full path of files matching 'models', sorted in human
|
||||
numerical order (i.e., 0 1 2 ..., 10 11 12, ..., 100, ..., 1000).
|
||||
|
||||
Files are supposed to be named identically except one variable component
|
||||
e.g. the list,
|
||||
test.weights.e5.lstm1200.ldc93s1.pb
|
||||
test.weights.e5.lstm1000.ldc93s1.pb
|
||||
test.weights.e5.lstm800.ldc93s1.pb
|
||||
gets sorted:
|
||||
test.weights.e5.lstm800.ldc93s1.pb
|
||||
test.weights.e5.lstm1000.ldc93s1.pb
|
||||
test.weights.e5.lstm1200.ldc93s1.pb
|
||||
'''
|
||||
|
||||
def nsort(a, b):
|
||||
fa = os.path.basename(a).split('.')
|
||||
fb = os.path.basename(b).split('.')
|
||||
elements_to_remove = []
|
||||
|
||||
assert len(fa) == len(fb)
|
||||
|
||||
for i in range(0, len(fa)):
|
||||
if fa[i] == fb[i]:
|
||||
elements_to_remove.append(fa[i])
|
||||
|
||||
for e in elements_to_remove:
|
||||
fa.remove(e)
|
||||
fb.remove(e)
|
||||
|
||||
assert len(fa) == len(fb)
|
||||
assert len(fa) == 1
|
||||
|
||||
fa = int(keep_only_digits(fa[0]))
|
||||
fb = int(keep_only_digits(fb[0]))
|
||||
|
||||
if fa < fb:
|
||||
return -1
|
||||
if fa == fb:
|
||||
return 0
|
||||
if fa > fb:
|
||||
return 1
|
||||
|
||||
base = list(map(lambda x: os.path.abspath(x), maybe_inspect_zip(models)))
|
||||
base.sort(key=cmp_to_key(nsort))
|
||||
|
||||
return base
|
||||
|
||||
def copy_tree(dir):
|
||||
assert_valid_dir(dir)
|
||||
|
||||
sftp = ssh_conn.open_sftp()
|
||||
# IOError will get triggered if the path does not exists remotely
|
||||
try:
|
||||
if stat.S_ISDIR(sftp.stat(dir).st_mode):
|
||||
print('Directory already existent: %s' % dir)
|
||||
except IOError:
|
||||
print('Creating remote directory: %s' % dir)
|
||||
sftp.mkdir(dir)
|
||||
|
||||
print('Copy files to remote')
|
||||
for fname in os.listdir(dir):
|
||||
fullpath = os.path.join(dir, fname)
|
||||
local_stat = os.stat(fullpath)
|
||||
try:
|
||||
remote_mode = sftp.stat(fullpath).st_mode
|
||||
except IOError:
|
||||
remote_mode = 0
|
||||
|
||||
if not stat.S_ISREG(remote_mode):
|
||||
print('Copying %s ...' % fullpath)
|
||||
remote_mode = sftp.put(fullpath, fullpath, confirm=True).st_mode
|
||||
|
||||
if local_stat.st_mode != remote_mode:
|
||||
print('Setting proper remote mode: %s' % local_stat.st_mode)
|
||||
sftp.chmod(fullpath, local_stat.st_mode)
|
||||
|
||||
sftp.close()
|
||||
|
||||
def delete_tree(dir):
|
||||
assert_valid_dir(dir)
|
||||
|
||||
sftp = ssh_conn.open_sftp()
|
||||
# IOError will get triggered if the path does not exists remotely
|
||||
try:
|
||||
if stat.S_ISDIR(sftp.stat(dir).st_mode):
|
||||
print('Removing remote files')
|
||||
for fname in sftp.listdir(dir):
|
||||
fullpath = os.path.join(dir, fname)
|
||||
remote_stat = sftp.stat(fullpath)
|
||||
if stat.S_ISREG(remote_stat.st_mode):
|
||||
print('Removing %s ...' % fullpath)
|
||||
sftp.remove(fullpath)
|
||||
|
||||
print('Removing directory %s ...' % dir)
|
||||
sftp.rmdir(dir)
|
||||
|
||||
sftp.close()
|
||||
except IOError:
|
||||
print('No remote directory: %s' % dir)
|
||||
|
||||
def setup_tempdir(dir, models, wav, lm_binary, trie, binaries):
|
||||
r'''
|
||||
Copy models, libs and binary to a directory (new one if dir is None)
|
||||
'''
|
||||
if dir is None:
|
||||
dir = tempfile.mkdtemp(suffix='dsbench')
|
||||
|
||||
sorted_models = all_files(models=models)
|
||||
if binaries is None:
|
||||
maybe_download_binaries(dir)
|
||||
else:
|
||||
print('Using local binaries: %s' % (binaries))
|
||||
shutil.copy2(binaries, dir)
|
||||
extract_native_client_tarball(dir)
|
||||
|
||||
filenames = map(lambda x: os.path.join(dir, os.path.basename(x)), sorted_models)
|
||||
missing_models = list(filter(lambda x: not os.path.isfile(x), filenames))
|
||||
if len(missing_models) > 0:
|
||||
# If we have a ZIP file, directly extract it to the proper path
|
||||
if is_zip_file(models):
|
||||
print('Extracting %s to %s' % (models[0], dir))
|
||||
zipfile.ZipFile(models[0]).extractall(path=dir)
|
||||
print('Extracted %s.' % models[0])
|
||||
else:
|
||||
# If one model is missing, let's copy everything again. Be safe.
|
||||
for f in sorted_models:
|
||||
print('Copying %s to %s' % (f, dir))
|
||||
shutil.copy2(f, dir)
|
||||
|
||||
for extra_file in [ wav, lm_binary, trie ]:
|
||||
if extra_file and not os.path.isfile(os.path.join(dir, os.path.basename(extra_file))):
|
||||
print('Copying %s to %s' % (extra_file, dir))
|
||||
shutil.copy2(extra_file, dir)
|
||||
|
||||
if ssh_conn:
|
||||
copy_tree(dir)
|
||||
|
||||
return dir, sorted_models
|
||||
|
||||
def teardown_tempdir(dir):
|
||||
r'''
|
||||
Cleanup temporary directory.
|
||||
'''
|
||||
|
||||
if ssh_conn:
|
||||
delete_tree(dir)
|
||||
|
||||
assert_valid_dir(dir)
|
||||
shutil.rmtree(dir)
|
||||
|
||||
def get_sshconfig():
|
||||
r'''
|
||||
Read user's SSH configuration file
|
||||
'''
|
||||
|
||||
with open(os.path.expanduser('~/.ssh/config')) as f:
|
||||
cfg = paramiko.SSHConfig()
|
||||
cfg.parse(f)
|
||||
ret_dict = {}
|
||||
for d in cfg._config:
|
||||
_copy = dict(d)
|
||||
# Avoid buggy behavior with strange host definitions, we need
|
||||
# Hostname and not Host.
|
||||
del _copy['host']
|
||||
for host in d['host']:
|
||||
ret_dict[host] = _copy['config']
|
||||
|
||||
return ret_dict
|
||||
|
||||
def establish_ssh(target=None, auto_trust=False, allow_agent=True, look_keys=True):
|
||||
r'''
|
||||
Establish a SSH connection to a remote host. It should be able to use
|
||||
SSH's config file Host name declarations. By default, will not automatically
|
||||
add trust for hosts, will use SSH agent and will try to load keys.
|
||||
'''
|
||||
|
||||
def password_prompt(username, hostname):
|
||||
r'''
|
||||
If the Host is relying on password authentication, lets ask it.
|
||||
Relying on SSH itself to take care of that would not work when the
|
||||
remote authentication is password behind a SSH-key+2FA jumphost.
|
||||
'''
|
||||
return getpass.getpass('No SSH key for %s@%s, please provide password: ' % (username, hostname))
|
||||
|
||||
ssh_conn = None
|
||||
if target is not None:
|
||||
ssh_conf = get_sshconfig()
|
||||
cfg = {
|
||||
'hostname': None,
|
||||
'port': 22,
|
||||
'allow_agent': allow_agent,
|
||||
'look_for_keys': look_keys
|
||||
}
|
||||
if ssh_conf.has_key(target):
|
||||
user_config = ssh_conf.get(target)
|
||||
|
||||
# If ssh_config file's Host defined 'User' instead of 'Username'
|
||||
if user_config.has_key('user') and not user_config.has_key('username'):
|
||||
user_config['username'] = user_config['user']
|
||||
del user_config['user']
|
||||
|
||||
for k in ('username', 'hostname', 'port'):
|
||||
if k in user_config:
|
||||
cfg[k] = user_config[k]
|
||||
|
||||
# Assume Password auth. If we don't do that, then when connecting
|
||||
# through a jumphost we will run into issues and the user will
|
||||
# not be able to input his password to the SSH prompt.
|
||||
if 'identityfile' in user_config:
|
||||
cfg['key_filename'] = user_config['identityfile']
|
||||
else:
|
||||
cfg['password'] = password_prompt(cfg['username'], cfg['hostname'] or target)
|
||||
|
||||
# Should be the last one, since ProxyCommand will issue connection to remote host
|
||||
if 'proxycommand' in user_config:
|
||||
cfg['sock'] = paramiko.ProxyCommand(user_config['proxycommand'])
|
||||
|
||||
else:
|
||||
cfg['username'] = target.split('@')[0]
|
||||
cfg['hostname'] = target.split('@')[1].split(':')[0]
|
||||
cfg['password'] = password_prompt(cfg['username'], cfg['hostname'])
|
||||
try:
|
||||
cfg['port'] = int(target.split('@')[1].split(':')[1])
|
||||
except IndexError:
|
||||
# IndexError will happen if no :PORT is there.
|
||||
# Default value 22 is defined above in 'cfg'.
|
||||
pass
|
||||
|
||||
ssh_conn = paramiko.SSHClient()
|
||||
if auto_trust:
|
||||
ssh_conn.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
|
||||
ssh_conn.connect(**cfg)
|
||||
|
||||
return ssh_conn
|
||||
|
||||
def run_benchmarks(dir, models, wav, lm_binary=None, trie=None, iters=-1):
|
||||
r'''
|
||||
Core of the running of the benchmarks. We will run on all of models, against
|
||||
the WAV file provided as wav.
|
||||
'''
|
||||
|
||||
assert_valid_dir(dir)
|
||||
|
||||
inference_times = [ ]
|
||||
|
||||
for model in models:
|
||||
model_filename = model
|
||||
|
||||
current_model = {
|
||||
'name': model,
|
||||
'iters': [ ],
|
||||
'mean': numpy.infty,
|
||||
'stddev': numpy.infty
|
||||
}
|
||||
|
||||
if lm_binary and trie:
|
||||
cmdline = './deepspeech --model "%s" --lm "%s" --trie "%s" --audio "%s" -t' % (model_filename, lm_binary, trie, wav)
|
||||
else:
|
||||
cmdline = './deepspeech --model "%s" --audio "%s" -t' % (model_filename, wav)
|
||||
|
||||
for it in range(iters):
|
||||
sys.stdout.write('\rRunning %s: %d/%d' % (os.path.basename(model), (it+1), iters))
|
||||
sys.stdout.flush()
|
||||
rc, stdout, stderr = exec_command(cmdline, cwd=dir)
|
||||
if rc == 0:
|
||||
inference_time = float(stdout.split(b'\n')[1].split(b'=')[-1])
|
||||
# print("[%d] model=%s inference=%f" % (it, model, inference_time))
|
||||
current_model['iters'].append(inference_time)
|
||||
else:
|
||||
print('exec_command("%s") failed with rc=%d' % (cmdline, rc))
|
||||
print('stdout: %s' % stdout)
|
||||
print('stderr: %s' % stderr)
|
||||
raise AssertionError('Execution failure: rc=%d' % (rc))
|
||||
|
||||
sys.stdout.write('\n')
|
||||
sys.stdout.flush()
|
||||
current_model['mean'] = numpy.mean(current_model['iters'])
|
||||
current_model['stddev'] = numpy.std(current_model['iters'])
|
||||
inference_times.append(current_model)
|
||||
|
||||
return inference_times
|
||||
|
||||
def produce_csv(input, output):
|
||||
r'''
|
||||
Take an input dictionnary and write it to the object-file output.
|
||||
'''
|
||||
output.write('"model","mean","std"\n')
|
||||
for model_data in input:
|
||||
output.write('"%s",%f,%f\n' % (model_data['name'], model_data['mean'], model_data['stddev']))
|
||||
output.flush()
|
||||
output.close()
|
||||
print("Wrote as %s" % output.name)
|
||||
|
||||
def handle_args():
|
||||
parser = argparse.ArgumentParser(description='Benchmarking tooling for DeepSpeech native_client.')
|
||||
parser.add_argument('--target', required=False,
|
||||
help='SSH user:pass@host string for remote benchmarking. This can also be a name of a matching \'Host\' in your SSH config.')
|
||||
parser.add_argument('--autotrust', action='store_true', default=False,
|
||||
help='SSH Paramiko policy to automatically trust unknown keys.')
|
||||
parser.add_argument('--allowagent', action='store_true', dest='allowagent',
|
||||
help='Allow the use of a SSH agent.')
|
||||
parser.add_argument('--no-allowagent', action='store_false', dest='allowagent',
|
||||
help='Disallow the use of a SSH agent.')
|
||||
parser.add_argument('--lookforkeys', action='store_true', dest='lookforkeys',
|
||||
help='Allow to look for SSH keys in ~/.ssh/.')
|
||||
parser.add_argument('--no-lookforkeys', action='store_false', dest='lookforkeys',
|
||||
help='Disallow to look for SSH keys in ~/.ssh/.')
|
||||
parser.add_argument('--dir', required=False, default=None,
|
||||
help='Local directory where to copy stuff. This will be mirrored to the remote system if needed (make sure to use path that will work on both).')
|
||||
parser.add_argument('--models', nargs='+', required=False,
|
||||
help='List of files (protocolbuffer) to work on. Might be a zip file.')
|
||||
parser.add_argument('--wav', required=False,
|
||||
help='WAV file to pass to native_client. Supply again in plotting mode to draw realine line.')
|
||||
parser.add_argument('--lm_binary', required=False,
|
||||
help='Path to the LM binary file used by the decoder.')
|
||||
parser.add_argument('--trie', required=False,
|
||||
help='Path to the trie file used by the decoder.')
|
||||
parser.add_argument('--iters', type=int, required=False, default=5,
|
||||
help='How many iterations to perfom on each model.')
|
||||
parser.add_argument('--keep', required=False, action='store_true',
|
||||
help='Keeping run files (binaries & models).')
|
||||
parser.add_argument('--csv', type=argparse.FileType('w'), required=False,
|
||||
help='Target CSV file where to dump data.')
|
||||
parser.add_argument('--binaries', required=False, default=None,
|
||||
help='Specify non TaskCluster native_client.tar.xz to use')
|
||||
return parser.parse_args()
|
||||
|
||||
def do_main():
|
||||
cli_args = handle_args()
|
||||
|
||||
if not cli_args.models or not cli_args.wav:
|
||||
raise AssertionError('Missing arguments (models or wav)')
|
||||
|
||||
if cli_args.dir is not None and not os.path.isdir(cli_args.dir):
|
||||
raise AssertionError('Inexistent temp directory')
|
||||
|
||||
if cli_args.binaries is not None and cli_args.binaries.find('native_client.tar.xz') == -1:
|
||||
raise AssertionError('Local binaries must be bundled in a native_client.tar.xz file')
|
||||
|
||||
global ssh_conn
|
||||
ssh_conn = establish_ssh(target=cli_args.target, auto_trust=cli_args.autotrust, allow_agent=cli_args.allowagent, look_keys=cli_args.lookforkeys)
|
||||
|
||||
tempdir, sorted_models = setup_tempdir(dir=cli_args.dir, models=cli_args.models, wav=cli_args.wav, lm_binary=cli_args.lm_binary, trie=cli_args.trie, binaries=cli_args.binaries)
|
||||
|
||||
dest_sorted_models = list(map(lambda x: os.path.join(tempdir, os.path.basename(x)), sorted_models))
|
||||
dest_wav = os.path.join(tempdir, os.path.basename(cli_args.wav))
|
||||
|
||||
if cli_args.lm_binary and cli_args.trie:
|
||||
dest_lm_binary = os.path.join(tempdir, os.path.basename(cli_args.lm_binary))
|
||||
dest_trie = os.path.join(tempdir, os.path.basename(cli_args.trie))
|
||||
inference_times = run_benchmarks(dir=tempdir, models=dest_sorted_models, wav=dest_wav, lm_binary=dest_lm_binary, trie=dest_trie, iters=cli_args.iters)
|
||||
else:
|
||||
inference_times = run_benchmarks(dir=tempdir, models=dest_sorted_models, wav=dest_wav, iters=cli_args.iters)
|
||||
|
||||
if cli_args.csv:
|
||||
produce_csv(input=inference_times, output=cli_args.csv)
|
||||
|
||||
if not cli_args.keep:
|
||||
teardown_tempdir(dir=tempdir)
|
||||
|
||||
if __name__ == '__main__' :
|
||||
do_main()
|
|
@ -1,145 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# To use util.tc
|
||||
sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(sys.argv[0]))))
|
||||
from util.helpers import keep_only_digits
|
||||
|
||||
import argparse
|
||||
import numpy
|
||||
import matplotlib.pyplot as plt
|
||||
import scipy.stats as scipy_stats
|
||||
import scipy.io.wavfile as wav
|
||||
import csv
|
||||
import getpass
|
||||
|
||||
from six import iteritems
|
||||
from six.moves import range, map
|
||||
|
||||
r'''
|
||||
Tool to:
|
||||
- ingest CSV file produced by benchmark_nc and produce nice plots
|
||||
'''
|
||||
|
||||
def reduce_filename(f):
|
||||
r'''
|
||||
Expects something like /tmp/tmpAjry4Gdsbench/test.weights.e5.XXX.YYY.pb
|
||||
Where XXX is a variation on the model size for example
|
||||
And where YYY is a const related to the training dataset
|
||||
'''
|
||||
|
||||
f = os.path.basename(f).split('.')
|
||||
return int(keep_only_digits(f[-3]))
|
||||
|
||||
def ingest_csv(datasets=None, range=None):
|
||||
existing_files = filter(lambda x: os.path.isfile(x[1]), datasets)
|
||||
assert len(list(datasets)) == len(list(existing_files))
|
||||
|
||||
if range:
|
||||
range = map(int, range.split(','))
|
||||
|
||||
data = {}
|
||||
for (dsname, dsfile) in datasets:
|
||||
print('Reading %s from %s' % (dsname, dsfile))
|
||||
with open(dsfile) as f:
|
||||
d = csv.DictReader(f)
|
||||
data[dsname] = []
|
||||
for e in d:
|
||||
if range:
|
||||
re = reduce_filename(e['model'])
|
||||
in_range = (re >= range[0] and re <= range[1])
|
||||
if in_range:
|
||||
data[dsname].append(e)
|
||||
else:
|
||||
data[dsname].append(e)
|
||||
|
||||
return data
|
||||
|
||||
def produce_plot(input=None, output=None):
|
||||
x = range(len(input))
|
||||
xlabels = list(map(lambda a: a['name'], input))
|
||||
y = list(map(lambda a: a['mean'], input))
|
||||
yerr = list(map(lambda a: a['stddev'], input))
|
||||
|
||||
print('y=', y)
|
||||
print('yerr=', yerr)
|
||||
plt.errorbar(x, y, yerr=yerr)
|
||||
plt.show()
|
||||
print("Wrote as %s" % output.name)
|
||||
|
||||
def produce_plot_multiseries(input=None, output=None, title=None, size=None, fig_dpi=None, source_wav=None):
|
||||
fig, ax = plt.subplots()
|
||||
# float() required because size.split()[] is a string
|
||||
fig.set_figwidth(float(size.split('x')[0]) / fig_dpi)
|
||||
fig.set_figheight(float(size.split('x')[1]) / fig_dpi)
|
||||
|
||||
nb_items = len(input[list(input.keys())[0]])
|
||||
x_all = list(range(nb_items))
|
||||
for serie, serie_values in iteritems(input):
|
||||
xtics = list(map(lambda a: reduce_filename(a['model']), serie_values))
|
||||
y = list(map(lambda a: float(a['mean']), serie_values))
|
||||
yerr = list(map(lambda a: float(a['std']), serie_values))
|
||||
linreg = scipy_stats.linregress(x_all, y)
|
||||
ylin = linreg.intercept + linreg.slope * numpy.asarray(x_all)
|
||||
|
||||
ax.errorbar(x_all, y, yerr=yerr, label=('%s' % serie), fmt='-', capsize=4, elinewidth=1)
|
||||
ax.plot(x_all, ylin, label=('%s ~= %0.4f*x+%0.4f (R=%0.4f)' % (serie, linreg.slope, linreg.intercept, linreg.rvalue)))
|
||||
|
||||
plt.xticks(x_all, xtics, rotation=60)
|
||||
|
||||
if source_wav:
|
||||
audio = wav.read(source_wav)
|
||||
print('Adding realtime')
|
||||
for rt_factor in [ 0.5, 1.0, 1.5, 2.0 ]:
|
||||
rt_secs = len(audio[1]) / audio[0] * rt_factor
|
||||
y_rt = numpy.repeat(rt_secs, nb_items)
|
||||
ax.plot(x_all, y_rt, label=('Realtime: %0.4f secs [%0.1f]' % (rt_secs, rt_factor)))
|
||||
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel('Model size')
|
||||
ax.set_ylabel('Execution time (s)')
|
||||
legend = ax.legend(loc='best')
|
||||
|
||||
plot_format = os.path.splitext(output.name)[-1].split('.')[-1]
|
||||
|
||||
plt.grid()
|
||||
plt.tight_layout()
|
||||
plt.savefig(output, transparent=False, frameon=True, dpi=fig_dpi, format=plot_format)
|
||||
|
||||
def handle_args():
|
||||
parser = argparse.ArgumentParser(description='Benchmarking tooling for DeepSpeech native_client.')
|
||||
parser.add_argument('--wav', required=False,
|
||||
help='WAV file to pass to native_client. Supply again in plotting mode to draw realine line.')
|
||||
parser.add_argument('--dataset', action='append', nargs=2, metavar=('name','source'),
|
||||
help='Include dataset NAME from file SOURCE. Repeat the option to add more datasets.')
|
||||
parser.add_argument('--title', default=None, help='Title of the plot.')
|
||||
parser.add_argument('--plot', type=argparse.FileType('wb'), required=False,
|
||||
help='Target file where to plot data. Format will be deduced from extension.')
|
||||
parser.add_argument('--size', default='800x600',
|
||||
help='Size (px) of the resulting plot.')
|
||||
parser.add_argument('--dpi', type=int, default=96,
|
||||
help='Set plot DPI.')
|
||||
parser.add_argument('--range', default=None,
|
||||
help='Range of model size to use. Comma-separated string of boundaries: min,max')
|
||||
return parser.parse_args()
|
||||
|
||||
def do_main():
|
||||
cli_args = handle_args()
|
||||
|
||||
if not cli_args.dataset or not cli_args.plot:
|
||||
raise AssertionError('Missing arguments (dataset or target file)')
|
||||
|
||||
# This is required to avoid errors about missing DISPLAY env var
|
||||
plt.switch_backend('agg')
|
||||
all_inference_times = ingest_csv(datasets=cli_args.dataset, range=cli_args.range)
|
||||
|
||||
if cli_args.plot:
|
||||
produce_plot_multiseries(input=all_inference_times, output=cli_args.plot, title=cli_args.title, size=cli_args.size, fig_dpi=cli_args.dpi, source_wav=cli_args.wav)
|
||||
|
||||
if __name__ == '__main__' :
|
||||
do_main()
|
|
@ -20,8 +20,3 @@ bs4
|
|||
requests
|
||||
librosa
|
||||
soundfile
|
||||
|
||||
# Miscellaneous scripts
|
||||
paramiko >= 2.1
|
||||
scipy
|
||||
matplotlib
|
||||
|
|
|
@ -1,106 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/tc-tests-utils.sh
|
||||
|
||||
exec_benchmark()
|
||||
{
|
||||
model_file="$1"
|
||||
run_postfix=$2
|
||||
|
||||
mkdir -p /tmp/bench-ds/ || true
|
||||
mkdir -p /tmp/bench-ds-nolm/ || true
|
||||
|
||||
csv=${TASKCLUSTER_ARTIFACTS}/benchmark-${run_postfix}.csv
|
||||
csv_nolm=${TASKCLUSTER_ARTIFACTS}/benchmark-nolm-${run_postfix}.csv
|
||||
png=${TASKCLUSTER_ARTIFACTS}/benchmark-${run_postfix}.png
|
||||
svg=${TASKCLUSTER_ARTIFACTS}/benchmark-${run_postfix}.svg
|
||||
|
||||
python ${DS_ROOT_TASK}/DeepSpeech/ds/bin/benchmark_nc.py \
|
||||
--dir /tmp/bench-ds/ \
|
||||
--models ${model_file} \
|
||||
--wav /tmp/LDC93S1.wav \
|
||||
--lm_binary /tmp/lm.binary \
|
||||
--trie /tmp/trie \
|
||||
--csv ${csv}
|
||||
|
||||
python ${DS_ROOT_TASK}/DeepSpeech/ds/bin/benchmark_nc.py \
|
||||
--dir /tmp/bench-ds-nolm/ \
|
||||
--models ${model_file} \
|
||||
--wav /tmp/LDC93S1.wav \
|
||||
--csv ${csv_nolm}
|
||||
|
||||
python ${DS_ROOT_TASK}/DeepSpeech/ds/bin/benchmark_plotter.py \
|
||||
--dataset "TaskCluster model" ${csv} \
|
||||
--dataset "TaskCluster model (no LM)" ${csv_nolm} \
|
||||
--title "TaskCluster model benchmark" \
|
||||
--wav /tmp/LDC93S1.wav \
|
||||
--plot ${png} \
|
||||
--size 1280x720
|
||||
|
||||
python ${DS_ROOT_TASK}/DeepSpeech/ds/bin/benchmark_plotter.py \
|
||||
--dataset "TaskCluster model" ${csv} \
|
||||
--dataset "TaskCluster model (no LM)" ${csv_nolm} \
|
||||
--title "TaskCluster model benchmark" \
|
||||
--wav /tmp/LDC93S1.wav \
|
||||
--plot ${svg} \
|
||||
--size 1280x720
|
||||
}
|
||||
|
||||
pyver=3.5.6
|
||||
|
||||
unset PYTHON_BIN_PATH
|
||||
unset PYTHONPATH
|
||||
export PYENV_ROOT="${HOME}/ds-test/.pyenv"
|
||||
export PATH="${PYENV_ROOT}/bin:$PATH"
|
||||
|
||||
mkdir -p ${TASKCLUSTER_ARTIFACTS} || true
|
||||
mkdir -p ${PYENV_ROOT} || true
|
||||
|
||||
# We still need to get model, wav and alphabet
|
||||
download_data
|
||||
|
||||
# Follow benchmark naming from parameters in bin/run-tc-ldc93s1.sh
|
||||
# Okay, it's not really the real LSTM sizes, just a way to verify how things
|
||||
# actually behave.
|
||||
for size in 100 200 300 400 500 600 700 800 900;
|
||||
do
|
||||
cp /tmp/${model_name} /tmp/test.frozen.e75.lstm${size}.ldc93s1.pb
|
||||
done;
|
||||
|
||||
# Let's make it a ZIP file. We don't want the directory structure.
|
||||
zip --junk-paths -r9 /tmp/test.frozen.e75.lstm100-900.ldc93s1.zip /tmp/test.frozen.e75.lstm*.ldc93s1.pb && rm /tmp/test.frozen.e75.lstm*.ldc93s1.pb
|
||||
|
||||
# And prepare for multiple files on the CLI
|
||||
model_list=""
|
||||
for size in 10 20 30 40 50 60 70 80 90;
|
||||
do
|
||||
cp /tmp/${model_name} /tmp/test.frozen.e75.lstm${size}.ldc93s1.pb
|
||||
model_list="${model_list} /tmp/test.frozen.e75.lstm${size}.ldc93s1.pb"
|
||||
done;
|
||||
|
||||
# Let's prepare another model for single-model codepath
|
||||
mv /tmp/${model_name} /tmp/test.frozen.e75.lstm494.ldc93s1.pb
|
||||
|
||||
export TASKCLUSTER_SCHEME=${DEEPSPEECH_ARTIFACTS_ROOT}/native_client.tar.xz
|
||||
|
||||
install_pyenv "${PYENV_ROOT}"
|
||||
install_pyenv_virtualenv "$(pyenv root)/plugins/pyenv-virtualenv"
|
||||
|
||||
PYENV_NAME=deepspeech-test
|
||||
pyenv install ${pyver}
|
||||
pyenv virtualenv ${pyver} ${PYENV_NAME}
|
||||
source ${PYENV_ROOT}/versions/${pyver}/envs/${PYENV_NAME}/bin/activate
|
||||
|
||||
set -o pipefail
|
||||
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
||||
pip install -r ${DS_ROOT_TASK}/DeepSpeech/ds/requirements.txt | cat
|
||||
set +o pipefail
|
||||
|
||||
exec_benchmark "/tmp/test.frozen.e75.lstm494.ldc93s1.pb" "single-model"
|
||||
exec_benchmark "/tmp/test.frozen.e75.lstm100-900.ldc93s1.zip" "zipfile-model"
|
||||
exec_benchmark "${model_list}" "multi-model"
|
||||
|
||||
deactivate
|
||||
pyenv uninstall --force ${PYENV_NAME}
|
|
@ -1,14 +0,0 @@
|
|||
build:
|
||||
template_file: test-linux-opt-base.tyml
|
||||
dependencies:
|
||||
- "linux-amd64-cpu-opt"
|
||||
- "test-training_16k-linux-amd64-py36m-opt"
|
||||
test_model_task: "test-training_16k-linux-amd64-py36m-opt"
|
||||
system_setup:
|
||||
>
|
||||
apt-get -qq -y install ${python.packages_trusty.apt} zip
|
||||
args:
|
||||
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-benchmark-tests.sh"
|
||||
metadata:
|
||||
name: "DeepSpeech Linux AMD64 CPU benchmarking"
|
||||
description: "Benchmarking DeepSpeech for Linux/AMD64, CPU only, optimized version"
|
Loading…
Reference in New Issue