diff --git a/.compute b/.compute index 38d03dd0..16977ac5 100755 --- a/.compute +++ b/.compute @@ -21,5 +21,4 @@ python3 -u DeepSpeech.py \ --display_step 0 \ --validation_step 1 \ --checkpoint_dir "../keep" \ - --summary_dir "../keep/summaries" \ - --decoder_library_path "../tmp/native_client/libctc_decoder_with_kenlm.so" + --summary_dir "../keep/summaries" diff --git a/.install b/.install index d7801ca0..f776c1e4 100755 --- a/.install +++ b/.install @@ -7,4 +7,10 @@ pip install tensorflow-gpu==1.12.0rc2 python3 util/taskcluster.py --arch gpu --target ../tmp/native_client +# Install ds_ctcdecoder package from TaskCluster +VERSION=$(python -c 'import pkg_resources; print(pkg_resources.safe_version(open("VERSION").read()))') +PYVER=$(python -c 'import sys; print("cp{0}{1}-cp{0}{1}m".format(sys.version_info.major, sys.version_info.minor))') +python3 util/taskcluster.py --arch cpu --target ../tmp --artifact "ds_ctcdecoder-${VERSION}-${PYVER}-manylinux1_x86_64.whl" +pip install ../tmp/ds_ctcdecoder-*.whl + mkdir -p ../keep/summaries diff --git a/DeepSpeech.py b/DeepSpeech.py index 83e02485..9cc5bf48 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -8,356 +8,24 @@ import sys log_level_index = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0 os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[log_level_index] if log_level_index > 0 and log_level_index < len(sys.argv) else '3' -import datetime -import pickle -import shutil -import six -import subprocess -import tensorflow as tf -import time -import traceback -import inspect import progressbar +import shutil import tempfile +import tensorflow as tf +import traceback +import evaluate -from functools import partial -from six.moves import zip, range, filter, urllib, BaseHTTPServer -from tensorflow.python.tools import freeze_graph +from ds_ctcdecoder import ctc_beam_search_decoder, Scorer +from six.moves import zip, range from tensorflow.contrib.lite.python import tflite_convert -from threading import Thread, Lock +from tensorflow.python.tools import freeze_graph from util.audio import audiofile_to_input_vector from util.feeding import DataSet, ModelFeeder +from util.logging import * +from util.flags import create_flags, FLAGS +from util.coordinator import C, initialize_globals from util.preprocess import preprocess -from util.gpu import get_available_gpus -from util.shared_lib import check_cupti -from util.text import sparse_tensor_value_to_texts, wer, levenshtein, Alphabet -from xdg import BaseDirectory as xdg -import numpy as np - - -def create_flags(): - # Importer - # ======== - - tf.app.flags.DEFINE_string ('train_files', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged') - tf.app.flags.DEFINE_string ('dev_files', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged') - tf.app.flags.DEFINE_string ('test_files', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged') - tf.app.flags.DEFINE_boolean ('fulltrace', False, 'if full trace debug info should be generated during training') - - tf.app.flags.DEFINE_string ('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged') - tf.app.flags.DEFINE_string ('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged') - tf.app.flags.DEFINE_string ('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged') - - # Cluster configuration - # ===================== - - tf.app.flags.DEFINE_string ('ps_hosts', '', 'parameter servers - comma separated list of hostname:port pairs') - tf.app.flags.DEFINE_string ('worker_hosts', '', 'workers - comma separated list of hostname:port pairs') - tf.app.flags.DEFINE_string ('job_name', 'localhost', 'job name - one of localhost (default), worker, ps') - tf.app.flags.DEFINE_integer ('task_index', 0, 'index of task within the job - worker with index 0 will be the chief') - tf.app.flags.DEFINE_integer ('replicas', -1, 'total number of replicas - if negative, its absolute value is multiplied by the number of workers') - tf.app.flags.DEFINE_integer ('replicas_to_agg', -1, 'number of replicas to aggregate - if negative, its absolute value is multiplied by the number of workers') - tf.app.flags.DEFINE_integer ('coord_retries', 100, 'number of tries of workers connecting to training coordinator before failing') - tf.app.flags.DEFINE_string ('coord_host', 'localhost', 'coordination server host') - tf.app.flags.DEFINE_integer ('coord_port', 2500, 'coordination server port') - tf.app.flags.DEFINE_integer ('iters_per_worker', 1, 'number of train or inference iterations per worker before results are sent back to coordinator') - - # Global Constants - # ================ - - tf.app.flags.DEFINE_boolean ('train', True, 'whether to train the network') - tf.app.flags.DEFINE_boolean ('test', True, 'whether to test the network') - tf.app.flags.DEFINE_integer ('epoch', 75, 'target epoch to train - if negative, the absolute number of additional epochs will be trained') - - tf.app.flags.DEFINE_float ('dropout_rate', 0.05, 'dropout rate for feedforward layers') - tf.app.flags.DEFINE_float ('dropout_rate2', -1.0, 'dropout rate for layer 2 - defaults to dropout_rate') - tf.app.flags.DEFINE_float ('dropout_rate3', -1.0, 'dropout rate for layer 3 - defaults to dropout_rate') - tf.app.flags.DEFINE_float ('dropout_rate4', 0.0, 'dropout rate for layer 4 - defaults to 0.0') - tf.app.flags.DEFINE_float ('dropout_rate5', 0.0, 'dropout rate for layer 5 - defaults to 0.0') - tf.app.flags.DEFINE_float ('dropout_rate6', -1.0, 'dropout rate for layer 6 - defaults to dropout_rate') - - tf.app.flags.DEFINE_float ('relu_clip', 20.0, 'ReLU clipping value for non-recurrant layers') - - # Adam optimizer (http://arxiv.org/abs/1412.6980) parameters - - tf.app.flags.DEFINE_float ('beta1', 0.9, 'beta 1 parameter of Adam optimizer') - tf.app.flags.DEFINE_float ('beta2', 0.999, 'beta 2 parameter of Adam optimizer') - tf.app.flags.DEFINE_float ('epsilon', 1e-8, 'epsilon parameter of Adam optimizer') - tf.app.flags.DEFINE_float ('learning_rate', 0.001, 'learning rate of Adam optimizer') - - # Batch sizes - - tf.app.flags.DEFINE_integer ('train_batch_size', 1, 'number of elements in a training batch') - tf.app.flags.DEFINE_integer ('dev_batch_size', 1, 'number of elements in a validation batch') - tf.app.flags.DEFINE_integer ('test_batch_size', 1, 'number of elements in a test batch') - - tf.app.flags.DEFINE_integer ('export_batch_size', 1, 'number of elements per batch on the exported graph') - - # Performance (UNSUPPORTED) - tf.app.flags.DEFINE_integer ('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details') - tf.app.flags.DEFINE_integer ('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details') - - # Sample limits - - tf.app.flags.DEFINE_integer ('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit') - tf.app.flags.DEFINE_integer ('limit_dev', 0, 'maximum number of elements to use from validation set- 0 means no limit') - tf.app.flags.DEFINE_integer ('limit_test', 0, 'maximum number of elements to use from test set- 0 means no limit') - - # Step widths - - tf.app.flags.DEFINE_integer ('display_step', 0, 'number of epochs we cycle through before displaying detailed progress - 0 means no progress display') - tf.app.flags.DEFINE_integer ('validation_step', 0, 'number of epochs we cycle through before validating the model - a detailed progress report is dependent on "--display_step" - 0 means no validation steps') - - # Checkpointing - - tf.app.flags.DEFINE_string ('checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification') - tf.app.flags.DEFINE_integer ('checkpoint_secs', 600, 'checkpoint saving interval in seconds') - tf.app.flags.DEFINE_integer ('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5') - - # Exporting - - tf.app.flags.DEFINE_string ('export_dir', '', 'directory in which exported models are stored - if omitted, the model won\'t get exported') - tf.app.flags.DEFINE_integer ('export_version', 1, 'version number of the exported model') - tf.app.flags.DEFINE_boolean ('remove_export', False, 'whether to remove old exported models') - tf.app.flags.DEFINE_boolean ('export_tflite', False, 'export a graph ready for TF Lite engine') - tf.app.flags.DEFINE_boolean ('use_seq_length', True, 'have sequence_length in the exported graph (will make tfcompile unhappy)') - tf.app.flags.DEFINE_integer ('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency') - - # Reporting - - tf.app.flags.DEFINE_integer ('log_level', 1, 'log level for console logs - 0: INFO, 1: WARN, 2: ERROR, 3: FATAL') - tf.app.flags.DEFINE_boolean ('log_traffic', False, 'log cluster transaction and traffic information during debug logging') - tf.app.flags.DEFINE_boolean ('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.') - - tf.app.flags.DEFINE_string ('wer_log_pattern', '', 'pattern for machine readable global logging of WER progress; has to contain %%s, %%s and %%f for the set name, the date and the float respectively; example: "GLOBAL LOG: logwer(\'12ade231\', %%s, %%s, %%f)" would result in some entry like "GLOBAL LOG: logwer(\'12ade231\', \'train\', \'2017-05-18T03:09:48-0700\', 0.05)"; if omitted (default), there will be no logging') - - tf.app.flags.DEFINE_boolean ('log_placement', False, 'whether to log device placement of the operators to the console') - tf.app.flags.DEFINE_integer ('report_count', 10, 'number of phrases with lowest WER (best matching) to print out during a WER report') - - tf.app.flags.DEFINE_string ('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification') - tf.app.flags.DEFINE_integer ('summary_secs', 0, 'interval in seconds for saving TensorBoard summaries - if 0, no summaries will be written') - - # Geometry - - tf.app.flags.DEFINE_integer ('n_hidden', 2048, 'layer width to use when initialising layers') - - # Initialization - - tf.app.flags.DEFINE_integer ('random_seed', 4567, 'default random seed that is used to initialize variables') - - # Early Stopping - - tf.app.flags.DEFINE_boolean ('early_stop', True, 'enable early stopping mechanism over validation dataset. Make sure that dev FLAG is enabled for this to work') - - # This parameter is irrespective of the time taken by single epoch to complete and checkpoint saving intervals. - # It is possible that early stopping is triggered far after the best checkpoint is already replaced by checkpoint saving interval mechanism. - # One has to align the parameters (earlystop_nsteps, checkpoint_secs) accordingly as per the time taken by an epoch on different datasets. - - tf.app.flags.DEFINE_integer ('earlystop_nsteps', 4, 'number of steps to consider for early stopping. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point') - tf.app.flags.DEFINE_float ('estop_mean_thresh', 0.5, 'mean threshold for loss to determine the condition if early stopping is required') - tf.app.flags.DEFINE_float ('estop_std_thresh', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required') - - # Decoder - - tf.app.flags.DEFINE_string ('decoder_library_path', 'native_client/libctc_decoder_with_kenlm.so', 'path to the libctc_decoder_with_kenlm.so library containing the decoder implementation.') - tf.app.flags.DEFINE_string ('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.') - tf.app.flags.DEFINE_string ('lm_binary_path', 'data/lm/lm.binary', 'path to the language model binary file created with KenLM') - tf.app.flags.DEFINE_string ('lm_trie_path', 'data/lm/trie', 'path to the language model trie file created with native_client/generate_trie') - tf.app.flags.DEFINE_integer ('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions') - tf.app.flags.DEFINE_float ('lm_weight', 1.50, 'the alpha hyperparameter of the CTC decoder. Language Model weight.') - tf.app.flags.DEFINE_float ('valid_word_count_weight', 2.10, 'valid word insertion weight. This is used to lessen the word insertion penalty when the inserted word is part of the vocabulary.') - - # Inference mode - - tf.app.flags.DEFINE_string ('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it. Disables training, testing and exporting.') - -FLAGS = tf.app.flags.FLAGS - -def initialize_globals(): - - # ps and worker hosts required for p2p cluster setup - FLAGS.ps_hosts = list(filter(len, FLAGS.ps_hosts.split(','))) - FLAGS.worker_hosts = list(filter(len, FLAGS.worker_hosts.split(','))) - - # Determine, if we are the chief worker - global is_chief - is_chief = len(FLAGS.worker_hosts) == 0 or (FLAGS.task_index == 0 and FLAGS.job_name == 'worker') - - # Initializing and starting the training coordinator - global COORD - COORD = TrainingCoordinator() - COORD.start() - - # The absolute number of computing nodes - regardless of cluster or single mode - global num_workers - num_workers = max(1, len(FLAGS.worker_hosts)) - - # Create a cluster from the parameter server and worker hosts. - global cluster - cluster = tf.train.ClusterSpec({'ps': FLAGS.ps_hosts, 'worker': FLAGS.worker_hosts}) - - # If replica numbers are negative, we multiply their absolute values with the number of workers - if FLAGS.replicas < 0: - FLAGS.replicas = num_workers * -FLAGS.replicas - if FLAGS.replicas_to_agg < 0: - FLAGS.replicas_to_agg = num_workers * -FLAGS.replicas_to_agg - - # The device path base for this node - global worker_device - worker_device = '/job:%s/task:%d' % (FLAGS.job_name, FLAGS.task_index) - - # This node's CPU device - global cpu_device - cpu_device = worker_device + '/cpu:0' - - # This node's available GPU devices - global available_devices - available_devices = [worker_device + gpu for gpu in get_available_gpus()] - - # If there is no GPU available, we fall back to CPU based operation - if 0 == len(available_devices): - available_devices = [cpu_device] - - # Set default dropout rates - if FLAGS.dropout_rate2 < 0: - FLAGS.dropout_rate2 = FLAGS.dropout_rate - if FLAGS.dropout_rate3 < 0: - FLAGS.dropout_rate3 = FLAGS.dropout_rate - if FLAGS.dropout_rate6 < 0: - FLAGS.dropout_rate6 = FLAGS.dropout_rate - - global dropout_rates - dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)] - - global no_dropout - no_dropout = [ 0.0 ] * 6 - - # Set default checkpoint dir - if len(FLAGS.checkpoint_dir) == 0: - FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('deepspeech','checkpoints')) - - # Set default summary dir - if len(FLAGS.summary_dir) == 0: - FLAGS.summary_dir = xdg.save_data_path(os.path.join('deepspeech','summaries')) - - # Standard session configuration that'll be used for all new sessions. - global session_config - session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement, - inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads, - intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads) - - global alphabet - alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path)) - - # Geometric Constants - # =================== - - # For an explanation of the meaning of the geometric constants, please refer to - # doc/Geometry.md - - # Number of MFCC features - global n_input - n_input = 26 # TODO: Determine this programatically from the sample rate - - # The number of frames in the context - global n_context - n_context = 9 # TODO: Determine the optimal value using a validation data set - - # Number of units in hidden layers - global n_hidden - n_hidden = FLAGS.n_hidden - - global n_hidden_1 - n_hidden_1 = n_hidden - - global n_hidden_2 - n_hidden_2 = n_hidden - - global n_hidden_5 - n_hidden_5 = n_hidden - - # LSTM cell state dimension - global n_cell_dim - n_cell_dim = n_hidden - - # The number of units in the third layer, which feeds in to the LSTM - global n_hidden_3 - n_hidden_3 = n_cell_dim - - # The number of characters in the target language plus one - global n_character - n_character = alphabet.size() + 1 # +1 for CTC blank label - - # The number of units in the sixth layer - global n_hidden_6 - n_hidden_6 = n_character - - # Queues that are used to gracefully stop parameter servers. - # Each queue stands for one ps. A finishing worker sends a token to each queue before joining/quitting. - # Each ps will dequeue as many tokens as there are workers before joining/quitting. - # This ensures parameter servers won't quit, if still required by at least one worker and - # also won't wait forever (like with a standard `server.join()`). - global done_queues - done_queues = [] - for i, ps in enumerate(FLAGS.ps_hosts): - # Queues are hosted by their respective owners - with tf.device('/job:ps/task:%d' % i): - done_queues.append(tf.FIFOQueue(1, tf.int32, shared_name=('queue%i' % i))) - - # Placeholder to pass in the worker's index as token - global token_placeholder - token_placeholder = tf.placeholder(tf.int32) - - # Enqueue operations for each parameter server - global done_enqueues - done_enqueues = [queue.enqueue(token_placeholder) for queue in done_queues] - - # Dequeue operations for each parameter server - global done_dequeues - done_dequeues = [queue.dequeue() for queue in done_queues] - - if len(FLAGS.one_shot_infer) > 0: - FLAGS.train = False - FLAGS.test = False - FLAGS.export_dir = '' - if not os.path.exists(FLAGS.one_shot_infer): - log_error('Path specified in --one_shot_infer is not a valid file.') - exit(1) - - if not os.path.exists(os.path.abspath(FLAGS.decoder_library_path)): - print('ERROR: The decoder library file does not exist. Make sure you have ' \ - 'downloaded or built the native client binaries and pass the ' \ - 'appropriate path to the binaries in the --decoder_library_path parameter.') - - global custom_op_module - custom_op_module = tf.load_op_library(FLAGS.decoder_library_path) - - -# Logging functions -# ================= - -def prefix_print(prefix, message): - print(prefix + ('\n' + prefix).join(message.split('\n'))) - -def log_debug(message): - if FLAGS.log_level == 0: - prefix_print('D ', message) - -def log_traffic(message): - if FLAGS.log_traffic: - log_debug(message) - -def log_info(message): - if FLAGS.log_level <= 1: - prefix_print('I ', message) - -def log_warn(message): - if FLAGS.log_level <= 2: - prefix_print('W ', message) - -def log_error(message): - if FLAGS.log_level <= 3: - prefix_print('E ', message) +from util.text import Alphabet # Graph Creation @@ -371,9 +39,9 @@ def variable_on_worker_level(name, shape, initializer): ''' # Use the /cpu:0 device on worker_device for scoped operations if len(FLAGS.ps_hosts) == 0: - device = worker_device + device = C.worker_device else: - device = tf.train.replica_device_setter(worker_device=worker_device, cluster=cluster) + device = tf.train.replica_device_setter(worker_device=C.worker_device, cluster=C.cluster) with tf.device(device): # Create or get apropos variable @@ -407,29 +75,29 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1 # Permute n_steps and batch_size batch_x = tf.transpose(batch_x, [1, 0, 2, 3]) # Reshape to prepare input for first layer - batch_x = tf.reshape(batch_x, [-1, n_input + 2*n_input*n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context) + batch_x = tf.reshape(batch_x, [-1, C.n_input + 2*C.n_input*C.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context) layers['input_reshaped'] = batch_x # The next three blocks will pass `batch_x` through three hidden layers with # clipped RELU activation and dropout. # 1st layer - b1 = variable_on_worker_level('b1', [n_hidden_1], tf.zeros_initializer()) - h1 = variable_on_worker_level('h1', [n_input + 2*n_input*n_context, n_hidden_1], tf.contrib.layers.xavier_initializer()) + b1 = variable_on_worker_level('b1', [C.n_hidden_1], tf.zeros_initializer()) + h1 = variable_on_worker_level('h1', [C.n_input + 2*C.n_input*C.n_context, C.n_hidden_1], tf.contrib.layers.xavier_initializer()) layer_1 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(batch_x, h1), b1)), FLAGS.relu_clip) layer_1 = tf.nn.dropout(layer_1, (1.0 - dropout[0])) layers['layer_1'] = layer_1 # 2nd layer - b2 = variable_on_worker_level('b2', [n_hidden_2], tf.zeros_initializer()) - h2 = variable_on_worker_level('h2', [n_hidden_1, n_hidden_2], tf.contrib.layers.xavier_initializer()) + b2 = variable_on_worker_level('b2', [C.n_hidden_2], tf.zeros_initializer()) + h2 = variable_on_worker_level('h2', [C.n_hidden_1, C.n_hidden_2], tf.contrib.layers.xavier_initializer()) layer_2 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(layer_1, h2), b2)), FLAGS.relu_clip) layer_2 = tf.nn.dropout(layer_2, (1.0 - dropout[1])) layers['layer_2'] = layer_2 # 3rd layer - b3 = variable_on_worker_level('b3', [n_hidden_3], tf.zeros_initializer()) - h3 = variable_on_worker_level('h3', [n_hidden_2, n_hidden_3], tf.contrib.layers.xavier_initializer()) + b3 = variable_on_worker_level('b3', [C.n_hidden_3], tf.zeros_initializer()) + h3 = variable_on_worker_level('h3', [C.n_hidden_2, C.n_hidden_3], tf.contrib.layers.xavier_initializer()) layer_3 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(layer_2, h3), b3)), FLAGS.relu_clip) layer_3 = tf.nn.dropout(layer_3, (1.0 - dropout[2])) layers['layer_3'] = layer_3 @@ -439,14 +107,14 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1 # Forward direction cell: if not tflite: - fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(n_cell_dim, reuse=reuse) + fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(C.n_cell_dim, reuse=reuse) layers['fw_cell'] = fw_cell else: - fw_cell = tf.nn.rnn_cell.LSTMCell(n_cell_dim, reuse=reuse) + fw_cell = tf.nn.rnn_cell.LSTMCell(C.n_cell_dim, reuse=reuse) # `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`, # as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`. - layer_3 = tf.reshape(layer_3, [n_steps, batch_size, n_hidden_3]) + layer_3 = tf.reshape(layer_3, [n_steps, batch_size, C.n_hidden_3]) if tflite: # Generated StridedSlice, not supported by NNAPI #n_layer_3 = [] @@ -467,48 +135,33 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1 # Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim] # to a tensor of shape [n_steps*batch_size, n_cell_dim] - output = tf.reshape(output, [-1, n_cell_dim]) + output = tf.reshape(output, [-1, C.n_cell_dim]) layers['rnn_output'] = output layers['rnn_output_state'] = output_state # Now we feed `output` to the fifth hidden layer with clipped RELU activation and dropout - b5 = variable_on_worker_level('b5', [n_hidden_5], tf.zeros_initializer()) - h5 = variable_on_worker_level('h5', [n_cell_dim, n_hidden_5], tf.contrib.layers.xavier_initializer()) + b5 = variable_on_worker_level('b5', [C.n_hidden_5], tf.zeros_initializer()) + h5 = variable_on_worker_level('h5', [C.n_cell_dim, C.n_hidden_5], tf.contrib.layers.xavier_initializer()) layer_5 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(output, h5), b5)), FLAGS.relu_clip) layer_5 = tf.nn.dropout(layer_5, (1.0 - dropout[5])) layers['layer_5'] = layer_5 # Now we apply the weight matrix `h6` and bias `b6` to the output of `layer_5` # creating `n_classes` dimensional vectors, the logits. - b6 = variable_on_worker_level('b6', [n_hidden_6], tf.zeros_initializer()) - h6 = variable_on_worker_level('h6', [n_hidden_5, n_hidden_6], tf.contrib.layers.xavier_initializer()) + b6 = variable_on_worker_level('b6', [C.n_hidden_6], tf.zeros_initializer()) + h6 = variable_on_worker_level('h6', [C.n_hidden_5, C.n_hidden_6], tf.contrib.layers.xavier_initializer()) layer_6 = tf.add(tf.matmul(layer_5, h6), b6) layers['layer_6'] = layer_6 # Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6] # to the slightly more useful shape [n_steps, batch_size, n_hidden_6]. # Note, that this differs from the input in that it is time-major. - layer_6 = tf.reshape(layer_6, [n_steps, batch_size, n_hidden_6], name="raw_logits") + layer_6 = tf.reshape(layer_6, [n_steps, batch_size, C.n_hidden_6], name="raw_logits") layers['raw_logits'] = layer_6 # Output shape: [n_steps, batch_size, n_hidden_6] return layer_6, layers -def decode_with_lm(inputs, sequence_length, beam_width=100, - top_paths=1, merge_repeated=True): - decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = ( - custom_op_module.ctc_beam_search_decoder_with_lm( - inputs, sequence_length, beam_width=beam_width, - model_path=FLAGS.lm_binary_path, trie_path=FLAGS.lm_trie_path, alphabet_path=FLAGS.alphabet_config_path, - lm_weight=FLAGS.lm_weight, valid_word_count_weight=FLAGS.valid_word_count_weight, - top_paths=top_paths, merge_repeated=merge_repeated)) - - return ( - [tf.SparseTensor(ix, val, shape) for (ix, val, shape) - in zip(decoded_ixs, decoded_vals, decoded_shapes)], - log_probabilities) - - # Accuracy and Loss # ================= @@ -538,23 +191,8 @@ def calculate_mean_edit_distance_and_loss(model_feeder, tower, dropout, reuse): # Calculate the average loss across the batch avg_loss = tf.reduce_mean(total_loss) - # Beam search decode the batch - decoded, _ = decode_with_lm(logits, batch_seq_len, merge_repeated=False, beam_width=FLAGS.beam_width) - - # Compute the edit (Levenshtein) distance - distance = tf.edit_distance(tf.cast(decoded[0], tf.int32), batch_y) - - # Compute the mean edit distance - mean_edit_distance = tf.reduce_mean(distance) - - # Finally we return the - # - calculated total and - # - average losses, - # - the Levenshtein distance, - # - the recognition mean edit distance, - # - the decoded batch and - # - the original batch_y (which contains the verified transcriptions). - return total_loss, avg_loss, distance, mean_edit_distance, decoded, batch_y + # Finally we return the average loss + return avg_loss # Adam Optimization @@ -590,78 +228,38 @@ def create_optimizer(): # on which all operations within the tower execute. # For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`. -def get_tower_results(model_feeder, optimizer): +def get_tower_results(model_feeder, optimizer, dropout_rates): r''' With this preliminary step out of the way, we can for each GPU introduce a - tower for which's batch we calculate - - * The CTC decodings ``decoded``, - * The (total) loss against the outcome (Y) ``total_loss``, - * The loss averaged over the whole batch ``avg_loss``, - * The optimization gradient (computed based on the averaged loss), - * The Levenshtein distances between the decodings and their transcriptions ``distance``, - * The mean edit distance of the outcome averaged over the whole batch ``mean_edit_distance`` - - and retain the original ``labels`` (Y). - ``decoded``, ``labels``, the optimization gradient, ``distance``, ``mean_edit_distance``, - ``total_loss`` and ``avg_loss`` are collected into the corresponding arrays - ``tower_decodings``, ``tower_labels``, ``tower_gradients``, ``tower_distances``, - ``tower_mean_edit_distances``, ``tower_total_losses``, ``tower_avg_losses`` (dimension 0 being the tower). - Finally this new method ``get_tower_results()`` will return those tower arrays. - In case of ``tower_mean_edit_distances`` and ``tower_avg_losses``, it will return the - averaged values instead of the arrays. + tower for which's batch we calculate and return the optimization gradients + and the average loss across towers. ''' - # Tower labels to return - tower_labels = [] - - # Tower decodings to return - tower_decodings = [] - - # Tower distances to return - tower_distances = [] - - # Tower total batch losses to return - tower_total_losses = [] + # To calculate the mean of the losses + tower_avg_losses = [] # Tower gradients to return tower_gradients = [] - # To calculate the mean of the mean edit distances - tower_mean_edit_distances = [] - - # To calculate the mean of the losses - tower_avg_losses = [] - with tf.variable_scope(tf.get_variable_scope()): # Loop over available_devices - for i in range(len(available_devices)): + for i in range(len(C.available_devices)): # Execute operations of tower i on device i if len(FLAGS.ps_hosts) == 0: - device = available_devices[i] + device = C.available_devices[i] else: - device = tf.train.replica_device_setter(worker_device=available_devices[i], cluster=cluster) + device = tf.train.replica_device_setter(worker_device=C.available_devices[i], cluster=cluster) with tf.device(device): # Create a scope for all operations of tower i with tf.name_scope('tower_%d' % i) as scope: # Calculate the avg_loss and mean_edit_distance and retrieve the decoded # batch along with the original batch's labels (Y) of this tower - total_loss, avg_loss, distance, mean_edit_distance, decoded, labels = \ - calculate_mean_edit_distance_and_loss(model_feeder, i, dropout_rates, reuse=i>0) + avg_loss = calculate_mean_edit_distance_and_loss(model_feeder, i, dropout_rates, reuse=i>0) # Allow for variables to be re-used by the next tower tf.get_variable_scope().reuse_variables() - # Retain tower's labels (Y) - tower_labels.append(labels) - - # Retain tower's decoded batch - tower_decodings.append(decoded) - - # Retain tower's distances - tower_distances.append(distance) - - # Retain tower's total losses - tower_total_losses.append(total_loss) + # Retain tower's avg losses + tower_avg_losses.append(avg_loss) # Compute gradients for model parameters using tower's mini-batch gradients = optimizer.compute_gradients(avg_loss) @@ -669,21 +267,13 @@ def get_tower_results(model_feeder, optimizer): # Retain tower's gradients tower_gradients.append(gradients) - # Retain tower's mean edit distance - tower_mean_edit_distances.append(mean_edit_distance) - - # Retain tower's avg losses - tower_avg_losses.append(avg_loss) avg_loss_across_towers = tf.reduce_mean(tower_avg_losses, 0) tf.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries']) - # Return the results tuple, the gradients, and the means of mean edit distances and losses - return (tower_labels, tower_decodings, tower_distances, tower_total_losses), \ - tower_gradients, \ - tf.reduce_mean(tower_mean_edit_distances, 0), \ - avg_loss_across_towers + # Return gradients and the average loss + return tower_gradients, avg_loss_across_towers def average_gradients(tower_gradients): @@ -696,7 +286,7 @@ def average_gradients(tower_gradients): average_grads = [] # Run this on cpu_device to conserve GPU memory - with tf.device(cpu_device): + with tf.device(C.cpu_device): # Loop over gradient/variable pairs from all towers for grad_and_vars in zip(*tower_gradients): # Introduce grads to store the gradients for the current variable @@ -756,714 +346,21 @@ def log_grads_and_vars(grads_and_vars): for gradient, variable in grads_and_vars: log_variable(variable, gradient=gradient) -def get_git_revision_hash(): - return subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip() - -def get_git_branch(): - return subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).strip() - # Helpers # ======= -def calculate_report(results_tuple): - r''' - This routine will calculate a WER report. - It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest - loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER). - ''' - samples = [] - items = list(zip(*results_tuple)) - total_levenshtein = 0.0 - total_label_length = 0.0 - for label, decoding, distance, loss in items: - sample_wer = wer(label, decoding) - sample = Sample(label, decoding, loss, distance, sample_wer) - samples.append(sample) - total_levenshtein += levenshtein(label.split(), decoding.split()) - total_label_length += float(len(label.split())) - - # Getting the WER from the accumulated levenshteins and lengths - samples_wer = total_levenshtein / total_label_length - - # Filter out all items with WER=0 - samples = [s for s in samples if s.wer > 0] - - # Order the remaining items by their loss (lowest loss on top) - samples.sort(key=lambda s: s.loss) - - # Take only the first report_count items - samples = samples[:FLAGS.report_count] - - # Order this top FLAGS.report_count items by their WER (lowest WER on top) - samples.sort(key=lambda s: s.wer) - - return samples_wer, samples - -def collect_results(results_tuple, returns): - r''' - This routine will help collecting partial results for the WER reports. - The ``results_tuple`` is composed of an array of the original labels, - an array of the corresponding decodings, an array of the corrsponding - distances and an array of the corresponding losses. ``returns`` is built up - in a similar way, containing just the unprocessed results of one - ``session.run`` call (effectively of one batch). - Labels and decodings are converted to text before splicing them into their - corresponding results_tuple lists. In the case of decodings, - for now we just pick the first available path. - ''' - # Each of the arrays within results_tuple will get extended by a batch of each available device - for i in range(len(available_devices)): - # Collect the labels - results_tuple[0].extend(sparse_tensor_value_to_texts(returns[0][i], alphabet)) - - # Collect the decodings - at the moment we default to the first one - results_tuple[1].extend(sparse_tensor_value_to_texts(returns[1][i][0], alphabet)) - - # Collect the distances - results_tuple[2].extend(returns[2][i]) - - # Collect the losses - results_tuple[3].extend(returns[3][i]) - - -# For reporting we also need a standard way to do time measurements. -def stopwatch(start_duration=0): - r''' - This function will toggle a stopwatch. - The first call starts it, second call stops it, third call continues it etc. - So if you want to measure the accumulated time spent in a certain area of the code, - you can surround that code by stopwatch-calls like this: - - .. code:: python - - fun_time = 0 # initializes a stopwatch - [...] - for i in range(10): - [...] - # Starts/continues the stopwatch - fun_time is now a point in time (again) - fun_time = stopwatch(fun_time) - fun() - # Pauses the stopwatch - fun_time is now a duration - fun_time = stopwatch(fun_time) - [...] - # The following line only makes sense after an even call of :code:`fun_time = stopwatch(fun_time)`. - print 'Time spent in fun():', format_duration(fun_time) - - ''' - if start_duration == 0: - return datetime.datetime.utcnow() - else: - return datetime.datetime.utcnow() - start_duration - -def format_duration(duration): - '''Formats the result of an even stopwatch call as hours:minutes:seconds''' - duration = duration if isinstance(duration, int) else duration.seconds - m, s = divmod(duration, 60) - h, m = divmod(m, 60) - return '%d:%02d:%02d' % (h, m, s) - - -# Execution -# ========= - -# String constants for different services of the web handler -PREFIX_NEXT_INDEX = '/next_index_' -PREFIX_GET_JOB = '/get_job_' - -# Global ID counter for all objects requiring an ID -id_counter = 0 - -def new_id(): - '''Returns a new ID that is unique on process level. Not thread-safe. - - Returns: - int. The new ID - ''' - global id_counter - id_counter += 1 - return id_counter - -class Sample(object): - '''Represents one item of a WER report. - - Args: - src (str): source text - res (str): resulting text - loss (float): computed loss of this item - mean_edit_distance (float): computed mean edit distance of this item - ''' - def __init__(self, src, res, loss, mean_edit_distance, sample_wer): - self.src = src - self.res = res - self.loss = loss - self.mean_edit_distance = mean_edit_distance - self.wer = sample_wer - - def __str__(self): - return 'WER: %f, loss: %f, mean edit distance: %f\n - src: "%s"\n - res: "%s"' % (self.wer, self.loss, self.mean_edit_distance, self.src, self.res) - -class WorkerJob(object): - '''Represents a job that should be executed by a worker. - - Args: - epoch_id (int): the ID of the 'parent' epoch - index (int): the epoch index of the 'parent' epoch - set_name (str): the name of the data-set - one of 'train', 'dev', 'test' - steps (int): the number of `session.run` calls - report (bool): if this job should produce a WER report - ''' - def __init__(self, epoch_id, index, set_name, steps, report): - self.id = new_id() - self.epoch_id = epoch_id - self.index = index - self.worker = -1 - self.set_name = set_name - self.steps = steps - self.report = report - self.loss = -1 - self.mean_edit_distance = -1 - self.wer = -1 - self.samples = [] - - def __str__(self): - return 'Job (ID: %d, worker: %d, epoch: %d, set_name: %s)' % (self.id, self.worker, self.index, self.set_name) - -class Epoch(object): - '''Represents an epoch that should be executed by the Training Coordinator. - Creates `num_jobs` `WorkerJob` instances in state 'open'. - - Args: - index (int): the epoch index of the 'parent' epoch - num_jobs (int): the number of jobs in this epoch - - Kwargs: - set_name (str): the name of the data-set - one of 'train', 'dev', 'test' - report (bool): if this job should produce a WER report - ''' - def __init__(self, index, num_jobs, set_name='train', report=False): - self.id = new_id() - self.index = index - self.num_jobs = num_jobs - self.set_name = set_name - self.report = report - self.wer = -1 - self.loss = -1 - self.mean_edit_distance = -1 - self.jobs_open = [] - self.jobs_running = [] - self.jobs_done = [] - self.samples = [] - for i in range(self.num_jobs): - self.jobs_open.append(WorkerJob(self.id, self.index, self.set_name, FLAGS.iters_per_worker, self.report)) - - def name(self): - '''Gets a printable name for this epoch. - - Returns: - str. printable name for this epoch - ''' - if self.index >= 0: - ename = ' of Epoch %d' % self.index - else: - ename = '' - if self.set_name == 'train': - return 'Training%s' % ename - elif self.set_name == 'dev': - return 'Validation%s' % ename - else: - return 'Test%s' % ename - - def get_job(self, worker): - '''Gets the next open job from this epoch. The job will be marked as 'running'. - - Args: - worker (int): index of the worker that takes the job - - Returns: - WorkerJob. job that has been marked as running for this worker - ''' - if len(self.jobs_open) > 0: - job = self.jobs_open.pop(0) - self.jobs_running.append(job) - job.worker = worker - return job - else: - return None - - def finish_job(self, job): - '''Finishes a running job. Removes it from the running jobs list and adds it to the done jobs list. - - Args: - job (WorkerJob): the job to put into state 'done' - ''' - index = next((i for i in range(len(self.jobs_running)) if self.jobs_running[i].id == job.id), -1) - if index >= 0: - self.jobs_running.pop(index) - self.jobs_done.append(job) - log_traffic('%s - Moved %s from running to done.' % (self.name(), job)) - else: - log_warn('%s - There is no job with ID %d registered as running.' % (self.name(), job.id)) - - def done(self): - '''Checks, if all jobs of the epoch are in state 'done'. - It also lazy-prepares a WER report from the result data of all jobs. - - Returns: - bool. if all jobs of the epoch are 'done' - ''' - if len(self.jobs_open) == 0 and len(self.jobs_running) == 0: - num_jobs = len(self.jobs_done) - if num_jobs > 0: - jobs = self.jobs_done - self.jobs_done = [] - if not self.num_jobs == num_jobs: - log_warn('%s - Number of steps not equal to number of jobs done.' % (self.name())) - - agg_loss = 0.0 - agg_wer = 0.0 - agg_mean_edit_distance = 0.0 - - for i in range(num_jobs): - job = jobs.pop(0) - agg_loss += job.loss - if self.report: - agg_wer += job.wer - agg_mean_edit_distance += job.mean_edit_distance - self.samples.extend(job.samples) - - self.loss = agg_loss / num_jobs - - # if the job was for validation dataset then append it to the COORD's _loss for early stop verification - if (FLAGS.early_stop is True) and (self.set_name == 'dev'): - COORD._dev_losses.append(self.loss) - - if self.report: - self.wer = agg_wer / num_jobs - self.mean_edit_distance = agg_mean_edit_distance / num_jobs - - # Order samles by their loss (lowest loss on top) - self.samples.sort(key=lambda s: s.loss) - - # Take only the first report_count items - self.samples = self.samples[:FLAGS.report_count] - - # Order this top FLAGS.report_count items by their WER (lowest WER on top) - self.samples.sort(key=lambda s: s.wer) - - # Append WER to WER log file - if len(FLAGS.wer_log_pattern) > 0: - time = datetime.datetime.utcnow().isoformat() - # Log WER progress - print(FLAGS.wer_log_pattern % (time, self.set_name, self.wer)) - - return True - return False - - def job_status(self): - '''Provides a printable overview of the states of the jobs of this epoch. - - Returns: - str. printable overall job state - ''' - return '%s - jobs open: %d, jobs running: %d, jobs done: %d' % (self.name(), len(self.jobs_open), len(self.jobs_running), len(self.jobs_done)) - - def __str__(self): - if not self.done(): - return self.job_status() - - if not self.report: - return '%s - loss: %f' % (self.name(), self.loss) - - s = '%s - WER: %f, loss: %s, mean edit distance: %f' % (self.name(), self.wer, self.loss, self.mean_edit_distance) - if len(self.samples) > 0: - line = '\n' + ('-' * 80) - for sample in self.samples: - s += '%s\n%s' % (line, sample) - s += line - return s - - -class TrainingCoordinator(object): - ''' Central training coordination class. - Used for distributing jobs among workers of a cluster. - Instantiated on all workers, calls of non-chief workers will transparently - HTTP-forwarded to the chief worker instance. - ''' - - class TrainingCoordinationHandler(BaseHTTPServer.BaseHTTPRequestHandler): - '''Handles HTTP requests from remote workers to the Training Coordinator. - ''' - def _send_answer(self, data=None): - self.send_response(200) - self.send_header('content-type', 'text/plain') - self.end_headers() - if data: - self.wfile.write(data) - - def do_GET(self): - if COORD.started: - if self.path.startswith(PREFIX_NEXT_INDEX): - index = COORD.get_next_index(self.path[len(PREFIX_NEXT_INDEX):]) - if index >= 0: - self._send_answer(str(index).encode("utf-8")) - return - elif self.path.startswith(PREFIX_GET_JOB): - job = COORD.get_job(worker=int(self.path[len(PREFIX_GET_JOB):])) - if job: - self._send_answer(pickle.dumps(job)) - return - self.send_response(204) # end of training - else: - self.send_response(202) # not ready yet - self.end_headers() - - def do_POST(self): - if COORD.started: - src = self.rfile.read(int(self.headers['content-length'])) - job = COORD.next_job(pickle.loads(src)) - if job: - self._send_answer(pickle.dumps(job)) - return - self.send_response(204) # end of training - else: - self.send_response(202) # not ready yet - self.end_headers() - - def log_message(self, format, *args): - '''Overriding base method to suppress web handler messages on stdout. - ''' - return - - def __init__(self): - self._init() - self._lock = Lock() - self.started = False - if is_chief: - self._httpd = BaseHTTPServer.HTTPServer((FLAGS.coord_host, FLAGS.coord_port), TrainingCoordinator.TrainingCoordinationHandler) - - def _reset_counters(self): - self._index_train = 0 - self._index_dev = 0 - self._index_test = 0 - - def _init(self): - self._epochs_running = [] - self._epochs_done = [] - self._reset_counters() - self._dev_losses = [] - - def _log_all_jobs(self): - '''Use this to debug-print epoch state''' - log_debug('Epochs - running: %d, done: %d' % (len(self._epochs_running), len(self._epochs_done))) - for epoch in self._epochs_running: - log_debug(' - running: ' + epoch.job_status()) - - def start_coordination(self, model_feeder, step=0): - '''Starts to coordinate epochs and jobs among workers on base of - data-set sizes, the (global) step and FLAGS parameters. - - Args: - model_feeder (ModelFeeder): data-sets to be used for coordinated training - - Kwargs: - step (int): global step of a loaded model to determine starting point - ''' - with self._lock: - self._init() - - # Number of GPUs per worker - fixed for now by local reality or cluster setup - gpus_per_worker = len(available_devices) - - # Number of batches processed per job per worker - batches_per_job = gpus_per_worker * max(1, FLAGS.iters_per_worker) - - # Number of batches per global step - batches_per_step = gpus_per_worker * max(1, FLAGS.replicas_to_agg) - - # Number of global steps per epoch - to be at least 1 - steps_per_epoch = max(1, model_feeder.train.total_batches // batches_per_step) - - # The start epoch of our training - self._epoch = step // steps_per_epoch - - # Number of additional 'jobs' trained already 'on top of' our start epoch - jobs_trained = (step % steps_per_epoch) * batches_per_step // batches_per_job - - # Total number of train/dev/test jobs covering their respective whole sets (one epoch) - self._num_jobs_train = max(1, model_feeder.train.total_batches // batches_per_job) - self._num_jobs_dev = max(1, model_feeder.dev.total_batches // batches_per_job) - self._num_jobs_test = max(1, model_feeder.test.total_batches // batches_per_job) - - if FLAGS.epoch < 0: - # A negative epoch means to add its absolute number to the epochs already computed - self._target_epoch = self._epoch + abs(FLAGS.epoch) - else: - self._target_epoch = FLAGS.epoch - - # State variables - # We only have to train, if we are told so and are not at the target epoch yet - self._train = FLAGS.train and self._target_epoch > self._epoch - self._test = FLAGS.test - - if self._train: - # The total number of jobs for all additional epochs to be trained - # Will be decremented for each job that is produced/put into state 'open' - self._num_jobs_train_left = (self._target_epoch - self._epoch) * self._num_jobs_train - jobs_trained - log_info('STARTING Optimization') - self._training_time = stopwatch() - - # Important for debugging - log_debug('step: %d' % step) - log_debug('epoch: %d' % self._epoch) - log_debug('target epoch: %d' % self._target_epoch) - log_debug('steps per epoch: %d' % steps_per_epoch) - log_debug('number of batches in train set: %d' % model_feeder.train.total_batches) - log_debug('batches per job: %d' % batches_per_job) - log_debug('batches per step: %d' % batches_per_step) - log_debug('number of jobs in train set: %d' % self._num_jobs_train) - log_debug('number of jobs already trained in first epoch: %d' % jobs_trained) - - self._next_epoch() - - # The coordinator is ready to serve - self.started = True - - def _next_epoch(self): - # State-machine of the coordination process - - # Indicates, if there were 'new' epoch(s) provided - result = False - - # Make sure that early stop is enabled and validation part is enabled - if (FLAGS.early_stop is True) and (FLAGS.validation_step > 0) and (len(self._dev_losses) >= FLAGS.earlystop_nsteps): - - # Calculate the mean of losses for past epochs - mean_loss = np.mean(self._dev_losses[-FLAGS.earlystop_nsteps:-1]) - # Calculate the standard deviation for losses from validation part in the past epochs - std_loss = np.std(self._dev_losses[-FLAGS.earlystop_nsteps:-1]) - # Update the list of losses incurred - self._dev_losses = self._dev_losses[-FLAGS.earlystop_nsteps:] - log_debug('Checking for early stopping (last %d steps) validation loss: %f, with standard deviation: %f and mean: %f' % (FLAGS.earlystop_nsteps, self._dev_losses[-1], std_loss, mean_loss)) - - # Check if validation loss has started increasing or is not decreasing substantially, making sure slight fluctuations don't bother the early stopping from working - if self._dev_losses[-1] > np.max(self._dev_losses[:-1]) or (abs(self._dev_losses[-1] - mean_loss) < FLAGS.estop_mean_thresh and std_loss < FLAGS.estop_std_thresh): - # Time to early stop - log_info('Early stop triggered as (for last %d steps) validation loss: %f with standard deviation: %f and mean: %f' % (FLAGS.earlystop_nsteps, self._dev_losses[-1], std_loss, mean_loss)) - self._dev_losses = [] - self._end_training() - self._train = False - - if self._train: - # We are in train mode - if self._num_jobs_train_left > 0: - # There are still jobs left - num_jobs_train = min(self._num_jobs_train_left, self._num_jobs_train) - self._num_jobs_train_left -= num_jobs_train - - # Let's try our best to keep the notion of curriculum learning - self._reset_counters() - - # If the training part of the current epoch should generate a WER report - is_display_step = FLAGS.display_step > 0 and (FLAGS.display_step == 1 or self._epoch > 0) and (self._epoch % FLAGS.display_step == 0 or self._epoch == self._target_epoch) - # Append the training epoch - self._epochs_running.append(Epoch(self._epoch, num_jobs_train, set_name='train', report=is_display_step)) - - if FLAGS.validation_step > 0 and (FLAGS.validation_step == 1 or self._epoch > 0) and self._epoch % FLAGS.validation_step == 0: - # The current epoch should also have a validation part - self._epochs_running.append(Epoch(self._epoch, self._num_jobs_dev, set_name='dev', report=is_display_step)) - - - # Indicating that there were 'new' epoch(s) provided - result = True - else: - # No jobs left, but still in train mode: concluding training - self._end_training() - self._train = False - - if self._test and not self._train: - # We shall test, and are not in train mode anymore - self._test = False - self._epochs_running.append(Epoch(self._epoch, self._num_jobs_test, set_name='test', report=True)) - # Indicating that there were 'new' epoch(s) provided - result = True - - if result: - # Increment the epoch index - shared among train and test 'state' - self._epoch += 1 - return result - - def _end_training(self): - self._training_time = stopwatch(self._training_time) - log_info('FINISHED Optimization - training time: %s' % format_duration(self._training_time)) - - def start(self): - '''Starts Training Coordinator. If chief, it starts a web server for - communication with non-chief instances. - ''' - if is_chief: - log_debug('Starting coordinator...') - self._thread = Thread(target=self._httpd.serve_forever) - self._thread.daemon = True - self._thread.start() - log_debug('Coordinator started.') - - def stop(self, wait_for_running_epochs=True): - '''Stops Training Coordinator. If chief, it waits for all epochs to be - 'done' and then shuts down the web server. - ''' - if is_chief: - if wait_for_running_epochs: - while len(self._epochs_running) > 0: - log_traffic('Coordinator is waiting for epochs to finish...') - time.sleep(5) - log_debug('Stopping coordinator...') - self._httpd.shutdown() - log_debug('Coordinator stopped.') - - def _talk_to_chief(self, path, data=None, default=None): - tries = 0 - while tries < FLAGS.coord_retries: - tries += 1 - try: - url = 'http://%s:%d%s' % (FLAGS.coord_host, FLAGS.coord_port, path) - log_traffic('Contacting coordinator - url: %s, tries: %d ...' % (url, tries-1)) - res = urllib.request.urlopen(urllib.request.Request(url, data, { 'content-type': 'text/plain' })) - str = res.read() - status = res.getcode() - log_traffic('Coordinator responded - url: %s, status: %s' % (url, status)) - if status == 200: - return str - if status == 204: # We use 204 (no content) to indicate end of training - return default - except urllib.error.HTTPError as error: - log_traffic('Problem reaching coordinator - url: %s, HTTP code: %d' % (url, error.code)) - pass - time.sleep(10) - return default - - def get_next_index(self, set_name): - '''Retrives a new cluster-unique batch index for a given set-name. - Prevents applying one batch multiple times per epoch. - - Args: - set_name (str): name of the data set - one of 'train', 'dev', 'test' - - Returns: - int. new data set index - ''' - with self._lock: - if is_chief: - member = '_index_' + set_name - value = getattr(self, member, -1) - setattr(self, member, value + 1) - return value - else: - # We are a remote worker and have to hand over to the chief worker by HTTP - log_traffic('Asking for next index...') - value = int(self._talk_to_chief(PREFIX_NEXT_INDEX + set_name)) - log_traffic('Got index %d.' % value) - return value - - def _get_job(self, worker=0): - job = None - # Find first running epoch that provides a next job - for epoch in self._epochs_running: - job = epoch.get_job(worker) - if job: - return job - # No next job found - return None - - def get_job(self, worker=0): - '''Retrieves the first job for a worker. - - Kwargs: - worker (int): index of the worker to get the first job for - - Returns: - WorkerJob. a job of one of the running epochs that will get - associated with the given worker and put into state 'running' - ''' - # Let's ensure that this does not interfere with other workers/requests - with self._lock: - if is_chief: - # First try to get a next job - job = self._get_job(worker) - - if job is None: - # If there was no next job, we give it a second chance by triggering the epoch state machine - if self._next_epoch(): - # Epoch state machine got a new epoch - # Second try to get a next job - job = self._get_job(worker) - if job is None: - # Albeit the epoch state machine got a new epoch, the epoch had no new job for us - log_error('Unexpected case - no job for worker %d.' % (worker)) - return job - - # Epoch state machine has no new epoch - # This happens at the end of the whole training - nothing to worry about - log_traffic('No jobs left for worker %d.' % (worker)) - self._log_all_jobs() - return None - - # We got a new job from one of the currently running epochs - log_traffic('Got new %s' % job) - return job - - # We are a remote worker and have to hand over to the chief worker by HTTP - result = self._talk_to_chief(PREFIX_GET_JOB + str(FLAGS.task_index)) - if result: - result = pickle.loads(result) - return result - - def next_job(self, job): - '''Sends a finished job back to the coordinator and retrieves in exchange the next one. - - Kwargs: - job (WorkerJob): job that was finished by a worker and who's results are to be - digested by the coordinator - - Returns: - WorkerJob. next job of one of the running epochs that will get - associated with the worker from the finished job and put into state 'running' - ''' - if is_chief: - # Try to find the epoch the job belongs to - epoch = next((epoch for epoch in self._epochs_running if epoch.id == job.epoch_id), None) - if epoch: - # We are going to manipulate things - let's avoid undefined state - with self._lock: - # Let the epoch finish the job - epoch.finish_job(job) - # Check, if epoch is done now - if epoch.done(): - # If it declares itself done, move it from 'running' to 'done' collection - self._epochs_running.remove(epoch) - self._epochs_done.append(epoch) - log_info('%s' % epoch) - else: - # There was no running epoch found for this job - this should never happen. - log_error('There is no running epoch of ID %d for job with ID %d. This should never happen.' % (job.epoch_id, job.id)) - return self.get_job(job.worker) - - # We are a remote worker and have to hand over to the chief worker by HTTP - result = self._talk_to_chief('', data=pickle.dumps(job)) - if result: - result = pickle.loads(result) - return result - def send_token_to_ps(session, kill=False): # Sending our token (the task_index as a debug opportunity) to each parameter server. # kill switch tokens are negative and decremented by 1 to deal with task_index 0 token = -FLAGS.task_index-1 if kill else FLAGS.task_index kind = 'kill switch' if kill else 'stop' - for index, enqueue in enumerate(done_enqueues): + for index, enqueue in enumerate(C.done_enqueues): log_debug('Sending %s token to ps %d...' % (kind, index)) session.run(enqueue, feed_dict={ token_placeholder: token }) log_debug('Sent %s token to ps %d.' % (kind, index)) + def train(server=None): r''' Trains the network on a given server of a cluster. @@ -1474,53 +371,41 @@ def train(server=None): # It will automagically get incremented by the optimizer. global_step = tf.Variable(0, trainable=False, name='global_step') + dropout_rates = [tf.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)] + # Reading training set train_data = preprocess(FLAGS.train_files.split(','), FLAGS.train_batch_size, - n_input, - n_context, - alphabet, + C.n_input, + C.n_context, + C.alphabet, hdf5_cache_path=FLAGS.train_cached_features_path) train_set = DataSet(train_data, FLAGS.train_batch_size, limit=FLAGS.limit_train, - next_index=lambda i: COORD.get_next_index('train')) + next_index=lambda i: C.COORD.get_next_index('train')) # Reading validation set dev_data = preprocess(FLAGS.dev_files.split(','), FLAGS.dev_batch_size, - n_input, - n_context, - alphabet, + C.n_input, + C.n_context, + C.alphabet, hdf5_cache_path=FLAGS.dev_cached_features_path) dev_set = DataSet(dev_data, FLAGS.dev_batch_size, limit=FLAGS.limit_dev, - next_index=lambda i: COORD.get_next_index('dev')) - - # Reading test set - test_data = preprocess(FLAGS.test_files.split(','), - FLAGS.test_batch_size, - n_input, - n_context, - alphabet, - hdf5_cache_path=FLAGS.test_cached_features_path) - - test_set = DataSet(test_data, - FLAGS.test_batch_size, - limit=FLAGS.limit_test, - next_index=lambda i: COORD.get_next_index('test')) + next_index=lambda i: C.COORD.get_next_index('dev')) # Combining all sets to a multi set model feeder model_feeder = ModelFeeder(train_set, dev_set, - test_set, - n_input, - n_context, - alphabet, - tower_feeder_count=len(available_devices)) + C.n_input, + C.n_context, + C.alphabet, + tower_feeder_count=len(C.available_devices)) # Create the optimizer optimizer = create_optimizer() @@ -1532,7 +417,7 @@ def train(server=None): total_num_replicas=FLAGS.replicas) # Get the data_set specific graph end-points - results_tuple, gradients, mean_edit_distance, loss = get_tower_results(model_feeder, optimizer) + gradients, loss = get_tower_results(model_feeder, optimizer, dropout_rates) # Average tower gradients across GPUs avg_tower_gradients = average_gradients(gradients) @@ -1548,8 +433,7 @@ def train(server=None): step_summary_writers = { 'train': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120), - 'dev': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120), - 'test': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'test'), max_queue=120) + 'dev': tf.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120) } # Apply gradients to modify the model @@ -1618,17 +502,14 @@ def train(server=None): update_progressbar.total_jobs = None update_progressbar.current_job_index = 0 - current_epoch = COORD._epoch-1 + current_epoch = C.COORD._epoch-1 if set_name == "train": log_info('Training epoch %i...' % current_epoch) - update_progressbar.total_jobs = COORD._num_jobs_train - elif set_name == "dev": + update_progressbar.total_jobs = C.COORD._num_jobs_train + else: log_info('Validating epoch %i...' % current_epoch) - update_progressbar.total_jobs = COORD._num_jobs_dev - elif set_name == "test": - log_info('Testing epoch %i...' % current_epoch) - update_progressbar.total_jobs = COORD._num_jobs_test + update_progressbar.total_jobs = C.COORD._num_jobs_dev # recreate pbar update_progressbar.pbar = progressbar.ProgressBar(max_value=update_progressbar.total_jobs, @@ -1649,22 +530,22 @@ def train(server=None): # or an error occurs. try: with tf.train.MonitoredTrainingSession(master='' if server is None else server.target, - is_chief=is_chief, + is_chief=C.is_chief, hooks=hooks, checkpoint_dir=FLAGS.checkpoint_dir, save_checkpoint_secs=None, # already taken care of by a hook - config=session_config) as session: + config=C.session_config) as session: tf.get_default_graph().finalize() try: - if is_chief: + if C.is_chief: # Retrieving global_step from the (potentially restored) model model_feeder.set_data_set(no_dropout_feed_dict, model_feeder.train) step = session.run(global_step, feed_dict=no_dropout_feed_dict) - COORD.start_coordination(model_feeder, step) + C.COORD.start_coordination(model_feeder, step) # Get the first job - job = COORD.get_job() + job = C.COORD.get_job() while job and not session.should_stop(): log_debug('Computing %s...' % job) @@ -1693,17 +574,6 @@ def train(server=None): # Setting the training operation in case of training requested train_op = apply_gradient_op if is_train else [] - # Requirements to display a WER report - if job.report: - # Reset mean edit distance - total_mean_edit_distance = 0.0 - # Create report results tuple - report_results = ([],[],[],[]) - # Extend the session.run parameters - report_params = [results_tuple, mean_edit_distance] - else: - report_params = [] - # So far the only extra parameter is the feed_dict extra_params = { 'feed_dict': feed_dict } @@ -1716,7 +586,7 @@ def train(server=None): log_debug('Starting batch...') # Compute the batch - _, current_step, batch_loss, batch_report, step_summary = session.run([train_op, global_step, loss, report_params, step_summaries_op], **extra_params) + _, current_step, batch_loss, step_summary = session.run([train_op, global_step, loss, step_summaries_op], **extra_params) # Log step summaries step_summary_writer.add_summary(step_summary, current_step) @@ -1727,17 +597,8 @@ def train(server=None): # Add batch to loss total_loss += batch_loss - if job.report: - # Collect individual sample results - collect_results(report_results, batch_report[0]) - # Add batch to total_mean_edit_distance - total_mean_edit_distance += batch_report[1] - # Gathering job results job.loss = total_loss / job.steps - if job.report: - job.mean_edit_distance = total_mean_edit_distance / job.steps - job.wer, job.samples = calculate_report(report_results) # Display progressbar if FLAGS.show_progressbar: @@ -1745,7 +606,7 @@ def train(server=None): # Send the current job to coordinator and receive the next one log_debug('Sending %s...' % job) - job = COORD.next_job(job) + job = C.COORD.next_job(job) if update_progressbar.pbar: update_progressbar.pbar.finish() @@ -1759,7 +620,7 @@ def train(server=None): # Only chief has a SyncReplicasOptimizer queue runner that needs to be stopped for unblocking process exit. # A rather graceful way to do this is by stopping the ps. # Only one party can send it w/o failing. - if is_chief: + if C.is_chief: send_token_to_ps(session, kill=True) sys.exit(1) @@ -1773,23 +634,38 @@ def train(server=None): ' or removing the contents of {0}.'.format(FLAGS.checkpoint_dir)) sys.exit(1) + +def test(): + # Reading test set + test_data = preprocess(FLAGS.test_files.split(','), + FLAGS.test_batch_size, + C.n_input, + C.n_context, + C.alphabet, + hdf5_cache_path=FLAGS.test_cached_features_path) + + graph = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1) + + evaluate.evaluate(test_data, graph, C.alphabet) + + def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False, tflite=False): # Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input] - input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2*n_context+1, n_input], name='input_node') + input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2*C.n_context+1, C.n_input], name='input_node') seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths') if not tflite: - previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, n_cell_dim], initializer=None) - previous_state_h = variable_on_worker_level('previous_state_h', [batch_size, n_cell_dim], initializer=None) + previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, C.n_cell_dim], initializer=None) + previous_state_h = variable_on_worker_level('previous_state_h', [batch_size, C.n_cell_dim], initializer=None) else: - previous_state_c = tf.placeholder(tf.float32, [batch_size, n_cell_dim], name='previous_state_c') - previous_state_h = tf.placeholder(tf.float32, [batch_size, n_cell_dim], name='previous_state_h') + previous_state_c = tf.placeholder(tf.float32, [batch_size, C.n_cell_dim], name='previous_state_c') + previous_state_h = tf.placeholder(tf.float32, [batch_size, C.n_cell_dim], name='previous_state_h') previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h) logits, layers = BiRNN(batch_x=input_tensor, seq_length=seq_length if FLAGS.use_seq_length else None, - dropout=no_dropout, + dropout=C.no_dropout, batch_size=batch_size, n_steps=n_steps, previous_state=previous_state, @@ -1808,7 +684,7 @@ def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False, tfli # Initial zero state if not tflite: - zero_state = tf.zeros([batch_size, n_cell_dim], tf.float32) + zero_state = tf.zeros([batch_size, C.n_cell_dim], tf.float32) initialize_c = tf.assign(previous_state_c, zero_state) initialize_h = tf.assign(previous_state_h, zero_state) initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state') @@ -1855,7 +731,7 @@ def export(): from tensorflow.python.framework.ops import Tensor, Operation tf.reset_default_graph() - session = tf.Session(config=session_config) + session = tf.Session(config=C.session_config) inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite) input_names = ",".join(tensor.op.name for tensor in inputs.values()) @@ -1948,8 +824,10 @@ def export(): log_error(str(e)) def do_single_file_inference(input_file_path): - with tf.Session(config=session_config) as session: - inputs, outputs, _ = create_inference_graph(batch_size=1, use_new_decoder=True) + import numpy as np + + with tf.Session(config=C.session_config) as session: + inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1, use_new_decoder=True) # Create a saver using variables from the above newly created graph mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')} @@ -1968,77 +846,70 @@ def do_single_file_inference(input_file_path): session.run(outputs['initialize_state']) - features = audiofile_to_input_vector(input_file_path, n_input, n_context) - num_strides = len(features) - (n_context * 2) + features = audiofile_to_input_vector(input_file_path, C.n_input, C.n_context) + num_strides = len(features) - (C.n_context * 2) # Create a view into the array with overlapping strides of size # numcontext (past) + 1 (present) + numcontext (future) - window_size = 2*n_context+1 + window_size = 2*C.n_context+1 features = np.lib.stride_tricks.as_strided( features, - (num_strides, window_size, n_input), + (num_strides, window_size, C.n_input), (features.strides[0], features.strides[0], features.strides[1]), writeable=False) - logits = np.empty([0, 1, alphabet.size()+1]) - for i in range(0, len(features), FLAGS.n_steps): - chunk = features[i:i+FLAGS.n_steps] + logits = session.run(outputs['outputs'], feed_dict = { + inputs['input']: [features], + inputs['input_lengths']: [num_strides], + }) - # pad with zeros if not enough steps (len(features) % FLAGS.n_steps != 0) - if len(chunk) < FLAGS.n_steps: - chunk = np.pad(chunk, - ( - (0, FLAGS.n_steps - len(chunk)), - (0, 0), - (0, 0) - ), - mode='constant', - constant_values=0) + logits = np.squeeze(logits) - output = session.run(outputs['outputs'], feed_dict = { - inputs['input']: [chunk], - inputs['input_lengths']: [len(chunk)], - }) - logits = np.concatenate((logits, output)) - - decoded, _ = decode_with_lm(logits, [len(logits)], merge_repeated=False, beam_width=FLAGS.beam_width) - output = session.run(decoded) - - text = sparse_tensor_value_to_texts(output[0], alphabet) - - print(text[0]) + scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, + FLAGS.lm_binary_path, FLAGS.lm_trie_path, + C.alphabet) + decoded = ctc_beam_search_decoder(logits, C.alphabet, FLAGS.beam_width, scorer=scorer) + # Print highest probability result + print(decoded[0][1]) -def main(_) : - +def main(_): initialize_globals() if FLAGS.train or FLAGS.test: if len(FLAGS.worker_hosts) == 0: # Only one local task: this process (default case - no cluster) - train() + with tf.Graph().as_default(): + train() + # Now do a final test epoch + if FLAGS.test: + with tf.Graph().as_default(): + test() log_debug('Done.') else: # Create and start a server for the local task. - server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) + server = tf.train.Server(C.cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) if FLAGS.job_name == 'ps': # We are a parameter server and therefore we just wait for all workers to finish # by waiting for their stop tokens. with tf.Session(server.target) as session: for worker in FLAGS.worker_hosts: log_debug('Waiting for stop token...') - token = session.run(done_dequeues[FLAGS.task_index]) + token = session.run(C.done_dequeues[FLAGS.task_index]) if token < 0: log_debug('Got a kill switch token from worker %i.' % abs(token + 1)) break log_debug('Got a stop token from worker %i.' % token) log_debug('Session closed.') + + if FLAGS.test: + test() elif FLAGS.job_name == 'worker': # We are a worker and therefore we have to do some work. # Assigns ops to the local worker by default. with tf.device(tf.train.replica_device_setter( - worker_device=worker_device, - cluster=cluster)): + worker_device=C.worker_device, + cluster=C.cluster)): # Do the training train(server) @@ -2046,7 +917,7 @@ def main(_) : log_debug('Server stopped.') # Are we the main process? - if is_chief: + if C.is_chief: # Doing solo/post-processing work just on the main process... # Exporting the model if FLAGS.export_dir: @@ -2056,7 +927,7 @@ def main(_) : do_single_file_inference(FLAGS.one_shot_infer) # Stopping the coordinator - COORD.stop() + C.COORD.stop() if __name__ == '__main__' : create_flags() diff --git a/evaluate.py b/evaluate.py index 59c5b2e8..b3f54c93 100755 --- a/evaluate.py +++ b/evaluate.py @@ -15,8 +15,10 @@ import tensorflow as tf from attrdict import AttrDict from collections import namedtuple from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer -from DeepSpeech import initialize_globals, create_flags, log_debug, log_info, log_warn, log_error, create_inference_graph -from multiprocessing import Pool +from util.flags import create_flags +from util.coordinator import C, initialize_globals +from util.logging import log_debug, log_info, log_warn, log_error +from multiprocessing import Pool, cpu_count from six.moves import zip, range from util.audio import audiofile_to_input_vector from util.text import Alphabet, ctc_label_dense_to_sparse, wer, levenshtein @@ -86,31 +88,11 @@ def calculate_report(labels, decodings, distances, losses): return samples_wer, samples -def main(_): - initialize_globals() - - if not FLAGS.test_files: - log_error('You need to specify what files to use for evaluation via ' - 'the --test_files flag.') - exit(1) - - global alphabet - alphabet = Alphabet(FLAGS.alphabet_config_path) - - scorer = Scorer(FLAGS.lm_weight, FLAGS.valid_word_count_weight, +def evaluate(test_data, inference_graph, alphabet): + scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path, FLAGS.lm_trie_path, - alphabet) + C.alphabet) - # sort examples by length, improves packing of batches and timesteps - test_data = preprocess( - FLAGS.test_files.split(','), - FLAGS.test_batch_size, - alphabet=alphabet, - numcep=N_FEATURES, - numcontext=N_CONTEXT, - hdf5_cache_path=FLAGS.hdf5_test_set).sort_values( - by="features_len", - ascending=False) def create_windows(features): num_strides = len(features) - (N_CONTEXT * 2) @@ -130,7 +112,7 @@ def main(_): test_data['features'] = test_data['features'].apply(create_windows) with tf.Session() as session: - inputs, outputs, layers = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1) + inputs, outputs, layers = inference_graph # Transpose to batch major for decoder transposed = tf.transpose(outputs['outputs'], [1, 0, 2]) @@ -192,7 +174,10 @@ def main(_): widget=progressbar.AdaptiveETA) # Get number of accessible CPU cores for this process - num_processes = len(os.sched_getaffinity(0)) + try: + num_processes = cpu_count() + except: + num_processes = 1 # Second pass, decode logits and compute WER and edit distance metrics for logits, batch in bar(zip(logitses, split_data(test_data, FLAGS.test_batch_size))): @@ -221,7 +206,38 @@ def main(_): print(' - res: "%s"' % sample.res) print('-' * 80) + return samples + + +def main(_): + initialize_globals() + + if not FLAGS.test_files: + log_error('You need to specify what files to use for evaluation via ' + 'the --test_files flag.') + exit(1) + + global alphabet + alphabet = Alphabet(FLAGS.alphabet_config_path) + + # sort examples by length, improves packing of batches and timesteps + test_data = preprocess( + FLAGS.test_files.split(','), + FLAGS.test_batch_size, + alphabet=alphabet, + numcep=N_FEATURES, + numcontext=N_CONTEXT, + hdf5_cache_path=FLAGS.hdf5_test_set).sort_values( + by="features_len", + ascending=False) + + from DeepSpeech import create_inference_graph + graph = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1) + + samples = evaluate(test_data, graph, alphabet) + if FLAGS.test_output_file: + # Save decoded tuples as JSON, converting NumPy floats to Python floats json.dump(samples, open(FLAGS.test_output_file, 'w'), default=lambda x: float(x)) diff --git a/requirements.txt b/requirements.txt index 03cc814d..b4329ffe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ bs4 six requests tables +attrdict diff --git a/util/coordinator.py b/util/coordinator.py new file mode 100644 index 00000000..f2fc9036 --- /dev/null +++ b/util/coordinator.py @@ -0,0 +1,706 @@ +from __future__ import absolute_import, division, print_function + +import os +import pickle +import tensorflow as tf + +from attrdict import AttrDict +from datetime import datetime +from six.moves import zip, range, filter, urllib, BaseHTTPServer +from threading import Thread, Lock +from util.gpu import get_available_gpus +from util.flags import FLAGS +from util.logging import * +from util.text import Alphabet +from xdg import BaseDirectory as xdg + +# Execution +# ========= + +# For reporting we also need a standard way to do time measurements. +def stopwatch(start_duration=0): + r''' + This function will toggle a stopwatch. + The first call starts it, second call stops it, third call continues it etc. + So if you want to measure the accumulated time spent in a certain area of the code, + you can surround that code by stopwatch-calls like this: + + .. code:: python + + fun_time = 0 # initializes a stopwatch + [...] + for i in range(10): + [...] + # Starts/continues the stopwatch - fun_time is now a point in time (again) + fun_time = stopwatch(fun_time) + fun() + # Pauses the stopwatch - fun_time is now a duration + fun_time = stopwatch(fun_time) + [...] + # The following line only makes sense after an even call of :code:`fun_time = stopwatch(fun_time)`. + print 'Time spent in fun():', format_duration(fun_time) + + ''' + if start_duration == 0: + return datetime.utcnow() + else: + return datetime.utcnow() - start_duration + +def format_duration(duration): + '''Formats the result of an even stopwatch call as hours:minutes:seconds''' + duration = duration if isinstance(duration, int) else duration.seconds + m, s = divmod(duration, 60) + h, m = divmod(m, 60) + return '%d:%02d:%02d' % (h, m, s) + + +# String constants for different services of the web handler +PREFIX_NEXT_INDEX = '/next_index_' +PREFIX_GET_JOB = '/get_job_' + +# Global ID counter for all objects requiring an ID +id_counter = 0 + +def new_id(): + '''Returns a new ID that is unique on process level. Not thread-safe. + + Returns: + int. The new ID + ''' + global id_counter + id_counter += 1 + return id_counter + +class WorkerJob(object): + '''Represents a job that should be executed by a worker. + + Args: + epoch_id (int): the ID of the 'parent' epoch + index (int): the epoch index of the 'parent' epoch + set_name (str): the name of the data-set - one of 'train', 'dev' + steps (int): the number of `session.run` calls + ''' + def __init__(self, epoch_id, index, set_name, steps): + self.id = new_id() + self.epoch_id = epoch_id + self.index = index + self.worker = -1 + self.set_name = set_name + self.steps = steps + self.loss = -1 + self.samples = [] + + def __str__(self): + return 'Job (ID: %d, worker: %d, epoch: %d, set_name: %s)' % (self.id, self.worker, self.index, self.set_name) + +class Epoch(object): + '''Represents an epoch that should be executed by the Training Coordinator. + Creates `num_jobs` `WorkerJob` instances in state 'open'. + + Args: + index (int): the epoch index of the 'parent' epoch + num_jobs (int): the number of jobs in this epoch + + Kwargs: + set_name (str): the name of the data-set - one of 'train', 'dev' + ''' + def __init__(self, index, num_jobs, set_name='train'): + self.id = new_id() + self.index = index + self.num_jobs = num_jobs + self.set_name = set_name + self.loss = -1 + self.jobs_open = [] + self.jobs_running = [] + self.jobs_done = [] + for i in range(self.num_jobs): + self.jobs_open.append(WorkerJob(self.id, self.index, self.set_name, FLAGS.iters_per_worker)) + + def name(self): + '''Gets a printable name for this epoch. + + Returns: + str. printable name for this epoch + ''' + if self.index >= 0: + ename = ' of Epoch %d' % self.index + else: + ename = '' + if self.set_name == 'train': + return 'Training%s' % ename + else: + return 'Validation%s' % ename + + def get_job(self, worker): + '''Gets the next open job from this epoch. The job will be marked as 'running'. + + Args: + worker (int): index of the worker that takes the job + + Returns: + WorkerJob. job that has been marked as running for this worker + ''' + if len(self.jobs_open) > 0: + job = self.jobs_open.pop(0) + self.jobs_running.append(job) + job.worker = worker + return job + else: + return None + + def finish_job(self, job): + '''Finishes a running job. Removes it from the running jobs list and adds it to the done jobs list. + + Args: + job (WorkerJob): the job to put into state 'done' + ''' + index = next((i for i in range(len(self.jobs_running)) if self.jobs_running[i].id == job.id), -1) + if index >= 0: + self.jobs_running.pop(index) + self.jobs_done.append(job) + log_traffic('%s - Moved %s from running to done.' % (self.name(), job)) + else: + log_warn('%s - There is no job with ID %d registered as running.' % (self.name(), job.id)) + + def done(self): + '''Checks, if all jobs of the epoch are in state 'done'. + + Returns: + bool. if all jobs of the epoch are 'done' + ''' + if len(self.jobs_open) == 0 and len(self.jobs_running) == 0: + num_jobs = len(self.jobs_done) + if num_jobs > 0: + jobs = self.jobs_done + self.jobs_done = [] + if not self.num_jobs == num_jobs: + log_warn('%s - Number of steps not equal to number of jobs done.' % (self.name())) + + agg_loss = 0.0 + + for i in range(num_jobs): + job = jobs.pop(0) + agg_loss += job.loss + + self.loss = agg_loss / num_jobs + + # if the job was for validation dataset then append it to the COORD's _loss for early stop verification + if (FLAGS.early_stop is True) and (self.set_name == 'dev'): + COORD._dev_losses.append(self.loss) + + return True + return False + + def job_status(self): + '''Provides a printable overview of the states of the jobs of this epoch. + + Returns: + str. printable overall job state + ''' + return '%s - jobs open: %d, jobs running: %d, jobs done: %d' % (self.name(), len(self.jobs_open), len(self.jobs_running), len(self.jobs_done)) + + def __str__(self): + if not self.done(): + return self.job_status() + + return '%s - loss: %f' % (self.name(), self.loss) + + +class TrainingCoordinator(object): + ''' Central training coordination class. + Used for distributing jobs among workers of a cluster. + Instantiated on all workers, calls of non-chief workers will transparently + HTTP-forwarded to the chief worker instance. + ''' + + class TrainingCoordinationHandler(BaseHTTPServer.BaseHTTPRequestHandler): + '''Handles HTTP requests from remote workers to the Training Coordinator. + ''' + def _send_answer(self, data=None): + self.send_response(200) + self.send_header('content-type', 'text/plain') + self.end_headers() + if data: + self.wfile.write(data) + + def do_GET(self): + if COORD.started: + if self.path.startswith(PREFIX_NEXT_INDEX): + index = COORD.get_next_index(self.path[len(PREFIX_NEXT_INDEX):]) + if index >= 0: + self._send_answer(str(index).encode("utf-8")) + return + elif self.path.startswith(PREFIX_GET_JOB): + job = COORD.get_job(worker=int(self.path[len(PREFIX_GET_JOB):])) + if job: + self._send_answer(pickle.dumps(job)) + return + self.send_response(204) # end of training + else: + self.send_response(202) # not ready yet + self.end_headers() + + def do_POST(self): + if COORD.started: + src = self.rfile.read(int(self.headers['content-length'])) + job = COORD.next_job(pickle.loads(src)) + if job: + self._send_answer(pickle.dumps(job)) + return + self.send_response(204) # end of training + else: + self.send_response(202) # not ready yet + self.end_headers() + + def log_message(self, format, *args): + '''Overriding base method to suppress web handler messages on stdout. + ''' + return + + def __init__(self, is_chief): + self._init() + self._lock = Lock() + self.started = False + self.is_chief = is_chief + if is_chief: + self._httpd = BaseHTTPServer.HTTPServer((FLAGS.coord_host, FLAGS.coord_port), TrainingCoordinator.TrainingCoordinationHandler) + + def _reset_counters(self): + self._index_train = 0 + self._index_dev = 0 + + def _init(self): + self._epochs_running = [] + self._epochs_done = [] + self._reset_counters() + self._dev_losses = [] + + def _log_all_jobs(self): + '''Use this to debug-print epoch state''' + log_debug('Epochs - running: %d, done: %d' % (len(self._epochs_running), len(self._epochs_done))) + for epoch in self._epochs_running: + log_debug(' - running: ' + epoch.job_status()) + + def start_coordination(self, model_feeder, step=0): + '''Starts to coordinate epochs and jobs among workers on base of + data-set sizes, the (global) step and FLAGS parameters. + + Args: + model_feeder (ModelFeeder): data-sets to be used for coordinated training + + Kwargs: + step (int): global step of a loaded model to determine starting point + ''' + with self._lock: + self._init() + + # Number of GPUs per worker - fixed for now by local reality or cluster setup + gpus_per_worker = len(C.available_devices) + + # Number of batches processed per job per worker + batches_per_job = gpus_per_worker * max(1, FLAGS.iters_per_worker) + + # Number of batches per global step + batches_per_step = gpus_per_worker * max(1, FLAGS.replicas_to_agg) + + # Number of global steps per epoch - to be at least 1 + steps_per_epoch = max(1, model_feeder.train.total_batches // batches_per_step) + + # The start epoch of our training + self._epoch = step // steps_per_epoch + + # Number of additional 'jobs' trained already 'on top of' our start epoch + jobs_trained = (step % steps_per_epoch) * batches_per_step // batches_per_job + + # Total number of train/dev jobs covering their respective whole sets (one epoch) + self._num_jobs_train = max(1, model_feeder.train.total_batches // batches_per_job) + self._num_jobs_dev = max(1, model_feeder.dev.total_batches // batches_per_job) + + if FLAGS.epoch < 0: + # A negative epoch means to add its absolute number to the epochs already computed + self._target_epoch = self._epoch + abs(FLAGS.epoch) + else: + self._target_epoch = FLAGS.epoch + + # State variables + # We only have to train, if we are told so and are not at the target epoch yet + self._train = FLAGS.train and self._target_epoch > self._epoch + + if self._train: + # The total number of jobs for all additional epochs to be trained + # Will be decremented for each job that is produced/put into state 'open' + self._num_jobs_train_left = (self._target_epoch - self._epoch) * self._num_jobs_train - jobs_trained + log_info('STARTING Optimization') + self._training_time = stopwatch() + + # Important for debugging + log_debug('step: %d' % step) + log_debug('epoch: %d' % self._epoch) + log_debug('target epoch: %d' % self._target_epoch) + log_debug('steps per epoch: %d' % steps_per_epoch) + log_debug('number of batches in train set: %d' % model_feeder.train.total_batches) + log_debug('batches per job: %d' % batches_per_job) + log_debug('batches per step: %d' % batches_per_step) + log_debug('number of jobs in train set: %d' % self._num_jobs_train) + log_debug('number of jobs already trained in first epoch: %d' % jobs_trained) + + self._next_epoch() + + # The coordinator is ready to serve + self.started = True + + def _next_epoch(self): + # State-machine of the coordination process + + # Indicates, if there were 'new' epoch(s) provided + result = False + + # Make sure that early stop is enabled and validation part is enabled + if (FLAGS.early_stop is True) and (FLAGS.validation_step > 0) and (len(self._dev_losses) >= FLAGS.earlystop_nsteps): + + # Calculate the mean of losses for past epochs + mean_loss = np.mean(self._dev_losses[-FLAGS.earlystop_nsteps:-1]) + # Calculate the standard deviation for losses from validation part in the past epochs + std_loss = np.std(self._dev_losses[-FLAGS.earlystop_nsteps:-1]) + # Update the list of losses incurred + self._dev_losses = self._dev_losses[-FLAGS.earlystop_nsteps:] + log_debug('Checking for early stopping (last %d steps) validation loss: %f, with standard deviation: %f and mean: %f' % (FLAGS.earlystop_nsteps, self._dev_losses[-1], std_loss, mean_loss)) + + # Check if validation loss has started increasing or is not decreasing substantially, making sure slight fluctuations don't bother the early stopping from working + if self._dev_losses[-1] > np.max(self._dev_losses[:-1]) or (abs(self._dev_losses[-1] - mean_loss) < FLAGS.estop_mean_thresh and std_loss < FLAGS.estop_std_thresh): + # Time to early stop + log_info('Early stop triggered as (for last %d steps) validation loss: %f with standard deviation: %f and mean: %f' % (FLAGS.earlystop_nsteps, self._dev_losses[-1], std_loss, mean_loss)) + self._dev_losses = [] + self._end_training() + self._train = False + + if self._train: + # We are in train mode + if self._num_jobs_train_left > 0: + # There are still jobs left + num_jobs_train = min(self._num_jobs_train_left, self._num_jobs_train) + self._num_jobs_train_left -= num_jobs_train + + # Let's try our best to keep the notion of curriculum learning + self._reset_counters() + + # Append the training epoch + self._epochs_running.append(Epoch(self._epoch, num_jobs_train, set_name='train')) + + if FLAGS.validation_step > 0 and (FLAGS.validation_step == 1 or self._epoch > 0) and self._epoch % FLAGS.validation_step == 0: + # The current epoch should also have a validation part + self._epochs_running.append(Epoch(self._epoch, self._num_jobs_dev, set_name='dev')) + + + # Indicating that there were 'new' epoch(s) provided + result = True + else: + # No jobs left, but still in train mode: concluding training + self._end_training() + self._train = False + + if result: + # Increment the epoch index + self._epoch += 1 + return result + + def _end_training(self): + self._training_time = stopwatch(self._training_time) + log_info('FINISHED Optimization - training time: %s' % format_duration(self._training_time)) + + def start(self): + '''Starts Training Coordinator. If chief, it starts a web server for + communication with non-chief instances. + ''' + if self.is_chief: + log_debug('Starting coordinator...') + self._thread = Thread(target=self._httpd.serve_forever) + self._thread.daemon = True + self._thread.start() + log_debug('Coordinator started.') + + def stop(self, wait_for_running_epochs=True): + '''Stops Training Coordinator. If chief, it waits for all epochs to be + 'done' and then shuts down the web server. + ''' + if self.is_chief: + if wait_for_running_epochs: + while len(self._epochs_running) > 0: + log_traffic('Coordinator is waiting for epochs to finish...') + time.sleep(5) + log_debug('Stopping coordinator...') + self._httpd.shutdown() + log_debug('Coordinator stopped.') + + def _talk_to_chief(self, path, data=None, default=None): + tries = 0 + while tries < FLAGS.coord_retries: + tries += 1 + try: + url = 'http://%s:%d%s' % (FLAGS.coord_host, FLAGS.coord_port, path) + log_traffic('Contacting coordinator - url: %s, tries: %d ...' % (url, tries-1)) + res = urllib.request.urlopen(urllib.request.Request(url, data, { 'content-type': 'text/plain' })) + str = res.read() + status = res.getcode() + log_traffic('Coordinator responded - url: %s, status: %s' % (url, status)) + if status == 200: + return str + if status == 204: # We use 204 (no content) to indicate end of training + return default + except urllib.error.HTTPError as error: + log_traffic('Problem reaching coordinator - url: %s, HTTP code: %d' % (url, error.code)) + pass + time.sleep(10) + return default + + def get_next_index(self, set_name): + '''Retrives a new cluster-unique batch index for a given set-name. + Prevents applying one batch multiple times per epoch. + + Args: + set_name (str): name of the data set - one of 'train', 'dev' + + Returns: + int. new data set index + ''' + with self._lock: + if self.is_chief: + member = '_index_' + set_name + value = getattr(self, member, -1) + setattr(self, member, value + 1) + return value + else: + # We are a remote worker and have to hand over to the chief worker by HTTP + log_traffic('Asking for next index...') + value = int(self._talk_to_chief(PREFIX_NEXT_INDEX + set_name)) + log_traffic('Got index %d.' % value) + return value + + def _get_job(self, worker=0): + job = None + # Find first running epoch that provides a next job + for epoch in self._epochs_running: + job = epoch.get_job(worker) + if job: + return job + # No next job found + return None + + def get_job(self, worker=0): + '''Retrieves the first job for a worker. + + Kwargs: + worker (int): index of the worker to get the first job for + + Returns: + WorkerJob. a job of one of the running epochs that will get + associated with the given worker and put into state 'running' + ''' + # Let's ensure that this does not interfere with other workers/requests + with self._lock: + if self.is_chief: + # First try to get a next job + job = self._get_job(worker) + + if job is None: + # If there was no next job, we give it a second chance by triggering the epoch state machine + if self._next_epoch(): + # Epoch state machine got a new epoch + # Second try to get a next job + job = self._get_job(worker) + if job is None: + # Albeit the epoch state machine got a new epoch, the epoch had no new job for us + log_error('Unexpected case - no job for worker %d.' % (worker)) + return job + + # Epoch state machine has no new epoch + # This happens at the end of the whole training - nothing to worry about + log_traffic('No jobs left for worker %d.' % (worker)) + self._log_all_jobs() + return None + + # We got a new job from one of the currently running epochs + log_traffic('Got new %s' % job) + return job + + # We are a remote worker and have to hand over to the chief worker by HTTP + result = self._talk_to_chief(PREFIX_GET_JOB + str(FLAGS.task_index)) + if result: + result = pickle.loads(result) + return result + + def next_job(self, job): + '''Sends a finished job back to the coordinator and retrieves in exchange the next one. + + Kwargs: + job (WorkerJob): job that was finished by a worker and who's results are to be + digested by the coordinator + + Returns: + WorkerJob. next job of one of the running epochs that will get + associated with the worker from the finished job and put into state 'running' + ''' + if self.is_chief: + # Try to find the epoch the job belongs to + epoch = next((epoch for epoch in self._epochs_running if epoch.id == job.epoch_id), None) + if epoch: + # We are going to manipulate things - let's avoid undefined state + with self._lock: + # Let the epoch finish the job + epoch.finish_job(job) + # Check, if epoch is done now + if epoch.done(): + # If it declares itself done, move it from 'running' to 'done' collection + self._epochs_running.remove(epoch) + self._epochs_done.append(epoch) + log_info('%s' % epoch) + else: + # There was no running epoch found for this job - this should never happen. + log_error('There is no running epoch of ID %d for job with ID %d. This should never happen.' % (job.epoch_id, job.id)) + return self.get_job(job.worker) + + # We are a remote worker and have to hand over to the chief worker by HTTP + result = self._talk_to_chief('', data=pickle.dumps(job)) + if result: + result = pickle.loads(result) + return result + +class GlobalConfig: + _config = None + + def __getattr__(self, name): + if not GlobalConfig._config: + raise RuntimeError("Global configuration not yet initialized.") + if not hasattr(GlobalConfig._config, name): + raise RuntimeError("Configuration option {} not found in config.".format(name)) + return GlobalConfig._config[name] + +C = GlobalConfig() + +def initialize_globals(): + c = AttrDict() + + # ps and worker hosts required for p2p cluster setup + FLAGS.ps_hosts = list(filter(len, FLAGS.ps_hosts.split(','))) + FLAGS.worker_hosts = list(filter(len, FLAGS.worker_hosts.split(','))) + + # The absolute number of computing nodes - regardless of cluster or single mode + c.num_workers = max(1, len(FLAGS.worker_hosts)) + + # Create a cluster from the parameter server and worker hosts. + c.cluster = tf.train.ClusterSpec({'ps': FLAGS.ps_hosts, 'worker': FLAGS.worker_hosts}) + + # If replica numbers are negative, we multiply their absolute values with the number of workers + if FLAGS.replicas < 0: + FLAGS.replicas = c.num_workers * -FLAGS.replicas + if FLAGS.replicas_to_agg < 0: + FLAGS.replicas_to_agg = c.num_workers * -FLAGS.replicas_to_agg + + # The device path base for this node + c.worker_device = '/job:%s/task:%d' % (FLAGS.job_name, FLAGS.task_index) + + # This node's CPU device + c.cpu_device = c.worker_device + '/cpu:0' + + # This node's available GPU devices + c.available_devices = [c.worker_device + gpu for gpu in get_available_gpus()] + + # If there is no GPU available, we fall back to CPU based operation + if 0 == len(c.available_devices): + c.available_devices = [c.cpu_device] + + # Set default dropout rates + if FLAGS.dropout_rate2 < 0: + FLAGS.dropout_rate2 = FLAGS.dropout_rate + if FLAGS.dropout_rate3 < 0: + FLAGS.dropout_rate3 = FLAGS.dropout_rate + if FLAGS.dropout_rate6 < 0: + FLAGS.dropout_rate6 = FLAGS.dropout_rate + + c.no_dropout = [ 0.0 ] * 6 + + # Set default checkpoint dir + if len(FLAGS.checkpoint_dir) == 0: + FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('deepspeech','checkpoints')) + + # Set default summary dir + if len(FLAGS.summary_dir) == 0: + FLAGS.summary_dir = xdg.save_data_path(os.path.join('deepspeech','summaries')) + + # Standard session configuration that'll be used for all new sessions. + c.session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement, + inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads, + intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads) + + c.alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path)) + + # Geometric Constants + # =================== + + # For an explanation of the meaning of the geometric constants, please refer to + # doc/Geometry.md + + # Number of MFCC features + c.n_input = 26 # TODO: Determine this programatically from the sample rate + + # The number of frames in the context + c.n_context = 9 # TODO: Determine the optimal value using a validation data set + + # Number of units in hidden layers + c.n_hidden = FLAGS.n_hidden + + c.n_hidden_1 = c.n_hidden + + c.n_hidden_2 = c.n_hidden + + c.n_hidden_5 = c.n_hidden + + # LSTM cell state dimension + c.n_cell_dim = c.n_hidden + + # The number of units in the third layer, which feeds in to the LSTM + c.n_hidden_3 = c.n_cell_dim + + # The number of characters in the target language plus one + c.n_character = c.alphabet.size() + 1 # +1 for CTC blank label + + # The number of units in the sixth layer + c.n_hidden_6 = c.n_character + + # Queues that are used to gracefully stop parameter servers. + # Each queue stands for one ps. A finishing worker sends a token to each queue before joining/quitting. + # Each ps will dequeue as many tokens as there are workers before joining/quitting. + # This ensures parameter servers won't quit, if still required by at least one worker and + # also won't wait forever (like with a standard `server.join()`). + c.done_queues = [] + for i, ps in enumerate(FLAGS.ps_hosts): + # Queues are hosted by their respective owners + with tf.device('/job:ps/task:%d' % i): + c.done_queues.append(tf.FIFOQueue(1, tf.int32, shared_name=('queue%i' % i))) + + # Placeholder to pass in the worker's index as token + c.token_placeholder = tf.placeholder(tf.int32) + + # Enqueue operations for each parameter server + c.done_enqueues = [queue.enqueue(token_placeholder) for queue in c.done_queues] + + # Dequeue operations for each parameter server + c.done_dequeues = [queue.dequeue() for queue in c.done_queues] + + if len(FLAGS.one_shot_infer) > 0: + FLAGS.train = False + FLAGS.test = False + FLAGS.export_dir = '' + if not os.path.exists(FLAGS.one_shot_infer): + log_error('Path specified in --one_shot_infer is not a valid file.') + exit(1) + + # Determine, if we are the chief worker + c.is_chief = len(FLAGS.worker_hosts) == 0 or (FLAGS.task_index == 0 and FLAGS.job_name == 'worker') + + # Initializing and starting the training coordinator + c.COORD = TrainingCoordinator(c.is_chief) + c.COORD.start() + + GlobalConfig._config = c + diff --git a/util/feeding.py b/util/feeding.py index 65b90a26..26851bb7 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -12,14 +12,13 @@ class ModelFeeder(object): ''' Feeds data into a model. Feeding is parallelized by independent units called tower feeders (usually one per GPU). - Each tower feeder provides data from three runtime switchable sources (train, dev, test). - These sources are to be provided by three DataSet instances whos references are kept. + Each tower feeder provides data from runtime switchable sources (train, dev). + These sources are to be provided by the DataSet instances whose references are kept. Creates, owns and delegates to tower_feeder_count internal tower feeder objects. ''' def __init__(self, train_set, dev_set, - test_set, numcep, numcontext, alphabet, @@ -28,8 +27,7 @@ class ModelFeeder(object): self.train = train_set self.dev = dev_set - self.test = test_set - self.sets = [train_set, dev_set, test_set] + self.sets = [train_set, dev_set] self.numcep = numcep self.numcontext = numcontext self.tower_feeder_count = max(len(get_available_gpus()), 1) if tower_feeder_count < 0 else tower_feeder_count diff --git a/util/flags.py b/util/flags.py new file mode 100644 index 00000000..b58d8d42 --- /dev/null +++ b/util/flags.py @@ -0,0 +1,142 @@ +from __future__ import print_function + +import tensorflow as tf + +from xdg import BaseDirectory as xdg + + +FLAGS = tf.app.flags.FLAGS + + +def create_flags(): + # Importer + # ======== + + tf.app.flags.DEFINE_string ('train_files', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged') + tf.app.flags.DEFINE_string ('dev_files', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged') + tf.app.flags.DEFINE_string ('test_files', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged') + tf.app.flags.DEFINE_boolean ('fulltrace', False, 'if full trace debug info should be generated during training') + + tf.app.flags.DEFINE_string ('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged') + tf.app.flags.DEFINE_string ('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged') + tf.app.flags.DEFINE_string ('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged') + + # Cluster configuration + # ===================== + + tf.app.flags.DEFINE_string ('ps_hosts', '', 'parameter servers - comma separated list of hostname:port pairs') + tf.app.flags.DEFINE_string ('worker_hosts', '', 'workers - comma separated list of hostname:port pairs') + tf.app.flags.DEFINE_string ('job_name', 'localhost', 'job name - one of localhost (default), worker, ps') + tf.app.flags.DEFINE_integer ('task_index', 0, 'index of task within the job - worker with index 0 will be the chief') + tf.app.flags.DEFINE_integer ('replicas', -1, 'total number of replicas - if negative, its absolute value is multiplied by the number of workers') + tf.app.flags.DEFINE_integer ('replicas_to_agg', -1, 'number of replicas to aggregate - if negative, its absolute value is multiplied by the number of workers') + tf.app.flags.DEFINE_integer ('coord_retries', 100, 'number of tries of workers connecting to training coordinator before failing') + tf.app.flags.DEFINE_string ('coord_host', 'localhost', 'coordination server host') + tf.app.flags.DEFINE_integer ('coord_port', 2500, 'coordination server port') + tf.app.flags.DEFINE_integer ('iters_per_worker', 1, 'number of train or inference iterations per worker before results are sent back to coordinator') + + # Global Constants + # ================ + + tf.app.flags.DEFINE_boolean ('train', True, 'whether to train the network') + tf.app.flags.DEFINE_boolean ('test', True, 'whether to test the network') + tf.app.flags.DEFINE_integer ('epoch', 75, 'target epoch to train - if negative, the absolute number of additional epochs will be trained') + + tf.app.flags.DEFINE_float ('dropout_rate', 0.05, 'dropout rate for feedforward layers') + tf.app.flags.DEFINE_float ('dropout_rate2', -1.0, 'dropout rate for layer 2 - defaults to dropout_rate') + tf.app.flags.DEFINE_float ('dropout_rate3', -1.0, 'dropout rate for layer 3 - defaults to dropout_rate') + tf.app.flags.DEFINE_float ('dropout_rate4', 0.0, 'dropout rate for layer 4 - defaults to 0.0') + tf.app.flags.DEFINE_float ('dropout_rate5', 0.0, 'dropout rate for layer 5 - defaults to 0.0') + tf.app.flags.DEFINE_float ('dropout_rate6', -1.0, 'dropout rate for layer 6 - defaults to dropout_rate') + + tf.app.flags.DEFINE_float ('relu_clip', 20.0, 'ReLU clipping value for non-recurrant layers') + + # Adam optimizer (http://arxiv.org/abs/1412.6980) parameters + + tf.app.flags.DEFINE_float ('beta1', 0.9, 'beta 1 parameter of Adam optimizer') + tf.app.flags.DEFINE_float ('beta2', 0.999, 'beta 2 parameter of Adam optimizer') + tf.app.flags.DEFINE_float ('epsilon', 1e-8, 'epsilon parameter of Adam optimizer') + tf.app.flags.DEFINE_float ('learning_rate', 0.001, 'learning rate of Adam optimizer') + + # Batch sizes + + tf.app.flags.DEFINE_integer ('train_batch_size', 1, 'number of elements in a training batch') + tf.app.flags.DEFINE_integer ('dev_batch_size', 1, 'number of elements in a validation batch') + tf.app.flags.DEFINE_integer ('test_batch_size', 1, 'number of elements in a test batch') + + tf.app.flags.DEFINE_integer ('export_batch_size', 1, 'number of elements per batch on the exported graph') + + # Performance (UNSUPPORTED) + tf.app.flags.DEFINE_integer ('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details') + tf.app.flags.DEFINE_integer ('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details') + + # Sample limits + + tf.app.flags.DEFINE_integer ('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit') + tf.app.flags.DEFINE_integer ('limit_dev', 0, 'maximum number of elements to use from validation set- 0 means no limit') + tf.app.flags.DEFINE_integer ('limit_test', 0, 'maximum number of elements to use from test set- 0 means no limit') + + # Step widths + + tf.app.flags.DEFINE_integer ('validation_step', 0, 'number of epochs we cycle through before validating the model - 0 means no validation steps') + + # Checkpointing + + tf.app.flags.DEFINE_string ('checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification') + tf.app.flags.DEFINE_integer ('checkpoint_secs', 600, 'checkpoint saving interval in seconds') + tf.app.flags.DEFINE_integer ('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5') + + # Exporting + + tf.app.flags.DEFINE_string ('export_dir', '', 'directory in which exported models are stored - if omitted, the model won\'t get exported') + tf.app.flags.DEFINE_integer ('export_version', 1, 'version number of the exported model') + tf.app.flags.DEFINE_boolean ('remove_export', False, 'whether to remove old exported models') + tf.app.flags.DEFINE_boolean ('export_tflite', False, 'export a graph ready for TF Lite engine') + tf.app.flags.DEFINE_boolean ('use_seq_length', True, 'have sequence_length in the exported graph (will make tfcompile unhappy)') + tf.app.flags.DEFINE_integer ('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency') + + # Reporting + + tf.app.flags.DEFINE_integer ('log_level', 1, 'log level for console logs - 0: INFO, 1: WARN, 2: ERROR, 3: FATAL') + tf.app.flags.DEFINE_boolean ('log_traffic', False, 'log cluster transaction and traffic information during debug logging') + tf.app.flags.DEFINE_boolean ('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.') + + tf.app.flags.DEFINE_boolean ('log_placement', False, 'whether to log device placement of the operators to the console') + tf.app.flags.DEFINE_integer ('report_count', 10, 'number of phrases with lowest WER (best matching) to print out during a WER report') + + tf.app.flags.DEFINE_string ('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification') + tf.app.flags.DEFINE_integer ('summary_secs', 0, 'interval in seconds for saving TensorBoard summaries - if 0, no summaries will be written') + + # Geometry + + tf.app.flags.DEFINE_integer ('n_hidden', 2048, 'layer width to use when initialising layers') + + # Initialization + + tf.app.flags.DEFINE_integer ('random_seed', 4567, 'default random seed that is used to initialize variables') + + # Early Stopping + + tf.app.flags.DEFINE_boolean ('early_stop', True, 'enable early stopping mechanism over validation dataset. Make sure that dev FLAG is enabled for this to work') + + # This parameter is irrespective of the time taken by single epoch to complete and checkpoint saving intervals. + # It is possible that early stopping is triggered far after the best checkpoint is already replaced by checkpoint saving interval mechanism. + # One has to align the parameters (earlystop_nsteps, checkpoint_secs) accordingly as per the time taken by an epoch on different datasets. + + tf.app.flags.DEFINE_integer ('earlystop_nsteps', 4, 'number of steps to consider for early stopping. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point') + tf.app.flags.DEFINE_float ('estop_mean_thresh', 0.5, 'mean threshold for loss to determine the condition if early stopping is required') + tf.app.flags.DEFINE_float ('estop_std_thresh', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required') + + # Decoder + + tf.app.flags.DEFINE_string ('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.') + tf.app.flags.DEFINE_string ('lm_binary_path', 'data/lm/lm.binary', 'path to the language model binary file created with KenLM') + tf.app.flags.DEFINE_string ('lm_trie_path', 'data/lm/trie', 'path to the language model trie file created with native_client/generate_trie') + tf.app.flags.DEFINE_integer ('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions') + tf.app.flags.DEFINE_float ('lm_alpha', 1.50, 'the alpha hyperparameter of the CTC decoder. Language Model weight.') + tf.app.flags.DEFINE_float ('lm_beta', 2.10, 'the beta hyperparameter of the CTC decoder. Word insertion weight.') + + # Inference mode + + tf.app.flags.DEFINE_string ('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it. Disables training, testing and exporting.') + diff --git a/util/logging.py b/util/logging.py new file mode 100644 index 00000000..15060626 --- /dev/null +++ b/util/logging.py @@ -0,0 +1,35 @@ +from __future__ import print_function + +from util.flags import FLAGS + + +# Logging functions +# ================= + +def prefix_print(prefix, message): + print(prefix + ('\n' + prefix).join(message.split('\n'))) + + +def log_debug(message): + if FLAGS.log_level == 0: + prefix_print('D ', message) + + +def log_traffic(message): + if FLAGS.log_traffic: + log_debug(message) + + +def log_info(message): + if FLAGS.log_level <= 1: + prefix_print('I ', message) + + +def log_warn(message): + if FLAGS.log_level <= 2: + prefix_print('W ', message) + + +def log_error(message): + if FLAGS.log_level <= 3: + prefix_print('E ', message) \ No newline at end of file diff --git a/util/shared_lib.py b/util/shared_lib.py deleted file mode 100644 index 011109bd..00000000 --- a/util/shared_lib.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import print_function -from __future__ import absolute_import -from util.gpu import get_available_gpus -from ctypes import cdll -from sys import platform as _platform - -def get_cupti_libname(): - if _platform == 'linux' or _platform == 'linux2': - return 'libcupti.so' - elif _platform == 'darwin': - return 'libcupti.dylib' - elif _platform == 'win32': - return 'libcupti.dll' - -def check_cupti(): - # We want to ensure that user has properly configured its libs. - # We do this because dso load of libcupti will happen after a lot - # of computation happened, so easy to miss and loose time. - libname = get_cupti_libname() - cupti = check_so(libname) - if cupti is None: - print("INFO: No %s because no GPU, go ahead." % libname) - elif cupti is True: - print("INFO: Found %s." % libname) - else: - print("WARNING: Running on GPU but no %s could be found ; will be unable to report GPU VRAM usage." % libname) - -def check_so(soname): - """ - Verify that we do have the 'soname' lib present in the system, and that it - can be loaded. - """ - - if len(get_available_gpus()) == 0: - return None - - # Try to force load lib, this would fail if the lib is not there :) - try: - lib = cdll.LoadLibrary(soname) - print("INFO: Found so as", lib) - assert lib.__class__.__name__ == 'CDLL' - assert lib._name == soname - return True - except OSError as ex: - print("WARNING:", ex) - return False