Export inception model after retrain

This commit is contained in:
Kamil Sindi 2018-01-04 13:13:38 -05:00
parent 5999ae54f9
commit 81ec5d2093

View File

@ -96,6 +96,10 @@ Visualize the summaries with this command:
tensorboard --logdir /tmp/retrain_logs tensorboard --logdir /tmp/retrain_logs
To use with Tensorflow Serving:
tensorflow_model_server --port=9000 --model_name=inception --model_base_path=/tmp/saved_models/
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -1004,6 +1008,38 @@ def add_jpeg_decoding(input_width, input_height, input_depth, input_mean,
return jpeg_data, mul_image return jpeg_data, mul_image
def export_model(sess, architecture, saved_model_dir):
if architecture == 'inception_v3':
input_tensor = 'DecodeJpeg/contents:0'
elif architecture.startswith('mobilenet_'):
input_tensor = 'input:0'
else:
raise ValueError('Unknown architecture', architecture)
in_image = sess.graph.get_tensor_by_name(input_tensor)
inputs = {'image': tf.saved_model.utils.build_tensor_info(in_image)}
out_classes = sess.graph.get_tensor_by_name('final_result:0')
outputs = {'prediction': tf.saved_model.utils.build_tensor_info(out_classes)}
signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs=inputs,
outputs=outputs,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
)
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
# Save out the SavedModel.
builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature
},
legacy_init_op=legacy_init_op)
builder.save()
def main(_): def main(_):
# Needed to make sure the logging output is visible. # Needed to make sure the logging output is visible.
# See https://github.com/tensorflow/tensorflow/issues/3047 # See https://github.com/tensorflow/tensorflow/issues/3047
@ -1179,6 +1215,8 @@ def main(_):
with gfile.FastGFile(FLAGS.output_labels, 'w') as f: with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
f.write('\n'.join(image_lists.keys()) + '\n') f.write('\n'.join(image_lists.keys()) + '\n')
export_model(sess, FLAGS.architecture, FLAGS.saved_model_dir)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -1362,5 +1400,11 @@ if __name__ == '__main__':
takes 128x128 images. See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html takes 128x128 images. See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html
for more information on Mobilenet.\ for more information on Mobilenet.\
""") """)
parser.add_argument(
'--saved_model_dir',
type=str,
default='/tmp/saved_models/1/',
help='Where to save the exported graph.'
)
FLAGS, unparsed = parser.parse_known_args() FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)