Make layers used in Keras Applications aware of calls to enable/disable_v2_behavior.

PiperOrigin-RevId: 303807515
Change-Id: I4bca246dd446def8cde73b09c106d3a0cbad3264
This commit is contained in:
Francois Chollet 2020-03-30 13:13:23 -07:00 committed by TensorFlower Gardener
parent 433d40514b
commit 30f753a95f
14 changed files with 90 additions and 307 deletions

View File

@ -26,9 +26,9 @@ from __future__ import print_function
import os import os
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -52,6 +52,8 @@ DENSENET201_WEIGHT_PATH_NO_TOP = (
BASE_WEIGTHS_PATH + BASE_WEIGTHS_PATH +
'densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5') 'densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5')
layers = VersionAwareLayers()
def dense_block(x, blocks, name): def dense_block(x, blocks, name):
"""A dense block. """A dense block.
@ -133,8 +135,7 @@ def DenseNet(
input_shape=None, input_shape=None,
pooling=None, pooling=None,
classes=1000, classes=1000,
classifier_activation='softmax', classifier_activation='softmax'):
):
"""Instantiates the DenseNet architecture. """Instantiates the DenseNet architecture.
Reference paper: Reference paper:
@ -358,37 +359,12 @@ def DenseNet201(include_top=True,
@keras_export('keras.applications.densenet.preprocess_input') @keras_export('keras.applications.densenet.preprocess_input')
def preprocess_input(x, data_format=None): def preprocess_input(x, data_format=None):
"""Preprocesses a numpy array encoding a batch of images.
Arguments
x: A 4D numpy array consists of RGB values within [0, 255].
Returns
Preprocessed array.
Raises
ValueError: In case of unknown `data_format` argument.
"""
return imagenet_utils.preprocess_input( return imagenet_utils.preprocess_input(
x, data_format=data_format, mode='torch') x, data_format=data_format, mode='torch')
@keras_export('keras.applications.densenet.decode_predictions') @keras_export('keras.applications.densenet.decode_predictions')
def decode_predictions(preds, top=5): def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top) return imagenet_utils.decode_predictions(preds, top=top)
@ -405,7 +381,7 @@ DOC = """
Optionally loads weights pre-trained on ImageNet. Optionally loads weights pre-trained on ImageNet.
Note that the data format convention used by the model is Note that the data format convention used by the model is
the one specified in your Keras config at `~/.keras/keras.json`. the one specified in your Keras config at `~/.keras/keras.json`.
Arguments: Arguments:
include_top: whether to include the fully-connected include_top: whether to include the fully-connected
layer at the top of the network. layer at the top of the network.

View File

@ -28,9 +28,9 @@ import math
import os import os
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -140,6 +140,8 @@ DENSE_KERNEL_INITIALIZER = {
} }
} }
layers = VersionAwareLayers()
def EfficientNet( def EfficientNet(
width_coefficient, width_coefficient,
@ -157,8 +159,7 @@ def EfficientNet(
input_shape=None, input_shape=None,
pooling=None, pooling=None,
classes=1000, classes=1000,
classifier_activation='softmax', classifier_activation='softmax'):
):
"""Instantiates the EfficientNet architecture using given scaling coefficients. """Instantiates the EfficientNet architecture using given scaling coefficients.
Reference paper: Reference paper:
@ -664,18 +665,7 @@ def preprocess_input(x, data_format=None): # pylint: disable=unused-argument
@keras_export('keras.applications.efficientnet.decode_predictions') @keras_export('keras.applications.efficientnet.decode_predictions')
def decode_predictions(preds, top=5): def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top) return imagenet_utils.decode_predictions(preds, top=top)
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__

View File

@ -28,9 +28,9 @@ from __future__ import print_function
import os import os
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -38,6 +38,7 @@ from tensorflow.python.util.tf_export import keras_export
BASE_WEIGHT_URL = ('https://storage.googleapis.com/tensorflow/' BASE_WEIGHT_URL = ('https://storage.googleapis.com/tensorflow/'
'keras-applications/inception_resnet_v2/') 'keras-applications/inception_resnet_v2/')
layers = None
@keras_export('keras.applications.inception_resnet_v2.InceptionResNetV2', @keras_export('keras.applications.inception_resnet_v2.InceptionResNetV2',
@ -105,9 +106,11 @@ def InceptionResNetV2(include_top=True,
ValueError: if `classifier_activation` is not `softmax` or `None` when ValueError: if `classifier_activation` is not `softmax` or `None` when
using a pretrained top layer. using a pretrained top layer.
""" """
global layers
if 'layers' in kwargs: if 'layers' in kwargs:
global layers
layers = kwargs.pop('layers') layers = kwargs.pop('layers')
else:
layers = VersionAwareLayers()
if kwargs: if kwargs:
raise ValueError('Unknown argument(s): %s' % (kwargs,)) raise ValueError('Unknown argument(s): %s' % (kwargs,))
if not (weights in {'imagenet', None} or os.path.exists(weights)): if not (weights in {'imagenet', None} or os.path.exists(weights)):
@ -378,36 +381,11 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
@keras_export('keras.applications.inception_resnet_v2.preprocess_input') @keras_export('keras.applications.inception_resnet_v2.preprocess_input')
def preprocess_input(x, data_format=None): def preprocess_input(x, data_format=None):
"""Preprocesses a numpy array encoding a batch of images.
Arguments
x: A 4D numpy array consists of RGB values within [0, 255].
Returns
Preprocessed array.
Raises
ValueError: In case of unknown `data_format` argument.
"""
return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf')
@keras_export('keras.applications.inception_resnet_v2.decode_predictions') @keras_export('keras.applications.inception_resnet_v2.decode_predictions')
def decode_predictions(preds, top=5): def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top) return imagenet_utils.decode_predictions(preds, top=top)

View File

@ -26,9 +26,9 @@ from __future__ import print_function
import os import os
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -41,6 +41,8 @@ WEIGHTS_PATH_NO_TOP = (
'https://storage.googleapis.com/tensorflow/keras-applications/' 'https://storage.googleapis.com/tensorflow/keras-applications/'
'inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5') 'inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5')
layers = VersionAwareLayers()
@keras_export('keras.applications.inception_v3.InceptionV3', @keras_export('keras.applications.inception_v3.InceptionV3',
'keras.applications.InceptionV3') 'keras.applications.InceptionV3')
@ -51,8 +53,7 @@ def InceptionV3(
input_shape=None, input_shape=None,
pooling=None, pooling=None,
classes=1000, classes=1000,
classifier_activation='softmax', classifier_activation='softmax'):
):
"""Instantiates the Inception v3 architecture. """Instantiates the Inception v3 architecture.
Reference paper: Reference paper:
@ -406,36 +407,11 @@ def conv2d_bn(x,
@keras_export('keras.applications.inception_v3.preprocess_input') @keras_export('keras.applications.inception_v3.preprocess_input')
def preprocess_input(x, data_format=None): def preprocess_input(x, data_format=None):
"""Preprocesses a numpy array encoding a batch of images.
Arguments
x: A 4D numpy array consists of RGB values within [0, 255].
Returns
Preprocessed array.
Raises
ValueError: In case of unknown `data_format` argument.
"""
return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf')
@keras_export('keras.applications.inception_v3.decode_predictions') @keras_export('keras.applications.inception_v3.decode_predictions')
def decode_predictions(preds, top=5): def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top) return imagenet_utils.decode_predictions(preds, top=top)

View File

@ -67,9 +67,9 @@ from __future__ import print_function
import os import os
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
@ -77,6 +77,7 @@ from tensorflow.python.util.tf_export import keras_export
BASE_WEIGHT_PATH = ('https://storage.googleapis.com/tensorflow/' BASE_WEIGHT_PATH = ('https://storage.googleapis.com/tensorflow/'
'keras-applications/mobilenet/') 'keras-applications/mobilenet/')
layers = None
@keras_export('keras.applications.mobilenet.MobileNet', @keras_export('keras.applications.mobilenet.MobileNet',
@ -155,9 +156,11 @@ def MobileNet(input_shape=None,
ValueError: if `classifier_activation` is not `softmax` or `None` when ValueError: if `classifier_activation` is not `softmax` or `None` when
using a pretrained top layer. using a pretrained top layer.
""" """
global layers
if 'layers' in kwargs: if 'layers' in kwargs:
global layers
layers = kwargs.pop('layers') layers = kwargs.pop('layers')
else:
layers = VersionAwareLayers()
if kwargs: if kwargs:
raise ValueError('Unknown argument(s): %s' % (kwargs,)) raise ValueError('Unknown argument(s): %s' % (kwargs,))
if not (weights in {'imagenet', None} or os.path.exists(weights)): if not (weights in {'imagenet', None} or os.path.exists(weights)):
@ -439,36 +442,11 @@ def _depthwise_conv_block(inputs,
@keras_export('keras.applications.mobilenet.preprocess_input') @keras_export('keras.applications.mobilenet.preprocess_input')
def preprocess_input(x, data_format=None): def preprocess_input(x, data_format=None):
"""Preprocesses a numpy array encoding a batch of images.
Arguments
x: A 4D numpy array consists of RGB values within [0, 255].
Returns
Preprocessed array.
Raises
ValueError: In case of unknown `data_format` argument.
"""
return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf')
@keras_export('keras.applications.mobilenet.decode_predictions') @keras_export('keras.applications.mobilenet.decode_predictions')
def decode_predictions(preds, top=5): def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top) return imagenet_utils.decode_predictions(preds, top=top)

View File

@ -80,9 +80,9 @@ from __future__ import print_function
import os import os
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
@ -90,6 +90,7 @@ from tensorflow.python.util.tf_export import keras_export
BASE_WEIGHT_PATH = ('https://storage.googleapis.com/tensorflow/' BASE_WEIGHT_PATH = ('https://storage.googleapis.com/tensorflow/'
'keras-applications/mobilenet_v2/') 'keras-applications/mobilenet_v2/')
layers = None
@keras_export('keras.applications.mobilenet_v2.MobileNetV2', @keras_export('keras.applications.mobilenet_v2.MobileNetV2',
@ -173,9 +174,11 @@ def MobileNetV2(input_shape=None,
ValueError: if `classifier_activation` is not `softmax` or `None` when ValueError: if `classifier_activation` is not `softmax` or `None` when
using a pretrained top layer. using a pretrained top layer.
""" """
global layers
if 'layers' in kwargs: if 'layers' in kwargs:
global layers
layers = kwargs.pop('layers') layers = kwargs.pop('layers')
else:
layers = VersionAwareLayers()
if kwargs: if kwargs:
raise ValueError('Unknown argument(s): %s' % (kwargs,)) raise ValueError('Unknown argument(s): %s' % (kwargs,))
if not (weights in {'imagenet', None} or os.path.exists(weights)): if not (weights in {'imagenet', None} or os.path.exists(weights)):

View File

@ -44,9 +44,9 @@ from __future__ import print_function
import os import os
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
@ -60,6 +60,8 @@ NASNET_MOBILE_WEIGHT_PATH_NO_TOP = BASE_WEIGHTS_PATH + 'NASNet-mobile-no-top.h5'
NASNET_LARGE_WEIGHT_PATH = BASE_WEIGHTS_PATH + 'NASNet-large.h5' NASNET_LARGE_WEIGHT_PATH = BASE_WEIGHTS_PATH + 'NASNet-large.h5'
NASNET_LARGE_WEIGHT_PATH_NO_TOP = BASE_WEIGHTS_PATH + 'NASNet-large-no-top.h5' NASNET_LARGE_WEIGHT_PATH_NO_TOP = BASE_WEIGHTS_PATH + 'NASNet-large-no-top.h5'
layers = VersionAwareLayers()
def NASNet( def NASNet(
input_shape=None, input_shape=None,
@ -74,8 +76,7 @@ def NASNet(
pooling=None, pooling=None,
classes=1000, classes=1000,
default_size=None, default_size=None,
classifier_activation='softmax', classifier_activation='softmax'):
):
"""Instantiates a NASNet model. """Instantiates a NASNet model.
Reference paper: Reference paper:
@ -785,36 +786,11 @@ def _reduction_a_cell(ip, p, filters, block_id=None):
@keras_export('keras.applications.nasnet.preprocess_input') @keras_export('keras.applications.nasnet.preprocess_input')
def preprocess_input(x, data_format=None): def preprocess_input(x, data_format=None):
"""Preprocesses a numpy array encoding a batch of images.
Arguments
x: A 4D numpy array consists of RGB values within [0, 255].
Returns
Preprocessed array.
Raises
ValueError: In case of unknown `data_format` argument.
"""
return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf')
@keras_export('keras.applications.nasnet.decode_predictions') @keras_export('keras.applications.nasnet.decode_predictions')
def decode_predictions(preds, top=5): def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top) return imagenet_utils.decode_predictions(preds, top=top)

View File

@ -26,9 +26,9 @@ from __future__ import print_function
import os import os
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -55,6 +55,8 @@ WEIGHTS_HASHES = {
('34fb605428fcc7aa4d62f44404c11509', '0f678c91647380debd923963594981b3') ('34fb605428fcc7aa4d62f44404c11509', '0f678c91647380debd923963594981b3')
} }
layers = None
def ResNet(stack_fn, def ResNet(stack_fn,
preact, preact,
@ -129,9 +131,11 @@ def ResNet(stack_fn,
ValueError: if `classifier_activation` is not `softmax` or `None` when ValueError: if `classifier_activation` is not `softmax` or `None` when
using a pretrained top layer. using a pretrained top layer.
""" """
global layers
if 'layers' in kwargs: if 'layers' in kwargs:
global layers
layers = kwargs.pop('layers') layers = kwargs.pop('layers')
else:
layers = VersionAwareLayers()
if kwargs: if kwargs:
raise ValueError('Unknown argument(s): %s' % (kwargs,)) raise ValueError('Unknown argument(s): %s' % (kwargs,))
if not (weights in {'imagenet', None} or os.path.exists(weights)): if not (weights in {'imagenet', None} or os.path.exists(weights)):
@ -517,17 +521,6 @@ def ResNet152(include_top=True,
@keras_export('keras.applications.resnet50.preprocess_input', @keras_export('keras.applications.resnet50.preprocess_input',
'keras.applications.resnet.preprocess_input') 'keras.applications.resnet.preprocess_input')
def preprocess_input(x, data_format=None): def preprocess_input(x, data_format=None):
"""Preprocesses a numpy array encoding a batch of images.
Arguments
x: A 4D numpy array consists of RGB values within [0, 255].
Returns
Preprocessed array.
Raises
ValueError: In case of unknown `data_format` argument.
"""
return imagenet_utils.preprocess_input( return imagenet_utils.preprocess_input(
x, data_format=data_format, mode='caffe') x, data_format=data_format, mode='caffe')
@ -535,20 +528,6 @@ def preprocess_input(x, data_format=None):
@keras_export('keras.applications.resnet50.decode_predictions', @keras_export('keras.applications.resnet50.decode_predictions',
'keras.applications.resnet.decode_predictions') 'keras.applications.resnet.decode_predictions')
def decode_predictions(preds, top=5): def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top) return imagenet_utils.decode_predictions(preds, top=top)
@ -565,7 +544,7 @@ DOC = """
Optionally loads weights pre-trained on ImageNet. Optionally loads weights pre-trained on ImageNet.
Note that the data format convention used by the model is Note that the data format convention used by the model is
the one specified in your Keras config at `~/.keras/keras.json`. the one specified in your Keras config at `~/.keras/keras.json`.
Arguments: Arguments:
include_top: whether to include the fully-connected include_top: whether to include the fully-connected
layer at the top of the network. layer at the top of the network.

