Automated rollback of commit 308bb3c69b

PiperOrigin-RevId: 241024319
This commit is contained in:
Ruoxin Sang 2019-03-29 12:00:25 -07:00 committed by TensorFlower Gardener
parent 4f73ae879a
commit d208ae26ed
3 changed files with 21 additions and 164 deletions

View File

@ -2,6 +2,7 @@
load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test")
load("//tensorflow/core:platform/default/distribute.bzl", "distribute_py_test") load("//tensorflow/core:platform/default/distribute.bzl", "distribute_py_test")
load("//tensorflow:tensorflow.bzl", "py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test")
package( package(
@ -839,17 +840,3 @@ tf_xla_py_test(
"//tensorflow/python/training/tracking:util", "//tensorflow/python/training/tracking:util",
], ],
) )
distribute_py_test(
name = "zero_batch_test",
srcs = ["zero_batch_test.py"],
main = "zero_batch_test.py",
deps = [
":mirrored_strategy",
":tpu_strategy",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:strategy_combinations",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -1,109 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test DistributionStrategy in the zero batch case."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.layers import normalization
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
all_combinations = combinations.combine(
distribution=[
strategy_combinations.one_device_strategy,
],
mode=["graph"])
class NormalizationTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.times(all_combinations,
combinations.combine(fused=[True, False])))
def testBNWithZeroBatchInput(self, distribution, fused):
with distribution.scope(), self.cached_session() as sess:
bn_list = []
inputs = ops.convert_to_tensor(
np.random.random((0, 4, 4, 3)) + 100, dtype=dtypes.float32)
targets = ops.convert_to_tensor(
np.random.random((0, 4, 4, 3)), dtype=dtypes.float32)
def step_fn(is_training, inputs, targets=None):
bn = normalization.BatchNormalization(
axis=3, epsilon=1e-3, momentum=0.9, fused=fused)
bn_list.append(bn)
outputs = bn.apply(inputs, training=is_training)
if not is_training:
return outputs
loss = losses.mean_squared_error(targets, outputs)
optimizer = gradient_descent.GradientDescentOptimizer(0.01)
train_op = optimizer.minimize(loss)
with ops.control_dependencies([train_op]):
return array_ops.identity(loss)
train_op = distribution.extended.call_for_each_replica(
step_fn, args=(True, inputs, targets))
predict_op = distribution.extended.call_for_each_replica(
step_fn, args=(False, inputs))
bn = bn_list[0]
self.evaluate(variables.global_variables_initializer())
# Check for initial statistics and weights.
moving_mean, moving_var = self.evaluate(
[bn.moving_mean, bn.moving_variance])
self.assertAllEqual([0, 0, 0], moving_mean)
self.assertAllEqual([1, 1, 1], moving_var)
np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta])
self.assertAllEqual([1, 1, 1], np_gamma)
self.assertAllEqual([0, 0, 0], np_beta)
for _ in range(100):
np_output, _, _ = sess.run([train_op] + bn.updates)
self.assertEqual(0.0, np_output)
# Verify that the statistics and weights are not changed after training.
moving_mean, moving_var = self.evaluate(
[bn.moving_mean, bn.moving_variance])
self.assertAllEqual([0, 0, 0], moving_mean)
self.assertAllEqual([1, 1, 1], moving_var)
np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta])
self.assertAllEqual([1, 1, 1], np_gamma)
self.assertAllEqual([0, 0, 0], np_beta)
# Test inference.
np_output = sess.run(predict_op)
self.assertEqual([], np_output.tolist())
if __name__ == "__main__":
test.main()

View File

