Respect --load when exporting
This commit is contained in:
parent
cedd72da9b
commit
c46d8396bc
|
@ -21,7 +21,6 @@ from datetime import datetime
|
||||||
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||||
from evaluate import evaluate
|
from evaluate import evaluate
|
||||||
from six.moves import zip, range
|
from six.moves import zip, range
|
||||||
from tensorflow.python.tools import freeze_graph, strip_unused_lib
|
|
||||||
from util.config import Config, initialize_globals
|
from util.config import Config, initialize_globals
|
||||||
from util.checkpoints import load_or_init_graph
|
from util.checkpoints import load_or_init_graph
|
||||||
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||||
|
@ -733,49 +732,40 @@ def export():
|
||||||
if FLAGS.export_language:
|
if FLAGS.export_language:
|
||||||
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
||||||
|
|
||||||
|
# Prevent further graph changes
|
||||||
|
tfv1.get_default_graph().finalize()
|
||||||
|
|
||||||
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
|
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
|
||||||
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
|
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
|
||||||
output_names = ",".join(output_names_tensors + output_names_ops)
|
output_names = output_names_tensors + output_names_ops
|
||||||
|
|
||||||
# Create a saver using variables from the above newly created graph
|
with tf.Session() as session:
|
||||||
saver = tfv1.train.Saver()
|
# Restore variables from checkpoint
|
||||||
|
if FLAGS.load == 'auto':
|
||||||
|
method_order = ['best', 'last']
|
||||||
|
else:
|
||||||
|
method_order = [FLAGS.load]
|
||||||
|
load_or_init_graph(session, method_order)
|
||||||
|
|
||||||
# Restore variables from training checkpoint
|
output_filename = FLAGS.export_name + '.pb'
|
||||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.save_checkpoint_dir)
|
if FLAGS.remove_export:
|
||||||
checkpoint_path = checkpoint.model_checkpoint_path
|
if os.path.isdir(FLAGS.export_dir):
|
||||||
|
log_info('Removing old export')
|
||||||
|
shutil.rmtree(FLAGS.export_dir)
|
||||||
|
|
||||||
output_filename = FLAGS.export_name + '.pb'
|
|
||||||
if FLAGS.remove_export:
|
|
||||||
if os.path.isdir(FLAGS.export_dir):
|
|
||||||
log_info('Removing old export')
|
|
||||||
shutil.rmtree(FLAGS.export_dir)
|
|
||||||
try:
|
|
||||||
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
||||||
|
|
||||||
if not os.path.isdir(FLAGS.export_dir):
|
if not os.path.isdir(FLAGS.export_dir):
|
||||||
os.makedirs(FLAGS.export_dir)
|
os.makedirs(FLAGS.export_dir)
|
||||||
|
|
||||||
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=''):
|
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
||||||
frozen = freeze_graph.freeze_graph_with_def_protos(
|
sess=session,
|
||||||
input_graph_def=tfv1.get_default_graph().as_graph_def(),
|
input_graph_def=tfv1.get_default_graph().as_graph_def(),
|
||||||
input_saver_def=saver.as_saver_def(),
|
output_node_names=output_names)
|
||||||
input_checkpoint=checkpoint_path,
|
|
||||||
output_node_names=output_node_names,
|
|
||||||
restore_op_name=None,
|
|
||||||
filename_tensor_name=None,
|
|
||||||
output_graph=output_file,
|
|
||||||
clear_devices=False,
|
|
||||||
variable_names_blacklist=variables_blacklist,
|
|
||||||
initializer_nodes='')
|
|
||||||
|
|
||||||
input_node_names = []
|
frozen_graph = tfv1.graph_util.extract_sub_graph(
|
||||||
return strip_unused_lib.strip_unused(
|
graph_def=frozen_graph,
|
||||||
input_graph_def=frozen,
|
dest_nodes=output_names)
|
||||||
input_node_names=input_node_names,
|
|
||||||
output_node_names=output_node_names.split(','),
|
|
||||||
placeholder_type_enum=tf.float32.as_datatype_enum)
|
|
||||||
|
|
||||||
frozen_graph = do_graph_freeze(output_node_names=output_names)
|
|
||||||
|
|
||||||
if not FLAGS.export_tflite:
|
if not FLAGS.export_tflite:
|
||||||
with open(output_graph_path, 'wb') as fout:
|
with open(output_graph_path, 'wb') as fout:
|
||||||
|
@ -784,7 +774,7 @@ def export():
|
||||||
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
||||||
|
|
||||||
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
||||||
converter.optimizations = [ tf.lite.Optimize.DEFAULT ]
|
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||||
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
|
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
|
||||||
converter.allow_custom_ops = True
|
converter.allow_custom_ops = True
|
||||||
tflite_model = converter.convert()
|
tflite_model = converter.convert()
|
||||||
|
@ -792,11 +782,7 @@ def export():
|
||||||
with open(output_tflite_path, 'wb') as fout:
|
with open(output_tflite_path, 'wb') as fout:
|
||||||
fout.write(tflite_model)
|
fout.write(tflite_model)
|
||||||
|
|
||||||
log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))
|
|
||||||
|
|
||||||
log_info('Models exported at %s' % (FLAGS.export_dir))
|
log_info('Models exported at %s' % (FLAGS.export_dir))
|
||||||
except RuntimeError as e:
|
|
||||||
log_error(str(e))
|
|
||||||
|
|
||||||
def package_zip():
|
def package_zip():
|
||||||
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
||||||
|
|
Loading…
Reference in New Issue