Respect --load when exporting
This commit is contained in:
parent
cedd72da9b
commit
c46d8396bc
|
@ -21,7 +21,6 @@ from datetime import datetime
|
|||
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||
from evaluate import evaluate
|
||||
from six.moves import zip, range
|
||||
from tensorflow.python.tools import freeze_graph, strip_unused_lib
|
||||
from util.config import Config, initialize_globals
|
||||
from util.checkpoints import load_or_init_graph
|
||||
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||
|
@ -733,49 +732,40 @@ def export():
|
|||
if FLAGS.export_language:
|
||||
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
||||
|
||||
# Prevent further graph changes
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
|
||||
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
|
||||
output_names = ",".join(output_names_tensors + output_names_ops)
|
||||
output_names = output_names_tensors + output_names_ops
|
||||
|
||||
# Create a saver using variables from the above newly created graph
|
||||
saver = tfv1.train.Saver()
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.save_checkpoint_dir)
|
||||
checkpoint_path = checkpoint.model_checkpoint_path
|
||||
with tf.Session() as session:
|
||||
# Restore variables from checkpoint
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
output_filename = FLAGS.export_name + '.pb'
|
||||
if FLAGS.remove_export:
|
||||
if os.path.isdir(FLAGS.export_dir):
|
||||
log_info('Removing old export')
|
||||
shutil.rmtree(FLAGS.export_dir)
|
||||
try:
|
||||
|
||||
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
||||
|
||||
if not os.path.isdir(FLAGS.export_dir):
|
||||
os.makedirs(FLAGS.export_dir)
|
||||
|
||||
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=''):
|
||||
frozen = freeze_graph.freeze_graph_with_def_protos(
|
||||
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
||||
sess=session,
|
||||
input_graph_def=tfv1.get_default_graph().as_graph_def(),
|
||||
input_saver_def=saver.as_saver_def(),
|
||||
input_checkpoint=checkpoint_path,
|
||||
output_node_names=output_node_names,
|
||||
restore_op_name=None,
|
||||
filename_tensor_name=None,
|
||||
output_graph=output_file,
|
||||
clear_devices=False,
|
||||
variable_names_blacklist=variables_blacklist,
|
||||
initializer_nodes='')
|
||||
output_node_names=output_names)
|
||||
|
||||
input_node_names = []
|
||||
return strip_unused_lib.strip_unused(
|
||||
input_graph_def=frozen,
|
||||
input_node_names=input_node_names,
|
||||
output_node_names=output_node_names.split(','),
|
||||
placeholder_type_enum=tf.float32.as_datatype_enum)
|
||||
|
||||
frozen_graph = do_graph_freeze(output_node_names=output_names)
|
||||
frozen_graph = tfv1.graph_util.extract_sub_graph(
|
||||
graph_def=frozen_graph,
|
||||
dest_nodes=output_names)
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
with open(output_graph_path, 'wb') as fout:
|
||||
|
@ -784,7 +774,7 @@ def export():
|
|||
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
||||
|
||||
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
||||
converter.optimizations = [ tf.lite.Optimize.DEFAULT ]
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
|
||||
converter.allow_custom_ops = True
|
||||
tflite_model = converter.convert()
|
||||
|
@ -792,11 +782,7 @@ def export():
|
|||
with open(output_tflite_path, 'wb') as fout:
|
||||
fout.write(tflite_model)
|
||||
|
||||
log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))
|
||||
|
||||
log_info('Models exported at %s' % (FLAGS.export_dir))
|
||||
except RuntimeError as e:
|
||||
log_error(str(e))
|
||||
|
||||
def package_zip():
|
||||
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
||||
|
|
Loading…
Reference in New Issue