From 4a37d3fecd759f2e8a02f917f4256f9089cd44f4 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Tue, 19 May 2020 19:49:43 -0700 Subject: [PATCH] Set static shape information for reshape layer. See https://github.com/tensorflow/tensorflow/issues/36363 for more details. PiperOrigin-RevId: 312404506 Change-Id: I6409fb5335f4ac76df7a80b15f672369143bd1fd --- tensorflow/python/keras/layers/core.py | 15 +++++++++++---- tensorflow/python/keras/layers/core_test.py | 6 ++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index db9c47eca17..60834fad30b 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -460,7 +460,7 @@ class Reshape(Layer): >>> # also supports shape inference using `-1` as dimension >>> model.add(tf.keras.layers.Reshape((-1, 2, 2))) >>> model.output_shape - (None, None, 2, 2) + (None, 3, 2, 2) """ def __init__(self, target_shape, **kwargs): @@ -495,7 +495,9 @@ class Reshape(Layer): is specified. """ output_shape = list(output_shape) - msg = 'total size of new array must be unchanged' + msg = ('total size of new array must be unchanged, ' + 'input_shape = {}, output_shape = {}' + .format(input_shape, output_shape)) known, unknown = 1, None for index, dim in enumerate(output_shape): @@ -529,8 +531,13 @@ class Reshape(Layer): return tensor_shape.TensorShape(output_shape) def call(self, inputs): - return array_ops.reshape(inputs, - (array_ops.shape(inputs)[0],) + self.target_shape) + result = array_ops.reshape( + inputs, (array_ops.shape(inputs)[0],) + self.target_shape) + if not context.executing_eagerly(): + # Set the static shape for the result since it might lost during array_ops + # reshape, eg, some `None` dim in the result could be inferred. + result.set_shape(self.compute_output_shape(inputs.shape)) + return result def get_config(self): config = {'target_shape': self.target_shape} diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py index 3daa187f1ce..70ad63c17eb 100644 --- a/tensorflow/python/keras/layers/core_test.py +++ b/tensorflow/python/keras/layers/core_test.py @@ -430,6 +430,12 @@ class CoreLayersTest(keras_parameterized.TestCase): kwargs={'target_shape': (-1, 1)}, input_shape=(None, None, 2)) + def test_reshape_set_static_shape(self): + input_layer = keras.Input(batch_shape=(1, None)) + reshaped = keras.layers.Reshape((1, 100))(input_layer) + # Make sure the batch dim is not lost after array_ops.reshape. + self.assertEqual(reshaped.shape, [1, 1, 100]) + def test_permute(self): testing_utils.layer_test( keras.layers.Permute, kwargs={'dims': (2, 1)}, input_shape=(3, 2, 4))