Merge pull request #2174 from lissyx/strip-unused
Apply strip_unused on inference graph
This commit is contained in:
commit
dce068f3cb
@ -18,7 +18,7 @@ from datetime import datetime
|
|||||||
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||||
from evaluate import evaluate
|
from evaluate import evaluate
|
||||||
from six.moves import zip, range
|
from six.moves import zip, range
|
||||||
from tensorflow.python.tools import freeze_graph
|
from tensorflow.python.tools import freeze_graph, strip_unused_lib
|
||||||
from util.config import Config, initialize_globals
|
from util.config import Config, initialize_globals
|
||||||
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||||
from util.flags import create_flags, FLAGS
|
from util.flags import create_flags, FLAGS
|
||||||
@ -711,7 +711,7 @@ def export():
|
|||||||
os.makedirs(FLAGS.export_dir)
|
os.makedirs(FLAGS.export_dir)
|
||||||
|
|
||||||
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
|
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
|
||||||
return freeze_graph.freeze_graph_with_def_protos(
|
frozen = freeze_graph.freeze_graph_with_def_protos(
|
||||||
input_graph_def=tf.get_default_graph().as_graph_def(),
|
input_graph_def=tf.get_default_graph().as_graph_def(),
|
||||||
input_saver_def=saver.as_saver_def(),
|
input_saver_def=saver.as_saver_def(),
|
||||||
input_checkpoint=checkpoint_path,
|
input_checkpoint=checkpoint_path,
|
||||||
@ -723,6 +723,13 @@ def export():
|
|||||||
variable_names_blacklist=variables_blacklist,
|
variable_names_blacklist=variables_blacklist,
|
||||||
initializer_nodes='')
|
initializer_nodes='')
|
||||||
|
|
||||||
|
input_node_names = []
|
||||||
|
return strip_unused_lib.strip_unused(
|
||||||
|
input_graph_def=frozen,
|
||||||
|
input_node_names=input_node_names,
|
||||||
|
output_node_names=output_node_names.split(','),
|
||||||
|
placeholder_type_enum=tf.float32.as_datatype_enum)
|
||||||
|
|
||||||
if not FLAGS.export_tflite:
|
if not FLAGS.export_tflite:
|
||||||
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
|
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
|
||||||
frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip())
|
frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user