From 308bb3c69b850535a49d49a63ca74d0a7ba61fc1 Mon Sep 17 00:00:00 2001 From: Ruoxin Sang Date: Wed, 27 Mar 2019 14:37:27 -0700 Subject: [PATCH] Handle zero batch input in BatchNorm correctly if inside a DistributionStrategy scope. PiperOrigin-RevId: 240643242 --- tensorflow/contrib/distribute/python/BUILD | 15 ++- .../distribute/python/zero_batch_test.py | 109 ++++++++++++++++++ .../python/keras/layers/normalization.py | 61 ++++++---- 3 files changed, 164 insertions(+), 21 deletions(-) create mode 100644 tensorflow/contrib/distribute/python/zero_batch_test.py diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index fb09339ef29..5f7923eeab3 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -2,7 +2,6 @@ load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") load("//tensorflow/core:platform/default/distribute.bzl", "distribute_py_test") -load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") package( @@ -805,3 +804,17 @@ tf_xla_py_test( "//tensorflow/python/training/tracking:util", ], ) + +distribute_py_test( + name = "zero_batch_test", + srcs = ["zero_batch_test.py"], + main = "zero_batch_test.py", + deps = [ + ":mirrored_strategy", + ":tpu_strategy", + "//tensorflow/python/distribute:combinations", + "//tensorflow/python/distribute:strategy_combinations", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/contrib/distribute/python/zero_batch_test.py b/tensorflow/contrib/distribute/python/zero_batch_test.py new file mode 100644 index 00000000000..2aeffa3aadf --- /dev/null +++ b/tensorflow/contrib/distribute/python/zero_batch_test.py @@ -0,0 +1,109 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test DistributionStrategy in the zero batch case.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.layers import normalization +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import test +from tensorflow.python.training import gradient_descent + + +all_combinations = combinations.combine( + distribution=[ + strategy_combinations.one_device_strategy, + ], + mode=["graph"]) + + +class NormalizationTest(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.times(all_combinations, + combinations.combine(fused=[True, False]))) + def testBNWithZeroBatchInput(self, distribution, fused): + with distribution.scope(), self.cached_session() as sess: + bn_list = [] + inputs = ops.convert_to_tensor( + np.random.random((0, 4, 4, 3)) + 100, dtype=dtypes.float32) + targets = ops.convert_to_tensor( + np.random.random((0, 4, 4, 3)), dtype=dtypes.float32) + + def step_fn(is_training, inputs, targets=None): + bn = normalization.BatchNormalization( + axis=3, epsilon=1e-3, momentum=0.9, fused=fused) + bn_list.append(bn) + outputs = bn.apply(inputs, training=is_training) + if not is_training: + return outputs + + loss = losses.mean_squared_error(targets, outputs) + optimizer = gradient_descent.GradientDescentOptimizer(0.01) + train_op = optimizer.minimize(loss) + with ops.control_dependencies([train_op]): + return array_ops.identity(loss) + + train_op = distribution.extended.call_for_each_replica( + step_fn, args=(True, inputs, targets)) + predict_op = distribution.extended.call_for_each_replica( + step_fn, args=(False, inputs)) + bn = bn_list[0] + + self.evaluate(variables.global_variables_initializer()) + + # Check for initial statistics and weights. + moving_mean, moving_var = self.evaluate( + [bn.moving_mean, bn.moving_variance]) + self.assertAllEqual([0, 0, 0], moving_mean) + self.assertAllEqual([1, 1, 1], moving_var) + + np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta]) + self.assertAllEqual([1, 1, 1], np_gamma) + self.assertAllEqual([0, 0, 0], np_beta) + + for _ in range(100): + np_output, _, _ = sess.run([train_op] + bn.updates) + self.assertEqual(0.0, np_output) + + # Verify that the statistics and weights are not changed after training. + moving_mean, moving_var = self.evaluate( + [bn.moving_mean, bn.moving_variance]) + self.assertAllEqual([0, 0, 0], moving_mean) + self.assertAllEqual([1, 1, 1], moving_var) + + np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta]) + self.assertAllEqual([1, 1, 1], np_gamma) + self.assertAllEqual([0, 0, 0], np_beta) + + # Test inference. + np_output = sess.run(predict_op) + self.assertEqual([], np_output.tolist()) + + +if __name__ == "__main__": + test.main() + diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index 24a02a9d7fa..3f0e8ec86ab 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -424,7 +424,7 @@ class BatchNormalizationBase(Layer): self._scope.set_partitioner(partitioner) self.built = True - def _assign_moving_average(self, variable, value, momentum): + def _assign_moving_average(self, variable, value, momentum, inputs_size): with ops.name_scope(None, 'AssignMovingAvg', [variable, value, momentum]) as scope: with ops.colocate_with(variable): @@ -433,12 +433,19 @@ class BatchNormalizationBase(Layer): decay = math_ops.cast(decay, variable.dtype.base_dtype) update_delta = ( variable - math_ops.cast(value, variable.dtype)) * decay + # TODO(b/129279393): Support zero batch input in non + # DistributionStrategy code as well. + if distribution_strategy_context.has_strategy(): + update_delta = tf_utils.smart_cond( + inputs_size > 0, + lambda: update_delta, lambda: K.zeros_like(update_delta)) return state_ops.assign_sub(variable, update_delta, name=scope) def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" beta = self.beta if self.center else self._beta_const gamma = self.gamma if self.scale else self._gamma_const + inputs_size = array_ops.size(inputs) def _fused_batch_norm_training(): return nn.fused_batch_norm( @@ -482,21 +489,22 @@ class BatchNormalizationBase(Layer): strategy = distribution_strategy_context.get_strategy() mean_update = strategy.extended.update( self.moving_mean, self._assign_moving_average, - (mean, self.momentum)) + (mean, self.momentum, inputs_size)) variance_update = strategy.extended.update( self.moving_variance, self._assign_moving_average, - (variance, self.momentum)) + (variance, self.momentum, inputs_size)) else: mean_update = self._assign_moving_average(self.moving_mean, mean, - momentum) - variance_update = self._assign_moving_average(self.moving_variance, - variance, momentum) + momentum, inputs_size) + variance_update = self._assign_moving_average( + self.moving_variance, variance, momentum, inputs_size) self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) return output - def _renorm_correction_and_moments(self, mean, variance, training): + def _renorm_correction_and_moments(self, mean, variance, training, + inputs_size): """Returns the correction and update values for renorm.""" stddev = math_ops.sqrt(variance + self.epsilon) # Compute the average mean and standard deviation, as if they were @@ -527,7 +535,7 @@ class BatchNormalizationBase(Layer): lambda: d, lambda: array_ops.zeros_like(d)) - def _update_renorm_variable(var, weight, value): + def _update_renorm_variable(var, weight, value, inputs_size): """Updates a moving average and weight, returns the unbiased value.""" value = array_ops.identity(value) def _do_update(): @@ -540,9 +548,10 @@ class BatchNormalizationBase(Layer): # Make sure the weight is not updated until before r and d computation. with ops.control_dependencies([value]): weight_value = array_ops.constant(1., dtype=weight.dtype) - new_var = self._assign_moving_average(var, value, self.renorm_momentum) - new_weight = self._assign_moving_average(weight, weight_value, - self.renorm_momentum) + new_var = self._assign_moving_average(var, value, self.renorm_momentum, + inputs_size) + new_weight = self._assign_moving_average( + weight, weight_value, self.renorm_momentum, inputs_size) # TODO(yuefengz): the updates to var and weighted can not be batched # together if we fetch their updated values here. Consider calculating # new values and delaying the updates. @@ -553,17 +562,26 @@ class BatchNormalizationBase(Layer): return tf_utils.smart_cond(training, _do_update, _fake_update) # TODO(yuefengz): colocate the operations - new_mean = _update_renorm_variable(self.renorm_mean, - self.renorm_mean_weight, mean) - new_stddev = _update_renorm_variable(self.renorm_stddev, - self.renorm_stddev_weight, stddev) + new_mean = _update_renorm_variable( + self.renorm_mean, self.renorm_mean_weight, mean, inputs_size) + new_stddev = _update_renorm_variable( + self.renorm_stddev, self.renorm_stddev_weight, stddev, inputs_size) # Make sqrt(moving_variance + epsilon) = new_stddev. new_variance = math_ops.square(new_stddev) - self.epsilon return (r, d, new_mean, new_variance) def _moments(self, inputs, reduction_axes, keep_dims): - return nn.moments(inputs, reduction_axes, keep_dims=keep_dims) + mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims) + # TODO(b/129279393): Support zero batch input in non DistributionStrategy + # code as well. + if distribution_strategy_context.has_strategy(): + inputs_size = array_ops.size(inputs) + mean = tf_utils.smart_cond( + inputs_size > 0, lambda: mean, lambda: K.zeros_like(mean)) + variance = tf_utils.smart_cond( + inputs_size > 0, lambda: variance, lambda: K.zeros_like(variance)) + return mean, variance def call(self, inputs, training=None): if training is None: @@ -661,9 +679,10 @@ class BatchNormalizationBase(Layer): else: new_mean, new_variance = mean, variance + inputs_size = array_ops.size(inputs) if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( - new_mean, new_variance, training) + new_mean, new_variance, training, inputs_size) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. @@ -679,8 +698,8 @@ class BatchNormalizationBase(Layer): if in_eager_mode and not self.trainable: return return strategy.extended.update( - var, self._assign_moving_average, (value, self.momentum), - group=False) + var, self._assign_moving_average, + (value, self.momentum, inputs_size), group=False) # We need to unwrap the moving_mean or moving_variance in the case of # training being false to match the output of true_fn and false_fn # in the smart cond. @@ -697,7 +716,9 @@ class BatchNormalizationBase(Layer): """Compute the updates for mean and variance.""" if in_eager_mode and not self.trainable: return - return self._assign_moving_average(var, value, self.momentum) + return self._assign_moving_average(var, value, self.momentum, + inputs_size) + mean_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_mean, new_mean),