diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py
index 741cc831f02..0ef4840b651 100644
--- a/tensorflow/python/keras/engine/functional.py
+++ b/tensorflow/python/keras/engine/functional.py
@@ -544,6 +544,7 @@ class Functional(training_lib.Model):
       t_rank = t_shape.rank
       ref_shape = ref_input.shape
       ref_rank = ref_shape.rank
+      keras_history = getattr(tensor, '_keras_history', None)
       if t_rank is not None and ref_rank is not None:
         # Should squeeze last dimension.
         # True if tensor is (BATCH, ..., 1) and reference is (BATCH, ...).
@@ -553,6 +554,8 @@ class Functional(training_lib.Model):
         # True if tensor is (BATCH, ...) and reference is (BATCH, ..., 1).
         elif (t_rank == ref_rank - 1 and ref_shape[-1] == 1):
           tensor = array_ops.expand_dims_v2(tensor, axis=-1)
+      if keras_history is not None:  # Restore keras history.
+        tensor._keras_history = keras_history
 
       # Add shape hints to Tensors that may have None shape dims but have shapes
       # defined by the `keras.Input` (not applicable in eager mode).
diff --git a/tensorflow/python/keras/engine/functional_test.py b/tensorflow/python/keras/engine/functional_test.py
index b877e81af15..68b40caad9b 100644
--- a/tensorflow/python/keras/engine/functional_test.py
+++ b/tensorflow/python/keras/engine/functional_test.py
@@ -2068,5 +2068,18 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
     # `None` value passed during construction is overridden.
     self.assertAllEqual(network(x, training=False), x * 0.0)
 
+  def test_keras_history_propagation_(self):
+    for input_shape in [(1,), (1, 1)]:
+      sub_in = input_layer_lib.Input((1,))
+      relu_layer = layers.ReLU()
+      sub_out = relu_layer(sub_in)
+      submodel = functional.Functional(sub_in, sub_out)
+      self.assertLen(relu_layer._inbound_nodes, 1)
+
+      inp = input_layer_lib.Input(input_shape)
+      submodel(inp)
+      self.assertLen(relu_layer._inbound_nodes, 2)
+
+
 if __name__ == '__main__':
   test.main()