298 lines
15 KiB
Python
298 lines
15 KiB
Python
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Keras layers API."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from tensorflow.python import tf2
|
|
|
|
# Generic layers.
|
|
# pylint: disable=g-bad-import-order
|
|
# pylint: disable=g-import-not-at-top
|
|
from tensorflow.python.keras.engine.input_layer import Input
|
|
from tensorflow.python.keras.engine.input_layer import InputLayer
|
|
from tensorflow.python.keras.engine.input_spec import InputSpec
|
|
from tensorflow.python.keras.engine.base_layer import Layer
|
|
from tensorflow.python.keras.engine.base_preprocessing_layer import PreprocessingLayer
|
|
|
|
# Image preprocessing layers.
|
|
from tensorflow.python.keras.layers.preprocessing.image_preprocessing import CenterCrop
|
|
from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomCrop
|
|
from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomFlip
|
|
from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomContrast
|
|
from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomHeight
|
|
from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomRotation
|
|
from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomTranslation
|
|
from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomWidth
|
|
from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomZoom
|
|
from tensorflow.python.keras.layers.preprocessing.image_preprocessing import Resizing
|
|
from tensorflow.python.keras.layers.preprocessing.image_preprocessing import Rescaling
|
|
|
|
# Preprocessing layers.
|
|
if tf2.enabled():
|
|
from tensorflow.python.keras.layers.preprocessing.category_encoding import CategoryEncoding
|
|
from tensorflow.python.keras.layers.preprocessing.category_encoding_v1 import CategoryEncoding as CategoryEncodingV1
|
|
CategoryEncodingV2 = CategoryEncoding
|
|
from tensorflow.python.keras.layers.preprocessing.integer_lookup import IntegerLookup
|
|
from tensorflow.python.keras.layers.preprocessing.integer_lookup_v1 import IntegerLookup as IntegerLookupV1
|
|
IntegerLookupV2 = IntegerLookup
|
|
from tensorflow.python.keras.layers.preprocessing.normalization import Normalization
|
|
from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization as NormalizationV1
|
|
NormalizationV2 = Normalization
|
|
from tensorflow.python.keras.layers.preprocessing.string_lookup import StringLookup
|
|
from tensorflow.python.keras.layers.preprocessing.string_lookup_v1 import StringLookup as StringLookupV1
|
|
StringLookupV2 = StringLookup
|
|
from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization
|
|
from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization as TextVectorizationV1
|
|
TextVectorizationV2 = TextVectorization
|
|
else:
|
|
from tensorflow.python.keras.layers.preprocessing.integer_lookup_v1 import IntegerLookup
|
|
from tensorflow.python.keras.layers.preprocessing.integer_lookup import IntegerLookup as IntegerLookupV2
|
|
IntegerLookupV1 = IntegerLookup
|
|
from tensorflow.python.keras.layers.preprocessing.category_encoding_v1 import CategoryEncoding
|
|
from tensorflow.python.keras.layers.preprocessing.category_encoding import CategoryEncoding as CategoryEncodingV2
|
|
CategoryEncodingV1 = CategoryEncoding
|
|
from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization
|
|
from tensorflow.python.keras.layers.preprocessing.normalization import Normalization as NormalizationV2
|
|
NormalizationV1 = Normalization
|
|
from tensorflow.python.keras.layers.preprocessing.string_lookup_v1 import StringLookup
|
|
from tensorflow.python.keras.layers.preprocessing.string_lookup import StringLookup as StringLookupV2
|
|
StringLookupV1 = StringLookup
|
|
from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization
|
|
from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization as TextVectorizationV2
|
|
TextVectorizationV1 = TextVectorization
|
|
from tensorflow.python.keras.layers.preprocessing.category_crossing import CategoryCrossing
|
|
from tensorflow.python.keras.layers.preprocessing.discretization import Discretization
|
|
from tensorflow.python.keras.layers.preprocessing.hashing import Hashing
|
|
|
|
# Advanced activations.
|
|
from tensorflow.python.keras.layers.advanced_activations import LeakyReLU
|
|
from tensorflow.python.keras.layers.advanced_activations import PReLU
|
|
from tensorflow.python.keras.layers.advanced_activations import ELU
|
|
from tensorflow.python.keras.layers.advanced_activations import ReLU
|
|
from tensorflow.python.keras.layers.advanced_activations import ThresholdedReLU
|
|
from tensorflow.python.keras.layers.advanced_activations import Softmax
|
|
|
|
# Convolution layers.
|
|
from tensorflow.python.keras.layers.convolutional import Conv1D
|
|
from tensorflow.python.keras.layers.convolutional import Conv2D
|
|
from tensorflow.python.keras.layers.convolutional import Conv3D
|
|
from tensorflow.python.keras.layers.convolutional import Conv1DTranspose
|
|
from tensorflow.python.keras.layers.convolutional import Conv2DTranspose
|
|
from tensorflow.python.keras.layers.convolutional import Conv3DTranspose
|
|
from tensorflow.python.keras.layers.convolutional import SeparableConv1D
|
|
from tensorflow.python.keras.layers.convolutional import SeparableConv2D
|
|
|
|
# Convolution layer aliases.
|
|
from tensorflow.python.keras.layers.convolutional import Convolution1D
|
|
from tensorflow.python.keras.layers.convolutional import Convolution2D
|
|
from tensorflow.python.keras.layers.convolutional import Convolution3D
|
|
from tensorflow.python.keras.layers.convolutional import Convolution2DTranspose
|
|
from tensorflow.python.keras.layers.convolutional import Convolution3DTranspose
|
|
from tensorflow.python.keras.layers.convolutional import SeparableConvolution1D
|
|
from tensorflow.python.keras.layers.convolutional import SeparableConvolution2D
|
|
from tensorflow.python.keras.layers.convolutional import DepthwiseConv2D
|
|
|
|
# Image processing layers.
|
|
from tensorflow.python.keras.layers.convolutional import UpSampling1D
|
|
from tensorflow.python.keras.layers.convolutional import UpSampling2D
|
|
from tensorflow.python.keras.layers.convolutional import UpSampling3D
|
|
from tensorflow.python.keras.layers.convolutional import ZeroPadding1D
|
|
from tensorflow.python.keras.layers.convolutional import ZeroPadding2D
|
|
from tensorflow.python.keras.layers.convolutional import ZeroPadding3D
|
|
from tensorflow.python.keras.layers.convolutional import Cropping1D
|
|
from tensorflow.python.keras.layers.convolutional import Cropping2D
|
|
from tensorflow.python.keras.layers.convolutional import Cropping3D
|
|
|
|
# Core layers.
|
|
from tensorflow.python.keras.layers.core import Masking
|
|
from tensorflow.python.keras.layers.core import Dropout
|
|
from tensorflow.python.keras.layers.core import SpatialDropout1D
|
|
from tensorflow.python.keras.layers.core import SpatialDropout2D
|
|
from tensorflow.python.keras.layers.core import SpatialDropout3D
|
|
from tensorflow.python.keras.layers.core import Activation
|
|
from tensorflow.python.keras.layers.core import Reshape
|
|
from tensorflow.python.keras.layers.core import Permute
|
|
from tensorflow.python.keras.layers.core import Flatten
|
|
from tensorflow.python.keras.layers.core import RepeatVector
|
|
from tensorflow.python.keras.layers.core import Lambda
|
|
from tensorflow.python.keras.layers.core import Dense
|
|
from tensorflow.python.keras.layers.core import ActivityRegularization
|
|
|
|
# Dense Attention layers.
|
|
from tensorflow.python.keras.layers.dense_attention import AdditiveAttention
|
|
from tensorflow.python.keras.layers.dense_attention import Attention
|
|
|
|
# Embedding layers.
|
|
from tensorflow.python.keras.layers.embeddings import Embedding
|
|
|
|
# Einsum-based dense layer/
|
|
from tensorflow.python.keras.layers.einsum_dense import EinsumDense
|
|
|
|
# Multi-head Attention layer.
|
|
from tensorflow.python.keras.layers.multi_head_attention import MultiHeadAttention
|
|
|
|
# Locally-connected layers.
|
|
from tensorflow.python.keras.layers.local import LocallyConnected1D
|
|
from tensorflow.python.keras.layers.local import LocallyConnected2D
|
|
|
|
# Merge layers.
|
|
from tensorflow.python.keras.layers.merge import Add
|
|
from tensorflow.python.keras.layers.merge import Subtract
|
|
from tensorflow.python.keras.layers.merge import Multiply
|
|
from tensorflow.python.keras.layers.merge import Average
|
|
from tensorflow.python.keras.layers.merge import Maximum
|
|
from tensorflow.python.keras.layers.merge import Minimum
|
|
from tensorflow.python.keras.layers.merge import Concatenate
|
|
from tensorflow.python.keras.layers.merge import Dot
|
|
from tensorflow.python.keras.layers.merge import add
|
|
from tensorflow.python.keras.layers.merge import subtract
|
|
from tensorflow.python.keras.layers.merge import multiply
|
|
from tensorflow.python.keras.layers.merge import average
|
|
from tensorflow.python.keras.layers.merge import maximum
|
|
from tensorflow.python.keras.layers.merge import minimum
|
|
from tensorflow.python.keras.layers.merge import concatenate
|
|
from tensorflow.python.keras.layers.merge import dot
|
|
|
|
# Noise layers.
|
|
from tensorflow.python.keras.layers.noise import AlphaDropout
|
|
from tensorflow.python.keras.layers.noise import GaussianNoise
|
|
from tensorflow.python.keras.layers.noise import GaussianDropout
|
|
|
|
# Normalization layers.
|
|
from tensorflow.python.keras.layers.normalization import LayerNormalization
|
|
from tensorflow.python.keras.layers.normalization_v2 import SyncBatchNormalization
|
|
|
|
if tf2.enabled():
|
|
from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization
|
|
from tensorflow.python.keras.layers.normalization import BatchNormalization as BatchNormalizationV1
|
|
BatchNormalizationV2 = BatchNormalization
|
|
else:
|
|
from tensorflow.python.keras.layers.normalization import BatchNormalization
|
|
from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization as BatchNormalizationV2
|
|
BatchNormalizationV1 = BatchNormalization
|
|
|
|
# Kernelized layers.
|
|
from tensorflow.python.keras.layers.kernelized import RandomFourierFeatures
|
|
|
|
# Pooling layers.
|
|
from tensorflow.python.keras.layers.pooling import MaxPooling1D
|
|
from tensorflow.python.keras.layers.pooling import MaxPooling2D
|
|
from tensorflow.python.keras.layers.pooling import MaxPooling3D
|
|
from tensorflow.python.keras.layers.pooling import AveragePooling1D
|
|
from tensorflow.python.keras.layers.pooling import AveragePooling2D
|
|
from tensorflow.python.keras.layers.pooling import AveragePooling3D
|
|
from tensorflow.python.keras.layers.pooling import GlobalAveragePooling1D
|
|
from tensorflow.python.keras.layers.pooling import GlobalAveragePooling2D
|
|
from tensorflow.python.keras.layers.pooling import GlobalAveragePooling3D
|
|
from tensorflow.python.keras.layers.pooling import GlobalMaxPooling1D
|
|
from tensorflow.python.keras.layers.pooling import GlobalMaxPooling2D
|
|
from tensorflow.python.keras.layers.pooling import GlobalMaxPooling3D
|
|
|
|
# Pooling layer aliases.
|
|
from tensorflow.python.keras.layers.pooling import MaxPool1D
|
|
from tensorflow.python.keras.layers.pooling import MaxPool2D
|
|
from tensorflow.python.keras.layers.pooling import MaxPool3D
|
|
from tensorflow.python.keras.layers.pooling import AvgPool1D
|
|
from tensorflow.python.keras.layers.pooling import AvgPool2D
|
|
from tensorflow.python.keras.layers.pooling import AvgPool3D
|
|
from tensorflow.python.keras.layers.pooling import GlobalAvgPool1D
|
|
from tensorflow.python.keras.layers.pooling import GlobalAvgPool2D
|
|
from tensorflow.python.keras.layers.pooling import GlobalAvgPool3D
|
|
from tensorflow.python.keras.layers.pooling import GlobalMaxPool1D
|
|
from tensorflow.python.keras.layers.pooling import GlobalMaxPool2D
|
|
from tensorflow.python.keras.layers.pooling import GlobalMaxPool3D
|
|
|
|
# Recurrent layers.
|
|
from tensorflow.python.keras.layers.recurrent import RNN
|
|
from tensorflow.python.keras.layers.recurrent import AbstractRNNCell
|
|
from tensorflow.python.keras.layers.recurrent import StackedRNNCells
|
|
from tensorflow.python.keras.layers.recurrent import SimpleRNNCell
|
|
from tensorflow.python.keras.layers.recurrent import PeepholeLSTMCell
|
|
from tensorflow.python.keras.layers.recurrent import SimpleRNN
|
|
|
|
if tf2.enabled():
|
|
from tensorflow.python.keras.layers.recurrent_v2 import GRU
|
|
from tensorflow.python.keras.layers.recurrent_v2 import GRUCell
|
|
from tensorflow.python.keras.layers.recurrent_v2 import LSTM
|
|
from tensorflow.python.keras.layers.recurrent_v2 import LSTMCell
|
|
from tensorflow.python.keras.layers.recurrent import GRU as GRUV1
|
|
from tensorflow.python.keras.layers.recurrent import GRUCell as GRUCellV1
|
|
from tensorflow.python.keras.layers.recurrent import LSTM as LSTMV1
|
|
from tensorflow.python.keras.layers.recurrent import LSTMCell as LSTMCellV1
|
|
GRUV2 = GRU
|
|
GRUCellV2 = GRUCell
|
|
LSTMV2 = LSTM
|
|
LSTMCellV2 = LSTMCell
|
|
else:
|
|
from tensorflow.python.keras.layers.recurrent import GRU
|
|
from tensorflow.python.keras.layers.recurrent import GRUCell
|
|
from tensorflow.python.keras.layers.recurrent import LSTM
|
|
from tensorflow.python.keras.layers.recurrent import LSTMCell
|
|
from tensorflow.python.keras.layers.recurrent_v2 import GRU as GRUV2
|
|
from tensorflow.python.keras.layers.recurrent_v2 import GRUCell as GRUCellV2
|
|
from tensorflow.python.keras.layers.recurrent_v2 import LSTM as LSTMV2
|
|
from tensorflow.python.keras.layers.recurrent_v2 import LSTMCell as LSTMCellV2
|
|
GRUV1 = GRU
|
|
GRUCellV1 = GRUCell
|
|
LSTMV1 = LSTM
|
|
LSTMCellV1 = LSTMCell
|
|
|
|
# Convolutional-recurrent layers.
|
|
from tensorflow.python.keras.layers.convolutional_recurrent import ConvLSTM2D
|
|
|
|
# CuDNN recurrent layers.
|
|
from tensorflow.python.keras.layers.cudnn_recurrent import CuDNNLSTM
|
|
from tensorflow.python.keras.layers.cudnn_recurrent import CuDNNGRU
|
|
|
|
# Wrapper functions
|
|
from tensorflow.python.keras.layers.wrappers import Wrapper
|
|
from tensorflow.python.keras.layers.wrappers import Bidirectional
|
|
from tensorflow.python.keras.layers.wrappers import TimeDistributed
|
|
|
|
# # RNN Cell wrappers.
|
|
from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import DeviceWrapper
|
|
from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import DropoutWrapper
|
|
from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper
|
|
|
|
# Serialization functions
|
|
from tensorflow.python.keras.layers import serialization
|
|
from tensorflow.python.keras.layers.serialization import deserialize
|
|
from tensorflow.python.keras.layers.serialization import serialize
|
|
|
|
|
|
class VersionAwareLayers(object):
|
|
"""Utility to be used internally to access layers in a V1/V2-aware fashion.
|
|
|
|
When using layers within the Keras codebase, under the constraint that
|
|
e.g. `layers.BatchNormalization` should be the `BatchNormalization` version
|
|
corresponding to the current runtime (TF1 or TF2), do not simply access
|
|
`layers.BatchNormalization` since it would ignore e.g. an early
|
|
`compat.v2.disable_v2_behavior()` call. Instead, use an instance
|
|
of `VersionAwareLayers` (which you can use just like the `layers` module).
|
|
"""
|
|
|
|
def __getattr__(self, name):
|
|
serialization.populate_deserializable_objects()
|
|
if name in serialization.LOCAL.ALL_OBJECTS:
|
|
return serialization.LOCAL.ALL_OBJECTS[name]
|
|
return super(VersionAwareLayers, self).__getattr__(name)
|
|
|
|
del absolute_import
|
|
del division
|
|
del print_function
|