Fix serialization/deserialization for BatchNorm v1 and v2.

PiperOrigin-RevId: 238315500
This commit is contained in:
Yanhui Liang 2019-03-13 14:54:13 -07:00 committed by TensorFlower Gardener
parent 0993d774a8
commit 0f371719b1
12 changed files with 81 additions and 91 deletions

View File

@ -21,7 +21,6 @@ from __future__ import print_function
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.layers import normalization
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -43,7 +42,6 @@ def enable_v2_behavior():
ops.enable_eager_execution() ops.enable_eager_execution()
tensor_shape.enable_v2_tensorshape() # Also switched by tf2 tensor_shape.enable_v2_tensorshape() # Also switched by tf2
variable_scope.enable_resource_variables() variable_scope.enable_resource_variables()
normalization.enable_v2_batch_normalization()
@tf_export(v1=["disable_v2_behavior"]) @tf_export(v1=["disable_v2_behavior"])
@ -61,4 +59,3 @@ def disable_v2_behavior():
ops.disable_eager_execution() ops.disable_eager_execution()
tensor_shape.disable_v2_tensorshape() # Also switched by tf2 tensor_shape.disable_v2_tensorshape() # Also switched by tf2
variable_scope.disable_resource_variables() variable_scope.disable_resource_variables()
normalization.disable_v2_batch_normalization()

View File

