Merge pull request #35066 from tensorflow/layer_imports_internal
[r2.1 Cherrypick] Unify V1/2 layer naming in internal imports
This commit is contained in:
commit
03a3020934
@ -18,8 +18,11 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import tf2
|
||||
|
||||
# Generic layers.
|
||||
# pylint: disable=g-bad-import-order
|
||||
# pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.keras.engine.input_layer import Input
|
||||
from tensorflow.python.keras.engine.input_layer import InputLayer
|
||||
from tensorflow.python.keras.engine.input_spec import InputSpec
|
||||
@ -27,10 +30,20 @@ from tensorflow.python.keras.engine.base_layer import Layer
|
||||
from tensorflow.python.keras.engine.base_preprocessing_layer import PreprocessingLayer
|
||||
|
||||
# Preprocessing layers.
|
||||
from tensorflow.python.keras.layers.preprocessing.normalization import Normalization
|
||||
from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization as NormalizationV1
|
||||
from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization
|
||||
from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization as TextVectorizationV1
|
||||
if tf2.enabled():
|
||||
from tensorflow.python.keras.layers.preprocessing.normalization import Normalization
|
||||
from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization as NormalizationV1
|
||||
NormalizationV2 = Normalization
|
||||
from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization
|
||||
from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization as TextVectorizationV1
|
||||
TextVectorizationV2 = TextVectorization
|
||||
else:
|
||||
from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization
|
||||
from tensorflow.python.keras.layers.preprocessing.normalization import Normalization as NormalizationV2
|
||||
NormalizationV1 = Normalization
|
||||
from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization
|
||||
from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization as TextVectorizationV2
|
||||
TextVectorizationV1 = TextVectorization
|
||||
|
||||
# Advanced activations.
|
||||
from tensorflow.python.keras.layers.advanced_activations import LeakyReLU
|
||||
@ -121,8 +134,14 @@ from tensorflow.python.keras.layers.noise import GaussianDropout
|
||||
|
||||
# Normalization layers.
|
||||
from tensorflow.python.keras.layers.normalization import LayerNormalization
|
||||
from tensorflow.python.keras.layers.normalization import BatchNormalization
|
||||
from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization as BatchNormalizationV2
|
||||
if tf2.enabled():
|
||||
from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization
|
||||
from tensorflow.python.keras.layers.normalization import BatchNormalization as BatchNormalizationV1
|
||||
BatchNormalizationV2 = BatchNormalization
|
||||
else:
|
||||
from tensorflow.python.keras.layers.normalization import BatchNormalization
|
||||
from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization as BatchNormalizationV2
|
||||
BatchNormalizationV1 = BatchNormalization
|
||||
|
||||
# Kernelized layers.
|
||||
from tensorflow.python.keras.layers.kernelized import RandomFourierFeatures
|
||||
@ -163,14 +182,32 @@ from tensorflow.python.keras.layers.recurrent import SimpleRNNCell
|
||||
from tensorflow.python.keras.layers.recurrent import PeepholeLSTMCell
|
||||
from tensorflow.python.keras.layers.recurrent import SimpleRNN
|
||||
|
||||
from tensorflow.python.keras.layers.recurrent import GRU
|
||||
from tensorflow.python.keras.layers.recurrent import GRUCell
|
||||
from tensorflow.python.keras.layers.recurrent import LSTM
|
||||
from tensorflow.python.keras.layers.recurrent import LSTMCell
|
||||
from tensorflow.python.keras.layers.recurrent_v2 import GRU as GRU_v2
|
||||
from tensorflow.python.keras.layers.recurrent_v2 import GRUCell as GRUCell_v2
|
||||
from tensorflow.python.keras.layers.recurrent_v2 import LSTM as LSTM_v2
|
||||
from tensorflow.python.keras.layers.recurrent_v2 import LSTMCell as LSTMCell_v2
|
||||
if tf2.enabled():
|
||||
from tensorflow.python.keras.layers.recurrent_v2 import GRU
|
||||
from tensorflow.python.keras.layers.recurrent_v2 import GRUCell
|
||||
from tensorflow.python.keras.layers.recurrent_v2 import LSTM
|
||||
from tensorflow.python.keras.layers.recurrent_v2 import LSTMCell
|
||||
from tensorflow.python.keras.layers.recurrent import GRU as GRUV1
|
||||
from tensorflow.python.keras.layers.recurrent import GRUCell as GRUCellV1
|
||||
from tensorflow.python.keras.layers.recurrent import LSTM as LSTMV1
|
||||
from tensorflow.python.keras.layers.recurrent import LSTMCell as LSTMCellV1
|
||||
GRUV2 = GRU
|
||||
GRUCellV2 = GRUCell
|
||||
LSTMV2 = LSTM
|
||||
LSTMCellV2 = LSTMCell
|
||||
else:
|
||||
from tensorflow.python.keras.layers.recurrent import GRU
|
||||
from tensorflow.python.keras.layers.recurrent import GRUCell
|
||||
from tensorflow.python.keras.layers.recurrent import LSTM
|
||||
from tensorflow.python.keras.layers.recurrent import LSTMCell
|
||||
from tensorflow.python.keras.layers.recurrent_v2 import GRU as GRUV2
|
||||
from tensorflow.python.keras.layers.recurrent_v2 import GRUCell as GRUCellV2
|
||||
from tensorflow.python.keras.layers.recurrent_v2 import LSTM as LSTMV2
|
||||
from tensorflow.python.keras.layers.recurrent_v2 import LSTMCell as LSTMCellV2
|
||||
GRUV1 = GRU
|
||||
GRUCellV1 = GRUCell
|
||||
LSTMV1 = LSTM
|
||||
LSTMCellV1 = LSTMCell
|
||||
|
||||
# Convolutional-recurrent layers.
|
||||
from tensorflow.python.keras.layers.convolutional_recurrent import ConvLSTM2D
|
||||
|
@ -460,7 +460,7 @@ class CuDNNV1OnlyTest(keras_parameterized.TestCase):
|
||||
input_shape = (3, 5)
|
||||
|
||||
def gru(cudnn=False, **kwargs):
|
||||
layer_class = keras.layers.CuDNNGRU if cudnn else keras.layers.GRU
|
||||
layer_class = keras.layers.CuDNNGRU if cudnn else keras.layers.GRUV1
|
||||
return layer_class(2, input_shape=input_shape, **kwargs)
|
||||
|
||||
def get_layer_weights(layer):
|
||||
|
@ -256,7 +256,7 @@ class RNNCellWrapperTest(test.TestCase, parameterized.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, "does not work with "):
|
||||
wrapper_cls(cell)
|
||||
|
||||
cell = layers.LSTMCell_v2(10)
|
||||
cell = layers.LSTMCellV2(10)
|
||||
with self.assertRaisesRegexp(ValueError, "does not work with "):
|
||||
wrapper_cls(cell)
|
||||
|
||||
|
@ -145,7 +145,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
||||
(None, input_dim, 4, 4, 4),
|
||||
],
|
||||
[
|
||||
(keras.layers.GRU(output_dim)),
|
||||
(keras.layers.GRUV1(output_dim)),
|
||||
[np.random.random((input_dim, output_dim)),
|
||||
np.random.random((output_dim, output_dim)),
|
||||
np.random.random((output_dim,)),
|
||||
@ -158,7 +158,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
||||
(None, 4, input_dim),
|
||||
],
|
||||
[
|
||||
(keras.layers.LSTM(output_dim)),
|
||||
(keras.layers.LSTMV1(output_dim)),
|
||||
[np.random.random((input_dim, output_dim)),
|
||||
np.random.random((output_dim, output_dim)),
|
||||
np.random.random((output_dim,)),
|
||||
|
@ -20,7 +20,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
from tensorflow.python.keras import layers as keras_layers
|
||||
from tensorflow.python.keras.layers import normalization as keras_normalization
|
||||
from tensorflow.python.layers import base
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.util import deprecation
|
||||
@ -28,7 +28,7 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export(v1=['layers.BatchNormalization'])
|
||||
class BatchNormalization(keras_layers.BatchNormalization, base.Layer):
|
||||
class BatchNormalization(keras_normalization.BatchNormalization, base.Layer):
|
||||
"""Batch Normalization layer from http://arxiv.org/abs/1502.03167.
|
||||
|
||||
"Batch Normalization: Accelerating Deep Network Training by Reducing
|
||||
@ -170,7 +170,7 @@ class BatchNormalization(keras_layers.BatchNormalization, base.Layer):
|
||||
@deprecation.deprecated(
|
||||
date=None, instructions='Use keras.layers.BatchNormalization instead. In '
|
||||
'particular, `tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)` should not '
|
||||
'be used (consult the `tf.keras.layers.batch_normalization` '
|
||||
'be used (consult the `tf.keras.layers.BatchNormalization` '
|
||||
'documentation).')
|
||||
@tf_export(v1=['layers.batch_normalization'])
|
||||
def batch_normalization(inputs,
|
||||
|
Loading…
x
Reference in New Issue
Block a user