Merge pull request #2174 from lissyx/strip-unused
Apply strip_unused on inference graph
This commit is contained in:
commit
dce068f3cb
@ -18,7 +18,7 @@ from datetime import datetime
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||
from evaluate import evaluate
|
||||
from six.moves import zip, range
|
||||
from tensorflow.python.tools import freeze_graph
|
||||
from tensorflow.python.tools import freeze_graph, strip_unused_lib
|
||||
from util.config import Config, initialize_globals
|
||||
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||
from util.flags import create_flags, FLAGS
|
||||
@ -711,7 +711,7 @@ def export():
|
||||
os.makedirs(FLAGS.export_dir)
|
||||
|
||||
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
|
||||
return freeze_graph.freeze_graph_with_def_protos(
|
||||
frozen = freeze_graph.freeze_graph_with_def_protos(
|
||||
input_graph_def=tf.get_default_graph().as_graph_def(),
|
||||
input_saver_def=saver.as_saver_def(),
|
||||
input_checkpoint=checkpoint_path,
|
||||
@ -723,6 +723,13 @@ def export():
|
||||
variable_names_blacklist=variables_blacklist,
|
||||
initializer_nodes='')
|
||||
|
||||
input_node_names = []
|
||||
return strip_unused_lib.strip_unused(
|
||||
input_graph_def=frozen,
|
||||
input_node_names=input_node_names,
|
||||
output_node_names=output_node_names.split(','),
|
||||
placeholder_type_enum=tf.float32.as_datatype_enum)
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
|
||||
frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip())
|
||||
|
Loading…
x
Reference in New Issue
Block a user