Training and directly exporting as TF Lite

This commit is contained in:
Alexandre Lissy 2018-10-30 09:43:24 +01:00
parent c3860f63a3
commit 5d30afdbad
3 changed files with 121 additions and 32 deletions

View File

@ -117,6 +117,7 @@ def create_flags():
tf.app.flags.DEFINE_string ('export_dir', '', 'directory in which exported models are stored - if omitted, the model won\'t get exported')
tf.app.flags.DEFINE_integer ('export_version', 1, 'version number of the exported model')
tf.app.flags.DEFINE_boolean ('remove_export', False, 'whether to remove old exported models')
tf.app.flags.DEFINE_boolean ('export_tflite', False, 'export a graph ready for TF Lite engine')
tf.app.flags.DEFINE_boolean ('use_seq_length', True, 'have sequence_length in the exported graph (will make tfcompile unhappy)')
tf.app.flags.DEFINE_integer ('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
@ -378,7 +379,7 @@ def variable_on_worker_level(name, shape, initializer):
return var
def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1, previous_state=None):
def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1, previous_state=None, tflite=False):
r'''
That done, we will define the learned variables, the weights and biases,
within the method ``BiRNN()`` which also constructs the neural network.
@ -435,16 +436,32 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1
# Both of which have inputs of length `n_cell_dim` and bias `1.0` for the forget gate of the LSTM.
# Forward direction cell:
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(n_cell_dim, reuse=reuse)
layers['fw_cell'] = fw_cell
if not tflite:
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(n_cell_dim, reuse=reuse)
layers['fw_cell'] = fw_cell
else:
fw_cell = tf.nn.rnn_cell.LSTMCell(n_cell_dim, reuse=reuse)
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
layer_3 = tf.reshape(layer_3, [n_steps, batch_size, n_hidden_3])
if tflite:
# Generated StridedSlice, not supported by NNAPI
#n_layer_3 = []
#for l in range(layer_3.shape[0]):
# n_layer_3.append(layer_3[l])
#layer_3 = n_layer_3
# Unstack/Unpack is not supported by NNAPI
layer_3 = tf.unstack(layer_3, n_steps)
# We parametrize the RNN implementation as the training and inference graph
# need to do different things here.
output, output_state = fw_cell(inputs=layer_3, dtype=tf.float32, sequence_length=seq_length, initial_state=previous_state)
if not tflite:
output, output_state = fw_cell(inputs=layer_3, dtype=tf.float32, sequence_length=seq_length, initial_state=previous_state)
else:
output, output_state = tf.nn.static_rnn(fw_cell, layer_3, previous_state, tf.float32)
output = tf.concat(output, 0)
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
@ -1754,13 +1771,18 @@ def train(server=None):
' or removing the contents of {0}.'.format(FLAGS.checkpoint_dir))
sys.exit(1)
def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False):
def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False, tflite=False):
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2*n_context+1, n_input], name='input_node')
seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths')
previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, n_cell_dim], initializer=None)
previous_state_h = variable_on_worker_level('previous_state_h', [batch_size, n_cell_dim], initializer=None)
if not tflite:
previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, n_cell_dim], initializer=None)
previous_state_h = variable_on_worker_level('previous_state_h', [batch_size, n_cell_dim], initializer=None)
else:
previous_state_c = tf.placeholder(tf.float32, [batch_size, n_cell_dim], name='previous_state_c')
previous_state_h = tf.placeholder(tf.float32, [batch_size, n_cell_dim], name='previous_state_h')
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)
logits, layers = BiRNN(batch_x=input_tensor,
@ -1768,7 +1790,14 @@ def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False):
dropout=no_dropout,
batch_size=batch_size,
n_steps=n_steps,
previous_state=previous_state)
previous_state=previous_state,
tflite=tflite)
# TF Lite runtime will check that input dimensions are 1, 2 or 4
# by default we get 3, the middle one being batch_size which is forced to
# one on inference graph, so remove that dimension
if tflite:
logits = tf.squeeze(logits, [1])
# Apply softmax for CTC decoder
logits = tf.nn.softmax(logits)
@ -1776,26 +1805,42 @@ def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False):
new_state_c, new_state_h = layers['rnn_output_state']
# Initial zero state
zero_state = tf.zeros([batch_size, n_cell_dim], tf.float32)
if not tflite:
zero_state = tf.zeros([batch_size, n_cell_dim], tf.float32)
initialize_c = tf.assign(previous_state_c, zero_state)
initialize_h = tf.assign(previous_state_h, zero_state)
initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state')
with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]):
logits = tf.identity(logits, name='logits')
initialize_c = tf.assign(previous_state_c, zero_state)
initialize_h = tf.assign(previous_state_h, zero_state)
initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state')
with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]):
return (
{
'input': input_tensor,
'input_lengths': seq_length,
},
{
'outputs': logits,
'initialize_state': initialize_state,
}
)
else:
logits = tf.identity(logits, name='logits')
new_state_c = tf.identity(new_state_c, name='new_state_c')
new_state_h = tf.identity(new_state_c, name='new_state_h')
return (
{
'input': input_tensor,
'input_lengths': seq_length,
},
{
'outputs': logits,
'initialize_state': initialize_state,
}
)
return (
{
'input': input_tensor,
'input_lengths': seq_length,
'new_state_c': new_state_c,
'new_state_h': new_state_h,
},
{
'outputs': logits,
'new_state_c': new_state_c,
'new_state_h': new_state_h,
}
)
def export():
@ -1808,38 +1853,59 @@ def export():
tf.reset_default_graph()
session = tf.Session(config=session_config)
inputs, outputs = create_inference_graph(batch_size=1, n_steps=FLAGS.n_steps)
inputs, outputs = create_inference_graph(batch_size=1, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
if not FLAGS.export_tflite:
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
else:
# Create a saver using variables from the above newly created graph
def fixup(name):
if name.startswith('rnn/lstm_cell/'):
return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
return name
mapping = {fixup(v.op.name): v for v in tf.global_variables()}
# Create a saver using variables from the above newly created graph
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
saver = tf.train.Saver(mapping)
# Restore variables from training checkpoint
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
checkpoint_path = checkpoint.model_checkpoint_path
if not FLAGS.export_tflite:
output_filename = 'output_graph.pb'
else:
output_filename = 'output_graph.fb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir)
try:
output_graph_path = os.path.join(FLAGS.export_dir, 'output_graph.pb')
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)
if not FLAGS.export_tflite:
output_node_names = 'logits,initialize_state'
variables_blacklist = 'previous_state_c,previous_state_h'
else:
output_node_names = 'logits,new_state_c,new_state_h'
variables_blacklist = ''
# Freeze graph
freeze_graph.freeze_graph_with_def_protos(
input_graph_def=session.graph_def,
input_saver_def=saver.as_saver_def(),
input_checkpoint=checkpoint_path,
output_node_names='logits,initialize_state',
output_node_names=output_node_names,
restore_op_name=None,
filename_tensor_name=None,
output_graph=output_graph_path,
clear_devices=False,
initializer_nodes='',
variable_names_blacklist='previous_state_c,previous_state_h')
variable_names_blacklist=variables_blacklist,
initializer_nodes='')
log_info('Models exported at %s' % (FLAGS.export_dir))
except RuntimeError as e:

21
bin/run-tc-ldc93s1_tflite.sh Executable file
View File

@ -0,0 +1,21 @@
#!/bin/sh
set -xe
ldc93s1_dir="./data/ldc93s1-tc"
ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv"
if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
fi;
python -u DeepSpeech.py \
--n_hidden 494 \
--checkpoint_dir '/tmp/ckpt' \
--export_dir '/tmp/train' \
--decoder_library_path '/tmp/ds/libctc_decoder_with_kenlm.so' \
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
--lm_trie_path 'data/smoke_test/vocab.trie' \
--notrain --notest \
--export_tflite \

View File

@ -62,9 +62,11 @@ fi;
pushd ${HOME}/DeepSpeech/ds/
time ./bin/run-tc-ldc93s1_new.sh 105
time ./bin/run-tc-ldc93s1_tflite.sh
popd
cp /tmp/train/output_graph.pb ${TASKCLUSTER_ARTIFACTS}
cp /tmp/train/output_graph.fb ${TASKCLUSTER_ARTIFACTS}
if [ ! -z "${CONVERT_GRAPHDEF_MEMMAPPED}" ]; then
convert_graphdef=$(basename "${CONVERT_GRAPHDEF_MEMMAPPED}")