Merge pull request #2174 from lissyx/strip-unused

Apply strip_unused on inference graph
This commit is contained in:
lissyx 2019-06-14 19:39:00 +02:00 committed by GitHub
commit dce068f3cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -18,7 +18,7 @@ from datetime import datetime
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
from evaluate import evaluate
from six.moves import zip, range
from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import freeze_graph, strip_unused_lib
from util.config import Config, initialize_globals
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from util.flags import create_flags, FLAGS
@ -711,7 +711,7 @@ def export():
os.makedirs(FLAGS.export_dir)
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
return freeze_graph.freeze_graph_with_def_protos(
frozen = freeze_graph.freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(),
input_saver_def=saver.as_saver_def(),
input_checkpoint=checkpoint_path,
@ -723,6 +723,13 @@ def export():
variable_names_blacklist=variables_blacklist,
initializer_nodes='')
input_node_names = []
return strip_unused_lib.strip_unused(
input_graph_def=frozen,
input_node_names=input_node_names,
output_node_names=output_node_names.split(','),
placeholder_type_enum=tf.float32.as_datatype_enum)
if not FLAGS.export_tflite:
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip())