From 939772a64eb30fe7ccd632547d98f1511de87637 Mon Sep 17 00:00:00 2001 From: Pavithra Vijay Date: Mon, 8 Jun 2020 16:50:13 -0700 Subject: [PATCH] Fix connectivity metadata missing issue on functional model. PiperOrigin-RevId: 315381292 Change-Id: I3a7e61a0afbda0ae1984a4152dd297e350b29775 --- tensorflow/python/keras/engine/functional.py | 3 +++ tensorflow/python/keras/engine/functional_test.py | 13 +++++++++++++ 2 files changed, 16 insertions(+) diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py index 741cc831f02..0ef4840b651 100644 --- a/tensorflow/python/keras/engine/functional.py +++ b/tensorflow/python/keras/engine/functional.py @@ -544,6 +544,7 @@ class Functional(training_lib.Model): t_rank = t_shape.rank ref_shape = ref_input.shape ref_rank = ref_shape.rank + keras_history = getattr(tensor, '_keras_history', None) if t_rank is not None and ref_rank is not None: # Should squeeze last dimension. # True if tensor is (BATCH, ..., 1) and reference is (BATCH, ...). @@ -553,6 +554,8 @@ class Functional(training_lib.Model): # True if tensor is (BATCH, ...) and reference is (BATCH, ..., 1). elif (t_rank == ref_rank - 1 and ref_shape[-1] == 1): tensor = array_ops.expand_dims_v2(tensor, axis=-1) + if keras_history is not None: # Restore keras history. + tensor._keras_history = keras_history # Add shape hints to Tensors that may have None shape dims but have shapes # defined by the `keras.Input` (not applicable in eager mode). diff --git a/tensorflow/python/keras/engine/functional_test.py b/tensorflow/python/keras/engine/functional_test.py index b877e81af15..68b40caad9b 100644 --- a/tensorflow/python/keras/engine/functional_test.py +++ b/tensorflow/python/keras/engine/functional_test.py @@ -2068,5 +2068,18 @@ class CacheCorrectnessTest(keras_parameterized.TestCase): # `None` value passed during construction is overridden. self.assertAllEqual(network(x, training=False), x * 0.0) + def test_keras_history_propagation_(self): + for input_shape in [(1,), (1, 1)]: + sub_in = input_layer_lib.Input((1,)) + relu_layer = layers.ReLU() + sub_out = relu_layer(sub_in) + submodel = functional.Functional(sub_in, sub_out) + self.assertLen(relu_layer._inbound_nodes, 1) + + inp = input_layer_lib.Input(input_shape) + submodel(inp) + self.assertLen(relu_layer._inbound_nodes, 2) + + if __name__ == '__main__': test.main()