Set static shape information for reshape layer.
See https://github.com/tensorflow/tensorflow/issues/36363 for more details. PiperOrigin-RevId: 312404506 Change-Id: I6409fb5335f4ac76df7a80b15f672369143bd1fd
This commit is contained in:
parent
7d65cad2ce
commit
4a37d3fecd
@ -460,7 +460,7 @@ class Reshape(Layer):
|
||||
>>> # also supports shape inference using `-1` as dimension
|
||||
>>> model.add(tf.keras.layers.Reshape((-1, 2, 2)))
|
||||
>>> model.output_shape
|
||||
(None, None, 2, 2)
|
||||
(None, 3, 2, 2)
|
||||
"""
|
||||
|
||||
def __init__(self, target_shape, **kwargs):
|
||||
@ -495,7 +495,9 @@ class Reshape(Layer):
|
||||
is specified.
|
||||
"""
|
||||
output_shape = list(output_shape)
|
||||
msg = 'total size of new array must be unchanged'
|
||||
msg = ('total size of new array must be unchanged, '
|
||||
'input_shape = {}, output_shape = {}'
|
||||
.format(input_shape, output_shape))
|
||||
|
||||
known, unknown = 1, None
|
||||
for index, dim in enumerate(output_shape):
|
||||
@ -529,8 +531,13 @@ class Reshape(Layer):
|
||||
return tensor_shape.TensorShape(output_shape)
|
||||
|
||||
def call(self, inputs):
|
||||
return array_ops.reshape(inputs,
|
||||
(array_ops.shape(inputs)[0],) + self.target_shape)
|
||||
result = array_ops.reshape(
|
||||
inputs, (array_ops.shape(inputs)[0],) + self.target_shape)
|
||||
if not context.executing_eagerly():
|
||||
# Set the static shape for the result since it might lost during array_ops
|
||||
# reshape, eg, some `None` dim in the result could be inferred.
|
||||
result.set_shape(self.compute_output_shape(inputs.shape))
|
||||
return result
|
||||
|
||||
def get_config(self):
|
||||
config = {'target_shape': self.target_shape}
|
||||
|
@ -430,6 +430,12 @@ class CoreLayersTest(keras_parameterized.TestCase):
|
||||
kwargs={'target_shape': (-1, 1)},
|
||||
input_shape=(None, None, 2))
|
||||
|
||||
def test_reshape_set_static_shape(self):
|
||||
input_layer = keras.Input(batch_shape=(1, None))
|
||||
reshaped = keras.layers.Reshape((1, 100))(input_layer)
|
||||
# Make sure the batch dim is not lost after array_ops.reshape.
|
||||
self.assertEqual(reshaped.shape, [1, 1, 100])
|
||||
|
||||
def test_permute(self):
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Permute, kwargs={'dims': (2, 1)}, input_shape=(3, 2, 4))
|
||||
|
Loading…
Reference in New Issue
Block a user