Centralize WER report code into evaluate.py, call it from DeepSpeech.py

This commit is contained in:
Reuben Morais 2018-11-08 18:24:36 -02:00
parent 60fb5ad04c
commit 56dc024d29
10 changed files with 1070 additions and 1342 deletions

View File

@ -21,5 +21,4 @@ python3 -u DeepSpeech.py \
--display_step 0 \
--validation_step 1 \
--checkpoint_dir "../keep" \
--summary_dir "../keep/summaries" \
--decoder_library_path "../tmp/native_client/libctc_decoder_with_kenlm.so"
--summary_dir "../keep/summaries"

View File

@ -7,4 +7,10 @@ pip install tensorflow-gpu==1.12.0rc2
python3 util/taskcluster.py --arch gpu --target ../tmp/native_client
# Install ds_ctcdecoder package from TaskCluster
VERSION=$(python -c 'import pkg_resources; print(pkg_resources.safe_version(open("VERSION").read()))')
PYVER=$(python -c 'import sys; print("cp{0}{1}-cp{0}{1}m".format(sys.version_info.major, sys.version_info.minor))')
python3 util/taskcluster.py --arch cpu --target ../tmp --artifact "ds_ctcdecoder-${VERSION}-${PYVER}-manylinux1_x86_64.whl"
pip install ../tmp/ds_ctcdecoder-*.whl
mkdir -p ../keep/summaries

File diff suppressed because it is too large Load Diff

View File

@ -15,8 +15,10 @@ import tensorflow as tf
from attrdict import AttrDict
from collections import namedtuple
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from DeepSpeech import initialize_globals, create_flags, log_debug, log_info, log_warn, log_error, create_inference_graph
from multiprocessing import Pool
from util.flags import create_flags
from util.coordinator import C, initialize_globals
from util.logging import log_debug, log_info, log_warn, log_error
from multiprocessing import Pool, cpu_count
from six.moves import zip, range
from util.audio import audiofile_to_input_vector
from util.text import Alphabet, ctc_label_dense_to_sparse, wer, levenshtein
@ -86,31 +88,11 @@ def calculate_report(labels, decodings, distances, losses):
return samples_wer, samples
def main(_):
initialize_globals()
if not FLAGS.test_files:
log_error('You need to specify what files to use for evaluation via '
'the --test_files flag.')
exit(1)
global alphabet
alphabet = Alphabet(FLAGS.alphabet_config_path)
scorer = Scorer(FLAGS.lm_weight, FLAGS.valid_word_count_weight,
def evaluate(test_data, inference_graph, alphabet):
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
alphabet)
C.alphabet)
# sort examples by length, improves packing of batches and timesteps
test_data = preprocess(
FLAGS.test_files.split(','),
FLAGS.test_batch_size,
alphabet=alphabet,
numcep=N_FEATURES,
numcontext=N_CONTEXT,
hdf5_cache_path=FLAGS.hdf5_test_set).sort_values(
by="features_len",
ascending=False)
def create_windows(features):
num_strides = len(features) - (N_CONTEXT * 2)
@ -130,7 +112,7 @@ def main(_):
test_data['features'] = test_data['features'].apply(create_windows)
with tf.Session() as session:
inputs, outputs, layers = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1)
inputs, outputs, layers = inference_graph
# Transpose to batch major for decoder
transposed = tf.transpose(outputs['outputs'], [1, 0, 2])
@ -192,7 +174,10 @@ def main(_):
widget=progressbar.AdaptiveETA)
# Get number of accessible CPU cores for this process
num_processes = len(os.sched_getaffinity(0))
try:
num_processes = cpu_count()
except:
num_processes = 1
# Second pass, decode logits and compute WER and edit distance metrics
for logits, batch in bar(zip(logitses, split_data(test_data, FLAGS.test_batch_size))):
@ -221,7 +206,38 @@ def main(_):
print(' - res: "%s"' % sample.res)
print('-' * 80)
return samples
def main(_):
initialize_globals()
if not FLAGS.test_files:
log_error('You need to specify what files to use for evaluation via '
'the --test_files flag.')
exit(1)
global alphabet
alphabet = Alphabet(FLAGS.alphabet_config_path)
# sort examples by length, improves packing of batches and timesteps
test_data = preprocess(
FLAGS.test_files.split(','),
FLAGS.test_batch_size,
alphabet=alphabet,
numcep=N_FEATURES,
numcontext=N_CONTEXT,
hdf5_cache_path=FLAGS.hdf5_test_set).sort_values(
by="features_len",
ascending=False)
from DeepSpeech import create_inference_graph
graph = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1)
samples = evaluate(test_data, graph, alphabet)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=lambda x: float(x))

View File

@ -13,3 +13,4 @@ bs4
six
requests
tables
attrdict

706
util/coordinator.py Normal file
View File

