Merge pull request #15855 from ksindi/export-retrained-inception
Export inception model after retrain
This commit is contained in:
commit
4347f17abf
@ -41,7 +41,6 @@ The subfolder names are important, since they define what label is applied to
|
|||||||
each image, but the filenames themselves don't matter. Once your images are
|
each image, but the filenames themselves don't matter. Once your images are
|
||||||
prepared, you can run the training with a command like this:
|
prepared, you can run the training with a command like this:
|
||||||
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
bazel build tensorflow/examples/image_retraining:retrain && \
|
bazel build tensorflow/examples/image_retraining:retrain && \
|
||||||
bazel-bin/tensorflow/examples/image_retraining/retrain \
|
bazel-bin/tensorflow/examples/image_retraining/retrain \
|
||||||
@ -70,12 +69,14 @@ on resource-limited platforms, you can try the `--architecture` flag with a
|
|||||||
Mobilenet model. For example:
|
Mobilenet model. For example:
|
||||||
|
|
||||||
Run floating-point version of mobilenet:
|
Run floating-point version of mobilenet:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python tensorflow/examples/image_retraining/retrain.py \
|
python tensorflow/examples/image_retraining/retrain.py \
|
||||||
--image_dir ~/flower_photos --architecture mobilenet_1.0_224
|
--image_dir ~/flower_photos --architecture mobilenet_1.0_224
|
||||||
```
|
```
|
||||||
|
|
||||||
Run quantized version of mobilenet:
|
Run quantized version of mobilenet:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python tensorflow/examples/image_retraining/retrain.py \
|
python tensorflow/examples/image_retraining/retrain.py \
|
||||||
--image_dir ~/flower_photos/ --architecture mobilenet_1.0_224_quantized
|
--image_dir ~/flower_photos/ --architecture mobilenet_1.0_224_quantized
|
||||||
@ -96,6 +97,12 @@ Visualize the summaries with this command:
|
|||||||
|
|
||||||
tensorboard --logdir /tmp/retrain_logs
|
tensorboard --logdir /tmp/retrain_logs
|
||||||
|
|
||||||
|
To use with Tensorflow Serving:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
tensorflow_model_server --port=9000 --model_name=inception \
|
||||||
|
--model_base_path=/tmp/saved_models/
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -1004,6 +1011,46 @@ def add_jpeg_decoding(input_width, input_height, input_depth, input_mean,
|
|||||||
return jpeg_data, mul_image
|
return jpeg_data, mul_image
|
||||||
|
|
||||||
|
|
||||||
|
def export_model(sess, architecture, saved_model_dir):
|
||||||
|
"""Exports model for serving.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sess: Current active TensorFlow Session.
|
||||||
|
architecture: Model architecture.
|
||||||
|
saved_model_dir: Directory in which to save exported model and variables.
|
||||||
|
"""
|
||||||
|
if architecture == 'inception_v3':
|
||||||
|
input_tensor = 'DecodeJpeg/contents:0'
|
||||||
|
elif architecture.startswith('mobilenet_'):
|
||||||
|
input_tensor = 'input:0'
|
||||||
|
else:
|
||||||
|
raise ValueError('Unknown architecture', architecture)
|
||||||
|
in_image = sess.graph.get_tensor_by_name(input_tensor)
|
||||||
|
inputs = {'image': tf.saved_model.utils.build_tensor_info(in_image)}
|
||||||
|
|
||||||
|
out_classes = sess.graph.get_tensor_by_name('final_result:0')
|
||||||
|
outputs = {'prediction':
|
||||||
|
tf.saved_model.utils.build_tensor_info(out_classes)}
|
||||||
|
|
||||||
|
signature = tf.saved_model.signature_def_utils.build_signature_def(
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=outputs,
|
||||||
|
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
|
||||||
|
|
||||||
|
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
|
||||||
|
|
||||||
|
# Save out the SavedModel.
|
||||||
|
builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
|
||||||
|
builder.add_meta_graph_and_variables(
|
||||||
|
sess, [tf.saved_model.tag_constants.SERVING],
|
||||||
|
signature_def_map={
|
||||||
|
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
|
||||||
|
signature
|
||||||
|
},
|
||||||
|
legacy_init_op=legacy_init_op)
|
||||||
|
builder.save()
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
def main(_):
|
||||||
# Needed to make sure the logging output is visible.
|
# Needed to make sure the logging output is visible.
|
||||||
# See https://github.com/tensorflow/tensorflow/issues/3047
|
# See https://github.com/tensorflow/tensorflow/issues/3047
|
||||||
@ -1179,6 +1226,8 @@ def main(_):
|
|||||||
with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
|
with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
|
||||||
f.write('\n'.join(image_lists.keys()) + '\n')
|
f.write('\n'.join(image_lists.keys()) + '\n')
|
||||||
|
|
||||||
|
export_model(sess, FLAGS.architecture, FLAGS.saved_model_dir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -1362,5 +1411,11 @@ if __name__ == '__main__':
|
|||||||
takes 128x128 images. See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html
|
takes 128x128 images. See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html
|
||||||
for more information on Mobilenet.\
|
for more information on Mobilenet.\
|
||||||
""")
|
""")
|
||||||
|
parser.add_argument(
|
||||||
|
'--saved_model_dir',
|
||||||
|
type=str,
|
||||||
|
default='/tmp/saved_models/1/',
|
||||||
|
help='Where to save the exported graph.'
|
||||||
|
)
|
||||||
FLAGS, unparsed = parser.parse_known_args()
|
FLAGS, unparsed = parser.parse_known_args()
|
||||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||||
|
Loading…
Reference in New Issue
Block a user