Fix input size used for batch normalization.
Inputs_size (array_ops.size()) used to determine whether to use optional_get_next() API code path defaults to using int32 dtype. If input size is big enough this can lead to integer overflow and cause model to diverge. Correct usage will be to use inputs.get_shape()[0] to get the batch size -- instead of using array_ops.size() which returns the number of elements in inputs tensor which can be arbitrarily large. PiperOrigin-RevId: 305823718 Change-Id: Idc5660d80406fe233b162b73330c6fce4d5357b4
This commit is contained in:
parent
a7543c1e44
commit
2e23d38ce7
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.eager import backprop
|
||||
@ -158,5 +159,53 @@ class NormalizationTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(np.zeros(shape=(0, 4, 4, 3), dtype=np.float32),
|
||||
test_step().numpy())
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.one_device_strategy,
|
||||
],
|
||||
mode=["eager"],
|
||||
fused=[True, False]))
|
||||
def testBNWithDynamicBatchInputEager(self, distribution, fused):
|
||||
distribution.extended.experimental_enable_get_next_as_optional = True
|
||||
with distribution.scope():
|
||||
# Explicitly create dataset with drop_remainder=False.
|
||||
# This would make batch size unknown.
|
||||
inputs = np.random.random((11, 4, 4, 3)).astype(np.float32) + 100
|
||||
targets = np.random.random((11, 4, 4, 3)).astype(np.float32)
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).batch(
|
||||
10, drop_remainder=False).repeat()
|
||||
dataset_iterator = iter(
|
||||
distribution.experimental_distribute_dataset(dataset))
|
||||
|
||||
bn = normalization.BatchNormalization(
|
||||
axis=-1, epsilon=1e-3, momentum=0.9, fused=fused)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(0.01)
|
||||
|
||||
@def_function.function
|
||||
def train_step(iterator):
|
||||
|
||||
def step_fn(inputs):
|
||||
features, targets = inputs
|
||||
with backprop.GradientTape() as tape:
|
||||
outputs = bn(features, training=True)
|
||||
loss = losses.mean_squared_error(targets, outputs)
|
||||
|
||||
grads = tape.gradient(loss, bn.variables)
|
||||
optimizer.apply_gradients(zip(grads, bn.variables))
|
||||
return loss
|
||||
|
||||
return distribution.run(step_fn, args=(next(iterator),))
|
||||
|
||||
for _ in range(100):
|
||||
train_step(dataset_iterator).numpy()
|
||||
|
||||
# Verify that the statistics and weights are updated.
|
||||
self.assertNotAllEqual(np.ndarray([0, 0, 0]), bn.moving_mean.numpy())
|
||||
self.assertNotAllEqual(np.ndarray([1, 1, 1]), bn.moving_variance.numpy())
|
||||
self.assertNotAllEqual(np.ndarray([1, 1, 1]), bn.gamma.numpy())
|
||||
self.assertNotAllEqual(np.ndarray([0, 0, 0]), bn.beta.numpy())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -537,9 +537,11 @@ class BatchNormalizationBase(Layer):
|
||||
# TODO(b/129279393): Support zero batch input in non DistributionStrategy
|
||||
# code as well.
|
||||
if self._support_zero_size_input():
|
||||
inputs_size = array_ops.size(inputs)
|
||||
# Keras assumes that batch dimension is the first dimension for Batch
|
||||
# Normalization.
|
||||
input_batch_size = array_ops.shape(inputs)[0]
|
||||
else:
|
||||
inputs_size = None
|
||||
input_batch_size = None
|
||||
|
||||
# TODO(rmlarsen): Support using fused avg updates for non-eager execution
|
||||
# after fixing graph pattern matching and enabling fused_batch_norm to
|
||||
@ -600,10 +602,12 @@ class BatchNormalizationBase(Layer):
|
||||
data_format=self._data_format)
|
||||
|
||||
train_op = _fused_batch_norm_training
|
||||
if use_fused_avg_updates and inputs_size is not None:
|
||||
train_op = lambda: tf_utils.smart_cond(inputs_size > 0,
|
||||
if use_fused_avg_updates and input_batch_size is not None:
|
||||
# pylint: disable=g-long-lambda
|
||||
train_op = lambda: tf_utils.smart_cond(input_batch_size > 0,
|
||||
_fused_batch_norm_training,
|
||||
_fused_batch_norm_training_empty)
|
||||
# pylint: enable=g-long-lambda
|
||||
|
||||
output, mean, variance = tf_utils.smart_cond(training, train_op,
|
||||
_fused_batch_norm_inference)
|
||||
@ -624,7 +628,7 @@ class BatchNormalizationBase(Layer):
|
||||
return self._assign_new_value(self.moving_mean, mean)
|
||||
else:
|
||||
return self._assign_moving_average(self.moving_mean, mean, momentum,
|
||||
inputs_size)
|
||||
input_batch_size)
|
||||
|
||||
def variance_update():
|
||||
"""Update self.moving_variance with the most recent data point."""
|
||||
@ -632,7 +636,7 @@ class BatchNormalizationBase(Layer):
|
||||
return self._assign_new_value(self.moving_variance, variance)
|
||||
else:
|
||||
return self._assign_moving_average(self.moving_variance, variance,
|
||||
momentum, inputs_size)
|
||||
momentum, input_batch_size)
|
||||
|
||||
self.add_update(mean_update)
|
||||
self.add_update(variance_update)
|
||||
@ -706,9 +710,9 @@ class BatchNormalizationBase(Layer):
|
||||
# TODO(b/129279393): Support zero batch input in non DistributionStrategy
|
||||
# code as well.
|
||||
if self._support_zero_size_input():
|
||||
inputs_size = array_ops.size(inputs)
|
||||
mean = array_ops.where(inputs_size > 0, mean, K.zeros_like(mean))
|
||||
variance = array_ops.where(inputs_size > 0, variance,
|
||||
input_batch_size = array_ops.shape(inputs)[0]
|
||||
mean = array_ops.where(input_batch_size > 0, mean, K.zeros_like(mean))
|
||||
variance = array_ops.where(input_batch_size > 0, variance,
|
||||
K.zeros_like(variance))
|
||||
return mean, variance
|
||||
|
||||
@ -822,12 +826,15 @@ class BatchNormalizationBase(Layer):
|
||||
new_mean, new_variance = mean, variance
|
||||
|
||||
if self._support_zero_size_input():
|
||||
inputs_size = array_ops.size(inputs)
|
||||
# Keras assumes that batch dimension is the first dimension for Batch
|
||||
# Normalization.
|
||||
input_batch_size = array_ops.shape(inputs)[0]
|
||||
else:
|
||||
inputs_size = None
|
||||
input_batch_size = None
|
||||
|
||||
if self.renorm:
|
||||
r, d, new_mean, new_variance = self._renorm_correction_and_moments(
|
||||
new_mean, new_variance, training, inputs_size)
|
||||
new_mean, new_variance, training, input_batch_size)
|
||||
# When training, the normalized values (say, x) will be transformed as
|
||||
# x * gamma + beta without renorm, and (x * r + d) * gamma + beta
|
||||
# = x * (r * gamma) + (d * gamma + beta) with renorm.
|
||||
@ -838,7 +845,7 @@ class BatchNormalizationBase(Layer):
|
||||
def _do_update(var, value):
|
||||
"""Compute the updates for mean and variance."""
|
||||
return self._assign_moving_average(var, value, self.momentum,
|
||||
inputs_size)
|
||||
input_batch_size)
|
||||
|
||||
def mean_update():
|
||||
true_branch = lambda: _do_update(self.moving_mean, new_mean)
|
||||
|
Loading…
Reference in New Issue
Block a user