Fix input size used for batch normalization.

Inputs_size (array_ops.size()) used to determine whether to use optional_get_next() API code path defaults to using int32 dtype. If input size is big enough this can lead to integer overflow and cause model to diverge.

Correct usage will be to use inputs.get_shape()[0] to get the batch size -- instead of using array_ops.size() which returns the number of elements in inputs tensor which can be arbitrarily large.

PiperOrigin-RevId: 305823718
Change-Id: Idc5660d80406fe233b162b73330c6fce4d5357b4
This commit is contained in:
A. Unique TensorFlower 2020-04-09 21:58:09 -07:00 committed by TensorFlower Gardener
parent a7543c1e44
commit 2e23d38ce7
2 changed files with 69 additions and 13 deletions

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
@ -158,5 +159,53 @@ class NormalizationTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(np.zeros(shape=(0, 4, 4, 3), dtype=np.float32), self.assertAllEqual(np.zeros(shape=(0, 4, 4, 3), dtype=np.float32),
test_step().numpy()) test_step().numpy())
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy,
],
mode=["eager"],
fused=[True, False]))
def testBNWithDynamicBatchInputEager(self, distribution, fused):
distribution.extended.experimental_enable_get_next_as_optional = True
with distribution.scope():
# Explicitly create dataset with drop_remainder=False.
# This would make batch size unknown.
inputs = np.random.random((11, 4, 4, 3)).astype(np.float32) + 100
targets = np.random.random((11, 4, 4, 3)).astype(np.float32)
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).batch(
10, drop_remainder=False).repeat()
dataset_iterator = iter(
distribution.experimental_distribute_dataset(dataset))
bn = normalization.BatchNormalization(
axis=-1, epsilon=1e-3, momentum=0.9, fused=fused)
optimizer = gradient_descent.GradientDescentOptimizer(0.01)
@def_function.function
def train_step(iterator):
def step_fn(inputs):
features, targets = inputs
with backprop.GradientTape() as tape:
outputs = bn(features, training=True)
loss = losses.mean_squared_error(targets, outputs)
grads = tape.gradient(loss, bn.variables)
optimizer.apply_gradients(zip(grads, bn.variables))
return loss
return distribution.run(step_fn, args=(next(iterator),))
for _ in range(100):
train_step(dataset_iterator).numpy()
# Verify that the statistics and weights are updated.
self.assertNotAllEqual(np.ndarray([0, 0, 0]), bn.moving_mean.numpy())
self.assertNotAllEqual(np.ndarray([1, 1, 1]), bn.moving_variance.numpy())
self.assertNotAllEqual(np.ndarray([1, 1, 1]), bn.gamma.numpy())
self.assertNotAllEqual(np.ndarray([0, 0, 0]), bn.beta.numpy())
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -537,9 +537,11 @@ class BatchNormalizationBase(Layer):
# TODO(b/129279393): Support zero batch input in non DistributionStrategy # TODO(b/129279393): Support zero batch input in non DistributionStrategy
# code as well. # code as well.
if self._support_zero_size_input(): if self._support_zero_size_input():
inputs_size = array_ops.size(inputs) # Keras assumes that batch dimension is the first dimension for Batch
# Normalization.
input_batch_size = array_ops.shape(inputs)[0]
else: else:
inputs_size = None input_batch_size = None
# TODO(rmlarsen): Support using fused avg updates for non-eager execution # TODO(rmlarsen): Support using fused avg updates for non-eager execution
# after fixing graph pattern matching and enabling fused_batch_norm to # after fixing graph pattern matching and enabling fused_batch_norm to
@ -600,10 +602,12 @@ class BatchNormalizationBase(Layer):
data_format=self._data_format) data_format=self._data_format)
train_op = _fused_batch_norm_training train_op = _fused_batch_norm_training
if use_fused_avg_updates and inputs_size is not None: if use_fused_avg_updates and input_batch_size is not None:
train_op = lambda: tf_utils.smart_cond(inputs_size > 0, # pylint: disable=g-long-lambda
train_op = lambda: tf_utils.smart_cond(input_batch_size > 0,
_fused_batch_norm_training, _fused_batch_norm_training,
_fused_batch_norm_training_empty) _fused_batch_norm_training_empty)
# pylint: enable=g-long-lambda
output, mean, variance = tf_utils.smart_cond(training, train_op, output, mean, variance = tf_utils.smart_cond(training, train_op,
_fused_batch_norm_inference) _fused_batch_norm_inference)
@ -624,7 +628,7 @@ class BatchNormalizationBase(Layer):
return self._assign_new_value(self.moving_mean, mean) return self._assign_new_value(self.moving_mean, mean)
else: else:
return self._assign_moving_average(self.moving_mean, mean, momentum, return self._assign_moving_average(self.moving_mean, mean, momentum,
inputs_size) input_batch_size)
def variance_update(): def variance_update():
"""Update self.moving_variance with the most recent data point.""" """Update self.moving_variance with the most recent data point."""
@ -632,7 +636,7 @@ class BatchNormalizationBase(Layer):
return self._assign_new_value(self.moving_variance, variance) return self._assign_new_value(self.moving_variance, variance)
else: else:
return self._assign_moving_average(self.moving_variance, variance, return self._assign_moving_average(self.moving_variance, variance,
momentum, inputs_size) momentum, input_batch_size)
self.add_update(mean_update) self.add_update(mean_update)
self.add_update(variance_update) self.add_update(variance_update)
@ -706,9 +710,9 @@ class BatchNormalizationBase(Layer):
# TODO(b/129279393): Support zero batch input in non DistributionStrategy # TODO(b/129279393): Support zero batch input in non DistributionStrategy
# code as well. # code as well.
if self._support_zero_size_input(): if self._support_zero_size_input():
inputs_size = array_ops.size(inputs) input_batch_size = array_ops.shape(inputs)[0]
mean = array_ops.where(inputs_size > 0, mean, K.zeros_like(mean)) mean = array_ops.where(input_batch_size > 0, mean, K.zeros_like(mean))
variance = array_ops.where(inputs_size > 0, variance, variance = array_ops.where(input_batch_size > 0, variance,
K.zeros_like(variance)) K.zeros_like(variance))
return mean, variance return mean, variance
@ -822,12 +826,15 @@ class BatchNormalizationBase(Layer):
new_mean, new_variance = mean, variance new_mean, new_variance = mean, variance
if self._support_zero_size_input(): if self._support_zero_size_input():
inputs_size = array_ops.size(inputs) # Keras assumes that batch dimension is the first dimension for Batch
# Normalization.
input_batch_size = array_ops.shape(inputs)[0]
else: else:
inputs_size = None input_batch_size = None
if self.renorm: if self.renorm:
r, d, new_mean, new_variance = self._renorm_correction_and_moments( r, d, new_mean, new_variance = self._renorm_correction_and_moments(
new_mean, new_variance, training, inputs_size) new_mean, new_variance, training, input_batch_size)
# When training, the normalized values (say, x) will be transformed as # When training, the normalized values (say, x) will be transformed as
# x * gamma + beta without renorm, and (x * r + d) * gamma + beta # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
# = x * (r * gamma) + (d * gamma + beta) with renorm. # = x * (r * gamma) + (d * gamma + beta) with renorm.
@ -838,7 +845,7 @@ class BatchNormalizationBase(Layer):
def _do_update(var, value): def _do_update(var, value):
"""Compute the updates for mean and variance.""" """Compute the updates for mean and variance."""
return self._assign_moving_average(var, value, self.momentum, return self._assign_moving_average(var, value, self.momentum,
inputs_size) input_batch_size)
def mean_update(): def mean_update():
true_branch = lambda: _do_update(self.moving_mean, new_mean) true_branch = lambda: _do_update(self.moving_mean, new_mean)