From 0f371719b10617d6f7da6622c345aa3b7b00175e Mon Sep 17 00:00:00 2001 From: Yanhui Liang Date: Wed, 13 Mar 2019 14:54:13 -0700 Subject: [PATCH] Fix serialization/deserialization for BatchNorm v1 and v2. PiperOrigin-RevId: 238315500 --- tensorflow/python/compat/v2_compat.py | 3 - tensorflow/python/keras/BUILD | 1 + tensorflow/python/keras/layers/__init__.py | 3 +- .../python/keras/layers/normalization.py | 33 +-------- .../python/keras/layers/normalization_test.py | 67 ++++++++++--------- .../python/keras/layers/normalization_v2.py | 28 ++++++++ .../python/keras/layers/serialization.py | 13 +--- .../python/keras/layers/serialization_test.py | 10 ++- ...ow.keras.layers.-batch-normalization.pbtxt | 2 +- ...nsorflow.layers.-batch-normalization.pbtxt | 2 +- .../tools/api/golden/v1/tensorflow.pbtxt | 8 --- ...ow.keras.layers.-batch-normalization.pbtxt | 2 +- 12 files changed, 81 insertions(+), 91 deletions(-) create mode 100644 tensorflow/python/keras/layers/normalization_v2.py diff --git a/tensorflow/python/compat/v2_compat.py b/tensorflow/python/compat/v2_compat.py index 8a94939ae11..9961cae11c5 100644 --- a/tensorflow/python/compat/v2_compat.py +++ b/tensorflow/python/compat/v2_compat.py @@ -21,7 +21,6 @@ from __future__ import print_function from tensorflow.python import tf2 from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras.layers import normalization from tensorflow.python.ops import variable_scope from tensorflow.python.util.tf_export import tf_export @@ -43,7 +42,6 @@ def enable_v2_behavior(): ops.enable_eager_execution() tensor_shape.enable_v2_tensorshape() # Also switched by tf2 variable_scope.enable_resource_variables() - normalization.enable_v2_batch_normalization() @tf_export(v1=["disable_v2_behavior"]) @@ -61,4 +59,3 @@ def disable_v2_behavior(): ops.disable_eager_execution() tensor_shape.disable_v2_tensorshape() # Also switched by tf2 variable_scope.disable_resource_variables() - normalization.disable_v2_batch_normalization() diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 2eeb8e14736..e8782951062 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -331,6 +331,7 @@ py_library( "layers/merge.py", "layers/noise.py", "layers/normalization.py", + "layers/normalization_v2.py", "layers/pooling.py", "layers/recurrent.py", "layers/recurrent_v2.py", diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py index b3b0298cfd7..016cb116823 100644 --- a/tensorflow/python/keras/layers/__init__.py +++ b/tensorflow/python/keras/layers/__init__.py @@ -109,8 +109,9 @@ from tensorflow.python.keras.layers.noise import GaussianNoise from tensorflow.python.keras.layers.noise import GaussianDropout # Normalization layers. -from tensorflow.python.keras.layers.normalization import BatchNormalization from tensorflow.python.keras.layers.normalization import LayerNormalization +from tensorflow.python.keras.layers.normalization import BatchNormalization +from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization as BatchNormalizationV2 # Kernelized layers. from tensorflow.python.keras.layers.kernelized import RandomFourierFeatures diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index 7221f662c10..f7ce5e654e4 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python import tf2 from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.eager import context from tensorflow.python.framework import dtypes @@ -39,7 +38,6 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import keras_export -from tensorflow.python.util.tf_export import tf_export class BatchNormalizationBase(Layer): @@ -131,7 +129,7 @@ class BatchNormalizationBase(Layer): Internal Covariate Shift](https://arxiv.org/abs/1502.03167) """ - # By default, the base class uses V2 behavior. The BatchNormalizationV1 + # By default, the base class uses V2 behavior. The BatchNormalization V1 # subclass sets this to False to use the V1 behavior. _USE_V2_BEHAVIOR = True @@ -785,7 +783,7 @@ def _replace_in_base_docstring(old, new): @keras_export(v1=['keras.layers.BatchNormalization']) # pylint: disable=missing-docstring -class BatchNormalizationV1(BatchNormalizationBase): +class BatchNormalization(BatchNormalizationBase): __doc__ = _replace_in_base_docstring( ''' @@ -801,33 +799,6 @@ class BatchNormalizationV1(BatchNormalizationBase): _USE_V2_BEHAVIOR = False -@keras_export('keras.layers.BatchNormalization', v1=[]) # pylint: disable=missing-docstring -class BatchNormalizationV2(BatchNormalizationBase): - - pass - - -BatchNormalization = None # pylint: disable=invalid-name - - -@tf_export(v1=['enable_v2_batch_normalization']) -def enable_v2_batch_normalization(): - global BatchNormalization # pylint: disable=invalid-name - BatchNormalization = BatchNormalizationV2 - - -@tf_export(v1=['disable_v2_batch_normalization']) -def disable_v2_batch_normalization(): - global BatchNormalization # pylint: disable=invalid-name - BatchNormalization = BatchNormalizationV1 - - -if tf2.enabled(): - enable_v2_batch_normalization() -else: - disable_v2_batch_normalization() - - @keras_export('keras.layers.experimental.LayerNormalization') class LayerNormalization(Layer): """Layer normalization layer (Ba et al., 2016). diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py index 3f4ba0c05a6..0a422e39f2e 100644 --- a/tensorflow/python/keras/layers/normalization_test.py +++ b/tensorflow/python/keras/layers/normalization_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.python import keras @@ -26,6 +27,7 @@ from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils from tensorflow.python.keras.layers import normalization +from tensorflow.python.keras.layers import normalization_v2 from tensorflow.python.keras.mixed_precision.experimental import policy from tensorflow.python.platform import test from tensorflow.python.training import gradient_descent @@ -131,18 +133,14 @@ class BatchNormalizationTest(keras_parameterized.TestCase): _run_batchnorm_correctness_test( normalization.BatchNormalization, dtype='float32') _run_batchnorm_correctness_test( - normalization.BatchNormalization, dtype='float32', fused=True) - _run_batchnorm_correctness_test( - normalization.BatchNormalization, dtype='float32', fused=False) + normalization_v2.BatchNormalization, dtype='float32') @keras_parameterized.run_all_keras_modes def test_batchnorm_mixed_precision(self): _run_batchnorm_correctness_test( normalization.BatchNormalization, dtype='float16') _run_batchnorm_correctness_test( - normalization.BatchNormalization, dtype='float16', fused=True) - _run_batchnorm_correctness_test( - normalization.BatchNormalization, dtype='float16', fused=False) + normalization_v2.BatchNormalization, dtype='float16') @tf_test_util.run_in_graph_and_eager_modes def test_batchnorm_policy(self): @@ -162,18 +160,18 @@ class BatchNormalizationV1Test(test.TestCase): @tf_test_util.run_in_graph_and_eager_modes def test_v1_fused_attribute(self): - norm = normalization.BatchNormalizationV1() + norm = normalization.BatchNormalization() inp = keras.layers.Input((4, 4, 4)) norm(inp) self.assertEqual(norm.fused, True) - norm = normalization.BatchNormalizationV1(fused=False) + norm = normalization.BatchNormalization(fused=False) self.assertEqual(norm.fused, False) inp = keras.layers.Input(shape=(4, 4, 4)) norm(inp) self.assertEqual(norm.fused, False) - norm = normalization.BatchNormalizationV1(virtual_batch_size=2) + norm = normalization.BatchNormalization(virtual_batch_size=2) self.assertEqual(norm.fused, True) inp = keras.layers.Input(shape=(2, 2, 2)) norm(inp) @@ -185,63 +183,63 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase): @keras_parameterized.run_all_keras_modes def test_basic_batchnorm_v2(self): testing_utils.layer_test( - normalization.BatchNormalizationV2, + normalization_v2.BatchNormalization, kwargs={'fused': True}, input_shape=(3, 3, 3, 3)) testing_utils.layer_test( - normalization.BatchNormalizationV2, + normalization_v2.BatchNormalization, kwargs={'fused': None}, input_shape=(3, 3, 3)) @tf_test_util.run_in_graph_and_eager_modes def test_v2_fused_attribute(self): - norm = normalization.BatchNormalizationV2() + norm = normalization_v2.BatchNormalization() self.assertEqual(norm.fused, None) inp = keras.layers.Input(shape=(4, 4, 4)) norm(inp) self.assertEqual(norm.fused, True) - norm = normalization.BatchNormalizationV2() + norm = normalization_v2.BatchNormalization() self.assertEqual(norm.fused, None) inp = keras.layers.Input(shape=(4, 4)) norm(inp) self.assertEqual(norm.fused, False) - norm = normalization.BatchNormalizationV2(virtual_batch_size=2) + norm = normalization_v2.BatchNormalization(virtual_batch_size=2) self.assertEqual(norm.fused, False) inp = keras.layers.Input(shape=(4, 4, 4)) norm(inp) self.assertEqual(norm.fused, False) - norm = normalization.BatchNormalizationV2(fused=False) + norm = normalization_v2.BatchNormalization(fused=False) self.assertEqual(norm.fused, False) inp = keras.layers.Input(shape=(4, 4, 4)) norm(inp) self.assertEqual(norm.fused, False) - norm = normalization.BatchNormalizationV2(fused=True, axis=[3]) + norm = normalization_v2.BatchNormalization(fused=True, axis=[3]) self.assertEqual(norm.fused, True) inp = keras.layers.Input(shape=(4, 4, 4)) norm(inp) self.assertEqual(norm.fused, True) with self.assertRaisesRegexp(ValueError, 'fused.*renorm'): - normalization.BatchNormalizationV2(fused=True, renorm=True) + normalization_v2.BatchNormalization(fused=True, renorm=True) with self.assertRaisesRegexp(ValueError, 'fused.*when axis is 1 or 3'): - normalization.BatchNormalizationV2(fused=True, axis=2) + normalization_v2.BatchNormalization(fused=True, axis=2) with self.assertRaisesRegexp(ValueError, 'fused.*when axis is 1 or 3'): - normalization.BatchNormalizationV2(fused=True, axis=[1, 3]) + normalization_v2.BatchNormalization(fused=True, axis=[1, 3]) with self.assertRaisesRegexp(ValueError, 'fused.*virtual_batch_size'): - normalization.BatchNormalizationV2(fused=True, virtual_batch_size=2) + normalization_v2.BatchNormalization(fused=True, virtual_batch_size=2) with self.assertRaisesRegexp(ValueError, 'fused.*adjustment'): - normalization.BatchNormalizationV2(fused=True, - adjustment=lambda _: (1, 0)) + normalization_v2.BatchNormalization(fused=True, + adjustment=lambda _: (1, 0)) - norm = normalization.BatchNormalizationV2(fused=True) + norm = normalization_v2.BatchNormalization(fused=True) self.assertEqual(norm.fused, True) inp = keras.layers.Input(shape=(4, 4)) with self.assertRaisesRegexp(ValueError, '4D input tensors'): @@ -272,14 +270,16 @@ def _run_batchnorm_correctness_test(layer, dtype='float32', fused=False): np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) -class NormalizationLayersGraphModeOnlyTest(test.TestCase): +@parameterized.parameters( + [normalization.BatchNormalization, normalization_v2.BatchNormalization]) +class NormalizationLayersGraphModeOnlyTest( + test.TestCase, parameterized.TestCase): - def test_shared_batchnorm(self): - """Test that a BN layer can be shared across different data streams. - """ + def test_shared_batchnorm(self, layer): + """Test that a BN layer can be shared across different data streams.""" with self.cached_session(): # Test single layer reuse - bn = keras.layers.BatchNormalization() + bn = layer() x1 = keras.layers.Input(shape=(10,)) _ = bn(x1) @@ -307,13 +307,13 @@ class NormalizationLayersGraphModeOnlyTest(test.TestCase): new_model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') new_model.train_on_batch(x, x) - def test_that_trainable_disables_updates(self): + def test_that_trainable_disables_updates(self, layer): with self.cached_session(): val_a = np.random.random((10, 4)) val_out = np.random.random((10, 4)) a = keras.layers.Input(shape=(4,)) - layer = keras.layers.BatchNormalization(input_shape=(4,)) + layer = layer(input_shape=(4,)) b = layer(a) model = keras.models.Model(a, b) @@ -346,11 +346,14 @@ class NormalizationLayersGraphModeOnlyTest(test.TestCase): self.assertAllClose(x1, x2, atol=1e-7) @tf_test_util.run_deprecated_v1 - def test_batchnorm_trainable(self): + def test_batchnorm_trainable(self, layer): """Tests that batchnorm layer is trainable when learning phase is enabled. Computes mean and std for current inputs then applies batch normalization using them. + + Args: + layer: Either V1 or V2 of BatchNormalization layer. """ # TODO(fchollet): enable in all execution modes when issue with # learning phase setting is resolved. @@ -361,7 +364,7 @@ class NormalizationLayersGraphModeOnlyTest(test.TestCase): def get_model(bn_mean, bn_std): inp = keras.layers.Input(shape=(1,)) - x = keras.layers.BatchNormalization()(inp) + x = layer()(inp) model1 = keras.models.Model(inp, x) model1.set_weights([ np.array([1.]), diff --git a/tensorflow/python/keras/layers/normalization_v2.py b/tensorflow/python/keras/layers/normalization_v2.py new file mode 100644 index 00000000000..05501a7bf2c --- /dev/null +++ b/tensorflow/python/keras/layers/normalization_v2.py @@ -0,0 +1,28 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The V2 implementation of Normalization layers. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.keras.layers.normalization import BatchNormalizationBase +from tensorflow.python.util.tf_export import keras_export + + +@keras_export('keras.layers.BatchNormalization', v1=[]) # pylint: disable=missing-docstring +class BatchNormalization(BatchNormalizationBase): + + _USE_V2_BEHAVIOR = True diff --git a/tensorflow/python/keras/layers/serialization.py b/tensorflow/python/keras/layers/serialization.py index a651f7f7989..35202617716 100644 --- a/tensorflow/python/keras/layers/serialization.py +++ b/tensorflow/python/keras/layers/serialization.py @@ -42,22 +42,13 @@ from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.util.tf_export import keras_export if tf2.enabled(): + from tensorflow.python.keras.layers.normalization_v2 import * # pylint: disable=g-import-not-at-top from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top -# TODO(b/124791387): replace mapping with layer attribute. -# Name conversion between class name and API symbol in config. -_SERIALIZATION_TABLE = { - 'BatchNormalizationV1': 'BatchNormalization', - 'BatchNormalizationV2': 'BatchNormalization', -} - @keras_export('keras.layers.serialize') def serialize(layer): - layer_class_name = layer.__class__.__name__ - if layer_class_name in _SERIALIZATION_TABLE: - layer_class_name = _SERIALIZATION_TABLE[layer_class_name] - return {'class_name': layer_class_name, 'config': layer.get_config()} + return {'class_name': layer.__class__.__name__, 'config': layer.get_config()} @keras_export('keras.layers.deserialize') diff --git a/tensorflow/python/keras/layers/serialization_test.py b/tensorflow/python/keras/layers/serialization_test.py index ab86529bf3c..5e9fa3cef8d 100644 --- a/tensorflow/python/keras/layers/serialization_test.py +++ b/tensorflow/python/keras/layers/serialization_test.py @@ -23,6 +23,8 @@ from absl.testing import parameterized from tensorflow.python import keras from tensorflow.python import tf2 from tensorflow.python.framework import test_util as tf_test_util +from tensorflow.python.keras.layers import normalization as batchnorm_v1 +from tensorflow.python.keras.layers import normalization_v2 as batchnorm_v2 from tensorflow.python.keras.layers import recurrent as rnn_v1 from tensorflow.python.keras.layers import recurrent_v2 as rnn_v2 from tensorflow.python.platform import test @@ -47,17 +49,21 @@ class LayerSerializationTest(parameterized.TestCase, test.TestCase): keras.initializers.Ones) self.assertEqual(new_layer.units, 3) - def test_serialize_deserialize_batchnorm(self): - layer = keras.layers.BatchNormalization( + @parameterized.parameters( + [batchnorm_v1.BatchNormalization, batchnorm_v2.BatchNormalization]) + def test_serialize_deserialize_batchnorm(self, batchnorm_layer): + layer = batchnorm_layer( momentum=0.9, beta_initializer='zeros', gamma_regularizer='l2') config = keras.layers.serialize(layer) self.assertEqual(config['class_name'], 'BatchNormalization') new_layer = keras.layers.deserialize(config) self.assertEqual(new_layer.momentum, 0.9) if tf2.enabled(): + self.assertIsInstance(new_layer, batchnorm_v2.BatchNormalization) self.assertEqual(new_layer.beta_initializer.__class__, keras.initializers.ZerosV2) else: + self.assertIsInstance(new_layer, batchnorm_v1.BatchNormalization) self.assertEqual(new_layer.beta_initializer.__class__, keras.initializers.Zeros) self.assertEqual(new_layer.gamma_regularizer.__class__, diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt index 2638adf5a3b..f2b80301df8 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.layers.BatchNormalization" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt index 751101585ab..f55c3e4c426 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.layers.BatchNormalization" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 9b8e383ae14..f0416fc5520 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1112,10 +1112,6 @@ tf_module { name: "disable_resource_variables" argspec: "args=[], varargs=None, keywords=None, defaults=None" } - member_method { - name: "disable_v2_batch_normalization" - argspec: "args=[], varargs=None, keywords=None, defaults=None" - } member_method { name: "disable_v2_behavior" argspec: "args=[], varargs=None, keywords=None, defaults=None" @@ -1160,10 +1156,6 @@ tf_module { name: "enable_resource_variables" argspec: "args=[], varargs=None, keywords=None, defaults=None" } - member_method { - name: "enable_v2_batch_normalization" - argspec: "args=[], varargs=None, keywords=None, defaults=None" - } member_method { name: "enable_v2_behavior" argspec: "args=[], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt index 200aa2f890f..5613c23641a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.layers.BatchNormalization" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: ""