Merge pull request #15855 from ksindi/export-retrained-inception

Export inception model after retrain
This commit is contained in:
Martin Wicke 2018-02-15 17:20:19 -08:00 committed by GitHub
commit 4347f17abf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -41,7 +41,6 @@ The subfolder names are important, since they define what label is applied to
each image, but the filenames themselves don't matter. Once your images are each image, but the filenames themselves don't matter. Once your images are
prepared, you can run the training with a command like this: prepared, you can run the training with a command like this:
```bash ```bash
bazel build tensorflow/examples/image_retraining:retrain && \ bazel build tensorflow/examples/image_retraining:retrain && \
bazel-bin/tensorflow/examples/image_retraining/retrain \ bazel-bin/tensorflow/examples/image_retraining/retrain \
@ -70,12 +69,14 @@ on resource-limited platforms, you can try the `--architecture` flag with a
Mobilenet model. For example: Mobilenet model. For example:
Run floating-point version of mobilenet: Run floating-point version of mobilenet:
```bash ```bash
python tensorflow/examples/image_retraining/retrain.py \ python tensorflow/examples/image_retraining/retrain.py \
--image_dir ~/flower_photos --architecture mobilenet_1.0_224 --image_dir ~/flower_photos --architecture mobilenet_1.0_224
``` ```
Run quantized version of mobilenet: Run quantized version of mobilenet:
```bash ```bash
python tensorflow/examples/image_retraining/retrain.py \ python tensorflow/examples/image_retraining/retrain.py \
--image_dir ~/flower_photos/ --architecture mobilenet_1.0_224_quantized --image_dir ~/flower_photos/ --architecture mobilenet_1.0_224_quantized
@ -96,6 +97,12 @@ Visualize the summaries with this command:
tensorboard --logdir /tmp/retrain_logs tensorboard --logdir /tmp/retrain_logs
To use with Tensorflow Serving:
```bash
tensorflow_model_server --port=9000 --model_name=inception \
--model_base_path=/tmp/saved_models/
```
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -1004,6 +1011,46 @@ def add_jpeg_decoding(input_width, input_height, input_depth, input_mean,
return jpeg_data, mul_image return jpeg_data, mul_image
def export_model(sess, architecture, saved_model_dir):
"""Exports model for serving.
Args:
sess: Current active TensorFlow Session.
architecture: Model architecture.
saved_model_dir: Directory in which to save exported model and variables.
"""
if architecture == 'inception_v3':
input_tensor = 'DecodeJpeg/contents:0'
elif architecture.startswith('mobilenet_'):
input_tensor = 'input:0'
else:
raise ValueError('Unknown architecture', architecture)
in_image = sess.graph.get_tensor_by_name(input_tensor)
inputs = {'image': tf.saved_model.utils.build_tensor_info(in_image)}
out_classes = sess.graph.get_tensor_by_name('final_result:0')
outputs = {'prediction':
tf.saved_model.utils.build_tensor_info(out_classes)}
signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs=inputs,
outputs=outputs,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
# Save out the SavedModel.
builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
signature
},
legacy_init_op=legacy_init_op)
builder.save()
def main(_): def main(_):
# Needed to make sure the logging output is visible. # Needed to make sure the logging output is visible.
# See https://github.com/tensorflow/tensorflow/issues/3047 # See https://github.com/tensorflow/tensorflow/issues/3047
@ -1179,6 +1226,8 @@ def main(_):
with gfile.FastGFile(FLAGS.output_labels, 'w') as f: with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
f.write('\n'.join(image_lists.keys()) + '\n') f.write('\n'.join(image_lists.keys()) + '\n')
export_model(sess, FLAGS.architecture, FLAGS.saved_model_dir)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -1362,5 +1411,11 @@ if __name__ == '__main__':
takes 128x128 images. See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html takes 128x128 images. See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html
for more information on Mobilenet.\ for more information on Mobilenet.\
""") """)
parser.add_argument(
'--saved_model_dir',
type=str,
default='/tmp/saved_models/1/',
help='Where to save the exported graph.'
)
FLAGS, unparsed = parser.parse_known_args() FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)