Change Keras batch normalization layer to use the running mean and average computation in fused_batch_norm.
PiperOrigin-RevId: 298446605 Change-Id: I87466c847e6c3c54a5f16d9588c19a72e24b15eb
This commit is contained in:
parent
3f35d6d8b0
commit
2482c2354b
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -513,9 +514,11 @@ class BatchNormalizationBase(Layer):
|
||||
K.zeros_like(update_delta))
|
||||
return state_ops.assign_sub(variable, update_delta, name=scope)
|
||||
|
||||
def _assign_new_value(self, variable, value):
|
||||
def _assign_new_value(self, variable, value, inputs_size=None):
|
||||
with K.name_scope('AssignNewValue') as scope:
|
||||
with ops.colocate_with(variable):
|
||||
if inputs_size is not None:
|
||||
value = array_ops.where(inputs_size > 0, value, variable)
|
||||
return state_ops.assign(variable, value, name=scope)
|
||||
|
||||
def _fused_batch_norm(self, inputs, training):
|
||||
@ -530,13 +533,41 @@ class BatchNormalizationBase(Layer):
|
||||
else:
|
||||
inputs_size = None
|
||||
|
||||
if compat.forward_compatible(2020, 3, 6):
|
||||
exponential_avg_factor = 1.0 - self.momentum
|
||||
else:
|
||||
exponential_avg_factor = None
|
||||
|
||||
def _maybe_add_or_remove_bessels_correction(variance, remove=True):
|
||||
r"""Add or remove Bessel's correction."""
|
||||
# Removes Bessel's correction if remove == True, adds it otherwise.
|
||||
# This is to be consistent with non-fused batch norm. Note that the
|
||||
# variance computed by fused batch norm is with Bessel's correction.
|
||||
# This is only used in legacy V1 batch norm tests.
|
||||
if self._bessels_correction_test_only:
|
||||
return variance
|
||||
sample_size = math_ops.cast(
|
||||
array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
|
||||
if remove:
|
||||
factor = (sample_size -
|
||||
math_ops.cast(1.0, variance.dtype)) / sample_size
|
||||
else:
|
||||
factor = sample_size / (
|
||||
sample_size - math_ops.cast(1.0, variance.dtype))
|
||||
return variance * factor
|
||||
|
||||
def _fused_batch_norm_training():
|
||||
return nn.fused_batch_norm(
|
||||
inputs,
|
||||
gamma,
|
||||
beta,
|
||||
mean=self.moving_mean,
|
||||
variance=_maybe_add_or_remove_bessels_correction(
|
||||
self.moving_variance, remove=False),
|
||||
epsilon=self.epsilon,
|
||||
data_format=self._data_format)
|
||||
is_training=True,
|
||||
data_format=self._data_format,
|
||||
exponential_avg_factor=exponential_avg_factor)
|
||||
|
||||
def _fused_batch_norm_inference():
|
||||
return nn.fused_batch_norm(
|
||||
@ -551,40 +582,30 @@ class BatchNormalizationBase(Layer):
|
||||
|
||||
output, mean, variance = tf_utils.smart_cond(
|
||||
training, _fused_batch_norm_training, _fused_batch_norm_inference)
|
||||
if not self._bessels_correction_test_only:
|
||||
# Remove Bessel's correction to be consistent with non-fused batch norm.
|
||||
# Note that the variance computed by fused batch norm is
|
||||
# with Bessel's correction.
|
||||
sample_size = math_ops.cast(
|
||||
array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
|
||||
factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size
|
||||
variance *= factor
|
||||
variance = _maybe_add_or_remove_bessels_correction(variance, remove=True)
|
||||
|
||||
training_value = tf_utils.constant_value(training)
|
||||
if training_value is None:
|
||||
momentum = tf_utils.smart_cond(training,
|
||||
lambda: self.momentum,
|
||||
lambda: 1.0)
|
||||
else:
|
||||
momentum = ops.convert_to_tensor_v2(self.momentum)
|
||||
if training_value or training_value is None:
|
||||
if not compat.forward_compatible(2020, 3, 6):
|
||||
if training_value is None:
|
||||
momentum = tf_utils.smart_cond(training, lambda: self.momentum,
|
||||
lambda: 1.0)
|
||||
else:
|
||||
momentum = ops.convert_to_tensor_v2(self.momentum)
|
||||
|
||||
def mean_update():
|
||||
return self._assign_moving_average(self.moving_mean, mean, momentum,
|
||||
inputs_size)
|
||||
"""Update self.moving_mean with the most recent data point."""
|
||||
if compat.forward_compatible(2020, 3, 6):
|
||||
return self._assign_new_value(self.moving_mean, mean, inputs_size)
|
||||
else:
|
||||
return self._assign_moving_average(self.moving_mean, mean, momentum,
|
||||
inputs_size)
|
||||
|
||||
def variance_update():
|
||||
"""Update self.moving_variance with the most recent data point."""
|
||||
if self.renorm:
|
||||
# We apply epsilon as part of the moving_stddev to mirror the training
|
||||
# code path.
|
||||
moving_stddev = self._assign_moving_average(
|
||||
self.moving_stddev, math_ops.sqrt(variance + self.epsilon),
|
||||
momentum, inputs_size)
|
||||
return self._assign_new_value(
|
||||
self.moving_variance,
|
||||
# Apply relu in case floating point rounding causes it to go
|
||||
# negative.
|
||||
K.relu(moving_stddev * moving_stddev - self.epsilon))
|
||||
if compat.forward_compatible(2020, 3, 6):
|
||||
return self._assign_new_value(self.moving_variance, variance,
|
||||
inputs_size)
|
||||
else:
|
||||
return self._assign_moving_average(self.moving_variance, variance,
|
||||
momentum, inputs_size)
|
||||
|
Loading…
Reference in New Issue
Block a user