View File

@ -37,8 +37,7 @@ def ResNet50V2(
input_shape=None, input_shape=None,
pooling=None, pooling=None,
classes=1000, classes=1000,
classifier_activation='softmax', classifier_activation='softmax'):
):
"""Instantiates the ResNet50V2 architecture.""" """Instantiates the ResNet50V2 architecture."""
def stack_fn(x): def stack_fn(x):
x = resnet.stack2(x, 64, 3, name='conv2') x = resnet.stack2(x, 64, 3, name='conv2')
@ -57,8 +56,7 @@ def ResNet50V2(
input_shape, input_shape,
pooling, pooling,
classes, classes,
classifier_activation=classifier_activation, classifier_activation=classifier_activation)
)
@keras_export('keras.applications.resnet_v2.ResNet101V2', @keras_export('keras.applications.resnet_v2.ResNet101V2',
@ -70,8 +68,7 @@ def ResNet101V2(
input_shape=None, input_shape=None,
pooling=None, pooling=None,
classes=1000, classes=1000,
classifier_activation='softmax', classifier_activation='softmax'):
):
"""Instantiates the ResNet101V2 architecture.""" """Instantiates the ResNet101V2 architecture."""
def stack_fn(x): def stack_fn(x):
x = resnet.stack2(x, 64, 3, name='conv2') x = resnet.stack2(x, 64, 3, name='conv2')
@ -90,8 +87,7 @@ def ResNet101V2(
input_shape, input_shape,
pooling, pooling,
classes, classes,
classifier_activation=classifier_activation, classifier_activation=classifier_activation)
)
@keras_export('keras.applications.resnet_v2.ResNet152V2', @keras_export('keras.applications.resnet_v2.ResNet152V2',
@ -103,8 +99,7 @@ def ResNet152V2(
input_shape=None, input_shape=None,
pooling=None, pooling=None,
classes=1000, classes=1000,
classifier_activation='softmax', classifier_activation='softmax'):
):
"""Instantiates the ResNet152V2 architecture.""" """Instantiates the ResNet152V2 architecture."""
def stack_fn(x): def stack_fn(x):
x = resnet.stack2(x, 64, 3, name='conv2') x = resnet.stack2(x, 64, 3, name='conv2')
@ -123,43 +118,17 @@ def ResNet152V2(
input_shape, input_shape,
pooling, pooling,
classes, classes,
classifier_activation=classifier_activation, classifier_activation=classifier_activation)
)
@keras_export('keras.applications.resnet_v2.preprocess_input') @keras_export('keras.applications.resnet_v2.preprocess_input')
def preprocess_input(x, data_format=None): def preprocess_input(x, data_format=None):
"""Preprocesses a numpy array encoding a batch of images.
Arguments
x: A 4D numpy array consists of RGB values within [0, 255].
Returns
Preprocessed array.
Raises
ValueError: In case of unknown `data_format` argument.
"""
return imagenet_utils.preprocess_input( return imagenet_utils.preprocess_input(
x, data_format=data_format, mode='tf') x, data_format=data_format, mode='tf')
@keras_export('keras.applications.resnet_v2.decode_predictions') @keras_export('keras.applications.resnet_v2.decode_predictions')
def decode_predictions(preds, top=5): def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top) return imagenet_utils.decode_predictions(preds, top=top)