@ -0,0 +1,706 @@
from __future__ import absolute_import, division, print_function
import os
import pickle
import tensorflow as tf
from attrdict import AttrDict
from datetime import datetime
from six.moves import zip, range, filter, urllib, BaseHTTPServer
from threading import Thread, Lock
from util.gpu import get_available_gpus
from util.flags import FLAGS
from util.logging import *
from util.text import Alphabet
from xdg import BaseDirectory as xdg
# Execution
# =========
# For reporting we also need a standard way to do time measurements.
def stopwatch(start_duration=0):
r'''
This function will toggle a stopwatch.
The first call starts it, second call stops it, third call continues it etc.
So if you want to measure the accumulated time spent in a certain area of the code,
you can surround that code by stopwatch-calls like this:
.. code:: python
fun_time = 0 # initializes a stopwatch
[...]
for i in range(10):
[...]
# Starts/continues the stopwatch - fun_time is now a point in time (again)
fun_time = stopwatch(fun_time)
fun()
# Pauses the stopwatch - fun_time is now a duration
fun_time = stopwatch(fun_time)
[...]
# The following line only makes sense after an even call of :code:`fun_time = stopwatch(fun_time)`.
print 'Time spent in fun():', format_duration(fun_time)
'''
if start_duration == 0:
return datetime.utcnow()
else:
return datetime.utcnow() - start_duration
def format_duration(duration):
'''Formats the result of an even stopwatch call as hours:minutes:seconds'''
duration = duration if isinstance(duration, int) else duration.seconds
m, s = divmod(duration, 60)
h, m = divmod(m, 60)
return '%d:%02d:%02d' % (h, m, s)
# String constants for different services of the web handler
PREFIX_NEXT_INDEX = '/next_index_'
PREFIX_GET_JOB = '/get_job_'
# Global ID counter for all objects requiring an ID
id_counter = 0
def new_id():
'''Returns a new ID that is unique on process level. Not thread-safe.
Returns:
int. The new ID
'''
global id_counter
id_counter += 1
return id_counter
class WorkerJob(object):
'''Represents a job that should be executed by a worker.
Args:
epoch_id (int): the ID of the 'parent' epoch
index (int): the epoch index of the 'parent' epoch
set_name (str): the name of the data-set - one of 'train', 'dev'
steps (int): the number of `session.run` calls
'''
def __init__(self, epoch_id, index, set_name, steps):
self.id = new_id()
self.epoch_id = epoch_id
self.index = index
self.worker = -1
self.set_name = set_name
self.steps = steps
self.loss = -1
self.samples = []
def __str__(self):
return 'Job (ID: %d, worker: %d, epoch: %d, set_name: %s)' % (self.id, self.worker, self.index, self.set_name)
class Epoch(object):
'''Represents an epoch that should be executed by the Training Coordinator.
Creates `num_jobs` `WorkerJob` instances in state 'open'.
Args:
index (int): the epoch index of the 'parent' epoch
num_jobs (int): the number of jobs in this epoch
Kwargs:
set_name (str): the name of the data-set - one of 'train', 'dev'
'''
def __init__(self, index, num_jobs, set_name='train'):
self.id = new_id()
self.index = index
self.num_jobs = num_jobs
self.set_name = set_name
self.loss = -1
self.jobs_open = []
self.jobs_running = []
self.jobs_done = []
for i in range(self.num_jobs):
self.jobs_open.append(WorkerJob(self.id, self.index, self.set_name, FLAGS.iters_per_worker))
def name(self):
'''Gets a printable name for this epoch.
Returns:
str. printable name for this epoch
'''
if self.index >= 0:
ename = ' of Epoch %d' % self.index
else:
ename = ''
if self.set_name == 'train':
return 'Training%s' % ename
else:
return 'Validation%s' % ename
def get_job(self, worker):
'''Gets the next open job from this epoch. The job will be marked as 'running'.
Args:
worker (int): index of the worker that takes the job
Returns:
WorkerJob. job that has been marked as running for this worker
'''
if len(self.jobs_open) > 0:
job = self.jobs_open.pop(0)
self.jobs_running.append(job)
job.worker = worker
return job
else:
return None
def finish_job(self, job):
'''Finishes a running job. Removes it from the running jobs list and adds it to the done jobs list.
Args:
job (WorkerJob): the job to put into state 'done'
'''
index = next((i for i in range(len(self.jobs_running)) if self.jobs_running[i].id == job.id), -1)
if index >= 0:
self.jobs_running.pop(index)
self.jobs_done.append(job)
log_traffic('%s - Moved %s from running to done.' % (self.name(), job))
else:
log_warn('%s - There is no job with ID %d registered as running.' % (self.name(), job.id))
def done(self):
'''Checks, if all jobs of the epoch are in state 'done'.
Returns:
bool. if all jobs of the epoch are 'done'
'''
if len(self.jobs_open) == 0 and len(self.jobs_running) == 0:
num_jobs = len(self.jobs_done)
if num_jobs > 0:
jobs = self.jobs_done
self.jobs_done = []
if not self.num_jobs == num_jobs:
log_warn('%s - Number of steps not equal to number of jobs done.' % (self.name()))
agg_loss = 0.0
for i in range(num_jobs):
job = jobs.pop(0)
agg_loss += job.loss
self.loss = agg_loss / num_jobs
# if the job was for validation dataset then append it to the COORD's _loss for early stop verification
if (FLAGS.early_stop is True) and (self.set_name == 'dev'):
COORD._dev_losses.append(self.loss)
return True
return False
def job_status(self):
'''Provides a printable overview of the states of the jobs of this epoch.
Returns:
str. printable overall job state
'''
return '%s - jobs open: %d, jobs running: %d, jobs done: %d' % (self.name(), len(self.jobs_open), len(self.jobs_running), len(self.jobs_done))
def __str__(self):
if not self.done():
return self.job_status()
return '%s - loss: %f' % (self.name(), self.loss)
class TrainingCoordinator(object):
''' Central training coordination class.
Used for distributing jobs among workers of a cluster.
Instantiated on all workers, calls of non-chief workers will transparently
HTTP-forwarded to the chief worker instance.
'''
class TrainingCoordinationHandler(BaseHTTPServer.BaseHTTPRequestHandler):
'''Handles HTTP requests from remote workers to the Training Coordinator.
'''
def _send_answer(self, data=None):
self.send_response(200)
self.send_header('content-type', 'text/plain')
self.end_headers()
if data:
self.wfile.write(data)
def do_GET(self):
if COORD.started:
if self.path.startswith(PREFIX_NEXT_INDEX):
index = COORD.get_next_index(self.path[len(PREFIX_NEXT_INDEX):])
if index >= 0:
self._send_answer(str(index).encode("utf-8"))
return
elif self.path.startswith(PREFIX_GET_JOB):
job = COORD.get_job(worker=int(self.path[len(PREFIX_GET_JOB):]))
if job:
self._send_answer(pickle.dumps(job))
return
self.send_response(204) # end of training
else:
self.send_response(202) # not ready yet
self.end_headers()
def do_POST(self):
if COORD.started:
src = self.rfile.read(int(self.headers['content-length']))
job = COORD.next_job(pickle.loads(src))
if job:
self._send_answer(pickle.dumps(job))
return
self.send_response(204) # end of training
else:
self.send_response(202) # not ready yet
self.end_headers()
def log_message(self, format, *args):
'''Overriding base method to suppress web handler messages on stdout.
'''
return
def __init__(self, is_chief):
self._init()
self._lock = Lock()
self.started = False
self.is_chief = is_chief
if is_chief:
self._httpd = BaseHTTPServer.HTTPServer((FLAGS.coord_host, FLAGS.coord_port), TrainingCoordinator.TrainingCoordinationHandler)
def _reset_counters(self):
self._index_train = 0
self._index_dev = 0
def _init(self):
self._epochs_running = []
self._epochs_done = []
self._reset_counters()
self._dev_losses = []
def _log_all_jobs(self):
'''Use this to debug-print epoch state'''
log_debug('Epochs - running: %d, done: %d' % (len(self._epochs_running), len(self._epochs_done)))
for epoch in self._epochs_running:
log_debug(' - running: ' + epoch.job_status())
def start_coordination(self, model_feeder, step=0):
'''Starts to coordinate epochs and jobs among workers on base of
data-set sizes, the (global) step and FLAGS parameters.
Args:
model_feeder (ModelFeeder): data-sets to be used for coordinated training
Kwargs:
step (int): global step of a loaded model to determine starting point
'''
with self._lock:
self._init()
# Number of GPUs per worker - fixed for now by local reality or cluster setup
gpus_per_worker = len(C.available_devices)
# Number of batches processed per job per worker
batches_per_job = gpus_per_worker * max(1, FLAGS.iters_per_worker)
# Number of batches per global step
batches_per_step = gpus_per_worker * max(1, FLAGS.replicas_to_agg)
# Number of global steps per epoch - to be at least 1
steps_per_epoch = max(1, model_feeder.train.total_batches // batches_per_step)
# The start epoch of our training
self._epoch = step // steps_per_epoch
# Number of additional 'jobs' trained already 'on top of' our start epoch
jobs_trained = (step % steps_per_epoch) * batches_per_step // batches_per_job
# Total number of train/dev jobs covering their respective whole sets (one epoch)
self._num_jobs_train = max(1, model_feeder.train.total_batches // batches_per_job)
self._num_jobs_dev = max(1, model_feeder.dev.total_batches // batches_per_job)
if FLAGS.epoch < 0:
# A negative epoch means to add its absolute number to the epochs already computed
self._target_epoch = self._epoch + abs(FLAGS.epoch)
else:
self._target_epoch = FLAGS.epoch
# State variables
# We only have to train, if we are told so and are not at the target epoch yet
self._train = FLAGS.train and self._target_epoch > self._epoch
if self._train:
# The total number of jobs for all additional epochs to be trained
# Will be decremented for each job that is produced/put into state 'open'
self._num_jobs_train_left = (self._target_epoch - self._epoch) * self._num_jobs_train - jobs_trained
log_info('STARTING Optimization')
self._training_time = stopwatch()
# Important for debugging
log_debug('step: %d' % step)
log_debug('epoch: %d' % self._epoch)
log_debug('target epoch: %d' % self._target_epoch)
log_debug('steps per epoch: %d' % steps_per_epoch)
log_debug('number of batches in train set: %d' % model_feeder.train.total_batches)
log_debug('batches per job: %d' % batches_per_job)
log_debug('batches per step: %d' % batches_per_step)
log_debug('number of jobs in train set: %d' % self._num_jobs_train)
log_debug('number of jobs already trained in first epoch: %d' % jobs_trained)
self._next_epoch()
# The coordinator is ready to serve
self.started = True
def _next_epoch(self):
# State-machine of the coordination process
# Indicates, if there were 'new' epoch(s) provided
result = False
# Make sure that early stop is enabled and validation part is enabled
if (FLAGS.early_stop is True) and (FLAGS.validation_step > 0) and (len(self._dev_losses) >= FLAGS.earlystop_nsteps):
# Calculate the mean of losses for past epochs
mean_loss = np.mean(self._dev_losses[-FLAGS.earlystop_nsteps:-1])
# Calculate the standard deviation for losses from validation part in the past epochs
std_loss = np.std(self._dev_losses[-FLAGS.earlystop_nsteps:-1])
# Update the list of losses incurred
self._dev_losses = self._dev_losses[-FLAGS.earlystop_nsteps:]
log_debug('Checking for early stopping (last %d steps) validation loss: %f, with standard deviation: %f and mean: %f' % (FLAGS.earlystop_nsteps, self._dev_losses[-1], std_loss, mean_loss))
# Check if validation loss has started increasing or is not decreasing substantially, making sure slight fluctuations don't bother the early stopping from working
if self._dev_losses[-1] > np.max(self._dev_losses[:-1]) or (abs(self._dev_losses[-1] - mean_loss) < FLAGS.estop_mean_thresh and std_loss < FLAGS.estop_std_thresh):
# Time to early stop
log_info('Early stop triggered as (for last %d steps) validation loss: %f with standard deviation: %f and mean: %f' % (FLAGS.earlystop_nsteps, self._dev_losses[-1], std_loss, mean_loss))
self._dev_losses = []
self._end_training()
self._train = False
if self._train:
# We are in train mode
if self._num_jobs_train_left > 0:
# There are still jobs left
num_jobs_train = min(self._num_jobs_train_left, self._num_jobs_train)
self._num_jobs_train_left -= num_jobs_train
# Let's try our best to keep the notion of curriculum learning
self._reset_counters()
# Append the training epoch
self._epochs_running.append(Epoch(self._epoch, num_jobs_train, set_name='train'))
if FLAGS.validation_step > 0 and (FLAGS.validation_step == 1 or self._epoch > 0) and self._epoch % FLAGS.validation_step == 0:
# The current epoch should also have a validation part
self._epochs_running.append(Epoch(self._epoch, self._num_jobs_dev, set_name='dev'))
# Indicating that there were 'new' epoch(s) provided
result = True
else:
# No jobs left, but still in train mode: concluding training
self._end_training()
self._train = False
if result:
# Increment the epoch index
self._epoch += 1
return result
def _end_training(self):
self._training_time = stopwatch(self._training_time)
log_info('FINISHED Optimization - training time: %s' % format_duration(self._training_time))
def start(self):
'''Starts Training Coordinator. If chief, it starts a web server for
communication with non-chief instances.
'''
if self.is_chief:
log_debug('Starting coordinator...')
self._thread = Thread(target=self._httpd.serve_forever)
self._thread.daemon = True
self._thread.start()
log_debug('Coordinator started.')
def stop(self, wait_for_running_epochs=True):
'''Stops Training Coordinator. If chief, it waits for all epochs to be
'done' and then shuts down the web server.
'''
if self.is_chief:
if wait_for_running_epochs:
while len(self._epochs_running) > 0:
log_traffic('Coordinator is waiting for epochs to finish...')
time.sleep(5)
log_debug('Stopping coordinator...')
self._httpd.shutdown()
log_debug('Coordinator stopped.')
def _talk_to_chief(self, path, data=None, default=None):
tries = 0
while tries < FLAGS.coord_retries:
tries += 1
try:
url = 'http://%s:%d%s' % (FLAGS.coord_host, FLAGS.coord_port, path)
log_traffic('Contacting coordinator - url: %s, tries: %d ...' % (url, tries-1))
res = urllib.request.urlopen(urllib.request.Request(url, data, { 'content-type': 'text/plain' }))
str = res.read()
status = res.getcode()
log_traffic('Coordinator responded - url: %s, status: %s' % (url, status))
if status == 200:
return str
if status == 204: # We use 204 (no content) to indicate end of training
return default
except urllib.error.HTTPError as error:
log_traffic('Problem reaching coordinator - url: %s, HTTP code: %d' % (url, error.code))
pass
time.sleep(10)
return default
def get_next_index(self, set_name):
'''Retrives a new cluster-unique batch index for a given set-name.
Prevents applying one batch multiple times per epoch.
Args:
set_name (str): name of the data set - one of 'train', 'dev'
Returns:
int. new data set index
'''
with self._lock:
if self.is_chief:
member = '_index_' + set_name
value = getattr(self, member, -1)
setattr(self, member, value + 1)
return value
else:
# We are a remote worker and have to hand over to the chief worker by HTTP
log_traffic('Asking for next index...')
value = int(self._talk_to_chief(PREFIX_NEXT_INDEX + set_name))
log_traffic('Got index %d.' % value)
return value
def _get_job(self, worker=0):
job = None
# Find first running epoch that provides a next job
for epoch in self._epochs_running:
job = epoch.get_job(worker)
if job:
return job
# No next job found
return None
def get_job(self, worker=0):
'''Retrieves the first job for a worker.
Kwargs:
worker (int): index of the worker to get the first job for
Returns:
WorkerJob. a job of one of the running epochs that will get
associated with the given worker and put into state 'running'
'''
# Let's ensure that this does not interfere with other workers/requests
with self._lock:
if self.is_chief:
# First try to get a next job
job = self._get_job(worker)
if job is None:
# If there was no next job, we give it a second chance by triggering the epoch state machine
if self._next_epoch():
# Epoch state machine got a new epoch
# Second try to get a next job
job = self._get_job(worker)
if job is None:
# Albeit the epoch state machine got a new epoch, the epoch had no new job for us
log_error('Unexpected case - no job for worker %d.' % (worker))
return job
# Epoch state machine has no new epoch
# This happens at the end of the whole training - nothing to worry about
log_traffic('No jobs left for worker %d.' % (worker))
self._log_all_jobs()
return None
# We got a new job from one of the currently running epochs
log_traffic('Got new %s' % job)
return job
# We are a remote worker and have to hand over to the chief worker by HTTP
result = self._talk_to_chief(PREFIX_GET_JOB + str(FLAGS.task_index))
if result:
result = pickle.loads(result)
return result
def next_job(self, job):
'''Sends a finished job back to the coordinator and retrieves in exchange the next one.
Kwargs:
job (WorkerJob): job that was finished by a worker and who's results are to be
digested by the coordinator
Returns:
WorkerJob. next job of one of the running epochs that will get
associated with the worker from the finished job and put into state 'running'
'''
if self.is_chief:
# Try to find the epoch the job belongs to
epoch = next((epoch for epoch in self._epochs_running if epoch.id == job.epoch_id), None)
if epoch:
# We are going to manipulate things - let's avoid undefined state
with self._lock:
# Let the epoch finish the job
epoch.finish_job(job)
# Check, if epoch is done now
if epoch.done():
# If it declares itself done, move it from 'running' to 'done' collection
self._epochs_running.remove(epoch)
self._epochs_done.append(epoch)
log_info('%s' % epoch)
else:
# There was no running epoch found for this job - this should never happen.
log_error('There is no running epoch of ID %d for job with ID %d. This should never happen.' % (job.epoch_id, job.id))
return self.get_job(job.worker)
# We are a remote worker and have to hand over to the chief worker by HTTP
result = self._talk_to_chief('', data=pickle.dumps(job))
if result:
result = pickle.loads(result)
return result
class GlobalConfig:
_config = None
def __getattr__(self, name):
if not GlobalConfig._config:
raise RuntimeError("Global configuration not yet initialized.")
if not hasattr(GlobalConfig._config, name):
raise RuntimeError("Configuration option {} not found in config.".format(name))
return GlobalConfig._config[name]
C = GlobalConfig()
def initialize_globals():
c = AttrDict()
# ps and worker hosts required for p2p cluster setup
FLAGS.ps_hosts = list(filter(len, FLAGS.ps_hosts.split(',')))
FLAGS.worker_hosts = list(filter(len, FLAGS.worker_hosts.split(',')))
# The absolute number of computing nodes - regardless of cluster or single mode
c.num_workers = max(1, len(FLAGS.worker_hosts))
# Create a cluster from the parameter server and worker hosts.
c.cluster = tf.train.ClusterSpec({'ps': FLAGS.ps_hosts, 'worker': FLAGS.worker_hosts})
# If replica numbers are negative, we multiply their absolute values with the number of workers
if FLAGS.replicas < 0:
FLAGS.replicas = c.num_workers * -FLAGS.replicas
if FLAGS.replicas_to_agg < 0:
FLAGS.replicas_to_agg = c.num_workers * -FLAGS.replicas_to_agg
# The device path base for this node
c.worker_device = '/job:%s/task:%d' % (FLAGS.job_name, FLAGS.task_index)
# This node's CPU device
c.cpu_device = c.worker_device + '/cpu:0'
# This node's available GPU devices
c.available_devices = [c.worker_device + gpu for gpu in get_available_gpus()]
# If there is no GPU available, we fall back to CPU based operation
if 0 == len(c.available_devices):
c.available_devices = [c.cpu_device]
# Set default dropout rates
if FLAGS.dropout_rate2 < 0:
FLAGS.dropout_rate2 = FLAGS.dropout_rate
if FLAGS.dropout_rate3 < 0:
FLAGS.dropout_rate3 = FLAGS.dropout_rate
if FLAGS.dropout_rate6 < 0:
FLAGS.dropout_rate6 = FLAGS.dropout_rate
c.no_dropout = [ 0.0 ] * 6
# Set default checkpoint dir
if len(FLAGS.checkpoint_dir) == 0:
FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('deepspeech','checkpoints'))
# Set default summary dir
if len(FLAGS.summary_dir) == 0:
FLAGS.summary_dir = xdg.save_data_path(os.path.join('deepspeech','summaries'))
# Standard session configuration that'll be used for all new sessions.
c.session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement,
inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads)
c.alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path))
# Geometric Constants
# ===================
# For an explanation of the meaning of the geometric constants, please refer to
# doc/Geometry.md
# Number of MFCC features
c.n_input = 26 # TODO: Determine this programatically from the sample rate
# The number of frames in the context
c.n_context = 9 # TODO: Determine the optimal value using a validation data set
# Number of units in hidden layers
c.n_hidden = FLAGS.n_hidden
c.n_hidden_1 = c.n_hidden
c.n_hidden_2 = c.n_hidden
c.n_hidden_5 = c.n_hidden
# LSTM cell state dimension
c.n_cell_dim = c.n_hidden
# The number of units in the third layer, which feeds in to the LSTM
c.n_hidden_3 = c.n_cell_dim
# The number of characters in the target language plus one
c.n_character = c.alphabet.size() + 1 # +1 for CTC blank label
# The number of units in the sixth layer
c.n_hidden_6 = c.n_character
# Queues that are used to gracefully stop parameter servers.
# Each queue stands for one ps. A finishing worker sends a token to each queue before joining/quitting.
# Each ps will dequeue as many tokens as there are workers before joining/quitting.
# This ensures parameter servers won't quit, if still required by at least one worker and
# also won't wait forever (like with a standard `server.join()`).
c.done_queues = []
for i, ps in enumerate(FLAGS.ps_hosts):
# Queues are hosted by their respective owners
with tf.device('/job:ps/task:%d' % i):
c.done_queues.append(tf.FIFOQueue(1, tf.int32, shared_name=('queue%i' % i)))
# Placeholder to pass in the worker's index as token
c.token_placeholder = tf.placeholder(tf.int32)
# Enqueue operations for each parameter server
c.done_enqueues = [queue.enqueue(token_placeholder) for queue in c.done_queues]
# Dequeue operations for each parameter server
c.done_dequeues = [queue.dequeue() for queue in c.done_queues]
if len(FLAGS.one_shot_infer) > 0:
FLAGS.train = False
FLAGS.test = False
FLAGS.export_dir = ''
if not os.path.exists(FLAGS.one_shot_infer):
log_error('Path specified in --one_shot_infer is not a valid file.')
exit(1)
# Determine, if we are the chief worker
c.is_chief = len(FLAGS.worker_hosts) == 0 or (FLAGS.task_index == 0 and FLAGS.job_name == 'worker')
# Initializing and starting the training coordinator
c.COORD = TrainingCoordinator(c.is_chief)
c.COORD.start()
GlobalConfig._config = c

