Fix input size used for batch normalization.
Inputs_size (array_ops.size()) used to determine whether to use optional_get_next() API code path defaults to using int32 dtype. If input size is big enough this can lead to integer overflow and cause model to diverge. Correct usage will be to use inputs.get_shape()[0] to get the batch size -- instead of using array_ops.size() which returns the number of elements in inputs tensor which can be arbitrarily large. PiperOrigin-RevId: 305823718 Change-Id: Idc5660d80406fe233b162b73330c6fce4d5357b4
This commit is contained in:
parent
a7543c1e44
commit
2e23d38ce7
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.distribute import combinations
|
from tensorflow.python.distribute import combinations
|
||||||
from tensorflow.python.distribute import strategy_combinations
|
from tensorflow.python.distribute import strategy_combinations
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
@ -158,5 +159,53 @@ class NormalizationTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(np.zeros(shape=(0, 4, 4, 3), dtype=np.float32),
|
self.assertAllEqual(np.zeros(shape=(0, 4, 4, 3), dtype=np.float32),
|
||||||
test_step().numpy())
|
test_step().numpy())
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
distribution=[
|
||||||
|
strategy_combinations.one_device_strategy,
|
||||||
|
],
|
||||||
|
mode=["eager"],
|
||||||
|
fused=[True, False]))
|
||||||
|
def testBNWithDynamicBatchInputEager(self, distribution, fused):
|
||||||
|
distribution.extended.experimental_enable_get_next_as_optional = True
|
||||||
|
with distribution.scope():
|
||||||
|
# Explicitly create dataset with drop_remainder=False.
|
||||||
|
# This would make batch size unknown.
|
||||||
|
inputs = np.random.random((11, 4, 4, 3)).astype(np.float32) + 100
|
||||||
|
targets = np.random.random((11, 4, 4, 3)).astype(np.float32)
|
||||||
|
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).batch(
|
||||||
|
10, drop_remainder=False).repeat()
|
||||||
|
dataset_iterator = iter(
|
||||||
|
distribution.experimental_distribute_dataset(dataset))
|
||||||
|
|
||||||
|
bn = normalization.BatchNormalization(
|
||||||
|
axis=-1, epsilon=1e-3, momentum=0.9, fused=fused)
|
||||||
|
optimizer = gradient_descent.GradientDescentOptimizer(0.01)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def train_step(iterator):
|
||||||
|
|
||||||
|
def step_fn(inputs):
|
||||||
|
features, targets = inputs
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
outputs = bn(features, training=True)
|
||||||
|
loss = losses.mean_squared_error(targets, outputs)
|
||||||
|
|
||||||
|
grads = tape.gradient(loss, bn.variables)
|
||||||
|
optimizer.apply_gradients(zip(grads, bn.variables))
|
||||||
|
return loss
|
||||||
|
|
||||||
|
return distribution.run(step_fn, args=(next(iterator),))
|
||||||
|
|
||||||
|
for _ in range(100):
|
||||||
|
train_step(dataset_iterator).numpy()
|
||||||
|
|
||||||
|
# Verify that the statistics and weights are updated.
|
||||||
|
self.assertNotAllEqual(np.ndarray([0, 0, 0]), bn.moving_mean.numpy())
|
||||||
|
self.assertNotAllEqual(np.ndarray([1, 1, 1]), bn.moving_variance.numpy())
|
||||||
|
self.assertNotAllEqual(np.ndarray([1, 1, 1]), bn.gamma.numpy())
|
||||||
|
self.assertNotAllEqual(np.ndarray([0, 0, 0]), bn.beta.numpy())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -537,9 +537,11 @@ class BatchNormalizationBase(Layer):
|
|||||||
# TODO(b/129279393): Support zero batch input in non DistributionStrategy
|
# TODO(b/129279393): Support zero batch input in non DistributionStrategy
|
||||||
# code as well.
|
# code as well.
|
||||||
if self._support_zero_size_input():
|
if self._support_zero_size_input():
|
||||||
inputs_size = array_ops.size(inputs)
|
# Keras assumes that batch dimension is the first dimension for Batch
|
||||||
|
# Normalization.
|
||||||
|
input_batch_size = array_ops.shape(inputs)[0]
|
||||||
else:
|
else:
|
||||||
inputs_size = None
|
input_batch_size = None
|
||||||
|
|
||||||
# TODO(rmlarsen): Support using fused avg updates for non-eager execution
|
# TODO(rmlarsen): Support using fused avg updates for non-eager execution
|
||||||
# after fixing graph pattern matching and enabling fused_batch_norm to
|
# after fixing graph pattern matching and enabling fused_batch_norm to
|
||||||
@ -600,10 +602,12 @@ class BatchNormalizationBase(Layer):
|
|||||||
data_format=self._data_format)
|
data_format=self._data_format)
|
||||||
|
|
||||||
train_op = _fused_batch_norm_training
|
train_op = _fused_batch_norm_training
|
||||||
if use_fused_avg_updates and inputs_size is not None:
|
if use_fused_avg_updates and input_batch_size is not None:
|
||||||
train_op = lambda: tf_utils.smart_cond(inputs_size > 0,
|
# pylint: disable=g-long-lambda
|
||||||
|
train_op = lambda: tf_utils.smart_cond(input_batch_size > 0,
|
||||||
_fused_batch_norm_training,
|
_fused_batch_norm_training,
|
||||||
_fused_batch_norm_training_empty)
|
_fused_batch_norm_training_empty)
|
||||||
|
# pylint: enable=g-long-lambda
|
||||||
|
|
||||||
output, mean, variance = tf_utils.smart_cond(training, train_op,
|
output, mean, variance = tf_utils.smart_cond(training, train_op,
|
||||||
_fused_batch_norm_inference)
|
_fused_batch_norm_inference)
|
||||||
@ -624,7 +628,7 @@ class BatchNormalizationBase(Layer):
|
|||||||
return self._assign_new_value(self.moving_mean, mean)
|
return self._assign_new_value(self.moving_mean, mean)
|
||||||
else:
|
else:
|
||||||
return self._assign_moving_average(self.moving_mean, mean, momentum,
|
return self._assign_moving_average(self.moving_mean, mean, momentum,
|
||||||
inputs_size)
|
input_batch_size)
|
||||||
|
|
||||||
def variance_update():
|
def variance_update():
|
||||||
"""Update self.moving_variance with the most recent data point."""
|
"""Update self.moving_variance with the most recent data point."""
|
||||||
@ -632,7 +636,7 @@ class BatchNormalizationBase(Layer):
|
|||||||
return self._assign_new_value(self.moving_variance, variance)
|
return self._assign_new_value(self.moving_variance, variance)
|
||||||
else:
|
else:
|
||||||
return self._assign_moving_average(self.moving_variance, variance,
|
return self._assign_moving_average(self.moving_variance, variance,
|
||||||
momentum, inputs_size)
|
momentum, input_batch_size)
|
||||||
|
|
||||||
self.add_update(mean_update)
|
self.add_update(mean_update)
|
||||||
self.add_update(variance_update)
|
self.add_update(variance_update)
|
||||||
@ -706,9 +710,9 @@ class BatchNormalizationBase(Layer):
|
|||||||
# TODO(b/129279393): Support zero batch input in non DistributionStrategy
|
# TODO(b/129279393): Support zero batch input in non DistributionStrategy
|
||||||
# code as well.
|
# code as well.
|
||||||
if self._support_zero_size_input():
|
if self._support_zero_size_input():
|
||||||
inputs_size = array_ops.size(inputs)
|
input_batch_size = array_ops.shape(inputs)[0]
|
||||||
mean = array_ops.where(inputs_size > 0, mean, K.zeros_like(mean))
|
mean = array_ops.where(input_batch_size > 0, mean, K.zeros_like(mean))
|
||||||
variance = array_ops.where(inputs_size > 0, variance,
|
variance = array_ops.where(input_batch_size > 0, variance,
|
||||||
K.zeros_like(variance))
|
K.zeros_like(variance))
|
||||||
return mean, variance
|
return mean, variance
|
||||||
|
|
||||||
@ -822,12 +826,15 @@ class BatchNormalizationBase(Layer):
|
|||||||
new_mean, new_variance = mean, variance
|
new_mean, new_variance = mean, variance
|
||||||
|
|
||||||
if self._support_zero_size_input():
|
if self._support_zero_size_input():
|
||||||
inputs_size = array_ops.size(inputs)
|
# Keras assumes that batch dimension is the first dimension for Batch
|
||||||
|
# Normalization.
|
||||||
|
input_batch_size = array_ops.shape(inputs)[0]
|
||||||
else:
|
else:
|
||||||
inputs_size = None
|
input_batch_size = None
|
||||||
|
|
||||||
if self.renorm:
|
if self.renorm:
|
||||||
r, d, new_mean, new_variance = self._renorm_correction_and_moments(
|
r, d, new_mean, new_variance = self._renorm_correction_and_moments(
|
||||||
new_mean, new_variance, training, inputs_size)
|
new_mean, new_variance, training, input_batch_size)
|
||||||
# When training, the normalized values (say, x) will be transformed as
|
# When training, the normalized values (say, x) will be transformed as
|
||||||
# x * gamma + beta without renorm, and (x * r + d) * gamma + beta
|
# x * gamma + beta without renorm, and (x * r + d) * gamma + beta
|
||||||
# = x * (r * gamma) + (d * gamma + beta) with renorm.
|
# = x * (r * gamma) + (d * gamma + beta) with renorm.
|
||||||
@ -838,7 +845,7 @@ class BatchNormalizationBase(Layer):
|
|||||||
def _do_update(var, value):
|
def _do_update(var, value):
|
||||||
"""Compute the updates for mean and variance."""
|
"""Compute the updates for mean and variance."""
|
||||||
return self._assign_moving_average(var, value, self.momentum,
|
return self._assign_moving_average(var, value, self.momentum,
|
||||||
inputs_size)
|
input_batch_size)
|
||||||
|
|
||||||
def mean_update():
|
def mean_update():
|
||||||
true_branch = lambda: _do_update(self.moving_mean, new_mean)
|
true_branch = lambda: _do_update(self.moving_mean, new_mean)
|
||||||
|
Loading…
Reference in New Issue
Block a user