Change Keras batch normalization layer to use the running mean and average computation in fused_batch_norm.

PiperOrigin-RevId: 298446605
Change-Id: I87466c847e6c3c54a5f16d9588c19a72e24b15eb
This commit is contained in:
A. Unique TensorFlower 2020-03-02 14:25:18 -08:00 committed by TensorFlower Gardener
parent 3f35d6d8b0
commit 2482c2354b

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -513,9 +514,11 @@ class BatchNormalizationBase(Layer):
K.zeros_like(update_delta))
return state_ops.assign_sub(variable, update_delta, name=scope)
def _assign_new_value(self, variable, value):
def _assign_new_value(self, variable, value, inputs_size=None):
with K.name_scope('AssignNewValue') as scope:
with ops.colocate_with(variable):
if inputs_size is not None:
value = array_ops.where(inputs_size > 0, value, variable)
return state_ops.assign(variable, value, name=scope)
def _fused_batch_norm(self, inputs, training):
@ -530,13 +533,41 @@ class BatchNormalizationBase(Layer):
else:
inputs_size = None
if compat.forward_compatible(2020, 3, 6):
exponential_avg_factor = 1.0 - self.momentum
else:
exponential_avg_factor = None
def _maybe_add_or_remove_bessels_correction(variance, remove=True):
r"""Add or remove Bessel's correction."""
# Removes Bessel's correction if remove == True, adds it otherwise.
# This is to be consistent with non-fused batch norm. Note that the
# variance computed by fused batch norm is with Bessel's correction.
# This is only used in legacy V1 batch norm tests.
if self._bessels_correction_test_only:
return variance
sample_size = math_ops.cast(
array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
if remove:
factor = (sample_size -
math_ops.cast(1.0, variance.dtype)) / sample_size
else:
factor = sample_size / (
sample_size - math_ops.cast(1.0, variance.dtype))
return variance * factor
def _fused_batch_norm_training():
return nn.fused_batch_norm(
inputs,
gamma,
beta,
mean=self.moving_mean,
variance=_maybe_add_or_remove_bessels_correction(
self.moving_variance, remove=False),
epsilon=self.epsilon,
data_format=self._data_format)
is_training=True,
data_format=self._data_format,
exponential_avg_factor=exponential_avg_factor)
def _fused_batch_norm_inference():
return nn.fused_batch_norm(
@ -551,40 +582,30 @@ class BatchNormalizationBase(Layer):
output, mean, variance = tf_utils.smart_cond(
training, _fused_batch_norm_training, _fused_batch_norm_inference)
if not self._bessels_correction_test_only:
# Remove Bessel's correction to be consistent with non-fused batch norm.
# Note that the variance computed by fused batch norm is
# with Bessel's correction.
sample_size = math_ops.cast(
array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size
variance *= factor
variance = _maybe_add_or_remove_bessels_correction(variance, remove=True)
training_value = tf_utils.constant_value(training)
if training_value is None:
momentum = tf_utils.smart_cond(training,
lambda: self.momentum,
lambda: 1.0)
else:
momentum = ops.convert_to_tensor_v2(self.momentum)
if training_value or training_value is None:
if not compat.forward_compatible(2020, 3, 6):
if training_value is None:
momentum = tf_utils.smart_cond(training, lambda: self.momentum,
lambda: 1.0)
else:
momentum = ops.convert_to_tensor_v2(self.momentum)
def mean_update():
return self._assign_moving_average(self.moving_mean, mean, momentum,
inputs_size)
"""Update self.moving_mean with the most recent data point."""
if compat.forward_compatible(2020, 3, 6):
return self._assign_new_value(self.moving_mean, mean, inputs_size)
else:
return self._assign_moving_average(self.moving_mean, mean, momentum,
inputs_size)
def variance_update():
"""Update self.moving_variance with the most recent data point."""
if self.renorm:
# We apply epsilon as part of the moving_stddev to mirror the training
# code path.
moving_stddev = self._assign_moving_average(
self.moving_stddev, math_ops.sqrt(variance + self.epsilon),
momentum, inputs_size)
return self._assign_new_value(
self.moving_variance,
# Apply relu in case floating point rounding causes it to go
# negative.
K.relu(moving_stddev * moving_stddev - self.epsilon))
if compat.forward_compatible(2020, 3, 6):
return self._assign_new_value(self.moving_variance, variance,
inputs_size)
else:
return self._assign_moving_average(self.moving_variance, variance,
momentum, inputs_size)