From 4a1931ae46b7f77180ab098253bfa60ac72411c9 Mon Sep 17 00:00:00 2001 From: Ruoxin Sang Date: Tue, 16 Apr 2019 22:54:42 -0700 Subject: [PATCH] Automated rollback of commit d208ae26ed00b260ffefa8100437242260278ae8 PiperOrigin-RevId: 243939780 --- .../python/distribute/zero_batch_test.py | 2 +- .../python/keras/layers/normalization.py | 83 ++++++++++++------- 2 files changed, 55 insertions(+), 30 deletions(-) diff --git a/tensorflow/python/distribute/zero_batch_test.py b/tensorflow/python/distribute/zero_batch_test.py index cb8ce071e93..39c577a682e 100644 --- a/tensorflow/python/distribute/zero_batch_test.py +++ b/tensorflow/python/distribute/zero_batch_test.py @@ -43,7 +43,7 @@ class NormalizationTest(test.TestCase, parameterized.TestCase): @combinations.generate( combinations.times(all_combinations, combinations.combine(fused=[True, False]))) - def disabled_testBNWithZeroBatchInput(self, distribution, fused): + def testBNWithZeroBatchInput(self, distribution, fused): with distribution.scope(), self.cached_session() as sess: bn_list = [] inputs = np.random.random((0, 4, 4, 3)) + 100 diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index d27dc010b01..8549a37a1e9 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -423,7 +423,7 @@ class BatchNormalizationBase(Layer): self._scope.set_partitioner(partitioner) self.built = True - def _assign_moving_average(self, variable, value, momentum): + def _assign_moving_average(self, variable, value, momentum, inputs_size): with ops.name_scope(None, 'AssignMovingAvg', [variable, value, momentum]) as scope: with ops.colocate_with(variable): @@ -432,6 +432,9 @@ class BatchNormalizationBase(Layer): decay = math_ops.cast(decay, variable.dtype.base_dtype) update_delta = ( variable - math_ops.cast(value, variable.dtype)) * decay + if inputs_size is not None: + update_delta = array_ops.where(inputs_size > 0, update_delta, + K.zeros_like(update_delta)) return state_ops.assign_sub(variable, update_delta, name=scope) def _fused_batch_norm(self, inputs, training): @@ -439,6 +442,14 @@ class BatchNormalizationBase(Layer): beta = self.beta if self.center else self._beta_const gamma = self.gamma if self.scale else self._gamma_const + # TODO(b/129279393): Support zero batch input in non DistributionStrategy + # code as well. + if distribution_strategy_context.has_strategy( + ) and not inputs.shape.is_fully_defined(): + inputs_size = array_ops.size(inputs) + else: + inputs_size = None + def _fused_batch_norm_training(): return nn.fused_batch_norm( inputs, @@ -479,31 +490,24 @@ class BatchNormalizationBase(Layer): if training_value or training_value is None: if distribution_strategy_context.in_cross_replica_context(): strategy = distribution_strategy_context.get_strategy() - - def mean_update(): - return strategy.extended.update(self.moving_mean, - self._assign_moving_average, - (mean, self.momentum)) - - def variance_update(): - return strategy.extended.update(self.moving_variance, - self._assign_moving_average, - (variance, self.momentum)) + mean_update = strategy.extended.update( + self.moving_mean, self._assign_moving_average, + (mean, self.momentum, inputs_size)) + variance_update = strategy.extended.update( + self.moving_variance, self._assign_moving_average, + (variance, self.momentum, inputs_size)) else: - - def mean_update(): - return self._assign_moving_average(self.moving_mean, mean, momentum) - - def variance_update(): - return self._assign_moving_average(self.moving_variance, variance, - momentum) - + mean_update = self._assign_moving_average(self.moving_mean, mean, + momentum, inputs_size) + variance_update = self._assign_moving_average( + self.moving_variance, variance, momentum, inputs_size) self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) return output - def _renorm_correction_and_moments(self, mean, variance, training): + def _renorm_correction_and_moments(self, mean, variance, training, + inputs_size): """Returns the correction and update values for renorm.""" stddev = math_ops.sqrt(variance + self.epsilon) # Compute the average mean and standard deviation, as if they were @@ -534,7 +538,7 @@ class BatchNormalizationBase(Layer): lambda: d, lambda: array_ops.zeros_like(d)) - def _update_renorm_variable(var, weight, value): + def _update_renorm_variable(var, weight, value, inputs_size): """Updates a moving average and weight, returns the unbiased value.""" value = array_ops.identity(value) def _do_update(): @@ -547,9 +551,11 @@ class BatchNormalizationBase(Layer): # Make sure the weight is not updated until before r and d computation. with ops.control_dependencies([value]): weight_value = array_ops.constant(1., dtype=weight.dtype) - new_var = self._assign_moving_average(var, value, self.renorm_momentum) + new_var = self._assign_moving_average(var, value, self.renorm_momentum, + inputs_size) new_weight = self._assign_moving_average(weight, weight_value, - self.renorm_momentum) + self.renorm_momentum, + inputs_size) # TODO(yuefengz): the updates to var and weighted can not be batched # together if we fetch their updated values here. Consider calculating # new values and delaying the updates. @@ -561,16 +567,27 @@ class BatchNormalizationBase(Layer): # TODO(yuefengz): colocate the operations new_mean = _update_renorm_variable(self.renorm_mean, - self.renorm_mean_weight, mean) + self.renorm_mean_weight, mean, + inputs_size) new_stddev = _update_renorm_variable(self.renorm_stddev, - self.renorm_stddev_weight, stddev) + self.renorm_stddev_weight, stddev, + inputs_size) # Make sqrt(moving_variance + epsilon) = new_stddev. new_variance = math_ops.square(new_stddev) - self.epsilon return (r, d, new_mean, new_variance) def _moments(self, inputs, reduction_axes, keep_dims): - return nn.moments(inputs, reduction_axes, keep_dims=keep_dims) + mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims) + # TODO(b/129279393): Support zero batch input in non DistributionStrategy + # code as well. + if distribution_strategy_context.has_strategy( + ) and not inputs.shape.is_fully_defined(): + inputs_size = array_ops.size(inputs) + mean = array_ops.where(inputs_size > 0, mean, K.zeros_like(mean)) + variance = array_ops.where(inputs_size > 0, variance, + K.zeros_like(variance)) + return mean, variance def call(self, inputs, training=None): if training is None: @@ -667,9 +684,14 @@ class BatchNormalizationBase(Layer): else: new_mean, new_variance = mean, variance + if distribution_strategy_context.has_strategy( + ) and not inputs.shape.is_fully_defined(): + inputs_size = array_ops.size(inputs) + else: + inputs_size = None if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( - new_mean, new_variance, training) + new_mean, new_variance, training, inputs_size) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. @@ -683,7 +705,8 @@ class BatchNormalizationBase(Layer): def _do_update(var, value): """Compute the updates for mean and variance.""" return strategy.extended.update( - var, self._assign_moving_average, (value, self.momentum), + var, + self._assign_moving_average, (value, self.momentum, inputs_size), group=False) # We need to unwrap the moving_mean or moving_variance in the case of # training being false to match the output of true_fn and false_fn @@ -700,7 +723,9 @@ class BatchNormalizationBase(Layer): else: def _do_update(var, value): """Compute the updates for mean and variance.""" - return self._assign_moving_average(var, value, self.momentum) + return self._assign_moving_average(var, value, self.momentum, + inputs_size) + def mean_update(): true_branch = lambda: _do_update(self.moving_mean, new_mean)