Training and directly exporting as TF Lite

This commit is contained in:
Alexandre Lissy 2018-10-30 09:43:24 +01:00
parent c3860f63a3
commit 5d30afdbad
3 changed files with 121 additions and 32 deletions

View File

@ -117,6 +117,7 @@ def create_flags():
tf.app.flags.DEFINE_string ('export_dir', '', 'directory in which exported models are stored - if omitted, the model won\'t get exported') tf.app.flags.DEFINE_string ('export_dir', '', 'directory in which exported models are stored - if omitted, the model won\'t get exported')
tf.app.flags.DEFINE_integer ('export_version', 1, 'version number of the exported model') tf.app.flags.DEFINE_integer ('export_version', 1, 'version number of the exported model')
tf.app.flags.DEFINE_boolean ('remove_export', False, 'whether to remove old exported models') tf.app.flags.DEFINE_boolean ('remove_export', False, 'whether to remove old exported models')
tf.app.flags.DEFINE_boolean ('export_tflite', False, 'export a graph ready for TF Lite engine')
tf.app.flags.DEFINE_boolean ('use_seq_length', True, 'have sequence_length in the exported graph (will make tfcompile unhappy)') tf.app.flags.DEFINE_boolean ('use_seq_length', True, 'have sequence_length in the exported graph (will make tfcompile unhappy)')
tf.app.flags.DEFINE_integer ('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency') tf.app.flags.DEFINE_integer ('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
@ -378,7 +379,7 @@ def variable_on_worker_level(name, shape, initializer):
return var return var
def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1, previous_state=None): def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1, previous_state=None, tflite=False):
r''' r'''
That done, we will define the learned variables, the weights and biases, That done, we will define the learned variables, the weights and biases,
within the method ``BiRNN()`` which also constructs the neural network. within the method ``BiRNN()`` which also constructs the neural network.
@ -435,16 +436,32 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1
# Both of which have inputs of length `n_cell_dim` and bias `1.0` for the forget gate of the LSTM. # Both of which have inputs of length `n_cell_dim` and bias `1.0` for the forget gate of the LSTM.
# Forward direction cell: # Forward direction cell:
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(n_cell_dim, reuse=reuse) if not tflite:
layers['fw_cell'] = fw_cell fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(n_cell_dim, reuse=reuse)
layers['fw_cell'] = fw_cell
else:
fw_cell = tf.nn.rnn_cell.LSTMCell(n_cell_dim, reuse=reuse)
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`, # `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`. # as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
layer_3 = tf.reshape(layer_3, [n_steps, batch_size, n_hidden_3]) layer_3 = tf.reshape(layer_3, [n_steps, batch_size, n_hidden_3])
if tflite:
# Generated StridedSlice, not supported by NNAPI
#n_layer_3 = []
#for l in range(layer_3.shape[0]):
# n_layer_3.append(layer_3[l])
#layer_3 = n_layer_3
# Unstack/Unpack is not supported by NNAPI
layer_3 = tf.unstack(layer_3, n_steps)
# We parametrize the RNN implementation as the training and inference graph # We parametrize the RNN implementation as the training and inference graph
# need to do different things here. # need to do different things here.
output, output_state = fw_cell(inputs=layer_3, dtype=tf.float32, sequence_length=seq_length, initial_state=previous_state) if not tflite:
output, output_state = fw_cell(inputs=layer_3, dtype=tf.float32, sequence_length=seq_length, initial_state=previous_state)
else:
output, output_state = tf.nn.static_rnn(fw_cell, layer_3, previous_state, tf.float32)
output = tf.concat(output, 0)
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim] # Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
# to a tensor of shape [n_steps*batch_size, n_cell_dim] # to a tensor of shape [n_steps*batch_size, n_cell_dim]
@ -1754,13 +1771,18 @@ def train(server=None):
' or removing the contents of {0}.'.format(FLAGS.checkpoint_dir)) ' or removing the contents of {0}.'.format(FLAGS.checkpoint_dir))
sys.exit(1) sys.exit(1)
def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False): def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False, tflite=False):
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input] # Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2*n_context+1, n_input], name='input_node') input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2*n_context+1, n_input], name='input_node')
seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths') seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths')
previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, n_cell_dim], initializer=None) if not tflite:
previous_state_h = variable_on_worker_level('previous_state_h', [batch_size, n_cell_dim], initializer=None) previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, n_cell_dim], initializer=None)
previous_state_h = variable_on_worker_level('previous_state_h', [batch_size, n_cell_dim], initializer=None)
else:
previous_state_c = tf.placeholder(tf.float32, [batch_size, n_cell_dim], name='previous_state_c')
previous_state_h = tf.placeholder(tf.float32, [batch_size, n_cell_dim], name='previous_state_h')
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h) previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)
logits, layers = BiRNN(batch_x=input_tensor, logits, layers = BiRNN(batch_x=input_tensor,
@ -1768,7 +1790,14 @@ def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False):
dropout=no_dropout, dropout=no_dropout,
batch_size=batch_size, batch_size=batch_size,
n_steps=n_steps, n_steps=n_steps,
previous_state=previous_state) previous_state=previous_state,
tflite=tflite)
# TF Lite runtime will check that input dimensions are 1, 2 or 4
# by default we get 3, the middle one being batch_size which is forced to
# one on inference graph, so remove that dimension
if tflite:
logits = tf.squeeze(logits, [1])
# Apply softmax for CTC decoder # Apply softmax for CTC decoder
logits = tf.nn.softmax(logits) logits = tf.nn.softmax(logits)
@ -1776,26 +1805,42 @@ def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False):
new_state_c, new_state_h = layers['rnn_output_state'] new_state_c, new_state_h = layers['rnn_output_state']
# Initial zero state # Initial zero state
zero_state = tf.zeros([batch_size, n_cell_dim], tf.float32) if not tflite:
zero_state = tf.zeros([batch_size, n_cell_dim], tf.float32)
initialize_c = tf.assign(previous_state_c, zero_state)
initialize_h = tf.assign(previous_state_h, zero_state)
initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state')
with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]):
logits = tf.identity(logits, name='logits')
initialize_c = tf.assign(previous_state_c, zero_state) return (
initialize_h = tf.assign(previous_state_h, zero_state) {
'input': input_tensor,
initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state') 'input_lengths': seq_length,
},
with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]): {
'outputs': logits,
'initialize_state': initialize_state,
}
)
else:
logits = tf.identity(logits, name='logits') logits = tf.identity(logits, name='logits')
new_state_c = tf.identity(new_state_c, name='new_state_c')
new_state_h = tf.identity(new_state_c, name='new_state_h')
return ( return (
{ {
'input': input_tensor, 'input': input_tensor,
'input_lengths': seq_length, 'input_lengths': seq_length,
}, 'new_state_c': new_state_c,
{ 'new_state_h': new_state_h,
'outputs': logits, },
'initialize_state': initialize_state, {
} 'outputs': logits,
) 'new_state_c': new_state_c,
'new_state_h': new_state_h,
}
)
def export(): def export():
@ -1808,38 +1853,59 @@ def export():
tf.reset_default_graph() tf.reset_default_graph()
session = tf.Session(config=session_config) session = tf.Session(config=session_config)
inputs, outputs = create_inference_graph(batch_size=1, n_steps=FLAGS.n_steps) inputs, outputs = create_inference_graph(batch_size=1, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
if not FLAGS.export_tflite:
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
else:
# Create a saver using variables from the above newly created graph
def fixup(name):
if name.startswith('rnn/lstm_cell/'):
return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
return name
mapping = {fixup(v.op.name): v for v in tf.global_variables()}
# Create a saver using variables from the above newly created graph
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
saver = tf.train.Saver(mapping) saver = tf.train.Saver(mapping)
# Restore variables from training checkpoint # Restore variables from training checkpoint
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
checkpoint_path = checkpoint.model_checkpoint_path checkpoint_path = checkpoint.model_checkpoint_path
if not FLAGS.export_tflite:
output_filename = 'output_graph.pb'
else:
output_filename = 'output_graph.fb'
if FLAGS.remove_export: if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir): if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export') log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir) shutil.rmtree(FLAGS.export_dir)
try: try:
output_graph_path = os.path.join(FLAGS.export_dir, 'output_graph.pb') output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
if not os.path.isdir(FLAGS.export_dir): if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir) os.makedirs(FLAGS.export_dir)
if not FLAGS.export_tflite:
output_node_names = 'logits,initialize_state'
variables_blacklist = 'previous_state_c,previous_state_h'
else:
output_node_names = 'logits,new_state_c,new_state_h'
variables_blacklist = ''
# Freeze graph # Freeze graph
freeze_graph.freeze_graph_with_def_protos( freeze_graph.freeze_graph_with_def_protos(
input_graph_def=session.graph_def, input_graph_def=session.graph_def,
input_saver_def=saver.as_saver_def(), input_saver_def=saver.as_saver_def(),
input_checkpoint=checkpoint_path, input_checkpoint=checkpoint_path,
output_node_names='logits,initialize_state', output_node_names=output_node_names,
restore_op_name=None, restore_op_name=None,
filename_tensor_name=None, filename_tensor_name=None,
output_graph=output_graph_path, output_graph=output_graph_path,
clear_devices=False, clear_devices=False,
initializer_nodes='', variable_names_blacklist=variables_blacklist,
variable_names_blacklist='previous_state_c,previous_state_h') initializer_nodes='')
log_info('Models exported at %s' % (FLAGS.export_dir)) log_info('Models exported at %s' % (FLAGS.export_dir))
except RuntimeError as e: except RuntimeError as e:

21
bin/run-tc-ldc93s1_tflite.sh Executable file
View File

@ -0,0 +1,21 @@
#!/bin/sh
set -xe
ldc93s1_dir="./data/ldc93s1-tc"
ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv"
if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
fi;
python -u DeepSpeech.py \
--n_hidden 494 \
--checkpoint_dir '/tmp/ckpt' \
--export_dir '/tmp/train' \
--decoder_library_path '/tmp/ds/libctc_decoder_with_kenlm.so' \
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
--lm_trie_path 'data/smoke_test/vocab.trie' \
--notrain --notest \
--export_tflite \

View File

@ -62,9 +62,11 @@ fi;
pushd ${HOME}/DeepSpeech/ds/ pushd ${HOME}/DeepSpeech/ds/
time ./bin/run-tc-ldc93s1_new.sh 105 time ./bin/run-tc-ldc93s1_new.sh 105
time ./bin/run-tc-ldc93s1_tflite.sh
popd popd
cp /tmp/train/output_graph.pb ${TASKCLUSTER_ARTIFACTS} cp /tmp/train/output_graph.pb ${TASKCLUSTER_ARTIFACTS}
cp /tmp/train/output_graph.fb ${TASKCLUSTER_ARTIFACTS}
if [ ! -z "${CONVERT_GRAPHDEF_MEMMAPPED}" ]; then if [ ! -z "${CONVERT_GRAPHDEF_MEMMAPPED}" ]; then
convert_graphdef=$(basename "${CONVERT_GRAPHDEF_MEMMAPPED}") convert_graphdef=$(basename "${CONVERT_GRAPHDEF_MEMMAPPED}")