Set static shape information for reshape layer.

See https://github.com/tensorflow/tensorflow/issues/36363 for more details.

PiperOrigin-RevId: 312404506
Change-Id: I6409fb5335f4ac76df7a80b15f672369143bd1fd
This commit is contained in:
Scott Zhu 2020-05-19 19:49:43 -07:00 committed by TensorFlower Gardener
parent 7d65cad2ce
commit 4a37d3fecd
2 changed files with 17 additions and 4 deletions

View File

@ -460,7 +460,7 @@ class Reshape(Layer):
>>> # also supports shape inference using `-1` as dimension
>>> model.add(tf.keras.layers.Reshape((-1, 2, 2)))
>>> model.output_shape
(None, None, 2, 2)
(None, 3, 2, 2)
"""
def __init__(self, target_shape, **kwargs):
@ -495,7 +495,9 @@ class Reshape(Layer):
is specified.
"""
output_shape = list(output_shape)
msg = 'total size of new array must be unchanged'
msg = ('total size of new array must be unchanged, '
'input_shape = {}, output_shape = {}'
.format(input_shape, output_shape))
known, unknown = 1, None
for index, dim in enumerate(output_shape):
@ -529,8 +531,13 @@ class Reshape(Layer):
return tensor_shape.TensorShape(output_shape)
def call(self, inputs):
return array_ops.reshape(inputs,
(array_ops.shape(inputs)[0],) + self.target_shape)
result = array_ops.reshape(
inputs, (array_ops.shape(inputs)[0],) + self.target_shape)
if not context.executing_eagerly():
# Set the static shape for the result since it might lost during array_ops
# reshape, eg, some `None` dim in the result could be inferred.
result.set_shape(self.compute_output_shape(inputs.shape))
return result
def get_config(self):
config = {'target_shape': self.target_shape}

View File

@ -430,6 +430,12 @@ class CoreLayersTest(keras_parameterized.TestCase):
kwargs={'target_shape': (-1, 1)},
input_shape=(None, None, 2))
def test_reshape_set_static_shape(self):
input_layer = keras.Input(batch_shape=(1, None))
reshaped = keras.layers.Reshape((1, 100))(input_layer)
# Make sure the batch dim is not lost after array_ops.reshape.
self.assertEqual(reshaped.shape, [1, 1, 100])
def test_permute(self):
testing_utils.layer_test(
keras.layers.Permute, kwargs={'dims': (2, 1)}, input_shape=(3, 2, 4))