From 10396ff82a9216a3ce5c5770eefa85bc0a973d38 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 3 May 2020 20:22:18 +0000 Subject: [PATCH 1/3] Fix BatchNormalization issue with virtual_batch_size when shape has None This PR tries to address the issue raised in 32380 where BatchNormalization with virtual_batch_size will throw out error if shape has None: ``` TypeError: Failed to convert object of type to Tensor. Contents: [8, -1, None, None, 3]. Consider casting elements to a supported type. ``` This PR converts None to -1 so that it could be passed as a tensor to `reshape`. This PR fixes 32380. Signed-off-by: Yong Tang --- tensorflow/python/keras/layers/normalization.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index 97da2954b65..482157447c2 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -736,7 +736,9 @@ class BatchNormalizationBase(Layer): if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation - original_shape = [-1] + inputs.shape.as_list()[1:] + original_shape = [ + d if d is not None else -1 for d in inputs.shape.as_list()] + original_shape = [-1] + original_shape[1:] expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:] # Will cause errors if virtual_batch_size does not divide the batch size From d5598f160946b6faf589da473129237153c874f3 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 3 May 2020 20:25:12 +0000 Subject: [PATCH 2/3] Add test case for BatchNormalization with virtual_batch_size and shape has None. Signed-off-by: Yong Tang --- tensorflow/python/keras/layers/normalization_test.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py index 6a615c2ecdc..ad5d00eb4d9 100644 --- a/tensorflow/python/keras/layers/normalization_test.py +++ b/tensorflow/python/keras/layers/normalization_test.py @@ -354,6 +354,13 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase): # Updates should be tracked in a `wrap_function`. self.assertLen(layer.updates, 2) + @keras_parameterized.run_all_keras_modes + def test_basic_batchnorm_v2_none_shape_and_virtual_batch_size(self): + # Test case for GitHub issue for 32380 + norm = normalization_v2.BatchNormalization(virtual_batch_size=8) + inp = keras.layers.Input(shape=(None, None, 3)) + _ = norm(inp) + def _run_batchnorm_correctness_test(layer, dtype='float32', fused=False): model = keras.models.Sequential() From 1f44a4cd0534a06e068c2c3d9b7597c566121c17 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 4 May 2020 17:06:51 +0000 Subject: [PATCH 3/3] Update to use tf.shape to get the shape of the tensor, from review comment Signed-off-by: Yong Tang --- tensorflow/python/keras/layers/normalization.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index 482157447c2..5a05b5b252d 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -736,10 +736,13 @@ class BatchNormalizationBase(Layer): if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation - original_shape = [ - d if d is not None else -1 for d in inputs.shape.as_list()] - original_shape = [-1] + original_shape[1:] - expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:] + original_shape = array_ops.shape(inputs) + original_shape = array_ops.concat([ + constant_op.constant([-1]), + original_shape[1:]], axis=0) + expanded_shape = array_ops.concat([ + constant_op.constant([self.virtual_batch_size, -1]), + original_shape[1:]], axis=0) # Will cause errors if virtual_batch_size does not divide the batch size inputs = array_ops.reshape(inputs, expanded_shape)