From f7a715d506dac932a712f7e443ddbe1f3abf328d Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Fri, 12 Jul 2019 01:21:19 +0200 Subject: [PATCH] Use CuDNN RNN for training --- DeepSpeech.py | 103 +++++++++++++++++++++++++++++++++----------------- util/flags.py | 8 ++-- 2 files changed, 74 insertions(+), 37 deletions(-) diff --git a/DeepSpeech.py b/DeepSpeech.py index 21b366d2..09f9acdd 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -8,12 +8,12 @@ import sys LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0 os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3' -import time import numpy as np import progressbar import shutil import tensorflow as tf import tensorflow.compat.v1 as tfv1 +import time from datetime import datetime from ds_ctcdecoder import ctc_beam_search_decoder, Scorer @@ -79,32 +79,65 @@ def dense(name, x, units, dropout_rate=None, relu=True): def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse): - # Forward direction cell: - fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim, reuse=reuse) + with tf.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'): + fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim, + reuse=reuse, + name='cudnn_compatible_lstm_cell') - output, output_state = fw_cell(inputs=x, - dtype=tf.float32, - sequence_length=seq_length, - initial_state=previous_state) + output, output_state = fw_cell(inputs=x, + dtype=tf.float32, + sequence_length=seq_length, + initial_state=previous_state) return output, output_state -def rnn_impl_static_rnn(x, seq_length, previous_state, reuse): +def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _): + assert previous_state is None # 'Passing previous state not supported with CuDNN backend' + # Forward direction cell: - fw_cell = tf.nn.rnn_cell.LSTMCell(Config.n_cell_dim, reuse=reuse) + if not rnn_impl_cudnn_rnn.cell: + with tf.variable_scope('rnn'): + fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1, + num_units=Config.n_cell_dim, + input_mode='linear_input', + direction='unidirectional', + dtype=tf.float32) + rnn_impl_cudnn_rnn.cell = fw_cell - # Split rank N tensor into list of rank N-1 tensors - x = [x[l] for l in range(x.shape[0])] + output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x, + sequence_lengths=seq_length) - # We parametrize the RNN implementation as the training and inference graph - # need to do different things here. - output, output_state = tf.nn.static_rnn(cell=fw_cell, - inputs=x, - initial_state=previous_state, - dtype=tf.float32, - sequence_length=seq_length) - output = tf.concat(output, 0) + return output, output_state + +# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate +# the object it creates the variables, and then you just call it several times +# to enable variable re-use. Because all of our code is structure in an old +# school TensorFlow structure where you can just call tf.get_variable again with +# reuse=True to reuse variables, we can't easily make use of the object oriented +# way CudnnLSTM is implemented, so we save a singleton instance in the function, +# emulating a static function variable. +rnn_impl_cudnn_rnn.cell = None + + +def rnn_impl_static_rnn(x, seq_length, previous_state, reuse): + with tf.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'): + # Forward direction cell: + fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim, + reuse=reuse, + name='cudnn_compatible_lstm_cell') + + # Split rank N tensor into list of rank N-1 tensors + x = [x[l] for l in range(x.shape[0])] + + output, output_state = tfv1.nn.static_rnn(cell=fw_cell, + inputs=x, + sequence_length=seq_length, + initial_state=previous_state, + dtype=tf.float32, + scope='cell_0') + + output = tf.concat(output, 0) return output, output_state @@ -183,8 +216,13 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse): # Obtain the next batch of data (batch_x, batch_seq_len), batch_y = iterator.get_next() + if FLAGS.use_cudnn_rnn: + rnn_impl = rnn_impl_cudnn_rnn + else: + rnn_impl = rnn_impl_lstmblockfusedcell + # Calculate the logits of the batch - logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse) + logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl) # Compute the CTC loss using TensorFlow's `ctc_loss` total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len) @@ -573,7 +611,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): if batch_size <= 0: # no state management since n_step is expected to be dynamic too (see below) - previous_state = previous_state_c = previous_state_h = None + previous_states = None else: previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c') previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h') @@ -632,7 +670,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): } if not FLAGS.export_tflite: - inputs.update({'input_lengths': seq_length}) + inputs['input_lengths'] = seq_length outputs = { 'outputs': logits, @@ -659,19 +697,16 @@ def export(): output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)] output_names = ",".join(output_names_tensors + output_names_ops) - mapping = None - if FLAGS.export_tflite: - # Create a saver using variables from the above newly created graph - # Training graph uses LSTMFusedCell, but the TFLite inference graph uses - # a static RNN with a normal cell, so we need to rewrite the names to - # match the training weights when restoring. - def fixup(name): - if name.startswith('rnn/lstm_cell/'): - return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/') - return name - - mapping = {fixup(v.op.name): v for v in tf.global_variables()} + # Create a saver using variables from the above newly created graph + # Training graph uses LSTMFusedCell, but the TFLite inference graph uses + # a static RNN with a normal cell, so we need to rewrite the names to + # match the training weights when restoring. + def fixup(name): + if name.startswith('rnn/lstm_cell/'): + return name.replace('rnn/lstm_cell/', 'rnn/cudnn_compatible_lstm_cell/') + return name + mapping = {fixup(v.op.name): v for v in tf.global_variables()} saver = tfv1.train.Saver(mapping) # Restore variables from training checkpoint diff --git a/util/flags.py b/util/flags.py index 468d8e50..997c8f30 100644 --- a/util/flags.py +++ b/util/flags.py @@ -51,9 +51,11 @@ def create_flags(): f.DEFINE_integer('export_batch_size', 1, 'number of elements per batch on the exported graph') - # Performance(UNSUPPORTED) - f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details') - f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details') + # Performance + + f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED') + f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED') + f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU') # Sample limits