Set static shape information for reshape layer.
See https://github.com/tensorflow/tensorflow/issues/36363 for more details. PiperOrigin-RevId: 312404506 Change-Id: I6409fb5335f4ac76df7a80b15f672369143bd1fd
This commit is contained in:
parent
7d65cad2ce
commit
4a37d3fecd
@ -460,7 +460,7 @@ class Reshape(Layer):
|
|||||||
>>> # also supports shape inference using `-1` as dimension
|
>>> # also supports shape inference using `-1` as dimension
|
||||||
>>> model.add(tf.keras.layers.Reshape((-1, 2, 2)))
|
>>> model.add(tf.keras.layers.Reshape((-1, 2, 2)))
|
||||||
>>> model.output_shape
|
>>> model.output_shape
|
||||||
(None, None, 2, 2)
|
(None, 3, 2, 2)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, target_shape, **kwargs):
|
def __init__(self, target_shape, **kwargs):
|
||||||
@ -495,7 +495,9 @@ class Reshape(Layer):
|
|||||||
is specified.
|
is specified.
|
||||||
"""
|
"""
|
||||||
output_shape = list(output_shape)
|
output_shape = list(output_shape)
|
||||||
msg = 'total size of new array must be unchanged'
|
msg = ('total size of new array must be unchanged, '
|
||||||
|
'input_shape = {}, output_shape = {}'
|
||||||
|
.format(input_shape, output_shape))
|
||||||
|
|
||||||
known, unknown = 1, None
|
known, unknown = 1, None
|
||||||
for index, dim in enumerate(output_shape):
|
for index, dim in enumerate(output_shape):
|
||||||
@ -529,8 +531,13 @@ class Reshape(Layer):
|
|||||||
return tensor_shape.TensorShape(output_shape)
|
return tensor_shape.TensorShape(output_shape)
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
return array_ops.reshape(inputs,
|
result = array_ops.reshape(
|
||||||
(array_ops.shape(inputs)[0],) + self.target_shape)
|
inputs, (array_ops.shape(inputs)[0],) + self.target_shape)
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
# Set the static shape for the result since it might lost during array_ops
|
||||||
|
# reshape, eg, some `None` dim in the result could be inferred.
|
||||||
|
result.set_shape(self.compute_output_shape(inputs.shape))
|
||||||
|
return result
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
config = {'target_shape': self.target_shape}
|
config = {'target_shape': self.target_shape}
|
||||||
|
@ -430,6 +430,12 @@ class CoreLayersTest(keras_parameterized.TestCase):
|
|||||||
kwargs={'target_shape': (-1, 1)},
|
kwargs={'target_shape': (-1, 1)},
|
||||||
input_shape=(None, None, 2))
|
input_shape=(None, None, 2))
|
||||||
|
|
||||||
|
def test_reshape_set_static_shape(self):
|
||||||
|
input_layer = keras.Input(batch_shape=(1, None))
|
||||||
|
reshaped = keras.layers.Reshape((1, 100))(input_layer)
|
||||||
|
# Make sure the batch dim is not lost after array_ops.reshape.
|
||||||
|
self.assertEqual(reshaped.shape, [1, 1, 100])
|
||||||
|
|
||||||
def test_permute(self):
|
def test_permute(self):
|
||||||
testing_utils.layer_test(
|
testing_utils.layer_test(
|
||||||
keras.layers.Permute, kwargs={'dims': (2, 1)}, input_shape=(3, 2, 4))
|
keras.layers.Permute, kwargs={'dims': (2, 1)}, input_shape=(3, 2, 4))
|
||||||
|
Loading…
Reference in New Issue
Block a user