View File

@ -26,9 +26,9 @@ from __future__ import print_function
import os import os
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -40,6 +40,8 @@ WEIGHTS_PATH_NO_TOP = ('https://storage.googleapis.com/tensorflow/'
'keras-applications/vgg16/' 'keras-applications/vgg16/'
'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5') 'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5')
layers = VersionAwareLayers()
@keras_export('keras.applications.vgg16.VGG16', 'keras.applications.VGG16') @keras_export('keras.applications.vgg16.VGG16', 'keras.applications.VGG16')
def VGG16( def VGG16(
@ -49,8 +51,7 @@ def VGG16(
input_shape=None, input_shape=None,
pooling=None, pooling=None,
classes=1000, classes=1000,
classifier_activation='softmax', classifier_activation='softmax'):
):
"""Instantiates the VGG16 model. """Instantiates the VGG16 model.
Reference paper: Reference paper:
@ -227,37 +228,12 @@ def VGG16(
@keras_export('keras.applications.vgg16.preprocess_input') @keras_export('keras.applications.vgg16.preprocess_input')
def preprocess_input(x, data_format=None): def preprocess_input(x, data_format=None):
"""Preprocesses a numpy array encoding a batch of images.
Arguments
x: A 4D numpy array consists of RGB values within [0, 255].
Returns
Preprocessed array.
Raises
ValueError: In case of unknown `data_format` argument.
"""
return imagenet_utils.preprocess_input( return imagenet_utils.preprocess_input(
x, data_format=data_format, mode='caffe') x, data_format=data_format, mode='caffe')
@keras_export('keras.applications.vgg16.decode_predictions') @keras_export('keras.applications.vgg16.decode_predictions')
def decode_predictions(preds, top=5): def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top) return imagenet_utils.decode_predictions(preds, top=top)

View File

@ -26,9 +26,9 @@ from __future__ import print_function
import os import os
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -40,6 +40,8 @@ WEIGHTS_PATH_NO_TOP = ('https://storage.googleapis.com/tensorflow/'
'keras-applications/vgg19/' 'keras-applications/vgg19/'
'vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5') 'vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5')
layers = VersionAwareLayers()
@keras_export('keras.applications.vgg19.VGG19', 'keras.applications.VGG19') @keras_export('keras.applications.vgg19.VGG19', 'keras.applications.VGG19')
def VGG19( def VGG19(
@ -49,8 +51,7 @@ def VGG19(
input_shape=None, input_shape=None,
pooling=None, pooling=None,
classes=1000, classes=1000,
classifier_activation='softmax', classifier_activation='softmax'):
):
"""Instantiates the VGG19 architecture. """Instantiates the VGG19 architecture.
Reference: Reference:
@ -232,37 +233,12 @@ def VGG19(
@keras_export('keras.applications.vgg19.preprocess_input') @keras_export('keras.applications.vgg19.preprocess_input')
def preprocess_input(x, data_format=None): def preprocess_input(x, data_format=None):
"""Preprocesses a numpy array encoding a batch of images.
Arguments
x: A 4D numpy array consists of RGB values within [0, 255].
Returns
Preprocessed array.
Raises
ValueError: In case of unknown `data_format` argument.
"""
return imagenet_utils.preprocess_input( return imagenet_utils.preprocess_input(
x, data_format=data_format, mode='caffe') x, data_format=data_format, mode='caffe')
@keras_export('keras.applications.vgg19.decode_predictions') @keras_export('keras.applications.vgg19.decode_predictions')
def decode_predictions(preds, top=5): def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top) return imagenet_utils.decode_predictions(preds, top=top)

View File

@ -30,9 +30,9 @@ from __future__ import print_function
import os import os
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -45,6 +45,8 @@ TF_WEIGHTS_PATH_NO_TOP = (
'https://storage.googleapis.com/tensorflow/keras-applications/' 'https://storage.googleapis.com/tensorflow/keras-applications/'
'xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5') 'xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5')
layers = VersionAwareLayers()
@keras_export('keras.applications.xception.Xception', @keras_export('keras.applications.xception.Xception',
'keras.applications.Xception') 'keras.applications.Xception')
@ -55,8 +57,7 @@ def Xception(
input_shape=None, input_shape=None,
pooling=None, pooling=None,
classes=1000, classes=1000,
classifier_activation='softmax', classifier_activation='softmax'):
):
"""Instantiates the Xception architecture. """Instantiates the Xception architecture.
Optionally loads weights pre-trained on ImageNet. Optionally loads weights pre-trained on ImageNet.
@ -312,36 +313,11 @@ def Xception(
@keras_export('keras.applications.xception.preprocess_input') @keras_export('keras.applications.xception.preprocess_input')
def preprocess_input(x, data_format=None): def preprocess_input(x, data_format=None):
"""Preprocesses a numpy array encoding a batch of images.
Arguments
x: A 4D numpy array consists of RGB values within [0, 255].
Returns
Preprocessed array.
Raises
ValueError: In case of unknown `data_format` argument.
"""
return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf')
@keras_export('keras.applications.xception.decode_predictions') @keras_export('keras.applications.xception.decode_predictions')
def decode_predictions(preds, top=5): def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top) return imagenet_utils.decode_predictions(preds, top=top)

