Use CuDNN RNN for training

This commit is contained in:
Reuben Morais 2019-07-12 01:21:19 +02:00
parent 7fd7381871
commit f7a715d506
2 changed files with 74 additions and 37 deletions

View File

@ -8,12 +8,12 @@ import sys
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0 LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3' os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
import time
import numpy as np import numpy as np
import progressbar import progressbar
import shutil import shutil
import tensorflow as tf import tensorflow as tf
import tensorflow.compat.v1 as tfv1 import tensorflow.compat.v1 as tfv1
import time
from datetime import datetime from datetime import datetime
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
@ -79,32 +79,65 @@ def dense(name, x, units, dropout_rate=None, relu=True):
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse): def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
# Forward direction cell: with tf.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim, reuse=reuse) fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
reuse=reuse,
name='cudnn_compatible_lstm_cell')
output, output_state = fw_cell(inputs=x, output, output_state = fw_cell(inputs=x,
dtype=tf.float32, dtype=tf.float32,
sequence_length=seq_length, sequence_length=seq_length,
initial_state=previous_state) initial_state=previous_state)
return output, output_state return output, output_state
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse): def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
# Forward direction cell: # Forward direction cell:
fw_cell = tf.nn.rnn_cell.LSTMCell(Config.n_cell_dim, reuse=reuse) if not rnn_impl_cudnn_rnn.cell:
with tf.variable_scope('rnn'):
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
num_units=Config.n_cell_dim,
input_mode='linear_input',
direction='unidirectional',
dtype=tf.float32)
rnn_impl_cudnn_rnn.cell = fw_cell
# Split rank N tensor into list of rank N-1 tensors output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
x = [x[l] for l in range(x.shape[0])] sequence_lengths=seq_length)
# We parametrize the RNN implementation as the training and inference graph return output, output_state
# need to do different things here.
output, output_state = tf.nn.static_rnn(cell=fw_cell, # Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
inputs=x, # the object it creates the variables, and then you just call it several times
initial_state=previous_state, # to enable variable re-use. Because all of our code is structure in an old
dtype=tf.float32, # school TensorFlow structure where you can just call tf.get_variable again with
sequence_length=seq_length) # reuse=True to reuse variables, we can't easily make use of the object oriented
output = tf.concat(output, 0) # way CudnnLSTM is implemented, so we save a singleton instance in the function,
# emulating a static function variable.
rnn_impl_cudnn_rnn.cell = None
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
with tf.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
# Forward direction cell:
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
reuse=reuse,
name='cudnn_compatible_lstm_cell')
# Split rank N tensor into list of rank N-1 tensors
x = [x[l] for l in range(x.shape[0])]
output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
inputs=x,
sequence_length=seq_length,
initial_state=previous_state,
dtype=tf.float32,
scope='cell_0')
output = tf.concat(output, 0)
return output, output_state return output, output_state
@ -183,8 +216,13 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
# Obtain the next batch of data # Obtain the next batch of data
(batch_x, batch_seq_len), batch_y = iterator.get_next() (batch_x, batch_seq_len), batch_y = iterator.get_next()
if FLAGS.use_cudnn_rnn:
rnn_impl = rnn_impl_cudnn_rnn
else:
rnn_impl = rnn_impl_lstmblockfusedcell
# Calculate the logits of the batch # Calculate the logits of the batch
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse) logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
# Compute the CTC loss using TensorFlow's `ctc_loss` # Compute the CTC loss using TensorFlow's `ctc_loss`
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len) total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
@ -573,7 +611,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
if batch_size <= 0: if batch_size <= 0:
# no state management since n_step is expected to be dynamic too (see below) # no state management since n_step is expected to be dynamic too (see below)
previous_state = previous_state_c = previous_state_h = None previous_states = None
else: else:
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c') previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h') previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
@ -632,7 +670,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
} }
if not FLAGS.export_tflite: if not FLAGS.export_tflite:
inputs.update({'input_lengths': seq_length}) inputs['input_lengths'] = seq_length
outputs = { outputs = {
'outputs': logits, 'outputs': logits,
@ -659,19 +697,16 @@ def export():
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)] output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
output_names = ",".join(output_names_tensors + output_names_ops) output_names = ",".join(output_names_tensors + output_names_ops)
mapping = None # Create a saver using variables from the above newly created graph
if FLAGS.export_tflite: # Training graph uses LSTMFusedCell, but the TFLite inference graph uses
# Create a saver using variables from the above newly created graph # a static RNN with a normal cell, so we need to rewrite the names to
# Training graph uses LSTMFusedCell, but the TFLite inference graph uses # match the training weights when restoring.
# a static RNN with a normal cell, so we need to rewrite the names to def fixup(name):
# match the training weights when restoring. if name.startswith('rnn/lstm_cell/'):
def fixup(name): return name.replace('rnn/lstm_cell/', 'rnn/cudnn_compatible_lstm_cell/')
if name.startswith('rnn/lstm_cell/'): return name
return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
return name
mapping = {fixup(v.op.name): v for v in tf.global_variables()}
mapping = {fixup(v.op.name): v for v in tf.global_variables()}
saver = tfv1.train.Saver(mapping) saver = tfv1.train.Saver(mapping)
# Restore variables from training checkpoint # Restore variables from training checkpoint

View File

@ -51,9 +51,11 @@ def create_flags():
f.DEFINE_integer('export_batch_size', 1, 'number of elements per batch on the exported graph') f.DEFINE_integer('export_batch_size', 1, 'number of elements per batch on the exported graph')
# Performance(UNSUPPORTED) # Performance
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details')
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details') f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU')
# Sample limits # Sample limits