Fix serialization/deserialization for BatchNorm v1 and v2.

PiperOrigin-RevId: 238315500
This commit is contained in:
Yanhui Liang 2019-03-13 14:54:13 -07:00 committed by TensorFlower Gardener
parent 0993d774a8
commit 0f371719b1
12 changed files with 81 additions and 91 deletions

View File

@ -21,7 +21,6 @@ from __future__ import print_function
from tensorflow.python import tf2
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.layers import normalization
from tensorflow.python.ops import variable_scope
from tensorflow.python.util.tf_export import tf_export
@ -43,7 +42,6 @@ def enable_v2_behavior():
ops.enable_eager_execution()
tensor_shape.enable_v2_tensorshape() # Also switched by tf2
variable_scope.enable_resource_variables()
normalization.enable_v2_batch_normalization()
@tf_export(v1=["disable_v2_behavior"])
@ -61,4 +59,3 @@ def disable_v2_behavior():
ops.disable_eager_execution()
tensor_shape.disable_v2_tensorshape() # Also switched by tf2
variable_scope.disable_resource_variables()
normalization.disable_v2_batch_normalization()

View File

@ -331,6 +331,7 @@ py_library(
"layers/merge.py",
"layers/noise.py",
"layers/normalization.py",
"layers/normalization_v2.py",
"layers/pooling.py",
"layers/recurrent.py",
"layers/recurrent_v2.py",

View File

@ -109,8 +109,9 @@ from tensorflow.python.keras.layers.noise import GaussianNoise
from tensorflow.python.keras.layers.noise import GaussianDropout
# Normalization layers.
from tensorflow.python.keras.layers.normalization import BatchNormalization
from tensorflow.python.keras.layers.normalization import LayerNormalization
from tensorflow.python.keras.layers.normalization import BatchNormalization
from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization as BatchNormalizationV2
# Kernelized layers.
from tensorflow.python.keras.layers.kernelized import RandomFourierFeatures

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import tf2
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
@ -39,7 +38,6 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export
from tensorflow.python.util.tf_export import tf_export
class BatchNormalizationBase(Layer):
@ -785,7 +783,7 @@ def _replace_in_base_docstring(old, new):
@keras_export(v1=['keras.layers.BatchNormalization']) # pylint: disable=missing-docstring
class BatchNormalizationV1(BatchNormalizationBase):
class BatchNormalization(BatchNormalizationBase):
__doc__ = _replace_in_base_docstring(
'''
@ -801,33 +799,6 @@ class BatchNormalizationV1(BatchNormalizationBase):
_USE_V2_BEHAVIOR = False
@keras_export('keras.layers.BatchNormalization', v1=[]) # pylint: disable=missing-docstring
class BatchNormalizationV2(BatchNormalizationBase):
pass
BatchNormalization = None # pylint: disable=invalid-name
@tf_export(v1=['enable_v2_batch_normalization'])
def enable_v2_batch_normalization():
global BatchNormalization # pylint: disable=invalid-name
BatchNormalization = BatchNormalizationV2
@tf_export(v1=['disable_v2_batch_normalization'])
def disable_v2_batch_normalization():
global BatchNormalization # pylint: disable=invalid-name
BatchNormalization = BatchNormalizationV1
if tf2.enabled():
enable_v2_batch_normalization()
else:
disable_v2_batch_normalization()
@keras_export('keras.layers.experimental.LayerNormalization')
class LayerNormalization(Layer):
"""Layer normalization layer (Ba et al., 2016).

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python import keras
@ -26,6 +27,7 @@ from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.layers import normalization
from tensorflow.python.keras.layers import normalization_v2
from tensorflow.python.keras.mixed_precision.experimental import policy
from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
@ -131,18 +133,14 @@ class BatchNormalizationTest(keras_parameterized.TestCase):
_run_batchnorm_correctness_test(
normalization.BatchNormalization, dtype='float32')
_run_batchnorm_correctness_test(
normalization.BatchNormalization, dtype='float32', fused=True)
_run_batchnorm_correctness_test(
normalization.BatchNormalization, dtype='float32', fused=False)
normalization_v2.BatchNormalization, dtype='float32')
@keras_parameterized.run_all_keras_modes
def test_batchnorm_mixed_precision(self):
_run_batchnorm_correctness_test(
normalization.BatchNormalization, dtype='float16')
_run_batchnorm_correctness_test(
normalization.BatchNormalization, dtype='float16', fused=True)
_run_batchnorm_correctness_test(
normalization.BatchNormalization, dtype='float16', fused=False)
normalization_v2.BatchNormalization, dtype='float16')
@tf_test_util.run_in_graph_and_eager_modes
def test_batchnorm_policy(self):
@ -162,18 +160,18 @@ class BatchNormalizationV1Test(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_v1_fused_attribute(self):
norm = normalization.BatchNormalizationV1()
norm = normalization.BatchNormalization()
inp = keras.layers.Input((4, 4, 4))
norm(inp)
self.assertEqual(norm.fused, True)
norm = normalization.BatchNormalizationV1(fused=False)
norm = normalization.BatchNormalization(fused=False)
self.assertEqual(norm.fused, False)
inp = keras.layers.Input(shape=(4, 4, 4))
norm(inp)
self.assertEqual(norm.fused, False)
norm = normalization.BatchNormalizationV1(virtual_batch_size=2)
norm = normalization.BatchNormalization(virtual_batch_size=2)
self.assertEqual(norm.fused, True)
inp = keras.layers.Input(shape=(2, 2, 2))
norm(inp)
@ -185,63 +183,63 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase):
@keras_parameterized.run_all_keras_modes
def test_basic_batchnorm_v2(self):
testing_utils.layer_test(
normalization.BatchNormalizationV2,
normalization_v2.BatchNormalization,
kwargs={'fused': True},
input_shape=(3, 3, 3, 3))
testing_utils.layer_test(
normalization.BatchNormalizationV2,
normalization_v2.BatchNormalization,
kwargs={'fused': None},
input_shape=(3, 3, 3))
@tf_test_util.run_in_graph_and_eager_modes
def test_v2_fused_attribute(self):
norm = normalization.BatchNormalizationV2()
norm = normalization_v2.BatchNormalization()
self.assertEqual(norm.fused, None)
inp = keras.layers.Input(shape=(4, 4, 4))
norm(inp)
self.assertEqual(norm.fused, True)
norm = normalization.BatchNormalizationV2()
norm = normalization_v2.BatchNormalization()
self.assertEqual(norm.fused, None)
inp = keras.layers.Input(shape=(4, 4))
norm(inp)
self.assertEqual(norm.fused, False)
norm = normalization.BatchNormalizationV2(virtual_batch_size=2)
norm = normalization_v2.BatchNormalization(virtual_batch_size=2)
self.assertEqual(norm.fused, False)
inp = keras.layers.Input(shape=(4, 4, 4))
norm(inp)
self.assertEqual(norm.fused, False)
norm = normalization.BatchNormalizationV2(fused=False)
norm = normalization_v2.BatchNormalization(fused=False)
self.assertEqual(norm.fused, False)
inp = keras.layers.Input(shape=(4, 4, 4))
norm(inp)
self.assertEqual(norm.fused, False)
norm = normalization.BatchNormalizationV2(fused=True, axis=[3])
norm = normalization_v2.BatchNormalization(fused=True, axis=[3])
self.assertEqual(norm.fused, True)
inp = keras.layers.Input(shape=(4, 4, 4))
norm(inp)
self.assertEqual(norm.fused, True)
with self.assertRaisesRegexp(ValueError, 'fused.*renorm'):
normalization.BatchNormalizationV2(fused=True, renorm=True)
normalization_v2.BatchNormalization(fused=True, renorm=True)
with self.assertRaisesRegexp(ValueError, 'fused.*when axis is 1 or 3'):
normalization.BatchNormalizationV2(fused=True, axis=2)
normalization_v2.BatchNormalization(fused=True, axis=2)
with self.assertRaisesRegexp(ValueError, 'fused.*when axis is 1 or 3'):
normalization.BatchNormalizationV2(fused=True, axis=[1, 3])
normalization_v2.BatchNormalization(fused=True, axis=[1, 3])
with self.assertRaisesRegexp(ValueError, 'fused.*virtual_batch_size'):
normalization.BatchNormalizationV2(fused=True, virtual_batch_size=2)
normalization_v2.BatchNormalization(fused=True, virtual_batch_size=2)
with self.assertRaisesRegexp(ValueError, 'fused.*adjustment'):
normalization.BatchNormalizationV2(fused=True,
normalization_v2.BatchNormalization(fused=True,
adjustment=lambda _: (1, 0))
norm = normalization.BatchNormalizationV2(fused=True)
norm = normalization_v2.BatchNormalization(fused=True)
self.assertEqual(norm.fused, True)
inp = keras.layers.Input(shape=(4, 4))
with self.assertRaisesRegexp(ValueError, '4D input tensors'):
@ -272,14 +270,16 @@ def _run_batchnorm_correctness_test(layer, dtype='float32', fused=False):
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
class NormalizationLayersGraphModeOnlyTest(test.TestCase):
@parameterized.parameters(
[normalization.BatchNormalization, normalization_v2.BatchNormalization])
class NormalizationLayersGraphModeOnlyTest(
test.TestCase, parameterized.TestCase):
def test_shared_batchnorm(self):
"""Test that a BN layer can be shared across different data streams.
"""
def test_shared_batchnorm(self, layer):
"""Test that a BN layer can be shared across different data streams."""
with self.cached_session():
# Test single layer reuse
bn = keras.layers.BatchNormalization()
bn = layer()
x1 = keras.layers.Input(shape=(10,))
_ = bn(x1)
@ -307,13 +307,13 @@ class NormalizationLayersGraphModeOnlyTest(test.TestCase):
new_model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
new_model.train_on_batch(x, x)
def test_that_trainable_disables_updates(self):
def test_that_trainable_disables_updates(self, layer):
with self.cached_session():
val_a = np.random.random((10, 4))
val_out = np.random.random((10, 4))
a = keras.layers.Input(shape=(4,))
layer = keras.layers.BatchNormalization(input_shape=(4,))
layer = layer(input_shape=(4,))
b = layer(a)
model = keras.models.Model(a, b)
@ -346,11 +346,14 @@ class NormalizationLayersGraphModeOnlyTest(test.TestCase):
self.assertAllClose(x1, x2, atol=1e-7)
@tf_test_util.run_deprecated_v1
def test_batchnorm_trainable(self):
def test_batchnorm_trainable(self, layer):
"""Tests that batchnorm layer is trainable when learning phase is enabled.
Computes mean and std for current inputs then
applies batch normalization using them.
Args:
layer: Either V1 or V2 of BatchNormalization layer.
"""
# TODO(fchollet): enable in all execution modes when issue with
# learning phase setting is resolved.
@ -361,7 +364,7 @@ class NormalizationLayersGraphModeOnlyTest(test.TestCase):
def get_model(bn_mean, bn_std):
inp = keras.layers.Input(shape=(1,))
x = keras.layers.BatchNormalization()(inp)
x = layer()(inp)
model1 = keras.models.Model(inp, x)
model1.set_weights([
np.array([1.]),

View File

@ -0,0 +1,28 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The V2 implementation of Normalization layers.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras.layers.normalization import BatchNormalizationBase
from tensorflow.python.util.tf_export import keras_export
@keras_export('keras.layers.BatchNormalization', v1=[]) # pylint: disable=missing-docstring
class BatchNormalization(BatchNormalizationBase):
_USE_V2_BEHAVIOR = True

View File

@ -42,22 +42,13 @@ from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.util.tf_export import keras_export
if tf2.enabled():
from tensorflow.python.keras.layers.normalization_v2 import * # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top
# TODO(b/124791387): replace mapping with layer attribute.
# Name conversion between class name and API symbol in config.
_SERIALIZATION_TABLE = {
'BatchNormalizationV1': 'BatchNormalization',
'BatchNormalizationV2': 'BatchNormalization',
}
@keras_export('keras.layers.serialize')
def serialize(layer):
layer_class_name = layer.__class__.__name__
if layer_class_name in _SERIALIZATION_TABLE:
layer_class_name = _SERIALIZATION_TABLE[layer_class_name]
return {'class_name': layer_class_name, 'config': layer.get_config()}
return {'class_name': layer.__class__.__name__, 'config': layer.get_config()}
@keras_export('keras.layers.deserialize')

View File

@ -23,6 +23,8 @@ from absl.testing import parameterized
from tensorflow.python import keras
from tensorflow.python import tf2
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras.layers import normalization as batchnorm_v1
from tensorflow.python.keras.layers import normalization_v2 as batchnorm_v2
from tensorflow.python.keras.layers import recurrent as rnn_v1
from tensorflow.python.keras.layers import recurrent_v2 as rnn_v2
from tensorflow.python.platform import test
@ -47,17 +49,21 @@ class LayerSerializationTest(parameterized.TestCase, test.TestCase):
keras.initializers.Ones)
self.assertEqual(new_layer.units, 3)
def test_serialize_deserialize_batchnorm(self):
layer = keras.layers.BatchNormalization(
@parameterized.parameters(
[batchnorm_v1.BatchNormalization, batchnorm_v2.BatchNormalization])
def test_serialize_deserialize_batchnorm(self, batchnorm_layer):
layer = batchnorm_layer(
momentum=0.9, beta_initializer='zeros', gamma_regularizer='l2')
config = keras.layers.serialize(layer)
self.assertEqual(config['class_name'], 'BatchNormalization')
new_layer = keras.layers.deserialize(config)
self.assertEqual(new_layer.momentum, 0.9)
if tf2.enabled():
self.assertIsInstance(new_layer, batchnorm_v2.BatchNormalization)
self.assertEqual(new_layer.beta_initializer.__class__,
keras.initializers.ZerosV2)
else:
self.assertIsInstance(new_layer, batchnorm_v1.BatchNormalization)
self.assertEqual(new_layer.beta_initializer.__class__,
keras.initializers.Zeros)
self.assertEqual(new_layer.gamma_regularizer.__class__,

View File

@ -1,6 +1,6 @@
path: "tensorflow.keras.layers.BatchNormalization"
tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationV1\'>"
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalization\'>"
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationBase\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"

View File

@ -1,7 +1,7 @@
path: "tensorflow.layers.BatchNormalization"
tf_class {
is_instance: "<class \'tensorflow.python.layers.normalization.BatchNormalization\'>"
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationV1\'>"
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalization\'>"
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationBase\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"

View File

@ -1112,10 +1112,6 @@ tf_module {
name: "disable_resource_variables"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "disable_v2_batch_normalization"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "disable_v2_behavior"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
@ -1160,10 +1156,6 @@ tf_module {
name: "enable_resource_variables"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "enable_v2_batch_normalization"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "enable_v2_behavior"
argspec: "args=[], varargs=None, keywords=None, defaults=None"

View File

@ -1,6 +1,6 @@
path: "tensorflow.keras.layers.BatchNormalization"
tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationV2\'>"
is_instance: "<class \'tensorflow.python.keras.layers.normalization_v2.BatchNormalization\'>"
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationBase\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"