Fix connectivity metadata missing issue on functional model.

PiperOrigin-RevId: 315381292
Change-Id: I3a7e61a0afbda0ae1984a4152dd297e350b29775
This commit is contained in:
Pavithra Vijay 2020-06-08 16:50:13 -07:00 committed by TensorFlower Gardener
parent b546504c58
commit 939772a64e
2 changed files with 16 additions and 0 deletions

View File

@ -544,6 +544,7 @@ class Functional(training_lib.Model):
t_rank = t_shape.rank
ref_shape = ref_input.shape
ref_rank = ref_shape.rank
keras_history = getattr(tensor, '_keras_history', None)
if t_rank is not None and ref_rank is not None:
# Should squeeze last dimension.
# True if tensor is (BATCH, ..., 1) and reference is (BATCH, ...).
@ -553,6 +554,8 @@ class Functional(training_lib.Model):
# True if tensor is (BATCH, ...) and reference is (BATCH, ..., 1).
elif (t_rank == ref_rank - 1 and ref_shape[-1] == 1):
tensor = array_ops.expand_dims_v2(tensor, axis=-1)
if keras_history is not None: # Restore keras history.
tensor._keras_history = keras_history
# Add shape hints to Tensors that may have None shape dims but have shapes
# defined by the `keras.Input` (not applicable in eager mode).

View File

@ -2068,5 +2068,18 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
# `None` value passed during construction is overridden.
self.assertAllEqual(network(x, training=False), x * 0.0)
def test_keras_history_propagation_(self):
for input_shape in [(1,), (1, 1)]:
sub_in = input_layer_lib.Input((1,))
relu_layer = layers.ReLU()
sub_out = relu_layer(sub_in)
submodel = functional.Functional(sub_in, sub_out)
self.assertLen(relu_layer._inbound_nodes, 1)
inp = input_layer_lib.Input(input_shape)
submodel(inp)
self.assertLen(relu_layer._inbound_nodes, 2)
if __name__ == '__main__':
test.main()