View File

@ -242,9 +242,28 @@ from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import DropoutWrapper
from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper
# Serialization functions # Serialization functions
from tensorflow.python.keras.layers import serialization
from tensorflow.python.keras.layers.serialization import deserialize from tensorflow.python.keras.layers.serialization import deserialize
from tensorflow.python.keras.layers.serialization import serialize from tensorflow.python.keras.layers.serialization import serialize
class VersionAwareLayers(object):
"""Utility to be used internally to access layers in a V1/V2-aware fashion.
When using layers within the Keras codebase, under the constraint that
e.g. `layers.BatchNormalization` should be the `BatchNormalization` version
corresponding to the current runtime (TF1 or TF2), do not simply access
`layers.BatchNormalization` since it would ignore e.g. an early
`compat.v2.disable_v2_behavior()` call. Instead, use an instance
of `VersionAwareLayers` (which you can use just like the `layers` module).
"""
def __getattr__(self, name):
serialization.populate_deserializable_objects()
if name in serialization.LOCAL.ALL_OBJECTS:
return serialization.LOCAL.ALL_OBJECTS[name]
return super(VersionAwareLayers, self).__getattr__(name)
del absolute_import del absolute_import
del division del division
del print_function del print_function

View File

@ -26,6 +26,7 @@ import threading
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import input_layer from tensorflow.python.keras.engine import input_layer
from tensorflow.python.keras.engine import input_spec
from tensorflow.python.keras.layers import advanced_activations from tensorflow.python.keras.layers import advanced_activations
from tensorflow.python.keras.layers import convolutional from tensorflow.python.keras.layers import convolutional
from tensorflow.python.keras.layers import convolutional_recurrent from tensorflow.python.keras.layers import convolutional_recurrent
@ -128,6 +129,7 @@ def populate_deserializable_objects():
from tensorflow.python.feature_column import sequence_feature_column as sfc # pylint: disable=g-import-not-at-top from tensorflow.python.feature_column import sequence_feature_column as sfc # pylint: disable=g-import-not-at-top
LOCAL.ALL_OBJECTS['Input'] = input_layer.Input LOCAL.ALL_OBJECTS['Input'] = input_layer.Input
LOCAL.ALL_OBJECTS['InputSpec'] = input_spec.InputSpec
LOCAL.ALL_OBJECTS['Network'] = models.Network LOCAL.ALL_OBJECTS['Network'] = models.Network
LOCAL.ALL_OBJECTS['Model'] = models.Model LOCAL.ALL_OBJECTS['Model'] = models.Model
LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential
@ -135,6 +137,15 @@ def populate_deserializable_objects():
LOCAL.ALL_OBJECTS['WideDeepModel'] = WideDeepModel LOCAL.ALL_OBJECTS['WideDeepModel'] = WideDeepModel
LOCAL.ALL_OBJECTS['DenseFeatures'] = dense_features.DenseFeatures LOCAL.ALL_OBJECTS['DenseFeatures'] = dense_features.DenseFeatures
LOCAL.ALL_OBJECTS['SequenceFeatures'] = sfc.SequenceFeatures LOCAL.ALL_OBJECTS['SequenceFeatures'] = sfc.SequenceFeatures
# Merge layers, function versions.
LOCAL.ALL_OBJECTS['add'] = merge.add
LOCAL.ALL_OBJECTS['subtract'] = merge.subtract
LOCAL.ALL_OBJECTS['multiply'] = merge.multiply
LOCAL.ALL_OBJECTS['average'] = merge.average
LOCAL.ALL_OBJECTS['maximum'] = merge.maximum
LOCAL.ALL_OBJECTS['minimum'] = merge.minimum
LOCAL.ALL_OBJECTS['concatenate'] = merge.concatenate
LOCAL.ALL_OBJECTS['dot'] = merge.dot
@keras_export('keras.layers.serialize') @keras_export('keras.layers.serialize')