@ -331,6 +331,7 @@ py_library(
"layers/merge.py", "layers/merge.py",
"layers/noise.py", "layers/noise.py",
"layers/normalization.py", "layers/normalization.py",
"layers/normalization_v2.py",
"layers/pooling.py", "layers/pooling.py",
"layers/recurrent.py", "layers/recurrent.py",
"layers/recurrent_v2.py", "layers/recurrent_v2.py",

View File

@ -109,8 +109,9 @@ from tensorflow.python.keras.layers.noise import GaussianNoise
from tensorflow.python.keras.layers.noise import GaussianDropout from tensorflow.python.keras.layers.noise import GaussianDropout
# Normalization layers. # Normalization layers.
from tensorflow.python.keras.layers.normalization import BatchNormalization
from tensorflow.python.keras.layers.normalization import LayerNormalization from tensorflow.python.keras.layers.normalization import LayerNormalization
from tensorflow.python.keras.layers.normalization import BatchNormalization
from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization as BatchNormalizationV2
# Kernelized layers. # Kernelized layers.
from tensorflow.python.keras.layers.kernelized import RandomFourierFeatures from tensorflow.python.keras.layers.kernelized import RandomFourierFeatures

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python import tf2
from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -39,7 +38,6 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables as tf_variables from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
from tensorflow.python.util.tf_export import tf_export
class BatchNormalizationBase(Layer): class BatchNormalizationBase(Layer):
@ -785,7 +783,7 @@ def _replace_in_base_docstring(old, new):
@keras_export(v1=['keras.layers.BatchNormalization']) # pylint: disable=missing-docstring @keras_export(v1=['keras.layers.BatchNormalization']) # pylint: disable=missing-docstring
class BatchNormalizationV1(BatchNormalizationBase): class BatchNormalization(BatchNormalizationBase):
__doc__ = _replace_in_base_docstring( __doc__ = _replace_in_base_docstring(
''' '''
@ -801,33 +799,6 @@ class BatchNormalizationV1(BatchNormalizationBase):
_USE_V2_BEHAVIOR = False _USE_V2_BEHAVIOR = False
@keras_export('keras.layers.BatchNormalization', v1=[]) # pylint: disable=missing-docstring
class BatchNormalizationV2(BatchNormalizationBase):
pass
BatchNormalization = None # pylint: disable=invalid-name
@tf_export(v1=['enable_v2_batch_normalization'])
def enable_v2_batch_normalization():
global BatchNormalization # pylint: disable=invalid-name
BatchNormalization = BatchNormalizationV2
@tf_export(v1=['disable_v2_batch_normalization'])
def disable_v2_batch_normalization():
global BatchNormalization # pylint: disable=invalid-name
BatchNormalization = BatchNormalizationV1
if tf2.enabled():
enable_v2_batch_normalization()
else:
disable_v2_batch_normalization()
@keras_export('keras.layers.experimental.LayerNormalization') @keras_export('keras.layers.experimental.LayerNormalization')
class LayerNormalization(Layer): class LayerNormalization(Layer):
"""Layer normalization layer (Ba et al., 2016). """Layer normalization layer (Ba et al., 2016).

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
@ -26,6 +27,7 @@ from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.layers import normalization from tensorflow.python.keras.layers import normalization
from tensorflow.python.keras.layers import normalization_v2
from tensorflow.python.keras.mixed_precision.experimental import policy from tensorflow.python.keras.mixed_precision.experimental import policy
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent from tensorflow.python.training import gradient_descent
@ -131,18 +133,14 @@ class BatchNormalizationTest(keras_parameterized.TestCase):
_run_batchnorm_correctness_test( _run_batchnorm_correctness_test(
normalization.BatchNormalization, dtype='float32') normalization.BatchNormalization, dtype='float32')
_run_batchnorm_correctness_test( _run_batchnorm_correctness_test(
normalization.BatchNormalization, dtype='float32', fused=True) normalization_v2.BatchNormalization, dtype='float32')
_run_batchnorm_correctness_test(
normalization.BatchNormalization, dtype='float32', fused=False)
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
def test_batchnorm_mixed_precision(self): def test_batchnorm_mixed_precision(self):
_run_batchnorm_correctness_test( _run_batchnorm_correctness_test(
normalization.BatchNormalization, dtype='float16') normalization.BatchNormalization, dtype='float16')
_run_batchnorm_correctness_test( _run_batchnorm_correctness_test(
normalization.BatchNormalization, dtype='float16', fused=True) normalization_v2.BatchNormalization, dtype='float16')
_run_batchnorm_correctness_test(
normalization.BatchNormalization, dtype='float16', fused=False)
@tf_test_util.run_in_graph_and_eager_modes @tf_test_util.run_in_graph_and_eager_modes
def test_batchnorm_policy(self): def test_batchnorm_policy(self):
@ -162,18 +160,18 @@ class BatchNormalizationV1Test(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes @tf_test_util.run_in_graph_and_eager_modes
def test_v1_fused_attribute(self): def test_v1_fused_attribute(self):
norm = normalization.BatchNormalizationV1() norm = normalization.BatchNormalization()
inp = keras.layers.Input((4, 4, 4)) inp = keras.layers.Input((4, 4, 4))
norm(inp) norm(inp)
self.assertEqual(norm.fused, True) self.assertEqual(norm.fused, True)
norm = normalization.BatchNormalizationV1(fused=False) norm = normalization.BatchNormalization(fused=False)
self.assertEqual(norm.fused, False) self.assertEqual(norm.fused, False)
inp = keras.layers.Input(shape=(4, 4, 4)) inp = keras.layers.Input(shape=(4, 4, 4))
norm(inp) norm(inp)
self.assertEqual(norm.fused, False) self.assertEqual(norm.fused, False)
norm = normalization.BatchNormalizationV1(virtual_batch_size=2) norm = normalization.BatchNormalization(virtual_batch_size=2)
self.assertEqual(norm.fused, True) self.assertEqual(norm.fused, True)
inp = keras.layers.Input(shape=(2, 2, 2)) inp = keras.layers.Input(shape=(2, 2, 2))
norm(inp) norm(inp)
@ -185,63 +183,63 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase):
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
def test_basic_batchnorm_v2(self): def test_basic_batchnorm_v2(self):
testing_utils.layer_test( testing_utils.layer_test(
normalization.BatchNormalizationV2, normalization_v2.BatchNormalization,
kwargs={'fused': True}, kwargs={'fused': True},
input_shape=(3, 3, 3, 3)) input_shape=(3, 3, 3, 3))
testing_utils.layer_test( testing_utils.layer_test(
normalization.BatchNormalizationV2, normalization_v2.BatchNormalization,
kwargs={'fused': None}, kwargs={'fused': None},
input_shape=(3, 3, 3)) input_shape=(3, 3, 3))
@tf_test_util.run_in_graph_and_eager_modes @tf_test_util.run_in_graph_and_eager_modes
def test_v2_fused_attribute(self): def test_v2_fused_attribute(self):
norm = normalization.BatchNormalizationV2() norm = normalization_v2.BatchNormalization()
self.assertEqual(norm.fused, None) self.assertEqual(norm.fused, None)
inp = keras.layers.Input(shape=(4, 4, 4)) inp = keras.layers.Input(shape=(4, 4, 4))
norm(inp) norm(inp)
self.assertEqual(norm.fused, True) self.assertEqual(norm.fused, True)
norm = normalization.BatchNormalizationV2() norm = normalization_v2.BatchNormalization()
self.assertEqual(norm.fused, None) self.assertEqual(norm.fused, None)
inp = keras.layers.Input(shape=(4, 4)) inp = keras.layers.Input(shape=(4, 4))
norm(inp) norm(inp)
self.assertEqual(norm.fused, False) self.assertEqual(norm.fused, False)
norm = normalization.BatchNormalizationV2(virtual_batch_size=2) norm = normalization_v2.BatchNormalization(virtual_batch_size=2)
self.assertEqual(norm.fused, False) self.assertEqual(norm.fused, False)
inp = keras.layers.Input(shape=(4, 4, 4)) inp = keras.layers.Input(shape=(4, 4, 4))
norm(inp) norm(inp)
self.assertEqual(norm.fused, False) self.assertEqual(norm.fused, False)
norm = normalization.BatchNormalizationV2(fused=False) norm = normalization_v2.BatchNormalization(fused=False)
self.assertEqual(norm.fused, False) self.assertEqual(norm.fused, False)
inp = keras.layers.Input(shape=(4, 4, 4)) inp = keras.layers.Input(shape=(4, 4, 4))
norm(inp) norm(inp)
self.assertEqual(norm.fused, False) self.assertEqual(norm.fused, False)
norm = normalization.BatchNormalizationV2(fused=True, axis=[3]) norm = normalization_v2.BatchNormalization(fused=True, axis=[3])
self.assertEqual(norm.fused, True) self.assertEqual(norm.fused, True)
inp = keras.layers.Input(shape=(4, 4, 4)) inp = keras.layers.Input(shape=(4, 4, 4))
norm(inp) norm(inp)
self.assertEqual(norm.fused, True) self.assertEqual(norm.fused, True)
with self.assertRaisesRegexp(ValueError, 'fused.*renorm'): with self.assertRaisesRegexp(ValueError, 'fused.*renorm'):
normalization.BatchNormalizationV2(fused=True, renorm=True) normalization_v2.BatchNormalization(fused=True, renorm=True)
with self.assertRaisesRegexp(ValueError, 'fused.*when axis is 1 or 3'): with self.assertRaisesRegexp(ValueError, 'fused.*when axis is 1 or 3'):
normalization.BatchNormalizationV2(fused=True, axis=2) normalization_v2.BatchNormalization(fused=True, axis=2)
with self.assertRaisesRegexp(ValueError, 'fused.*when axis is 1 or 3'): with self.assertRaisesRegexp(ValueError, 'fused.*when axis is 1 or 3'):
normalization.BatchNormalizationV2(fused=True, axis=[1, 3]) normalization_v2.BatchNormalization(fused=True, axis=[1, 3])
with self.assertRaisesRegexp(ValueError, 'fused.*virtual_batch_size'): with self.assertRaisesRegexp(ValueError, 'fused.*virtual_batch_size'):
normalization.BatchNormalizationV2(fused=True, virtual_batch_size=2) normalization_v2.BatchNormalization(fused=True, virtual_batch_size=2)
with self.assertRaisesRegexp(ValueError, 'fused.*adjustment'): with self.assertRaisesRegexp(ValueError, 'fused.*adjustment'):
normalization.BatchNormalizationV2(fused=True, normalization_v2.BatchNormalization(fused=True,
adjustment=lambda _: (1, 0)) adjustment=lambda _: (1, 0))
norm = normalization.BatchNormalizationV2(fused=True) norm = normalization_v2.BatchNormalization(fused=True)
self.assertEqual(norm.fused, True) self.assertEqual(norm.fused, True)
inp = keras.layers.Input(shape=(4, 4)) inp = keras.layers.Input(shape=(4, 4))
with self.assertRaisesRegexp(ValueError, '4D input tensors'): with self.assertRaisesRegexp(ValueError, '4D input tensors'):
@ -272,14 +270,16 @@ def _run_batchnorm_correctness_test(layer, dtype='float32', fused=False):
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
class NormalizationLayersGraphModeOnlyTest(test.TestCase): @parameterized.parameters(
[normalization.BatchNormalization, normalization_v2.BatchNormalization])
class NormalizationLayersGraphModeOnlyTest(
test.TestCase, parameterized.TestCase):
def test_shared_batchnorm(self): def test_shared_batchnorm(self, layer):
"""Test that a BN layer can be shared across different data streams. """Test that a BN layer can be shared across different data streams."""
"""
with self.cached_session(): with self.cached_session():
# Test single layer reuse # Test single layer reuse
bn = keras.layers.BatchNormalization() bn = layer()
x1 = keras.layers.Input(shape=(10,)) x1 = keras.layers.Input(shape=(10,))
_ = bn(x1) _ = bn(x1)
@ -307,13 +307,13 @@ class NormalizationLayersGraphModeOnlyTest(test.TestCase):
new_model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') new_model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
new_model.train_on_batch(x, x) new_model.train_on_batch(x, x)
def test_that_trainable_disables_updates(self): def test_that_trainable_disables_updates(self, layer):
with self.cached_session(): with self.cached_session():
val_a = np.random.random((10, 4)) val_a = np.random.random((10, 4))
val_out = np.random.random((10, 4)) val_out = np.random.random((10, 4))
a = keras.layers.Input(shape=(4,)) a = keras.layers.Input(shape=(4,))
layer = keras.layers.BatchNormalization(input_shape=(4,)) layer = layer(input_shape=(4,))
b = layer(a) b = layer(a)
model = keras.models.Model(a, b) model = keras.models.Model(a, b)
@ -346,11 +346,14 @@ class NormalizationLayersGraphModeOnlyTest(test.TestCase):
self.assertAllClose(x1, x2, atol=1e-7) self.assertAllClose(x1, x2, atol=1e-7)
@tf_test_util.run_deprecated_v1 @tf_test_util.run_deprecated_v1
def test_batchnorm_trainable(self): def test_batchnorm_trainable(self, layer):
"""Tests that batchnorm layer is trainable when learning phase is enabled. """Tests that batchnorm layer is trainable when learning phase is enabled.
Computes mean and std for current inputs then Computes mean and std for current inputs then
applies batch normalization using them. applies batch normalization using them.
Args:
layer: Either V1 or V2 of BatchNormalization layer.
""" """
# TODO(fchollet): enable in all execution modes when issue with # TODO(fchollet): enable in all execution modes when issue with
# learning phase setting is resolved. # learning phase setting is resolved.
@ -361,7 +364,7 @@ class NormalizationLayersGraphModeOnlyTest(test.TestCase):
def get_model(bn_mean, bn_std): def get_model(bn_mean, bn_std):
inp = keras.layers.Input(shape=(1,)) inp = keras.layers.Input(shape=(1,))
x = keras.layers.BatchNormalization()(inp) x = layer()(inp)
model1 = keras.models.Model(inp, x) model1 = keras.models.Model(inp, x)
model1.set_weights([ model1.set_weights([
np.array([1.]), np.array([1.]),

View File

@ -0,0 +1,28 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The V2 implementation of Normalization layers.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras.layers.normalization import BatchNormalizationBase
from tensorflow.python.util.tf_export import keras_export
@keras_export('keras.layers.BatchNormalization', v1=[]) # pylint: disable=missing-docstring
class BatchNormalization(BatchNormalizationBase):
_USE_V2_BEHAVIOR = True

View File

@ -42,22 +42,13 @@ from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
if tf2.enabled(): if tf2.enabled():
from tensorflow.python.keras.layers.normalization_v2 import * # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top from tensorflow.python.keras.layers.recurrent_v2 import * # pylint: disable=g-import-not-at-top
# TODO(b/124791387): replace mapping with layer attribute.
# Name conversion between class name and API symbol in config.
_SERIALIZATION_TABLE = {
'BatchNormalizationV1': 'BatchNormalization',
'BatchNormalizationV2': 'BatchNormalization',
}
@keras_export('keras.layers.serialize') @keras_export('keras.layers.serialize')
def serialize(layer): def serialize(layer):
layer_class_name = layer.__class__.__name__ return {'class_name': layer.__class__.__name__, 'config': layer.get_config()}
if layer_class_name in _SERIALIZATION_TABLE:
layer_class_name = _SERIALIZATION_TABLE[layer_class_name]
return {'class_name': layer_class_name, 'config': layer.get_config()}
@keras_export('keras.layers.deserialize') @keras_export('keras.layers.deserialize')

View File

@ -23,6 +23,8 @@ from absl.testing import parameterized
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras.layers import normalization as batchnorm_v1
from tensorflow.python.keras.layers import normalization_v2 as batchnorm_v2
from tensorflow.python.keras.layers import recurrent as rnn_v1 from tensorflow.python.keras.layers import recurrent as rnn_v1
from tensorflow.python.keras.layers import recurrent_v2 as rnn_v2 from tensorflow.python.keras.layers import recurrent_v2 as rnn_v2
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -47,17 +49,21 @@ class LayerSerializationTest(parameterized.TestCase, test.TestCase):
keras.initializers.Ones) keras.initializers.Ones)
self.assertEqual(new_layer.units, 3) self.assertEqual(new_layer.units, 3)
def test_serialize_deserialize_batchnorm(self): @parameterized.parameters(
layer = keras.layers.BatchNormalization( [batchnorm_v1.BatchNormalization, batchnorm_v2.BatchNormalization])
def test_serialize_deserialize_batchnorm(self, batchnorm_layer):
layer = batchnorm_layer(
momentum=0.9, beta_initializer='zeros', gamma_regularizer='l2') momentum=0.9, beta_initializer='zeros', gamma_regularizer='l2')
config = keras.layers.serialize(layer) config = keras.layers.serialize(layer)
self.assertEqual(config['class_name'], 'BatchNormalization') self.assertEqual(config['class_name'], 'BatchNormalization')
new_layer = keras.layers.deserialize(config) new_layer = keras.layers.deserialize(config)
self.assertEqual(new_layer.momentum, 0.9) self.assertEqual(new_layer.momentum, 0.9)
if tf2.enabled(): if tf2.enabled():
self.assertIsInstance(new_layer, batchnorm_v2.BatchNormalization)
self.assertEqual(new_layer.beta_initializer.__class__, self.assertEqual(new_layer.beta_initializer.__class__,
keras.initializers.ZerosV2) keras.initializers.ZerosV2)
else: else:
self.assertIsInstance(new_layer, batchnorm_v1.BatchNormalization)
self.assertEqual(new_layer.beta_initializer.__class__, self.assertEqual(new_layer.beta_initializer.__class__,
keras.initializers.Zeros) keras.initializers.Zeros)
self.assertEqual(new_layer.gamma_regularizer.__class__, self.assertEqual(new_layer.gamma_regularizer.__class__,

View File

@ -1,6 +1,6 @@
path: "tensorflow.keras.layers.BatchNormalization" path: "tensorflow.keras.layers.BatchNormalization"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationV1\'>" is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalization\'>"
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationBase\'>" is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationBase\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>" is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"

View File

@ -1,7 +1,7 @@
path: "tensorflow.layers.BatchNormalization" path: "tensorflow.layers.BatchNormalization"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.layers.normalization.BatchNormalization\'>" is_instance: "<class \'tensorflow.python.layers.normalization.BatchNormalization\'>"
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationV1\'>" is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalization\'>"
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationBase\'>" is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationBase\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>" is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>" is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"

View File

@ -1112,10 +1112,6 @@ tf_module {
name: "disable_resource_variables" name: "disable_resource_variables"
argspec: "args=[], varargs=None, keywords=None, defaults=None" argspec: "args=[], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "disable_v2_batch_normalization"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "disable_v2_behavior" name: "disable_v2_behavior"
argspec: "args=[], varargs=None, keywords=None, defaults=None" argspec: "args=[], varargs=None, keywords=None, defaults=None"
@ -1160,10 +1156,6 @@ tf_module {
name: "enable_resource_variables" name: "enable_resource_variables"
argspec: "args=[], varargs=None, keywords=None, defaults=None" argspec: "args=[], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "enable_v2_batch_normalization"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "enable_v2_behavior" name: "enable_v2_behavior"
argspec: "args=[], varargs=None, keywords=None, defaults=None" argspec: "args=[], varargs=None, keywords=None, defaults=None"

View File

@ -1,6 +1,6 @@
path: "tensorflow.keras.layers.BatchNormalization" path: "tensorflow.keras.layers.BatchNormalization"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationV2\'>" is_instance: "<class \'tensorflow.python.keras.layers.normalization_v2.BatchNormalization\'>"
is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationBase\'>" is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationBase\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>" is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"