Make sure `build` returns correct input_shape if `compute_output_shape` is called manually.
PiperOrigin-RevId: 305589497 Change-Id: Ifbcb13569a56759e3ff463b324cec1be5b8bbb02
This commit is contained in:
parent
27058058e3
commit
a807a425a2
|
@ -2360,8 +2360,15 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
else:
|
||||
self._dtype_policy = policy.Policy(dtype)
|
||||
input_shapes = None
|
||||
# Converts Tensors / CompositeTensors to TensorShapes.
|
||||
if all(hasattr(x, 'shape') for x in input_list):
|
||||
input_shapes = nest.map_structure(lambda x: x.shape, inputs)
|
||||
input_shapes = tf_utils.get_shapes(inputs)
|
||||
else:
|
||||
# Converts input shape to TensorShapes.
|
||||
try:
|
||||
input_shapes = tf_utils.convert_shapes(inputs, to_tuples=False)
|
||||
except ValueError:
|
||||
pass
|
||||
# Only call `build` if the user has manually overridden the build method.
|
||||
if not hasattr(self.build, '_is_default'):
|
||||
# Any setup work performed only once should happen in an `init_scope`
|
||||
|
|
|
@ -125,6 +125,7 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
|||
|
||||
def build(self, input_shape):
|
||||
self.build_counter += 1
|
||||
self.build_shape = input_shape
|
||||
|
||||
def call(self, inputs):
|
||||
return inputs
|
||||
|
@ -132,14 +133,17 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
|||
layer = BuildCounter(dtype=dtypes.float64)
|
||||
output_shape = layer.compute_output_shape((None, 10))
|
||||
self.assertEqual(layer.build_counter, 1)
|
||||
self.assertEqual(layer.build_shape.as_list(), [None, 10])
|
||||
self.assertEqual(output_shape.as_list(), [None, 10])
|
||||
output_signature = layer.compute_output_signature(
|
||||
tensor_spec.TensorSpec(dtype=dtypes.float64, shape=[None, 10]))
|
||||
self.assertEqual(layer.build_counter, 1)
|
||||
self.assertEqual(layer.build_shape.as_list(), [None, 10])
|
||||
self.assertEqual(output_signature.dtype, dtypes.float64)
|
||||
self.assertEqual(output_signature.shape.as_list(), [None, 10])
|
||||
layer(np.ones((5, 10)))
|
||||
self.assertEqual(layer.build_counter, 1)
|
||||
self.assertEqual(layer.build_shape.as_list(), [None, 10])
|
||||
|
||||
def test_eager_switch_case_input(self):
|
||||
task = input_layer.Input(shape=(), dtype=dtypes.int32)
|
||||
|
|
|
@ -187,6 +187,11 @@ def map_structure_with_atomic(is_atomic_fn, map_fn, nested):
|
|||
return nest._sequence_like(nested, mapped_values)
|
||||
|
||||
|
||||
def get_shapes(tensors):
|
||||
"""Gets shapes from tensors."""
|
||||
return nest.map_structure(lambda x: x.shape, tensors)
|
||||
|
||||
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue