From 77b77e1108e1ce53b38b46feb34ebd5bfa9255b0 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Tue, 15 Oct 2019 12:51:07 -0700 Subject: [PATCH] Remove the __init__.py for keras_application. All the reference to shortcut in __init__ has been replaced with explicit import. PiperOrigin-RevId: 274867057 Change-Id: Icd9d642d638821b8e825ec266aaa6d054ae900c8 --- .../debug/lib/check_numerics_callback_test.py | 4 +- tensorflow/python/keras/__init__.py | 1 - tensorflow/python/keras/api/BUILD | 12 +++++ .../python/keras/applications/__init__.py | 33 ++----------- .../keras/applications/applications_test.py | 46 +++++++++++-------- 5 files changed, 46 insertions(+), 50 deletions(-) diff --git a/tensorflow/python/debug/lib/check_numerics_callback_test.py b/tensorflow/python/debug/lib/check_numerics_callback_test.py index ae353ca9636..426ad946d74 100644 --- a/tensorflow/python/debug/lib/check_numerics_callback_test.py +++ b/tensorflow/python/debug/lib/check_numerics_callback_test.py @@ -31,10 +31,10 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.keras import applications from tensorflow.python.keras import layers from tensorflow.python.keras import models from tensorflow.python.keras import optimizer_v2 +from tensorflow.python.keras.applications import mobilenet_v2 from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker_v2 @@ -514,7 +514,7 @@ class CheckNumericsCallbackTest(test_util.TensorFlowTestCase): def testMobileNetV2Fit(self): """Test training Keras MobileNetV2 application works w/ check numerics.""" check_numerics_callback.enable_check_numerics() - model = applications.MobileNetV2(alpha=0.1, weights=None) + model = mobilenet_v2.MobileNetV2(alpha=0.1, weights=None) xs = np.zeros([2] + list(model.input_shape[1:])) ys = np.zeros([2] + list(model.output_shape[1:])) diff --git a/tensorflow/python/keras/__init__.py b/tensorflow/python/keras/__init__.py index 15fcd412d7b..d59d277d9c7 100644 --- a/tensorflow/python/keras/__init__.py +++ b/tensorflow/python/keras/__init__.py @@ -23,7 +23,6 @@ from __future__ import print_function from tensorflow.python import tf2 -from tensorflow.python.keras import applications from tensorflow.python.keras import datasets from tensorflow.python.keras import estimator from tensorflow.python.keras import layers diff --git a/tensorflow/python/keras/api/BUILD b/tensorflow/python/keras/api/BUILD index 1bcf0996801..e53d9c0ab82 100644 --- a/tensorflow/python/keras/api/BUILD +++ b/tensorflow/python/keras/api/BUILD @@ -14,6 +14,18 @@ keras_packages = [ "tensorflow.python", "tensorflow.python.keras", "tensorflow.python.keras.activations", + "tensorflow.python.keras.applications.densenet", + "tensorflow.python.keras.applications.imagenet_utils", + "tensorflow.python.keras.applications.inception_resnet_v2", + "tensorflow.python.keras.applications.inception_v3", + "tensorflow.python.keras.applications.mobilenet", + "tensorflow.python.keras.applications.mobilenet_v2", + "tensorflow.python.keras.applications.nasnet", + "tensorflow.python.keras.applications.resnet", + "tensorflow.python.keras.applications.resnet_v2", + "tensorflow.python.keras.applications.vgg16", + "tensorflow.python.keras.applications.vgg19", + "tensorflow.python.keras.applications.xception", "tensorflow.python.keras.backend", "tensorflow.python.keras.callbacks", "tensorflow.python.keras.callbacks_v1", diff --git a/tensorflow/python/keras/applications/__init__.py b/tensorflow/python/keras/applications/__init__.py index ef89793fd73..601c990e2f6 100644 --- a/tensorflow/python/keras/applications/__init__.py +++ b/tensorflow/python/keras/applications/__init__.py @@ -21,13 +21,6 @@ from __future__ import print_function import keras_applications -from tensorflow.python.keras import backend -from tensorflow.python.keras import engine -from tensorflow.python.keras import layers -from tensorflow.python.keras import models -from tensorflow.python.keras.utils import all_utils -from tensorflow.python.util import tf_inspect - def keras_modules_injection(base_fun): """Decorator injecting tf.keras replacements for Keras modules. @@ -39,6 +32,10 @@ def keras_modules_injection(base_fun): Decorated function that injects keyword argument for the tf.keras modules required by the Applications. """ + from tensorflow.python.keras import backend + from tensorflow.python.keras import layers + from tensorflow.python.keras import models + from tensorflow.python.keras.utils import all_utils def wrapper(*args, **kwargs): kwargs['backend'] = backend @@ -48,25 +45,3 @@ def keras_modules_injection(base_fun): kwargs['utils'] = all_utils return base_fun(*args, **kwargs) return wrapper - - -from tensorflow.python.keras.applications.densenet import DenseNet121 -from tensorflow.python.keras.applications.densenet import DenseNet169 -from tensorflow.python.keras.applications.densenet import DenseNet201 -from tensorflow.python.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras.applications.imagenet_utils import preprocess_input -from tensorflow.python.keras.applications.inception_resnet_v2 import InceptionResNetV2 -from tensorflow.python.keras.applications.inception_v3 import InceptionV3 -from tensorflow.python.keras.applications.mobilenet import MobileNet -from tensorflow.python.keras.applications.mobilenet_v2 import MobileNetV2 -from tensorflow.python.keras.applications.nasnet import NASNetLarge -from tensorflow.python.keras.applications.nasnet import NASNetMobile -from tensorflow.python.keras.applications.resnet import ResNet50 -from tensorflow.python.keras.applications.resnet import ResNet101 -from tensorflow.python.keras.applications.resnet import ResNet152 -from tensorflow.python.keras.applications.resnet_v2 import ResNet50V2 -from tensorflow.python.keras.applications.resnet_v2 import ResNet101V2 -from tensorflow.python.keras.applications.resnet_v2 import ResNet152V2 -from tensorflow.python.keras.applications.vgg16 import VGG16 -from tensorflow.python.keras.applications.vgg19 import VGG19 -from tensorflow.python.keras.applications.xception import Xception diff --git a/tensorflow/python/keras/applications/applications_test.py b/tensorflow/python/keras/applications/applications_test.py index 06c5c6708a6..957fd5c7a82 100644 --- a/tensorflow/python/keras/applications/applications_test.py +++ b/tensorflow/python/keras/applications/applications_test.py @@ -20,28 +20,38 @@ from __future__ import print_function from absl.testing import parameterized -from tensorflow.python.keras import applications +from tensorflow.python.keras.applications import densenet +from tensorflow.python.keras.applications import inception_resnet_v2 +from tensorflow.python.keras.applications import inception_v3 +from tensorflow.python.keras.applications import mobilenet +from tensorflow.python.keras.applications import mobilenet_v2 +from tensorflow.python.keras.applications import nasnet +from tensorflow.python.keras.applications import resnet +from tensorflow.python.keras.applications import resnet_v2 +from tensorflow.python.keras.applications import vgg16 +from tensorflow.python.keras.applications import vgg19 +from tensorflow.python.keras.applications import xception from tensorflow.python.platform import test MODEL_LIST = [ - (applications.ResNet50, 2048), - (applications.ResNet101, 2048), - (applications.ResNet152, 2048), - (applications.ResNet50V2, 2048), - (applications.ResNet101V2, 2048), - (applications.ResNet152V2, 2048), - (applications.VGG16, 512), - (applications.VGG19, 512), - (applications.Xception, 2048), - (applications.InceptionV3, 2048), - (applications.InceptionResNetV2, 1536), - (applications.MobileNet, 1024), - (applications.MobileNetV2, 1280), - (applications.DenseNet121, 1024), - (applications.DenseNet169, 1664), - (applications.DenseNet201, 1920), - (applications.NASNetMobile, 1056), + (resnet.ResNet50, 2048), + (resnet.ResNet101, 2048), + (resnet.ResNet152, 2048), + (resnet_v2.ResNet50V2, 2048), + (resnet_v2.ResNet101V2, 2048), + (resnet_v2.ResNet152V2, 2048), + (vgg16.VGG16, 512), + (vgg19.VGG19, 512), + (xception.Xception, 2048), + (inception_v3.InceptionV3, 2048), + (inception_resnet_v2.InceptionResNetV2, 1536), + (mobilenet.MobileNet, 1024), + (mobilenet_v2.MobileNetV2, 1280), + (densenet.DenseNet121, 1024), + (densenet.DenseNet169, 1664), + (densenet.DenseNet201, 1920), + (nasnet.NASNetMobile, 1056), ]