Automated rollback of commit d208ae26ed

PiperOrigin-RevId: 243939780
This commit is contained in:
Ruoxin Sang 2019-04-16 22:54:42 -07:00 committed by TensorFlower Gardener
parent c16cf12009
commit 4a1931ae46
2 changed files with 55 additions and 30 deletions

View File

@ -43,7 +43,7 @@ class NormalizationTest(test.TestCase, parameterized.TestCase):
@combinations.generate( @combinations.generate(
combinations.times(all_combinations, combinations.times(all_combinations,
combinations.combine(fused=[True, False]))) combinations.combine(fused=[True, False])))
def disabled_testBNWithZeroBatchInput(self, distribution, fused): def testBNWithZeroBatchInput(self, distribution, fused):
with distribution.scope(), self.cached_session() as sess: with distribution.scope(), self.cached_session() as sess:
bn_list = [] bn_list = []
inputs = np.random.random((0, 4, 4, 3)) + 100 inputs = np.random.random((0, 4, 4, 3)) + 100

View File

@ -423,7 +423,7 @@ class BatchNormalizationBase(Layer):
self._scope.set_partitioner(partitioner) self._scope.set_partitioner(partitioner)
self.built = True self.built = True
def _assign_moving_average(self, variable, value, momentum): def _assign_moving_average(self, variable, value, momentum, inputs_size):
with ops.name_scope(None, 'AssignMovingAvg', with ops.name_scope(None, 'AssignMovingAvg',
[variable, value, momentum]) as scope: [variable, value, momentum]) as scope:
with ops.colocate_with(variable): with ops.colocate_with(variable):
@ -432,6 +432,9 @@ class BatchNormalizationBase(Layer):
decay = math_ops.cast(decay, variable.dtype.base_dtype) decay = math_ops.cast(decay, variable.dtype.base_dtype)
update_delta = ( update_delta = (
variable - math_ops.cast(value, variable.dtype)) * decay variable - math_ops.cast(value, variable.dtype)) * decay
if inputs_size is not None:
update_delta = array_ops.where(inputs_size > 0, update_delta,
K.zeros_like(update_delta))
return state_ops.assign_sub(variable, update_delta, name=scope) return state_ops.assign_sub(variable, update_delta, name=scope)
def _fused_batch_norm(self, inputs, training): def _fused_batch_norm(self, inputs, training):
@ -439,6 +442,14 @@ class BatchNormalizationBase(Layer):
beta = self.beta if self.center else self._beta_const beta = self.beta if self.center else self._beta_const
gamma = self.gamma if self.scale else self._gamma_const gamma = self.gamma if self.scale else self._gamma_const
# TODO(b/129279393): Support zero batch input in non DistributionStrategy
# code as well.
if distribution_strategy_context.has_strategy(
) and not inputs.shape.is_fully_defined():
inputs_size = array_ops.size(inputs)
else:
inputs_size = None
def _fused_batch_norm_training(): def _fused_batch_norm_training():
return nn.fused_batch_norm( return nn.fused_batch_norm(
inputs, inputs,
@ -479,31 +490,24 @@ class BatchNormalizationBase(Layer):
if training_value or training_value is None: if training_value or training_value is None:
if distribution_strategy_context.in_cross_replica_context(): if distribution_strategy_context.in_cross_replica_context():
strategy = distribution_strategy_context.get_strategy() strategy = distribution_strategy_context.get_strategy()
mean_update = strategy.extended.update(
def mean_update(): self.moving_mean, self._assign_moving_average,
return strategy.extended.update(self.moving_mean, (mean, self.momentum, inputs_size))
self._assign_moving_average, variance_update = strategy.extended.update(
(mean, self.momentum)) self.moving_variance, self._assign_moving_average,
(variance, self.momentum, inputs_size))
def variance_update():
return strategy.extended.update(self.moving_variance,
self._assign_moving_average,
(variance, self.momentum))
else: else:
mean_update = self._assign_moving_average(self.moving_mean, mean,
def mean_update(): momentum, inputs_size)
return self._assign_moving_average(self.moving_mean, mean, momentum) variance_update = self._assign_moving_average(
self.moving_variance, variance, momentum, inputs_size)
def variance_update():
return self._assign_moving_average(self.moving_variance, variance,
momentum)
self.add_update(mean_update, inputs=True) self.add_update(mean_update, inputs=True)
self.add_update(variance_update, inputs=True) self.add_update(variance_update, inputs=True)
return output return output
def _renorm_correction_and_moments(self, mean, variance, training): def _renorm_correction_and_moments(self, mean, variance, training,
inputs_size):
"""Returns the correction and update values for renorm.""" """Returns the correction and update values for renorm."""
stddev = math_ops.sqrt(variance + self.epsilon) stddev = math_ops.sqrt(variance + self.epsilon)
# Compute the average mean and standard deviation, as if they were # Compute the average mean and standard deviation, as if they were
@ -534,7 +538,7 @@ class BatchNormalizationBase(Layer):
lambda: d, lambda: d,
lambda: array_ops.zeros_like(d)) lambda: array_ops.zeros_like(d))
def _update_renorm_variable(var, weight, value): def _update_renorm_variable(var, weight, value, inputs_size):
"""Updates a moving average and weight, returns the unbiased value.""" """Updates a moving average and weight, returns the unbiased value."""
value = array_ops.identity(value) value = array_ops.identity(value)
def _do_update(): def _do_update():
@ -547,9 +551,11 @@ class BatchNormalizationBase(Layer):
# Make sure the weight is not updated until before r and d computation. # Make sure the weight is not updated until before r and d computation.
with ops.control_dependencies([value]): with ops.control_dependencies([value]):
weight_value = array_ops.constant(1., dtype=weight.dtype) weight_value = array_ops.constant(1., dtype=weight.dtype)
new_var = self._assign_moving_average(var, value, self.renorm_momentum) new_var = self._assign_moving_average(var, value, self.renorm_momentum,
inputs_size)
new_weight = self._assign_moving_average(weight, weight_value, new_weight = self._assign_moving_average(weight, weight_value,
self.renorm_momentum) self.renorm_momentum,
inputs_size)
# TODO(yuefengz): the updates to var and weighted can not be batched # TODO(yuefengz): the updates to var and weighted can not be batched
# together if we fetch their updated values here. Consider calculating # together if we fetch their updated values here. Consider calculating
# new values and delaying the updates. # new values and delaying the updates.
@ -561,16 +567,27 @@ class BatchNormalizationBase(Layer):
# TODO(yuefengz): colocate the operations # TODO(yuefengz): colocate the operations
new_mean = _update_renorm_variable(self.renorm_mean, new_mean = _update_renorm_variable(self.renorm_mean,
self.renorm_mean_weight, mean) self.renorm_mean_weight, mean,
inputs_size)
new_stddev = _update_renorm_variable(self.renorm_stddev, new_stddev = _update_renorm_variable(self.renorm_stddev,
self.renorm_stddev_weight, stddev) self.renorm_stddev_weight, stddev,
inputs_size)
# Make sqrt(moving_variance + epsilon) = new_stddev. # Make sqrt(moving_variance + epsilon) = new_stddev.
new_variance = math_ops.square(new_stddev) - self.epsilon new_variance = math_ops.square(new_stddev) - self.epsilon
return (r, d, new_mean, new_variance) return (r, d, new_mean, new_variance)
def _moments(self, inputs, reduction_axes, keep_dims): def _moments(self, inputs, reduction_axes, keep_dims):
return nn.moments(inputs, reduction_axes, keep_dims=keep_dims) mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
# TODO(b/129279393): Support zero batch input in non DistributionStrategy
# code as well.
if distribution_strategy_context.has_strategy(
) and not inputs.shape.is_fully_defined():
inputs_size = array_ops.size(inputs)
mean = array_ops.where(inputs_size > 0, mean, K.zeros_like(mean))
variance = array_ops.where(inputs_size > 0, variance,
K.zeros_like(variance))
return mean, variance
def call(self, inputs, training=None): def call(self, inputs, training=None):
if training is None: if training is None:
@ -667,9 +684,14 @@ class BatchNormalizationBase(Layer):
else: else:
new_mean, new_variance = mean, variance new_mean, new_variance = mean, variance
if distribution_strategy_context.has_strategy(
) and not inputs.shape.is_fully_defined():
inputs_size = array_ops.size(inputs)
else:
inputs_size = None
if self.renorm: if self.renorm:
r, d, new_mean, new_variance = self._renorm_correction_and_moments( r, d, new_mean, new_variance = self._renorm_correction_and_moments(
new_mean, new_variance, training) new_mean, new_variance, training, inputs_size)
# When training, the normalized values (say, x) will be transformed as # When training, the normalized values (say, x) will be transformed as
# x * gamma + beta without renorm, and (x * r + d) * gamma + beta # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
# = x * (r * gamma) + (d * gamma + beta) with renorm. # = x * (r * gamma) + (d * gamma + beta) with renorm.
@ -683,7 +705,8 @@ class BatchNormalizationBase(Layer):
def _do_update(var, value): def _do_update(var, value):
"""Compute the updates for mean and variance.""" """Compute the updates for mean and variance."""
return strategy.extended.update( return strategy.extended.update(
var, self._assign_moving_average, (value, self.momentum), var,
self._assign_moving_average, (value, self.momentum, inputs_size),
group=False) group=False)
# We need to unwrap the moving_mean or moving_variance in the case of # We need to unwrap the moving_mean or moving_variance in the case of
# training being false to match the output of true_fn and false_fn # training being false to match the output of true_fn and false_fn
@ -700,7 +723,9 @@ class BatchNormalizationBase(Layer):
else: else:
def _do_update(var, value): def _do_update(var, value):
"""Compute the updates for mean and variance.""" """Compute the updates for mean and variance."""
return self._assign_moving_average(var, value, self.momentum) return self._assign_moving_average(var, value, self.momentum,
inputs_size)
def mean_update(): def mean_update():
true_branch = lambda: _do_update(self.moving_mean, new_mean) true_branch = lambda: _do_update(self.moving_mean, new_mean)