Respect --load when exporting

This commit is contained in:
Reuben Morais 2020-02-17 12:45:37 +01:00
parent cedd72da9b
commit c46d8396bc
1 changed files with 24 additions and 38 deletions

View File

@ -21,7 +21,6 @@ from datetime import datetime
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
from evaluate import evaluate
from six.moves import zip, range
from tensorflow.python.tools import freeze_graph, strip_unused_lib
from util.config import Config, initialize_globals
from util.checkpoints import load_or_init_graph
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
@ -733,49 +732,40 @@ def export():
if FLAGS.export_language:
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
# Prevent further graph changes
tfv1.get_default_graph().finalize()
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
output_names = ",".join(output_names_tensors + output_names_ops)
output_names = output_names_tensors + output_names_ops
# Create a saver using variables from the above newly created graph
saver = tfv1.train.Saver()
# Restore variables from training checkpoint
checkpoint = tf.train.get_checkpoint_state(FLAGS.save_checkpoint_dir)
checkpoint_path = checkpoint.model_checkpoint_path
with tf.Session() as session:
# Restore variables from checkpoint
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
output_filename = FLAGS.export_name + '.pb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir)
try:
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=''):
frozen = freeze_graph.freeze_graph_with_def_protos(
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
sess=session,
input_graph_def=tfv1.get_default_graph().as_graph_def(),
input_saver_def=saver.as_saver_def(),
input_checkpoint=checkpoint_path,
output_node_names=output_node_names,
restore_op_name=None,
filename_tensor_name=None,
output_graph=output_file,
clear_devices=False,
variable_names_blacklist=variables_blacklist,
initializer_nodes='')
output_node_names=output_names)
input_node_names = []
return strip_unused_lib.strip_unused(
input_graph_def=frozen,
input_node_names=input_node_names,
output_node_names=output_node_names.split(','),
placeholder_type_enum=tf.float32.as_datatype_enum)
frozen_graph = do_graph_freeze(output_node_names=output_names)
frozen_graph = tfv1.graph_util.extract_sub_graph(
graph_def=frozen_graph,
dest_nodes=output_names)
if not FLAGS.export_tflite:
with open(output_graph_path, 'wb') as fout:
@ -784,7 +774,7 @@ def export():
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
converter.optimizations = [ tf.lite.Optimize.DEFAULT ]
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
converter.allow_custom_ops = True
tflite_model = converter.convert()
@ -792,11 +782,7 @@ def export():
with open(output_tflite_path, 'wb') as fout:
fout.write(tflite_model)
log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))
log_info('Models exported at %s' % (FLAGS.export_dir))
except RuntimeError as e:
log_error(str(e))
def package_zip():
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip