commit
1d50667234
@ -8,12 +8,12 @@ import sys
|
|||||||
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
||||||
|
|
||||||
import time
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import progressbar
|
import progressbar
|
||||||
import shutil
|
import shutil
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow.compat.v1 as tfv1
|
import tensorflow.compat.v1 as tfv1
|
||||||
|
import time
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||||
@ -79,32 +79,64 @@ def dense(name, x, units, dropout_rate=None, relu=True):
|
|||||||
|
|
||||||
|
|
||||||
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
|
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
|
||||||
# Forward direction cell:
|
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
|
||||||
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim, reuse=reuse)
|
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
|
||||||
|
reuse=reuse,
|
||||||
|
name='cudnn_compatible_lstm_cell')
|
||||||
|
|
||||||
output, output_state = fw_cell(inputs=x,
|
output, output_state = fw_cell(inputs=x,
|
||||||
dtype=tf.float32,
|
dtype=tf.float32,
|
||||||
sequence_length=seq_length,
|
sequence_length=seq_length,
|
||||||
initial_state=previous_state)
|
initial_state=previous_state)
|
||||||
|
|
||||||
return output, output_state
|
return output, output_state
|
||||||
|
|
||||||
|
|
||||||
|
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
|
||||||
|
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
|
||||||
|
|
||||||
|
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
|
||||||
|
# the object it creates the variables, and then you just call it several times
|
||||||
|
# to enable variable re-use. Because all of our code is structure in an old
|
||||||
|
# school TensorFlow structure where you can just call tf.get_variable again with
|
||||||
|
# reuse=True to reuse variables, we can't easily make use of the object oriented
|
||||||
|
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
|
||||||
|
# emulating a static function variable.
|
||||||
|
if not rnn_impl_cudnn_rnn.cell:
|
||||||
|
# Forward direction cell:
|
||||||
|
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
|
||||||
|
num_units=Config.n_cell_dim,
|
||||||
|
input_mode='linear_input',
|
||||||
|
direction='unidirectional',
|
||||||
|
dtype=tf.float32)
|
||||||
|
rnn_impl_cudnn_rnn.cell = fw_cell
|
||||||
|
|
||||||
|
output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
|
||||||
|
sequence_lengths=seq_length)
|
||||||
|
|
||||||
|
return output, output_state
|
||||||
|
|
||||||
|
rnn_impl_cudnn_rnn.cell = None
|
||||||
|
|
||||||
|
|
||||||
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
|
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
|
||||||
# Forward direction cell:
|
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
|
||||||
fw_cell = tf.nn.rnn_cell.LSTMCell(Config.n_cell_dim, reuse=reuse)
|
# Forward direction cell:
|
||||||
|
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
|
||||||
|
reuse=reuse,
|
||||||
|
name='cudnn_compatible_lstm_cell')
|
||||||
|
|
||||||
# Split rank N tensor into list of rank N-1 tensors
|
# Split rank N tensor into list of rank N-1 tensors
|
||||||
x = [x[l] for l in range(x.shape[0])]
|
x = [x[l] for l in range(x.shape[0])]
|
||||||
|
|
||||||
# We parametrize the RNN implementation as the training and inference graph
|
output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
|
||||||
# need to do different things here.
|
inputs=x,
|
||||||
output, output_state = tf.nn.static_rnn(cell=fw_cell,
|
sequence_length=seq_length,
|
||||||
inputs=x,
|
initial_state=previous_state,
|
||||||
initial_state=previous_state,
|
dtype=tf.float32,
|
||||||
dtype=tf.float32,
|
scope='cell_0')
|
||||||
sequence_length=seq_length)
|
|
||||||
output = tf.concat(output, 0)
|
output = tf.concat(output, 0)
|
||||||
|
|
||||||
return output, output_state
|
return output, output_state
|
||||||
|
|
||||||
@ -183,8 +215,13 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
|
|||||||
# Obtain the next batch of data
|
# Obtain the next batch of data
|
||||||
(batch_x, batch_seq_len), batch_y = iterator.get_next()
|
(batch_x, batch_seq_len), batch_y = iterator.get_next()
|
||||||
|
|
||||||
|
if FLAGS.use_cudnn_rnn:
|
||||||
|
rnn_impl = rnn_impl_cudnn_rnn
|
||||||
|
else:
|
||||||
|
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||||
|
|
||||||
# Calculate the logits of the batch
|
# Calculate the logits of the batch
|
||||||
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse)
|
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
|
||||||
|
|
||||||
# Compute the CTC loss using TensorFlow's `ctc_loss`
|
# Compute the CTC loss using TensorFlow's `ctc_loss`
|
||||||
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
|
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
|
||||||
@ -573,7 +610,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||||||
|
|
||||||
if batch_size <= 0:
|
if batch_size <= 0:
|
||||||
# no state management since n_step is expected to be dynamic too (see below)
|
# no state management since n_step is expected to be dynamic too (see below)
|
||||||
previous_state = previous_state_c = previous_state_h = None
|
previous_state = None
|
||||||
else:
|
else:
|
||||||
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
||||||
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
||||||
@ -632,7 +669,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if not FLAGS.export_tflite:
|
if not FLAGS.export_tflite:
|
||||||
inputs.update({'input_lengths': seq_length})
|
inputs['input_lengths'] = seq_length
|
||||||
|
|
||||||
outputs = {
|
outputs = {
|
||||||
'outputs': logits,
|
'outputs': logits,
|
||||||
@ -659,20 +696,8 @@ def export():
|
|||||||
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
|
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
|
||||||
output_names = ",".join(output_names_tensors + output_names_ops)
|
output_names = ",".join(output_names_tensors + output_names_ops)
|
||||||
|
|
||||||
mapping = None
|
# Create a saver using variables from the above newly created graph
|
||||||
if FLAGS.export_tflite:
|
saver = tfv1.train.Saver()
|
||||||
# Create a saver using variables from the above newly created graph
|
|
||||||
# Training graph uses LSTMFusedCell, but the TFLite inference graph uses
|
|
||||||
# a static RNN with a normal cell, so we need to rewrite the names to
|
|
||||||
# match the training weights when restoring.
|
|
||||||
def fixup(name):
|
|
||||||
if name.startswith('rnn/lstm_cell/'):
|
|
||||||
return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
|
|
||||||
return name
|
|
||||||
|
|
||||||
mapping = {fixup(v.op.name): v for v in tf.global_variables()}
|
|
||||||
|
|
||||||
saver = tfv1.train.Saver(mapping)
|
|
||||||
|
|
||||||
# Restore variables from training checkpoint
|
# Restore variables from training checkpoint
|
||||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
|
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
|
||||||
|
@ -51,9 +51,11 @@ def create_flags():
|
|||||||
|
|
||||||
f.DEFINE_integer('export_batch_size', 1, 'number of elements per batch on the exported graph')
|
f.DEFINE_integer('export_batch_size', 1, 'number of elements per batch on the exported graph')
|
||||||
|
|
||||||
# Performance(UNSUPPORTED)
|
# Performance
|
||||||
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details')
|
|
||||||
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details')
|
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
|
||||||
|
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
|
||||||
|
f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
|
||||||
|
|
||||||
# Sample limits
|
# Sample limits
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user