Centralize WER report code into evaluate.py, call it from DeepSpeech.py
This commit is contained in:
parent
60fb5ad04c
commit
56dc024d29
3
.compute
3
.compute
@ -21,5 +21,4 @@ python3 -u DeepSpeech.py \
|
||||
--display_step 0 \
|
||||
--validation_step 1 \
|
||||
--checkpoint_dir "../keep" \
|
||||
--summary_dir "../keep/summaries" \
|
||||
--decoder_library_path "../tmp/native_client/libctc_decoder_with_kenlm.so"
|
||||
--summary_dir "../keep/summaries"
|
||||
|
6
.install
6
.install
@ -7,4 +7,10 @@ pip install tensorflow-gpu==1.12.0rc2
|
||||
|
||||
python3 util/taskcluster.py --arch gpu --target ../tmp/native_client
|
||||
|
||||
# Install ds_ctcdecoder package from TaskCluster
|
||||
VERSION=$(python -c 'import pkg_resources; print(pkg_resources.safe_version(open("VERSION").read()))')
|
||||
PYVER=$(python -c 'import sys; print("cp{0}{1}-cp{0}{1}m".format(sys.version_info.major, sys.version_info.minor))')
|
||||
python3 util/taskcluster.py --arch cpu --target ../tmp --artifact "ds_ctcdecoder-${VERSION}-${PYVER}-manylinux1_x86_64.whl"
|
||||
pip install ../tmp/ds_ctcdecoder-*.whl
|
||||
|
||||
mkdir -p ../keep/summaries
|
||||
|
1395
DeepSpeech.py
1395
DeepSpeech.py
File diff suppressed because it is too large
Load Diff
70
evaluate.py
70
evaluate.py
@ -15,8 +15,10 @@ import tensorflow as tf
|
||||
from attrdict import AttrDict
|
||||
from collections import namedtuple
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||
from DeepSpeech import initialize_globals, create_flags, log_debug, log_info, log_warn, log_error, create_inference_graph
|
||||
from multiprocessing import Pool
|
||||
from util.flags import create_flags
|
||||
from util.coordinator import C, initialize_globals
|
||||
from util.logging import log_debug, log_info, log_warn, log_error
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from six.moves import zip, range
|
||||
from util.audio import audiofile_to_input_vector
|
||||
from util.text import Alphabet, ctc_label_dense_to_sparse, wer, levenshtein
|
||||
@ -86,31 +88,11 @@ def calculate_report(labels, decodings, distances, losses):
|
||||
return samples_wer, samples
|
||||
|
||||
|
||||
def main(_):
|
||||
initialize_globals()
|
||||
|
||||
if not FLAGS.test_files:
|
||||
log_error('You need to specify what files to use for evaluation via '
|
||||
'the --test_files flag.')
|
||||
exit(1)
|
||||
|
||||
global alphabet
|
||||
alphabet = Alphabet(FLAGS.alphabet_config_path)
|
||||
|
||||
scorer = Scorer(FLAGS.lm_weight, FLAGS.valid_word_count_weight,
|
||||
def evaluate(test_data, inference_graph, alphabet):
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
|
||||
alphabet)
|
||||
C.alphabet)
|
||||
|
||||
# sort examples by length, improves packing of batches and timesteps
|
||||
test_data = preprocess(
|
||||
FLAGS.test_files.split(','),
|
||||
FLAGS.test_batch_size,
|
||||
alphabet=alphabet,
|
||||
numcep=N_FEATURES,
|
||||
numcontext=N_CONTEXT,
|
||||
hdf5_cache_path=FLAGS.hdf5_test_set).sort_values(
|
||||
by="features_len",
|
||||
ascending=False)
|
||||
|
||||
def create_windows(features):
|
||||
num_strides = len(features) - (N_CONTEXT * 2)
|
||||
@ -130,7 +112,7 @@ def main(_):
|
||||
test_data['features'] = test_data['features'].apply(create_windows)
|
||||
|
||||
with tf.Session() as session:
|
||||
inputs, outputs, layers = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1)
|
||||
inputs, outputs, layers = inference_graph
|
||||
|
||||
# Transpose to batch major for decoder
|
||||
transposed = tf.transpose(outputs['outputs'], [1, 0, 2])
|
||||
@ -192,7 +174,10 @@ def main(_):
|
||||
widget=progressbar.AdaptiveETA)
|
||||
|
||||
# Get number of accessible CPU cores for this process
|
||||
num_processes = len(os.sched_getaffinity(0))
|
||||
try:
|
||||
num_processes = cpu_count()
|
||||
except:
|
||||
num_processes = 1
|
||||
|
||||
# Second pass, decode logits and compute WER and edit distance metrics
|
||||
for logits, batch in bar(zip(logitses, split_data(test_data, FLAGS.test_batch_size))):
|
||||
@ -221,7 +206,38 @@ def main(_):
|
||||
print(' - res: "%s"' % sample.res)
|
||||
print('-' * 80)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def main(_):
|
||||
initialize_globals()
|
||||
|
||||
if not FLAGS.test_files:
|
||||
log_error('You need to specify what files to use for evaluation via '
|
||||
'the --test_files flag.')
|
||||
exit(1)
|
||||
|
||||
global alphabet
|
||||
alphabet = Alphabet(FLAGS.alphabet_config_path)
|
||||
|
||||
# sort examples by length, improves packing of batches and timesteps
|
||||
test_data = preprocess(
|
||||
FLAGS.test_files.split(','),
|
||||
FLAGS.test_batch_size,
|
||||
alphabet=alphabet,
|
||||
numcep=N_FEATURES,
|
||||
numcontext=N_CONTEXT,
|
||||
hdf5_cache_path=FLAGS.hdf5_test_set).sort_values(
|
||||
by="features_len",
|
||||
ascending=False)
|
||||
|
||||
from DeepSpeech import create_inference_graph
|
||||
graph = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1)
|
||||
|
||||
samples = evaluate(test_data, graph, alphabet)
|
||||
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=lambda x: float(x))
|
||||
|
||||
|
||||
|
@ -13,3 +13,4 @@ bs4
|
||||
six
|
||||
requests
|
||||
tables
|
||||
attrdict
|
||||
|
706
util/coordinator.py
Normal file
706
util/coordinator.py
Normal file
@ -0,0 +1,706 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import tensorflow as tf
|
||||
|
||||
from attrdict import AttrDict
|
||||
from datetime import datetime
|
||||
from six.moves import zip, range, filter, urllib, BaseHTTPServer
|
||||
from threading import Thread, Lock
|
||||
from util.gpu import get_available_gpus
|
||||
from util.flags import FLAGS
|
||||
from util.logging import *
|
||||
from util.text import Alphabet
|
||||
from xdg import BaseDirectory as xdg
|
||||
|
||||
# Execution
|
||||
# =========
|
||||
|
||||
# For reporting we also need a standard way to do time measurements.
|
||||
def stopwatch(start_duration=0):
|
||||
r'''
|
||||
This function will toggle a stopwatch.
|
||||
The first call starts it, second call stops it, third call continues it etc.
|
||||
So if you want to measure the accumulated time spent in a certain area of the code,
|
||||
you can surround that code by stopwatch-calls like this:
|
||||
|
||||
.. code:: python
|
||||
|
||||
fun_time = 0 # initializes a stopwatch
|
||||
[...]
|
||||
for i in range(10):
|
||||
[...]
|
||||
# Starts/continues the stopwatch - fun_time is now a point in time (again)
|
||||
fun_time = stopwatch(fun_time)
|
||||
fun()
|
||||
# Pauses the stopwatch - fun_time is now a duration
|
||||
fun_time = stopwatch(fun_time)
|
||||
[...]
|
||||
# The following line only makes sense after an even call of :code:`fun_time = stopwatch(fun_time)`.
|
||||
print 'Time spent in fun():', format_duration(fun_time)
|
||||
|
||||
'''
|
||||
if start_duration == 0:
|
||||
return datetime.utcnow()
|
||||
else:
|
||||
return datetime.utcnow() - start_duration
|
||||
|
||||
def format_duration(duration):
|
||||
'''Formats the result of an even stopwatch call as hours:minutes:seconds'''
|
||||
duration = duration if isinstance(duration, int) else duration.seconds
|
||||
m, s = divmod(duration, 60)
|
||||
h, m = divmod(m, 60)
|
||||
return '%d:%02d:%02d' % (h, m, s)
|
||||
|
||||
|
||||
# String constants for different services of the web handler
|
||||
PREFIX_NEXT_INDEX = '/next_index_'
|
||||
PREFIX_GET_JOB = '/get_job_'
|
||||
|
||||
# Global ID counter for all objects requiring an ID
|
||||
id_counter = 0
|
||||
|
||||
def new_id():
|
||||
'''Returns a new ID that is unique on process level. Not thread-safe.
|
||||
|
||||
Returns:
|
||||
int. The new ID
|
||||
'''
|
||||
global id_counter
|
||||
id_counter += 1
|
||||
return id_counter
|
||||
|
||||
class WorkerJob(object):
|
||||
'''Represents a job that should be executed by a worker.
|
||||
|
||||
Args:
|
||||
epoch_id (int): the ID of the 'parent' epoch
|
||||
index (int): the epoch index of the 'parent' epoch
|
||||
set_name (str): the name of the data-set - one of 'train', 'dev'
|
||||
steps (int): the number of `session.run` calls
|
||||
'''
|
||||
def __init__(self, epoch_id, index, set_name, steps):
|
||||
self.id = new_id()
|
||||
self.epoch_id = epoch_id
|
||||
self.index = index
|
||||
self.worker = -1
|
||||
self.set_name = set_name
|
||||
self.steps = steps
|
||||
self.loss = -1
|
||||
self.samples = []
|
||||
|
||||
def __str__(self):
|
||||
return 'Job (ID: %d, worker: %d, epoch: %d, set_name: %s)' % (self.id, self.worker, self.index, self.set_name)
|
||||
|
||||
class Epoch(object):
|
||||
'''Represents an epoch that should be executed by the Training Coordinator.
|
||||
Creates `num_jobs` `WorkerJob` instances in state 'open'.
|
||||
|
||||
Args:
|
||||
index (int): the epoch index of the 'parent' epoch
|
||||
num_jobs (int): the number of jobs in this epoch
|
||||
|
||||
Kwargs:
|
||||
set_name (str): the name of the data-set - one of 'train', 'dev'
|
||||
'''
|
||||
def __init__(self, index, num_jobs, set_name='train'):
|
||||
self.id = new_id()
|
||||
self.index = index
|
||||
self.num_jobs = num_jobs
|
||||
self.set_name = set_name
|
||||
self.loss = -1
|
||||
self.jobs_open = []
|
||||
self.jobs_running = []
|
||||
self.jobs_done = []
|
||||
for i in range(self.num_jobs):
|
||||
self.jobs_open.append(WorkerJob(self.id, self.index, self.set_name, FLAGS.iters_per_worker))
|
||||
|
||||
def name(self):
|
||||
'''Gets a printable name for this epoch.
|
||||
|
||||
Returns:
|
||||
str. printable name for this epoch
|
||||
'''
|
||||
if self.index >= 0:
|
||||
ename = ' of Epoch %d' % self.index
|
||||
else:
|
||||
ename = ''
|
||||
if self.set_name == 'train':
|
||||
return 'Training%s' % ename
|
||||
else:
|
||||
return 'Validation%s' % ename
|
||||
|
||||
def get_job(self, worker):
|
||||
'''Gets the next open job from this epoch. The job will be marked as 'running'.
|
||||
|
||||
Args:
|
||||
worker (int): index of the worker that takes the job
|
||||
|
||||
Returns:
|
||||
WorkerJob. job that has been marked as running for this worker
|
||||
'''
|
||||
if len(self.jobs_open) > 0:
|
||||
job = self.jobs_open.pop(0)
|
||||
self.jobs_running.append(job)
|
||||
job.worker = worker
|
||||
return job
|
||||
else:
|
||||
return None
|
||||
|
||||
def finish_job(self, job):
|
||||
'''Finishes a running job. Removes it from the running jobs list and adds it to the done jobs list.
|
||||
|
||||
Args:
|
||||
job (WorkerJob): the job to put into state 'done'
|
||||
'''
|
||||
index = next((i for i in range(len(self.jobs_running)) if self.jobs_running[i].id == job.id), -1)
|
||||
if index >= 0:
|
||||
self.jobs_running.pop(index)
|
||||
self.jobs_done.append(job)
|
||||
log_traffic('%s - Moved %s from running to done.' % (self.name(), job))
|
||||
else:
|
||||
log_warn('%s - There is no job with ID %d registered as running.' % (self.name(), job.id))
|
||||
|
||||
def done(self):
|
||||
'''Checks, if all jobs of the epoch are in state 'done'.
|
||||
|
||||
Returns:
|
||||
bool. if all jobs of the epoch are 'done'
|
||||
'''
|
||||
if len(self.jobs_open) == 0 and len(self.jobs_running) == 0:
|
||||
num_jobs = len(self.jobs_done)
|
||||
if num_jobs > 0:
|
||||
jobs = self.jobs_done
|
||||
self.jobs_done = []
|
||||
if not self.num_jobs == num_jobs:
|
||||
log_warn('%s - Number of steps not equal to number of jobs done.' % (self.name()))
|
||||
|
||||
agg_loss = 0.0
|
||||
|
||||
for i in range(num_jobs):
|
||||
job = jobs.pop(0)
|
||||
agg_loss += job.loss
|
||||
|
||||
self.loss = agg_loss / num_jobs
|
||||
|
||||
# if the job was for validation dataset then append it to the COORD's _loss for early stop verification
|
||||
if (FLAGS.early_stop is True) and (self.set_name == 'dev'):
|
||||
COORD._dev_losses.append(self.loss)
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
def job_status(self):
|
||||
'''Provides a printable overview of the states of the jobs of this epoch.
|
||||
|
||||
Returns:
|
||||
str. printable overall job state
|
||||
'''
|
||||
return '%s - jobs open: %d, jobs running: %d, jobs done: %d' % (self.name(), len(self.jobs_open), len(self.jobs_running), len(self.jobs_done))
|
||||
|
||||
def __str__(self):
|
||||
if not self.done():
|
||||
return self.job_status()
|
||||
|
||||
return '%s - loss: %f' % (self.name(), self.loss)
|
||||
|
||||
|
||||
class TrainingCoordinator(object):
|
||||
''' Central training coordination class.
|
||||
Used for distributing jobs among workers of a cluster.
|
||||
Instantiated on all workers, calls of non-chief workers will transparently
|
||||
HTTP-forwarded to the chief worker instance.
|
||||
'''
|
||||
|
||||
class TrainingCoordinationHandler(BaseHTTPServer.BaseHTTPRequestHandler):
|
||||
'''Handles HTTP requests from remote workers to the Training Coordinator.
|
||||
'''
|
||||
def _send_answer(self, data=None):
|
||||
self.send_response(200)
|
||||
self.send_header('content-type', 'text/plain')
|
||||
self.end_headers()
|
||||
if data:
|
||||
self.wfile.write(data)
|
||||
|
||||
def do_GET(self):
|
||||
if COORD.started:
|
||||
if self.path.startswith(PREFIX_NEXT_INDEX):
|
||||
index = COORD.get_next_index(self.path[len(PREFIX_NEXT_INDEX):])
|
||||
if index >= 0:
|
||||
self._send_answer(str(index).encode("utf-8"))
|
||||
return
|
||||
elif self.path.startswith(PREFIX_GET_JOB):
|
||||
job = COORD.get_job(worker=int(self.path[len(PREFIX_GET_JOB):]))
|
||||
if job:
|
||||
self._send_answer(pickle.dumps(job))
|
||||
return
|
||||
self.send_response(204) # end of training
|
||||
else:
|
||||
self.send_response(202) # not ready yet
|
||||
self.end_headers()
|
||||
|
||||
def do_POST(self):
|
||||
if COORD.started:
|
||||
src = self.rfile.read(int(self.headers['content-length']))
|
||||
job = COORD.next_job(pickle.loads(src))
|
||||
if job:
|
||||
self._send_answer(pickle.dumps(job))
|
||||
return
|
||||
self.send_response(204) # end of training
|
||||
else:
|
||||
self.send_response(202) # not ready yet
|
||||
self.end_headers()
|
||||
|
||||
def log_message(self, format, *args):
|
||||
'''Overriding base method to suppress web handler messages on stdout.
|
||||
'''
|
||||
return
|
||||
|
||||
def __init__(self, is_chief):
|
||||
self._init()
|
||||
self._lock = Lock()
|
||||
self.started = False
|
||||
self.is_chief = is_chief
|
||||
if is_chief:
|
||||
self._httpd = BaseHTTPServer.HTTPServer((FLAGS.coord_host, FLAGS.coord_port), TrainingCoordinator.TrainingCoordinationHandler)
|
||||
|
||||
def _reset_counters(self):
|
||||
self._index_train = 0
|
||||
self._index_dev = 0
|
||||
|
||||
def _init(self):
|
||||
self._epochs_running = []
|
||||
self._epochs_done = []
|
||||
self._reset_counters()
|
||||
self._dev_losses = []
|
||||
|
||||
def _log_all_jobs(self):
|
||||
'''Use this to debug-print epoch state'''
|
||||
log_debug('Epochs - running: %d, done: %d' % (len(self._epochs_running), len(self._epochs_done)))
|
||||
for epoch in self._epochs_running:
|
||||
log_debug(' - running: ' + epoch.job_status())
|
||||
|
||||
def start_coordination(self, model_feeder, step=0):
|
||||
'''Starts to coordinate epochs and jobs among workers on base of
|
||||
data-set sizes, the (global) step and FLAGS parameters.
|
||||
|
||||
Args:
|
||||
model_feeder (ModelFeeder): data-sets to be used for coordinated training
|
||||
|
||||
Kwargs:
|
||||
step (int): global step of a loaded model to determine starting point
|
||||
'''
|
||||
with self._lock:
|
||||
self._init()
|
||||
|
||||
# Number of GPUs per worker - fixed for now by local reality or cluster setup
|
||||
gpus_per_worker = len(C.available_devices)
|
||||
|
||||
# Number of batches processed per job per worker
|
||||
batches_per_job = gpus_per_worker * max(1, FLAGS.iters_per_worker)
|
||||
|
||||
# Number of batches per global step
|
||||
batches_per_step = gpus_per_worker * max(1, FLAGS.replicas_to_agg)
|
||||
|
||||
# Number of global steps per epoch - to be at least 1
|
||||
steps_per_epoch = max(1, model_feeder.train.total_batches // batches_per_step)
|
||||
|
||||
# The start epoch of our training
|
||||
self._epoch = step // steps_per_epoch
|
||||
|
||||
# Number of additional 'jobs' trained already 'on top of' our start epoch
|
||||
jobs_trained = (step % steps_per_epoch) * batches_per_step // batches_per_job
|
||||
|
||||
# Total number of train/dev jobs covering their respective whole sets (one epoch)
|
||||
self._num_jobs_train = max(1, model_feeder.train.total_batches // batches_per_job)
|
||||
self._num_jobs_dev = max(1, model_feeder.dev.total_batches // batches_per_job)
|
||||
|
||||
if FLAGS.epoch < 0:
|
||||
# A negative epoch means to add its absolute number to the epochs already computed
|
||||
self._target_epoch = self._epoch + abs(FLAGS.epoch)
|
||||
else:
|
||||
self._target_epoch = FLAGS.epoch
|
||||
|
||||
# State variables
|
||||
# We only have to train, if we are told so and are not at the target epoch yet
|
||||
self._train = FLAGS.train and self._target_epoch > self._epoch
|
||||
|
||||
if self._train:
|
||||
# The total number of jobs for all additional epochs to be trained
|
||||
# Will be decremented for each job that is produced/put into state 'open'
|
||||
self._num_jobs_train_left = (self._target_epoch - self._epoch) * self._num_jobs_train - jobs_trained
|
||||
log_info('STARTING Optimization')
|
||||
self._training_time = stopwatch()
|
||||
|
||||
# Important for debugging
|
||||
log_debug('step: %d' % step)
|
||||
log_debug('epoch: %d' % self._epoch)
|
||||
log_debug('target epoch: %d' % self._target_epoch)
|
||||
log_debug('steps per epoch: %d' % steps_per_epoch)
|
||||
log_debug('number of batches in train set: %d' % model_feeder.train.total_batches)
|
||||
log_debug('batches per job: %d' % batches_per_job)
|
||||
log_debug('batches per step: %d' % batches_per_step)
|
||||
log_debug('number of jobs in train set: %d' % self._num_jobs_train)
|
||||
log_debug('number of jobs already trained in first epoch: %d' % jobs_trained)
|
||||
|
||||
self._next_epoch()
|
||||
|
||||
# The coordinator is ready to serve
|
||||
self.started = True
|
||||
|
||||
def _next_epoch(self):
|
||||
# State-machine of the coordination process
|
||||
|
||||
# Indicates, if there were 'new' epoch(s) provided
|
||||
result = False
|
||||
|
||||
# Make sure that early stop is enabled and validation part is enabled
|
||||
if (FLAGS.early_stop is True) and (FLAGS.validation_step > 0) and (len(self._dev_losses) >= FLAGS.earlystop_nsteps):
|
||||
|
||||
# Calculate the mean of losses for past epochs
|
||||
mean_loss = np.mean(self._dev_losses[-FLAGS.earlystop_nsteps:-1])
|
||||
# Calculate the standard deviation for losses from validation part in the past epochs
|
||||
std_loss = np.std(self._dev_losses[-FLAGS.earlystop_nsteps:-1])
|
||||
# Update the list of losses incurred
|
||||
self._dev_losses = self._dev_losses[-FLAGS.earlystop_nsteps:]
|
||||
log_debug('Checking for early stopping (last %d steps) validation loss: %f, with standard deviation: %f and mean: %f' % (FLAGS.earlystop_nsteps, self._dev_losses[-1], std_loss, mean_loss))
|
||||
|
||||
# Check if validation loss has started increasing or is not decreasing substantially, making sure slight fluctuations don't bother the early stopping from working
|
||||
if self._dev_losses[-1] > np.max(self._dev_losses[:-1]) or (abs(self._dev_losses[-1] - mean_loss) < FLAGS.estop_mean_thresh and std_loss < FLAGS.estop_std_thresh):
|
||||
# Time to early stop
|
||||
log_info('Early stop triggered as (for last %d steps) validation loss: %f with standard deviation: %f and mean: %f' % (FLAGS.earlystop_nsteps, self._dev_losses[-1], std_loss, mean_loss))
|
||||
self._dev_losses = []
|
||||
self._end_training()
|
||||
self._train = False
|
||||
|
||||
if self._train:
|
||||
# We are in train mode
|
||||
if self._num_jobs_train_left > 0:
|
||||
# There are still jobs left
|
||||
num_jobs_train = min(self._num_jobs_train_left, self._num_jobs_train)
|
||||
self._num_jobs_train_left -= num_jobs_train
|
||||
|
||||
# Let's try our best to keep the notion of curriculum learning
|
||||
self._reset_counters()
|
||||
|
||||
# Append the training epoch
|
||||
self._epochs_running.append(Epoch(self._epoch, num_jobs_train, set_name='train'))
|
||||
|
||||
if FLAGS.validation_step > 0 and (FLAGS.validation_step == 1 or self._epoch > 0) and self._epoch % FLAGS.validation_step == 0:
|
||||
# The current epoch should also have a validation part
|
||||
self._epochs_running.append(Epoch(self._epoch, self._num_jobs_dev, set_name='dev'))
|
||||
|
||||
|
||||
# Indicating that there were 'new' epoch(s) provided
|
||||
result = True
|
||||
else:
|
||||
# No jobs left, but still in train mode: concluding training
|
||||
self._end_training()
|
||||
self._train = False
|
||||
|
||||
if result:
|
||||
# Increment the epoch index
|
||||
self._epoch += 1
|
||||
return result
|
||||
|
||||
def _end_training(self):
|
||||
self._training_time = stopwatch(self._training_time)
|
||||
log_info('FINISHED Optimization - training time: %s' % format_duration(self._training_time))
|
||||
|
||||
def start(self):
|
||||
'''Starts Training Coordinator. If chief, it starts a web server for
|
||||
communication with non-chief instances.
|
||||
'''
|
||||
if self.is_chief:
|
||||
log_debug('Starting coordinator...')
|
||||
self._thread = Thread(target=self._httpd.serve_forever)
|
||||
self._thread.daemon = True
|
||||
self._thread.start()
|
||||
log_debug('Coordinator started.')
|
||||
|
||||
def stop(self, wait_for_running_epochs=True):
|
||||
'''Stops Training Coordinator. If chief, it waits for all epochs to be
|
||||
'done' and then shuts down the web server.
|
||||
'''
|
||||
if self.is_chief:
|
||||
if wait_for_running_epochs:
|
||||
while len(self._epochs_running) > 0:
|
||||
log_traffic('Coordinator is waiting for epochs to finish...')
|
||||
time.sleep(5)
|
||||
log_debug('Stopping coordinator...')
|
||||
self._httpd.shutdown()
|
||||
log_debug('Coordinator stopped.')
|
||||
|
||||
def _talk_to_chief(self, path, data=None, default=None):
|
||||
tries = 0
|
||||
while tries < FLAGS.coord_retries:
|
||||
tries += 1
|
||||
try:
|
||||
url = 'http://%s:%d%s' % (FLAGS.coord_host, FLAGS.coord_port, path)
|
||||
log_traffic('Contacting coordinator - url: %s, tries: %d ...' % (url, tries-1))
|
||||
res = urllib.request.urlopen(urllib.request.Request(url, data, { 'content-type': 'text/plain' }))
|
||||
str = res.read()
|
||||
status = res.getcode()
|
||||
log_traffic('Coordinator responded - url: %s, status: %s' % (url, status))
|
||||
if status == 200:
|
||||
return str
|
||||
if status == 204: # We use 204 (no content) to indicate end of training
|
||||
return default
|
||||
except urllib.error.HTTPError as error:
|
||||
log_traffic('Problem reaching coordinator - url: %s, HTTP code: %d' % (url, error.code))
|
||||
pass
|
||||
time.sleep(10)
|
||||
return default
|
||||
|
||||
def get_next_index(self, set_name):
|
||||
'''Retrives a new cluster-unique batch index for a given set-name.
|
||||
Prevents applying one batch multiple times per epoch.
|
||||
|
||||
Args:
|
||||
set_name (str): name of the data set - one of 'train', 'dev'
|
||||
|
||||
Returns:
|
||||
int. new data set index
|
||||
'''
|
||||
with self._lock:
|
||||
if self.is_chief:
|
||||
member = '_index_' + set_name
|
||||
value = getattr(self, member, -1)
|
||||
setattr(self, member, value + 1)
|
||||
return value
|
||||
else:
|
||||
# We are a remote worker and have to hand over to the chief worker by HTTP
|
||||
log_traffic('Asking for next index...')
|
||||
value = int(self._talk_to_chief(PREFIX_NEXT_INDEX + set_name))
|
||||
log_traffic('Got index %d.' % value)
|
||||
return value
|
||||
|
||||
def _get_job(self, worker=0):
|
||||
job = None
|
||||
# Find first running epoch that provides a next job
|
||||
for epoch in self._epochs_running:
|
||||
job = epoch.get_job(worker)
|
||||
if job:
|
||||
return job
|
||||
# No next job found
|
||||
return None
|
||||
|
||||
def get_job(self, worker=0):
|
||||
'''Retrieves the first job for a worker.
|
||||
|
||||
Kwargs:
|
||||
worker (int): index of the worker to get the first job for
|
||||
|
||||
Returns:
|
||||
WorkerJob. a job of one of the running epochs that will get
|
||||
associated with the given worker and put into state 'running'
|
||||
'''
|
||||
# Let's ensure that this does not interfere with other workers/requests
|
||||
with self._lock:
|
||||
if self.is_chief:
|
||||
# First try to get a next job
|
||||
job = self._get_job(worker)
|
||||
|
||||
if job is None:
|
||||
# If there was no next job, we give it a second chance by triggering the epoch state machine
|
||||
if self._next_epoch():
|
||||
# Epoch state machine got a new epoch
|
||||
# Second try to get a next job
|
||||
job = self._get_job(worker)
|
||||
if job is None:
|
||||
# Albeit the epoch state machine got a new epoch, the epoch had no new job for us
|
||||
log_error('Unexpected case - no job for worker %d.' % (worker))
|
||||
return job
|
||||
|
||||
# Epoch state machine has no new epoch
|
||||
# This happens at the end of the whole training - nothing to worry about
|
||||
log_traffic('No jobs left for worker %d.' % (worker))
|
||||
self._log_all_jobs()
|
||||
return None
|
||||
|
||||
# We got a new job from one of the currently running epochs
|
||||
log_traffic('Got new %s' % job)
|
||||
return job
|
||||
|
||||
# We are a remote worker and have to hand over to the chief worker by HTTP
|
||||
result = self._talk_to_chief(PREFIX_GET_JOB + str(FLAGS.task_index))
|
||||
if result:
|
||||
result = pickle.loads(result)
|
||||
return result
|
||||
|
||||
def next_job(self, job):
|
||||
'''Sends a finished job back to the coordinator and retrieves in exchange the next one.
|
||||
|
||||
Kwargs:
|
||||
job (WorkerJob): job that was finished by a worker and who's results are to be
|
||||
digested by the coordinator
|
||||
|
||||
Returns:
|
||||
WorkerJob. next job of one of the running epochs that will get
|
||||
associated with the worker from the finished job and put into state 'running'
|
||||
'''
|
||||
if self.is_chief:
|
||||
# Try to find the epoch the job belongs to
|
||||
epoch = next((epoch for epoch in self._epochs_running if epoch.id == job.epoch_id), None)
|
||||
if epoch:
|
||||
# We are going to manipulate things - let's avoid undefined state
|
||||
with self._lock:
|
||||
# Let the epoch finish the job
|
||||
epoch.finish_job(job)
|
||||
# Check, if epoch is done now
|
||||
if epoch.done():
|
||||
# If it declares itself done, move it from 'running' to 'done' collection
|
||||
self._epochs_running.remove(epoch)
|
||||
self._epochs_done.append(epoch)
|
||||
log_info('%s' % epoch)
|
||||
else:
|
||||
# There was no running epoch found for this job - this should never happen.
|
||||
log_error('There is no running epoch of ID %d for job with ID %d. This should never happen.' % (job.epoch_id, job.id))
|
||||
return self.get_job(job.worker)
|
||||
|
||||
# We are a remote worker and have to hand over to the chief worker by HTTP
|
||||
result = self._talk_to_chief('', data=pickle.dumps(job))
|
||||
if result:
|
||||
result = pickle.loads(result)
|
||||
return result
|
||||
|
||||
class GlobalConfig:
|
||||
_config = None
|
||||
|
||||
def __getattr__(self, name):
|
||||
if not GlobalConfig._config:
|
||||
raise RuntimeError("Global configuration not yet initialized.")
|
||||
if not hasattr(GlobalConfig._config, name):
|
||||
raise RuntimeError("Configuration option {} not found in config.".format(name))
|
||||
return GlobalConfig._config[name]
|
||||
|
||||
C = GlobalConfig()
|
||||
|
||||
def initialize_globals():
|
||||
c = AttrDict()
|
||||
|
||||
# ps and worker hosts required for p2p cluster setup
|
||||
FLAGS.ps_hosts = list(filter(len, FLAGS.ps_hosts.split(',')))
|
||||
FLAGS.worker_hosts = list(filter(len, FLAGS.worker_hosts.split(',')))
|
||||
|
||||
# The absolute number of computing nodes - regardless of cluster or single mode
|
||||
c.num_workers = max(1, len(FLAGS.worker_hosts))
|
||||
|
||||
# Create a cluster from the parameter server and worker hosts.
|
||||
c.cluster = tf.train.ClusterSpec({'ps': FLAGS.ps_hosts, 'worker': FLAGS.worker_hosts})
|
||||
|
||||
# If replica numbers are negative, we multiply their absolute values with the number of workers
|
||||
if FLAGS.replicas < 0:
|
||||
FLAGS.replicas = c.num_workers * -FLAGS.replicas
|
||||
if FLAGS.replicas_to_agg < 0:
|
||||
FLAGS.replicas_to_agg = c.num_workers * -FLAGS.replicas_to_agg
|
||||
|
||||
# The device path base for this node
|
||||
c.worker_device = '/job:%s/task:%d' % (FLAGS.job_name, FLAGS.task_index)
|
||||
|
||||
# This node's CPU device
|
||||
c.cpu_device = c.worker_device + '/cpu:0'
|
||||
|
||||
# This node's available GPU devices
|
||||
c.available_devices = [c.worker_device + gpu for gpu in get_available_gpus()]
|
||||
|
||||
# If there is no GPU available, we fall back to CPU based operation
|
||||
if 0 == len(c.available_devices):
|
||||
c.available_devices = [c.cpu_device]
|
||||
|
||||
# Set default dropout rates
|
||||
if FLAGS.dropout_rate2 < 0:
|
||||
FLAGS.dropout_rate2 = FLAGS.dropout_rate
|
||||
if FLAGS.dropout_rate3 < 0:
|
||||
FLAGS.dropout_rate3 = FLAGS.dropout_rate
|
||||
if FLAGS.dropout_rate6 < 0:
|
||||
FLAGS.dropout_rate6 = FLAGS.dropout_rate
|
||||
|
||||
c.no_dropout = [ 0.0 ] * 6
|
||||
|
||||
# Set default checkpoint dir
|
||||
if len(FLAGS.checkpoint_dir) == 0:
|
||||
FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('deepspeech','checkpoints'))
|
||||
|
||||
# Set default summary dir
|
||||
if len(FLAGS.summary_dir) == 0:
|
||||
FLAGS.summary_dir = xdg.save_data_path(os.path.join('deepspeech','summaries'))
|
||||
|
||||
# Standard session configuration that'll be used for all new sessions.
|
||||
c.session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement,
|
||||
inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
|
||||
intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads)
|
||||
|
||||
c.alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path))
|
||||
|
||||
# Geometric Constants
|
||||
# ===================
|
||||
|
||||
# For an explanation of the meaning of the geometric constants, please refer to
|
||||
# doc/Geometry.md
|
||||
|
||||
# Number of MFCC features
|
||||
c.n_input = 26 # TODO: Determine this programatically from the sample rate
|
||||
|
||||
# The number of frames in the context
|
||||
c.n_context = 9 # TODO: Determine the optimal value using a validation data set
|
||||
|
||||
# Number of units in hidden layers
|
||||
c.n_hidden = FLAGS.n_hidden
|
||||
|
||||
c.n_hidden_1 = c.n_hidden
|
||||
|
||||
c.n_hidden_2 = c.n_hidden
|
||||
|
||||
c.n_hidden_5 = c.n_hidden
|
||||
|
||||
# LSTM cell state dimension
|
||||
c.n_cell_dim = c.n_hidden
|
||||
|
||||
# The number of units in the third layer, which feeds in to the LSTM
|
||||
c.n_hidden_3 = c.n_cell_dim
|
||||
|
||||
# The number of characters in the target language plus one
|
||||
c.n_character = c.alphabet.size() + 1 # +1 for CTC blank label
|
||||
|
||||
# The number of units in the sixth layer
|
||||
c.n_hidden_6 = c.n_character
|
||||
|
||||
# Queues that are used to gracefully stop parameter servers.
|
||||
# Each queue stands for one ps. A finishing worker sends a token to each queue before joining/quitting.
|
||||
# Each ps will dequeue as many tokens as there are workers before joining/quitting.
|
||||
# This ensures parameter servers won't quit, if still required by at least one worker and
|
||||
# also won't wait forever (like with a standard `server.join()`).
|
||||
c.done_queues = []
|
||||
for i, ps in enumerate(FLAGS.ps_hosts):
|
||||
# Queues are hosted by their respective owners
|
||||
with tf.device('/job:ps/task:%d' % i):
|
||||
c.done_queues.append(tf.FIFOQueue(1, tf.int32, shared_name=('queue%i' % i)))
|
||||
|
||||
# Placeholder to pass in the worker's index as token
|
||||
c.token_placeholder = tf.placeholder(tf.int32)
|
||||
|
||||
# Enqueue operations for each parameter server
|
||||
c.done_enqueues = [queue.enqueue(token_placeholder) for queue in c.done_queues]
|
||||
|
||||
# Dequeue operations for each parameter server
|
||||
c.done_dequeues = [queue.dequeue() for queue in c.done_queues]
|
||||
|
||||
if len(FLAGS.one_shot_infer) > 0:
|
||||
FLAGS.train = False
|
||||
FLAGS.test = False
|
||||
FLAGS.export_dir = ''
|
||||
if not os.path.exists(FLAGS.one_shot_infer):
|
||||
log_error('Path specified in --one_shot_infer is not a valid file.')
|
||||
exit(1)
|
||||
|
||||
# Determine, if we are the chief worker
|
||||
c.is_chief = len(FLAGS.worker_hosts) == 0 or (FLAGS.task_index == 0 and FLAGS.job_name == 'worker')
|
||||
|
||||
# Initializing and starting the training coordinator
|
||||
c.COORD = TrainingCoordinator(c.is_chief)
|
||||
c.COORD.start()
|
||||
|
||||
GlobalConfig._config = c
|
||||
|
@ -12,14 +12,13 @@ class ModelFeeder(object):
|
||||
'''
|
||||
Feeds data into a model.
|
||||
Feeding is parallelized by independent units called tower feeders (usually one per GPU).
|
||||
Each tower feeder provides data from three runtime switchable sources (train, dev, test).
|
||||
These sources are to be provided by three DataSet instances whos references are kept.
|
||||
Each tower feeder provides data from runtime switchable sources (train, dev).
|
||||
These sources are to be provided by the DataSet instances whose references are kept.
|
||||
Creates, owns and delegates to tower_feeder_count internal tower feeder objects.
|
||||
'''
|
||||
def __init__(self,
|
||||
train_set,
|
||||
dev_set,
|
||||
test_set,
|
||||
numcep,
|
||||
numcontext,
|
||||
alphabet,
|
||||
@ -28,8 +27,7 @@ class ModelFeeder(object):
|
||||
|
||||
self.train = train_set
|
||||
self.dev = dev_set
|
||||
self.test = test_set
|
||||
self.sets = [train_set, dev_set, test_set]
|
||||
self.sets = [train_set, dev_set]
|
||||
self.numcep = numcep
|
||||
self.numcontext = numcontext
|
||||
self.tower_feeder_count = max(len(get_available_gpus()), 1) if tower_feeder_count < 0 else tower_feeder_count
|
||||
|
142
util/flags.py
Normal file
142
util/flags.py
Normal file
@ -0,0 +1,142 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from xdg import BaseDirectory as xdg
|
||||
|
||||
|
||||
FLAGS = tf.app.flags.FLAGS
|
||||
|
||||
|
||||
def create_flags():
|
||||
# Importer
|
||||
# ========
|
||||
|
||||
tf.app.flags.DEFINE_string ('train_files', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
|
||||
tf.app.flags.DEFINE_string ('dev_files', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
|
||||
tf.app.flags.DEFINE_string ('test_files', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
|
||||
tf.app.flags.DEFINE_boolean ('fulltrace', False, 'if full trace debug info should be generated during training')
|
||||
|
||||
tf.app.flags.DEFINE_string ('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
|
||||
tf.app.flags.DEFINE_string ('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
|
||||
tf.app.flags.DEFINE_string ('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
|
||||
|
||||
# Cluster configuration
|
||||
# =====================
|
||||
|
||||
tf.app.flags.DEFINE_string ('ps_hosts', '', 'parameter servers - comma separated list of hostname:port pairs')
|
||||
tf.app.flags.DEFINE_string ('worker_hosts', '', 'workers - comma separated list of hostname:port pairs')
|
||||
tf.app.flags.DEFINE_string ('job_name', 'localhost', 'job name - one of localhost (default), worker, ps')
|
||||
tf.app.flags.DEFINE_integer ('task_index', 0, 'index of task within the job - worker with index 0 will be the chief')
|
||||
tf.app.flags.DEFINE_integer ('replicas', -1, 'total number of replicas - if negative, its absolute value is multiplied by the number of workers')
|
||||
tf.app.flags.DEFINE_integer ('replicas_to_agg', -1, 'number of replicas to aggregate - if negative, its absolute value is multiplied by the number of workers')
|
||||
tf.app.flags.DEFINE_integer ('coord_retries', 100, 'number of tries of workers connecting to training coordinator before failing')
|
||||
tf.app.flags.DEFINE_string ('coord_host', 'localhost', 'coordination server host')
|
||||
tf.app.flags.DEFINE_integer ('coord_port', 2500, 'coordination server port')
|
||||
tf.app.flags.DEFINE_integer ('iters_per_worker', 1, 'number of train or inference iterations per worker before results are sent back to coordinator')
|
||||
|
||||
# Global Constants
|
||||
# ================
|
||||
|
||||
tf.app.flags.DEFINE_boolean ('train', True, 'whether to train the network')
|
||||
tf.app.flags.DEFINE_boolean ('test', True, 'whether to test the network')
|
||||
tf.app.flags.DEFINE_integer ('epoch', 75, 'target epoch to train - if negative, the absolute number of additional epochs will be trained')
|
||||
|
||||
tf.app.flags.DEFINE_float ('dropout_rate', 0.05, 'dropout rate for feedforward layers')
|
||||
tf.app.flags.DEFINE_float ('dropout_rate2', -1.0, 'dropout rate for layer 2 - defaults to dropout_rate')
|
||||
tf.app.flags.DEFINE_float ('dropout_rate3', -1.0, 'dropout rate for layer 3 - defaults to dropout_rate')
|
||||
tf.app.flags.DEFINE_float ('dropout_rate4', 0.0, 'dropout rate for layer 4 - defaults to 0.0')
|
||||
tf.app.flags.DEFINE_float ('dropout_rate5', 0.0, 'dropout rate for layer 5 - defaults to 0.0')
|
||||
tf.app.flags.DEFINE_float ('dropout_rate6', -1.0, 'dropout rate for layer 6 - defaults to dropout_rate')
|
||||
|
||||
tf.app.flags.DEFINE_float ('relu_clip', 20.0, 'ReLU clipping value for non-recurrant layers')
|
||||
|
||||
# Adam optimizer (http://arxiv.org/abs/1412.6980) parameters
|
||||
|
||||
tf.app.flags.DEFINE_float ('beta1', 0.9, 'beta 1 parameter of Adam optimizer')
|
||||
tf.app.flags.DEFINE_float ('beta2', 0.999, 'beta 2 parameter of Adam optimizer')
|
||||
tf.app.flags.DEFINE_float ('epsilon', 1e-8, 'epsilon parameter of Adam optimizer')
|
||||
tf.app.flags.DEFINE_float ('learning_rate', 0.001, 'learning rate of Adam optimizer')
|
||||
|
||||
# Batch sizes
|
||||
|
||||
tf.app.flags.DEFINE_integer ('train_batch_size', 1, 'number of elements in a training batch')
|
||||
tf.app.flags.DEFINE_integer ('dev_batch_size', 1, 'number of elements in a validation batch')
|
||||
tf.app.flags.DEFINE_integer ('test_batch_size', 1, 'number of elements in a test batch')
|
||||
|
||||
tf.app.flags.DEFINE_integer ('export_batch_size', 1, 'number of elements per batch on the exported graph')
|
||||
|
||||
# Performance (UNSUPPORTED)
|
||||
tf.app.flags.DEFINE_integer ('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details')
|
||||
tf.app.flags.DEFINE_integer ('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details')
|
||||
|
||||
# Sample limits
|
||||
|
||||
tf.app.flags.DEFINE_integer ('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit')
|
||||
tf.app.flags.DEFINE_integer ('limit_dev', 0, 'maximum number of elements to use from validation set- 0 means no limit')
|
||||
tf.app.flags.DEFINE_integer ('limit_test', 0, 'maximum number of elements to use from test set- 0 means no limit')
|
||||
|
||||
# Step widths
|
||||
|
||||
tf.app.flags.DEFINE_integer ('validation_step', 0, 'number of epochs we cycle through before validating the model - 0 means no validation steps')
|
||||
|
||||
# Checkpointing
|
||||
|
||||
tf.app.flags.DEFINE_string ('checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
|
||||
tf.app.flags.DEFINE_integer ('checkpoint_secs', 600, 'checkpoint saving interval in seconds')
|
||||
tf.app.flags.DEFINE_integer ('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5')
|
||||
|
||||
# Exporting
|
||||
|
||||
tf.app.flags.DEFINE_string ('export_dir', '', 'directory in which exported models are stored - if omitted, the model won\'t get exported')
|
||||
tf.app.flags.DEFINE_integer ('export_version', 1, 'version number of the exported model')
|
||||
tf.app.flags.DEFINE_boolean ('remove_export', False, 'whether to remove old exported models')
|
||||
tf.app.flags.DEFINE_boolean ('export_tflite', False, 'export a graph ready for TF Lite engine')
|
||||
tf.app.flags.DEFINE_boolean ('use_seq_length', True, 'have sequence_length in the exported graph (will make tfcompile unhappy)')
|
||||
tf.app.flags.DEFINE_integer ('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
|
||||
|
||||
# Reporting
|
||||
|
||||
tf.app.flags.DEFINE_integer ('log_level', 1, 'log level for console logs - 0: INFO, 1: WARN, 2: ERROR, 3: FATAL')
|
||||
tf.app.flags.DEFINE_boolean ('log_traffic', False, 'log cluster transaction and traffic information during debug logging')
|
||||
tf.app.flags.DEFINE_boolean ('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.')
|
||||
|
||||
tf.app.flags.DEFINE_boolean ('log_placement', False, 'whether to log device placement of the operators to the console')
|
||||
tf.app.flags.DEFINE_integer ('report_count', 10, 'number of phrases with lowest WER (best matching) to print out during a WER report')
|
||||
|
||||
tf.app.flags.DEFINE_string ('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification')
|
||||
tf.app.flags.DEFINE_integer ('summary_secs', 0, 'interval in seconds for saving TensorBoard summaries - if 0, no summaries will be written')
|
||||
|
||||
# Geometry
|
||||
|
||||
tf.app.flags.DEFINE_integer ('n_hidden', 2048, 'layer width to use when initialising layers')
|
||||
|
||||
# Initialization
|
||||
|
||||
tf.app.flags.DEFINE_integer ('random_seed', 4567, 'default random seed that is used to initialize variables')
|
||||
|
||||
# Early Stopping
|
||||
|
||||
tf.app.flags.DEFINE_boolean ('early_stop', True, 'enable early stopping mechanism over validation dataset. Make sure that dev FLAG is enabled for this to work')
|
||||
|
||||
# This parameter is irrespective of the time taken by single epoch to complete and checkpoint saving intervals.
|
||||
# It is possible that early stopping is triggered far after the best checkpoint is already replaced by checkpoint saving interval mechanism.
|
||||
# One has to align the parameters (earlystop_nsteps, checkpoint_secs) accordingly as per the time taken by an epoch on different datasets.
|
||||
|
||||
tf.app.flags.DEFINE_integer ('earlystop_nsteps', 4, 'number of steps to consider for early stopping. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point')
|
||||
tf.app.flags.DEFINE_float ('estop_mean_thresh', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
|
||||
tf.app.flags.DEFINE_float ('estop_std_thresh', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')
|
||||
|
||||
# Decoder
|
||||
|
||||
tf.app.flags.DEFINE_string ('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.')
|
||||
tf.app.flags.DEFINE_string ('lm_binary_path', 'data/lm/lm.binary', 'path to the language model binary file created with KenLM')
|
||||
tf.app.flags.DEFINE_string ('lm_trie_path', 'data/lm/trie', 'path to the language model trie file created with native_client/generate_trie')
|
||||
tf.app.flags.DEFINE_integer ('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions')
|
||||
tf.app.flags.DEFINE_float ('lm_alpha', 1.50, 'the alpha hyperparameter of the CTC decoder. Language Model weight.')
|
||||
tf.app.flags.DEFINE_float ('lm_beta', 2.10, 'the beta hyperparameter of the CTC decoder. Word insertion weight.')
|
||||
|
||||
# Inference mode
|
||||
|
||||
tf.app.flags.DEFINE_string ('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it. Disables training, testing and exporting.')
|
||||
|
35
util/logging.py
Normal file
35
util/logging.py
Normal file
@ -0,0 +1,35 @@
|
||||
from __future__ import print_function
|
||||
|
||||
from util.flags import FLAGS
|
||||
|
||||
|
||||
# Logging functions
|
||||
# =================
|
||||
|
||||
def prefix_print(prefix, message):
|
||||
print(prefix + ('\n' + prefix).join(message.split('\n')))
|
||||
|
||||
|
||||
def log_debug(message):
|
||||
if FLAGS.log_level == 0:
|
||||
prefix_print('D ', message)
|
||||
|
||||
|
||||
def log_traffic(message):
|
||||
if FLAGS.log_traffic:
|
||||
log_debug(message)
|
||||
|
||||
|
||||
def log_info(message):
|
||||
if FLAGS.log_level <= 1:
|
||||
prefix_print('I ', message)
|
||||
|
||||
|
||||
def log_warn(message):
|
||||
if FLAGS.log_level <= 2:
|
||||
prefix_print('W ', message)
|
||||
|
||||
|
||||
def log_error(message):
|
||||
if FLAGS.log_level <= 3:
|
||||
prefix_print('E ', message)
|
@ -1,46 +0,0 @@
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import
|
||||
from util.gpu import get_available_gpus
|
||||
from ctypes import cdll
|
||||
from sys import platform as _platform
|
||||
|
||||
def get_cupti_libname():
|
||||
if _platform == 'linux' or _platform == 'linux2':
|
||||
return 'libcupti.so'
|
||||
elif _platform == 'darwin':
|
||||
return 'libcupti.dylib'
|
||||
elif _platform == 'win32':
|
||||
return 'libcupti.dll'
|
||||
|
||||
def check_cupti():
|
||||
# We want to ensure that user has properly configured its libs.
|
||||
# We do this because dso load of libcupti will happen after a lot
|
||||
# of computation happened, so easy to miss and loose time.
|
||||
libname = get_cupti_libname()
|
||||
cupti = check_so(libname)
|
||||
if cupti is None:
|
||||
print("INFO: No %s because no GPU, go ahead." % libname)
|
||||
elif cupti is True:
|
||||
print("INFO: Found %s." % libname)
|
||||
else:
|
||||
print("WARNING: Running on GPU but no %s could be found ; will be unable to report GPU VRAM usage." % libname)
|
||||
|
||||
def check_so(soname):
|
||||
"""
|
||||
Verify that we do have the 'soname' lib present in the system, and that it
|
||||
can be loaded.
|
||||
"""
|
||||
|
||||
if len(get_available_gpus()) == 0:
|
||||
return None
|
||||
|
||||
# Try to force load lib, this would fail if the lib is not there :)
|
||||
try:
|
||||
lib = cdll.LoadLibrary(soname)
|
||||
print("INFO: Found so as", lib)
|
||||
assert lib.__class__.__name__ == 'CDLL'
|
||||
assert lib._name == soname
|
||||
return True
|
||||
except OSError as ex:
|
||||
print("WARNING:", ex)
|
||||
return False
|
Loading…
Reference in New Issue
Block a user