Add tests for pretrained weights of Keras Applications.

PiperOrigin-RevId: 289376982
Change-Id: I76620361cf0018051e51849856dc6cc3101e0327
This commit is contained in:
Yanhui Liang 2020-01-12 21:51:49 -08:00 committed by TensorFlower Gardener
parent fbdf6b193f
commit c25b583371
2 changed files with 313 additions and 0 deletions

View File

@ -50,6 +50,205 @@ tf_py_test(
],
)
# Add target for each application module file, to make sure it only
# runs the test for the application models contained in that
# application module when it has been modified.
tf_py_test(
name = "applications_load_weight_test_resnet",
srcs = ["applications_load_weight_test.py"],
args = ["--module=resnet"],
main = "applications_load_weight_test.py",
tags = [
"no_oss", # TODO(b/146940090): fix kokoro error
"no_pip",
],
deps = [
":applications",
"//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "applications_load_weight_test_resnet_v2",
srcs = ["applications_load_weight_test.py"],
args = ["--module=resnet_v2"],
main = "applications_load_weight_test.py",
tags = [
"no_oss", # TODO(b/146940090): fix kokoro error
"no_pip",
],
deps = [
":applications",
"//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "applications_load_weight_test_vgg16",
srcs = ["applications_load_weight_test.py"],
args = ["--module=vgg16"],
main = "applications_load_weight_test.py",
tags = [
"no_oss", # TODO(b/146940090): fix kokoro error
"no_pip",
],
deps = [
":applications",
"//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "applications_load_weight_test_vgg19",
srcs = ["applications_load_weight_test.py"],
args = ["--module=vgg19"],
main = "applications_load_weight_test.py",
tags = [
"no_oss", # TODO(b/146940090): fix kokoro error
"no_pip",
],
deps = [
":applications",
"//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "applications_load_weight_test_xception",
srcs = ["applications_load_weight_test.py"],
args = ["--module=xception"],
main = "applications_load_weight_test.py",
tags = [
"no_oss", # TODO(b/146940090): fix kokoro error
"no_pip",
],
deps = [
":applications",
"//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "applications_load_weight_test_inception_v3",
srcs = ["applications_load_weight_test.py"],
args = ["--module=inception_v3"],
main = "applications_load_weight_test.py",
tags = [
"no_oss", # TODO(b/146940090): fix kokoro error
"no_pip",
],
deps = [
":applications",
"//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "applications_load_weight_test_inception_resnet_v2",
srcs = ["applications_load_weight_test.py"],
args = ["--module=inception_resnet_v2"],
main = "applications_load_weight_test.py",
tags = [
"no_oss", # TODO(b/146940090): fix kokoro error
"no_pip",
],
deps = [
":applications",
"//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "applications_load_weight_test_mobilenet",
srcs = ["applications_load_weight_test.py"],
args = ["--module=mobilenet"],
main = "applications_load_weight_test.py",
tags = [
"no_oss", # TODO(b/146940090): fix kokoro error
"no_pip",
],
deps = [
":applications",
"//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "applications_load_weight_test_mobilenet_v2",
srcs = ["applications_load_weight_test.py"],
args = ["--module=mobilenet_v2"],
main = "applications_load_weight_test.py",
tags = [
"no_oss", # TODO(b/146940090): fix kokoro error
"no_pip",
],
deps = [
":applications",
"//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "applications_load_weight_test_densenet",
size = "large",
srcs = ["applications_load_weight_test.py"],
args = ["--module=densenet"],
main = "applications_load_weight_test.py",
shard_count = 3,
tags = [
"no_oss", # TODO(b/146940090): fix kokoro error
"no_pip",
],
deps = [
":applications",
"//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "applications_load_weight_test_efficientnet",
size = "large",
srcs = ["applications_load_weight_test.py"],
args = ["--module=efficientnet"],
main = "applications_load_weight_test.py",
shard_count = 8,
tags = [
"no_oss", # TODO(b/146940090): fix kokoro error
"no_pip",
],
deps = [
":applications",
"//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "applications_load_weight_test_nasnet",
srcs = ["applications_load_weight_test.py"],
args = ["--module=nasnet"],
main = "applications_load_weight_test.py",
tags = [
"no_oss", # TODO(b/146940090): fix kokoro error
"no_pip",
],
deps = [
":applications",
"//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "imagenet_utils_test",
size = "medium",

View File

@ -0,0 +1,114 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Integration tests for Keras applications."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
from absl.testing import parameterized
import numpy as np
from tensorflow.python.keras.applications import densenet
from tensorflow.python.keras.applications import efficientnet
from tensorflow.python.keras.applications import inception_resnet_v2
from tensorflow.python.keras.applications import inception_v3
from tensorflow.python.keras.applications import mobilenet
from tensorflow.python.keras.applications import mobilenet_v2
from tensorflow.python.keras.applications import nasnet
from tensorflow.python.keras.applications import resnet
from tensorflow.python.keras.applications import resnet_v2
from tensorflow.python.keras.applications import vgg16
from tensorflow.python.keras.applications import vgg19
from tensorflow.python.keras.applications import xception
from tensorflow.python.keras.preprocessing import image
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.platform import test
ARG_TO_MODEL = {
'resnet': (resnet, [resnet.ResNet50, resnet.ResNet101, resnet.ResNet152]),
'resnet_v2': (resnet_v2, [resnet_v2.ResNet50V2, resnet_v2.ResNet101V2,
resnet_v2.ResNet152V2]),
'vgg16': (vgg16, [vgg16.VGG16]),
'vgg19': (vgg19, [vgg19.VGG19]),
'xception': (xception, [xception.Xception]),
'inception_v3': (inception_v3, [inception_v3.InceptionV3]),
'inception_resnet_v2': (inception_resnet_v2,
[inception_resnet_v2.InceptionResNetV2]),
'mobilenet': (mobilenet, [mobilenet.MobileNet]),
'mobilenet_v2': (mobilenet_v2, [mobilenet_v2.MobileNetV2]),
'densenet': (densenet, [densenet.DenseNet121,
densenet.DenseNet169, densenet.DenseNet201]),
'nasnet': (nasnet, [nasnet.NASNetMobile, nasnet.NASNetLarge]),
'efficientnet': (efficientnet,
[efficientnet.EfficientNetB0, efficientnet.EfficientNetB1,
efficientnet.EfficientNetB2, efficientnet.EfficientNetB3,
efficientnet.EfficientNetB4, efficientnet.EfficientNetB5,
efficientnet.EfficientNetB6, efficientnet.EfficientNetB7])
}
TEST_IMAGE_PATH = ('https://storage.googleapis.com/tensorflow/'
'keras-applications/tests/elephant.jpg')
_IMAGENET_CLASSES = 1000
# Add a flag to define which application module file is tested.
# This is set as an 'arg' in the build target to guarantee that
# it only triggers the tests of the application models in the module
# if that module file has been modified.
FLAGS = flags.FLAGS
flags.DEFINE_string('module', None,
'Application module used in this test.')
def _get_elephant(target_size):
# For models that don't include a Flatten step,
# the default is to accept variable-size inputs
# even when loading ImageNet weights (since it is possible).
# In this case, default to 299x299.
if target_size[0] is None:
target_size = (299, 299)
test_image = data_utils.get_file('elephant.jpg', TEST_IMAGE_PATH)
img = image.load_img(test_image, target_size=tuple(target_size))
x = image.img_to_array(img)
return np.expand_dims(x, axis=0)
class ApplicationsLoadWeightTest(test.TestCase, parameterized.TestCase):
def assertShapeEqual(self, shape1, shape2):
if len(shape1) != len(shape2):
raise AssertionError(
'Shapes are different rank: %s vs %s' % (shape1, shape2))
if shape1 != shape2:
raise AssertionError('Shapes differ: %s vs %s' % (shape1, shape2))
def test_application_pretrained_weights_loading(self):
app_module = ARG_TO_MODEL[FLAGS.module][0]
apps = ARG_TO_MODEL[FLAGS.module][1]
for app in apps:
model = app(weights='imagenet')
self.assertShapeEqual(model.output_shape, (None, _IMAGENET_CLASSES))
x = _get_elephant(model.input_shape[1:3])
x = app_module.preprocess_input(x)
preds = model.predict(x)
names = [p[1] for p in app_module.decode_predictions(preds)[0]]
# Test correct label is in top 3 (weak correctness test).
self.assertIn('African_elephant', names[:3])
if __name__ == '__main__':
test.main()