From b85ad3ea743f977d9fe2732e33048a61159fbc48 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Tue, 15 Dec 2020 11:31:21 +0200 Subject: [PATCH] Refactor train.py into separate scripts Currently train.py is overloaded with many independent features. Understanding the code and what will be the result of a training call requires untangling the entire script. It's also an error prone UX. This is a first step at separating independent parts into their own scripts. --- .../deepspeech_training/deepspeech_model.py | 287 +++++++++++ training/deepspeech_training/evaluate.py | 14 +- training/deepspeech_training/export.py | 162 ++++++ training/deepspeech_training/train.py | 486 ++---------------- .../training_graph_inference.py | 77 +++ training/deepspeech_training/util/config.py | 1 + training/deepspeech_training/util/io.py | 10 +- 7 files changed, 581 insertions(+), 456 deletions(-) create mode 100644 training/deepspeech_training/deepspeech_model.py create mode 100644 training/deepspeech_training/export.py create mode 100644 training/deepspeech_training/training_graph_inference.py diff --git a/training/deepspeech_training/deepspeech_model.py b/training/deepspeech_training/deepspeech_model.py new file mode 100644 index 00000000..237fddbd --- /dev/null +++ b/training/deepspeech_training/deepspeech_model.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import os +import sys + +LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0 +DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3' +os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL + +import numpy as np +import tensorflow as tf +import tensorflow.compat.v1 as tfv1 + +tfv1.logging.set_verbosity({ + '0': tfv1.logging.DEBUG, + '1': tfv1.logging.INFO, + '2': tfv1.logging.WARN, + '3': tfv1.logging.ERROR +}.get(DESIRED_LOG_LEVEL)) + +from .util.config import Config +from .util.feeding import audio_to_features +from .util.flags import FLAGS + + +def variable_on_cpu(name, shape, initializer): + r""" + Next we concern ourselves with graph creation. + However, before we do so we must introduce a utility function ``variable_on_cpu()`` + used to create a variable in CPU memory. + """ + # Use the /cpu:0 device for scoped operations + with tf.device(Config.cpu_device): + # Create or get apropos variable + var = tfv1.get_variable(name=name, shape=shape, initializer=initializer) + return var + + +def create_overlapping_windows(batch_x): + batch_size = tf.shape(input=batch_x)[0] + window_width = 2 * Config.n_context + 1 + num_channels = Config.n_input + + # Create a constant convolution filter using an identity matrix, so that the + # convolution returns patches of the input tensor as is, and we can create + # overlapping windows over the MFCCs. + eye_filter = tf.constant(np.eye(window_width * num_channels) + .reshape(window_width, num_channels, window_width * num_channels), tf.float32) + + # Create overlapping windows + batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME') + + # Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input] + batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels]) + + return batch_x + + +def dense(name, x, units, dropout_rate=None, relu=True, layer_norm=False): + with tfv1.variable_scope(name): + bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer()) + weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform")) + + output = tf.nn.bias_add(tf.matmul(x, weights), bias) + + if relu: + output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip) + + if layer_norm: + with tfv1.variable_scope(name): + output = tf.contrib.layers.layer_norm(output) + + if dropout_rate is not None: + output = tf.nn.dropout(output, rate=dropout_rate) + + return output + + +def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse): + with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'): + fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim, + forget_bias=0, + reuse=reuse, + name='cudnn_compatible_lstm_cell') + + output, output_state = fw_cell(inputs=x, + dtype=tf.float32, + sequence_length=seq_length, + initial_state=previous_state) + + return output, output_state + + +def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _): + assert previous_state is None # 'Passing previous state not supported with CuDNN backend' + + # Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate + # the object it creates the variables, and then you just call it several times + # to enable variable re-use. Because all of our code is structure in an old + # school TensorFlow structure where you can just call tf.get_variable again with + # reuse=True to reuse variables, we can't easily make use of the object oriented + # way CudnnLSTM is implemented, so we save a singleton instance in the function, + # emulating a static function variable. + if not rnn_impl_cudnn_rnn.cell: + # Forward direction cell: + fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1, + num_units=Config.n_cell_dim, + input_mode='linear_input', + direction='unidirectional', + dtype=tf.float32) + rnn_impl_cudnn_rnn.cell = fw_cell + + output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x, + sequence_lengths=seq_length) + + return output, output_state + +rnn_impl_cudnn_rnn.cell = None + + +def rnn_impl_static_rnn(x, seq_length, previous_state, reuse): + with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'): + # Forward direction cell: + fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim, + forget_bias=0, + reuse=reuse, + name='cudnn_compatible_lstm_cell') + + # Split rank N tensor into list of rank N-1 tensors + x = [x[l] for l in range(x.shape[0])] + + output, output_state = tfv1.nn.static_rnn(cell=fw_cell, + inputs=x, + sequence_length=seq_length, + initial_state=previous_state, + dtype=tf.float32, + scope='cell_0') + + output = tf.concat(output, 0) + + return output, output_state + + +def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell): + layers = {} + + # Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context] + if not batch_size: + batch_size = tf.shape(input=batch_x)[0] + + # Create overlapping feature windows if needed + if overlap: + batch_x = create_overlapping_windows(batch_x) + + # Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`. + # This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`. + + # Permute n_steps and batch_size + batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3]) + # Reshape to prepare input for first layer + batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context) + layers['input_reshaped'] = batch_x + + # The next three blocks will pass `batch_x` through three hidden layers with + # clipped RELU activation and dropout. + layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0], layer_norm=FLAGS.layer_norm) + layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1], layer_norm=FLAGS.layer_norm) + layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2], layer_norm=FLAGS.layer_norm) + + # `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`, + # as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`. + layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3]) + + # Run through parametrized RNN implementation, as we use different RNNs + # for training and inference + output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse) + + # Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim] + # to a tensor of shape [n_steps*batch_size, n_cell_dim] + output = tf.reshape(output, [-1, Config.n_cell_dim]) + layers['rnn_output'] = output + layers['rnn_output_state'] = output_state + + # Now we feed `output` to the fifth hidden layer with clipped RELU activation + layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5], layer_norm=FLAGS.layer_norm) + + # Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits. + layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False) + + # Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6] + # to the slightly more useful shape [n_steps, batch_size, n_hidden_6]. + # Note, that this differs from the input in that it is time-major. + layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits') + layers['raw_logits'] = layer_6 + + # Output shape: [n_steps, batch_size, n_hidden_6] + return layer_6, layers + + +def create_inference_graph(batch_size=1, n_steps=16, tflite=False): + batch_size = batch_size if batch_size > 0 else None + + # Create feature computation graph + input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples') + samples = tf.expand_dims(input_samples, -1) + mfccs, _ = audio_to_features(samples, FLAGS.audio_sample_rate) + mfccs = tf.identity(mfccs, name='mfccs') + + # Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input] + # This shape is read by the native_client in DS_CreateModel to know the + # value of n_steps, n_context and n_input. Make sure you update the code + # there if this shape is changed. + input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node') + seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths') + + if batch_size <= 0: + # no state management since n_step is expected to be dynamic too (see below) + previous_state = None + else: + previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c') + previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h') + + previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h) + + # One rate per layer + no_dropout = [None] * 6 + + if tflite: + rnn_impl = rnn_impl_static_rnn + else: + rnn_impl = rnn_impl_lstmblockfusedcell + + logits, layers = create_model(batch_x=input_tensor, + batch_size=batch_size, + seq_length=seq_length if not FLAGS.export_tflite else None, + dropout=no_dropout, + previous_state=previous_state, + overlap=False, + rnn_impl=rnn_impl) + + # TF Lite runtime will check that input dimensions are 1, 2 or 4 + # by default we get 3, the middle one being batch_size which is forced to + # one on inference graph, so remove that dimension + if tflite: + logits = tf.squeeze(logits, [1]) + + # Apply softmax for CTC decoder + probs = tf.nn.softmax(logits, name='logits') + + if batch_size <= 0: + if tflite: + raise NotImplementedError('dynamic batch_size does not support tflite nor streaming') + if n_steps > 0: + raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too') + return ( + { + 'input': input_tensor, + 'input_lengths': seq_length, + }, + { + 'outputs': probs, + }, + layers + ) + + new_state_c, new_state_h = layers['rnn_output_state'] + new_state_c = tf.identity(new_state_c, name='new_state_c') + new_state_h = tf.identity(new_state_h, name='new_state_h') + + inputs = { + 'input': input_tensor, + 'previous_state_c': previous_state_c, + 'previous_state_h': previous_state_h, + 'input_samples': input_samples, + } + + if not FLAGS.export_tflite: + inputs['input_lengths'] = seq_length + + outputs = { + 'outputs': probs, + 'new_state_c': new_state_c, + 'new_state_h': new_state_h, + 'mfccs': mfccs, + } + + return inputs, outputs, layers diff --git a/training/deepspeech_training/evaluate.py b/training/deepspeech_training/evaluate.py index 965b3370..376f0e95 100755 --- a/training/deepspeech_training/evaluate.py +++ b/training/deepspeech_training/evaluate.py @@ -133,6 +133,13 @@ def evaluate(test_csvs, create_model): return samples +def test(): + from .train import create_model # pylint: disable=cyclic-import,import-outside-toplevel + samples = evaluate(FLAGS.test_files.split(','), create_model) + if FLAGS.test_output_file: + save_samples_json(samples, FLAGS.test_output_file) + + def main(_): initialize_globals() @@ -141,16 +148,13 @@ def main(_): 'the --test_files flag.') sys.exit(1) - from .train import create_model # pylint: disable=cyclic-import,import-outside-toplevel - samples = evaluate(FLAGS.test_files.split(','), create_model) - - if FLAGS.test_output_file: - save_samples_json(samples, FLAGS.test_output_file) + test() def run_script(): create_flags() absl.app.run(main) + if __name__ == '__main__': run_script() diff --git a/training/deepspeech_training/export.py b/training/deepspeech_training/export.py new file mode 100644 index 00000000..11b50813 --- /dev/null +++ b/training/deepspeech_training/export.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import os +import sys + +LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0 +DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3' +os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL + +import absl.app +import tensorflow as tf +import tensorflow.compat.v1 as tfv1 +import shutil + +from .deepspeech_model import create_inference_graph +from .util.checkpoints import load_graph_for_evaluation +from .util.config import Config, initialize_globals +from .util.flags import create_flags, FLAGS +from .util.io import open_remote, rmtree_remote, listdir_remote, is_remote_path, isdir_remote +from .util.logging import log_error, log_info + + +def file_relative_read(fname): + return open(os.path.join(os.path.dirname(__file__), fname)).read() + + +def export(): + r''' + Restores the trained variables into a simpler graph that will be exported for serving. + ''' + log_info('Exporting the model...') + + inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite) + + graph_version = int(file_relative_read('GRAPH_VERSION').strip()) + assert graph_version > 0 + + outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version') + outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate') + outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len') + outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step') + outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width') + outputs['metadata_alphabet'] = tf.constant([Config.alphabet.Serialize()], name='metadata_alphabet') + + if FLAGS.export_language: + outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language') + + # Prevent further graph changes + tfv1.get_default_graph().finalize() + + output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, tf.Tensor)] + output_names_ops = [op.name for op in outputs.values() if isinstance(op, tf.Operation)] + output_names = output_names_tensors + output_names_ops + + with tf.Session() as session: + # Restore variables from checkpoint + load_graph_for_evaluation(session) + + output_filename = FLAGS.export_file_name + '.pb' + if FLAGS.remove_export: + if isdir_remote(FLAGS.export_dir): + log_info('Removing old export') + rmtree_remote(FLAGS.export_dir) + + output_graph_path = os.path.join(FLAGS.export_dir, output_filename) + + if not is_remote_path(FLAGS.export_dir) and not os.path.isdir(FLAGS.export_dir): + os.makedirs(FLAGS.export_dir) + + frozen_graph = tfv1.graph_util.convert_variables_to_constants( + sess=session, + input_graph_def=tfv1.get_default_graph().as_graph_def(), + output_node_names=output_names) + + frozen_graph = tfv1.graph_util.extract_sub_graph( + graph_def=frozen_graph, + dest_nodes=output_names) + + if not FLAGS.export_tflite: + with open_remote(output_graph_path, 'wb') as fout: + fout.write(frozen_graph.SerializeToString()) + else: + output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite')) + + converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values()) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + # AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite + converter.allow_custom_ops = True + tflite_model = converter.convert() + + with open_remote(output_tflite_path, 'wb') as fout: + fout.write(tflite_model) + + log_info('Models exported at %s' % (FLAGS.export_dir)) + + metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format( + FLAGS.export_author_id, + FLAGS.export_model_name, + FLAGS.export_model_version)) + + model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow' + with open_remote(metadata_fname, 'w') as f: + f.write('---\n') + f.write('author: {}\n'.format(FLAGS.export_author_id)) + f.write('model_name: {}\n'.format(FLAGS.export_model_name)) + f.write('model_version: {}\n'.format(FLAGS.export_model_version)) + f.write('contact_info: {}\n'.format(FLAGS.export_contact_info)) + f.write('license: {}\n'.format(FLAGS.export_license)) + f.write('language: {}\n'.format(FLAGS.export_language)) + f.write('runtime: {}\n'.format(model_runtime)) + f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version)) + f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version)) + f.write('acoustic_model_url: \n') + f.write('scorer_url: \n') + f.write('---\n') + f.write('{}\n'.format(FLAGS.export_description)) + + log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname)) + + +def package_zip(): + # --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip + export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/' + if is_remote_path(export_dir): + log_error("Cannot package remote path zip %s. Please do this manually." % export_dir) + return + + zip_filename = os.path.dirname(export_dir) + + shutil.copy(FLAGS.scorer_path, export_dir) + + archive = shutil.make_archive(zip_filename, 'zip', export_dir) + log_info('Exported packaged model {}'.format(archive)) + + +def main(_): + initialize_globals() + + if FLAGS.export_dir: + tfv1.reset_default_graph() + + if not FLAGS.export_zip: + # Export to folder + export() + else: + # Export and zip, TFLite only, creates package readable by Java example app + FLAGS.export_tflite = True + + if listdir_remote(FLAGS.export_dir): + log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir)) + sys.exit(1) + + export() + package_zip() + else: + log_error('Calling export script directly but no --export_dir specified') + sys.exit(1) + + +if __name__ == '__main__': + create_flags() + absl.app.run(main) diff --git a/training/deepspeech_training/train.py b/training/deepspeech_training/train.py index 94ca7c04..635f4a8f 100644 --- a/training/deepspeech_training/train.py +++ b/training/deepspeech_training/train.py @@ -1,7 +1,5 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from __future__ import absolute_import, division, print_function - import os import sys @@ -10,13 +8,13 @@ DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.a os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL import absl.app -import numpy as np import progressbar -import shutil import tensorflow as tf import tensorflow.compat.v1 as tfv1 import time +from datetime import datetime + tfv1.logging.set_verbosity({ '0': tfv1.logging.DEBUG, '1': tfv1.logging.INFO, @@ -24,197 +22,23 @@ tfv1.logging.set_verbosity({ '3': tfv1.logging.ERROR }.get(DESIRED_LOG_LEVEL)) -from datetime import datetime -from ds_ctcdecoder import ctc_beam_search_decoder, Scorer -from .evaluate import evaluate -from six.moves import zip, range +from ds_ctcdecoder import Scorer + +from . import export +from . import evaluate +from . import training_graph_inference +from .deepspeech_model import create_model, rnn_impl_lstmblockfusedcell, rnn_impl_cudnn_rnn from .util.config import Config, initialize_globals -from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation, reload_best_checkpoint -from .util.evaluate_tools import save_samples_json -from .util.feeding import create_dataset, audio_to_features, audiofile_to_features +from .util.checkpoints import load_or_init_graph_for_training, reload_best_checkpoint +from .util.feeding import create_dataset from .util.flags import create_flags, FLAGS from .util.helpers import check_ctcdecoder_version, ExceptionBox from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn -from .util.io import open_remote, remove_remote, listdir_remote, is_remote_path, isdir_remote +from .util.io import open_remote, remove_remote, listdir_remote, is_remote_path + check_ctcdecoder_version() -# Graph Creation -# ============== - -def variable_on_cpu(name, shape, initializer): - r""" - Next we concern ourselves with graph creation. - However, before we do so we must introduce a utility function ``variable_on_cpu()`` - used to create a variable in CPU memory. - """ - # Use the /cpu:0 device for scoped operations - with tf.device(Config.cpu_device): - # Create or get apropos variable - var = tfv1.get_variable(name=name, shape=shape, initializer=initializer) - return var - - -def create_overlapping_windows(batch_x): - batch_size = tf.shape(input=batch_x)[0] - window_width = 2 * Config.n_context + 1 - num_channels = Config.n_input - - # Create a constant convolution filter using an identity matrix, so that the - # convolution returns patches of the input tensor as is, and we can create - # overlapping windows over the MFCCs. - eye_filter = tf.constant(np.eye(window_width * num_channels) - .reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation - - # Create overlapping windows - batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME') - - # Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input] - batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels]) - - return batch_x - - -def dense(name, x, units, dropout_rate=None, relu=True, layer_norm=False): - with tfv1.variable_scope(name): - bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer()) - weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform")) - - output = tf.nn.bias_add(tf.matmul(x, weights), bias) - - if relu: - output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip) - - if layer_norm: - with tfv1.variable_scope(name): - output = tf.contrib.layers.layer_norm(output) - - if dropout_rate is not None: - output = tf.nn.dropout(output, rate=dropout_rate) - - return output - - -def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse): - with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'): - fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim, - forget_bias=0, - reuse=reuse, - name='cudnn_compatible_lstm_cell') - - output, output_state = fw_cell(inputs=x, - dtype=tf.float32, - sequence_length=seq_length, - initial_state=previous_state) - - return output, output_state - - -def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _): - assert previous_state is None # 'Passing previous state not supported with CuDNN backend' - - # Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate - # the object it creates the variables, and then you just call it several times - # to enable variable re-use. Because all of our code is structure in an old - # school TensorFlow structure where you can just call tf.get_variable again with - # reuse=True to reuse variables, we can't easily make use of the object oriented - # way CudnnLSTM is implemented, so we save a singleton instance in the function, - # emulating a static function variable. - if not rnn_impl_cudnn_rnn.cell: - # Forward direction cell: - fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1, - num_units=Config.n_cell_dim, - input_mode='linear_input', - direction='unidirectional', - dtype=tf.float32) - rnn_impl_cudnn_rnn.cell = fw_cell - - output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x, - sequence_lengths=seq_length) - - return output, output_state - -rnn_impl_cudnn_rnn.cell = None - - -def rnn_impl_static_rnn(x, seq_length, previous_state, reuse): - with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'): - # Forward direction cell: - fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim, - forget_bias=0, - reuse=reuse, - name='cudnn_compatible_lstm_cell') - - # Split rank N tensor into list of rank N-1 tensors - x = [x[l] for l in range(x.shape[0])] - - output, output_state = tfv1.nn.static_rnn(cell=fw_cell, - inputs=x, - sequence_length=seq_length, - initial_state=previous_state, - dtype=tf.float32, - scope='cell_0') - - output = tf.concat(output, 0) - - return output, output_state - - -def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell): - layers = {} - - # Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context] - if not batch_size: - batch_size = tf.shape(input=batch_x)[0] - - # Create overlapping feature windows if needed - if overlap: - batch_x = create_overlapping_windows(batch_x) - - # Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`. - # This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`. - - # Permute n_steps and batch_size - batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3]) - # Reshape to prepare input for first layer - batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context) - layers['input_reshaped'] = batch_x - - # The next three blocks will pass `batch_x` through three hidden layers with - # clipped RELU activation and dropout. - layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0], layer_norm=FLAGS.layer_norm) - layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1], layer_norm=FLAGS.layer_norm) - layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2], layer_norm=FLAGS.layer_norm) - - # `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`, - # as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`. - layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3]) - - # Run through parametrized RNN implementation, as we use different RNNs - # for training and inference - output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse) - - # Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim] - # to a tensor of shape [n_steps*batch_size, n_cell_dim] - output = tf.reshape(output, [-1, Config.n_cell_dim]) - layers['rnn_output'] = output - layers['rnn_output_state'] = output_state - - # Now we feed `output` to the fifth hidden layer with clipped RELU activation - layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5], layer_norm=FLAGS.layer_norm) - - # Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits. - layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False) - - # Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6] - # to the slightly more useful shape [n_steps, batch_size, n_hidden_6]. - # Note, that this differs from the input in that it is time-major. - layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits') - layers['raw_logits'] = layer_6 - - # Output shape: [n_steps, batch_size, n_hidden_6] - return layer_6, layers - # Accuracy and Loss # ================= @@ -678,255 +502,6 @@ def train(): log_debug('Session closed.') -def test(): - samples = evaluate(FLAGS.test_files.split(','), create_model) - if FLAGS.test_output_file: - save_samples_json(samples, FLAGS.test_output_file) - - -def create_inference_graph(batch_size=1, n_steps=16, tflite=False): - batch_size = batch_size if batch_size > 0 else None - - # Create feature computation graph - input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples') - samples = tf.expand_dims(input_samples, -1) - mfccs, _ = audio_to_features(samples, FLAGS.audio_sample_rate) - mfccs = tf.identity(mfccs, name='mfccs') - - # Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input] - # This shape is read by the native_client in DS_CreateModel to know the - # value of n_steps, n_context and n_input. Make sure you update the code - # there if this shape is changed. - input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node') - seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths') - - if batch_size <= 0: - # no state management since n_step is expected to be dynamic too (see below) - previous_state = None - else: - previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c') - previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h') - - previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h) - - # One rate per layer - no_dropout = [None] * 6 - - if tflite: - rnn_impl = rnn_impl_static_rnn - else: - rnn_impl = rnn_impl_lstmblockfusedcell - - logits, layers = create_model(batch_x=input_tensor, - batch_size=batch_size, - seq_length=seq_length if not FLAGS.export_tflite else None, - dropout=no_dropout, - previous_state=previous_state, - overlap=False, - rnn_impl=rnn_impl) - - # TF Lite runtime will check that input dimensions are 1, 2 or 4 - # by default we get 3, the middle one being batch_size which is forced to - # one on inference graph, so remove that dimension - if tflite: - logits = tf.squeeze(logits, [1]) - - # Apply softmax for CTC decoder - probs = tf.nn.softmax(logits, name='logits') - - if batch_size <= 0: - if tflite: - raise NotImplementedError('dynamic batch_size does not support tflite nor streaming') - if n_steps > 0: - raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too') - return ( - { - 'input': input_tensor, - 'input_lengths': seq_length, - }, - { - 'outputs': probs, - }, - layers - ) - - new_state_c, new_state_h = layers['rnn_output_state'] - new_state_c = tf.identity(new_state_c, name='new_state_c') - new_state_h = tf.identity(new_state_h, name='new_state_h') - - inputs = { - 'input': input_tensor, - 'previous_state_c': previous_state_c, - 'previous_state_h': previous_state_h, - 'input_samples': input_samples, - } - - if not FLAGS.export_tflite: - inputs['input_lengths'] = seq_length - - outputs = { - 'outputs': probs, - 'new_state_c': new_state_c, - 'new_state_h': new_state_h, - 'mfccs': mfccs, - } - - return inputs, outputs, layers - - -def file_relative_read(fname): - return open(os.path.join(os.path.dirname(__file__), fname)).read() - - -def export(): - r''' - Restores the trained variables into a simpler graph that will be exported for serving. - ''' - log_info('Exporting the model...') - - inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite) - - graph_version = int(file_relative_read('GRAPH_VERSION').strip()) - assert graph_version > 0 - - outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version') - outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate') - outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len') - outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step') - outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width') - outputs['metadata_alphabet'] = tf.constant([Config.alphabet.Serialize()], name='metadata_alphabet') - - if FLAGS.export_language: - outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language') - - # Prevent further graph changes - tfv1.get_default_graph().finalize() - - output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, tf.Tensor)] - output_names_ops = [op.name for op in outputs.values() if isinstance(op, tf.Operation)] - output_names = output_names_tensors + output_names_ops - - with tf.Session() as session: - # Restore variables from checkpoint - load_graph_for_evaluation(session) - - output_filename = FLAGS.export_file_name + '.pb' - if FLAGS.remove_export: - if isdir_remote(FLAGS.export_dir): - log_info('Removing old export') - remove_remote(FLAGS.export_dir) - - output_graph_path = os.path.join(FLAGS.export_dir, output_filename) - - if not is_remote_path(FLAGS.export_dir) and not os.path.isdir(FLAGS.export_dir): - os.makedirs(FLAGS.export_dir) - - frozen_graph = tfv1.graph_util.convert_variables_to_constants( - sess=session, - input_graph_def=tfv1.get_default_graph().as_graph_def(), - output_node_names=output_names) - - frozen_graph = tfv1.graph_util.extract_sub_graph( - graph_def=frozen_graph, - dest_nodes=output_names) - - if not FLAGS.export_tflite: - with open_remote(output_graph_path, 'wb') as fout: - fout.write(frozen_graph.SerializeToString()) - else: - output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite')) - - converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values()) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - # AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite - converter.allow_custom_ops = True - tflite_model = converter.convert() - - with open_remote(output_tflite_path, 'wb') as fout: - fout.write(tflite_model) - - log_info('Models exported at %s' % (FLAGS.export_dir)) - - metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format( - FLAGS.export_author_id, - FLAGS.export_model_name, - FLAGS.export_model_version)) - - model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow' - with open_remote(metadata_fname, 'w') as f: - f.write('---\n') - f.write('author: {}\n'.format(FLAGS.export_author_id)) - f.write('model_name: {}\n'.format(FLAGS.export_model_name)) - f.write('model_version: {}\n'.format(FLAGS.export_model_version)) - f.write('contact_info: {}\n'.format(FLAGS.export_contact_info)) - f.write('license: {}\n'.format(FLAGS.export_license)) - f.write('language: {}\n'.format(FLAGS.export_language)) - f.write('runtime: {}\n'.format(model_runtime)) - f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version)) - f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version)) - f.write('acoustic_model_url: \n') - f.write('scorer_url: \n') - f.write('---\n') - f.write('{}\n'.format(FLAGS.export_description)) - - log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname)) - - -def package_zip(): - # --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip - export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/' - if is_remote_path(export_dir): - log_error("Cannot package remote path zip %s. Please do this manually." % export_dir) - return - - zip_filename = os.path.dirname(export_dir) - - shutil.copy(FLAGS.scorer_path, export_dir) - - archive = shutil.make_archive(zip_filename, 'zip', export_dir) - log_info('Exported packaged model {}'.format(archive)) - - -def do_single_file_inference(input_file_path): - with tfv1.Session(config=Config.session_config) as session: - inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1) - - # Restore variables from training checkpoint - load_graph_for_evaluation(session) - - features, features_len = audiofile_to_features(input_file_path) - previous_state_c = np.zeros([1, Config.n_cell_dim]) - previous_state_h = np.zeros([1, Config.n_cell_dim]) - - # Add batch dimension - features = tf.expand_dims(features, 0) - features_len = tf.expand_dims(features_len, 0) - - # Evaluate - features = create_overlapping_windows(features).eval(session=session) - features_len = features_len.eval(session=session) - - probs = outputs['outputs'].eval(feed_dict={ - inputs['input']: features, - inputs['input_lengths']: features_len, - inputs['previous_state_c']: previous_state_c, - inputs['previous_state_h']: previous_state_h, - }, session=session) - - probs = np.squeeze(probs) - - if FLAGS.scorer_path: - scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, - FLAGS.scorer_path, Config.alphabet) - else: - scorer = None - decoded = ctc_beam_search_decoder(probs, Config.alphabet, FLAGS.beam_width, - scorer=scorer, cutoff_prob=FLAGS.cutoff_prob, - cutoff_top_n=FLAGS.cutoff_top_n) - # Print highest probability result - print(decoded[0][1]) - - def early_training_checks(): # Check for proper scorer early if FLAGS.scorer_path: @@ -948,38 +523,51 @@ def main(_): initialize_globals() early_training_checks() + def deprecated_msg(prefix): + return ('{} Using the training script as a generic driver for all training ' + 'related functionality is deprecated and will be removed soon. Use ' + 'the specific scripts: train/evaluate/export/training_graph_inference.'.format(prefix)) + if FLAGS.train_files: tfv1.reset_default_graph() tfv1.set_random_seed(FLAGS.random_seed) train() + else: + log_warn(deprecated_msg('Calling training script without --train_files.')) if FLAGS.test_files: + log_warn(deprecated_msg('Specifying --test_files when calling train script.')) tfv1.reset_default_graph() - test() + evaluate.test() - if FLAGS.export_dir and not FLAGS.export_zip: + if FLAGS.export_dir: + log_warn(deprecated_msg('Specifying --export_dir when calling train script.')) tfv1.reset_default_graph() - export() - if FLAGS.export_zip: - tfv1.reset_default_graph() - FLAGS.export_tflite = True + if not FLAGS.export_zip: + # Export to folder + export.export() + else: + # Export and zip, TFLite only, creates package readable by Java example app + FLAGS.export_tflite = True - if listdir_remote(FLAGS.export_dir): - log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir)) - sys.exit(1) + if listdir_remote(FLAGS.export_dir): + log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir)) + sys.exit(1) - export() - package_zip() + export.export() + export.package_zip() if FLAGS.one_shot_infer: + log_warn(deprecated_msg('Specifying --one_shot_infer when calling train script.')) tfv1.reset_default_graph() - do_single_file_inference(FLAGS.one_shot_infer) + training_graph_inference.do_single_file_inference(FLAGS.one_shot_infer) def run_script(): create_flags() absl.app.run(main) + if __name__ == '__main__': run_script() diff --git a/training/deepspeech_training/training_graph_inference.py b/training/deepspeech_training/training_graph_inference.py new file mode 100644 index 00000000..18758e2f --- /dev/null +++ b/training/deepspeech_training/training_graph_inference.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import os +import sys + +LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0 +DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3' +os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL + +import absl.app +import numpy as np +import tensorflow as tf +import tensorflow.compat.v1 as tfv1 + +from ds_ctcdecoder import ctc_beam_search_decoder, Scorer +from .deepspeech_model import create_inference_graph, create_overlapping_windows +from .util.checkpoints import load_graph_for_evaluation +from .util.config import Config, initialize_globals +from .util.feeding import audiofile_to_features +from .util.flags import create_flags, FLAGS +from .util.logging import log_error + + +def do_single_file_inference(input_file_path): + with tfv1.Session(config=Config.session_config) as session: + inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1) + + # Restore variables from training checkpoint + load_graph_for_evaluation(session) + + features, features_len = audiofile_to_features(input_file_path) + previous_state_c = np.zeros([1, Config.n_cell_dim]) + previous_state_h = np.zeros([1, Config.n_cell_dim]) + + # Add batch dimension + features = tf.expand_dims(features, 0) + features_len = tf.expand_dims(features_len, 0) + + # Evaluate + features = create_overlapping_windows(features).eval(session=session) + features_len = features_len.eval(session=session) + + probs = outputs['outputs'].eval(feed_dict={ + inputs['input']: features, + inputs['input_lengths']: features_len, + inputs['previous_state_c']: previous_state_c, + inputs['previous_state_h']: previous_state_h, + }, session=session) + + probs = np.squeeze(probs) + + if FLAGS.scorer_path: + scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, + FLAGS.scorer_path, Config.alphabet) + else: + scorer = None + decoded = ctc_beam_search_decoder(probs, Config.alphabet, FLAGS.beam_width, + scorer=scorer, cutoff_prob=FLAGS.cutoff_prob, + cutoff_top_n=FLAGS.cutoff_top_n) + # Print highest probability result + print(decoded[0][1]) + + +def main(_): + initialize_globals() + + if FLAGS.one_shot_infer: + tfv1.reset_default_graph() + do_single_file_inference(FLAGS.one_shot_infer) + else: + log_error('Calling training_graph_inference script directly but no --one_shot_infer input audio file specified') + sys.exit(1) + + +if __name__ == '__main__': + create_flags() + absl.app.run(main) diff --git a/training/deepspeech_training/util/config.py b/training/deepspeech_training/util/config.py index 18da6eed..0bdb8d30 100755 --- a/training/deepspeech_training/util/config.py +++ b/training/deepspeech_training/util/config.py @@ -15,6 +15,7 @@ from .helpers import parse_file_size from .augmentations import parse_augmentations from .io import path_exists_remote + class ConfigSingleton: _config = None diff --git a/training/deepspeech_training/util/io.py b/training/deepspeech_training/util/io.py index 947b43af..0feeba6a 100644 --- a/training/deepspeech_training/util/io.py +++ b/training/deepspeech_training/util/io.py @@ -77,5 +77,11 @@ def remove_remote(filename): """ Wrapper that can remove local and remote files like `gs://...` """ - # Conditional import - return gfile.remove_remote(filename) \ No newline at end of file + return gfile.remove(filename) + + +def rmtree_remote(foldername): + """ + Wrapper that can remove local and remote directories like `gs://...` + """ + return gfile.rmtree(foldername)