From 30f753a95f645caf47661ff0802b8b99a7ff4a8e Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Mon, 30 Mar 2020 13:13:23 -0700 Subject: [PATCH] Make layers used in Keras Applications aware of calls to `enable/disable_v2_behavior`. PiperOrigin-RevId: 303807515 Change-Id: I4bca246dd446def8cde73b09c106d3a0cbad3264 --- .../python/keras/applications/densenet.py | 34 +++------------ .../python/keras/applications/efficientnet.py | 24 +++-------- .../keras/applications/inception_resnet_v2.py | 32 +++----------- .../python/keras/applications/inception_v3.py | 32 ++------------ .../python/keras/applications/mobilenet.py | 32 +++----------- .../python/keras/applications/mobilenet_v2.py | 7 ++- .../python/keras/applications/nasnet.py | 32 ++------------ .../python/keras/applications/resnet.py | 35 +++------------ .../python/keras/applications/resnet_v2.py | 43 +++---------------- tensorflow/python/keras/applications/vgg16.py | 32 ++------------ tensorflow/python/keras/applications/vgg19.py | 32 ++------------ .../python/keras/applications/xception.py | 32 ++------------ tensorflow/python/keras/layers/__init__.py | 19 ++++++++ .../python/keras/layers/serialization.py | 11 +++++ 14 files changed, 90 insertions(+), 307 deletions(-) diff --git a/tensorflow/python/keras/applications/densenet.py b/tensorflow/python/keras/applications/densenet.py index 9b11c342536..fe353bcef15 100644 --- a/tensorflow/python/keras/applications/densenet.py +++ b/tensorflow/python/keras/applications/densenet.py @@ -26,9 +26,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.util.tf_export import keras_export @@ -52,6 +52,8 @@ DENSENET201_WEIGHT_PATH_NO_TOP = ( BASE_WEIGTHS_PATH + 'densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5') +layers = VersionAwareLayers() + def dense_block(x, blocks, name): """A dense block. @@ -133,8 +135,7 @@ def DenseNet( input_shape=None, pooling=None, classes=1000, - classifier_activation='softmax', -): + classifier_activation='softmax'): """Instantiates the DenseNet architecture. Reference paper: @@ -358,37 +359,12 @@ def DenseNet201(include_top=True, @keras_export('keras.applications.densenet.preprocess_input') def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments - x: A 4D numpy array consists of RGB values within [0, 255]. - - Returns - Preprocessed array. - - Raises - ValueError: In case of unknown `data_format` argument. - """ return imagenet_utils.preprocess_input( x, data_format=data_format, mode='torch') @keras_export('keras.applications.densenet.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) @@ -405,7 +381,7 @@ DOC = """ Optionally loads weights pre-trained on ImageNet. Note that the data format convention used by the model is the one specified in your Keras config at `~/.keras/keras.json`. - + Arguments: include_top: whether to include the fully-connected layer at the top of the network. diff --git a/tensorflow/python/keras/applications/efficientnet.py b/tensorflow/python/keras/applications/efficientnet.py index 4b9487dcdd6..0487450f880 100644 --- a/tensorflow/python/keras/applications/efficientnet.py +++ b/tensorflow/python/keras/applications/efficientnet.py @@ -28,9 +28,9 @@ import math import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.util.tf_export import keras_export @@ -140,6 +140,8 @@ DENSE_KERNEL_INITIALIZER = { } } +layers = VersionAwareLayers() + def EfficientNet( width_coefficient, @@ -157,8 +159,7 @@ def EfficientNet( input_shape=None, pooling=None, classes=1000, - classifier_activation='softmax', -): + classifier_activation='softmax'): """Instantiates the EfficientNet architecture using given scaling coefficients. Reference paper: @@ -664,18 +665,7 @@ def preprocess_input(x, data_format=None): # pylint: disable=unused-argument @keras_export('keras.applications.efficientnet.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) + + +decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ diff --git a/tensorflow/python/keras/applications/inception_resnet_v2.py b/tensorflow/python/keras/applications/inception_resnet_v2.py index 7f338f82597..d4ffd372a10 100644 --- a/tensorflow/python/keras/applications/inception_resnet_v2.py +++ b/tensorflow/python/keras/applications/inception_resnet_v2.py @@ -28,9 +28,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.util.tf_export import keras_export @@ -38,6 +38,7 @@ from tensorflow.python.util.tf_export import keras_export BASE_WEIGHT_URL = ('https://storage.googleapis.com/tensorflow/' 'keras-applications/inception_resnet_v2/') +layers = None @keras_export('keras.applications.inception_resnet_v2.InceptionResNetV2', @@ -105,9 +106,11 @@ def InceptionResNetV2(include_top=True, ValueError: if `classifier_activation` is not `softmax` or `None` when using a pretrained top layer. """ + global layers if 'layers' in kwargs: - global layers layers = kwargs.pop('layers') + else: + layers = VersionAwareLayers() if kwargs: raise ValueError('Unknown argument(s): %s' % (kwargs,)) if not (weights in {'imagenet', None} or os.path.exists(weights)): @@ -378,36 +381,11 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'): @keras_export('keras.applications.inception_resnet_v2.preprocess_input') def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments - x: A 4D numpy array consists of RGB values within [0, 255]. - - Returns - Preprocessed array. - - Raises - ValueError: In case of unknown `data_format` argument. - """ return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') @keras_export('keras.applications.inception_resnet_v2.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) diff --git a/tensorflow/python/keras/applications/inception_v3.py b/tensorflow/python/keras/applications/inception_v3.py index fa44becfe48..21f65b1fbc7 100644 --- a/tensorflow/python/keras/applications/inception_v3.py +++ b/tensorflow/python/keras/applications/inception_v3.py @@ -26,9 +26,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.util.tf_export import keras_export @@ -41,6 +41,8 @@ WEIGHTS_PATH_NO_TOP = ( 'https://storage.googleapis.com/tensorflow/keras-applications/' 'inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5') +layers = VersionAwareLayers() + @keras_export('keras.applications.inception_v3.InceptionV3', 'keras.applications.InceptionV3') @@ -51,8 +53,7 @@ def InceptionV3( input_shape=None, pooling=None, classes=1000, - classifier_activation='softmax', -): + classifier_activation='softmax'): """Instantiates the Inception v3 architecture. Reference paper: @@ -406,36 +407,11 @@ def conv2d_bn(x, @keras_export('keras.applications.inception_v3.preprocess_input') def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments - x: A 4D numpy array consists of RGB values within [0, 255]. - - Returns - Preprocessed array. - - Raises - ValueError: In case of unknown `data_format` argument. - """ return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') @keras_export('keras.applications.inception_v3.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) diff --git a/tensorflow/python/keras/applications/mobilenet.py b/tensorflow/python/keras/applications/mobilenet.py index d935282f98a..c79627c6aa7 100644 --- a/tensorflow/python/keras/applications/mobilenet.py +++ b/tensorflow/python/keras/applications/mobilenet.py @@ -67,9 +67,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.platform import tf_logging as logging @@ -77,6 +77,7 @@ from tensorflow.python.util.tf_export import keras_export BASE_WEIGHT_PATH = ('https://storage.googleapis.com/tensorflow/' 'keras-applications/mobilenet/') +layers = None @keras_export('keras.applications.mobilenet.MobileNet', @@ -155,9 +156,11 @@ def MobileNet(input_shape=None, ValueError: if `classifier_activation` is not `softmax` or `None` when using a pretrained top layer. """ + global layers if 'layers' in kwargs: - global layers layers = kwargs.pop('layers') + else: + layers = VersionAwareLayers() if kwargs: raise ValueError('Unknown argument(s): %s' % (kwargs,)) if not (weights in {'imagenet', None} or os.path.exists(weights)): @@ -439,36 +442,11 @@ def _depthwise_conv_block(inputs, @keras_export('keras.applications.mobilenet.preprocess_input') def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments - x: A 4D numpy array consists of RGB values within [0, 255]. - - Returns - Preprocessed array. - - Raises - ValueError: In case of unknown `data_format` argument. - """ return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') @keras_export('keras.applications.mobilenet.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) diff --git a/tensorflow/python/keras/applications/mobilenet_v2.py b/tensorflow/python/keras/applications/mobilenet_v2.py index bdd21c3da62..59aeba572e3 100644 --- a/tensorflow/python/keras/applications/mobilenet_v2.py +++ b/tensorflow/python/keras/applications/mobilenet_v2.py @@ -80,9 +80,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.platform import tf_logging as logging @@ -90,6 +90,7 @@ from tensorflow.python.util.tf_export import keras_export BASE_WEIGHT_PATH = ('https://storage.googleapis.com/tensorflow/' 'keras-applications/mobilenet_v2/') +layers = None @keras_export('keras.applications.mobilenet_v2.MobileNetV2', @@ -173,9 +174,11 @@ def MobileNetV2(input_shape=None, ValueError: if `classifier_activation` is not `softmax` or `None` when using a pretrained top layer. """ + global layers if 'layers' in kwargs: - global layers layers = kwargs.pop('layers') + else: + layers = VersionAwareLayers() if kwargs: raise ValueError('Unknown argument(s): %s' % (kwargs,)) if not (weights in {'imagenet', None} or os.path.exists(weights)): diff --git a/tensorflow/python/keras/applications/nasnet.py b/tensorflow/python/keras/applications/nasnet.py index 3da415dbb12..5c3117d8a47 100644 --- a/tensorflow/python/keras/applications/nasnet.py +++ b/tensorflow/python/keras/applications/nasnet.py @@ -44,9 +44,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.platform import tf_logging as logging @@ -60,6 +60,8 @@ NASNET_MOBILE_WEIGHT_PATH_NO_TOP = BASE_WEIGHTS_PATH + 'NASNet-mobile-no-top.h5' NASNET_LARGE_WEIGHT_PATH = BASE_WEIGHTS_PATH + 'NASNet-large.h5' NASNET_LARGE_WEIGHT_PATH_NO_TOP = BASE_WEIGHTS_PATH + 'NASNet-large-no-top.h5' +layers = VersionAwareLayers() + def NASNet( input_shape=None, @@ -74,8 +76,7 @@ def NASNet( pooling=None, classes=1000, default_size=None, - classifier_activation='softmax', -): + classifier_activation='softmax'): """Instantiates a NASNet model. Reference paper: @@ -785,36 +786,11 @@ def _reduction_a_cell(ip, p, filters, block_id=None): @keras_export('keras.applications.nasnet.preprocess_input') def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments - x: A 4D numpy array consists of RGB values within [0, 255]. - - Returns - Preprocessed array. - - Raises - ValueError: In case of unknown `data_format` argument. - """ return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') @keras_export('keras.applications.nasnet.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) diff --git a/tensorflow/python/keras/applications/resnet.py b/tensorflow/python/keras/applications/resnet.py index 3e33bb04bdd..ecb3f31e0c9 100644 --- a/tensorflow/python/keras/applications/resnet.py +++ b/tensorflow/python/keras/applications/resnet.py @@ -26,9 +26,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.util.tf_export import keras_export @@ -55,6 +55,8 @@ WEIGHTS_HASHES = { ('34fb605428fcc7aa4d62f44404c11509', '0f678c91647380debd923963594981b3') } +layers = None + def ResNet(stack_fn, preact, @@ -129,9 +131,11 @@ def ResNet(stack_fn, ValueError: if `classifier_activation` is not `softmax` or `None` when using a pretrained top layer. """ + global layers if 'layers' in kwargs: - global layers layers = kwargs.pop('layers') + else: + layers = VersionAwareLayers() if kwargs: raise ValueError('Unknown argument(s): %s' % (kwargs,)) if not (weights in {'imagenet', None} or os.path.exists(weights)): @@ -517,17 +521,6 @@ def ResNet152(include_top=True, @keras_export('keras.applications.resnet50.preprocess_input', 'keras.applications.resnet.preprocess_input') def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments - x: A 4D numpy array consists of RGB values within [0, 255]. - - Returns - Preprocessed array. - - Raises - ValueError: In case of unknown `data_format` argument. - """ return imagenet_utils.preprocess_input( x, data_format=data_format, mode='caffe') @@ -535,20 +528,6 @@ def preprocess_input(x, data_format=None): @keras_export('keras.applications.resnet50.decode_predictions', 'keras.applications.resnet.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) @@ -565,7 +544,7 @@ DOC = """ Optionally loads weights pre-trained on ImageNet. Note that the data format convention used by the model is the one specified in your Keras config at `~/.keras/keras.json`. - + Arguments: include_top: whether to include the fully-connected layer at the top of the network. diff --git a/tensorflow/python/keras/applications/resnet_v2.py b/tensorflow/python/keras/applications/resnet_v2.py index 2e1ee272c4b..a8f6e526ad5 100644 --- a/tensorflow/python/keras/applications/resnet_v2.py +++ b/tensorflow/python/keras/applications/resnet_v2.py @@ -37,8 +37,7 @@ def ResNet50V2( input_shape=None, pooling=None, classes=1000, - classifier_activation='softmax', -): + classifier_activation='softmax'): """Instantiates the ResNet50V2 architecture.""" def stack_fn(x): x = resnet.stack2(x, 64, 3, name='conv2') @@ -57,8 +56,7 @@ def ResNet50V2( input_shape, pooling, classes, - classifier_activation=classifier_activation, - ) + classifier_activation=classifier_activation) @keras_export('keras.applications.resnet_v2.ResNet101V2', @@ -70,8 +68,7 @@ def ResNet101V2( input_shape=None, pooling=None, classes=1000, - classifier_activation='softmax', -): + classifier_activation='softmax'): """Instantiates the ResNet101V2 architecture.""" def stack_fn(x): x = resnet.stack2(x, 64, 3, name='conv2') @@ -90,8 +87,7 @@ def ResNet101V2( input_shape, pooling, classes, - classifier_activation=classifier_activation, - ) + classifier_activation=classifier_activation) @keras_export('keras.applications.resnet_v2.ResNet152V2', @@ -103,8 +99,7 @@ def ResNet152V2( input_shape=None, pooling=None, classes=1000, - classifier_activation='softmax', -): + classifier_activation='softmax'): """Instantiates the ResNet152V2 architecture.""" def stack_fn(x): x = resnet.stack2(x, 64, 3, name='conv2') @@ -123,43 +118,17 @@ def ResNet152V2( input_shape, pooling, classes, - classifier_activation=classifier_activation, - ) + classifier_activation=classifier_activation) @keras_export('keras.applications.resnet_v2.preprocess_input') def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments - x: A 4D numpy array consists of RGB values within [0, 255]. - - Returns - Preprocessed array. - - Raises - ValueError: In case of unknown `data_format` argument. - """ return imagenet_utils.preprocess_input( x, data_format=data_format, mode='tf') @keras_export('keras.applications.resnet_v2.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) diff --git a/tensorflow/python/keras/applications/vgg16.py b/tensorflow/python/keras/applications/vgg16.py index 534d2cff6be..3a523dc5dc3 100644 --- a/tensorflow/python/keras/applications/vgg16.py +++ b/tensorflow/python/keras/applications/vgg16.py @@ -26,9 +26,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.util.tf_export import keras_export @@ -40,6 +40,8 @@ WEIGHTS_PATH_NO_TOP = ('https://storage.googleapis.com/tensorflow/' 'keras-applications/vgg16/' 'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5') +layers = VersionAwareLayers() + @keras_export('keras.applications.vgg16.VGG16', 'keras.applications.VGG16') def VGG16( @@ -49,8 +51,7 @@ def VGG16( input_shape=None, pooling=None, classes=1000, - classifier_activation='softmax', -): + classifier_activation='softmax'): """Instantiates the VGG16 model. Reference paper: @@ -227,37 +228,12 @@ def VGG16( @keras_export('keras.applications.vgg16.preprocess_input') def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments - x: A 4D numpy array consists of RGB values within [0, 255]. - - Returns - Preprocessed array. - - Raises - ValueError: In case of unknown `data_format` argument. - """ return imagenet_utils.preprocess_input( x, data_format=data_format, mode='caffe') @keras_export('keras.applications.vgg16.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) diff --git a/tensorflow/python/keras/applications/vgg19.py b/tensorflow/python/keras/applications/vgg19.py index 81c90e1ebb4..e4385cc8f6a 100644 --- a/tensorflow/python/keras/applications/vgg19.py +++ b/tensorflow/python/keras/applications/vgg19.py @@ -26,9 +26,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.util.tf_export import keras_export @@ -40,6 +40,8 @@ WEIGHTS_PATH_NO_TOP = ('https://storage.googleapis.com/tensorflow/' 'keras-applications/vgg19/' 'vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5') +layers = VersionAwareLayers() + @keras_export('keras.applications.vgg19.VGG19', 'keras.applications.VGG19') def VGG19( @@ -49,8 +51,7 @@ def VGG19( input_shape=None, pooling=None, classes=1000, - classifier_activation='softmax', -): + classifier_activation='softmax'): """Instantiates the VGG19 architecture. Reference: @@ -232,37 +233,12 @@ def VGG19( @keras_export('keras.applications.vgg19.preprocess_input') def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments - x: A 4D numpy array consists of RGB values within [0, 255]. - - Returns - Preprocessed array. - - Raises - ValueError: In case of unknown `data_format` argument. - """ return imagenet_utils.preprocess_input( x, data_format=data_format, mode='caffe') @keras_export('keras.applications.vgg19.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) diff --git a/tensorflow/python/keras/applications/xception.py b/tensorflow/python/keras/applications/xception.py index 5ea0f14cc79..d92bfd0f4c6 100644 --- a/tensorflow/python/keras/applications/xception.py +++ b/tensorflow/python/keras/applications/xception.py @@ -30,9 +30,9 @@ from __future__ import print_function import os from tensorflow.python.keras import backend -from tensorflow.python.keras import layers from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils from tensorflow.python.util.tf_export import keras_export @@ -45,6 +45,8 @@ TF_WEIGHTS_PATH_NO_TOP = ( 'https://storage.googleapis.com/tensorflow/keras-applications/' 'xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5') +layers = VersionAwareLayers() + @keras_export('keras.applications.xception.Xception', 'keras.applications.Xception') @@ -55,8 +57,7 @@ def Xception( input_shape=None, pooling=None, classes=1000, - classifier_activation='softmax', -): + classifier_activation='softmax'): """Instantiates the Xception architecture. Optionally loads weights pre-trained on ImageNet. @@ -312,36 +313,11 @@ def Xception( @keras_export('keras.applications.xception.preprocess_input') def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments - x: A 4D numpy array consists of RGB values within [0, 255]. - - Returns - Preprocessed array. - - Raises - ValueError: In case of unknown `data_format` argument. - """ return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') @keras_export('keras.applications.xception.decode_predictions') def decode_predictions(preds, top=5): - """Decodes the prediction result from the model. - - Arguments - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises - ValueError: In case of invalid shape of the `preds` array (must be 2D). - """ return imagenet_utils.decode_predictions(preds, top=top) diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py index 9fd902d70e9..9b4bc46ef31 100644 --- a/tensorflow/python/keras/layers/__init__.py +++ b/tensorflow/python/keras/layers/__init__.py @@ -242,9 +242,28 @@ from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import DropoutWrapper from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper # Serialization functions +from tensorflow.python.keras.layers import serialization from tensorflow.python.keras.layers.serialization import deserialize from tensorflow.python.keras.layers.serialization import serialize + +class VersionAwareLayers(object): + """Utility to be used internally to access layers in a V1/V2-aware fashion. + + When using layers within the Keras codebase, under the constraint that + e.g. `layers.BatchNormalization` should be the `BatchNormalization` version + corresponding to the current runtime (TF1 or TF2), do not simply access + `layers.BatchNormalization` since it would ignore e.g. an early + `compat.v2.disable_v2_behavior()` call. Instead, use an instance + of `VersionAwareLayers` (which you can use just like the `layers` module). + """ + + def __getattr__(self, name): + serialization.populate_deserializable_objects() + if name in serialization.LOCAL.ALL_OBJECTS: + return serialization.LOCAL.ALL_OBJECTS[name] + return super(VersionAwareLayers, self).__getattr__(name) + del absolute_import del division del print_function diff --git a/tensorflow/python/keras/layers/serialization.py b/tensorflow/python/keras/layers/serialization.py index a73f7744d0e..64bee4d6121 100644 --- a/tensorflow/python/keras/layers/serialization.py +++ b/tensorflow/python/keras/layers/serialization.py @@ -26,6 +26,7 @@ import threading from tensorflow.python import tf2 from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import input_layer +from tensorflow.python.keras.engine import input_spec from tensorflow.python.keras.layers import advanced_activations from tensorflow.python.keras.layers import convolutional from tensorflow.python.keras.layers import convolutional_recurrent @@ -128,6 +129,7 @@ def populate_deserializable_objects(): from tensorflow.python.feature_column import sequence_feature_column as sfc # pylint: disable=g-import-not-at-top LOCAL.ALL_OBJECTS['Input'] = input_layer.Input + LOCAL.ALL_OBJECTS['InputSpec'] = input_spec.InputSpec LOCAL.ALL_OBJECTS['Network'] = models.Network LOCAL.ALL_OBJECTS['Model'] = models.Model LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential @@ -135,6 +137,15 @@ def populate_deserializable_objects(): LOCAL.ALL_OBJECTS['WideDeepModel'] = WideDeepModel LOCAL.ALL_OBJECTS['DenseFeatures'] = dense_features.DenseFeatures LOCAL.ALL_OBJECTS['SequenceFeatures'] = sfc.SequenceFeatures + # Merge layers, function versions. + LOCAL.ALL_OBJECTS['add'] = merge.add + LOCAL.ALL_OBJECTS['subtract'] = merge.subtract + LOCAL.ALL_OBJECTS['multiply'] = merge.multiply + LOCAL.ALL_OBJECTS['average'] = merge.average + LOCAL.ALL_OBJECTS['maximum'] = merge.maximum + LOCAL.ALL_OBJECTS['minimum'] = merge.minimum + LOCAL.ALL_OBJECTS['concatenate'] = merge.concatenate + LOCAL.ALL_OBJECTS['dot'] = merge.dot @keras_export('keras.layers.serialize')