From 2482c2354be49a83e179c603e45819dbe4e6adf0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 2 Mar 2020 14:25:18 -0800 Subject: [PATCH] Change Keras batch normalization layer to use the running mean and average computation in fused_batch_norm. PiperOrigin-RevId: 298446605 Change-Id: I87466c847e6c3c54a5f16d9588c19a72e24b15eb --- .../python/keras/layers/normalization.py | 79 ++++++++++++------- 1 file changed, 50 insertions(+), 29 deletions(-) diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index ca4734ef8cc..c2d152ea384 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.compat import compat from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -513,9 +514,11 @@ class BatchNormalizationBase(Layer): K.zeros_like(update_delta)) return state_ops.assign_sub(variable, update_delta, name=scope) - def _assign_new_value(self, variable, value): + def _assign_new_value(self, variable, value, inputs_size=None): with K.name_scope('AssignNewValue') as scope: with ops.colocate_with(variable): + if inputs_size is not None: + value = array_ops.where(inputs_size > 0, value, variable) return state_ops.assign(variable, value, name=scope) def _fused_batch_norm(self, inputs, training): @@ -530,13 +533,41 @@ class BatchNormalizationBase(Layer): else: inputs_size = None + if compat.forward_compatible(2020, 3, 6): + exponential_avg_factor = 1.0 - self.momentum + else: + exponential_avg_factor = None + + def _maybe_add_or_remove_bessels_correction(variance, remove=True): + r"""Add or remove Bessel's correction.""" + # Removes Bessel's correction if remove == True, adds it otherwise. + # This is to be consistent with non-fused batch norm. Note that the + # variance computed by fused batch norm is with Bessel's correction. + # This is only used in legacy V1 batch norm tests. + if self._bessels_correction_test_only: + return variance + sample_size = math_ops.cast( + array_ops.size(inputs) / array_ops.size(variance), variance.dtype) + if remove: + factor = (sample_size - + math_ops.cast(1.0, variance.dtype)) / sample_size + else: + factor = sample_size / ( + sample_size - math_ops.cast(1.0, variance.dtype)) + return variance * factor + def _fused_batch_norm_training(): return nn.fused_batch_norm( inputs, gamma, beta, + mean=self.moving_mean, + variance=_maybe_add_or_remove_bessels_correction( + self.moving_variance, remove=False), epsilon=self.epsilon, - data_format=self._data_format) + is_training=True, + data_format=self._data_format, + exponential_avg_factor=exponential_avg_factor) def _fused_batch_norm_inference(): return nn.fused_batch_norm( @@ -551,40 +582,30 @@ class BatchNormalizationBase(Layer): output, mean, variance = tf_utils.smart_cond( training, _fused_batch_norm_training, _fused_batch_norm_inference) - if not self._bessels_correction_test_only: - # Remove Bessel's correction to be consistent with non-fused batch norm. - # Note that the variance computed by fused batch norm is - # with Bessel's correction. - sample_size = math_ops.cast( - array_ops.size(inputs) / array_ops.size(variance), variance.dtype) - factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size - variance *= factor + variance = _maybe_add_or_remove_bessels_correction(variance, remove=True) training_value = tf_utils.constant_value(training) - if training_value is None: - momentum = tf_utils.smart_cond(training, - lambda: self.momentum, - lambda: 1.0) - else: - momentum = ops.convert_to_tensor_v2(self.momentum) if training_value or training_value is None: + if not compat.forward_compatible(2020, 3, 6): + if training_value is None: + momentum = tf_utils.smart_cond(training, lambda: self.momentum, + lambda: 1.0) + else: + momentum = ops.convert_to_tensor_v2(self.momentum) + def mean_update(): - return self._assign_moving_average(self.moving_mean, mean, momentum, - inputs_size) + """Update self.moving_mean with the most recent data point.""" + if compat.forward_compatible(2020, 3, 6): + return self._assign_new_value(self.moving_mean, mean, inputs_size) + else: + return self._assign_moving_average(self.moving_mean, mean, momentum, + inputs_size) def variance_update(): """Update self.moving_variance with the most recent data point.""" - if self.renorm: - # We apply epsilon as part of the moving_stddev to mirror the training - # code path. - moving_stddev = self._assign_moving_average( - self.moving_stddev, math_ops.sqrt(variance + self.epsilon), - momentum, inputs_size) - return self._assign_new_value( - self.moving_variance, - # Apply relu in case floating point rounding causes it to go - # negative. - K.relu(moving_stddev * moving_stddev - self.epsilon)) + if compat.forward_compatible(2020, 3, 6): + return self._assign_new_value(self.moving_variance, variance, + inputs_size) else: return self._assign_moving_average(self.moving_variance, variance, momentum, inputs_size)