View File

@ -12,14 +12,13 @@ class ModelFeeder(object):
'''
Feeds data into a model.
Feeding is parallelized by independent units called tower feeders (usually one per GPU).
Each tower feeder provides data from three runtime switchable sources (train, dev, test).
These sources are to be provided by three DataSet instances whos references are kept.
Each tower feeder provides data from runtime switchable sources (train, dev).
These sources are to be provided by the DataSet instances whose references are kept.
Creates, owns and delegates to tower_feeder_count internal tower feeder objects.
'''
def __init__(self,
train_set,
dev_set,
test_set,
numcep,
numcontext,
alphabet,
@ -28,8 +27,7 @@ class ModelFeeder(object):
self.train = train_set
self.dev = dev_set
self.test = test_set
self.sets = [train_set, dev_set, test_set]
self.sets = [train_set, dev_set]
self.numcep = numcep
self.numcontext = numcontext
self.tower_feeder_count = max(len(get_available_gpus()), 1) if tower_feeder_count < 0 else tower_feeder_count

142
util/flags.py Normal file
View File

@ -0,0 +1,142 @@
from __future__ import print_function
import tensorflow as tf
from xdg import BaseDirectory as xdg
FLAGS = tf.app.flags.FLAGS
def create_flags():
# Importer
# ========
tf.app.flags.DEFINE_string ('train_files', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
tf.app.flags.DEFINE_string ('dev_files', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
tf.app.flags.DEFINE_string ('test_files', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
tf.app.flags.DEFINE_boolean ('fulltrace', False, 'if full trace debug info should be generated during training')
tf.app.flags.DEFINE_string ('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
tf.app.flags.DEFINE_string ('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
tf.app.flags.DEFINE_string ('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
# Cluster configuration
# =====================
tf.app.flags.DEFINE_string ('ps_hosts', '', 'parameter servers - comma separated list of hostname:port pairs')
tf.app.flags.DEFINE_string ('worker_hosts', '', 'workers - comma separated list of hostname:port pairs')
tf.app.flags.DEFINE_string ('job_name', 'localhost', 'job name - one of localhost (default), worker, ps')
tf.app.flags.DEFINE_integer ('task_index', 0, 'index of task within the job - worker with index 0 will be the chief')
tf.app.flags.DEFINE_integer ('replicas', -1, 'total number of replicas - if negative, its absolute value is multiplied by the number of workers')
tf.app.flags.DEFINE_integer ('replicas_to_agg', -1, 'number of replicas to aggregate - if negative, its absolute value is multiplied by the number of workers')
tf.app.flags.DEFINE_integer ('coord_retries', 100, 'number of tries of workers connecting to training coordinator before failing')
tf.app.flags.DEFINE_string ('coord_host', 'localhost', 'coordination server host')
tf.app.flags.DEFINE_integer ('coord_port', 2500, 'coordination server port')
tf.app.flags.DEFINE_integer ('iters_per_worker', 1, 'number of train or inference iterations per worker before results are sent back to coordinator')
# Global Constants
# ================
tf.app.flags.DEFINE_boolean ('train', True, 'whether to train the network')
tf.app.flags.DEFINE_boolean ('test', True, 'whether to test the network')
tf.app.flags.DEFINE_integer ('epoch', 75, 'target epoch to train - if negative, the absolute number of additional epochs will be trained')
tf.app.flags.DEFINE_float ('dropout_rate', 0.05, 'dropout rate for feedforward layers')
tf.app.flags.DEFINE_float ('dropout_rate2', -1.0, 'dropout rate for layer 2 - defaults to dropout_rate')
tf.app.flags.DEFINE_float ('dropout_rate3', -1.0, 'dropout rate for layer 3 - defaults to dropout_rate')
tf.app.flags.DEFINE_float ('dropout_rate4', 0.0, 'dropout rate for layer 4 - defaults to 0.0')
tf.app.flags.DEFINE_float ('dropout_rate5', 0.0, 'dropout rate for layer 5 - defaults to 0.0')
tf.app.flags.DEFINE_float ('dropout_rate6', -1.0, 'dropout rate for layer 6 - defaults to dropout_rate')
tf.app.flags.DEFINE_float ('relu_clip', 20.0, 'ReLU clipping value for non-recurrant layers')
# Adam optimizer (http://arxiv.org/abs/1412.6980) parameters
tf.app.flags.DEFINE_float ('beta1', 0.9, 'beta 1 parameter of Adam optimizer')
tf.app.flags.DEFINE_float ('beta2', 0.999, 'beta 2 parameter of Adam optimizer')
tf.app.flags.DEFINE_float ('epsilon', 1e-8, 'epsilon parameter of Adam optimizer')
tf.app.flags.DEFINE_float ('learning_rate', 0.001, 'learning rate of Adam optimizer')
# Batch sizes
tf.app.flags.DEFINE_integer ('train_batch_size', 1, 'number of elements in a training batch')
tf.app.flags.DEFINE_integer ('dev_batch_size', 1, 'number of elements in a validation batch')
tf.app.flags.DEFINE_integer ('test_batch_size', 1, 'number of elements in a test batch')
tf.app.flags.DEFINE_integer ('export_batch_size', 1, 'number of elements per batch on the exported graph')
# Performance (UNSUPPORTED)
tf.app.flags.DEFINE_integer ('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details')
tf.app.flags.DEFINE_integer ('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details')
# Sample limits
tf.app.flags.DEFINE_integer ('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit')
tf.app.flags.DEFINE_integer ('limit_dev', 0, 'maximum number of elements to use from validation set- 0 means no limit')
tf.app.flags.DEFINE_integer ('limit_test', 0, 'maximum number of elements to use from test set- 0 means no limit')
# Step widths
tf.app.flags.DEFINE_integer ('validation_step', 0, 'number of epochs we cycle through before validating the model - 0 means no validation steps')
# Checkpointing
tf.app.flags.DEFINE_string ('checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
tf.app.flags.DEFINE_integer ('checkpoint_secs', 600, 'checkpoint saving interval in seconds')
tf.app.flags.DEFINE_integer ('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5')
# Exporting
tf.app.flags.DEFINE_string ('export_dir', '', 'directory in which exported models are stored - if omitted, the model won\'t get exported')
tf.app.flags.DEFINE_integer ('export_version', 1, 'version number of the exported model')
tf.app.flags.DEFINE_boolean ('remove_export', False, 'whether to remove old exported models')
tf.app.flags.DEFINE_boolean ('export_tflite', False, 'export a graph ready for TF Lite engine')
tf.app.flags.DEFINE_boolean ('use_seq_length', True, 'have sequence_length in the exported graph (will make tfcompile unhappy)')
tf.app.flags.DEFINE_integer ('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
# Reporting
tf.app.flags.DEFINE_integer ('log_level', 1, 'log level for console logs - 0: INFO, 1: WARN, 2: ERROR, 3: FATAL')
tf.app.flags.DEFINE_boolean ('log_traffic', False, 'log cluster transaction and traffic information during debug logging')
tf.app.flags.DEFINE_boolean ('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.')
tf.app.flags.DEFINE_boolean ('log_placement', False, 'whether to log device placement of the operators to the console')
tf.app.flags.DEFINE_integer ('report_count', 10, 'number of phrases with lowest WER (best matching) to print out during a WER report')
tf.app.flags.DEFINE_string ('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification')
tf.app.flags.DEFINE_integer ('summary_secs', 0, 'interval in seconds for saving TensorBoard summaries - if 0, no summaries will be written')
# Geometry
tf.app.flags.DEFINE_integer ('n_hidden', 2048, 'layer width to use when initialising layers')
# Initialization
tf.app.flags.DEFINE_integer ('random_seed', 4567, 'default random seed that is used to initialize variables')
# Early Stopping
tf.app.flags.DEFINE_boolean ('early_stop', True, 'enable early stopping mechanism over validation dataset. Make sure that dev FLAG is enabled for this to work')
# This parameter is irrespective of the time taken by single epoch to complete and checkpoint saving intervals.
# It is possible that early stopping is triggered far after the best checkpoint is already replaced by checkpoint saving interval mechanism.
# One has to align the parameters (earlystop_nsteps, checkpoint_secs) accordingly as per the time taken by an epoch on different datasets.
tf.app.flags.DEFINE_integer ('earlystop_nsteps', 4, 'number of steps to consider for early stopping. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point')
tf.app.flags.DEFINE_float ('estop_mean_thresh', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
tf.app.flags.DEFINE_float ('estop_std_thresh', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')
# Decoder
tf.app.flags.DEFINE_string ('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.')
tf.app.flags.DEFINE_string ('lm_binary_path', 'data/lm/lm.binary', 'path to the language model binary file created with KenLM')
tf.app.flags.DEFINE_string ('lm_trie_path', 'data/lm/trie', 'path to the language model trie file created with native_client/generate_trie')
tf.app.flags.DEFINE_integer ('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions')
tf.app.flags.DEFINE_float ('lm_alpha', 1.50, 'the alpha hyperparameter of the CTC decoder. Language Model weight.')
tf.app.flags.DEFINE_float ('lm_beta', 2.10, 'the beta hyperparameter of the CTC decoder. Word insertion weight.')
# Inference mode
tf.app.flags.DEFINE_string ('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it. Disables training, testing and exporting.')

35
util/logging.py Normal file
View File

@ -0,0 +1,35 @@
from __future__ import print_function
from util.flags import FLAGS
# Logging functions
# =================
def prefix_print(prefix, message):
print(prefix + ('\n' + prefix).join(message.split('\n')))
def log_debug(message):
if FLAGS.log_level == 0:
prefix_print('D ', message)
def log_traffic(message):
if FLAGS.log_traffic:
log_debug(message)
def log_info(message):
if FLAGS.log_level <= 1:
prefix_print('I ', message)
def log_warn(message):
if FLAGS.log_level <= 2:
prefix_print('W ', message)
def log_error(message):
if FLAGS.log_level <= 3:
prefix_print('E ', message)

View File

@ -1,46 +0,0 @@
from __future__ import print_function
from __future__ import absolute_import
from util.gpu import get_available_gpus
from ctypes import cdll
from sys import platform as _platform
def get_cupti_libname():
if _platform == 'linux' or _platform == 'linux2':
return 'libcupti.so'
elif _platform == 'darwin':
return 'libcupti.dylib'
elif _platform == 'win32':
return 'libcupti.dll'
def check_cupti():
# We want to ensure that user has properly configured its libs.
# We do this because dso load of libcupti will happen after a lot
# of computation happened, so easy to miss and loose time.
libname = get_cupti_libname()
cupti = check_so(libname)
if cupti is None:
print("INFO: No %s because no GPU, go ahead." % libname)
elif cupti is True:
print("INFO: Found %s." % libname)
else:
print("WARNING: Running on GPU but no %s could be found ; will be unable to report GPU VRAM usage." % libname)
def check_so(soname):
"""
Verify that we do have the 'soname' lib present in the system, and that it
can be loaded.
"""
if len(get_available_gpus()) == 0:
return None
# Try to force load lib, this would fail if the lib is not there :)
try:
lib = cdll.LoadLibrary(soname)
print("INFO: Found so as", lib)
assert lib.__class__.__name__ == 'CDLL'
assert lib._name == soname
return True
except OSError as ex:
print("WARNING:", ex)
return False