Call toco during export
This commit is contained in:
parent
c67f66f864
commit
92ded598fb
@ -18,10 +18,12 @@ import time
|
||||
import traceback
|
||||
import inspect
|
||||
import progressbar
|
||||
import tempfile
|
||||
|
||||
from functools import partial
|
||||
from six.moves import zip, range, filter, urllib, BaseHTTPServer
|
||||
from tensorflow.python.tools import freeze_graph
|
||||
from tensorflow.contrib.lite.python import tflite_convert
|
||||
from threading import Thread, Lock
|
||||
from util.audio import audiofile_to_input_vector
|
||||
from util.feeding import DataSet, ModelFeeder
|
||||
@ -1831,9 +1833,8 @@ def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False, tfli
|
||||
return (
|
||||
{
|
||||
'input': input_tensor,
|
||||
'input_lengths': seq_length,
|
||||
'new_state_c': new_state_c,
|
||||
'new_state_h': new_state_h,
|
||||
'previous_state_c': previous_state_c,
|
||||
'previous_state_h': previous_state_h,
|
||||
},
|
||||
{
|
||||
'outputs': logits,
|
||||
@ -1849,11 +1850,17 @@ def export():
|
||||
'''
|
||||
log_info('Exporting the model...')
|
||||
with tf.device('/cpu:0'):
|
||||
from tensorflow.python.framework.ops import Tensor, Operation
|
||||
|
||||
tf.reset_default_graph()
|
||||
session = tf.Session(config=session_config)
|
||||
|
||||
inputs, outputs = create_inference_graph(batch_size=1, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
|
||||
input_names = ",".join(tensor.op.name for tensor in inputs.values())
|
||||
output_names_tensors = [ tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor) ]
|
||||
output_names_ops = [ tensor.name for tensor in outputs.values() if isinstance(tensor, Operation) ]
|
||||
output_names = ",".join(output_names_tensors + output_names_ops)
|
||||
input_shapes = ":".join(",".join(map(str, tensor.shape)) for tensor in inputs.values())
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
|
||||
@ -1872,11 +1879,7 @@ def export():
|
||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
|
||||
checkpoint_path = checkpoint.model_checkpoint_path
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
output_filename = 'output_graph.pb'
|
||||
else:
|
||||
output_filename = 'output_graph.fb'
|
||||
|
||||
output_filename = 'output_graph.pb'
|
||||
if FLAGS.remove_export:
|
||||
if os.path.isdir(FLAGS.export_dir):
|
||||
log_info('Removing old export')
|
||||
@ -1887,31 +1890,61 @@ def export():
|
||||
if not os.path.isdir(FLAGS.export_dir):
|
||||
os.makedirs(FLAGS.export_dir)
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
output_node_names = 'logits,initialize_state'
|
||||
variables_blacklist = 'previous_state_c,previous_state_h'
|
||||
else:
|
||||
output_node_names = 'logits,new_state_c,new_state_h'
|
||||
variables_blacklist = ''
|
||||
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
|
||||
freeze_graph.freeze_graph_with_def_protos(
|
||||
input_graph_def=session.graph_def,
|
||||
input_saver_def=saver.as_saver_def(),
|
||||
input_checkpoint=checkpoint_path,
|
||||
output_node_names=output_node_names,
|
||||
restore_op_name=None,
|
||||
filename_tensor_name=None,
|
||||
output_graph=output_file,
|
||||
clear_devices=False,
|
||||
variable_names_blacklist=variables_blacklist,
|
||||
initializer_nodes='')
|
||||
|
||||
# Freeze graph
|
||||
freeze_graph.freeze_graph_with_def_protos(
|
||||
input_graph_def=session.graph_def,
|
||||
input_saver_def=saver.as_saver_def(),
|
||||
input_checkpoint=checkpoint_path,
|
||||
output_node_names=output_node_names,
|
||||
restore_op_name=None,
|
||||
filename_tensor_name=None,
|
||||
output_graph=output_graph_path,
|
||||
clear_devices=False,
|
||||
variable_names_blacklist=variables_blacklist,
|
||||
initializer_nodes='')
|
||||
if not FLAGS.export_tflite:
|
||||
do_graph_freeze(output_file=output_graph_path, output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
|
||||
else:
|
||||
temp_fd, temp_freeze = tempfile.mkstemp(dir=FLAGS.export_dir)
|
||||
os.close(temp_fd)
|
||||
do_graph_freeze(output_file=temp_freeze, output_node_names=output_names, variables_blacklist='')
|
||||
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
||||
class TFLiteFlags():
|
||||
def __init__(self):
|
||||
self.graph_def_file = temp_freeze
|
||||
self.inference_type = 'FLOAT'
|
||||
self.input_arrays = input_names
|
||||
self.input_shapes = input_shapes
|
||||
self.output_arrays = output_names
|
||||
self.output_file = output_tflite_path
|
||||
self.output_format = 'TFLITE'
|
||||
|
||||
default_empty = [
|
||||
'inference_input_type',
|
||||
'mean_values',
|
||||
'default_ranges_min', 'default_ranges_max',
|
||||
'drop_control_dependency',
|
||||
'reorder_across_fake_quant',
|
||||
'change_concat_input_ranges',
|
||||
'allow_custom_ops',
|
||||
'converter_mode',
|
||||
'post_training_quantize',
|
||||
'dump_graphviz_dir',
|
||||
'dump_graphviz_video'
|
||||
]
|
||||
for e in default_empty:
|
||||
self.__dict__[e] = None
|
||||
|
||||
flags = TFLiteFlags()
|
||||
tflite_convert._convert_model(flags)
|
||||
os.unlink(temp_freeze)
|
||||
log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))
|
||||
|
||||
log_info('Models exported at %s' % (FLAGS.export_dir))
|
||||
except RuntimeError as e:
|
||||
log_error(str(e))
|
||||
|
||||
|
||||
def do_single_file_inference(input_file_path):
|
||||
with tf.Session(config=session_config) as session:
|
||||
inputs, outputs = create_inference_graph(batch_size=1, use_new_decoder=True)
|
||||
|
@ -45,6 +45,7 @@ See the output of `deepspeech -h` for more information on the use of `deepspeech
|
||||
- [Training a model](#training-a-model)
|
||||
- [Checkpointing](#checkpointing)
|
||||
- [Exporting a model for inference](#exporting-a-model-for-inference)
|
||||
- [Exporting a model for TFLite](#exporting-a-model-for-tflite)
|
||||
- [Distributed computing across more than one machine](#distributed-training-across-more-than-one-machine)
|
||||
- [Continuing training from a release model](#continuing-training-from-a-release-model)
|
||||
- [Code documentation](#code-documentation)
|
||||
@ -317,6 +318,10 @@ Be aware however that checkpoints are only valid for the same model geometry the
|
||||
If the `--export_dir` parameter is provided, a model will have been exported to this directory during training.
|
||||
Refer to the corresponding [README.md](native_client/README.md) for information on building and running a client that can use the exported model.
|
||||
|
||||
### Exporting a model for TFLite
|
||||
|
||||
If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the `--export_tflite` flag. If you already have a trained model, you can re-export it for TFLite by running `DeepSpeech.py` again and specifying the same `checkpoint_dir` that you used for training, as well as passing `--notrain --notest --export_tflite --export_dir /model/export/destination`.
|
||||
|
||||
### Making a mmap-able model for inference
|
||||
|
||||
The `output_graph.pb` model file generated in the above step will be loaded in memory to be dealt with when running inference.
|
||||
|
@ -66,7 +66,7 @@ pushd ${HOME}/DeepSpeech/ds/
|
||||
popd
|
||||
|
||||
cp /tmp/train/output_graph.pb ${TASKCLUSTER_ARTIFACTS}
|
||||
cp /tmp/train/output_graph.fb ${TASKCLUSTER_ARTIFACTS}
|
||||
cp /tmp/train/output_graph.tflite ${TASKCLUSTER_ARTIFACTS}
|
||||
|
||||
if [ ! -z "${CONVERT_GRAPHDEF_MEMMAPPED}" ]; then
|
||||
convert_graphdef=$(basename "${CONVERT_GRAPHDEF_MEMMAPPED}")
|
||||
|
Loading…
Reference in New Issue
Block a user