Fix #1986 - Remove distributed training support
This commit is contained in:
parent
a009361e47
commit
a179a2389f
4
.compute
4
.compute
@ -7,7 +7,7 @@ python3 -m venv /tmp/venv
|
||||
source /tmp/venv/bin/activate
|
||||
|
||||
pip install -r <(grep -v tensorflow requirements.txt)
|
||||
pip install tensorflow-gpu==1.13.0-rc2
|
||||
pip install tensorflow-gpu==1.13.1
|
||||
|
||||
# Install ds_ctcdecoder package from TaskCluster
|
||||
pip install $(python3 util/taskcluster.py --decoder)
|
||||
@ -30,7 +30,5 @@ python -u DeepSpeech.py \
|
||||
--learning_rate 0.0001 \
|
||||
--dropout_rate 0.2 \
|
||||
--epoch 13 \
|
||||
--display_step 0 \
|
||||
--validation_step 1 \
|
||||
--checkpoint_dir "../keep" \
|
||||
--summary_dir "../keep/summaries"
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -8,6 +8,7 @@
|
||||
/runs
|
||||
/logs
|
||||
/exports
|
||||
/data/ldc93s1
|
||||
/native_client/setup.cfg
|
||||
/native_client/build
|
||||
/native_client/*.egg-info
|
||||
|
508
DeepSpeech.py
508
DeepSpeech.py
@ -8,42 +8,35 @@ import sys
|
||||
log_level_index = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[log_level_index] if log_level_index > 0 and log_level_index < len(sys.argv) else '3'
|
||||
|
||||
import time
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import progressbar
|
||||
import shutil
|
||||
import tensorflow as tf
|
||||
import traceback
|
||||
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||
from six.moves import zip, range
|
||||
from tensorflow.python.tools import freeze_graph
|
||||
from util.audio import audiofile_to_input_vector
|
||||
from util.config import Config, initialize_globals
|
||||
from util.coordinator import TrainingCoordinator
|
||||
from util.feeding import DataSet, ModelFeeder
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.logging import log_info, log_error, log_debug, log_warn
|
||||
from util.preprocess import preprocess
|
||||
from util.text import Alphabet
|
||||
|
||||
|
||||
# Graph Creation
|
||||
# ==============
|
||||
|
||||
def variable_on_worker_level(name, shape, initializer):
|
||||
r'''
|
||||
def variable_on_cpu(name, shape, initializer):
|
||||
r"""
|
||||
Next we concern ourselves with graph creation.
|
||||
However, before we do so we must introduce a utility function ``variable_on_worker_level()``
|
||||
However, before we do so we must introduce a utility function ``variable_on_cpu()``
|
||||
used to create a variable in CPU memory.
|
||||
'''
|
||||
# Use the /cpu:0 device on worker_device for scoped operations
|
||||
if len(FLAGS.ps_hosts) == 0:
|
||||
device = Config.worker_device
|
||||
else:
|
||||
device = tf.train.replica_device_setter(worker_device=Config.worker_device, cluster=Config.cluster)
|
||||
|
||||
with tf.device(device):
|
||||
"""
|
||||
# Use the /cpu:0 device for scoped operations
|
||||
with tf.device(Config.cpu_device):
|
||||
# Create or get apropos variable
|
||||
var = tf.get_variable(name=name, shape=shape, initializer=initializer)
|
||||
return var
|
||||
@ -82,22 +75,22 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1
|
||||
# clipped RELU activation and dropout.
|
||||
|
||||
# 1st layer
|
||||
b1 = variable_on_worker_level('b1', [Config.n_hidden_1], tf.zeros_initializer())
|
||||
h1 = variable_on_worker_level('h1', [Config.n_input + 2*Config.n_input*Config.n_context, Config.n_hidden_1], tf.contrib.layers.xavier_initializer())
|
||||
b1 = variable_on_cpu('b1', [Config.n_hidden_1], tf.zeros_initializer())
|
||||
h1 = variable_on_cpu('h1', [Config.n_input + 2*Config.n_input*Config.n_context, Config.n_hidden_1], tf.contrib.layers.xavier_initializer())
|
||||
layer_1 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(batch_x, h1), b1)), FLAGS.relu_clip)
|
||||
layer_1 = tf.nn.dropout(layer_1, rate=dropout[0])
|
||||
layers['layer_1'] = layer_1
|
||||
|
||||
# 2nd layer
|
||||
b2 = variable_on_worker_level('b2', [Config.n_hidden_2], tf.zeros_initializer())
|
||||
h2 = variable_on_worker_level('h2', [Config.n_hidden_1, Config.n_hidden_2], tf.contrib.layers.xavier_initializer())
|
||||
b2 = variable_on_cpu('b2', [Config.n_hidden_2], tf.zeros_initializer())
|
||||
h2 = variable_on_cpu('h2', [Config.n_hidden_1, Config.n_hidden_2], tf.contrib.layers.xavier_initializer())
|
||||
layer_2 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(layer_1, h2), b2)), FLAGS.relu_clip)
|
||||
layer_2 = tf.nn.dropout(layer_2, rate=dropout[1])
|
||||
layers['layer_2'] = layer_2
|
||||
|
||||
# 3rd layer
|
||||
b3 = variable_on_worker_level('b3', [Config.n_hidden_3], tf.zeros_initializer())
|
||||
h3 = variable_on_worker_level('h3', [Config.n_hidden_2, Config.n_hidden_3], tf.contrib.layers.xavier_initializer())
|
||||
b3 = variable_on_cpu('b3', [Config.n_hidden_3], tf.zeros_initializer())
|
||||
h3 = variable_on_cpu('h3', [Config.n_hidden_2, Config.n_hidden_3], tf.contrib.layers.xavier_initializer())
|
||||
layer_3 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(layer_2, h3), b3)), FLAGS.relu_clip)
|
||||
layer_3 = tf.nn.dropout(layer_3, rate=dropout[2])
|
||||
layers['layer_3'] = layer_3
|
||||
@ -140,16 +133,16 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1
|
||||
layers['rnn_output_state'] = output_state
|
||||
|
||||
# Now we feed `output` to the fifth hidden layer with clipped RELU activation and dropout
|
||||
b5 = variable_on_worker_level('b5', [Config.n_hidden_5], tf.zeros_initializer())
|
||||
h5 = variable_on_worker_level('h5', [Config.n_cell_dim, Config.n_hidden_5], tf.contrib.layers.xavier_initializer())
|
||||
b5 = variable_on_cpu('b5', [Config.n_hidden_5], tf.zeros_initializer())
|
||||
h5 = variable_on_cpu('h5', [Config.n_cell_dim, Config.n_hidden_5], tf.contrib.layers.xavier_initializer())
|
||||
layer_5 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(output, h5), b5)), FLAGS.relu_clip)
|
||||
layer_5 = tf.nn.dropout(layer_5, rate=dropout[5])
|
||||
layers['layer_5'] = layer_5
|
||||
|
||||
# Now we apply the weight matrix `h6` and bias `b6` to the output of `layer_5`
|
||||
# creating `n_classes` dimensional vectors, the logits.
|
||||
b6 = variable_on_worker_level('b6', [Config.n_hidden_6], tf.zeros_initializer())
|
||||
h6 = variable_on_worker_level('h6', [Config.n_hidden_5, Config.n_hidden_6], tf.contrib.layers.xavier_initializer())
|
||||
b6 = variable_on_cpu('b6', [Config.n_hidden_6], tf.zeros_initializer())
|
||||
h6 = variable_on_cpu('h6', [Config.n_hidden_5, Config.n_hidden_6], tf.contrib.layers.xavier_initializer())
|
||||
layer_6 = tf.add(tf.matmul(layer_5, h6), b6)
|
||||
layers['layer_6'] = layer_6
|
||||
|
||||
@ -244,10 +237,7 @@ def get_tower_results(model_feeder, optimizer, dropout_rates):
|
||||
# Loop over available_devices
|
||||
for i in range(len(Config.available_devices)):
|
||||
# Execute operations of tower i on device i
|
||||
if len(FLAGS.ps_hosts) == 0:
|
||||
device = Config.available_devices[i]
|
||||
else:
|
||||
device = tf.train.replica_device_setter(worker_device=Config.available_devices[i], cluster=Config.cluster)
|
||||
device = Config.available_devices[i]
|
||||
with tf.device(device):
|
||||
# Create a scope for all operations of tower i
|
||||
with tf.name_scope('tower_%d' % i) as scope:
|
||||
@ -350,34 +340,41 @@ def log_grads_and_vars(grads_and_vars):
|
||||
# Helpers
|
||||
# =======
|
||||
|
||||
def send_token_to_ps(session, kill=False):
|
||||
# Sending our token (the task_index as a debug opportunity) to each parameter server.
|
||||
# kill switch tokens are negative and decremented by 1 to deal with task_index 0
|
||||
token = -FLAGS.task_index-1 if kill else FLAGS.task_index
|
||||
kind = 'kill switch' if kill else 'stop'
|
||||
for index, enqueue in enumerate(Config.done_enqueues):
|
||||
log_debug('Sending %s token to ps %d...' % (kind, index))
|
||||
session.run(enqueue, feed_dict={ Config.token_placeholder: token })
|
||||
log_debug('Sent %s token to ps %d.' % (kind, index))
|
||||
|
||||
class SampleIndex:
|
||||
def __init__(self, index=0):
|
||||
self.index = index
|
||||
|
||||
def inc(self, old_index):
|
||||
self.index += 1
|
||||
return self.index
|
||||
|
||||
|
||||
def train(server=None):
|
||||
def try_loading(session, saver, checkpoint_path, caption):
|
||||
try:
|
||||
saver.restore(session, checkpoint_path)
|
||||
log_info('Restored model from %s checkpoint at %s' % (caption, checkpoint_path))
|
||||
return True
|
||||
except tf.errors.InvalidArgumentError as e:
|
||||
log_error(str(e))
|
||||
log_error('The checkpoint in {0} does not match the shapes of the model.'
|
||||
' Did you change alphabet.txt or the --n_hidden parameter'
|
||||
' between train runs using the same checkpoint dir? Try moving'
|
||||
' or removing the contents of {0}.'.format(checkpoint_path))
|
||||
sys.exit(1)
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def train():
|
||||
r'''
|
||||
Trains the network on a given server of a cluster.
|
||||
If no server provided, it performs single process training.
|
||||
'''
|
||||
|
||||
# Initializing and starting the training coordinator
|
||||
coord = TrainingCoordinator(Config.is_chief)
|
||||
coord.start()
|
||||
|
||||
# Create a variable to hold the global_step.
|
||||
# It will automagically get incremented by the optimizer.
|
||||
global_step = tf.Variable(0, trainable=False, name='global_step')
|
||||
|
||||
dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
|
||||
|
||||
# Reading training set
|
||||
train_index = SampleIndex()
|
||||
|
||||
train_data = preprocess(FLAGS.train_files.split(','),
|
||||
FLAGS.train_batch_size,
|
||||
Config.n_input,
|
||||
@ -388,9 +385,11 @@ def train(server=None):
|
||||
train_set = DataSet(train_data,
|
||||
FLAGS.train_batch_size,
|
||||
limit=FLAGS.limit_train,
|
||||
next_index=lambda i: coord.get_next_index('train'))
|
||||
next_index=train_index.inc)
|
||||
|
||||
# Reading validation set
|
||||
dev_index = SampleIndex()
|
||||
|
||||
dev_data = preprocess(FLAGS.dev_files.split(','),
|
||||
FLAGS.dev_batch_size,
|
||||
Config.n_input,
|
||||
@ -401,7 +400,7 @@ def train(server=None):
|
||||
dev_set = DataSet(dev_data,
|
||||
FLAGS.dev_batch_size,
|
||||
limit=FLAGS.limit_dev,
|
||||
next_index=lambda i: coord.get_next_index('dev'))
|
||||
next_index=dev_index.inc)
|
||||
|
||||
# Combining all sets to a multi set model feeder
|
||||
model_feeder = ModelFeeder(train_set,
|
||||
@ -411,77 +410,16 @@ def train(server=None):
|
||||
Config.alphabet,
|
||||
tower_feeder_count=len(Config.available_devices))
|
||||
|
||||
# Create the optimizer
|
||||
optimizer = create_optimizer()
|
||||
|
||||
# Synchronous distributed training is facilitated by a special proxy-optimizer
|
||||
if not server is None:
|
||||
optimizer = tf.train.SyncReplicasOptimizer(optimizer,
|
||||
replicas_to_aggregate=FLAGS.replicas_to_agg,
|
||||
total_num_replicas=FLAGS.replicas)
|
||||
|
||||
# Get the data_set specific graph end-points
|
||||
gradients, loss = get_tower_results(model_feeder, optimizer, dropout_rates)
|
||||
|
||||
# Average tower gradients across GPUs
|
||||
avg_tower_gradients = average_gradients(gradients)
|
||||
|
||||
# Add summaries of all variables and gradients to log
|
||||
log_grads_and_vars(avg_tower_gradients)
|
||||
|
||||
# Op to merge all summaries for the summary hook
|
||||
merge_all_summaries_op = tf.summary.merge_all()
|
||||
|
||||
# These are saved on every step
|
||||
step_summaries_op = tf.summary.merge_all('step_summaries')
|
||||
|
||||
step_summary_writers = {
|
||||
'train': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
|
||||
'dev': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
|
||||
# Dropout
|
||||
dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
|
||||
dropout_feed_dict = {
|
||||
dropout_rates[0]: FLAGS.dropout_rate,
|
||||
dropout_rates[1]: FLAGS.dropout_rate2,
|
||||
dropout_rates[2]: FLAGS.dropout_rate3,
|
||||
dropout_rates[3]: FLAGS.dropout_rate4,
|
||||
dropout_rates[4]: FLAGS.dropout_rate5,
|
||||
dropout_rates[5]: FLAGS.dropout_rate6,
|
||||
}
|
||||
|
||||
# Apply gradients to modify the model
|
||||
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
|
||||
|
||||
|
||||
if FLAGS.early_stop is True and not FLAGS.validation_step > 0:
|
||||
log_warn('Parameter --validation_step needs to be >0 for early stopping to work')
|
||||
|
||||
class CoordHook(tf.train.SessionRunHook):
|
||||
r'''
|
||||
Embedded coordination hook-class that will use variables of the
|
||||
surrounding Python context.
|
||||
'''
|
||||
def after_create_session(self, session, coord):
|
||||
log_debug('Starting queue runners...')
|
||||
model_feeder.start_queue_threads(session, coord)
|
||||
log_debug('Queue runners started.')
|
||||
|
||||
def end(self, session):
|
||||
# Closing the data_set queues
|
||||
log_debug('Closing queues...')
|
||||
model_feeder.close_queues(session)
|
||||
log_debug('Queues closed.')
|
||||
|
||||
# Telling the ps that we are done
|
||||
send_token_to_ps(session)
|
||||
|
||||
# Collecting the hooks
|
||||
hooks = [CoordHook()]
|
||||
|
||||
# Hook to handle initialization and queues for sync replicas.
|
||||
if not server is None:
|
||||
hooks.append(optimizer.make_session_run_hook(Config.is_chief))
|
||||
|
||||
# Hook to save TensorBoard summaries
|
||||
if FLAGS.summary_secs > 0:
|
||||
hooks.append(tf.train.SummarySaverHook(save_secs=FLAGS.summary_secs, output_dir=FLAGS.summary_dir, summary_op=merge_all_summaries_op))
|
||||
|
||||
# Hook wih number of checkpoint files to save in checkpoint_dir
|
||||
if FLAGS.train and FLAGS.max_to_keep > 0:
|
||||
saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
||||
hooks.append(tf.train.CheckpointSaverHook(checkpoint_dir=FLAGS.checkpoint_dir, save_secs=FLAGS.checkpoint_secs, saver=saver))
|
||||
|
||||
no_dropout_feed_dict = {
|
||||
dropout_rates[0]: 0.,
|
||||
dropout_rates[1]: 0.,
|
||||
@ -491,156 +429,146 @@ def train(server=None):
|
||||
dropout_rates[5]: 0.,
|
||||
}
|
||||
|
||||
# Progress Bar
|
||||
def update_progressbar(set_name):
|
||||
if not hasattr(update_progressbar, 'current_set_name'):
|
||||
update_progressbar.current_set_name = None
|
||||
# Building the graph
|
||||
optimizer = create_optimizer()
|
||||
gradients, loss = get_tower_results(model_feeder, optimizer, dropout_rates)
|
||||
# Average tower gradients across GPUs
|
||||
avg_tower_gradients = average_gradients(gradients)
|
||||
log_grads_and_vars(avg_tower_gradients)
|
||||
# global_step is automagically incremented by the optimizer
|
||||
global_step = tf.Variable(0, trainable=False, name='global_step')
|
||||
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
|
||||
|
||||
if (update_progressbar.current_set_name != set_name or
|
||||
update_progressbar.current_job_index == update_progressbar.total_jobs):
|
||||
# Summaries
|
||||
step_summaries_op = tf.summary.merge_all('step_summaries')
|
||||
step_summary_writers = {
|
||||
'train': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
|
||||
'dev': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
|
||||
}
|
||||
|
||||
# finish prev pbar if it exists
|
||||
if hasattr(update_progressbar, 'pbar') and update_progressbar.pbar:
|
||||
update_progressbar.pbar.finish()
|
||||
# Checkpointing
|
||||
checkpoint_saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
||||
checkpoint_path = FLAGS.checkpoint_dir
|
||||
best_dev_saver = tf.train.Saver(max_to_keep=1)
|
||||
best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev.ckpt')
|
||||
initializer = tf.global_variables_initializer()
|
||||
|
||||
update_progressbar.total_jobs = None
|
||||
update_progressbar.current_job_index = 0
|
||||
with tf.Session(config=Config.session_config) as session:
|
||||
log_debug('Session opened.')
|
||||
tf.get_default_graph().finalize()
|
||||
|
||||
current_epoch = coord._epoch-1
|
||||
|
||||
if set_name == "train":
|
||||
log_info('Training epoch %i...' % current_epoch)
|
||||
update_progressbar.total_jobs = coord._num_jobs_train
|
||||
# Loading or initializing
|
||||
loaded = False
|
||||
if FLAGS.load in ['auto', 'last']:
|
||||
loaded = try_loading(session, checkpoint_saver, checkpoint_path, 'most recent epoch')
|
||||
if not loaded and FLAGS.load in ['auto', 'best']:
|
||||
loaded = try_loading(session, best_dev_saver, best_dev_path, 'best validation')
|
||||
if not loaded:
|
||||
if FLAGS.load in ['auto', 'init']:
|
||||
log_info('Initializing...')
|
||||
session.run(initializer)
|
||||
else:
|
||||
log_info('Validating epoch %i...' % current_epoch)
|
||||
update_progressbar.total_jobs = coord._num_jobs_dev
|
||||
|
||||
# recreate pbar
|
||||
update_progressbar.pbar = progressbar.ProgressBar(max_value=update_progressbar.total_jobs,
|
||||
redirect_stdout=True).start()
|
||||
|
||||
update_progressbar.current_set_name = set_name
|
||||
|
||||
if update_progressbar.pbar:
|
||||
update_progressbar.pbar.update(update_progressbar.current_job_index+1, force=True)
|
||||
|
||||
update_progressbar.current_job_index += 1
|
||||
|
||||
# Initialize update_progressbar()'s child fields to safe values
|
||||
update_progressbar.pbar = None
|
||||
|
||||
# The MonitoredTrainingSession takes care of session initialization,
|
||||
# restoring from a checkpoint, saving to a checkpoint, and closing when done
|
||||
# or an error occurs.
|
||||
try:
|
||||
with tf.train.MonitoredTrainingSession(master='' if server is None else server.target,
|
||||
is_chief=Config.is_chief,
|
||||
hooks=hooks,
|
||||
checkpoint_dir=FLAGS.checkpoint_dir,
|
||||
save_checkpoint_secs=None, # already taken care of by a hook
|
||||
log_step_count_steps=0, # disable logging of steps/s to avoid TF warning in validation sets
|
||||
config=Config.session_config) as session:
|
||||
tf.get_default_graph().finalize()
|
||||
|
||||
try:
|
||||
if Config.is_chief:
|
||||
# Retrieving global_step from the (potentially restored) model
|
||||
model_feeder.set_data_set(no_dropout_feed_dict, model_feeder.train)
|
||||
step = session.run(global_step, feed_dict=no_dropout_feed_dict)
|
||||
coord.start_coordination(model_feeder, step)
|
||||
|
||||
# Get the first job
|
||||
job = coord.get_job()
|
||||
|
||||
while job and not session.should_stop():
|
||||
log_debug('Computing %s...' % job)
|
||||
|
||||
is_train = job.set_name == 'train'
|
||||
|
||||
# The feed_dict (mainly for switching between queues)
|
||||
if is_train:
|
||||
feed_dict = {
|
||||
dropout_rates[0]: FLAGS.dropout_rate,
|
||||
dropout_rates[1]: FLAGS.dropout_rate2,
|
||||
dropout_rates[2]: FLAGS.dropout_rate3,
|
||||
dropout_rates[3]: FLAGS.dropout_rate4,
|
||||
dropout_rates[4]: FLAGS.dropout_rate5,
|
||||
dropout_rates[5]: FLAGS.dropout_rate6,
|
||||
}
|
||||
else:
|
||||
feed_dict = no_dropout_feed_dict
|
||||
|
||||
# Sets the current data_set for the respective placeholder in feed_dict
|
||||
model_feeder.set_data_set(feed_dict, getattr(model_feeder, job.set_name))
|
||||
|
||||
# Initialize loss aggregator
|
||||
total_loss = 0.0
|
||||
|
||||
# Setting the training operation in case of training requested
|
||||
train_op = apply_gradient_op if is_train else []
|
||||
|
||||
# So far the only extra parameter is the feed_dict
|
||||
extra_params = { 'feed_dict': feed_dict }
|
||||
|
||||
step_summary_writer = step_summary_writers.get(job.set_name)
|
||||
|
||||
# Loop over the batches
|
||||
for job_step in range(job.steps):
|
||||
if session.should_stop():
|
||||
break
|
||||
|
||||
log_debug('Starting batch...')
|
||||
# Compute the batch
|
||||
_, current_step, batch_loss, step_summary = session.run([train_op, global_step, loss, step_summaries_op], **extra_params)
|
||||
|
||||
# Log step summaries
|
||||
step_summary_writer.add_summary(step_summary, current_step)
|
||||
|
||||
# Uncomment the next line for debugging race conditions / distributed TF
|
||||
log_debug('Finished batch step %d.' % current_step)
|
||||
|
||||
# Add batch to loss
|
||||
total_loss += batch_loss
|
||||
|
||||
# Gathering job results
|
||||
job.loss = total_loss / job.steps
|
||||
|
||||
# Display progressbar
|
||||
if FLAGS.show_progressbar:
|
||||
update_progressbar(job.set_name)
|
||||
|
||||
# Send the current job to coordinator and receive the next one
|
||||
log_debug('Sending %s...' % job)
|
||||
job = coord.next_job(job)
|
||||
|
||||
if update_progressbar.pbar:
|
||||
update_progressbar.pbar.finish()
|
||||
|
||||
except Exception as e:
|
||||
log_error(str(e))
|
||||
traceback.print_exc()
|
||||
# Calling all hook's end() methods to end blocking calls
|
||||
for hook in hooks:
|
||||
hook.end(session)
|
||||
# Only chief has a SyncReplicasOptimizer queue runner that needs to be stopped for unblocking process exit.
|
||||
# A rather graceful way to do this is by stopping the ps.
|
||||
# Only one party can send it w/o failing.
|
||||
if Config.is_chief:
|
||||
send_token_to_ps(session, kill=True)
|
||||
log_error('Unable to load %s model from specified checkpoint dir'
|
||||
' - consider using load option "auto" or "init".' % FLAGS.load)
|
||||
sys.exit(1)
|
||||
|
||||
log_debug('Session closed.')
|
||||
# Retrieving global_step from restored model and setting training parameters accordingly
|
||||
model_feeder.set_data_set(no_dropout_feed_dict, train_set)
|
||||
step = session.run(global_step, feed_dict=no_dropout_feed_dict)
|
||||
num_gpus = len(Config.available_devices)
|
||||
steps_per_epoch = max(1, train_set.total_batches // num_gpus)
|
||||
steps_trained = step % steps_per_epoch
|
||||
current_epoch = step // steps_per_epoch
|
||||
target_epoch = current_epoch + abs(FLAGS.epoch) if FLAGS.epoch < 0 else FLAGS.epoch
|
||||
train_index.index = steps_trained * num_gpus
|
||||
|
||||
except tf.errors.InvalidArgumentError as e:
|
||||
log_error(str(e))
|
||||
log_error('The checkpoint in {0} does not match the shapes of the model.'
|
||||
' Did you change alphabet.txt or the --n_hidden parameter'
|
||||
' between train runs using the same checkpoint dir? Try moving'
|
||||
' or removing the contents of {0}.'.format(FLAGS.checkpoint_dir))
|
||||
sys.exit(1)
|
||||
log_debug('step: %d' % step)
|
||||
log_debug('epoch: %d' % current_epoch)
|
||||
log_debug('target epoch: %d' % target_epoch)
|
||||
log_debug('steps per epoch: %d' % steps_per_epoch)
|
||||
log_debug('batches per step (GPUs): %d' % num_gpus)
|
||||
log_debug('number of batches in train set: %d' % train_set.total_batches)
|
||||
log_debug('number of batches already trained in epoch: %d' % train_index.index)
|
||||
|
||||
# Stopping the coordinator
|
||||
coord.stop()
|
||||
def run_set(set_name):
|
||||
data_set = getattr(model_feeder, set_name)
|
||||
is_train = set_name == 'train'
|
||||
train_op = apply_gradient_op if is_train else []
|
||||
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
|
||||
model_feeder.set_data_set(feed_dict, data_set)
|
||||
total_loss = 0.0
|
||||
step_summary_writer = step_summary_writers.get(set_name)
|
||||
num_steps = max(1, data_set.total_batches // num_gpus)
|
||||
checkpoint_time = time.time()
|
||||
if FLAGS.show_progressbar:
|
||||
pbar = progressbar.ProgressBar(max_value=num_steps, redirect_stdout=True).start()
|
||||
# Batch loop
|
||||
for step_index in range(steps_trained, num_steps):
|
||||
if coord.should_stop():
|
||||
break
|
||||
_, current_step, batch_loss, step_summary = \
|
||||
session.run([train_op, global_step, loss, step_summaries_op],
|
||||
feed_dict=feed_dict)
|
||||
total_loss += batch_loss
|
||||
step_summary_writer.add_summary(step_summary, current_step)
|
||||
if FLAGS.show_progressbar:
|
||||
pbar.update(step_index + 1, force=True)
|
||||
if FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
|
||||
checkpoint_saver.save(session, checkpoint_path)
|
||||
checkpoint_time = time.time()
|
||||
if FLAGS.show_progressbar:
|
||||
pbar.finish()
|
||||
return total_loss / num_steps
|
||||
|
||||
if target_epoch > current_epoch:
|
||||
log_info('STARTING Optimization')
|
||||
best_dev_loss = float('inf')
|
||||
dev_losses = []
|
||||
coord = tf.train.Coordinator()
|
||||
with coord.stop_on_exception():
|
||||
log_debug('Starting queue runners...')
|
||||
model_feeder.start_queue_threads(session, coord=coord)
|
||||
log_debug('Queue runners started.')
|
||||
# Epoch loop
|
||||
for current_epoch in range(current_epoch, target_epoch):
|
||||
# Training
|
||||
if coord.should_stop():
|
||||
break
|
||||
log_info('Training epoch %d ...' % current_epoch)
|
||||
train_loss = run_set('train')
|
||||
log_info('Finished training epoch %d - loss: %f' % (current_epoch, train_loss))
|
||||
checkpoint_saver.save(session, checkpoint_path)
|
||||
steps_trained = 0
|
||||
# Validation
|
||||
log_info('Validating epoch %d ...' % current_epoch)
|
||||
dev_loss = run_set('dev')
|
||||
dev_losses.append(dev_loss)
|
||||
log_info('Finished validating epoch %d - loss: %f' % (current_epoch, dev_loss))
|
||||
if dev_loss < best_dev_loss:
|
||||
best_dev_loss = dev_loss
|
||||
save_path = best_dev_saver.save(session, best_dev_path)
|
||||
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
||||
# Early stopping
|
||||
if FLAGS.early_stop and len(dev_losses) >= FLAGS.es_steps:
|
||||
mean_loss = np.mean(dev_losses[-FLAGS.es_steps:-1])
|
||||
std_loss = np.std(dev_losses[-FLAGS.es_steps:-1])
|
||||
dev_losses = dev_losses[-FLAGS.es_steps:]
|
||||
log_debug('Checking for early stopping (last %d steps) validation loss: '
|
||||
'%f, with standard deviation: %f and mean: %f' %
|
||||
(FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
|
||||
if dev_losses[-1] > np.max(dev_losses[:-1]) or \
|
||||
(abs(dev_losses[-1] - mean_loss) < FLAGS.es_mean_th and std_loss < FLAGS.es_std_th):
|
||||
log_info('Early stop triggered as (for last %d steps) validation loss:'
|
||||
' %f with standard deviation: %f and mean: %f' %
|
||||
(FLAGS.es_steps, dev_losses[-1], std_loss, mean_loss))
|
||||
break
|
||||
log_debug('Closing queues...')
|
||||
coord.request_stop()
|
||||
model_feeder.close_queues(session)
|
||||
log_debug('Queues closed.')
|
||||
else:
|
||||
log_info('Target epoch already reached - skipped training.')
|
||||
log_debug('Session closed.')
|
||||
|
||||
|
||||
def test():
|
||||
@ -667,8 +595,8 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
previous_state = previous_state_c = previous_state_h = None
|
||||
else:
|
||||
if not tflite:
|
||||
previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None)
|
||||
previous_state_h = variable_on_worker_level('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None)
|
||||
previous_state_c = variable_on_cpu('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None)
|
||||
previous_state_h = variable_on_cpu('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None)
|
||||
else:
|
||||
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
||||
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
||||
@ -878,53 +806,17 @@ def do_single_file_inference(input_file_path):
|
||||
def main(_):
|
||||
initialize_globals()
|
||||
|
||||
if FLAGS.train or FLAGS.test:
|
||||
if len(FLAGS.worker_hosts) == 0:
|
||||
# Only one local task: this process (default case - no cluster)
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(FLAGS.random_seed)
|
||||
train()
|
||||
# Now do a final test epoch
|
||||
if FLAGS.test:
|
||||
with tf.Graph().as_default():
|
||||
test()
|
||||
log_debug('Done.')
|
||||
else:
|
||||
# Create and start a server for the local task.
|
||||
server = tf.train.Server(Config.cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
|
||||
if FLAGS.job_name == 'ps':
|
||||
# We are a parameter server and therefore we just wait for all workers to finish
|
||||
# by waiting for their stop tokens.
|
||||
with tf.Session(server.target) as session:
|
||||
for worker in FLAGS.worker_hosts:
|
||||
log_debug('Waiting for stop token...')
|
||||
token = session.run(Config.done_dequeues[FLAGS.task_index])
|
||||
if token < 0:
|
||||
log_debug('Got a kill switch token from worker %i.' % abs(token + 1))
|
||||
break
|
||||
log_debug('Got a stop token from worker %i.' % token)
|
||||
log_debug('Session closed.')
|
||||
if FLAGS.train:
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(FLAGS.random_seed)
|
||||
train()
|
||||
|
||||
if FLAGS.test:
|
||||
test()
|
||||
elif FLAGS.job_name == 'worker':
|
||||
# We are a worker and therefore we have to do some work.
|
||||
# Assigns ops to the local worker by default.
|
||||
with tf.device(tf.train.replica_device_setter(
|
||||
worker_device=Config.worker_device,
|
||||
cluster=Config.cluster)):
|
||||
if FLAGS.test:
|
||||
with tf.Graph().as_default():
|
||||
test()
|
||||
|
||||
# Do the training
|
||||
train(server)
|
||||
|
||||
log_debug('Server stopped.')
|
||||
|
||||
# Are we the main process?
|
||||
if Config.is_chief:
|
||||
# Doing solo/post-processing work just on the main process...
|
||||
# Exporting the model
|
||||
if FLAGS.export_dir:
|
||||
export()
|
||||
if FLAGS.export_dir:
|
||||
export()
|
||||
|
||||
if len(FLAGS.one_shot_infer):
|
||||
do_single_file_inference(FLAGS.one_shot_infer)
|
||||
|
25
README.md
25
README.md
@ -48,7 +48,6 @@ See the output of `deepspeech -h` for more information on the use of `deepspeech
|
||||
- [Checkpointing](#checkpointing)
|
||||
- [Exporting a model for inference](#exporting-a-model-for-inference)
|
||||
- [Exporting a model for TFLite](#exporting-a-model-for-tflite)
|
||||
- [Distributed computing across more than one machine](#distributed-training-across-more-than-one-machine)
|
||||
- [Continuing training from a release model](#continuing-training-from-a-release-model)
|
||||
- [Contact/Getting Help](#contactgetting-help)
|
||||
|
||||
@ -352,30 +351,6 @@ $ convert_graphdef_memmapped_format --in_graph=output_graph.pb --out_graph=outpu
|
||||
|
||||
Upon sucessfull run, it should report about conversion of a non-zero number of nodes. If it reports converting `0` nodes, something is wrong: make sure your model is a frozen one, and that you have not applied any incompatible changes (this includes `quantize_weights`).
|
||||
|
||||
### Distributed training across more than one machine
|
||||
|
||||
DeepSpeech has built-in support for [distributed TensorFlow](https://www.tensorflow.org/deploy/distributed). To get an idea on how this works, you can use the script `bin/run-cluster.sh` for running a cluster with workers just on the local machine.
|
||||
|
||||
```bash
|
||||
$ bin/run-cluster.sh --help
|
||||
Usage: run-cluster.sh [--help] [--script script] [p:w:g] <arg>*
|
||||
|
||||
--help print this help message
|
||||
--script run the provided script instead of DeepSpeech.py
|
||||
p number of local parameter servers
|
||||
w number of local workers
|
||||
g number of local GPUs per worker
|
||||
<arg>* remaining parameters will be forwarded to DeepSpeech.py or a provided script
|
||||
|
||||
Example usage - The following example will create a local DeepSpeech.py cluster
|
||||
with 1 parameter server, and 2 workers with 1 GPU each:
|
||||
$ run-cluster.sh 1:2:1 --epoch 10
|
||||
```
|
||||
|
||||
Be aware that for the help example to be able to run, you need at least two `CUDA` capable GPUs (2 workers x 1 GPU). The script utilizes environment variable `CUDA_VISIBLE_DEVICES` for `DeepSpeech.py` to see only the provided number of GPUs per worker.
|
||||
|
||||
The script is meant to be a template for your own distributed computing instrumentation. Just modify the startup code for the different servers (workers and parameter servers) accordingly. You could use SSH or something similar for running them on your remote hosts.
|
||||
|
||||
### Continuing training from a release model
|
||||
|
||||
If you'd like to use one of the pre-trained models released by Mozilla to bootstrap your training process (transfer learning, fine tuning), you can do so by using the `--checkpoint_dir` flag in `DeepSpeech.py`. Specify the path where you downloaded the checkpoint from the release, and training will resume from the pre-trained model.
|
||||
|
@ -1,72 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
#SBATCH --job-name=__NAME__
|
||||
#SBATCH --output=common.log
|
||||
#SBATCH --ntasks=__NODES__
|
||||
#SBATCH --nodes=__NODES__
|
||||
#SBATCH --gres=gpu:__GPUS__
|
||||
|
||||
{
|
||||
|
||||
set -o pipefail
|
||||
# force-killing all sub-processes on process exit, Ctrl-C, kill-signal
|
||||
# "trap - SIGTERM" hinders "kill" from being killed itself
|
||||
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT
|
||||
|
||||
PWD=`pwd`
|
||||
compute_cmd="/bin/bash .compute"
|
||||
srun="srun --exclusive -N1 -n1"
|
||||
|
||||
# unfolding slurm's compact cluster representation
|
||||
nodes_raw=`scontrol show hostname $SLURM_JOB_NODELIST`
|
||||
index=0
|
||||
for node in $nodes_raw
|
||||
do
|
||||
# keeping nodes as array for later index lookup
|
||||
nodes[$index]=$node
|
||||
((index=index + 1))
|
||||
done
|
||||
# comma separated node list (with leading comma)
|
||||
raw_node_list=$(printf ",%s" "${nodes[@]}")
|
||||
|
||||
# exporting COMPUTE variables for the .compute script that will
|
||||
# get executed on every node of the allocated cluster
|
||||
export COMPUTE_NODES=${raw_node_list:1}
|
||||
export COMPUTE_NODES_COUNT=${#nodes[@]}
|
||||
export COMPUTE_ID="__ID__"
|
||||
export COMPUTE_NAME="__NAME__"
|
||||
export COMPUTE_JOB_NUMBER=$SLURM_JOB_ID
|
||||
export COMPUTE_DATA_DIR=/data/shared
|
||||
export COMPUTE_RESULTS_DIR=$PWD/results
|
||||
export COMPUTE_KEEP_DIR=$COMPUTE_RESULTS_DIR/keep
|
||||
export COMPUTE_JOB_LOG="common.log"
|
||||
export COMPUTE_GLOBAL_LOG="__GLOBAL_LOG__"
|
||||
|
||||
for index in $(seq 0 $((COMPUTE_NODES_COUNT-1)));
|
||||
do
|
||||
# will tell every instance of the .compute script, which node of the cluster it represents
|
||||
export COMPUTE_NODE_INDEX=$index
|
||||
# the node has to be specified by "-w" to guarantee execution under the correct COMPUTE_NODE_INDEX
|
||||
# if some log line contains "GLOBAL LOG", the remaining string is emitted to the provided log file
|
||||
$srun -w ${nodes[index]} /bin/bash -c "cd src && $compute_cmd" \
|
||||
| tee >(sed -n -e 's/^.*GLOBAL LOG: //p' >> $COMPUTE_GLOBAL_LOG) &
|
||||
done
|
||||
|
||||
for index in $(seq 1 $COMPUTE_NODES_COUNT);
|
||||
do
|
||||
# "wait -n" waits for any sub-process to exit
|
||||
# doing this COMPUTE_NODES_COUNT times will wait for all node sub-processes to finish
|
||||
# in case of any node sub-process failing, it will exit immediately
|
||||
wait -n
|
||||
code=$?
|
||||
if ((code > 0)); then
|
||||
echo "One node failed with exit code $code."
|
||||
exit $code
|
||||
else
|
||||
echo "One node succeeded."
|
||||
fi
|
||||
done
|
||||
|
||||
echo "Success. Quitting..."
|
||||
|
||||
} 2>&1 | ts "[%Y-%m-%d %H:%M:%.S] [common ]" > common.log
|
@ -1,99 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ ! -f DeepSpeech.py ]; then
|
||||
echo "Please make sure you run this from DeepSpeech's top level directory."
|
||||
exit 1
|
||||
fi;
|
||||
|
||||
ps_count=1
|
||||
worker_count=2
|
||||
gpu_count=0
|
||||
script="python -u DeepSpeech.py"
|
||||
|
||||
if [[ $1 == "--help" ]]; then
|
||||
echo "Usage: run-cluster.sh [--help] [--script script] [p:w:g] <arg>*"
|
||||
echo ""
|
||||
echo "--help print this help message"
|
||||
echo "--script run the provided script instead of DeepSpeech.py"
|
||||
echo "p number of local parameter servers"
|
||||
echo "w number of local workers"
|
||||
echo "g number of local GPUs per worker"
|
||||
echo "<arg>* remaining parameters will be forwarded to DeepSpeech.py or a provided script"
|
||||
echo
|
||||
echo "Example usage - The following example will create a local DeepSpeech.py cluster"
|
||||
echo "with 1 parameter server, and 2 workers with 1 GPU each:"
|
||||
echo "$ run-cluster.sh 1:2:1 --epoch 10"
|
||||
echo
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [[ $1 == "--script" ]]; then
|
||||
shift 1
|
||||
script=$1
|
||||
shift 1
|
||||
echo "Using script $script..."
|
||||
fi
|
||||
|
||||
if [[ $1 =~ ([0-9]+):([0-9]+):([0-9]+) ]]; then
|
||||
ps_count=${BASH_REMATCH[1]}
|
||||
worker_count=${BASH_REMATCH[2]}
|
||||
gpu_count=${BASH_REMATCH[3]}
|
||||
shift 1
|
||||
fi
|
||||
|
||||
echo "Starting cluster with $ps_count parameter servers and $worker_count workers with $gpu_count GPUs each..."
|
||||
|
||||
# Generating the parameter server addresses
|
||||
index=0
|
||||
while [ "$index" -lt "$ps_count" ]
|
||||
do
|
||||
ps_hosts[$index]="localhost:$((index + 2000))"
|
||||
((index++))
|
||||
done
|
||||
ps_hosts=$(printf ",%s" "${ps_hosts[@]}")
|
||||
ps_hosts=${ps_hosts:1}
|
||||
|
||||
# Generating the worker addresses
|
||||
index=0
|
||||
while [ "$index" -lt "$worker_count" ]
|
||||
do
|
||||
worker_hosts[$index]="localhost:$((index + 3000))"
|
||||
((index++))
|
||||
done
|
||||
worker_hosts=$(printf ",%s" "${worker_hosts[@]}")
|
||||
worker_hosts=${worker_hosts:1}
|
||||
|
||||
|
||||
# Starting the parameter servers
|
||||
index=0
|
||||
while [ "$index" -lt "$ps_count" ]
|
||||
do
|
||||
CUDA_VISIBLE_DEVICES="" $script --ps_hosts $ps_hosts --worker_hosts $worker_hosts --job_name=ps --task_index=$index "$@" 2>&1 | sed 's/^/[ps '"$index"'] /' &
|
||||
echo "Started ps $index"
|
||||
((index++))
|
||||
done
|
||||
|
||||
# Starting the workers
|
||||
start=0
|
||||
index=0
|
||||
while [ "$index" -lt "$worker_count" ]
|
||||
do
|
||||
stop=$((start+gpu_count-1))
|
||||
# Creating a comma delimited number sequence from $start to $end
|
||||
cvd=`seq -s, $start $stop`
|
||||
CUDA_VISIBLE_DEVICES=$cvd $script --ps_hosts $ps_hosts --worker_hosts $worker_hosts --job_name=worker --task_index=$index "$@" 2>&1 | sed 's/^/[worker '"$index"'] /' &
|
||||
start=$((start+gpu_count))
|
||||
echo "Started worker $index"
|
||||
((index++))
|
||||
done
|
||||
|
||||
# If we are forced to quit, we kill all ramining jobs/servers
|
||||
function quit {
|
||||
echo
|
||||
echo "Killing whole process group - the hard way..."
|
||||
kill -KILL -$$
|
||||
}
|
||||
trap quit SIGINT SIGTERM
|
||||
|
||||
# Waiting for all running jobs to join
|
||||
while [ `jobs -rp | wc -l` -gt 0 ]; do sleep 1; done
|
@ -16,7 +16,7 @@ else
|
||||
checkpoint_dir=$(python -c 'from xdg import BaseDirectory as xdg; print(xdg.save_data_path("deepspeech/ldc93s1"))')
|
||||
fi
|
||||
|
||||
python -u DeepSpeech.py \
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--train_files data/ldc93s1/ldc93s1.csv \
|
||||
--dev_files data/ldc93s1/ldc93s1.csv \
|
||||
--test_files data/ldc93s1/ldc93s1.csv \
|
||||
|
@ -12,7 +12,7 @@ if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
|
||||
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
|
||||
fi;
|
||||
|
||||
python -u DeepSpeech.py --noshow_progressbar \
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
||||
@ -22,7 +22,7 @@ python -u DeepSpeech.py --noshow_progressbar \
|
||||
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
||||
--lm_trie_path 'data/smoke_test/vocab.trie' | tee /tmp/resume.log
|
||||
|
||||
if ! grep "Training of Epoch $epoch_count" /tmp/resume.log; then
|
||||
if ! grep "Training epoch $epoch_count" /tmp/resume.log; then
|
||||
echo "Did not resume training from checkpoint"
|
||||
exit 1
|
||||
else
|
||||
|
@ -12,7 +12,7 @@ if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
|
||||
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
|
||||
fi;
|
||||
|
||||
python -u DeepSpeech.py \
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--train_cached_features_path "/tmp/ldc93s1.hdf5" \
|
||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||
|
@ -10,7 +10,7 @@ if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
|
||||
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
|
||||
fi;
|
||||
|
||||
python -u DeepSpeech.py \
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
||||
@ -20,7 +20,7 @@ python -u DeepSpeech.py \
|
||||
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
||||
--lm_trie_path 'data/smoke_test/vocab.trie'
|
||||
|
||||
python -u DeepSpeech.py \
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
||||
|
@ -10,7 +10,7 @@ if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
|
||||
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
|
||||
fi;
|
||||
|
||||
python -u DeepSpeech.py \
|
||||
python -u DeepSpeech.py --noshow_progressbar \
|
||||
--n_hidden 494 \
|
||||
--checkpoint_dir '/tmp/ckpt' \
|
||||
--export_dir '/tmp/train' \
|
||||
|
@ -4,7 +4,6 @@ import os
|
||||
import tensorflow as tf
|
||||
|
||||
from attrdict import AttrDict
|
||||
from six.moves import zip, range, filter
|
||||
from util.flags import FLAGS
|
||||
from util.gpu import get_available_gpus
|
||||
from util.logging import log_error
|
||||
@ -27,30 +26,11 @@ Config = ConfigSingleton()
|
||||
def initialize_globals():
|
||||
c = AttrDict()
|
||||
|
||||
# ps and worker hosts required for p2p cluster setup
|
||||
FLAGS.ps_hosts = list(filter(len, FLAGS.ps_hosts.split(',')))
|
||||
FLAGS.worker_hosts = list(filter(len, FLAGS.worker_hosts.split(',')))
|
||||
# CPU device
|
||||
c.cpu_device = '/cpu:0'
|
||||
|
||||
# Create a cluster from the parameter server and worker hosts.
|
||||
c.cluster = tf.train.ClusterSpec({'ps': FLAGS.ps_hosts, 'worker': FLAGS.worker_hosts})
|
||||
|
||||
# The absolute number of computing nodes - regardless of cluster or single mode
|
||||
num_workers = max(1, len(FLAGS.worker_hosts))
|
||||
|
||||
# If replica numbers are negative, we multiply their absolute values with the number of workers
|
||||
if FLAGS.replicas < 0:
|
||||
FLAGS.replicas = num_workers * -FLAGS.replicas
|
||||
if FLAGS.replicas_to_agg < 0:
|
||||
FLAGS.replicas_to_agg = num_workers * -FLAGS.replicas_to_agg
|
||||
|
||||
# The device path base for this node
|
||||
c.worker_device = '/job:%s/task:%d' % (FLAGS.job_name, FLAGS.task_index)
|
||||
|
||||
# This node's CPU device
|
||||
c.cpu_device = c.worker_device + '/cpu:0'
|
||||
|
||||
# This node's available GPU devices
|
||||
c.available_devices = [c.worker_device + gpu for gpu in get_available_gpus()]
|
||||
# Available GPU devices
|
||||
c.available_devices = get_available_gpus()
|
||||
|
||||
# If there is no GPU available, we fall back to CPU based operation
|
||||
if 0 == len(c.available_devices):
|
||||
@ -68,6 +48,9 @@ def initialize_globals():
|
||||
if len(FLAGS.checkpoint_dir) == 0:
|
||||
FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('deepspeech','checkpoints'))
|
||||
|
||||
if FLAGS.load not in ['last', 'best', 'init', 'auto']:
|
||||
FLAGS.load = 'auto'
|
||||
|
||||
# Set default summary dir
|
||||
if len(FLAGS.summary_dir) == 0:
|
||||
FLAGS.summary_dir = xdg.save_data_path(os.path.join('deepspeech','summaries'))
|
||||
@ -109,26 +92,6 @@ def initialize_globals():
|
||||
# Units in the sixth layer = number of characters in the target language plus one
|
||||
c.n_hidden_6 = c.alphabet.size() + 1 # +1 for CTC blank label
|
||||
|
||||
# Queues that are used to gracefully stop parameter servers.
|
||||
# Each queue stands for one ps. A finishing worker sends a token to each queue before joining/quitting.
|
||||
# Each ps will dequeue as many tokens as there are workers before joining/quitting.
|
||||
# This ensures parameter servers won't quit, if still required by at least one worker and
|
||||
# also won't wait forever (like with a standard `server.join()`).
|
||||
done_queues = []
|
||||
for i, ps in enumerate(FLAGS.ps_hosts):
|
||||
# Queues are hosted by their respective owners
|
||||
with tf.device('/job:ps/task:%d' % i):
|
||||
done_queues.append(tf.FIFOQueue(1, tf.int32, shared_name=('queue%i' % i)))
|
||||
|
||||
# Placeholder to pass in the worker's index as token
|
||||
c.token_placeholder = tf.placeholder(tf.int32)
|
||||
|
||||
# Enqueue operations for each parameter server
|
||||
c.done_enqueues = [queue.enqueue(c.token_placeholder) for queue in done_queues]
|
||||
|
||||
# Dequeue operations for each parameter server
|
||||
c.done_dequeues = [queue.dequeue() for queue in done_queues]
|
||||
|
||||
if len(FLAGS.one_shot_infer) > 0:
|
||||
FLAGS.train = False
|
||||
FLAGS.test = False
|
||||
@ -137,7 +100,4 @@ def initialize_globals():
|
||||
log_error('Path specified in --one_shot_infer is not a valid file.')
|
||||
exit(1)
|
||||
|
||||
# Determine, if we are the chief worker
|
||||
c.is_chief = len(FLAGS.worker_hosts) == 0 or (FLAGS.task_index == 0 and FLAGS.job_name == 'worker')
|
||||
|
||||
ConfigSingleton._config = c
|
||||
|
@ -1,569 +0,0 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import pickle
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from datetime import datetime
|
||||
from six.moves import zip, range, filter, urllib, BaseHTTPServer
|
||||
from threading import Thread, Lock
|
||||
from util.config import Config
|
||||
from util.flags import FLAGS
|
||||
from util.logging import log_info, log_error, log_debug, log_warn, log_traffic
|
||||
|
||||
|
||||
# Execution
|
||||
# =========
|
||||
|
||||
# For reporting we also need a standard way to do time measurements.
|
||||
def stopwatch(start_duration=0):
|
||||
r'''
|
||||
This function will toggle a stopwatch.
|
||||
The first call starts it, second call stops it, third call continues it etc.
|
||||
So if you want to measure the accumulated time spent in a certain area of the code,
|
||||
you can surround that code by stopwatch-calls like this:
|
||||
|
||||
.. code:: python
|
||||
|
||||
fun_time = 0 # initializes a stopwatch
|
||||
[...]
|
||||
for i in range(10):
|
||||
[...]
|
||||
# Starts/continues the stopwatch - fun_time is now a point in time (again)
|
||||
fun_time = stopwatch(fun_time)
|
||||
fun()
|
||||
# Pauses the stopwatch - fun_time is now a duration
|
||||
fun_time = stopwatch(fun_time)
|
||||
[...]
|
||||
# The following line only makes sense after an even call of :code:`fun_time = stopwatch(fun_time)`.
|
||||
print 'Time spent in fun():', format_duration(fun_time)
|
||||
|
||||
'''
|
||||
if start_duration == 0:
|
||||
return datetime.utcnow()
|
||||
else:
|
||||
return datetime.utcnow() - start_duration
|
||||
|
||||
def format_duration(duration):
|
||||
'''Formats the result of an even stopwatch call as hours:minutes:seconds'''
|
||||
duration = duration if isinstance(duration, int) else duration.seconds
|
||||
m, s = divmod(duration, 60)
|
||||
h, m = divmod(m, 60)
|
||||
return '%d:%02d:%02d' % (h, m, s)
|
||||
|
||||
|
||||
# String constants for different services of the web handler
|
||||
PREFIX_NEXT_INDEX = '/next_index_'
|
||||
PREFIX_GET_JOB = '/get_job_'
|
||||
|
||||
# Global ID counter for all objects requiring an ID
|
||||
id_counter = 0
|
||||
|
||||
def new_id():
|
||||
'''Returns a new ID that is unique on process level. Not thread-safe.
|
||||
|
||||
Returns:
|
||||
int. The new ID
|
||||
'''
|
||||
global id_counter
|
||||
id_counter += 1
|
||||
return id_counter
|
||||
|
||||
class WorkerJob(object):
|
||||
'''Represents a job that should be executed by a worker.
|
||||
|
||||
Args:
|
||||
epoch_id (int): the ID of the 'parent' epoch
|
||||
index (int): the epoch index of the 'parent' epoch
|
||||
set_name (str): the name of the data-set - one of 'train', 'dev'
|
||||
steps (int): the number of `session.run` calls
|
||||
'''
|
||||
def __init__(self, epoch_id, index, set_name, steps):
|
||||
self.id = new_id()
|
||||
self.epoch_id = epoch_id
|
||||
self.index = index
|
||||
self.worker = -1
|
||||
self.set_name = set_name
|
||||
self.steps = steps
|
||||
self.loss = -1
|
||||
self.samples = []
|
||||
|
||||
def __str__(self):
|
||||
return 'Job (ID: %d, worker: %d, epoch: %d, set_name: %s)' % (self.id, self.worker, self.index, self.set_name)
|
||||
|
||||
class Epoch(object):
|
||||
'''Represents an epoch that should be executed by the Training Coordinator.
|
||||
Creates `num_jobs` `WorkerJob` instances in state 'open'.
|
||||
|
||||
Args:
|
||||
index (int): the epoch index of the 'parent' epoch
|
||||
num_jobs (int): the number of jobs in this epoch
|
||||
|
||||
Kwargs:
|
||||
set_name (str): the name of the data-set - one of 'train', 'dev'
|
||||
'''
|
||||
def __init__(self, coord, index, num_jobs, set_name='train'):
|
||||
self.id = new_id()
|
||||
self.coord = coord
|
||||
self.index = index
|
||||
self.num_jobs = num_jobs
|
||||
self.set_name = set_name
|
||||
self.loss = -1
|
||||
self.jobs_open = []
|
||||
self.jobs_running = []
|
||||
self.jobs_done = []
|
||||
for i in range(self.num_jobs):
|
||||
self.jobs_open.append(WorkerJob(self.id, self.index, self.set_name, FLAGS.iters_per_worker))
|
||||
|
||||
def name(self):
|
||||
'''Gets a printable name for this epoch.
|
||||
|
||||
Returns:
|
||||
str. printable name for this epoch
|
||||
'''
|
||||
if self.index >= 0:
|
||||
ename = ' of Epoch %d' % self.index
|
||||
else:
|
||||
ename = ''
|
||||
if self.set_name == 'train':
|
||||
return 'Training%s' % ename
|
||||
else:
|
||||
return 'Validation%s' % ename
|
||||
|
||||
def get_job(self, worker):
|
||||
'''Gets the next open job from this epoch. The job will be marked as 'running'.
|
||||
|
||||
Args:
|
||||
worker (int): index of the worker that takes the job
|
||||
|
||||
Returns:
|
||||
WorkerJob. job that has been marked as running for this worker
|
||||
'''
|
||||
if len(self.jobs_open) > 0:
|
||||
job = self.jobs_open.pop(0)
|
||||
self.jobs_running.append(job)
|
||||
job.worker = worker
|
||||
return job
|
||||
else:
|
||||
return None
|
||||
|
||||
def finish_job(self, job):
|
||||
'''Finishes a running job. Removes it from the running jobs list and adds it to the done jobs list.
|
||||
|
||||
Args:
|
||||
job (WorkerJob): the job to put into state 'done'
|
||||
'''
|
||||
index = next((i for i in range(len(self.jobs_running)) if self.jobs_running[i].id == job.id), -1)
|
||||
if index >= 0:
|
||||
self.jobs_running.pop(index)
|
||||
self.jobs_done.append(job)
|
||||
log_traffic('%s - Moved %s from running to done.' % (self.name(), job))
|
||||
else:
|
||||
log_warn('%s - There is no job with ID %d registered as running.' % (self.name(), job.id))
|
||||
|
||||
def done(self):
|
||||
'''Checks, if all jobs of the epoch are in state 'done'.
|
||||
|
||||
Returns:
|
||||
bool. if all jobs of the epoch are 'done'
|
||||
'''
|
||||
if len(self.jobs_open) == 0 and len(self.jobs_running) == 0:
|
||||
num_jobs = len(self.jobs_done)
|
||||
if num_jobs > 0:
|
||||
jobs = self.jobs_done
|
||||
self.jobs_done = []
|
||||
if not self.num_jobs == num_jobs:
|
||||
log_warn('%s - Number of steps not equal to number of jobs done.' % (self.name()))
|
||||
|
||||
agg_loss = 0.0
|
||||
|
||||
for i in range(num_jobs):
|
||||
job = jobs.pop(0)
|
||||
agg_loss += job.loss
|
||||
|
||||
self.loss = agg_loss / num_jobs
|
||||
|
||||
# if the job was for validation dataset then append it to the COORD's _loss for early stop verification
|
||||
if (FLAGS.early_stop is True) and (self.set_name == 'dev'):
|
||||
self.coord._dev_losses.append(self.loss)
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
def job_status(self):
|
||||
'''Provides a printable overview of the states of the jobs of this epoch.
|
||||
|
||||
Returns:
|
||||
str. printable overall job state
|
||||
'''
|
||||
return '%s - jobs open: %d, jobs running: %d, jobs done: %d' % (self.name(), len(self.jobs_open), len(self.jobs_running), len(self.jobs_done))
|
||||
|
||||
def __str__(self):
|
||||
if not self.done():
|
||||
return self.job_status()
|
||||
|
||||
return '%s - loss: %f' % (self.name(), self.loss)
|
||||
|
||||
|
||||
class TrainingCoordinator(object):
|
||||
''' Central training coordination class.
|
||||
Used for distributing jobs among workers of a cluster.
|
||||
Instantiated on all workers, calls of non-chief workers will transparently
|
||||
HTTP-forwarded to the chief worker instance.
|
||||
'''
|
||||
|
||||
def make_handler(coord):
|
||||
class TrainingCoordinationHandler(BaseHTTPServer.BaseHTTPRequestHandler):
|
||||
'''Handles HTTP requests from remote workers to the Training Coordinator.
|
||||
'''
|
||||
def _send_answer(self, data=None):
|
||||
self.send_response(200)
|
||||
self.send_header('content-type', 'text/plain')
|
||||
self.end_headers()
|
||||
if data:
|
||||
self.wfile.write(data)
|
||||
|
||||
def do_GET(self):
|
||||
if coord.started:
|
||||
if self.path.startswith(PREFIX_NEXT_INDEX):
|
||||
index = coord.get_next_index(self.path[len(PREFIX_NEXT_INDEX):])
|
||||
if index >= 0:
|
||||
self._send_answer(str(index).encode("utf-8"))
|
||||
return
|
||||
elif self.path.startswith(PREFIX_GET_JOB):
|
||||
job = coord.get_job(worker=int(self.path[len(PREFIX_GET_JOB):]))
|
||||
if job:
|
||||
self._send_answer(pickle.dumps(job))
|
||||
return
|
||||
self.send_response(204) # end of training
|
||||
else:
|
||||
self.send_response(202) # not ready yet
|
||||
self.end_headers()
|
||||
|
||||
def do_POST(self):
|
||||
if coord.started:
|
||||
src = self.rfile.read(int(self.headers['content-length']))
|
||||
job = coord.next_job(pickle.loads(src))
|
||||
if job:
|
||||
self._send_answer(pickle.dumps(job))
|
||||
return
|
||||
self.send_response(204) # end of training
|
||||
else:
|
||||
self.send_response(202) # not ready yet
|
||||
self.end_headers()
|
||||
|
||||
def log_message(self, format, *args):
|
||||
'''Overriding base method to suppress web handler messages on stdout.
|
||||
'''
|
||||
return
|
||||
|
||||
return TrainingCoordinationHandler
|
||||
|
||||
def __init__(self, is_chief):
|
||||
self._init()
|
||||
self._lock = Lock()
|
||||
self._thread = None
|
||||
self.started = False
|
||||
self.is_chief = is_chief
|
||||
if is_chief:
|
||||
self._httpd = BaseHTTPServer.HTTPServer((FLAGS.coord_host, FLAGS.coord_port), TrainingCoordinator.make_handler(self))
|
||||
|
||||
def _reset_counters(self):
|
||||
self._index_train = 0
|
||||
self._index_dev = 0
|
||||
|
||||
def _init(self):
|
||||
self._epochs_running = []
|
||||
self._epochs_done = []
|
||||
self._reset_counters()
|
||||
self._dev_losses = []
|
||||
|
||||
def _log_all_jobs(self):
|
||||
'''Use this to debug-print epoch state'''
|
||||
log_debug('Epochs - running: %d, done: %d' % (len(self._epochs_running), len(self._epochs_done)))
|
||||
for epoch in self._epochs_running:
|
||||
log_debug(' - running: ' + epoch.job_status())
|
||||
|
||||
def start_coordination(self, model_feeder, step=0):
|
||||
'''Starts to coordinate epochs and jobs among workers on base of
|
||||
data-set sizes, the (global) step and FLAGS parameters.
|
||||
|
||||
Args:
|
||||
model_feeder (ModelFeeder): data-sets to be used for coordinated training
|
||||
|
||||
Kwargs:
|
||||
step (int): global step of a loaded model to determine starting point
|
||||
'''
|
||||
with self._lock:
|
||||
self._init()
|
||||
|
||||
# Number of GPUs per worker - fixed for now by local reality or cluster setup
|
||||
gpus_per_worker = len(Config.available_devices)
|
||||
|
||||
# Number of batches processed per job per worker
|
||||
batches_per_job = gpus_per_worker * max(1, FLAGS.iters_per_worker)
|
||||
|
||||
# Number of batches per global step
|
||||
batches_per_step = gpus_per_worker * max(1, FLAGS.replicas_to_agg)
|
||||
|
||||
# Number of global steps per epoch - to be at least 1
|
||||
steps_per_epoch = max(1, model_feeder.train.total_batches // batches_per_step)
|
||||
|
||||
# The start epoch of our training
|
||||
self._epoch = step // steps_per_epoch
|
||||
|
||||
# Number of additional 'jobs' trained already 'on top of' our start epoch
|
||||
jobs_trained = (step % steps_per_epoch) * batches_per_step // batches_per_job
|
||||
|
||||
# Total number of train/dev jobs covering their respective whole sets (one epoch)
|
||||
self._num_jobs_train = max(1, model_feeder.train.total_batches // batches_per_job)
|
||||
self._num_jobs_dev = max(1, model_feeder.dev.total_batches // batches_per_job)
|
||||
|
||||
if FLAGS.epoch < 0:
|
||||
# A negative epoch means to add its absolute number to the epochs already computed
|
||||
self._target_epoch = self._epoch + abs(FLAGS.epoch)
|
||||
else:
|
||||
self._target_epoch = FLAGS.epoch
|
||||
|
||||
# State variables
|
||||
# We only have to train, if we are told so and are not at the target epoch yet
|
||||
self._train = FLAGS.train and self._target_epoch > self._epoch
|
||||
|
||||
if self._train:
|
||||
# The total number of jobs for all additional epochs to be trained
|
||||
# Will be decremented for each job that is produced/put into state 'open'
|
||||
self._num_jobs_train_left = (self._target_epoch - self._epoch) * self._num_jobs_train - jobs_trained
|
||||
log_info('STARTING Optimization')
|
||||
self._training_time = stopwatch()
|
||||
|
||||
# Important for debugging
|
||||
log_debug('step: %d' % step)
|
||||
log_debug('epoch: %d' % self._epoch)
|
||||
log_debug('target epoch: %d' % self._target_epoch)
|
||||
log_debug('steps per epoch: %d' % steps_per_epoch)
|
||||
log_debug('number of batches in train set: %d' % model_feeder.train.total_batches)
|
||||
log_debug('batches per job: %d' % batches_per_job)
|
||||
log_debug('batches per step: %d' % batches_per_step)
|
||||
log_debug('number of jobs in train set: %d' % self._num_jobs_train)
|
||||
log_debug('number of jobs already trained in first epoch: %d' % jobs_trained)
|
||||
|
||||
self._next_epoch()
|
||||
|
||||
# The coordinator is ready to serve
|
||||
self.started = True
|
||||
|
||||
def _next_epoch(self):
|
||||
# State-machine of the coordination process
|
||||
|
||||
# Indicates, if there were 'new' epoch(s) provided
|
||||
result = False
|
||||
|
||||
# Make sure that early stop is enabled and validation part is enabled
|
||||
if (FLAGS.early_stop is True) and (FLAGS.validation_step > 0) and (len(self._dev_losses) >= FLAGS.earlystop_nsteps):
|
||||
|
||||
# Calculate the mean of losses for past epochs
|
||||
mean_loss = np.mean(self._dev_losses[-FLAGS.earlystop_nsteps:-1])
|
||||
# Calculate the standard deviation for losses from validation part in the past epochs
|
||||
std_loss = np.std(self._dev_losses[-FLAGS.earlystop_nsteps:-1])
|
||||
# Update the list of losses incurred
|
||||
self._dev_losses = self._dev_losses[-FLAGS.earlystop_nsteps:]
|
||||
log_debug('Checking for early stopping (last %d steps) validation loss: %f, with standard deviation: %f and mean: %f' % (FLAGS.earlystop_nsteps, self._dev_losses[-1], std_loss, mean_loss))
|
||||
|
||||
# Check if validation loss has started increasing or is not decreasing substantially, making sure slight fluctuations don't bother the early stopping from working
|
||||
if self._dev_losses[-1] > np.max(self._dev_losses[:-1]) or (abs(self._dev_losses[-1] - mean_loss) < FLAGS.estop_mean_thresh and std_loss < FLAGS.estop_std_thresh):
|
||||
# Time to early stop
|
||||
log_info('Early stop triggered as (for last %d steps) validation loss: %f with standard deviation: %f and mean: %f' % (FLAGS.earlystop_nsteps, self._dev_losses[-1], std_loss, mean_loss))
|
||||
self._dev_losses = []
|
||||
self._end_training()
|
||||
self._train = False
|
||||
|
||||
if self._train:
|
||||
# We are in train mode
|
||||
if self._num_jobs_train_left > 0:
|
||||
# There are still jobs left
|
||||
num_jobs_train = min(self._num_jobs_train_left, self._num_jobs_train)
|
||||
self._num_jobs_train_left -= num_jobs_train
|
||||
|
||||
# Let's try our best to keep the notion of curriculum learning
|
||||
self._reset_counters()
|
||||
|
||||
# Append the training epoch
|
||||
self._epochs_running.append(Epoch(self, self._epoch, num_jobs_train, set_name='train'))
|
||||
|
||||
if FLAGS.validation_step > 0 and (FLAGS.validation_step == 1 or self._epoch > 0) and self._epoch % FLAGS.validation_step == 0:
|
||||
# The current epoch should also have a validation part
|
||||
self._epochs_running.append(Epoch(self, self._epoch, self._num_jobs_dev, set_name='dev'))
|
||||
|
||||
|
||||
# Indicating that there were 'new' epoch(s) provided
|
||||
result = True
|
||||
else:
|
||||
# No jobs left, but still in train mode: concluding training
|
||||
self._end_training()
|
||||
self._train = False
|
||||
|
||||
if result:
|
||||
# Increment the epoch index
|
||||
self._epoch += 1
|
||||
return result
|
||||
|
||||
def _end_training(self):
|
||||
self._training_time = stopwatch(self._training_time)
|
||||
log_info('FINISHED Optimization - training time: %s' % format_duration(self._training_time))
|
||||
|
||||
def start(self):
|
||||
'''Starts Training Coordinator. If chief, it starts a web server for
|
||||
communication with non-chief instances.
|
||||
'''
|
||||
if self.is_chief:
|
||||
log_debug('Starting coordinator...')
|
||||
self._thread = Thread(target=self._httpd.serve_forever)
|
||||
self._thread.daemon = True
|
||||
self._thread.start()
|
||||
log_debug('Coordinator started. Thread id {}'.format(self._thread.ident))
|
||||
|
||||
def stop(self, wait_for_running_epochs=True):
|
||||
'''Stops Training Coordinator. If chief, it waits for all epochs to be
|
||||
'done' and then shuts down the web server.
|
||||
'''
|
||||
if self.is_chief and self._thread:
|
||||
if wait_for_running_epochs:
|
||||
while len(self._epochs_running) > 0:
|
||||
log_traffic('Coordinator is waiting for epochs to finish...')
|
||||
time.sleep(5)
|
||||
log_debug('Stopping coordinator...')
|
||||
self._httpd.shutdown()
|
||||
log_debug('Coordinator stopped.')
|
||||
|
||||
def _talk_to_chief(self, path, data=None, default=None):
|
||||
tries = 0
|
||||
while tries < FLAGS.coord_retries:
|
||||
tries += 1
|
||||
try:
|
||||
url = 'http://%s:%d%s' % (FLAGS.coord_host, FLAGS.coord_port, path)
|
||||
log_traffic('Contacting coordinator - url: %s, tries: %d ...' % (url, tries-1))
|
||||
res = urllib.request.urlopen(urllib.request.Request(url, data, { 'content-type': 'text/plain' }))
|
||||
str = res.read()
|
||||
status = res.getcode()
|
||||
log_traffic('Coordinator responded - url: %s, status: %s' % (url, status))
|
||||
if status == 200:
|
||||
return str
|
||||
if status == 204: # We use 204 (no content) to indicate end of training
|
||||
return default
|
||||
except urllib.error.HTTPError as error:
|
||||
log_traffic('Problem reaching coordinator - url: %s, HTTP code: %d' % (url, error.code))
|
||||
pass
|
||||
time.sleep(10)
|
||||
return default
|
||||
|
||||
def get_next_index(self, set_name):
|
||||
'''Retrives a new cluster-unique batch index for a given set-name.
|
||||
Prevents applying one batch multiple times per epoch.
|
||||
|
||||
Args:
|
||||
set_name (str): name of the data set - one of 'train', 'dev'
|
||||
|
||||
Returns:
|
||||
int. new data set index
|
||||
'''
|
||||
with self._lock:
|
||||
if self.is_chief:
|
||||
member = '_index_' + set_name
|
||||
value = getattr(self, member, -1)
|
||||
setattr(self, member, value + 1)
|
||||
return value
|
||||
else:
|
||||
# We are a remote worker and have to hand over to the chief worker by HTTP
|
||||
log_traffic('Asking for next index...')
|
||||
value = int(self._talk_to_chief(PREFIX_NEXT_INDEX + set_name))
|
||||
log_traffic('Got index %d.' % value)
|
||||
return value
|
||||
|
||||
def _get_job(self, worker=0):
|
||||
job = None
|
||||
# Find first running epoch that provides a next job
|
||||
for epoch in self._epochs_running:
|
||||
job = epoch.get_job(worker)
|
||||
if job:
|
||||
return job
|
||||
# No next job found
|
||||
return None
|
||||
|
||||
def get_job(self, worker=0):
|
||||
'''Retrieves the first job for a worker.
|
||||
|
||||
Kwargs:
|
||||
worker (int): index of the worker to get the first job for
|
||||
|
||||
Returns:
|
||||
WorkerJob. a job of one of the running epochs that will get
|
||||
associated with the given worker and put into state 'running'
|
||||
'''
|
||||
# Let's ensure that this does not interfere with other workers/requests
|
||||
with self._lock:
|
||||
if self.is_chief:
|
||||
# First try to get a next job
|
||||
job = self._get_job(worker)
|
||||
|
||||
if job is None:
|
||||
# If there was no next job, we give it a second chance by triggering the epoch state machine
|
||||
if self._next_epoch():
|
||||
# Epoch state machine got a new epoch
|
||||
# Second try to get a next job
|
||||
job = self._get_job(worker)
|
||||
if job is None:
|
||||
# Albeit the epoch state machine got a new epoch, the epoch had no new job for us
|
||||
log_error('Unexpected case - no job for worker %d.' % (worker))
|
||||
return job
|
||||
|
||||
# Epoch state machine has no new epoch
|
||||
# This happens at the end of the whole training - nothing to worry about
|
||||
log_traffic('No jobs left for worker %d.' % (worker))
|
||||
self._log_all_jobs()
|
||||
return None
|
||||
|
||||
# We got a new job from one of the currently running epochs
|
||||
log_traffic('Got new %s' % job)
|
||||
return job
|
||||
|
||||
# We are a remote worker and have to hand over to the chief worker by HTTP
|
||||
result = self._talk_to_chief(PREFIX_GET_JOB + str(FLAGS.task_index))
|
||||
if result:
|
||||
result = pickle.loads(result)
|
||||
return result
|
||||
|
||||
def next_job(self, job):
|
||||
'''Sends a finished job back to the coordinator and retrieves in exchange the next one.
|
||||
|
||||
Kwargs:
|
||||
job (WorkerJob): job that was finished by a worker and who's results are to be
|
||||
digested by the coordinator
|
||||
|
||||
Returns:
|
||||
WorkerJob. next job of one of the running epochs that will get
|
||||
associated with the worker from the finished job and put into state 'running'
|
||||
'''
|
||||
if self.is_chief:
|
||||
# Try to find the epoch the job belongs to
|
||||
epoch = next((epoch for epoch in self._epochs_running if epoch.id == job.epoch_id), None)
|
||||
if epoch:
|
||||
# We are going to manipulate things - let's avoid undefined state
|
||||
with self._lock:
|
||||
# Let the epoch finish the job
|
||||
epoch.finish_job(job)
|
||||
# Check, if epoch is done now
|
||||
if epoch.done():
|
||||
# If it declares itself done, move it from 'running' to 'done' collection
|
||||
self._epochs_running.remove(epoch)
|
||||
self._epochs_done.append(epoch)
|
||||
log_info('%s' % epoch)
|
||||
else:
|
||||
# There was no running epoch found for this job - this should never happen.
|
||||
log_error('There is no running epoch of ID %d for job with ID %d. This should never happen.' % (job.epoch_id, job.id))
|
||||
return self.get_job(job.worker)
|
||||
|
||||
# We are a remote worker and have to hand over to the chief worker by HTTP
|
||||
result = self._talk_to_chief('', data=pickle.dumps(job))
|
||||
if result:
|
||||
result = pickle.loads(result)
|
||||
return result
|
@ -19,20 +19,6 @@ def create_flags():
|
||||
tf.app.flags.DEFINE_string ('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
|
||||
tf.app.flags.DEFINE_string ('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
|
||||
|
||||
# Cluster configuration
|
||||
# =====================
|
||||
|
||||
tf.app.flags.DEFINE_string ('ps_hosts', '', 'parameter servers - comma separated list of hostname:port pairs')
|
||||
tf.app.flags.DEFINE_string ('worker_hosts', '', 'workers - comma separated list of hostname:port pairs')
|
||||
tf.app.flags.DEFINE_string ('job_name', 'localhost', 'job name - one of localhost (default), worker, ps')
|
||||
tf.app.flags.DEFINE_integer ('task_index', 0, 'index of task within the job - worker with index 0 will be the chief')
|
||||
tf.app.flags.DEFINE_integer ('replicas', -1, 'total number of replicas - if negative, its absolute value is multiplied by the number of workers')
|
||||
tf.app.flags.DEFINE_integer ('replicas_to_agg', -1, 'number of replicas to aggregate - if negative, its absolute value is multiplied by the number of workers')
|
||||
tf.app.flags.DEFINE_integer ('coord_retries', 100, 'number of tries of workers connecting to training coordinator before failing')
|
||||
tf.app.flags.DEFINE_string ('coord_host', 'localhost', 'coordination server host')
|
||||
tf.app.flags.DEFINE_integer ('coord_port', 2500, 'coordination server port')
|
||||
tf.app.flags.DEFINE_integer ('iters_per_worker', 1, 'number of train or inference iterations per worker before results are sent back to coordinator')
|
||||
|
||||
# Global Constants
|
||||
# ================
|
||||
|
||||
@ -74,15 +60,12 @@ def create_flags():
|
||||
tf.app.flags.DEFINE_integer ('limit_dev', 0, 'maximum number of elements to use from validation set- 0 means no limit')
|
||||
tf.app.flags.DEFINE_integer ('limit_test', 0, 'maximum number of elements to use from test set- 0 means no limit')
|
||||
|
||||
# Step widths
|
||||
|
||||
tf.app.flags.DEFINE_integer ('validation_step', 0, 'number of epochs we cycle through before validating the model - 0 means no validation steps')
|
||||
|
||||
# Checkpointing
|
||||
|
||||
tf.app.flags.DEFINE_string ('checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
|
||||
tf.app.flags.DEFINE_integer ('checkpoint_secs', 600, 'checkpoint saving interval in seconds')
|
||||
tf.app.flags.DEFINE_integer ('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5')
|
||||
tf.app.flags.DEFINE_string ('load', 'auto', '"last" for loading most recent epoch checkpoint, "best" for loading best validated checkpoint, "init" for initializing a fresh model, "auto" for trying the other options in order last > best > init')
|
||||
|
||||
# Exporting
|
||||
|
||||
@ -96,14 +79,12 @@ def create_flags():
|
||||
# Reporting
|
||||
|
||||
tf.app.flags.DEFINE_integer ('log_level', 1, 'log level for console logs - 0: INFO, 1: WARN, 2: ERROR, 3: FATAL')
|
||||
tf.app.flags.DEFINE_boolean ('log_traffic', False, 'log cluster transaction and traffic information during debug logging')
|
||||
tf.app.flags.DEFINE_boolean ('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.')
|
||||
|
||||
tf.app.flags.DEFINE_boolean ('log_placement', False, 'whether to log device placement of the operators to the console')
|
||||
tf.app.flags.DEFINE_integer ('report_count', 10, 'number of phrases with lowest WER (best matching) to print out during a WER report')
|
||||
|
||||
tf.app.flags.DEFINE_string ('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification')
|
||||
tf.app.flags.DEFINE_integer ('summary_secs', 0, 'interval in seconds for saving TensorBoard summaries - if 0, no summaries will be written')
|
||||
|
||||
# Geometry
|
||||
|
||||
@ -115,15 +96,10 @@ def create_flags():
|
||||
|
||||
# Early Stopping
|
||||
|
||||
tf.app.flags.DEFINE_boolean ('early_stop', True, 'enable early stopping mechanism over validation dataset. Make sure that dev FLAG is enabled for this to work')
|
||||
|
||||
# This parameter is irrespective of the time taken by single epoch to complete and checkpoint saving intervals.
|
||||
# It is possible that early stopping is triggered far after the best checkpoint is already replaced by checkpoint saving interval mechanism.
|
||||
# One has to align the parameters (earlystop_nsteps, checkpoint_secs) accordingly as per the time taken by an epoch on different datasets.
|
||||
|
||||
tf.app.flags.DEFINE_integer ('earlystop_nsteps', 4, 'number of steps to consider for early stopping. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point')
|
||||
tf.app.flags.DEFINE_float ('estop_mean_thresh', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
|
||||
tf.app.flags.DEFINE_float ('estop_std_thresh', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')
|
||||
tf.app.flags.DEFINE_boolean ('early_stop', True, 'enable early stopping mechanism over validation dataset')
|
||||
tf.app.flags.DEFINE_integer ('es_steps', 4, 'number of validations to consider for early stopping. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point')
|
||||
tf.app.flags.DEFINE_float ('es_mean_th', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
|
||||
tf.app.flags.DEFINE_float ('es_std_th', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')
|
||||
|
||||
# Decoder
|
||||
|
||||
|
@ -15,11 +15,6 @@ def log_debug(message):
|
||||
prefix_print('D ', message)
|
||||
|
||||
|
||||
def log_traffic(message):
|
||||
if FLAGS.log_traffic:
|
||||
log_debug(message)
|
||||
|
||||
|
||||
def log_info(message):
|
||||
if FLAGS.log_level <= 1:
|
||||
prefix_print('I ', message)
|
||||
|
Loading…
x
Reference in New Issue
Block a user