From 2e23d38ce733613fd0db938b8ba6bcf39c722ba1 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 9 Apr 2020 21:58:09 -0700
Subject: [PATCH] Fix input size used for batch normalization.

Inputs_size (array_ops.size()) used to determine whether to use optional_get_next() API code path defaults to using int32 dtype. If input size is big enough this can lead to integer overflow and cause model to diverge.

Correct usage will be to use inputs.get_shape()[0] to get the batch size -- instead of using array_ops.size() which returns the number of elements in inputs tensor which can be arbitrarily large.

PiperOrigin-RevId: 305823718
Change-Id: Idc5660d80406fe233b162b73330c6fce4d5357b4
---
 .../python/distribute/zero_batch_test.py      | 49 +++++++++++++++++++
 .../python/keras/layers/normalization.py      | 33 ++++++++-----
 2 files changed, 69 insertions(+), 13 deletions(-)

diff --git a/tensorflow/python/distribute/zero_batch_test.py b/tensorflow/python/distribute/zero_batch_test.py
index e590d815459..b41611a91e0 100644
--- a/tensorflow/python/distribute/zero_batch_test.py
+++ b/tensorflow/python/distribute/zero_batch_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
 from absl.testing import parameterized
 import numpy as np
 
+from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.distribute import combinations
 from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.eager import backprop
@@ -158,5 +159,53 @@ class NormalizationTest(test.TestCase, parameterized.TestCase):
       self.assertAllEqual(np.zeros(shape=(0, 4, 4, 3), dtype=np.float32),
                           test_step().numpy())
 
+  @combinations.generate(
+      combinations.combine(
+          distribution=[
+              strategy_combinations.one_device_strategy,
+          ],
+          mode=["eager"],
+          fused=[True, False]))
+  def testBNWithDynamicBatchInputEager(self, distribution, fused):
+    distribution.extended.experimental_enable_get_next_as_optional = True
+    with distribution.scope():
+      # Explicitly create dataset with drop_remainder=False.
+      # This would make batch size unknown.
+      inputs = np.random.random((11, 4, 4, 3)).astype(np.float32) + 100
+      targets = np.random.random((11, 4, 4, 3)).astype(np.float32)
+      dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).batch(
+          10, drop_remainder=False).repeat()
+      dataset_iterator = iter(
+          distribution.experimental_distribute_dataset(dataset))
+
+      bn = normalization.BatchNormalization(
+          axis=-1, epsilon=1e-3, momentum=0.9, fused=fused)
+      optimizer = gradient_descent.GradientDescentOptimizer(0.01)
+
+      @def_function.function
+      def train_step(iterator):
+
+        def step_fn(inputs):
+          features, targets = inputs
+          with backprop.GradientTape() as tape:
+            outputs = bn(features, training=True)
+            loss = losses.mean_squared_error(targets, outputs)
+
+          grads = tape.gradient(loss, bn.variables)
+          optimizer.apply_gradients(zip(grads, bn.variables))
+          return loss
+
+        return distribution.run(step_fn, args=(next(iterator),))
+
+      for _ in range(100):
+        train_step(dataset_iterator).numpy()
+
+      # Verify that the statistics and weights are updated.
+      self.assertNotAllEqual(np.ndarray([0, 0, 0]), bn.moving_mean.numpy())
+      self.assertNotAllEqual(np.ndarray([1, 1, 1]), bn.moving_variance.numpy())
+      self.assertNotAllEqual(np.ndarray([1, 1, 1]), bn.gamma.numpy())
+      self.assertNotAllEqual(np.ndarray([0, 0, 0]), bn.beta.numpy())
+
+
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index d43737dd8d3..c5062163889 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -537,9 +537,11 @@ class BatchNormalizationBase(Layer):
     # TODO(b/129279393): Support zero batch input in non DistributionStrategy
     # code as well.
     if self._support_zero_size_input():
-      inputs_size = array_ops.size(inputs)
+      # Keras assumes that batch dimension is the first dimension for Batch
+      # Normalization.
+      input_batch_size = array_ops.shape(inputs)[0]
     else:
-      inputs_size = None
+      input_batch_size = None
 
     # TODO(rmlarsen): Support using fused avg updates for non-eager execution
     # after fixing graph pattern matching and enabling fused_batch_norm to
@@ -600,10 +602,12 @@ class BatchNormalizationBase(Layer):
           data_format=self._data_format)
 
     train_op = _fused_batch_norm_training
-    if use_fused_avg_updates and inputs_size is not None:
-      train_op = lambda: tf_utils.smart_cond(inputs_size > 0,
+    if use_fused_avg_updates and input_batch_size is not None:
+      # pylint: disable=g-long-lambda
+      train_op = lambda: tf_utils.smart_cond(input_batch_size > 0,
                                              _fused_batch_norm_training,
                                              _fused_batch_norm_training_empty)
+      # pylint: enable=g-long-lambda
 
     output, mean, variance = tf_utils.smart_cond(training, train_op,
                                                  _fused_batch_norm_inference)
@@ -624,7 +628,7 @@ class BatchNormalizationBase(Layer):
           return self._assign_new_value(self.moving_mean, mean)
         else:
           return self._assign_moving_average(self.moving_mean, mean, momentum,
-                                             inputs_size)
+                                             input_batch_size)
 
       def variance_update():
         """Update self.moving_variance with the most recent data point."""
@@ -632,7 +636,7 @@ class BatchNormalizationBase(Layer):
           return self._assign_new_value(self.moving_variance, variance)
         else:
           return self._assign_moving_average(self.moving_variance, variance,
-                                             momentum, inputs_size)
+                                             momentum, input_batch_size)
 
       self.add_update(mean_update)
       self.add_update(variance_update)
@@ -706,9 +710,9 @@ class BatchNormalizationBase(Layer):
     # TODO(b/129279393): Support zero batch input in non DistributionStrategy
     # code as well.
     if self._support_zero_size_input():
-      inputs_size = array_ops.size(inputs)
-      mean = array_ops.where(inputs_size > 0, mean, K.zeros_like(mean))
-      variance = array_ops.where(inputs_size > 0, variance,
+      input_batch_size = array_ops.shape(inputs)[0]
+      mean = array_ops.where(input_batch_size > 0, mean, K.zeros_like(mean))
+      variance = array_ops.where(input_batch_size > 0, variance,
                                  K.zeros_like(variance))
     return mean, variance
 
@@ -822,12 +826,15 @@ class BatchNormalizationBase(Layer):
         new_mean, new_variance = mean, variance
 
       if self._support_zero_size_input():
-        inputs_size = array_ops.size(inputs)
+        # Keras assumes that batch dimension is the first dimension for Batch
+        # Normalization.
+        input_batch_size = array_ops.shape(inputs)[0]
       else:
-        inputs_size = None
+        input_batch_size = None
+
       if self.renorm:
         r, d, new_mean, new_variance = self._renorm_correction_and_moments(
-            new_mean, new_variance, training, inputs_size)
+            new_mean, new_variance, training, input_batch_size)
         # When training, the normalized values (say, x) will be transformed as
         # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
         # = x * (r * gamma) + (d * gamma + beta) with renorm.
@@ -838,7 +845,7 @@ class BatchNormalizationBase(Layer):
       def _do_update(var, value):
         """Compute the updates for mean and variance."""
         return self._assign_moving_average(var, value, self.momentum,
-                                           inputs_size)
+                                           input_batch_size)
 
       def mean_update():
         true_branch = lambda: _do_update(self.moving_mean, new_mean)