diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py index 87dfa34f932..07cb1bdf1b3 100644 --- a/tensorflow/python/keras/layers/__init__.py +++ b/tensorflow/python/keras/layers/__init__.py @@ -18,8 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python import tf2 + # Generic layers. # pylint: disable=g-bad-import-order +# pylint: disable=g-import-not-at-top from tensorflow.python.keras.engine.input_layer import Input from tensorflow.python.keras.engine.input_layer import InputLayer from tensorflow.python.keras.engine.input_spec import InputSpec @@ -27,10 +30,20 @@ from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.engine.base_preprocessing_layer import PreprocessingLayer # Preprocessing layers. -from tensorflow.python.keras.layers.preprocessing.normalization import Normalization -from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization as NormalizationV1 -from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization -from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization as TextVectorizationV1 +if tf2.enabled(): + from tensorflow.python.keras.layers.preprocessing.normalization import Normalization + from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization as NormalizationV1 + NormalizationV2 = Normalization + from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization + from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization as TextVectorizationV1 + TextVectorizationV2 = TextVectorization +else: + from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization + from tensorflow.python.keras.layers.preprocessing.normalization import Normalization as NormalizationV2 + NormalizationV1 = Normalization + from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization + from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization as TextVectorizationV2 + TextVectorizationV1 = TextVectorization # Advanced activations. from tensorflow.python.keras.layers.advanced_activations import LeakyReLU @@ -121,8 +134,14 @@ from tensorflow.python.keras.layers.noise import GaussianDropout # Normalization layers. from tensorflow.python.keras.layers.normalization import LayerNormalization -from tensorflow.python.keras.layers.normalization import BatchNormalization -from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization as BatchNormalizationV2 +if tf2.enabled(): + from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization + from tensorflow.python.keras.layers.normalization import BatchNormalization as BatchNormalizationV1 + BatchNormalizationV2 = BatchNormalization +else: + from tensorflow.python.keras.layers.normalization import BatchNormalization + from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization as BatchNormalizationV2 + BatchNormalizationV1 = BatchNormalization # Kernelized layers. from tensorflow.python.keras.layers.kernelized import RandomFourierFeatures @@ -163,14 +182,32 @@ from tensorflow.python.keras.layers.recurrent import SimpleRNNCell from tensorflow.python.keras.layers.recurrent import PeepholeLSTMCell from tensorflow.python.keras.layers.recurrent import SimpleRNN -from tensorflow.python.keras.layers.recurrent import GRU -from tensorflow.python.keras.layers.recurrent import GRUCell -from tensorflow.python.keras.layers.recurrent import LSTM -from tensorflow.python.keras.layers.recurrent import LSTMCell -from tensorflow.python.keras.layers.recurrent_v2 import GRU as GRU_v2 -from tensorflow.python.keras.layers.recurrent_v2 import GRUCell as GRUCell_v2 -from tensorflow.python.keras.layers.recurrent_v2 import LSTM as LSTM_v2 -from tensorflow.python.keras.layers.recurrent_v2 import LSTMCell as LSTMCell_v2 +if tf2.enabled(): + from tensorflow.python.keras.layers.recurrent_v2 import GRU + from tensorflow.python.keras.layers.recurrent_v2 import GRUCell + from tensorflow.python.keras.layers.recurrent_v2 import LSTM + from tensorflow.python.keras.layers.recurrent_v2 import LSTMCell + from tensorflow.python.keras.layers.recurrent import GRU as GRUV1 + from tensorflow.python.keras.layers.recurrent import GRUCell as GRUCellV1 + from tensorflow.python.keras.layers.recurrent import LSTM as LSTMV1 + from tensorflow.python.keras.layers.recurrent import LSTMCell as LSTMCellV1 + GRUV2 = GRU + GRUCellV2 = GRUCell + LSTMV2 = LSTM + LSTMCellV2 = LSTMCell +else: + from tensorflow.python.keras.layers.recurrent import GRU + from tensorflow.python.keras.layers.recurrent import GRUCell + from tensorflow.python.keras.layers.recurrent import LSTM + from tensorflow.python.keras.layers.recurrent import LSTMCell + from tensorflow.python.keras.layers.recurrent_v2 import GRU as GRUV2 + from tensorflow.python.keras.layers.recurrent_v2 import GRUCell as GRUCellV2 + from tensorflow.python.keras.layers.recurrent_v2 import LSTM as LSTMV2 + from tensorflow.python.keras.layers.recurrent_v2 import LSTMCell as LSTMCellV2 + GRUV1 = GRU + GRUCellV1 = GRUCell + LSTMV1 = LSTM + LSTMCellV1 = LSTMCell # Convolutional-recurrent layers. from tensorflow.python.keras.layers.convolutional_recurrent import ConvLSTM2D diff --git a/tensorflow/python/keras/layers/cudnn_recurrent_test.py b/tensorflow/python/keras/layers/cudnn_recurrent_test.py index e3e193c3b63..1c20918ffc8 100644 --- a/tensorflow/python/keras/layers/cudnn_recurrent_test.py +++ b/tensorflow/python/keras/layers/cudnn_recurrent_test.py @@ -460,7 +460,7 @@ class CuDNNV1OnlyTest(keras_parameterized.TestCase): input_shape = (3, 5) def gru(cudnn=False, **kwargs): - layer_class = keras.layers.CuDNNGRU if cudnn else keras.layers.GRU + layer_class = keras.layers.CuDNNGRU if cudnn else keras.layers.GRUV1 return layer_class(2, input_shape=input_shape, **kwargs) def get_layer_weights(layer): diff --git a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py index 15cbf68c87a..a01e56be097 100644 --- a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py +++ b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py @@ -256,7 +256,7 @@ class RNNCellWrapperTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegexp(ValueError, "does not work with "): wrapper_cls(cell) - cell = layers.LSTMCell_v2(10) + cell = layers.LSTMCellV2(10) with self.assertRaisesRegexp(ValueError, "does not work with "): wrapper_cls(cell) diff --git a/tensorflow/python/keras/saving/hdf5_format_test.py b/tensorflow/python/keras/saving/hdf5_format_test.py index 96557410030..19340c1d86d 100644 --- a/tensorflow/python/keras/saving/hdf5_format_test.py +++ b/tensorflow/python/keras/saving/hdf5_format_test.py @@ -145,7 +145,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): (None, input_dim, 4, 4, 4), ], [ - (keras.layers.GRU(output_dim)), + (keras.layers.GRUV1(output_dim)), [np.random.random((input_dim, output_dim)), np.random.random((output_dim, output_dim)), np.random.random((output_dim,)), @@ -158,7 +158,7 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): (None, 4, input_dim), ], [ - (keras.layers.LSTM(output_dim)), + (keras.layers.LSTMV1(output_dim)), [np.random.random((input_dim, output_dim)), np.random.random((output_dim, output_dim)), np.random.random((output_dim,)), diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index c6f06069d7c..2554721eca2 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -20,7 +20,7 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras import layers as keras_layers +from tensorflow.python.keras.layers import normalization as keras_normalization from tensorflow.python.layers import base from tensorflow.python.ops import init_ops from tensorflow.python.util import deprecation @@ -28,7 +28,7 @@ from tensorflow.python.util.tf_export import tf_export @tf_export(v1=['layers.BatchNormalization']) -class BatchNormalization(keras_layers.BatchNormalization, base.Layer): +class BatchNormalization(keras_normalization.BatchNormalization, base.Layer): """Batch Normalization layer from http://arxiv.org/abs/1502.03167. "Batch Normalization: Accelerating Deep Network Training by Reducing @@ -170,7 +170,7 @@ class BatchNormalization(keras_layers.BatchNormalization, base.Layer): @deprecation.deprecated( date=None, instructions='Use keras.layers.BatchNormalization instead. In ' 'particular, `tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)` should not ' - 'be used (consult the `tf.keras.layers.batch_normalization` ' + 'be used (consult the `tf.keras.layers.BatchNormalization` ' 'documentation).') @tf_export(v1=['layers.batch_normalization']) def batch_normalization(inputs,