@ -424,7 +424,7 @@ class BatchNormalizationBase(Layer):
self._scope.set_partitioner(partitioner) self._scope.set_partitioner(partitioner)
self.built = True self.built = True
def _assign_moving_average(self, variable, value, momentum, inputs_size): def _assign_moving_average(self, variable, value, momentum):
with ops.name_scope(None, 'AssignMovingAvg', with ops.name_scope(None, 'AssignMovingAvg',
[variable, value, momentum]) as scope: [variable, value, momentum]) as scope:
with ops.colocate_with(variable): with ops.colocate_with(variable):
@ -433,19 +433,12 @@ class BatchNormalizationBase(Layer):
decay = math_ops.cast(decay, variable.dtype.base_dtype) decay = math_ops.cast(decay, variable.dtype.base_dtype)
update_delta = ( update_delta = (
variable - math_ops.cast(value, variable.dtype)) * decay variable - math_ops.cast(value, variable.dtype)) * decay
# TODO(b/129279393): Support zero batch input in non
# DistributionStrategy code as well.
if distribution_strategy_context.has_strategy():
update_delta = tf_utils.smart_cond(
inputs_size > 0,
lambda: update_delta, lambda: K.zeros_like(update_delta))
return state_ops.assign_sub(variable, update_delta, name=scope) return state_ops.assign_sub(variable, update_delta, name=scope)
def _fused_batch_norm(self, inputs, training): def _fused_batch_norm(self, inputs, training):
"""Returns the output of fused batch norm.""" """Returns the output of fused batch norm."""
beta = self.beta if self.center else self._beta_const beta = self.beta if self.center else self._beta_const
gamma = self.gamma if self.scale else self._gamma_const gamma = self.gamma if self.scale else self._gamma_const
inputs_size = array_ops.size(inputs)
def _fused_batch_norm_training(): def _fused_batch_norm_training():
return nn.fused_batch_norm( return nn.fused_batch_norm(
@ -489,22 +482,21 @@ class BatchNormalizationBase(Layer):
strategy = distribution_strategy_context.get_strategy() strategy = distribution_strategy_context.get_strategy()
mean_update = strategy.extended.update( mean_update = strategy.extended.update(
self.moving_mean, self._assign_moving_average, self.moving_mean, self._assign_moving_average,
(mean, self.momentum, inputs_size)) (mean, self.momentum))
variance_update = strategy.extended.update( variance_update = strategy.extended.update(
self.moving_variance, self._assign_moving_average, self.moving_variance, self._assign_moving_average,
(variance, self.momentum, inputs_size)) (variance, self.momentum))
else: else:
mean_update = self._assign_moving_average(self.moving_mean, mean, mean_update = self._assign_moving_average(self.moving_mean, mean,
momentum, inputs_size) momentum)
variance_update = self._assign_moving_average( variance_update = self._assign_moving_average(self.moving_variance,
self.moving_variance, variance, momentum, inputs_size) variance, momentum)
self.add_update(mean_update, inputs=True) self.add_update(mean_update, inputs=True)
self.add_update(variance_update, inputs=True) self.add_update(variance_update, inputs=True)
return output return output
def _renorm_correction_and_moments(self, mean, variance, training, def _renorm_correction_and_moments(self, mean, variance, training):
inputs_size):
"""Returns the correction and update values for renorm.""" """Returns the correction and update values for renorm."""
stddev = math_ops.sqrt(variance + self.epsilon) stddev = math_ops.sqrt(variance + self.epsilon)
# Compute the average mean and standard deviation, as if they were # Compute the average mean and standard deviation, as if they were
@ -535,7 +527,7 @@ class BatchNormalizationBase(Layer):
lambda: d, lambda: d,
lambda: array_ops.zeros_like(d)) lambda: array_ops.zeros_like(d))
def _update_renorm_variable(var, weight, value, inputs_size): def _update_renorm_variable(var, weight, value):
"""Updates a moving average and weight, returns the unbiased value.""" """Updates a moving average and weight, returns the unbiased value."""
value = array_ops.identity(value) value = array_ops.identity(value)
def _do_update(): def _do_update():
@ -548,10 +540,9 @@ class BatchNormalizationBase(Layer):
# Make sure the weight is not updated until before r and d computation. # Make sure the weight is not updated until before r and d computation.
with ops.control_dependencies([value]): with ops.control_dependencies([value]):
weight_value = array_ops.constant(1., dtype=weight.dtype) weight_value = array_ops.constant(1., dtype=weight.dtype)
new_var = self._assign_moving_average(var, value, self.renorm_momentum, new_var = self._assign_moving_average(var, value, self.renorm_momentum)
inputs_size) new_weight = self._assign_moving_average(weight, weight_value,
new_weight = self._assign_moving_average( self.renorm_momentum)
weight, weight_value, self.renorm_momentum, inputs_size)
# TODO(yuefengz): the updates to var and weighted can not be batched # TODO(yuefengz): the updates to var and weighted can not be batched
# together if we fetch their updated values here. Consider calculating # together if we fetch their updated values here. Consider calculating
# new values and delaying the updates. # new values and delaying the updates.
@ -562,26 +553,17 @@ class BatchNormalizationBase(Layer):
return tf_utils.smart_cond(training, _do_update, _fake_update) return tf_utils.smart_cond(training, _do_update, _fake_update)
# TODO(yuefengz): colocate the operations # TODO(yuefengz): colocate the operations
new_mean = _update_renorm_variable( new_mean = _update_renorm_variable(self.renorm_mean,
self.renorm_mean, self.renorm_mean_weight, mean, inputs_size) self.renorm_mean_weight, mean)
new_stddev = _update_renorm_variable( new_stddev = _update_renorm_variable(self.renorm_stddev,
self.renorm_stddev, self.renorm_stddev_weight, stddev, inputs_size) self.renorm_stddev_weight, stddev)
# Make sqrt(moving_variance + epsilon) = new_stddev. # Make sqrt(moving_variance + epsilon) = new_stddev.
new_variance = math_ops.square(new_stddev) - self.epsilon new_variance = math_ops.square(new_stddev) - self.epsilon
return (r, d, new_mean, new_variance) return (r, d, new_mean, new_variance)
def _moments(self, inputs, reduction_axes, keep_dims): def _moments(self, inputs, reduction_axes, keep_dims):
mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims) return nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
# TODO(b/129279393): Support zero batch input in non DistributionStrategy
# code as well.
if distribution_strategy_context.has_strategy():
inputs_size = array_ops.size(inputs)
mean = tf_utils.smart_cond(
inputs_size > 0, lambda: mean, lambda: K.zeros_like(mean))
variance = tf_utils.smart_cond(
inputs_size > 0, lambda: variance, lambda: K.zeros_like(variance))
return mean, variance
def call(self, inputs, training=None): def call(self, inputs, training=None):
if training is None: if training is None:
@ -679,10 +661,9 @@ class BatchNormalizationBase(Layer):
else: else:
new_mean, new_variance = mean, variance new_mean, new_variance = mean, variance
inputs_size = array_ops.size(inputs)
if self.renorm: if self.renorm:
r, d, new_mean, new_variance = self._renorm_correction_and_moments( r, d, new_mean, new_variance = self._renorm_correction_and_moments(
new_mean, new_variance, training, inputs_size) new_mean, new_variance, training)
# When training, the normalized values (say, x) will be transformed as # When training, the normalized values (say, x) will be transformed as
# x * gamma + beta without renorm, and (x * r + d) * gamma + beta # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
# = x * (r * gamma) + (d * gamma + beta) with renorm. # = x * (r * gamma) + (d * gamma + beta) with renorm.
@ -698,8 +679,8 @@ class BatchNormalizationBase(Layer):
if in_eager_mode and not self.trainable: if in_eager_mode and not self.trainable:
return return
return strategy.extended.update( return strategy.extended.update(
var, self._assign_moving_average, var, self._assign_moving_average, (value, self.momentum),
(value, self.momentum, inputs_size), group=False) group=False)
# We need to unwrap the moving_mean or moving_variance in the case of # We need to unwrap the moving_mean or moving_variance in the case of
# training being false to match the output of true_fn and false_fn # training being false to match the output of true_fn and false_fn
# in the smart cond. # in the smart cond.
@ -716,9 +697,7 @@ class BatchNormalizationBase(Layer):
"""Compute the updates for mean and variance.""" """Compute the updates for mean and variance."""
if in_eager_mode and not self.trainable: if in_eager_mode and not self.trainable:
return return
return self._assign_moving_average(var, value, self.momentum, return self._assign_moving_average(var, value, self.momentum)
inputs_size)
mean_update = tf_utils.smart_cond( mean_update = tf_utils.smart_cond(
training, training,
lambda: _do_update(self.moving_mean, new_mean), lambda: _do_update(self.moving_mean, new_mean),