Merge pull request #2240 from mozilla/cudnnrnn

Add CuDNN RNN support
This commit is contained in:
Reuben Morais 2019-07-22 07:27:28 +00:00 committed by GitHub
commit 1d50667234
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 39 deletions

View File

@ -8,12 +8,12 @@ import sys
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0 LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3' os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
import time
import numpy as np import numpy as np
import progressbar import progressbar
import shutil import shutil
import tensorflow as tf import tensorflow as tf
import tensorflow.compat.v1 as tfv1 import tensorflow.compat.v1 as tfv1
import time
from datetime import datetime from datetime import datetime
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
@ -79,32 +79,64 @@ def dense(name, x, units, dropout_rate=None, relu=True):
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse): def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
# Forward direction cell: with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim, reuse=reuse) fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
reuse=reuse,
name='cudnn_compatible_lstm_cell')
output, output_state = fw_cell(inputs=x, output, output_state = fw_cell(inputs=x,
dtype=tf.float32, dtype=tf.float32,
sequence_length=seq_length, sequence_length=seq_length,
initial_state=previous_state) initial_state=previous_state)
return output, output_state return output, output_state
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
# the object it creates the variables, and then you just call it several times
# to enable variable re-use. Because all of our code is structure in an old
# school TensorFlow structure where you can just call tf.get_variable again with
# reuse=True to reuse variables, we can't easily make use of the object oriented
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
# emulating a static function variable.
if not rnn_impl_cudnn_rnn.cell:
# Forward direction cell:
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
num_units=Config.n_cell_dim,
input_mode='linear_input',
direction='unidirectional',
dtype=tf.float32)
rnn_impl_cudnn_rnn.cell = fw_cell
output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
sequence_lengths=seq_length)
return output, output_state
rnn_impl_cudnn_rnn.cell = None
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse): def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
# Forward direction cell: with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
fw_cell = tf.nn.rnn_cell.LSTMCell(Config.n_cell_dim, reuse=reuse) # Forward direction cell:
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
reuse=reuse,
name='cudnn_compatible_lstm_cell')
# Split rank N tensor into list of rank N-1 tensors # Split rank N tensor into list of rank N-1 tensors
x = [x[l] for l in range(x.shape[0])] x = [x[l] for l in range(x.shape[0])]
# We parametrize the RNN implementation as the training and inference graph output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
# need to do different things here. inputs=x,
output, output_state = tf.nn.static_rnn(cell=fw_cell, sequence_length=seq_length,
inputs=x, initial_state=previous_state,
initial_state=previous_state, dtype=tf.float32,
dtype=tf.float32, scope='cell_0')
sequence_length=seq_length)
output = tf.concat(output, 0) output = tf.concat(output, 0)
return output, output_state return output, output_state
@ -183,8 +215,13 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
# Obtain the next batch of data # Obtain the next batch of data
(batch_x, batch_seq_len), batch_y = iterator.get_next() (batch_x, batch_seq_len), batch_y = iterator.get_next()
if FLAGS.use_cudnn_rnn:
rnn_impl = rnn_impl_cudnn_rnn
else:
rnn_impl = rnn_impl_lstmblockfusedcell
# Calculate the logits of the batch # Calculate the logits of the batch
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse) logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
# Compute the CTC loss using TensorFlow's `ctc_loss` # Compute the CTC loss using TensorFlow's `ctc_loss`
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len) total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
@ -573,7 +610,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
if batch_size <= 0: if batch_size <= 0:
# no state management since n_step is expected to be dynamic too (see below) # no state management since n_step is expected to be dynamic too (see below)
previous_state = previous_state_c = previous_state_h = None previous_state = None
else: else:
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c') previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h') previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
@ -632,7 +669,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
} }
if not FLAGS.export_tflite: if not FLAGS.export_tflite:
inputs.update({'input_lengths': seq_length}) inputs['input_lengths'] = seq_length
outputs = { outputs = {
'outputs': logits, 'outputs': logits,
@ -659,20 +696,8 @@ def export():
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)] output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
output_names = ",".join(output_names_tensors + output_names_ops) output_names = ",".join(output_names_tensors + output_names_ops)
mapping = None # Create a saver using variables from the above newly created graph
if FLAGS.export_tflite: saver = tfv1.train.Saver()
# Create a saver using variables from the above newly created graph
# Training graph uses LSTMFusedCell, but the TFLite inference graph uses
# a static RNN with a normal cell, so we need to rewrite the names to
# match the training weights when restoring.
def fixup(name):
if name.startswith('rnn/lstm_cell/'):
return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
return name
mapping = {fixup(v.op.name): v for v in tf.global_variables()}
saver = tfv1.train.Saver(mapping)
# Restore variables from training checkpoint # Restore variables from training checkpoint
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)

View File

@ -51,9 +51,11 @@ def create_flags():
f.DEFINE_integer('export_batch_size', 1, 'number of elements per batch on the exported graph') f.DEFINE_integer('export_batch_size', 1, 'number of elements per batch on the exported graph')
# Performance(UNSUPPORTED) # Performance
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details')
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details') f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
# Sample limits # Sample limits