Call toco during export

This commit is contained in:
Alexandre Lissy 2018-10-31 09:22:07 +01:00
parent c67f66f864
commit 92ded598fb
3 changed files with 66 additions and 28 deletions

View File

@ -18,10 +18,12 @@ import time
import traceback import traceback
import inspect import inspect
import progressbar import progressbar
import tempfile
from functools import partial from functools import partial
from six.moves import zip, range, filter, urllib, BaseHTTPServer from six.moves import zip, range, filter, urllib, BaseHTTPServer
from tensorflow.python.tools import freeze_graph from tensorflow.python.tools import freeze_graph
from tensorflow.contrib.lite.python import tflite_convert
from threading import Thread, Lock from threading import Thread, Lock
from util.audio import audiofile_to_input_vector from util.audio import audiofile_to_input_vector
from util.feeding import DataSet, ModelFeeder from util.feeding import DataSet, ModelFeeder
@ -1831,9 +1833,8 @@ def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False, tfli
return ( return (
{ {
'input': input_tensor, 'input': input_tensor,
'input_lengths': seq_length, 'previous_state_c': previous_state_c,
'new_state_c': new_state_c, 'previous_state_h': previous_state_h,
'new_state_h': new_state_h,
}, },
{ {
'outputs': logits, 'outputs': logits,
@ -1849,11 +1850,17 @@ def export():
''' '''
log_info('Exporting the model...') log_info('Exporting the model...')
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
from tensorflow.python.framework.ops import Tensor, Operation
tf.reset_default_graph() tf.reset_default_graph()
session = tf.Session(config=session_config) session = tf.Session(config=session_config)
inputs, outputs = create_inference_graph(batch_size=1, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite) inputs, outputs = create_inference_graph(batch_size=1, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
input_names = ",".join(tensor.op.name for tensor in inputs.values())
output_names_tensors = [ tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor) ]
output_names_ops = [ tensor.name for tensor in outputs.values() if isinstance(tensor, Operation) ]
output_names = ",".join(output_names_tensors + output_names_ops)
input_shapes = ":".join(",".join(map(str, tensor.shape)) for tensor in inputs.values())
if not FLAGS.export_tflite: if not FLAGS.export_tflite:
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')} mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
@ -1872,11 +1879,7 @@ def export():
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
checkpoint_path = checkpoint.model_checkpoint_path checkpoint_path = checkpoint.model_checkpoint_path
if not FLAGS.export_tflite: output_filename = 'output_graph.pb'
output_filename = 'output_graph.pb'
else:
output_filename = 'output_graph.fb'
if FLAGS.remove_export: if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir): if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export') log_info('Removing old export')
@ -1887,31 +1890,61 @@ def export():
if not os.path.isdir(FLAGS.export_dir): if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir) os.makedirs(FLAGS.export_dir)
if not FLAGS.export_tflite: def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
output_node_names = 'logits,initialize_state' freeze_graph.freeze_graph_with_def_protos(
variables_blacklist = 'previous_state_c,previous_state_h' input_graph_def=session.graph_def,
else: input_saver_def=saver.as_saver_def(),
output_node_names = 'logits,new_state_c,new_state_h' input_checkpoint=checkpoint_path,
variables_blacklist = '' output_node_names=output_node_names,
restore_op_name=None,
filename_tensor_name=None,
output_graph=output_file,
clear_devices=False,
variable_names_blacklist=variables_blacklist,
initializer_nodes='')
# Freeze graph if not FLAGS.export_tflite:
freeze_graph.freeze_graph_with_def_protos( do_graph_freeze(output_file=output_graph_path, output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
input_graph_def=session.graph_def, else:
input_saver_def=saver.as_saver_def(), temp_fd, temp_freeze = tempfile.mkstemp(dir=FLAGS.export_dir)
input_checkpoint=checkpoint_path, os.close(temp_fd)
output_node_names=output_node_names, do_graph_freeze(output_file=temp_freeze, output_node_names=output_names, variables_blacklist='')
restore_op_name=None, output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
filename_tensor_name=None, class TFLiteFlags():
output_graph=output_graph_path, def __init__(self):
clear_devices=False, self.graph_def_file = temp_freeze
variable_names_blacklist=variables_blacklist, self.inference_type = 'FLOAT'
initializer_nodes='') self.input_arrays = input_names
self.input_shapes = input_shapes
self.output_arrays = output_names
self.output_file = output_tflite_path
self.output_format = 'TFLITE'
default_empty = [
'inference_input_type',
'mean_values',
'default_ranges_min', 'default_ranges_max',
'drop_control_dependency',
'reorder_across_fake_quant',
'change_concat_input_ranges',
'allow_custom_ops',
'converter_mode',
'post_training_quantize',
'dump_graphviz_dir',
'dump_graphviz_video'
]
for e in default_empty:
self.__dict__[e] = None
flags = TFLiteFlags()
tflite_convert._convert_model(flags)
os.unlink(temp_freeze)
log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))
log_info('Models exported at %s' % (FLAGS.export_dir)) log_info('Models exported at %s' % (FLAGS.export_dir))
except RuntimeError as e: except RuntimeError as e:
log_error(str(e)) log_error(str(e))
def do_single_file_inference(input_file_path): def do_single_file_inference(input_file_path):
with tf.Session(config=session_config) as session: with tf.Session(config=session_config) as session:
inputs, outputs = create_inference_graph(batch_size=1, use_new_decoder=True) inputs, outputs = create_inference_graph(batch_size=1, use_new_decoder=True)

View File

@ -45,6 +45,7 @@ See the output of `deepspeech -h` for more information on the use of `deepspeech
- [Training a model](#training-a-model) - [Training a model](#training-a-model)
- [Checkpointing](#checkpointing) - [Checkpointing](#checkpointing)
- [Exporting a model for inference](#exporting-a-model-for-inference) - [Exporting a model for inference](#exporting-a-model-for-inference)
- [Exporting a model for TFLite](#exporting-a-model-for-tflite)
- [Distributed computing across more than one machine](#distributed-training-across-more-than-one-machine) - [Distributed computing across more than one machine](#distributed-training-across-more-than-one-machine)
- [Continuing training from a release model](#continuing-training-from-a-release-model) - [Continuing training from a release model](#continuing-training-from-a-release-model)
- [Code documentation](#code-documentation) - [Code documentation](#code-documentation)
@ -317,6 +318,10 @@ Be aware however that checkpoints are only valid for the same model geometry the
If the `--export_dir` parameter is provided, a model will have been exported to this directory during training. If the `--export_dir` parameter is provided, a model will have been exported to this directory during training.
Refer to the corresponding [README.md](native_client/README.md) for information on building and running a client that can use the exported model. Refer to the corresponding [README.md](native_client/README.md) for information on building and running a client that can use the exported model.
### Exporting a model for TFLite
If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the `--export_tflite` flag. If you already have a trained model, you can re-export it for TFLite by running `DeepSpeech.py` again and specifying the same `checkpoint_dir` that you used for training, as well as passing `--notrain --notest --export_tflite --export_dir /model/export/destination`.
### Making a mmap-able model for inference ### Making a mmap-able model for inference
The `output_graph.pb` model file generated in the above step will be loaded in memory to be dealt with when running inference. The `output_graph.pb` model file generated in the above step will be loaded in memory to be dealt with when running inference.

View File

@ -66,7 +66,7 @@ pushd ${HOME}/DeepSpeech/ds/
popd popd
cp /tmp/train/output_graph.pb ${TASKCLUSTER_ARTIFACTS} cp /tmp/train/output_graph.pb ${TASKCLUSTER_ARTIFACTS}
cp /tmp/train/output_graph.fb ${TASKCLUSTER_ARTIFACTS} cp /tmp/train/output_graph.tflite ${TASKCLUSTER_ARTIFACTS}
if [ ! -z "${CONVERT_GRAPHDEF_MEMMAPPED}" ]; then if [ ! -z "${CONVERT_GRAPHDEF_MEMMAPPED}" ]; then
convert_graphdef=$(basename "${CONVERT_GRAPHDEF_MEMMAPPED}") convert_graphdef=$(basename "${CONVERT_GRAPHDEF_MEMMAPPED}")