diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 81a7ab0d04a..36dd22f97cd 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -2,6 +2,7 @@ load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") load("//tensorflow/core:platform/default/distribute.bzl", "distribute_py_test") +load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") package( @@ -839,17 +840,3 @@ tf_xla_py_test( "//tensorflow/python/training/tracking:util", ], ) - -distribute_py_test( - name = "zero_batch_test", - srcs = ["zero_batch_test.py"], - main = "zero_batch_test.py", - deps = [ - ":mirrored_strategy", - ":tpu_strategy", - "//tensorflow/python/distribute:combinations", - "//tensorflow/python/distribute:strategy_combinations", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) diff --git a/tensorflow/contrib/distribute/python/zero_batch_test.py b/tensorflow/contrib/distribute/python/zero_batch_test.py deleted file mode 100644 index 2aeffa3aadf..00000000000 --- a/tensorflow/contrib/distribute/python/zero_batch_test.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Test DistributionStrategy in the zero batch case.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from absl.testing import parameterized -import numpy as np - -from tensorflow.python.distribute import combinations -from tensorflow.python.distribute import strategy_combinations -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.layers import normalization -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import variables -from tensorflow.python.ops.losses import losses -from tensorflow.python.platform import test -from tensorflow.python.training import gradient_descent - - -all_combinations = combinations.combine( - distribution=[ - strategy_combinations.one_device_strategy, - ], - mode=["graph"]) - - -class NormalizationTest(test.TestCase, parameterized.TestCase): - - @combinations.generate( - combinations.times(all_combinations, - combinations.combine(fused=[True, False]))) - def testBNWithZeroBatchInput(self, distribution, fused): - with distribution.scope(), self.cached_session() as sess: - bn_list = [] - inputs = ops.convert_to_tensor( - np.random.random((0, 4, 4, 3)) + 100, dtype=dtypes.float32) - targets = ops.convert_to_tensor( - np.random.random((0, 4, 4, 3)), dtype=dtypes.float32) - - def step_fn(is_training, inputs, targets=None): - bn = normalization.BatchNormalization( - axis=3, epsilon=1e-3, momentum=0.9, fused=fused) - bn_list.append(bn) - outputs = bn.apply(inputs, training=is_training) - if not is_training: - return outputs - - loss = losses.mean_squared_error(targets, outputs) - optimizer = gradient_descent.GradientDescentOptimizer(0.01) - train_op = optimizer.minimize(loss) - with ops.control_dependencies([train_op]): - return array_ops.identity(loss) - - train_op = distribution.extended.call_for_each_replica( - step_fn, args=(True, inputs, targets)) - predict_op = distribution.extended.call_for_each_replica( - step_fn, args=(False, inputs)) - bn = bn_list[0] - - self.evaluate(variables.global_variables_initializer()) - - # Check for initial statistics and weights. - moving_mean, moving_var = self.evaluate( - [bn.moving_mean, bn.moving_variance]) - self.assertAllEqual([0, 0, 0], moving_mean) - self.assertAllEqual([1, 1, 1], moving_var) - - np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta]) - self.assertAllEqual([1, 1, 1], np_gamma) - self.assertAllEqual([0, 0, 0], np_beta) - - for _ in range(100): - np_output, _, _ = sess.run([train_op] + bn.updates) - self.assertEqual(0.0, np_output) - - # Verify that the statistics and weights are not changed after training. - moving_mean, moving_var = self.evaluate( - [bn.moving_mean, bn.moving_variance]) - self.assertAllEqual([0, 0, 0], moving_mean) - self.assertAllEqual([1, 1, 1], moving_var) - - np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta]) - self.assertAllEqual([1, 1, 1], np_gamma) - self.assertAllEqual([0, 0, 0], np_beta) - - # Test inference. - np_output = sess.run(predict_op) - self.assertEqual([], np_output.tolist()) - - -if __name__ == "__main__": - test.main() - diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index 3f0e8ec86ab..24a02a9d7fa 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -424,7 +424,7 @@ class BatchNormalizationBase(Layer): self._scope.set_partitioner(partitioner) self.built = True - def _assign_moving_average(self, variable, value, momentum, inputs_size): + def _assign_moving_average(self, variable, value, momentum): with ops.name_scope(None, 'AssignMovingAvg', [variable, value, momentum]) as scope: with ops.colocate_with(variable): @@ -433,19 +433,12 @@ class BatchNormalizationBase(Layer): decay = math_ops.cast(decay, variable.dtype.base_dtype) update_delta = ( variable - math_ops.cast(value, variable.dtype)) * decay - # TODO(b/129279393): Support zero batch input in non - # DistributionStrategy code as well. - if distribution_strategy_context.has_strategy(): - update_delta = tf_utils.smart_cond( - inputs_size > 0, - lambda: update_delta, lambda: K.zeros_like(update_delta)) return state_ops.assign_sub(variable, update_delta, name=scope) def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" beta = self.beta if self.center else self._beta_const gamma = self.gamma if self.scale else self._gamma_const - inputs_size = array_ops.size(inputs) def _fused_batch_norm_training(): return nn.fused_batch_norm( @@ -489,22 +482,21 @@ class BatchNormalizationBase(Layer): strategy = distribution_strategy_context.get_strategy() mean_update = strategy.extended.update( self.moving_mean, self._assign_moving_average, - (mean, self.momentum, inputs_size)) + (mean, self.momentum)) variance_update = strategy.extended.update( self.moving_variance, self._assign_moving_average, - (variance, self.momentum, inputs_size)) + (variance, self.momentum)) else: mean_update = self._assign_moving_average(self.moving_mean, mean, - momentum, inputs_size) - variance_update = self._assign_moving_average( - self.moving_variance, variance, momentum, inputs_size) + momentum) + variance_update = self._assign_moving_average(self.moving_variance, + variance, momentum) self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) return output - def _renorm_correction_and_moments(self, mean, variance, training, - inputs_size): + def _renorm_correction_and_moments(self, mean, variance, training): """Returns the correction and update values for renorm.""" stddev = math_ops.sqrt(variance + self.epsilon) # Compute the average mean and standard deviation, as if they were @@ -535,7 +527,7 @@ class BatchNormalizationBase(Layer): lambda: d, lambda: array_ops.zeros_like(d)) - def _update_renorm_variable(var, weight, value, inputs_size): + def _update_renorm_variable(var, weight, value): """Updates a moving average and weight, returns the unbiased value.""" value = array_ops.identity(value) def _do_update(): @@ -548,10 +540,9 @@ class BatchNormalizationBase(Layer): # Make sure the weight is not updated until before r and d computation. with ops.control_dependencies([value]): weight_value = array_ops.constant(1., dtype=weight.dtype) - new_var = self._assign_moving_average(var, value, self.renorm_momentum, - inputs_size) - new_weight = self._assign_moving_average( - weight, weight_value, self.renorm_momentum, inputs_size) + new_var = self._assign_moving_average(var, value, self.renorm_momentum) + new_weight = self._assign_moving_average(weight, weight_value, + self.renorm_momentum) # TODO(yuefengz): the updates to var and weighted can not be batched # together if we fetch their updated values here. Consider calculating # new values and delaying the updates. @@ -562,26 +553,17 @@ class BatchNormalizationBase(Layer): return tf_utils.smart_cond(training, _do_update, _fake_update) # TODO(yuefengz): colocate the operations - new_mean = _update_renorm_variable( - self.renorm_mean, self.renorm_mean_weight, mean, inputs_size) - new_stddev = _update_renorm_variable( - self.renorm_stddev, self.renorm_stddev_weight, stddev, inputs_size) + new_mean = _update_renorm_variable(self.renorm_mean, + self.renorm_mean_weight, mean) + new_stddev = _update_renorm_variable(self.renorm_stddev, + self.renorm_stddev_weight, stddev) # Make sqrt(moving_variance + epsilon) = new_stddev. new_variance = math_ops.square(new_stddev) - self.epsilon return (r, d, new_mean, new_variance) def _moments(self, inputs, reduction_axes, keep_dims): - mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims) - # TODO(b/129279393): Support zero batch input in non DistributionStrategy - # code as well. - if distribution_strategy_context.has_strategy(): - inputs_size = array_ops.size(inputs) - mean = tf_utils.smart_cond( - inputs_size > 0, lambda: mean, lambda: K.zeros_like(mean)) - variance = tf_utils.smart_cond( - inputs_size > 0, lambda: variance, lambda: K.zeros_like(variance)) - return mean, variance + return nn.moments(inputs, reduction_axes, keep_dims=keep_dims) def call(self, inputs, training=None): if training is None: @@ -679,10 +661,9 @@ class BatchNormalizationBase(Layer): else: new_mean, new_variance = mean, variance - inputs_size = array_ops.size(inputs) if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( - new_mean, new_variance, training, inputs_size) + new_mean, new_variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. @@ -698,8 +679,8 @@ class BatchNormalizationBase(Layer): if in_eager_mode and not self.trainable: return return strategy.extended.update( - var, self._assign_moving_average, - (value, self.momentum, inputs_size), group=False) + var, self._assign_moving_average, (value, self.momentum), + group=False) # We need to unwrap the moving_mean or moving_variance in the case of # training being false to match the output of true_fn and false_fn # in the smart cond. @@ -716,9 +697,7 @@ class BatchNormalizationBase(Layer): """Compute the updates for mean and variance.""" if in_eager_mode and not self.trainable: return - return self._assign_moving_average(var, value, self.momentum, - inputs_size) - + return self._assign_moving_average(var, value, self.momentum) mean_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_mean, new_mean),