Merge pull request #35066 from tensorflow/layer_imports_internal

[r2.1 Cherrypick] Unify V1/2 layer naming in internal imports
This commit is contained in:
Goldie Gadde 2019-12-20 16:35:07 -08:00 committed by GitHub
commit 03a3020934
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 58 additions and 21 deletions

View File

@ -18,8 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import tf2
# Generic layers.
# pylint: disable=g-bad-import-order
# pylint: disable=g-import-not-at-top
from tensorflow.python.keras.engine.input_layer import Input
from tensorflow.python.keras.engine.input_layer import InputLayer
from tensorflow.python.keras.engine.input_spec import InputSpec
@ -27,10 +30,20 @@ from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.base_preprocessing_layer import PreprocessingLayer
# Preprocessing layers.
from tensorflow.python.keras.layers.preprocessing.normalization import Normalization
from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization as NormalizationV1
from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization
from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization as TextVectorizationV1
if tf2.enabled():
from tensorflow.python.keras.layers.preprocessing.normalization import Normalization
from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization as NormalizationV1
NormalizationV2 = Normalization
from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization
from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization as TextVectorizationV1
TextVectorizationV2 = TextVectorization
else:
from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization
from tensorflow.python.keras.layers.preprocessing.normalization import Normalization as NormalizationV2
NormalizationV1 = Normalization
from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization
from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization as TextVectorizationV2
TextVectorizationV1 = TextVectorization
# Advanced activations.
from tensorflow.python.keras.layers.advanced_activations import LeakyReLU
@ -121,8 +134,14 @@ from tensorflow.python.keras.layers.noise import GaussianDropout
# Normalization layers.
from tensorflow.python.keras.layers.normalization import LayerNormalization
from tensorflow.python.keras.layers.normalization import BatchNormalization
from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization as BatchNormalizationV2
if tf2.enabled():
from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization
from tensorflow.python.keras.layers.normalization import BatchNormalization as BatchNormalizationV1
BatchNormalizationV2 = BatchNormalization
else:
from tensorflow.python.keras.layers.normalization import BatchNormalization
from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization as BatchNormalizationV2
BatchNormalizationV1 = BatchNormalization
# Kernelized layers.
from tensorflow.python.keras.layers.kernelized import RandomFourierFeatures
@ -163,14 +182,32 @@ from tensorflow.python.keras.layers.recurrent import SimpleRNNCell
from tensorflow.python.keras.layers.recurrent import PeepholeLSTMCell
from tensorflow.python.keras.layers.recurrent import SimpleRNN
from tensorflow.python.keras.layers.recurrent import GRU
from tensorflow.python.keras.layers.recurrent import GRUCell
from tensorflow.python.keras.layers.recurrent import LSTM
from tensorflow.python.keras.layers.recurrent import LSTMCell
from tensorflow.python.keras.layers.recurrent_v2 import GRU as GRU_v2
from tensorflow.python.keras.layers.recurrent_v2 import GRUCell as GRUCell_v2
from tensorflow.python.keras.layers.recurrent_v2 import LSTM as LSTM_v2
from tensorflow.python.keras.layers.recurrent_v2 import LSTMCell as LSTMCell_v2
if tf2.enabled():
from tensorflow.python.keras.layers.recurrent_v2 import GRU
from tensorflow.python.keras.layers.recurrent_v2 import GRUCell
from tensorflow.python.keras.layers.recurrent_v2 import LSTM
from tensorflow.python.keras.layers.recurrent_v2 import LSTMCell
from tensorflow.python.keras.layers.recurrent import GRU as GRUV1
from tensorflow.python.keras.layers.recurrent import GRUCell as GRUCellV1
from tensorflow.python.keras.layers.recurrent import LSTM as LSTMV1
from tensorflow.python.keras.layers.recurrent import LSTMCell as LSTMCellV1
GRUV2 = GRU
GRUCellV2 = GRUCell
LSTMV2 = LSTM
LSTMCellV2 = LSTMCell
else:
from tensorflow.python.keras.layers.recurrent import GRU
from tensorflow.python.keras.layers.recurrent import GRUCell
from tensorflow.python.keras.layers.recurrent import LSTM
from tensorflow.python.keras.layers.recurrent import LSTMCell
from tensorflow.python.keras.layers.recurrent_v2 import GRU as GRUV2
from tensorflow.python.keras.layers.recurrent_v2 import GRUCell as GRUCellV2
from tensorflow.python.keras.layers.recurrent_v2 import LSTM as LSTMV2
from tensorflow.python.keras.layers.recurrent_v2 import LSTMCell as LSTMCellV2
GRUV1 = GRU
GRUCellV1 = GRUCell
LSTMV1 = LSTM
LSTMCellV1 = LSTMCell
# Convolutional-recurrent layers.
from tensorflow.python.keras.layers.convolutional_recurrent import ConvLSTM2D

View File

@ -460,7 +460,7 @@ class CuDNNV1OnlyTest(keras_parameterized.TestCase):
input_shape = (3, 5)
def gru(cudnn=False, **kwargs):
layer_class = keras.layers.CuDNNGRU if cudnn else keras.layers.GRU
layer_class = keras.layers.CuDNNGRU if cudnn else keras.layers.GRUV1
return layer_class(2, input_shape=input_shape, **kwargs)
def get_layer_weights(layer):

View File

@ -256,7 +256,7 @@ class RNNCellWrapperTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegexp(ValueError, "does not work with "):
wrapper_cls(cell)
cell = layers.LSTMCell_v2(10)
cell = layers.LSTMCellV2(10)
with self.assertRaisesRegexp(ValueError, "does not work with "):
wrapper_cls(cell)

View File

@ -145,7 +145,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
(None, input_dim, 4, 4, 4),
],
[
(keras.layers.GRU(output_dim)),
(keras.layers.GRUV1(output_dim)),
[np.random.random((input_dim, output_dim)),
np.random.random((output_dim, output_dim)),
np.random.random((output_dim,)),
@ -158,7 +158,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
(None, 4, input_dim),
],
[
(keras.layers.LSTM(output_dim)),
(keras.layers.LSTMV1(output_dim)),
[np.random.random((input_dim, output_dim)),
np.random.random((output_dim, output_dim)),
np.random.random((output_dim,)),

View File

@ -20,7 +20,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.keras import layers as keras_layers
from tensorflow.python.keras.layers import normalization as keras_normalization
from tensorflow.python.layers import base
from tensorflow.python.ops import init_ops
from tensorflow.python.util import deprecation
@ -28,7 +28,7 @@ from tensorflow.python.util.tf_export import tf_export
@tf_export(v1=['layers.BatchNormalization'])
class BatchNormalization(keras_layers.BatchNormalization, base.Layer):
class BatchNormalization(keras_normalization.BatchNormalization, base.Layer):
"""Batch Normalization layer from http://arxiv.org/abs/1502.03167.
"Batch Normalization: Accelerating Deep Network Training by Reducing
@ -170,7 +170,7 @@ class BatchNormalization(keras_layers.BatchNormalization, base.Layer):
@deprecation.deprecated(
date=None, instructions='Use keras.layers.BatchNormalization instead. In '
'particular, `tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)` should not '
'be used (consult the `tf.keras.layers.batch_normalization` '
'be used (consult the `tf.keras.layers.BatchNormalization` '
'documentation).')
@tf_export(v1=['layers.batch_normalization'])
def batch_normalization(inputs,