Training and directly exporting as TF Lite
This commit is contained in:
parent
c3860f63a3
commit
5d30afdbad
@ -117,6 +117,7 @@ def create_flags():
|
||||
tf.app.flags.DEFINE_string ('export_dir', '', 'directory in which exported models are stored - if omitted, the model won\'t get exported')
|
||||
tf.app.flags.DEFINE_integer ('export_version', 1, 'version number of the exported model')
|
||||
tf.app.flags.DEFINE_boolean ('remove_export', False, 'whether to remove old exported models')
|
||||
tf.app.flags.DEFINE_boolean ('export_tflite', False, 'export a graph ready for TF Lite engine')
|
||||
tf.app.flags.DEFINE_boolean ('use_seq_length', True, 'have sequence_length in the exported graph (will make tfcompile unhappy)')
|
||||
tf.app.flags.DEFINE_integer ('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
|
||||
|
||||
@ -378,7 +379,7 @@ def variable_on_worker_level(name, shape, initializer):
|
||||
return var
|
||||
|
||||
|
||||
def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1, previous_state=None):
|
||||
def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1, previous_state=None, tflite=False):
|
||||
r'''
|
||||
That done, we will define the learned variables, the weights and biases,
|
||||
within the method ``BiRNN()`` which also constructs the neural network.
|
||||
@ -435,16 +436,32 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1
|
||||
# Both of which have inputs of length `n_cell_dim` and bias `1.0` for the forget gate of the LSTM.
|
||||
|
||||
# Forward direction cell:
|
||||
if not tflite:
|
||||
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(n_cell_dim, reuse=reuse)
|
||||
layers['fw_cell'] = fw_cell
|
||||
else:
|
||||
fw_cell = tf.nn.rnn_cell.LSTMCell(n_cell_dim, reuse=reuse)
|
||||
|
||||
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
|
||||
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
|
||||
layer_3 = tf.reshape(layer_3, [n_steps, batch_size, n_hidden_3])
|
||||
if tflite:
|
||||
# Generated StridedSlice, not supported by NNAPI
|
||||
#n_layer_3 = []
|
||||
#for l in range(layer_3.shape[0]):
|
||||
# n_layer_3.append(layer_3[l])
|
||||
#layer_3 = n_layer_3
|
||||
|
||||
# Unstack/Unpack is not supported by NNAPI
|
||||
layer_3 = tf.unstack(layer_3, n_steps)
|
||||
|
||||
# We parametrize the RNN implementation as the training and inference graph
|
||||
# need to do different things here.
|
||||
if not tflite:
|
||||
output, output_state = fw_cell(inputs=layer_3, dtype=tf.float32, sequence_length=seq_length, initial_state=previous_state)
|
||||
else:
|
||||
output, output_state = tf.nn.static_rnn(fw_cell, layer_3, previous_state, tf.float32)
|
||||
output = tf.concat(output, 0)
|
||||
|
||||
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
|
||||
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
|
||||
@ -1754,13 +1771,18 @@ def train(server=None):
|
||||
' or removing the contents of {0}.'.format(FLAGS.checkpoint_dir))
|
||||
sys.exit(1)
|
||||
|
||||
def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False):
|
||||
def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False, tflite=False):
|
||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||
input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2*n_context+1, n_input], name='input_node')
|
||||
seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths')
|
||||
|
||||
if not tflite:
|
||||
previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, n_cell_dim], initializer=None)
|
||||
previous_state_h = variable_on_worker_level('previous_state_h', [batch_size, n_cell_dim], initializer=None)
|
||||
else:
|
||||
previous_state_c = tf.placeholder(tf.float32, [batch_size, n_cell_dim], name='previous_state_c')
|
||||
previous_state_h = tf.placeholder(tf.float32, [batch_size, n_cell_dim], name='previous_state_h')
|
||||
|
||||
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)
|
||||
|
||||
logits, layers = BiRNN(batch_x=input_tensor,
|
||||
@ -1768,7 +1790,14 @@ def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False):
|
||||
dropout=no_dropout,
|
||||
batch_size=batch_size,
|
||||
n_steps=n_steps,
|
||||
previous_state=previous_state)
|
||||
previous_state=previous_state,
|
||||
tflite=tflite)
|
||||
|
||||
# TF Lite runtime will check that input dimensions are 1, 2 or 4
|
||||
# by default we get 3, the middle one being batch_size which is forced to
|
||||
# one on inference graph, so remove that dimension
|
||||
if tflite:
|
||||
logits = tf.squeeze(logits, [1])
|
||||
|
||||
# Apply softmax for CTC decoder
|
||||
logits = tf.nn.softmax(logits)
|
||||
@ -1776,13 +1805,11 @@ def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False):
|
||||
new_state_c, new_state_h = layers['rnn_output_state']
|
||||
|
||||
# Initial zero state
|
||||
if not tflite:
|
||||
zero_state = tf.zeros([batch_size, n_cell_dim], tf.float32)
|
||||
|
||||
initialize_c = tf.assign(previous_state_c, zero_state)
|
||||
initialize_h = tf.assign(previous_state_h, zero_state)
|
||||
|
||||
initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state')
|
||||
|
||||
with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]):
|
||||
logits = tf.identity(logits, name='logits')
|
||||
|
||||
@ -1796,6 +1823,24 @@ def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False):
|
||||
'initialize_state': initialize_state,
|
||||
}
|
||||
)
|
||||
else:
|
||||
logits = tf.identity(logits, name='logits')
|
||||
new_state_c = tf.identity(new_state_c, name='new_state_c')
|
||||
new_state_h = tf.identity(new_state_c, name='new_state_h')
|
||||
|
||||
return (
|
||||
{
|
||||
'input': input_tensor,
|
||||
'input_lengths': seq_length,
|
||||
'new_state_c': new_state_c,
|
||||
'new_state_h': new_state_h,
|
||||
},
|
||||
{
|
||||
'outputs': logits,
|
||||
'new_state_c': new_state_c,
|
||||
'new_state_h': new_state_h,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def export():
|
||||
@ -1808,38 +1853,59 @@ def export():
|
||||
tf.reset_default_graph()
|
||||
session = tf.Session(config=session_config)
|
||||
|
||||
inputs, outputs = create_inference_graph(batch_size=1, n_steps=FLAGS.n_steps)
|
||||
inputs, outputs = create_inference_graph(batch_size=1, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
|
||||
|
||||
# Create a saver using variables from the above newly created graph
|
||||
if not FLAGS.export_tflite:
|
||||
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
|
||||
else:
|
||||
# Create a saver using variables from the above newly created graph
|
||||
def fixup(name):
|
||||
if name.startswith('rnn/lstm_cell/'):
|
||||
return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
|
||||
return name
|
||||
|
||||
mapping = {fixup(v.op.name): v for v in tf.global_variables()}
|
||||
|
||||
saver = tf.train.Saver(mapping)
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
|
||||
checkpoint_path = checkpoint.model_checkpoint_path
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
output_filename = 'output_graph.pb'
|
||||
else:
|
||||
output_filename = 'output_graph.fb'
|
||||
|
||||
if FLAGS.remove_export:
|
||||
if os.path.isdir(FLAGS.export_dir):
|
||||
log_info('Removing old export')
|
||||
shutil.rmtree(FLAGS.export_dir)
|
||||
try:
|
||||
output_graph_path = os.path.join(FLAGS.export_dir, 'output_graph.pb')
|
||||
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
||||
|
||||
if not os.path.isdir(FLAGS.export_dir):
|
||||
os.makedirs(FLAGS.export_dir)
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
output_node_names = 'logits,initialize_state'
|
||||
variables_blacklist = 'previous_state_c,previous_state_h'
|
||||
else:
|
||||
output_node_names = 'logits,new_state_c,new_state_h'
|
||||
variables_blacklist = ''
|
||||
|
||||
# Freeze graph
|
||||
freeze_graph.freeze_graph_with_def_protos(
|
||||
input_graph_def=session.graph_def,
|
||||
input_saver_def=saver.as_saver_def(),
|
||||
input_checkpoint=checkpoint_path,
|
||||
output_node_names='logits,initialize_state',
|
||||
output_node_names=output_node_names,
|
||||
restore_op_name=None,
|
||||
filename_tensor_name=None,
|
||||
output_graph=output_graph_path,
|
||||
clear_devices=False,
|
||||
initializer_nodes='',
|
||||
variable_names_blacklist='previous_state_c,previous_state_h')
|
||||
variable_names_blacklist=variables_blacklist,
|
||||
initializer_nodes='')
|
||||
|
||||
log_info('Models exported at %s' % (FLAGS.export_dir))
|
||||
except RuntimeError as e:
|
||||
|
21
bin/run-tc-ldc93s1_tflite.sh
Executable file
21
bin/run-tc-ldc93s1_tflite.sh
Executable file
@ -0,0 +1,21 @@
|
||||
#!/bin/sh
|
||||
|
||||
set -xe
|
||||
|
||||
ldc93s1_dir="./data/ldc93s1-tc"
|
||||
ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv"
|
||||
|
||||
if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
|
||||
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
|
||||
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
|
||||
fi;
|
||||
|
||||
python -u DeepSpeech.py \
|
||||
--n_hidden 494 \
|
||||
--checkpoint_dir '/tmp/ckpt' \
|
||||
--export_dir '/tmp/train' \
|
||||
--decoder_library_path '/tmp/ds/libctc_decoder_with_kenlm.so' \
|
||||
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
||||
--lm_trie_path 'data/smoke_test/vocab.trie' \
|
||||
--notrain --notest \
|
||||
--export_tflite \
|
@ -62,9 +62,11 @@ fi;
|
||||
|
||||
pushd ${HOME}/DeepSpeech/ds/
|
||||
time ./bin/run-tc-ldc93s1_new.sh 105
|
||||
time ./bin/run-tc-ldc93s1_tflite.sh
|
||||
popd
|
||||
|
||||
cp /tmp/train/output_graph.pb ${TASKCLUSTER_ARTIFACTS}
|
||||
cp /tmp/train/output_graph.fb ${TASKCLUSTER_ARTIFACTS}
|
||||
|
||||
if [ ! -z "${CONVERT_GRAPHDEF_MEMMAPPED}" ]; then
|
||||
convert_graphdef=$(basename "${CONVERT_GRAPHDEF_MEMMAPPED}")
|
||||
|
Loading…
Reference in New Issue
Block a user