Address review comments

This commit is contained in:
Reuben Morais 2019-07-19 11:07:58 +02:00
parent f7a715d506
commit fd3fbcaa78
2 changed files with 18 additions and 28 deletions

View File

@ -79,7 +79,7 @@ def dense(name, x, units, dropout_rate=None, relu=True):
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse): def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
with tf.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'): with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim, fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
reuse=reuse, reuse=reuse,
name='cudnn_compatible_lstm_cell') name='cudnn_compatible_lstm_cell')
@ -95,9 +95,15 @@ def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _): def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
assert previous_state is None # 'Passing previous state not supported with CuDNN backend' assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
# Forward direction cell: # Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
# the object it creates the variables, and then you just call it several times
# to enable variable re-use. Because all of our code is structure in an old
# school TensorFlow structure where you can just call tf.get_variable again with
# reuse=True to reuse variables, we can't easily make use of the object oriented
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
# emulating a static function variable.
if not rnn_impl_cudnn_rnn.cell: if not rnn_impl_cudnn_rnn.cell:
with tf.variable_scope('rnn'): # Forward direction cell:
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1, fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
num_units=Config.n_cell_dim, num_units=Config.n_cell_dim,
input_mode='linear_input', input_mode='linear_input',
@ -110,18 +116,11 @@ def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
return output, output_state return output, output_state
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
# the object it creates the variables, and then you just call it several times
# to enable variable re-use. Because all of our code is structure in an old
# school TensorFlow structure where you can just call tf.get_variable again with
# reuse=True to reuse variables, we can't easily make use of the object oriented
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
# emulating a static function variable.
rnn_impl_cudnn_rnn.cell = None rnn_impl_cudnn_rnn.cell = None
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse): def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
with tf.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'): with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
# Forward direction cell: # Forward direction cell:
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim, fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
reuse=reuse, reuse=reuse,
@ -611,7 +610,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
if batch_size <= 0: if batch_size <= 0:
# no state management since n_step is expected to be dynamic too (see below) # no state management since n_step is expected to be dynamic too (see below)
previous_states = None previous_state = None
else: else:
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c') previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h') previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
@ -698,16 +697,7 @@ def export():
output_names = ",".join(output_names_tensors + output_names_ops) output_names = ",".join(output_names_tensors + output_names_ops)
# Create a saver using variables from the above newly created graph # Create a saver using variables from the above newly created graph
# Training graph uses LSTMFusedCell, but the TFLite inference graph uses saver = tfv1.train.Saver()
# a static RNN with a normal cell, so we need to rewrite the names to
# match the training weights when restoring.
def fixup(name):
if name.startswith('rnn/lstm_cell/'):
return name.replace('rnn/lstm_cell/', 'rnn/cudnn_compatible_lstm_cell/')
return name
mapping = {fixup(v.op.name): v for v in tf.global_variables()}
saver = tfv1.train.Saver(mapping)
# Restore variables from training checkpoint # Restore variables from training checkpoint
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)

View File

@ -55,7 +55,7 @@ def create_flags():
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED') f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED') f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU') f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
# Sample limits # Sample limits