Fix input size used for batch normalization.

Inputs_size (array_ops.size()) used to determine whether to use optional_get_next() API code path defaults to using int32 dtype. If input size is big enough this can lead to integer overflow and cause model to diverge.

Correct usage will be to use inputs.get_shape()[0] to get the batch size -- instead of using array_ops.size() which returns the number of elements in inputs tensor which can be arbitrarily large.

PiperOrigin-RevId: 305823718
Change-Id: Idc5660d80406fe233b162b73330c6fce4d5357b4
This commit is contained in:
A. Unique TensorFlower 2020-04-09 21:58:09 -07:00 committed by TensorFlower Gardener
parent a7543c1e44
commit 2e23d38ce7
2 changed files with 69 additions and 13 deletions

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import backprop
@ -158,5 +159,53 @@ class NormalizationTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(np.zeros(shape=(0, 4, 4, 3), dtype=np.float32),
test_step().numpy())
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy,
],
mode=["eager"],
fused=[True, False]))
def testBNWithDynamicBatchInputEager(self, distribution, fused):
distribution.extended.experimental_enable_get_next_as_optional = True
with distribution.scope():
# Explicitly create dataset with drop_remainder=False.
# This would make batch size unknown.
inputs = np.random.random((11, 4, 4, 3)).astype(np.float32) + 100
targets = np.random.random((11, 4, 4, 3)).astype(np.float32)
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).batch(
10, drop_remainder=False).repeat()
dataset_iterator = iter(
distribution.experimental_distribute_dataset(dataset))
bn = normalization.BatchNormalization(
axis=-1, epsilon=1e-3, momentum=0.9, fused=fused)
optimizer = gradient_descent.GradientDescentOptimizer(0.01)
@def_function.function
def train_step(iterator):
def step_fn(inputs):
features, targets = inputs
with backprop.GradientTape() as tape:
outputs = bn(features, training=True)
loss = losses.mean_squared_error(targets, outputs)
grads = tape.gradient(loss, bn.variables)
optimizer.apply_gradients(zip(grads, bn.variables))
return loss
return distribution.run(step_fn, args=(next(iterator),))
for _ in range(100):
train_step(dataset_iterator).numpy()
# Verify that the statistics and weights are updated.
self.assertNotAllEqual(np.ndarray([0, 0, 0]), bn.moving_mean.numpy())
self.assertNotAllEqual(np.ndarray([1, 1, 1]), bn.moving_variance.numpy())
self.assertNotAllEqual(np.ndarray([1, 1, 1]), bn.gamma.numpy())
self.assertNotAllEqual(np.ndarray([0, 0, 0]), bn.beta.numpy())
if __name__ == "__main__":
test.main()

View File

@ -537,9 +537,11 @@ class BatchNormalizationBase(Layer):
# TODO(b/129279393): Support zero batch input in non DistributionStrategy
# code as well.
if self._support_zero_size_input():
inputs_size = array_ops.size(inputs)
# Keras assumes that batch dimension is the first dimension for Batch
# Normalization.
input_batch_size = array_ops.shape(inputs)[0]
else:
inputs_size = None
input_batch_size = None
# TODO(rmlarsen): Support using fused avg updates for non-eager execution
# after fixing graph pattern matching and enabling fused_batch_norm to
@ -600,10 +602,12 @@ class BatchNormalizationBase(Layer):
data_format=self._data_format)
train_op = _fused_batch_norm_training
if use_fused_avg_updates and inputs_size is not None:
train_op = lambda: tf_utils.smart_cond(inputs_size > 0,
if use_fused_avg_updates and input_batch_size is not None:
# pylint: disable=g-long-lambda
train_op = lambda: tf_utils.smart_cond(input_batch_size > 0,
_fused_batch_norm_training,
_fused_batch_norm_training_empty)
# pylint: enable=g-long-lambda
output, mean, variance = tf_utils.smart_cond(training, train_op,
_fused_batch_norm_inference)
@ -624,7 +628,7 @@ class BatchNormalizationBase(Layer):
return self._assign_new_value(self.moving_mean, mean)
else:
return self._assign_moving_average(self.moving_mean, mean, momentum,
inputs_size)
input_batch_size)
def variance_update():
"""Update self.moving_variance with the most recent data point."""
@ -632,7 +636,7 @@ class BatchNormalizationBase(Layer):
return self._assign_new_value(self.moving_variance, variance)
else:
return self._assign_moving_average(self.moving_variance, variance,
momentum, inputs_size)
momentum, input_batch_size)
self.add_update(mean_update)
self.add_update(variance_update)
@ -706,9 +710,9 @@ class BatchNormalizationBase(Layer):
# TODO(b/129279393): Support zero batch input in non DistributionStrategy
# code as well.
if self._support_zero_size_input():
inputs_size = array_ops.size(inputs)
mean = array_ops.where(inputs_size > 0, mean, K.zeros_like(mean))
variance = array_ops.where(inputs_size > 0, variance,
input_batch_size = array_ops.shape(inputs)[0]
mean = array_ops.where(input_batch_size > 0, mean, K.zeros_like(mean))
variance = array_ops.where(input_batch_size > 0, variance,
K.zeros_like(variance))
return mean, variance
@ -822,12 +826,15 @@ class BatchNormalizationBase(Layer):
new_mean, new_variance = mean, variance
if self._support_zero_size_input():
inputs_size = array_ops.size(inputs)
# Keras assumes that batch dimension is the first dimension for Batch
# Normalization.
input_batch_size = array_ops.shape(inputs)[0]
else:
inputs_size = None
input_batch_size = None
if self.renorm:
r, d, new_mean, new_variance = self._renorm_correction_and_moments(
new_mean, new_variance, training, inputs_size)
new_mean, new_variance, training, input_batch_size)
# When training, the normalized values (say, x) will be transformed as
# x * gamma + beta without renorm, and (x * r + d) * gamma + beta
# = x * (r * gamma) + (d * gamma + beta) with renorm.
@ -838,7 +845,7 @@ class BatchNormalizationBase(Layer):
def _do_update(var, value):
"""Compute the updates for mean and variance."""
return self._assign_moving_average(var, value, self.momentum,
inputs_size)
input_batch_size)
def mean_update():
true_branch = lambda: _do_update(self.moving_mean, new_mean)