From 5d30afdbad77084b090aecbf21bd78c551406671 Mon Sep 17 00:00:00 2001 From: Alexandre Lissy Date: Tue, 30 Oct 2018 09:43:24 +0100 Subject: [PATCH] Training and directly exporting as TF Lite --- DeepSpeech.py | 130 ++++++++++++++++++++++++++--------- bin/run-tc-ldc93s1_tflite.sh | 21 ++++++ tc-train-tests.sh | 2 + 3 files changed, 121 insertions(+), 32 deletions(-) create mode 100755 bin/run-tc-ldc93s1_tflite.sh diff --git a/DeepSpeech.py b/DeepSpeech.py index 832813b6..0f021be8 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -117,6 +117,7 @@ def create_flags(): tf.app.flags.DEFINE_string ('export_dir', '', 'directory in which exported models are stored - if omitted, the model won\'t get exported') tf.app.flags.DEFINE_integer ('export_version', 1, 'version number of the exported model') tf.app.flags.DEFINE_boolean ('remove_export', False, 'whether to remove old exported models') + tf.app.flags.DEFINE_boolean ('export_tflite', False, 'export a graph ready for TF Lite engine') tf.app.flags.DEFINE_boolean ('use_seq_length', True, 'have sequence_length in the exported graph (will make tfcompile unhappy)') tf.app.flags.DEFINE_integer ('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency') @@ -378,7 +379,7 @@ def variable_on_worker_level(name, shape, initializer): return var -def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1, previous_state=None): +def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1, previous_state=None, tflite=False): r''' That done, we will define the learned variables, the weights and biases, within the method ``BiRNN()`` which also constructs the neural network. @@ -435,16 +436,32 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1 # Both of which have inputs of length `n_cell_dim` and bias `1.0` for the forget gate of the LSTM. # Forward direction cell: - fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(n_cell_dim, reuse=reuse) - layers['fw_cell'] = fw_cell + if not tflite: + fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(n_cell_dim, reuse=reuse) + layers['fw_cell'] = fw_cell + else: + fw_cell = tf.nn.rnn_cell.LSTMCell(n_cell_dim, reuse=reuse) # `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`, # as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`. layer_3 = tf.reshape(layer_3, [n_steps, batch_size, n_hidden_3]) + if tflite: + # Generated StridedSlice, not supported by NNAPI + #n_layer_3 = [] + #for l in range(layer_3.shape[0]): + # n_layer_3.append(layer_3[l]) + #layer_3 = n_layer_3 + + # Unstack/Unpack is not supported by NNAPI + layer_3 = tf.unstack(layer_3, n_steps) # We parametrize the RNN implementation as the training and inference graph # need to do different things here. - output, output_state = fw_cell(inputs=layer_3, dtype=tf.float32, sequence_length=seq_length, initial_state=previous_state) + if not tflite: + output, output_state = fw_cell(inputs=layer_3, dtype=tf.float32, sequence_length=seq_length, initial_state=previous_state) + else: + output, output_state = tf.nn.static_rnn(fw_cell, layer_3, previous_state, tf.float32) + output = tf.concat(output, 0) # Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim] # to a tensor of shape [n_steps*batch_size, n_cell_dim] @@ -1754,13 +1771,18 @@ def train(server=None): ' or removing the contents of {0}.'.format(FLAGS.checkpoint_dir)) sys.exit(1) -def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False): +def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False, tflite=False): # Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input] input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2*n_context+1, n_input], name='input_node') seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths') - previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, n_cell_dim], initializer=None) - previous_state_h = variable_on_worker_level('previous_state_h', [batch_size, n_cell_dim], initializer=None) + if not tflite: + previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, n_cell_dim], initializer=None) + previous_state_h = variable_on_worker_level('previous_state_h', [batch_size, n_cell_dim], initializer=None) + else: + previous_state_c = tf.placeholder(tf.float32, [batch_size, n_cell_dim], name='previous_state_c') + previous_state_h = tf.placeholder(tf.float32, [batch_size, n_cell_dim], name='previous_state_h') + previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h) logits, layers = BiRNN(batch_x=input_tensor, @@ -1768,7 +1790,14 @@ def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False): dropout=no_dropout, batch_size=batch_size, n_steps=n_steps, - previous_state=previous_state) + previous_state=previous_state, + tflite=tflite) + + # TF Lite runtime will check that input dimensions are 1, 2 or 4 + # by default we get 3, the middle one being batch_size which is forced to + # one on inference graph, so remove that dimension + if tflite: + logits = tf.squeeze(logits, [1]) # Apply softmax for CTC decoder logits = tf.nn.softmax(logits) @@ -1776,26 +1805,42 @@ def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False): new_state_c, new_state_h = layers['rnn_output_state'] # Initial zero state - zero_state = tf.zeros([batch_size, n_cell_dim], tf.float32) + if not tflite: + zero_state = tf.zeros([batch_size, n_cell_dim], tf.float32) + initialize_c = tf.assign(previous_state_c, zero_state) + initialize_h = tf.assign(previous_state_h, zero_state) + initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state') + with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]): + logits = tf.identity(logits, name='logits') - initialize_c = tf.assign(previous_state_c, zero_state) - initialize_h = tf.assign(previous_state_h, zero_state) - - initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state') - - with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]): + return ( + { + 'input': input_tensor, + 'input_lengths': seq_length, + }, + { + 'outputs': logits, + 'initialize_state': initialize_state, + } + ) + else: logits = tf.identity(logits, name='logits') + new_state_c = tf.identity(new_state_c, name='new_state_c') + new_state_h = tf.identity(new_state_c, name='new_state_h') - return ( - { - 'input': input_tensor, - 'input_lengths': seq_length, - }, - { - 'outputs': logits, - 'initialize_state': initialize_state, - } - ) + return ( + { + 'input': input_tensor, + 'input_lengths': seq_length, + 'new_state_c': new_state_c, + 'new_state_h': new_state_h, + }, + { + 'outputs': logits, + 'new_state_c': new_state_c, + 'new_state_h': new_state_h, + } + ) def export(): @@ -1808,38 +1853,59 @@ def export(): tf.reset_default_graph() session = tf.Session(config=session_config) - inputs, outputs = create_inference_graph(batch_size=1, n_steps=FLAGS.n_steps) + inputs, outputs = create_inference_graph(batch_size=1, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite) + + if not FLAGS.export_tflite: + mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')} + else: + # Create a saver using variables from the above newly created graph + def fixup(name): + if name.startswith('rnn/lstm_cell/'): + return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/') + return name + + mapping = {fixup(v.op.name): v for v in tf.global_variables()} - # Create a saver using variables from the above newly created graph - mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')} saver = tf.train.Saver(mapping) # Restore variables from training checkpoint checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) checkpoint_path = checkpoint.model_checkpoint_path + if not FLAGS.export_tflite: + output_filename = 'output_graph.pb' + else: + output_filename = 'output_graph.fb' + if FLAGS.remove_export: if os.path.isdir(FLAGS.export_dir): log_info('Removing old export') shutil.rmtree(FLAGS.export_dir) try: - output_graph_path = os.path.join(FLAGS.export_dir, 'output_graph.pb') + output_graph_path = os.path.join(FLAGS.export_dir, output_filename) if not os.path.isdir(FLAGS.export_dir): os.makedirs(FLAGS.export_dir) + if not FLAGS.export_tflite: + output_node_names = 'logits,initialize_state' + variables_blacklist = 'previous_state_c,previous_state_h' + else: + output_node_names = 'logits,new_state_c,new_state_h' + variables_blacklist = '' + # Freeze graph freeze_graph.freeze_graph_with_def_protos( input_graph_def=session.graph_def, input_saver_def=saver.as_saver_def(), input_checkpoint=checkpoint_path, - output_node_names='logits,initialize_state', + output_node_names=output_node_names, restore_op_name=None, filename_tensor_name=None, output_graph=output_graph_path, clear_devices=False, - initializer_nodes='', - variable_names_blacklist='previous_state_c,previous_state_h') + variable_names_blacklist=variables_blacklist, + initializer_nodes='') log_info('Models exported at %s' % (FLAGS.export_dir)) except RuntimeError as e: diff --git a/bin/run-tc-ldc93s1_tflite.sh b/bin/run-tc-ldc93s1_tflite.sh new file mode 100755 index 00000000..df36d99b --- /dev/null +++ b/bin/run-tc-ldc93s1_tflite.sh @@ -0,0 +1,21 @@ +#!/bin/sh + +set -xe + +ldc93s1_dir="./data/ldc93s1-tc" +ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv" + +if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then + echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}." + python -u bin/import_ldc93s1.py ${ldc93s1_dir} +fi; + +python -u DeepSpeech.py \ + --n_hidden 494 \ + --checkpoint_dir '/tmp/ckpt' \ + --export_dir '/tmp/train' \ + --decoder_library_path '/tmp/ds/libctc_decoder_with_kenlm.so' \ + --lm_binary_path 'data/smoke_test/vocab.pruned.lm' \ + --lm_trie_path 'data/smoke_test/vocab.trie' \ + --notrain --notest \ + --export_tflite \ diff --git a/tc-train-tests.sh b/tc-train-tests.sh index b6de6d9c..fb65e891 100644 --- a/tc-train-tests.sh +++ b/tc-train-tests.sh @@ -62,9 +62,11 @@ fi; pushd ${HOME}/DeepSpeech/ds/ time ./bin/run-tc-ldc93s1_new.sh 105 + time ./bin/run-tc-ldc93s1_tflite.sh popd cp /tmp/train/output_graph.pb ${TASKCLUSTER_ARTIFACTS} +cp /tmp/train/output_graph.fb ${TASKCLUSTER_ARTIFACTS} if [ ! -z "${CONVERT_GRAPHDEF_MEMMAPPED}" ]; then convert_graphdef=$(basename "${CONVERT_GRAPHDEF_MEMMAPPED}")