diff --git a/tensorflow/python/distribute/zero_batch_test.py b/tensorflow/python/distribute/zero_batch_test.py index e590d815459..b41611a91e0 100644 --- a/tensorflow/python/distribute/zero_batch_test.py +++ b/tensorflow/python/distribute/zero_batch_test.py @@ -21,6 +21,7 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import combinations from tensorflow.python.distribute import strategy_combinations from tensorflow.python.eager import backprop @@ -158,5 +159,53 @@ class NormalizationTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(np.zeros(shape=(0, 4, 4, 3), dtype=np.float32), test_step().numpy()) + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.one_device_strategy, + ], + mode=["eager"], + fused=[True, False])) + def testBNWithDynamicBatchInputEager(self, distribution, fused): + distribution.extended.experimental_enable_get_next_as_optional = True + with distribution.scope(): + # Explicitly create dataset with drop_remainder=False. + # This would make batch size unknown. + inputs = np.random.random((11, 4, 4, 3)).astype(np.float32) + 100 + targets = np.random.random((11, 4, 4, 3)).astype(np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).batch( + 10, drop_remainder=False).repeat() + dataset_iterator = iter( + distribution.experimental_distribute_dataset(dataset)) + + bn = normalization.BatchNormalization( + axis=-1, epsilon=1e-3, momentum=0.9, fused=fused) + optimizer = gradient_descent.GradientDescentOptimizer(0.01) + + @def_function.function + def train_step(iterator): + + def step_fn(inputs): + features, targets = inputs + with backprop.GradientTape() as tape: + outputs = bn(features, training=True) + loss = losses.mean_squared_error(targets, outputs) + + grads = tape.gradient(loss, bn.variables) + optimizer.apply_gradients(zip(grads, bn.variables)) + return loss + + return distribution.run(step_fn, args=(next(iterator),)) + + for _ in range(100): + train_step(dataset_iterator).numpy() + + # Verify that the statistics and weights are updated. + self.assertNotAllEqual(np.ndarray([0, 0, 0]), bn.moving_mean.numpy()) + self.assertNotAllEqual(np.ndarray([1, 1, 1]), bn.moving_variance.numpy()) + self.assertNotAllEqual(np.ndarray([1, 1, 1]), bn.gamma.numpy()) + self.assertNotAllEqual(np.ndarray([0, 0, 0]), bn.beta.numpy()) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index d43737dd8d3..c5062163889 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -537,9 +537,11 @@ class BatchNormalizationBase(Layer): # TODO(b/129279393): Support zero batch input in non DistributionStrategy # code as well. if self._support_zero_size_input(): - inputs_size = array_ops.size(inputs) + # Keras assumes that batch dimension is the first dimension for Batch + # Normalization. + input_batch_size = array_ops.shape(inputs)[0] else: - inputs_size = None + input_batch_size = None # TODO(rmlarsen): Support using fused avg updates for non-eager execution # after fixing graph pattern matching and enabling fused_batch_norm to @@ -600,10 +602,12 @@ class BatchNormalizationBase(Layer): data_format=self._data_format) train_op = _fused_batch_norm_training - if use_fused_avg_updates and inputs_size is not None: - train_op = lambda: tf_utils.smart_cond(inputs_size > 0, + if use_fused_avg_updates and input_batch_size is not None: + # pylint: disable=g-long-lambda + train_op = lambda: tf_utils.smart_cond(input_batch_size > 0, _fused_batch_norm_training, _fused_batch_norm_training_empty) + # pylint: enable=g-long-lambda output, mean, variance = tf_utils.smart_cond(training, train_op, _fused_batch_norm_inference) @@ -624,7 +628,7 @@ class BatchNormalizationBase(Layer): return self._assign_new_value(self.moving_mean, mean) else: return self._assign_moving_average(self.moving_mean, mean, momentum, - inputs_size) + input_batch_size) def variance_update(): """Update self.moving_variance with the most recent data point.""" @@ -632,7 +636,7 @@ class BatchNormalizationBase(Layer): return self._assign_new_value(self.moving_variance, variance) else: return self._assign_moving_average(self.moving_variance, variance, - momentum, inputs_size) + momentum, input_batch_size) self.add_update(mean_update) self.add_update(variance_update) @@ -706,9 +710,9 @@ class BatchNormalizationBase(Layer): # TODO(b/129279393): Support zero batch input in non DistributionStrategy # code as well. if self._support_zero_size_input(): - inputs_size = array_ops.size(inputs) - mean = array_ops.where(inputs_size > 0, mean, K.zeros_like(mean)) - variance = array_ops.where(inputs_size > 0, variance, + input_batch_size = array_ops.shape(inputs)[0] + mean = array_ops.where(input_batch_size > 0, mean, K.zeros_like(mean)) + variance = array_ops.where(input_batch_size > 0, variance, K.zeros_like(variance)) return mean, variance @@ -822,12 +826,15 @@ class BatchNormalizationBase(Layer): new_mean, new_variance = mean, variance if self._support_zero_size_input(): - inputs_size = array_ops.size(inputs) + # Keras assumes that batch dimension is the first dimension for Batch + # Normalization. + input_batch_size = array_ops.shape(inputs)[0] else: - inputs_size = None + input_batch_size = None + if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( - new_mean, new_variance, training, inputs_size) + new_mean, new_variance, training, input_batch_size) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. @@ -838,7 +845,7 @@ class BatchNormalizationBase(Layer): def _do_update(var, value): """Compute the updates for mean and variance.""" return self._assign_moving_average(var, value, self.momentum, - inputs_size) + input_batch_size) def mean_update(): true_branch = lambda: _do_update(self.moving_mean, new_mean)