Replace examples/image_retraining by a pointer to TensorFlow Hub.
https://github.com/tensorflow/hub/tree/master/examples/image_retraining has the same tool, upgraded to use TensorFlow Hub instead of raw graph defs. PiperOrigin-RevId: 192502469
This commit is contained in:
parent
3fa224a453
commit
1a36eb1550
@ -1,51 +0,0 @@
|
|||||||
# Description:
|
|
||||||
# Transfer learning example for TensorFlow.
|
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
exports_files(["LICENSE"])
|
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
|
||||||
|
|
||||||
py_binary(
|
|
||||||
name = "retrain",
|
|
||||||
srcs = [
|
|
||||||
"retrain.py",
|
|
||||||
],
|
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
visibility = ["//tensorflow:__subpackages__"],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow:tensorflow_py",
|
|
||||||
"//tensorflow/python:framework",
|
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
|
||||||
"//tensorflow/python:graph_util",
|
|
||||||
"//tensorflow/python:platform",
|
|
||||||
"//tensorflow/python:util",
|
|
||||||
"//third_party/py/numpy",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_test(
|
|
||||||
name = "retrain_test",
|
|
||||||
size = "small",
|
|
||||||
srcs = [
|
|
||||||
"retrain.py",
|
|
||||||
"retrain_test.py",
|
|
||||||
],
|
|
||||||
data = [
|
|
||||||
":data/labels.txt",
|
|
||||||
"//tensorflow/examples/label_image:data/grace_hopper.jpg",
|
|
||||||
],
|
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
deps = [
|
|
||||||
":retrain",
|
|
||||||
"//tensorflow:tensorflow_py",
|
|
||||||
"//tensorflow/python:framework_test_lib",
|
|
||||||
"//tensorflow/python:graph_util",
|
|
||||||
"//tensorflow/python:platform",
|
|
||||||
"//tensorflow/python:platform_test",
|
|
||||||
"//tensorflow/python:tensor_shape",
|
|
||||||
"//tensorflow/python:util",
|
|
||||||
"//third_party/py/numpy",
|
|
||||||
],
|
|
||||||
)
|
|
@ -1,12 +1,15 @@
|
|||||||
|
**NOTE: This code has moved to**
|
||||||
|
https://github.com/tensorflow/hub/tree/master/examples/image_retraining
|
||||||
|
|
||||||
retrain.py is an example script that shows how one can adapt a pretrained
|
retrain.py is an example script that shows how one can adapt a pretrained
|
||||||
network for other classification problems. A detailed overview of this script
|
network for other classification problems (including use with TFLite and
|
||||||
can be found at:
|
quantization).
|
||||||
https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0
|
|
||||||
|
|
||||||
The script also shows how one can train layers
|
|
||||||
with quantized weights and activations instead of taking a pre-trained floating
|
|
||||||
point model and then quantizing weights and activations.
|
|
||||||
The output graphdef produced by this script is compatible with the TensorFlow
|
|
||||||
Lite Optimizing Converter and can be converted to TFLite format.
|
|
||||||
|
|
||||||
|
As of TensorFlow 1.7, it is recommended to use a pretrained network from
|
||||||
|
TensorFlow Hub, using the new version of this example found in the location
|
||||||
|
above, as explained in TensorFlow's revised [image retraining
|
||||||
|
tutorial](https://www.tensorflow.org/tutorials/image_retraining).
|
||||||
|
|
||||||
|
Older versions of this example (using frozen GraphDefs instead of
|
||||||
|
TensorFlow Hub modules) are available in the release branches of
|
||||||
|
TensorFlow versions up to and including 1.7.
|
||||||
|
@ -1,3 +0,0 @@
|
|||||||
Runner-up
|
|
||||||
Winner
|
|
||||||
Loser
|
|
File diff suppressed because it is too large
Load Diff
@ -1,148 +0,0 @@
|
|||||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
# pylint: disable=g-bad-import-order,unused-import
|
|
||||||
"""Tests the graph freezing tool."""
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
import os
|
|
||||||
|
|
||||||
from tensorflow.examples.image_retraining import retrain
|
|
||||||
from tensorflow.python.framework import test_util
|
|
||||||
|
|
||||||
|
|
||||||
class ImageRetrainingTest(test_util.TensorFlowTestCase):
|
|
||||||
|
|
||||||
def dummyImageLists(self):
|
|
||||||
return {'label_one': {'dir': 'somedir', 'training': ['image_one.jpg',
|
|
||||||
'image_two.jpg'],
|
|
||||||
'testing': ['image_three.jpg', 'image_four.jpg'],
|
|
||||||
'validation': ['image_five.jpg', 'image_six.jpg']},
|
|
||||||
'label_two': {'dir': 'otherdir', 'training': ['image_one.jpg',
|
|
||||||
'image_two.jpg'],
|
|
||||||
'testing': ['image_three.jpg', 'image_four.jpg'],
|
|
||||||
'validation': ['image_five.jpg', 'image_six.jpg']}}
|
|
||||||
|
|
||||||
def testGetImagePath(self):
|
|
||||||
image_lists = self.dummyImageLists()
|
|
||||||
self.assertEqual('image_dir/somedir/image_one.jpg', retrain.get_image_path(
|
|
||||||
image_lists, 'label_one', 0, 'image_dir', 'training'))
|
|
||||||
self.assertEqual('image_dir/otherdir/image_four.jpg',
|
|
||||||
retrain.get_image_path(image_lists, 'label_two', 1,
|
|
||||||
'image_dir', 'testing'))
|
|
||||||
|
|
||||||
def testGetBottleneckPath(self):
|
|
||||||
image_lists = self.dummyImageLists()
|
|
||||||
self.assertEqual('bottleneck_dir/somedir/image_five.jpg_imagenet_v3.txt',
|
|
||||||
retrain.get_bottleneck_path(
|
|
||||||
image_lists, 'label_one', 0, 'bottleneck_dir',
|
|
||||||
'validation', 'imagenet_v3'))
|
|
||||||
|
|
||||||
def testShouldDistortImage(self):
|
|
||||||
self.assertEqual(False, retrain.should_distort_images(False, 0, 0, 0))
|
|
||||||
self.assertEqual(True, retrain.should_distort_images(True, 0, 0, 0))
|
|
||||||
self.assertEqual(True, retrain.should_distort_images(False, 10, 0, 0))
|
|
||||||
self.assertEqual(True, retrain.should_distort_images(False, 0, 1, 0))
|
|
||||||
self.assertEqual(True, retrain.should_distort_images(False, 0, 0, 50))
|
|
||||||
|
|
||||||
def testAddInputDistortions(self):
|
|
||||||
with tf.Graph().as_default():
|
|
||||||
with tf.Session() as sess:
|
|
||||||
retrain.add_input_distortions(True, 10, 10, 10, 299, 299, 3, 128, 128)
|
|
||||||
self.assertIsNotNone(sess.graph.get_tensor_by_name('DistortJPGInput:0'))
|
|
||||||
self.assertIsNotNone(sess.graph.get_tensor_by_name('DistortResult:0'))
|
|
||||||
|
|
||||||
@tf.test.mock.patch.object(retrain, 'FLAGS', learning_rate=0.01)
|
|
||||||
def testAddFinalRetrainOps(self, flags_mock):
|
|
||||||
with tf.Graph().as_default():
|
|
||||||
with tf.Session() as sess:
|
|
||||||
bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck')
|
|
||||||
# Test creating final training op with quantization.
|
|
||||||
retrain.add_final_retrain_ops(5, 'final', bottleneck, 1024, False,
|
|
||||||
False)
|
|
||||||
self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0'))
|
|
||||||
|
|
||||||
@tf.test.mock.patch.object(retrain, 'FLAGS', learning_rate=0.01)
|
|
||||||
def testAddFinalRetrainOpsQuantized(self, flags_mock):
|
|
||||||
# Ensure that the training and eval graph for quantized models are correctly
|
|
||||||
# created.
|
|
||||||
with tf.Graph().as_default() as g:
|
|
||||||
with tf.Session() as sess:
|
|
||||||
bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck')
|
|
||||||
# Test creating final training op with quantization, set is_training to
|
|
||||||
# true.
|
|
||||||
retrain.add_final_retrain_ops(5, 'final', bottleneck, 1024, True, True)
|
|
||||||
self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0'))
|
|
||||||
found_fake_quant = 0
|
|
||||||
for op in g.get_operations():
|
|
||||||
if op.type == 'FakeQuantWithMinMaxVars':
|
|
||||||
found_fake_quant += 1
|
|
||||||
# Ensure that the inputs of each FakeQuant operations has 2 Assign
|
|
||||||
# operations in the training graph (Assign[Min,Max]Last,
|
|
||||||
# Assign[Min,Max]Ema)
|
|
||||||
self.assertEqual(2,
|
|
||||||
len([i for i in op.inputs if 'Assign' in i.name]))
|
|
||||||
self.assertEqual(found_fake_quant, 2)
|
|
||||||
with tf.Graph().as_default() as g:
|
|
||||||
with tf.Session() as sess:
|
|
||||||
bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck')
|
|
||||||
# Test creating final training op with quantization, set is_training to
|
|
||||||
# false.
|
|
||||||
retrain.add_final_retrain_ops(5, 'final', bottleneck, 1024, True, False)
|
|
||||||
self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0'))
|
|
||||||
found_fake_quant = 0
|
|
||||||
for op in g.get_operations():
|
|
||||||
if op.type == 'FakeQuantWithMinMaxVars':
|
|
||||||
found_fake_quant += 1
|
|
||||||
for i in op.inputs:
|
|
||||||
# Ensure that no operations are Assign operation since this is the
|
|
||||||
# evaluation graph.
|
|
||||||
self.assertTrue('Assign' not in i.name)
|
|
||||||
self.assertEqual(found_fake_quant, 2)
|
|
||||||
|
|
||||||
def testAddEvaluationStep(self):
|
|
||||||
with tf.Graph().as_default():
|
|
||||||
final = tf.placeholder(tf.float32, [1], name='final')
|
|
||||||
gt = tf.placeholder(tf.int64, [1], name='gt')
|
|
||||||
self.assertIsNotNone(retrain.add_evaluation_step(final, gt))
|
|
||||||
|
|
||||||
def testAddJpegDecoding(self):
|
|
||||||
with tf.Graph().as_default():
|
|
||||||
jpeg_data, mul_image = retrain.add_jpeg_decoding(10, 10, 3, 0, 255)
|
|
||||||
self.assertIsNotNone(jpeg_data)
|
|
||||||
self.assertIsNotNone(mul_image)
|
|
||||||
|
|
||||||
def testCreateModelInfo(self):
|
|
||||||
did_raise_value_error = False
|
|
||||||
try:
|
|
||||||
retrain.create_model_info('no_such_model_name')
|
|
||||||
except ValueError:
|
|
||||||
did_raise_value_error = True
|
|
||||||
self.assertTrue(did_raise_value_error)
|
|
||||||
model_info = retrain.create_model_info('inception_v3')
|
|
||||||
self.assertIsNotNone(model_info)
|
|
||||||
self.assertEqual(299, model_info['input_width'])
|
|
||||||
|
|
||||||
def testCreateModelInfoQuantized(self):
|
|
||||||
# Test for mobilenet_quantized
|
|
||||||
model_info = retrain.create_model_info('mobilenet_1.0_224')
|
|
||||||
self.assertIsNotNone(model_info)
|
|
||||||
self.assertEqual(224, model_info['input_width'])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
tf.test.main()
|
|
Loading…
Reference in New Issue
Block a user