Make layers used in Keras Applications aware of calls to enable/disable_v2_behavior.

PiperOrigin-RevId: 303807515
Change-Id: I4bca246dd446def8cde73b09c106d3a0cbad3264
This commit is contained in:
Francois Chollet 2020-03-30 13:13:23 -07:00 committed by TensorFlower Gardener
parent 433d40514b
commit 30f753a95f
14 changed files with 90 additions and 307 deletions

View File

@ -26,9 +26,9 @@ from __future__ import print_function
import os
from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.util.tf_export import keras_export
@ -52,6 +52,8 @@ DENSENET201_WEIGHT_PATH_NO_TOP = (
BASE_WEIGTHS_PATH +
'densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5')
layers = VersionAwareLayers()
def dense_block(x, blocks, name):
"""A dense block.
@ -133,8 +135,7 @@ def DenseNet(
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
):
classifier_activation='softmax'):
"""Instantiates the DenseNet architecture.
Reference paper:
@ -358,37 +359,12 @@ def DenseNet201(include_top=True,
@keras_export('keras.applications.densenet.preprocess_input')
def preprocess_input(x, data_format=None):
"""Preprocesses a numpy array encoding a batch of images.
Arguments
x: A 4D numpy array consists of RGB values within [0, 255].
Returns
Preprocessed array.
Raises
ValueError: In case of unknown `data_format` argument.
"""
return imagenet_utils.preprocess_input(
x, data_format=data_format, mode='torch')
@keras_export('keras.applications.densenet.decode_predictions')
def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top)

View File

@ -28,9 +28,9 @@ import math
import os
from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.util.tf_export import keras_export
@ -140,6 +140,8 @@ DENSE_KERNEL_INITIALIZER = {
}
}
layers = VersionAwareLayers()
def EfficientNet(
width_coefficient,
@ -157,8 +159,7 @@ def EfficientNet(
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
):
classifier_activation='softmax'):
"""Instantiates the EfficientNet architecture using given scaling coefficients.
Reference paper:
@ -664,18 +665,7 @@ def preprocess_input(x, data_format=None): # pylint: disable=unused-argument
@keras_export('keras.applications.efficientnet.decode_predictions')
def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top)
decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__

View File

@ -28,9 +28,9 @@ from __future__ import print_function
import os
from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.util.tf_export import keras_export
@ -38,6 +38,7 @@ from tensorflow.python.util.tf_export import keras_export
BASE_WEIGHT_URL = ('https://storage.googleapis.com/tensorflow/'
'keras-applications/inception_resnet_v2/')
layers = None
@keras_export('keras.applications.inception_resnet_v2.InceptionResNetV2',
@ -105,9 +106,11 @@ def InceptionResNetV2(include_top=True,
ValueError: if `classifier_activation` is not `softmax` or `None` when
using a pretrained top layer.
"""
if 'layers' in kwargs:
global layers
if 'layers' in kwargs:
layers = kwargs.pop('layers')
else:
layers = VersionAwareLayers()
if kwargs:
raise ValueError('Unknown argument(s): %s' % (kwargs,))
if not (weights in {'imagenet', None} or os.path.exists(weights)):
@ -378,36 +381,11 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
@keras_export('keras.applications.inception_resnet_v2.preprocess_input')
def preprocess_input(x, data_format=None):
"""Preprocesses a numpy array encoding a batch of images.
Arguments
x: A 4D numpy array consists of RGB values within [0, 255].
Returns
Preprocessed array.
Raises
ValueError: In case of unknown `data_format` argument.
"""
return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf')
@keras_export('keras.applications.inception_resnet_v2.decode_predictions')
def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top)

View File

