Fix serialization/deserialization for BatchNorm v1 and v2.
PiperOrigin-RevId: 238315500
This commit is contained in:
parent
0993d774a8
commit
0f371719b1
@ -21,7 +21,6 @@ from __future__ import print_function
|
|||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.keras.layers import normalization
|
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
|
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
@ -43,7 +42,6 @@ def enable_v2_behavior():
|
|||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
tensor_shape.enable_v2_tensorshape() # Also switched by tf2
|
tensor_shape.enable_v2_tensorshape() # Also switched by tf2
|
||||||
variable_scope.enable_resource_variables()
|
variable_scope.enable_resource_variables()
|
||||||
normalization.enable_v2_batch_normalization()
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["disable_v2_behavior"])
|
@tf_export(v1=["disable_v2_behavior"])
|
||||||
@ -61,4 +59,3 @@ def disable_v2_behavior():
|
|||||||
ops.disable_eager_execution()
|
ops.disable_eager_execution()
|
||||||
tensor_shape.disable_v2_tensorshape() # Also switched by tf2
|
tensor_shape.disable_v2_tensorshape() # Also switched by tf2
|
||||||
variable_scope.disable_resource_variables()
|
variable_scope.disable_resource_variables()
|
||||||
normalization.disable_v2_batch_normalization()
|
|
||||||
|
@ -331,6 +331,7 @@ py_library(
|
|||||||
"layers/merge.py",
|
"layers/merge.py",
|
||||||
"layers/noise.py",
|
"layers/noise.py",
|
||||||
"layers/normalization.py",
|
"layers/normalization.py",
|
||||||
|
"layers/normalization_v2.py",
|
||||||
"layers/pooling.py",
|
"layers/pooling.py",
|
||||||
"layers/recurrent.py",
|
"layers/recurrent.py",
|
||||||
"layers/recurrent_v2.py",
|
"layers/recurrent_v2.py",
|
||||||
|
@ -109,8 +109,9 @@ from tensorflow.python.keras.layers.noise import GaussianNoise
|
|||||||
from tensorflow.python.keras.layers.noise import GaussianDropout
|
from tensorflow.python.keras.layers.noise import GaussianDropout
|
||||||
|
|
||||||
# Normalization layers.
|
# Normalization layers.
|
||||||
from tensorflow.python.keras.layers.normalization import BatchNormalization
|
|
||||||
from tensorflow.python.keras.layers.normalization import LayerNormalization
|
from tensorflow.python.keras.layers.normalization import LayerNormalization
|
||||||
|
from tensorflow.python.keras.layers.normalization import BatchNormalization
|
||||||
|
from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization as BatchNormalizationV2
|
||||||
|
|
||||||
# Kernelized layers.
|
# Kernelized layers.
|
||||||
from tensorflow.python.keras.layers.kernelized import RandomFourierFeatures
|
from tensorflow.python.keras.layers.kernelized import RandomFourierFeatures
|
||||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python import tf2
|
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -39,7 +38,6 @@ from tensorflow.python.ops import state_ops
|
|||||||
from tensorflow.python.ops import variables as tf_variables
|
from tensorflow.python.ops import variables as tf_variables
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
|
||||||
|
|
||||||
|
|
||||||
class BatchNormalizationBase(Layer):
|
class BatchNormalizationBase(Layer):
|
||||||
@ -785,7 +783,7 @@ def _replace_in_base_docstring(old, new):
|
|||||||
|
|
||||||
|
|
||||||
@keras_export(v1=['keras.layers.BatchNormalization']) # pylint: disable=missing-docstring
|
@keras_export(v1=['keras.layers.BatchNormalization']) # pylint: disable=missing-docstring
|
||||||
class BatchNormalizationV1(BatchNormalizationBase):
|
class BatchNormalization(BatchNormalizationBase):
|
||||||
|
|
||||||
__doc__ = _replace_in_base_docstring(
|
__doc__ = _replace_in_base_docstring(
|
||||||
'''
|
'''
|
||||||
@ -801,33 +799,6 @@ class BatchNormalizationV1(BatchNormalizationBase):
|
|||||||
_USE_V2_BEHAVIOR = False
|
_USE_V2_BEHAVIOR = False
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.layers.BatchNormalization', v1=[]) # pylint: disable=missing-docstring
|
|
||||||
class BatchNormalizationV2(BatchNormalizationBase):
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
BatchNormalization = None # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=['enable_v2_batch_normalization'])
|
|
||||||
def enable_v2_batch_normalization():
|
|
||||||
global BatchNormalization # pylint: disable=invalid-name
|
|
||||||
BatchNormalization = BatchNormalizationV2
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=['disable_v2_batch_normalization'])
|
|
||||||
def disable_v2_batch_normalization():
|
|
||||||
global BatchNormalization # pylint: disable=invalid-name
|
|
||||||
BatchNormalization = BatchNormalizationV1
|
|
||||||
|
|
||||||
|
|
||||||
if tf2.enabled():
|
|
||||||
enable_v2_batch_normalization()
|
|
||||||
else:
|
|
||||||
disable_v2_batch_normalization()
|
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.layers.experimental.LayerNormalization')
|
@keras_export('keras.layers.experimental.LayerNormalization')
|
||||||
class LayerNormalization(Layer):
|
class LayerNormalization(Layer):
|
||||||
"""Layer normalization layer (Ba et al., 2016).
|
"""Layer normalization layer (Ba et al., 2016).
|
||||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
@ -26,6 +27,7 @@ from tensorflow.python.framework import test_util as tf_test_util
|
|||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
from tensorflow.python.keras.layers import normalization
|
from tensorflow.python.keras.layers import normalization
|
||||||
|
from tensorflow.python.keras.layers import normalization_v2
|
||||||
from tensorflow.python.keras.mixed_precision.experimental import policy
|
from tensorflow.python.keras.mixed_precision.experimental import policy
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.training import gradient_descent
|
from tensorflow.python.training import gradient_descent
|
||||||
@ -131,18 +133,14 @@ class BatchNormalizationTest(keras_parameterized.TestCase):
|
|||||||
_run_batchnorm_correctness_test(
|
_run_batchnorm_correctness_test(
|
||||||
normalization.BatchNormalization, dtype='float32')
|
normalization.BatchNormalization, dtype='float32')
|
||||||
_run_batchnorm_correctness_test(
|
_run_batchnorm_correctness_test(
|
||||||
normalization.BatchNormalization, dtype='float32', fused=True)
|
normalization_v2.BatchNormalization, dtype='float32')
|
||||||
_run_batchnorm_correctness_test(
|
|
||||||
normalization.BatchNormalization, dtype='float32', fused=False)
|
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_batchnorm_mixed_precision(self):
|
def test_batchnorm_mixed_precision(self):
|
||||||
_run_batchnorm_correctness_test(
|
_run_batchnorm_correctness_test(
|
||||||
normalization.BatchNormalization, dtype='float16')
|
normalization.BatchNormalization, dtype='float16')
|
||||||
_run_batchnorm_correctness_test(
|
_run_batchnorm_correctness_test(
|
||||||
normalization.BatchNormalization, dtype='float16', fused=True)
|
normalization_v2.BatchNormalization, dtype='float16')
|
||||||
_run_batchnorm_correctness_test(
|
|
||||||
normalization.BatchNormalization, dtype='float16', fused=False)
|
|
||||||
|
|
||||||
@tf_test_util.run_in_graph_and_eager_modes
|
@tf_test_util.run_in_graph_and_eager_modes
|
||||||
def test_batchnorm_policy(self):
|
def test_batchnorm_policy(self):
|
||||||
@ -162,18 +160,18 @@ class BatchNormalizationV1Test(test.TestCase):
|
|||||||
|
|
||||||
@tf_test_util.run_in_graph_and_eager_modes
|
@tf_test_util.run_in_graph_and_eager_modes
|
||||||
def test_v1_fused_attribute(self):
|
def test_v1_fused_attribute(self):
|
||||||
norm = normalization.BatchNormalizationV1()
|
norm = normalization.BatchNormalization()
|
||||||
inp = keras.layers.Input((4, 4, 4))
|
inp = keras.layers.Input((4, 4, 4))
|
||||||
norm(inp)
|
norm(inp)
|
||||||
self.assertEqual(norm.fused, True)
|
self.assertEqual(norm.fused, True)
|
||||||
|
|
||||||
norm = normalization.BatchNormalizationV1(fused=False)
|
norm = normalization.BatchNormalization(fused=False)
|
||||||
self.assertEqual(norm.fused, False)
|
self.assertEqual(norm.fused, False)
|
||||||
inp = keras.layers.Input(shape=(4, 4, 4))
|
inp = keras.layers.Input(shape=(4, 4, 4))
|
||||||
norm(inp)
|
norm(inp)
|
||||||
self.assertEqual(norm.fused, False)
|
self.assertEqual(norm.fused, False)
|
||||||
|
|
||||||
norm = normalization.BatchNormalizationV1(virtual_batch_size=2)
|
norm = normalization.BatchNormalization(virtual_batch_size=2)
|
||||||
self.assertEqual(norm.fused, True)
|
self.assertEqual(norm.fused, True)
|
||||||
inp = keras.layers.Input(shape=(2, 2, 2))
|
inp = keras.layers.Input(shape=(2, 2, 2))
|
||||||
norm(inp)
|
norm(inp)
|
||||||
@ -185,63 +183,63 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase):
|
|||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_basic_batchnorm_v2(self):
|
def test_basic_batchnorm_v2(self):
|
||||||
testing_utils.layer_test(
|
testing_utils.layer_test(
|
||||||
normalization.BatchNormalizationV2,
|
normalization_v2.BatchNormalization,
|
||||||
kwargs={'fused': True},
|
kwargs={'fused': True},
|
||||||
input_shape=(3, 3, 3, 3))
|
input_shape=(3, 3, 3, 3))
|
||||||
testing_utils.layer_test(
|
testing_utils.layer_test(
|
||||||
normalization.BatchNormalizationV2,
|
normalization_v2.BatchNormalization,
|
||||||
kwargs={'fused': None},
|
kwargs={'fused': None},
|
||||||
input_shape=(3, 3, 3))
|
input_shape=(3, 3, 3))
|
||||||
|
|
||||||
@tf_test_util.run_in_graph_and_eager_modes
|
@tf_test_util.run_in_graph_and_eager_modes
|
||||||
def test_v2_fused_attribute(self):
|
def test_v2_fused_attribute(self):
|
||||||
norm = normalization.BatchNormalizationV2()
|
norm = normalization_v2.BatchNormalization()
|
||||||
self.assertEqual(norm.fused, None)
|
self.assertEqual(norm.fused, None)
|
||||||
inp = keras.layers.Input(shape=(4, 4, 4))
|
inp = keras.layers.Input(shape=(4, 4, 4))
|
||||||
norm(inp)
|
norm(inp)
|
||||||
self.assertEqual(norm.fused, True)
|
self.assertEqual(norm.fused, True)
|
||||||
|
|
||||||
norm = normalization.BatchNormalizationV2()
|
norm = normalization_v2.BatchNormalization()
|
||||||
self.assertEqual(norm.fused, None)
|
self.assertEqual(norm.fused, None)
|
||||||
inp = keras.layers.Input(shape=(4, 4))
|
inp = keras.layers.Input(shape=(4, 4))
|
||||||
norm(inp)
|
norm(inp)
|
||||||
self.assertEqual(norm.fused, False)
|
self.assertEqual(norm.fused, False)
|
||||||
|
|
||||||
norm = normalization.BatchNormalizationV2(virtual_batch_size=2)
|
norm = normalization_v2.BatchNormalization(virtual_batch_size=2)
|
||||||
self.assertEqual(norm.fused, False)
|
self.assertEqual(norm.fused, False)
|
||||||
inp = keras.layers.Input(shape=(4, 4, 4))
|
inp = keras.layers.Input(shape=(4, 4, 4))
|
||||||
norm(inp)
|
norm(inp)
|
||||||
self.assertEqual(norm.fused, False)
|
self.assertEqual(norm.fused, False)
|
||||||
|
|
||||||
norm = normalization.BatchNormalizationV2(fused=False)
|
norm = normalization_v2.BatchNormalization(fused=False)
|
||||||
self.assertEqual(norm.fused, False)
|
self.assertEqual(norm.fused, False)
|
||||||
inp = keras.layers.Input(shape=(4, 4, 4))
|
inp = keras.layers.Input(shape=(4, 4, 4))
|
||||||
norm(inp)
|
norm(inp)
|
||||||
self.assertEqual(norm.fused, False)
|
self.assertEqual(norm.fused, False)
|
||||||
|
|
||||||
norm = normalization.BatchNormalizationV2(fused=True, axis=[3])
|
norm = normalization_v2.BatchNormalization(fused=True, axis=[3])
|
||||||
self.assertEqual(norm.fused, True)
|
self.assertEqual(norm.fused, True)
|
||||||
inp = keras.layers.Input(shape=(4, 4, 4))
|
inp = keras.layers.Input(shape=(4, 4, 4))
|
||||||
norm(inp)
|
norm(inp)
|
||||||
self.assertEqual(norm.fused, True)
|
self.assertEqual(norm.fused, True)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(ValueError, 'fused.*renorm'):
|
with self.assertRaisesRegexp(ValueError, 'fused.*renorm'):
|
||||||
normalization.BatchNormalizationV2(fused=True, renorm=True)
|
normalization_v2.BatchNormalization(fused=True, renorm=True)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(ValueError, 'fused.*when axis is 1 or 3'):
|
with self.assertRaisesRegexp(ValueError, 'fused.*when axis is 1 or 3'):
|
||||||
normalization.BatchNormalizationV2(fused=True, axis=2)
|
normalization_v2.BatchNormalization(fused=True, axis=2)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(ValueError, 'fused.*when axis is 1 or 3'):
|
with self.assertRaisesRegexp(ValueError, 'fused.*when axis is 1 or 3'):
|
||||||
normalization.BatchNormalizationV2(fused=True, axis=[1, 3])
|
normalization_v2.BatchNormalization(fused=True, axis=[1, 3])
|
||||||
|
|
||||||
with self.assertRaisesRegexp(ValueError, 'fused.*virtual_batch_size'):
|
with self.assertRaisesRegexp(ValueError, 'fused.*virtual_batch_size'):
|
||||||
normalization.BatchNormalizationV2(fused=True, virtual_batch_size=2)
|
normalization_v2.BatchNormalization(fused=True, virtual_batch_size=2)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(ValueError, 'fused.*adjustment'):
|
with self.assertRaisesRegexp(ValueError, 'fused.*adjustment'):
|
||||||
normalization.BatchNormalizationV2(fused=True,
|
normalization_v2.BatchNormalization(fused=True,
|
||||||
adjustment=lambda _: (1, 0))
|
adjustment=lambda _: (1, 0))
|
||||||
|
|
||||||
norm = normalization.BatchNormalizationV2(fused=True)
|
norm = normalization_v2.BatchNormalization(fused=True)
|
||||||
self.assertEqual(norm.fused, True)
|
self.assertEqual(norm.fused, True)
|
||||||
inp = keras.layers.Input(shape=(4, 4))
|
inp = keras.layers.Input(shape=(4, 4))
|
||||||
with self.assertRaisesRegexp(ValueError, '4D input tensors'):
|
with self.assertRaisesRegexp(ValueError, '4D input tensors'):
|
||||||
@ -272,14 +270,16 @@ def _run_batchnorm_correctness_test(layer, dtype='float32', fused=False):
|
|||||||
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
|
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
|
||||||
|
|
||||||
|
|
||||||
class NormalizationLayersGraphModeOnlyTest(test.TestCase):
|
@parameterized.parameters(
|
||||||
|
[normalization.BatchNormalization, normalization_v2.BatchNormalization])
|
||||||
|
class NormalizationLayersGraphModeOnlyTest(
|
||||||
|
test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def test_shared_batchnorm(self):
|
def test_shared_batchnorm(self, layer):
|
||||||
"""Test that a BN layer can be shared across different data streams.
|
"""Test that a BN layer can be shared across different data streams."""
|
||||||
"""
|
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
# Test single layer reuse
|
# Test single layer reuse
|
||||||
bn = keras.layers.BatchNormalization()
|
bn = layer()
|
||||||
x1 = keras.layers.Input(shape=(10,))
|
x1 = keras.layers.Input(shape=(10,))
|
||||||
_ = bn(x1)
|
_ = bn(x1)
|
||||||
|
|
||||||
@ -307,13 +307,13 @@ class NormalizationLayersGraphModeOnlyTest(test.TestCase):
|
|||||||
new_model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
|
new_model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
|
||||||
new_model.train_on_batch(x, x)
|
new_model.train_on_batch(x, x)
|
||||||
|
|
||||||
def test_that_trainable_disables_updates(self):
|
def test_that_trainable_disables_updates(self, layer):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
val_a = np.random.random((10, 4))
|
val_a = np.random.random((10, 4))
|
||||||
val_out = np.random.random((10, 4))
|
val_out = np.random.random((10, 4))
|
||||||
|
|
||||||
a = keras.layers.Input(shape=(4,))
|
a = keras.layers.Input(shape=(4,))
|
||||||
layer = keras.layers.BatchNormalization(input_shape=(4,))
|
layer = layer(input_shape=(4,))
|
||||||
b = layer(a)
|
b = layer(a)
|
||||||
model = keras.models.Model(a, b)
|
model = keras.models.Model(a, b)
|
||||||
|
|
||||||
@ -346,11 +346,14 @@ class NormalizationLayersGraphModeOnlyTest(test.TestCase):
|
|||||||
self.assertAllClose(x1, x2, atol=1e-7)
|
self.assertAllClose(x1, x2, atol=1e-7)
|
||||||
|
|
||||||
@tf_test_util.run_deprecated_v1
|
@tf_test_util.run_deprecated_v1
|
||||||
def test_batchnorm_trainable(self):
|
def test_batchnorm_trainable(self, layer):
|
||||||
"""Tests that batchnorm layer is trainable when learning phase is enabled.
|
"""Tests that batchnorm layer is trainable when learning phase is enabled.
|
||||||
|
|
||||||
Computes mean and std for current inputs then
|
Computes mean and std for current inputs then
|
||||||
applies batch normalization using them.
|
applies batch normalization using them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: Either V1 or V2 of BatchNormalization layer.
|
||||||
"""
|
"""
|
||||||
# TODO(fchollet): enable in all execution modes when issue with
|
# TODO(fchollet): enable in all execution modes when issue with
|
||||||
# learning phase setting is resolved.
|
# learning phase setting is resolved.
|
||||||
@ -361,7 +364,7 @@ class NormalizationLayersGraphModeOnlyTest(test.TestCase):
|
|||||||
|
|
||||||
def get_model(bn_mean, bn_std):
|
def get_model(bn_mean, bn_std):
|
||||||
inp = keras.layers.Input(shape=(1,))
|
inp = keras.layers.Input(shape=(1,))
|
||||||
x = keras.layers.BatchNormalization()(inp)
|
x = layer()(inp)
|
||||||
model1 = keras.models.Model(inp, x)
|
model1 = keras.models.Model(inp, x)
|
||||||
model1.set_weights([
|
model1.set_weights([
|
||||||
np.array([1.]),
|
np.array([1.]),
|
||||||
|
28
tensorflow/python/keras/layers/normalization_v2.py
Normal file
28
tensorflow/python/keras/layers/normalization_v2.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""The V2 implementation of Normalization layers.
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.keras.layers.normalization import BatchNormalizationBase
|
||||||
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
|
|
||||||
|
@keras_export('keras.layers.BatchNormalization', v1=[]) # pylint: disable=missing-docstring
|
||||||
|
class BatchNormalization(BatchNormalizationBase):
|
||||||
|
|
||||||
|
_USE_V2_BEHAVIOR = True
|
@ -42,22 +42,13 @@ from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
|||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
if tf2.enabled():
|
if tf2.enabled():
|
||||||
|
from tensorflow.python.keras.layers.normalization_v2 import * # pylint: disable=g-import-not-at-top
|
||||||
from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top
|
from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top
|
||||||
|
|
||||||
# TODO(b/124791387): replace mapping with layer attribute.
|
|
||||||
# Name conversion between class name and API symbol in config.
|
|
||||||
_SERIALIZATION_TABLE = {
|
|
||||||
'BatchNormalizationV1': 'BatchNormalization',
|
|
||||||
'BatchNormalizationV2': 'BatchNormalization',
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.layers.serialize')
|
@keras_export('keras.layers.serialize')
|
||||||
def serialize(layer):
|
def serialize(layer):
|
||||||
layer_class_name = layer.__class__.__name__
|
return {'class_name': layer.__class__.__name__, 'config': layer.get_config()}
|
||||||
if layer_class_name in _SERIALIZATION_TABLE:
|
|
||||||
layer_class_name = _SERIALIZATION_TABLE[layer_class_name]
|
|
||||||
return {'class_name': layer_class_name, 'config': layer.get_config()}
|
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.layers.deserialize')
|
@keras_export('keras.layers.deserialize')
|
||||||
|
@ -23,6 +23,8 @@ from absl.testing import parameterized
|
|||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.framework import test_util as tf_test_util
|
from tensorflow.python.framework import test_util as tf_test_util
|
||||||
|
from tensorflow.python.keras.layers import normalization as batchnorm_v1
|
||||||
|
from tensorflow.python.keras.layers import normalization_v2 as batchnorm_v2
|
||||||
from tensorflow.python.keras.layers import recurrent as rnn_v1
|
from tensorflow.python.keras.layers import recurrent as rnn_v1
|
||||||
from tensorflow.python.keras.layers import recurrent_v2 as rnn_v2
|
from tensorflow.python.keras.layers import recurrent_v2 as rnn_v2
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -47,17 +49,21 @@ class LayerSerializationTest(parameterized.TestCase, test.TestCase):
|
|||||||
keras.initializers.Ones)
|
keras.initializers.Ones)
|
||||||
self.assertEqual(new_layer.units, 3)
|
self.assertEqual(new_layer.units, 3)
|
||||||
|
|
||||||
def test_serialize_deserialize_batchnorm(self):
|
@parameterized.parameters(
|
||||||
layer = keras.layers.BatchNormalization(
|
[batchnorm_v1.BatchNormalization, batchnorm_v2.BatchNormalization])
|
||||||
|
def test_serialize_deserialize_batchnorm(self, batchnorm_layer):
|
||||||
|
layer = batchnorm_layer(
|
||||||
momentum=0.9, beta_initializer='zeros', gamma_regularizer='l2')
|
momentum=0.9, beta_initializer='zeros', gamma_regularizer='l2')
|
||||||
config = keras.layers.serialize(layer)
|
config = keras.layers.serialize(layer)
|
||||||
self.assertEqual(config['class_name'], 'BatchNormalization')
|
self.assertEqual(config['class_name'], 'BatchNormalization')
|
||||||
new_layer = keras.layers.deserialize(config)
|
new_layer = keras.layers.deserialize(config)
|
||||||
self.assertEqual(new_layer.momentum, 0.9)
|
self.assertEqual(new_layer.momentum, 0.9)
|
||||||
if tf2.enabled():
|
if tf2.enabled():
|
||||||
|
self.assertIsInstance(new_layer, batchnorm_v2.BatchNormalization)
|
||||||
self.assertEqual(new_layer.beta_initializer.__class__,
|
self.assertEqual(new_layer.beta_initializer.__class__,
|
||||||
keras.initializers.ZerosV2)
|
keras.initializers.ZerosV2)
|
||||||
else:
|
else:
|
||||||
|
self.assertIsInstance(new_layer, batchnorm_v1.BatchNormalization)
|
||||||
self.assertEqual(new_layer.beta_initializer.__class__,
|
self.assertEqual(new_layer.beta_initializer.__class__,
|
||||||
keras.initializers.Zeros)
|
keras.initializers.Zeros)
|
||||||
self.assertEqual(new_layer.gamma_regularizer.__class__,
|
self.assertEqual(new_layer.gamma_regularizer.__class__,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
path: "tensorflow.keras.layers.BatchNormalization"
|
path: "tensorflow.keras.layers.BatchNormalization"
|
||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationV1\'>"
|
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalization\'>"
|
||||||
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationBase\'>"
|
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationBase\'>"
|
||||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
path: "tensorflow.layers.BatchNormalization"
|
path: "tensorflow.layers.BatchNormalization"
|
||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.layers.normalization.BatchNormalization\'>"
|
is_instance: "<class \'tensorflow.python.layers.normalization.BatchNormalization\'>"
|
||||||
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationV1\'>"
|
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalization\'>"
|
||||||
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationBase\'>"
|
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationBase\'>"
|
||||||
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||||
|
@ -1112,10 +1112,6 @@ tf_module {
|
|||||||
name: "disable_resource_variables"
|
name: "disable_resource_variables"
|
||||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "disable_v2_batch_normalization"
|
|
||||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
|
||||||
}
|
|
||||||
member_method {
|
member_method {
|
||||||
name: "disable_v2_behavior"
|
name: "disable_v2_behavior"
|
||||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
@ -1160,10 +1156,6 @@ tf_module {
|
|||||||
name: "enable_resource_variables"
|
name: "enable_resource_variables"
|
||||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "enable_v2_batch_normalization"
|
|
||||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
|
||||||
}
|
|
||||||
member_method {
|
member_method {
|
||||||
name: "enable_v2_behavior"
|
name: "enable_v2_behavior"
|
||||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
path: "tensorflow.keras.layers.BatchNormalization"
|
path: "tensorflow.keras.layers.BatchNormalization"
|
||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationV2\'>"
|
is_instance: "<class \'tensorflow.python.keras.layers.normalization_v2.BatchNormalization\'>"
|
||||||
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationBase\'>"
|
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationBase\'>"
|
||||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
|
Loading…
Reference in New Issue
Block a user