From adf769043f0c48a44c05c5a24aac14f0b4951896 Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Sat, 1 Feb 2020 12:13:42 -0800 Subject: [PATCH] Add support for aggregating batch statistics across devices by using the newly added tf.keras.layers.experimental.SyncBatchNormalization layer. PiperOrigin-RevId: 292723222 Change-Id: I1c0458ec24c7e712ffa5e12dcf1f5efd6b4ce8ac --- .../python/distribute/ctl_correctness_test.py | 40 ++-- .../distribute/keras_correctness_test_base.py | 7 +- .../keras_image_model_correctness_test.py | 24 +- tensorflow/python/keras/layers/__init__.py | 2 + .../python/keras/layers/normalization.py | 6 +- .../python/keras/layers/normalization_v2.py | 182 +++++++++++++++ ...perimental.-sync-batch-normalization.pbtxt | 218 ++++++++++++++++++ ...tensorflow.keras.layers.experimental.pbtxt | 4 + 8 files changed, 463 insertions(+), 20 deletions(-) create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-sync-batch-normalization.pbtxt diff --git a/tensorflow/python/distribute/ctl_correctness_test.py b/tensorflow/python/distribute/ctl_correctness_test.py index ec133fc19ef..fd2926adcf6 100644 --- a/tensorflow/python/distribute/ctl_correctness_test.py +++ b/tensorflow/python/distribute/ctl_correctness_test.py @@ -35,7 +35,7 @@ from tensorflow.python.framework import random_seed from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn -_NUM_SAMPLES = 64 +_NUM_SAMPLES = 66 _BATCH_SIZE = 32 _RANDOM_SEED = 1337 _NUM_EPOCHS = 2 @@ -60,12 +60,16 @@ class MaybeStrategyScope(object): self._scope = None -def get_model(): +def get_model(sync_batchnorm=False): model = keras.Sequential() model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) model.add(keras.layers.Dense( 10, activation='relu', kernel_regularizer=keras.regularizers.l2(1e-4))) + if sync_batchnorm: + model.add(keras.layers.SyncBatchNormalization()) + else: + model.add(keras.layers.BatchNormalization()) model.add(keras.layers.Dense(10, activation='relu')) model.add(keras.layers.Dense(1)) return model @@ -90,10 +94,13 @@ def compute_loss(labels, logits, reg_losses): def iteration_inside_func(initial_weights, dataset, optimizer_fn, - iteration_type, strategy=None): + iteration_type, strategy=None, sync_batchnorm=None): """Helper function to test iterating over data inside a tf.function.""" with MaybeStrategyScope(strategy): - model = get_model() + if strategy and sync_batchnorm: + model = get_model(sync_batchnorm) + else: + model = get_model() model.set_weights(initial_weights) optimizer = optimizer_fn() @@ -153,10 +160,10 @@ def iteration_inside_func(initial_weights, dataset, optimizer_fn, def iteration_outside_func(initial_weights, dataset, optimizer_fn, - iteration_type, strategy=None): + iteration_type, strategy=None, sync_batchnorm=None): """Helper function to test iterating over data outside a tf.function.""" with MaybeStrategyScope(strategy): - model = get_model() + model = get_model(sync_batchnorm=sync_batchnorm) model.set_weights(initial_weights) optimizer = optimizer_fn() @@ -223,16 +230,21 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase, optimizer_fn=strategy_combinations.optimizers_v1_and_v2, mode=['eager'], iteration_type=['iterator', 'dataset'], - inside_func=[False, True] + inside_func=[False, True], + sync_batchnorm=[True, False] )) def test_dnn_correctness_minus_tpus(self, distribution, optimizer_fn, - iteration_type, inside_func): + iteration_type, inside_func, + sync_batchnorm): + # TODO(anjs): Identify why this particular V1 optimizer needs a higher tol. + if 'FtrlV1' in optimizer_fn._name and 'TPU' in type(distribution).__name__: + self.skipTest('Reduced tolerance of the order of 1e-1 required.') self.dnn_correctness(distribution, optimizer_fn, iteration_type, - inside_func) + inside_func, sync_batchnorm) def dnn_correctness(self, distribution, optimizer_fn, iteration_type, - inside_func): - model = get_model() + inside_func, sync_batchnorm=None): + model = get_model(sync_batchnorm) initial_weights = model.get_weights() dataset = get_data() if inside_func: @@ -241,13 +253,15 @@ class TestDistributionStrategyDnnCorrectness(test.TestCase, iteration_func = iteration_outside_func wts_with_ds, loss_with_ds, acc_with_ds = iteration_func( initial_weights, dataset, optimizer_fn, iteration_type, - strategy=distribution) + strategy=distribution, sync_batchnorm=sync_batchnorm) wts, loss, acc = iteration_func(initial_weights, dataset, optimizer_fn, - iteration_type) + iteration_type, + sync_batchnorm=sync_batchnorm) self.assertAllClose(wts, wts_with_ds, atol=1e-3, rtol=1e-3) self.assertAllClose(loss, loss_with_ds, atol=1e-3, rtol=1e-3) self.assertAllClose(acc, acc_with_ds, atol=1e-3, rtol=1e-3) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/distribute/keras_correctness_test_base.py b/tensorflow/python/keras/distribute/keras_correctness_test_base.py index 1c40a48e830..b9527127a5b 100644 --- a/tensorflow/python/keras/distribute/keras_correctness_test_base.py +++ b/tensorflow/python/keras/distribute/keras_correctness_test_base.py @@ -386,7 +386,7 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase, def set_up_test_config(self, use_numpy=False, use_validation_data=False, - with_batch_norm=False): + with_batch_norm=None): self.use_numpy = use_numpy self.use_validation_data = use_validation_data self.with_batch_norm = with_batch_norm @@ -435,7 +435,7 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase, use_numpy, use_validation_data, experimental_run_tf_function=None, - with_batch_norm=False, + with_batch_norm=None, is_stateful_model=False, partial_last_batch=None, training_epochs=2): @@ -503,7 +503,8 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase, # First, special case, for multi-replica distributed training, batch # norm is not aggregated globally. So it is expected to have different # weights. - if (self.with_batch_norm and distribution.num_replicas_in_sync > 1): + if (self.with_batch_norm == 'regular' and + distribution.num_replicas_in_sync > 1): with self.assertRaises(AssertionError): compare_results( results_with_ds, diff --git a/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py index 8f050f817a4..903067252af 100644 --- a/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py +++ b/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np from tensorflow.python import keras from tensorflow.python.distribute import combinations +from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.keras.distribute import keras_correctness_test_base from tensorflow.python.keras.optimizer_v2 import gradient_descent @@ -43,8 +44,10 @@ class DistributionStrategyCnnCorrectnessTest( strides=(4, 4), kernel_regularizer=keras.regularizers.l2(1e-4))( image) - if self.with_batch_norm: + if self.with_batch_norm == 'regular': c1 = keras.layers.BatchNormalization(name='bn1')(c1) + elif self.with_batch_norm == 'sync': + c1 = keras.layers.SyncBatchNormalization(name='bn1')(c1) c1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(c1) logits = keras.layers.Dense( 10, activation='softmax', name='pred')( @@ -107,7 +110,22 @@ class DistributionStrategyCnnCorrectnessTest( distribution, use_numpy, use_validation_data, - with_batch_norm=True, + with_batch_norm='regular', + experimental_run_tf_function=experimental_run_tf_function) + + @combinations.generate( + keras_correctness_test_base.all_strategy_and_input_config_combinations()) + def test_cnn_with_sync_batch_norm_correctness(self, distribution, use_numpy, + use_validation_data, + experimental_run_tf_function): + if not context.executing_eagerly() or not experimental_run_tf_function: + self.skipTest('SyncBatchNorm is not enabled in graph mode.') + + self.run_correctness_test( + distribution, + use_numpy, + use_validation_data, + with_batch_norm='sync', experimental_run_tf_function=experimental_run_tf_function) @combinations.generate( @@ -134,7 +152,7 @@ class DistributionStrategyCnnCorrectnessTest( distribution, use_numpy, use_validation_data, - with_batch_norm=True, + with_batch_norm='regular', partial_last_batch=True) diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py index 3f648b46bff..2370e138f09 100644 --- a/tensorflow/python/keras/layers/__init__.py +++ b/tensorflow/python/keras/layers/__init__.py @@ -135,6 +135,8 @@ from tensorflow.python.keras.layers.noise import GaussianDropout # Normalization layers. from tensorflow.python.keras.layers.normalization import LayerNormalization +from tensorflow.python.keras.layers.normalization_v2 import SyncBatchNormalization + if tf2.enabled(): from tensorflow.python.keras.layers.normalization_v2 import BatchNormalization from tensorflow.python.keras.layers.normalization import BatchNormalization as BatchNormalizationV1 diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index be686ad5e50..819e0e5929c 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -652,8 +652,12 @@ class BatchNormalizationBase(Layer): return (r, d, out_mean, out_variance) + def _calculate_mean_and_var(self, inputs, reduction_axes, keep_dims): + return nn.moments(inputs, reduction_axes, keep_dims=keep_dims) + def _moments(self, inputs, reduction_axes, keep_dims): - mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims) + mean, variance = self._calculate_mean_and_var(inputs, reduction_axes, + keep_dims) # TODO(b/129279393): Support zero batch input in non DistributionStrategy # code as well. if self._support_zero_size_input(): diff --git a/tensorflow/python/keras/layers/normalization_v2.py b/tensorflow/python/keras/layers/normalization_v2.py index 6a1049e773f..02e24d346db 100644 --- a/tensorflow/python/keras/layers/normalization_v2.py +++ b/tensorflow/python/keras/layers/normalization_v2.py @@ -18,10 +18,192 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.distribute import distribution_strategy_context as ds +from tensorflow.python.distribute import reduce_util +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.keras.layers import normalization +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import keras_export +@keras_export('keras.layers.experimental.SyncBatchNormalization', v1=[]) # pylint: disable=g-classes-have-attributes +class SyncBatchNormalization(normalization.BatchNormalizationBase): + r"""Normalize and scale inputs or activations synchronously across replicas. + + Applies batch normalization to activations of the previous layer at each batch + by synchronizing the global batch statistics across all devices that are + training the model. For specific details about batch normalization please + refer to the `tf.keras.layers.BatchNormalization` layer docs. + + If this layer is used when using tf.distribute strategy to train models + across devices/workers, there will be an allreduce call to aggregate batch + statistics across all replicas at every training step. Without tf.distribute + strategy, this layer behaves as a regular `tf.keras.layers.BatchNormalization` + layer. + + Example usage: + ``` + strategy = tf.distribute.MirroredStrategy() + + with strategy.scope(): + model = tf.keras.Sequential() + model.add(tf.keras.layers.Dense(16)) + model.add(tf.keras.layers.experimental.SyncBatchNormalization()) + ``` + + Arguments: + axis: Integer, the axis that should be normalized + (typically the features axis). + For instance, after a `Conv2D` layer with + `data_format="channels_first"`, + set `axis=1` in `BatchNormalization`. + momentum: Momentum for the moving average. + epsilon: Small float added to variance to avoid dividing by zero. + center: If True, add offset of `beta` to normalized tensor. + If False, `beta` is ignored. + scale: If True, multiply by `gamma`. + If False, `gamma` is not used. + When the next layer is linear (also e.g. `nn.relu`), + this can be disabled since the scaling + will be done by the next layer. + beta_initializer: Initializer for the beta weight. + gamma_initializer: Initializer for the gamma weight. + moving_mean_initializer: Initializer for the moving mean. + moving_variance_initializer: Initializer for the moving variance. + beta_regularizer: Optional regularizer for the beta weight. + gamma_regularizer: Optional regularizer for the gamma weight. + beta_constraint: Optional constraint for the beta weight. + gamma_constraint: Optional constraint for the gamma weight. + renorm: Whether to use Batch Renormalization + (https://arxiv.org/abs/1702.03275). This adds extra variables during + training. The inference is the same for either value of this parameter. + renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to + scalar `Tensors` used to clip the renorm correction. The correction + `(r, d)` is used as `corrected_value = normalized_value * r + d`, with + `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, + dmax are set to inf, 0, inf, respectively. + renorm_momentum: Momentum used to update the moving means and standard + deviations with renorm. Unlike `momentum`, this affects training + and should be neither too small (which would add noise) nor too large + (which would give stale estimates). Note that `momentum` is still applied + to get the means and variances for inference. + trainable: Boolean, if `True` the variables will be marked as trainable. + + Call arguments: + inputs: Input tensor (of any rank). + training: Python boolean indicating whether the layer should behave in + training mode or in inference mode. + - `training=True`: The layer will normalize its inputs using the + mean and variance of the current batch of inputs. + - `training=False`: The layer will normalize its inputs using the + mean and variance of its moving statistics, learned during training. + + Input shape: + Arbitrary. Use the keyword argument `input_shape` + (tuple of integers, does not include the samples axis) + when using this layer as the first layer in a model. + + Output shape: + Same shape as input. + + """ + + def __init__(self, + axis=-1, + momentum=0.99, + epsilon=1e-3, + center=True, + scale=True, + beta_initializer='zeros', + gamma_initializer='ones', + moving_mean_initializer='zeros', + moving_variance_initializer='ones', + beta_regularizer=None, + gamma_regularizer=None, + beta_constraint=None, + gamma_constraint=None, + renorm=False, + renorm_clipping=None, + renorm_momentum=0.99, + trainable=True, + adjustment=None, + name=None, + **kwargs): + + # Currently we only support aggregating over the global batch size. + super(SyncBatchNormalization, self).__init__( + axis=axis, + momentum=momentum, + epsilon=epsilon, + center=center, + scale=scale, + beta_initializer=beta_initializer, + gamma_initializer=gamma_initializer, + moving_mean_initializer=moving_mean_initializer, + moving_variance_initializer=moving_variance_initializer, + beta_regularizer=beta_regularizer, + gamma_regularizer=gamma_regularizer, + beta_constraint=beta_constraint, + gamma_constraint=gamma_constraint, + renorm=renorm, + renorm_clipping=renorm_clipping, + renorm_momentum=renorm_momentum, + fused=False, + trainable=trainable, + virtual_batch_size=None, + name=name, + **kwargs) + + def _calculate_mean_and_var(self, x, axes, keep_dims): + + with ops.name_scope('moments', values=[x, axes]): + # The dynamic range of fp16 is too limited to support the collection of + # sufficient statistics. As a workaround we simply perform the operations + # on 32-bit floats before converting the mean and variance back to fp16 + y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x + replica_ctx = ds.get_replica_context() + if replica_ctx: + local_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True) + local_squared_sum = math_ops.reduce_sum(math_ops.square(y), axis=axes, + keepdims=True) + y_sum, y_squared_sum, global_batch_size = ( + replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, [ + local_sum, local_squared_sum, array_ops.shape_v2(y)[0]])) + + axes_vals = [(array_ops.shape_v2(y))[i] for i in range(1, len(axes))] + multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals), + dtypes.float32) + multiplier = multiplier * math_ops.cast(global_batch_size, + dtypes.float32) + + mean = y_sum / multiplier + y_squared_mean = y_squared_sum / multiplier + # var = E(x^2) - E(x)^2 + variance = y_squared_mean - math_ops.square(mean) + else: + # Compute true mean while keeping the dims for proper broadcasting. + mean = math_ops.reduce_mean(y, axes, keepdims=True, name='mean') + # sample variance, not unbiased variance + # Note: stop_gradient does not change the gradient that gets + # backpropagated to the mean from the variance calculation, + # because that gradient is zero + variance = math_ops.reduce_mean( + math_ops.squared_difference(y, array_ops.stop_gradient(mean)), + axes, + keepdims=True, + name='variance') + if not keep_dims: + mean = array_ops.squeeze(mean, axes) + variance = array_ops.squeeze(variance, axes) + if x.dtype == dtypes.float16: + return (math_ops.cast(mean, dtypes.float16), + math_ops.cast(variance, dtypes.float16)) + else: + return (mean, variance) + + @keras_export('keras.layers.BatchNormalization', v1=[]) # pylint: disable=missing-docstring class BatchNormalization(normalization.BatchNormalizationBase): diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-sync-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-sync-batch-normalization.pbtxt new file mode 100644 index 00000000000..ceb406d3747 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-sync-batch-normalization.pbtxt @@ -0,0 +1,218 @@ +path: "tensorflow.keras.layers.experimental.SyncBatchNormalization" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "dynamic" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "input_spec" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "metrics" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "name_scope" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "stateful" + mtype: "" + } + member { + name: "submodules" + mtype: "" + } + member { + name: "trainable" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'beta_constraint\', \'gamma_constraint\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'trainable\', \'adjustment\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'zeros\', \'ones\', \'zeros\', \'ones\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_metric" + argspec: "args=[\'self\', \'value\', \'aggregation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_signature" + argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_name_scope" + argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.pbtxt index 7f6d81d297a..f9d1e84781d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.pbtxt @@ -1,5 +1,9 @@ path: "tensorflow.keras.layers.experimental" tf_module { + member { + name: "SyncBatchNormalization" + mtype: "" + } member { name: "preprocessing" mtype: ""