Make sure `build` returns correct input_shape if `compute_output_shape` is called manually.

PiperOrigin-RevId: 305589497
Change-Id: Ifbcb13569a56759e3ff463b324cec1be5b8bbb02
This commit is contained in:
Yanhui Liang 2020-04-08 17:36:58 -07:00 committed by TensorFlower Gardener
parent 27058058e3
commit a807a425a2
3 changed files with 17 additions and 1 deletions

View File

@ -2360,8 +2360,15 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
else:
self._dtype_policy = policy.Policy(dtype)
input_shapes = None
# Converts Tensors / CompositeTensors to TensorShapes.
if all(hasattr(x, 'shape') for x in input_list):
input_shapes = nest.map_structure(lambda x: x.shape, inputs)
input_shapes = tf_utils.get_shapes(inputs)
else:
# Converts input shape to TensorShapes.
try:
input_shapes = tf_utils.convert_shapes(inputs, to_tuples=False)
except ValueError:
pass
# Only call `build` if the user has manually overridden the build method.
if not hasattr(self.build, '_is_default'):
# Any setup work performed only once should happen in an `init_scope`

View File

@ -125,6 +125,7 @@ class BaseLayerTest(keras_parameterized.TestCase):
def build(self, input_shape):
self.build_counter += 1
self.build_shape = input_shape
def call(self, inputs):
return inputs
@ -132,14 +133,17 @@ class BaseLayerTest(keras_parameterized.TestCase):
layer = BuildCounter(dtype=dtypes.float64)
output_shape = layer.compute_output_shape((None, 10))
self.assertEqual(layer.build_counter, 1)
self.assertEqual(layer.build_shape.as_list(), [None, 10])
self.assertEqual(output_shape.as_list(), [None, 10])
output_signature = layer.compute_output_signature(
tensor_spec.TensorSpec(dtype=dtypes.float64, shape=[None, 10]))
self.assertEqual(layer.build_counter, 1)
self.assertEqual(layer.build_shape.as_list(), [None, 10])
self.assertEqual(output_signature.dtype, dtypes.float64)
self.assertEqual(output_signature.shape.as_list(), [None, 10])
layer(np.ones((5, 10)))
self.assertEqual(layer.build_counter, 1)
self.assertEqual(layer.build_shape.as_list(), [None, 10])
def test_eager_switch_case_input(self):
task = input_layer.Input(shape=(), dtype=dtypes.int32)

View File

@ -187,6 +187,11 @@ def map_structure_with_atomic(is_atomic_fn, map_fn, nested):
return nest._sequence_like(nested, mapped_values)
def get_shapes(tensors):
"""Gets shapes from tensors."""
return nest.map_structure(lambda x: x.shape, tensors)
# pylint: enable=protected-access