@ -26,9 +26,9 @@ from __future__ import print_function
import os
from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.util.tf_export import keras_export
@ -41,6 +41,8 @@ WEIGHTS_PATH_NO_TOP = (
'https://storage.googleapis.com/tensorflow/keras-applications/'
'inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5')
layers = VersionAwareLayers()
@keras_export('keras.applications.inception_v3.InceptionV3',
'keras.applications.InceptionV3')
@ -51,8 +53,7 @@ def InceptionV3(
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
):
classifier_activation='softmax'):
"""Instantiates the Inception v3 architecture.
Reference paper:
@ -406,36 +407,11 @@ def conv2d_bn(x,
@keras_export('keras.applications.inception_v3.preprocess_input')
def preprocess_input(x, data_format=None):
"""Preprocesses a numpy array encoding a batch of images.
Arguments
x: A 4D numpy array consists of RGB values within [0, 255].
Returns
Preprocessed array.
Raises
ValueError: In case of unknown `data_format` argument.
"""
return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf')
@keras_export('keras.applications.inception_v3.decode_predictions')
def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top)

View File

@ -67,9 +67,9 @@ from __future__ import print_function
import os
from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.platform import tf_logging as logging
@ -77,6 +77,7 @@ from tensorflow.python.util.tf_export import keras_export
BASE_WEIGHT_PATH = ('https://storage.googleapis.com/tensorflow/'
'keras-applications/mobilenet/')
layers = None
@keras_export('keras.applications.mobilenet.MobileNet',
@ -155,9 +156,11 @@ def MobileNet(input_shape=None,
ValueError: if `classifier_activation` is not `softmax` or `None` when
using a pretrained top layer.
"""
if 'layers' in kwargs:
global layers
if 'layers' in kwargs:
layers = kwargs.pop('layers')
else:
layers = VersionAwareLayers()
if kwargs:
raise ValueError('Unknown argument(s): %s' % (kwargs,))
if not (weights in {'imagenet', None} or os.path.exists(weights)):
@ -439,36 +442,11 @@ def _depthwise_conv_block(inputs,
@keras_export('keras.applications.mobilenet.preprocess_input')
def preprocess_input(x, data_format=None):
"""Preprocesses a numpy array encoding a batch of images.
Arguments
x: A 4D numpy array consists of RGB values within [0, 255].
Returns
Preprocessed array.
Raises
ValueError: In case of unknown `data_format` argument.
"""
return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf')
@keras_export('keras.applications.mobilenet.decode_predictions')
def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top)

View File

