From 1f44a4cd0534a06e068c2c3d9b7597c566121c17 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 4 May 2020 17:06:51 +0000 Subject: [PATCH] Update to use tf.shape to get the shape of the tensor, from review comment Signed-off-by: Yong Tang --- tensorflow/python/keras/layers/normalization.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index 482157447c2..5a05b5b252d 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -736,10 +736,13 @@ class BatchNormalizationBase(Layer): if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation - original_shape = [ - d if d is not None else -1 for d in inputs.shape.as_list()] - original_shape = [-1] + original_shape[1:] - expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:] + original_shape = array_ops.shape(inputs) + original_shape = array_ops.concat([ + constant_op.constant([-1]), + original_shape[1:]], axis=0) + expanded_shape = array_ops.concat([ + constant_op.constant([self.virtual_batch_size, -1]), + original_shape[1:]], axis=0) # Will cause errors if virtual_batch_size does not divide the batch size inputs = array_ops.reshape(inputs, expanded_shape)