@ -80,9 +80,9 @@ from __future__ import print_function
import os
from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.platform import tf_logging as logging
@ -90,6 +90,7 @@ from tensorflow.python.util.tf_export import keras_export
BASE_WEIGHT_PATH = ('https://storage.googleapis.com/tensorflow/'
'keras-applications/mobilenet_v2/')
layers = None
@keras_export('keras.applications.mobilenet_v2.MobileNetV2',
@ -173,9 +174,11 @@ def MobileNetV2(input_shape=None,
ValueError: if `classifier_activation` is not `softmax` or `None` when
using a pretrained top layer.
"""
if 'layers' in kwargs:
global layers
if 'layers' in kwargs:
layers = kwargs.pop('layers')
else:
layers = VersionAwareLayers()
if kwargs:
raise ValueError('Unknown argument(s): %s' % (kwargs,))
if not (weights in {'imagenet', None} or os.path.exists(weights)):

View File

@ -44,9 +44,9 @@ from __future__ import print_function
import os
from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.platform import tf_logging as logging
@ -60,6 +60,8 @@ NASNET_MOBILE_WEIGHT_PATH_NO_TOP = BASE_WEIGHTS_PATH + 'NASNet-mobile-no-top.h5'
NASNET_LARGE_WEIGHT_PATH = BASE_WEIGHTS_PATH + 'NASNet-large.h5'
NASNET_LARGE_WEIGHT_PATH_NO_TOP = BASE_WEIGHTS_PATH + 'NASNet-large-no-top.h5'
layers = VersionAwareLayers()
def NASNet(
input_shape=None,
@ -74,8 +76,7 @@ def NASNet(
pooling=None,
classes=1000,
default_size=None,
classifier_activation='softmax',
):
classifier_activation='softmax'):
"""Instantiates a NASNet model.
Reference paper:
@ -785,36 +786,11 @@ def _reduction_a_cell(ip, p, filters, block_id=None):
@keras_export('keras.applications.nasnet.preprocess_input')
def preprocess_input(x, data_format=None):
"""Preprocesses a numpy array encoding a batch of images.
Arguments
x: A 4D numpy array consists of RGB values within [0, 255].
Returns
Preprocessed array.
Raises
ValueError: In case of unknown `data_format` argument.
"""
return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf')
@keras_export('keras.applications.nasnet.decode_predictions')
def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top)

View File

@ -26,9 +26,9 @@ from __future__ import print_function
import os
from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.util.tf_export import keras_export
@ -55,6 +55,8 @@ WEIGHTS_HASHES = {
('34fb605428fcc7aa4d62f44404c11509', '0f678c91647380debd923963594981b3')
}
layers = None
def ResNet(stack_fn,
preact,
@ -129,9 +131,11 @@ def ResNet(stack_fn,
ValueError: if `classifier_activation` is not `softmax` or `None` when
using a pretrained top layer.
"""
if 'layers' in kwargs:
global layers
if 'layers' in kwargs:
layers = kwargs.pop('layers')
else:
layers = VersionAwareLayers()
if kwargs:
raise ValueError('Unknown argument(s): %s' % (kwargs,))
if not (weights in {'imagenet', None} or os.path.exists(weights)):
@ -517,17 +521,6 @@ def ResNet152(include_top=True,
@keras_export('keras.applications.resnet50.preprocess_input',
'keras.applications.resnet.preprocess_input')
def preprocess_input(x, data_format=None):
"""Preprocesses a numpy array encoding a batch of images.
Arguments
x: A 4D numpy array consists of RGB values within [0, 255].
Returns
Preprocessed array.
Raises
ValueError: In case of unknown `data_format` argument.
"""
return imagenet_utils.preprocess_input(
x, data_format=data_format, mode='caffe')
@ -535,20 +528,6 @@ def preprocess_input(x, data_format=None):
@keras_export('keras.applications.resnet50.decode_predictions',
'keras.applications.resnet.decode_predictions')
def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top)

View File

@ -37,8 +37,7 @@ def ResNet50V2(
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
):
classifier_activation='softmax'):
"""Instantiates the ResNet50V2 architecture."""
def stack_fn(x):
x = resnet.stack2(x, 64, 3, name='conv2')
@ -57,8 +56,7 @@ def ResNet50V2(
input_shape,
pooling,
classes,
classifier_activation=classifier_activation,
)
classifier_activation=classifier_activation)
@keras_export('keras.applications.resnet_v2.ResNet101V2',
@ -70,8 +68,7 @@ def ResNet101V2(
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
):
classifier_activation='softmax'):
"""Instantiates the ResNet101V2 architecture."""
def stack_fn(x):
x = resnet.stack2(x, 64, 3, name='conv2')
@ -90,8 +87,7 @@ def ResNet101V2(
input_shape,
pooling,
classes,
classifier_activation=classifier_activation,
)
classifier_activation=classifier_activation)
@keras_export('keras.applications.resnet_v2.ResNet152V2',
@ -103,8 +99,7 @@ def ResNet152V2(
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
):
classifier_activation='softmax'):
"""Instantiates the ResNet152V2 architecture."""
def stack_fn(x):
x = resnet.stack2(x, 64, 3, name='conv2')
@ -123,43 +118,17 @@ def ResNet152V2(
input_shape,
pooling,
classes,
classifier_activation=classifier_activation,
)
classifier_activation=classifier_activation)
@keras_export('keras.applications.resnet_v2.preprocess_input')
def preprocess_input(x, data_format=None):
"""Preprocesses a numpy array encoding a batch of images.
Arguments
x: A 4D numpy array consists of RGB values within [0, 255].
Returns
Preprocessed array.
Raises
ValueError: In case of unknown `data_format` argument.
"""
return imagenet_utils.preprocess_input(
x, data_format=data_format, mode='tf')
@keras_export('keras.applications.resnet_v2.decode_predictions')
def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top)

View File

@ -26,9 +26,9 @@ from __future__ import print_function
import os
from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.util.tf_export import keras_export
@ -40,6 +40,8 @@ WEIGHTS_PATH_NO_TOP = ('https://storage.googleapis.com/tensorflow/'
'keras-applications/vgg16/'
'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5')
layers = VersionAwareLayers()
@keras_export('keras.applications.vgg16.VGG16', 'keras.applications.VGG16')
def VGG16(
@ -49,8 +51,7 @@ def VGG16(
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
):
classifier_activation='softmax'):
"""Instantiates the VGG16 model.
Reference paper:
@ -227,37 +228,12 @@ def VGG16(
@keras_export('keras.applications.vgg16.preprocess_input')
def preprocess_input(x, data_format=None):
"""Preprocesses a numpy array encoding a batch of images.
Arguments
x: A 4D numpy array consists of RGB values within [0, 255].
Returns
Preprocessed array.
Raises
ValueError: In case of unknown `data_format` argument.
"""
return imagenet_utils.preprocess_input(
x, data_format=data_format, mode='caffe')
@keras_export('keras.applications.vgg16.decode_predictions')
def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top)

View File

@ -26,9 +26,9 @@ from __future__ import print_function
import os
from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.util.tf_export import keras_export
@ -40,6 +40,8 @@ WEIGHTS_PATH_NO_TOP = ('https://storage.googleapis.com/tensorflow/'
'keras-applications/vgg19/'
'vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5')
layers = VersionAwareLayers()
@keras_export('keras.applications.vgg19.VGG19', 'keras.applications.VGG19')
def VGG19(
@ -49,8 +51,7 @@ def VGG19(
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
):
classifier_activation='softmax'):
"""Instantiates the VGG19 architecture.
Reference:
@ -232,37 +233,12 @@ def VGG19(
@keras_export('keras.applications.vgg19.preprocess_input')
def preprocess_input(x, data_format=None):
"""Preprocesses a numpy array encoding a batch of images.
Arguments
x: A 4D numpy array consists of RGB values within [0, 255].
Returns
Preprocessed array.
Raises
ValueError: In case of unknown `data_format` argument.
"""
return imagenet_utils.preprocess_input(
x, data_format=data_format, mode='caffe')
@keras_export('keras.applications.vgg19.decode_predictions')
def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top)

View File

@ -30,9 +30,9 @@ from __future__ import print_function
import os
from tensorflow.python.keras import backend
from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.util.tf_export import keras_export
@ -45,6 +45,8 @@ TF_WEIGHTS_PATH_NO_TOP = (
'https://storage.googleapis.com/tensorflow/keras-applications/'
'xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5')
layers = VersionAwareLayers()
@keras_export('keras.applications.xception.Xception',
'keras.applications.Xception')
@ -55,8 +57,7 @@ def Xception(
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax',
):
classifier_activation='softmax'):
"""Instantiates the Xception architecture.
Optionally loads weights pre-trained on ImageNet.
@ -312,36 +313,11 @@ def Xception(
@keras_export('keras.applications.xception.preprocess_input')
def preprocess_input(x, data_format=None):
"""Preprocesses a numpy array encoding a batch of images.
Arguments
x: A 4D numpy array consists of RGB values within [0, 255].
Returns
Preprocessed array.
Raises
ValueError: In case of unknown `data_format` argument.
"""
return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf')
@keras_export('keras.applications.xception.decode_predictions')
def decode_predictions(preds, top=5):
"""Decodes the prediction result from the model.
Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
Raises
ValueError: In case of invalid shape of the `preds` array (must be 2D).
"""
return imagenet_utils.decode_predictions(preds, top=top)

View File

@ -242,9 +242,28 @@ from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import DropoutWrapper
from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper
# Serialization functions
from tensorflow.python.keras.layers import serialization
from tensorflow.python.keras.layers.serialization import deserialize
from tensorflow.python.keras.layers.serialization import serialize
class VersionAwareLayers(object):
"""Utility to be used internally to access layers in a V1/V2-aware fashion.
When using layers within the Keras codebase, under the constraint that
e.g. `layers.BatchNormalization` should be the `BatchNormalization` version
corresponding to the current runtime (TF1 or TF2), do not simply access
`layers.BatchNormalization` since it would ignore e.g. an early
`compat.v2.disable_v2_behavior()` call. Instead, use an instance
of `VersionAwareLayers` (which you can use just like the `layers` module).
"""
def __getattr__(self, name):
serialization.populate_deserializable_objects()
if name in serialization.LOCAL.ALL_OBJECTS:
return serialization.LOCAL.ALL_OBJECTS[name]
return super(VersionAwareLayers, self).__getattr__(name)
del absolute_import
del division
del print_function

View File

@ -26,6 +26,7 @@ import threading
from tensorflow.python import tf2
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import input_layer
from tensorflow.python.keras.engine import input_spec
from tensorflow.python.keras.layers import advanced_activations
from tensorflow.python.keras.layers import convolutional
from tensorflow.python.keras.layers import convolutional_recurrent
@ -128,6 +129,7 @@ def populate_deserializable_objects():
from tensorflow.python.feature_column import sequence_feature_column as sfc # pylint: disable=g-import-not-at-top
LOCAL.ALL_OBJECTS['Input'] = input_layer.Input
LOCAL.ALL_OBJECTS['InputSpec'] = input_spec.InputSpec
LOCAL.ALL_OBJECTS['Network'] = models.Network
LOCAL.ALL_OBJECTS['Model'] = models.Model
LOCAL.ALL_OBJECTS['Sequential'] = models.Sequential
@ -135,6 +137,15 @@ def populate_deserializable_objects():
LOCAL.ALL_OBJECTS['WideDeepModel'] = WideDeepModel
LOCAL.ALL_OBJECTS['DenseFeatures'] = dense_features.DenseFeatures
LOCAL.ALL_OBJECTS['SequenceFeatures'] = sfc.SequenceFeatures
# Merge layers, function versions.
LOCAL.ALL_OBJECTS['add'] = merge.add
LOCAL.ALL_OBJECTS['subtract'] = merge.subtract
LOCAL.ALL_OBJECTS['multiply'] = merge.multiply
LOCAL.ALL_OBJECTS['average'] = merge.average
LOCAL.ALL_OBJECTS['maximum'] = merge.maximum
LOCAL.ALL_OBJECTS['minimum'] = merge.minimum
LOCAL.ALL_OBJECTS['concatenate'] = merge.concatenate
LOCAL.ALL_OBJECTS['dot'] = merge.dot
@keras_export('keras.layers.serialize')