diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py
index 13d1d1f6f8e..0b7f4c0bcc7 100644
--- a/tensorflow/lite/python/lite_v2_test.py
+++ b/tensorflow/lite/python/lite_v2_test.py
@@ -328,9 +328,9 @@ class FromSavedModelTest(TestModels):
     self.assertIn('This converter can only convert a single ConcreteFunction',
                   str(error.exception))
 
+  @test_util.run_v2_only
   def testKerasSequentialModel(self):
     """Test a simple sequential tf.Keras model."""
-    self.skipTest('b/134660903')
     input_data = constant_op.constant(1., shape=[1, 1])
 
     x = np.array([[1.], [2.]])
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 24377614031..3d03a45335e 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -164,7 +164,7 @@ def _compatible_shapes(flat_relaxed, flat_to_check):
              for relaxed, to_check in zip(flat_relaxed, flat_to_check))
 
 
-def _common_shape(x, y):
+def common_shape(x, y):
   """Find a `TensorShape` that is compatible with both `x` and `y`."""
   if x is None != y is None:
     raise RuntimeError(
@@ -1577,7 +1577,7 @@ class Function(object):
                            "relaxed_arg_shapes len: %d vs. %d"
                            % (len(arg_shapes), len(relaxed_arg_shapes)))
       relaxed_arg_shapes = [
-          _common_shape(x, y) for (x, y) in zip(
+          common_shape(x, y) for (x, y) in zip(
               arg_shapes, relaxed_arg_shapes)]
     self._function_cache.arg_relaxed_shapes[rank_only_cache_key] = (
         relaxed_arg_shapes)
@@ -1679,8 +1679,9 @@ def register(func, *args, **kwargs):
 def validate_signature(signature):
   if any(not isinstance(arg, tensor_spec.TensorSpec)
          for arg in nest.flatten(signature, expand_composites=True)):
-    raise TypeError("Invalid input_signature %s; input_signature must be "
-                    "a possibly nested sequence of TensorSpec objects.")
+    raise TypeError("Invalid input_signature {}; input_signature must be "
+                    "a possibly nested sequence of TensorSpec objects."
+                    .format(signature))
 
 
 def defun(func=None,
diff --git a/tensorflow/python/keras/saving/saved_model.py b/tensorflow/python/keras/saving/saved_model.py
index 8d9fc094dad..ea387eedbbb 100644
--- a/tensorflow/python/keras/saving/saved_model.py
+++ b/tensorflow/python/keras/saving/saved_model.py
@@ -20,12 +20,12 @@ from __future__ import print_function
 import functools
 import json
 import os
+import weakref
 import six
 
 from tensorflow.python.client import session
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import function as defun
-from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
@@ -38,10 +38,10 @@ from tensorflow.python.keras.optimizer_v2 import optimizer_v2
 from tensorflow.python.keras.saving import model_from_json
 from tensorflow.python.keras.saving import saving_utils
 from tensorflow.python.keras.utils import mode_keys
+from tensorflow.python.keras.utils import tf_utils
 from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
 from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
 from tensorflow.python.lib.io import file_io
-from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.saved_model import builder as saved_model_builder
@@ -713,10 +713,20 @@ def serialize_all_attributes(layer, serialization_cache):
   if _should_skip_serialization(layer):
     return serialized_attr
 
+  function_dict = {}
+  if save_model_default_signature:
+    # For compatibility with the tf.Lite Converter, the default save signature
+    # should be traced without nested calls to other wrapped functions.
+    # TODO(kathywu): Investigate why having nested calls results in a stateful
+    # function. Perhaps something to do with losses, which are traced in nested
+    # calls but not in the flat call.
+    function_dict['_default_save_signature'] = _default_save_signature(layer)
+  else:
+    function_dict['_default_save_signature'] = None
+
   object_dict = _wrap_layer_objects(layer, serialization_cache)
   try:
-    function_dict = _wrap_layer_functions(layer, serialization_cache,
-                                          save_model_default_signature)
+    function_dict.update(_wrap_layer_functions(layer, serialization_cache))
   except (ValueError, TypeError) as e:
     logging.warning('Skipping full serialization of object {}, because an '
                     'error occurred while tracing layer functions. Error '
@@ -799,44 +809,53 @@ def _wrap_layer_objects(layer, serialization_cache):
           wrapped_loss_functions))
 
 
-def _wrap_layer_functions(layer, serialization_cache,
-                          save_model_default_signature=False):
+def _wrap_layer_functions(layer, serialization_cache):
   """Returns dict of wrapped layer call function and losses in tf.functions.
 
   Args:
     layer: Keras Layer object.
     serialization_cache: Dictionary shared between all objects during
       serialization.
-    save_model_default_signature: Whether to save traced model call function.
 
   Returns:
     A dictionary containing all keras tf.functions to serialize. See
     LayerAttributes and ModelAttributes for the list of all attributes.
   """
+  # Since Sequential models may be modified in place using model.add() or
+  # model.pop(), don't use saved functions.
+  if (isinstance(layer, RevivedLayer) and
+      not isinstance(layer, RevivedSequential)):
+    return {fn_name: getattr(layer.keras_api, fn_name, None)
+            for fn_name in LayerAttributes.all_functions}
+
   # Reset the losses of the layer and its children. The call function in each
   # child layer is replaced with tf.functions.
-  original_attrs = _replace_child_layer_functions(layer, serialization_cache)
-  original_layer_losses = layer._losses[:]  # pylint: disable=protected-access
-  with trackable.no_automatic_dependency_tracking_scope(layer):
-    layer._losses = []  # pylint: disable=protected-access
-    # Note that eager losses do not need to be saved since these functions
-    # create symbolic losses.
+  original_fns = _replace_child_layer_functions(layer, serialization_cache)
+  original_losses = _reset_layer_losses(layer)
 
   # Wrap all the layer call and activity regularizer functions.
-  call_fn_with_losses = _wrap_call_and_conditional_losses(layer)
-  fns = {'call_and_return_conditional_losses': call_fn_with_losses,
-         '__call__': _extract_outputs_from_fn(layer, call_fn_with_losses)}
 
-  if save_model_default_signature:
-    fns['_default_save_signature'] = saving_utils.trace_model_call(layer)
-  else:
-    fns['_default_save_signature'] = None
+  # Use LayerCallCollection to ensure that all layer call functions (__call__,
+  # call with losses) are traced with the same inputs.
+  call_collection = LayerCallCollection(layer)
+  call_fn_with_losses = call_collection.add_function(
+      _wrap_call_and_conditional_losses(layer),
+      '{}_layer_call_and_return_conditional_losses'.format(layer.name))
+  call_fn = call_collection.add_function(
+      _extract_outputs_from_fn(layer, call_fn_with_losses),
+      '{}_layer_call_fn'.format(layer.name))
+
+  fns = {'call_and_return_conditional_losses': call_fn_with_losses,
+         '__call__': call_fn}
 
   if layer.activity_regularizer is not None:
     fns['activity_regularizer_fn'] = _wrap_activity_regularizer(layer)
     fns['call_and_return_all_conditional_losses'] = (
-        _append_activity_regularizer_loss(
-            layer, call_fn_with_losses, fns['activity_regularizer_fn']))
+        call_collection.add_function(
+            _append_activity_regularizer_loss(call_fn_with_losses,
+                                              fns['activity_regularizer_fn']),
+            '{}_layer_call_and_return_all_conditional_losses'.format(layer.name)
+            ))
   else:
     fns['activity_regularizer_fn'] = None
     fns['call_and_return_all_conditional_losses'] = call_fn_with_losses
@@ -849,14 +868,21 @@ def _wrap_layer_functions(layer, serialization_cache,
       if fn is not None and fn.input_signature is not None:
         fn.get_concrete_function()
 
-  # Restore overwritten functions/losses
-  with trackable.no_automatic_dependency_tracking_scope(layer):
-    layer._losses = original_layer_losses  # pylint: disable=protected-access
-  _restore_child_layer_functions(original_attrs)
+  # Restore overwritten functions and losses
+  _restore_child_layer_functions(original_fns)
+  _restore_layer_losses(original_losses)
 
   return fns
 
 
+def _default_save_signature(layer):
+  original_losses = _reset_layer_losses(layer)
+  fn = saving_utils.trace_model_call(layer)
+  fn.get_concrete_function()
+  _restore_layer_losses(original_losses)
+  return fn
+
+
 def _list_all_layers(obj):
   if isinstance(obj, training_lib.Model):
     return obj.layers
@@ -888,11 +914,9 @@ def _replace_child_layer_functions(layer, serialization_cache):
         Child layer 2: ...
       }
   """
-  original_attrs = {}
+  # pylint: disable=protected-access
+  original_fns = {}
   for child_layer in _list_all_layers(layer):
-    # Save symbolic layer losses, which will be restored to maintain the same
-    # state.
-    original_attrs[child_layer] = {'losses': child_layer._losses[:]}  # pylint: disable=protected-access
     if child_layer not in serialization_cache[_KERAS_CACHE_KEY]:
       layer_fns = (serialize_all_attributes(child_layer, serialization_cache)
                    .functions)
@@ -906,27 +930,46 @@ def _replace_child_layer_functions(layer, serialization_cache):
       #     wrapped. In this case, no replacement is necessary so move on to the
       #     next child.
       continue
-
-    original_attrs[child_layer]['call'] = child_layer.call
-    original_attrs[child_layer]['activity_regularizer'] = (
-        child_layer.activity_regularizer)
+    original_fns[child_layer] = {
+        'call': child_layer.call,
+        'activity_regularizer': child_layer.activity_regularizer
+    }
     with trackable.no_automatic_dependency_tracking_scope(child_layer):
       child_layer.activity_regularizer = layer_fns.get(
           'activity_regularizer_fn')
       child_layer.call = _use_wrapped_call(
           child_layer, layer_fns['call_and_return_conditional_losses'])
-      child_layer._losses = []  # pylint: disable=protected-access
-  return original_attrs
+  return original_fns
+  # pylint: enable=protected-access
 
 
-def _restore_child_layer_functions(original_attrs):
+def _restore_child_layer_functions(original_fns):
   """Restores attributes replaced with `_replace_child_layer_functions`."""
-  for child_layer, attrs in original_attrs.items():
+  for child_layer, fns in original_fns.items():
     with trackable.no_automatic_dependency_tracking_scope(child_layer):
-      child_layer._losses = attrs['losses']  # pylint: disable=protected-access
-      if 'call' in attrs:
-        child_layer.call = attrs['call']
-        child_layer.activity_regularizer = attrs['activity_regularizer']
+      child_layer.call = fns['call']
+      child_layer.activity_regularizer = fns['activity_regularizer']
+
+
+# pylint: disable=protected-access
+def _reset_layer_losses(parent_layer):
+  """Resets losses of layer and its sublayers, and returns original losses."""
+  losses_dict = {}
+  for layer in _list_all_layers(parent_layer) + [parent_layer]:
+    losses_dict[layer] = {'losses': layer._losses[:],
+                          'eager_losses': layer._eager_losses[:]}
+    with trackable.no_automatic_dependency_tracking_scope(layer):
+      layer._losses = []
+      layer._eager_losses = []
+  return losses_dict
+
+
+def _restore_layer_losses(losses_dict):
+  for layer in losses_dict:
+    with trackable.no_automatic_dependency_tracking_scope(layer):
+      layer._losses = losses_dict[layer]['losses']
+      layer._eager_losses = losses_dict[layer]['eager_losses']
+# pylint: enable=protected-access
 
 
 def _use_wrapped_call(layer, call_fn):
@@ -947,8 +990,10 @@ def _use_wrapped_call(layer, call_fn):
       training = kwargs.pop('training', None)
       if training is None:
         training = K.learning_phase()
-      training = math_ops.cast(training, dtypes.bool)
-      outputs, losses = call_fn(inputs, training=training)
+      outputs, losses = tf_utils.smart_cond(
+          training,
+          lambda: call_fn(inputs, training=True),
+          lambda: call_fn(inputs, training=False))
     else:
       outputs, losses = call_fn(inputs)
     layer.add_loss(losses, inputs)
@@ -956,6 +1001,128 @@ def _use_wrapped_call(layer, call_fn):
   return wrapped_call
 
 
+class LayerCallCollection(object):
+  """Groups wrapped layer call functions.
+
+  This is used to ensure that all layer call functions are traced with the same
+  inputs-
+    - call
+    - call_and_return_conditional_losses
+    - call_and_return_all_conditional_losses
+  """
+
+  def __init__(self, layer):
+    self._layer = layer
+    self._expects_training_arg = layer._expects_training_arg  # pylint: disable=protected-access
+    self._input_signature = self._generate_input_signature(layer)
+    self._functions = weakref.WeakValueDictionary()
+    # Bool indicating whether this object is currently tracing the layer call
+    # functions.
+    self.tracing = False
+
+  def _generate_input_signature(self, layer):
+    """Inspects layer object and returns the inferred input signature.
+
+    Args:
+      layer: Layer object.
+
+    Returns:
+      List of possibly nested TensorSpecs of the layer call function inputs.
+      The list does not contain the `training` argument.
+    """
+    if (isinstance(layer.call, def_function.Function) and
+        layer.call.input_signature is not None):
+      return layer.call.input_signature
+    else:
+      if isinstance(layer, training_lib.Model):
+        return saving_utils.model_input_signature(layer)
+      elif layer.input_spec is not None:
+
+        def to_tensor_spec_or_none(x):
+          spec = input_spec.to_tensor_spec(x, layer.dtype)
+          # If the shape is too general (e.g. multiple dimensions are allowed),
+          # return None so that separate functions can be generated for each
+          # inferred input signature.
+          # TODO(b/134962016): currently partial signatures are not supported.
+          if spec.shape == tensor_shape.TensorShape(None):
+            return None
+          return spec
+        input_signature = [nest.map_structure(
+            to_tensor_spec_or_none, layer.input_spec)]
+
+        return input_signature
+      else:
+        return None
+
+  def add_trace(self, *args, **kwargs):
+    """Traces all functions with the same args and kwargs.
+
+    Args:
+      *args: Positional args passed to the original function.
+      **kwargs: Keyword args passed to the original function.
+    """
+    kwargs = kwargs.copy()
+    self.tracing = True
+    for fn in self._functions.values():
+      # TODO(kathywu): Replace arguments with broader shapes defined in the
+      # input signature.
+      if self._expects_training_arg:
+        kwargs['training'] = False
+        fn.original_get_concrete_function(*args, **kwargs)
+        kwargs['training'] = True
+        fn.original_get_concrete_function(*args, **kwargs)
+      else:
+        fn.original_get_concrete_function(*args, **kwargs)
+    self.tracing = False
+
+  @property
+  def fn_input_signature(self):
+    """Returns input signature for the wrapped layer call function."""
+    if self._expects_training_arg:
+      # The training arg is left as a python boolean, so the call functions
+      # will not have an input signature (input signatures may only describe
+      # tensor arguments).
+      return None
+    if None in nest.flatten(self._input_signature):
+      # TODO(b/134962016): If input signature cannot be partially defined.
+      return None
+    return self._input_signature
+
+  def add_function(self, python_function, name):
+    """Adds a layer call function to the collection."""
+    self._functions[name] = fn = LayerCall(
+        self, python_function, name,
+        input_signature=self.fn_input_signature)
+
+    if (None not in nest.flatten(self._input_signature) and
+        self._expects_training_arg):
+      # Manually add traces for layers that expect a training argument and have
+      # a fully defined input signature.
+      self.add_trace(*self._input_signature)
+    return fn
+
+
+class LayerCall(def_function.Function):
+  """Function that triggers traces of other functions in the same collection."""
+
+  def __init__(self, call_collection, *args, **kwargs):
+    super(LayerCall, self).__init__(*args, **kwargs)
+    self.call_collection = call_collection
+
+  def __call__(self, *args, **kwargs):
+    if not self.call_collection.tracing:
+      self.call_collection.add_trace(*args, **kwargs)
+    return super(LayerCall, self).__call__(*args, **kwargs)
+
+  def get_concrete_function(self, *args, **kwargs):
+    if not self.call_collection.tracing:
+      self.call_collection.add_trace(*args, **kwargs)
+    return super(LayerCall, self).get_concrete_function(*args, **kwargs)
+
+  def original_get_concrete_function(self, *args, **kwargs):
+    return super(LayerCall, self).get_concrete_function(*args, **kwargs)
+
+
 def _wrap_call_and_conditional_losses(layer):
   """Wraps call function that returns a tuple of (outputs, losses).
 
@@ -966,51 +1133,19 @@ def _wrap_call_and_conditional_losses(layer):
     layer: a Keras layer object
 
   Returns:
-    call function that returns outputs and conditional losses -- excludes
+    python call function that returns outputs and conditional losses -- excludes
     activity regularizer
   """
-  if isinstance(layer, RevivedLayer):
-    return layer.keras_api.call_and_return_conditional_losses
-
-  if (isinstance(layer.call, def_function.Function) and
-      layer.call.input_signature is not None):
-    input_signature = layer.call.input_signature
-  else:
-    if (isinstance(layer, training_lib.Model) and
-        saving_utils.model_input_signature(layer) is not None):
-      input_signature = saving_utils.model_input_signature(layer)
-    elif layer.input_spec is not None:
-      input_signature = [nest.map_structure(
-          lambda x: input_spec.to_tensor_spec(x, layer.dtype),
-          layer.input_spec)]
-      # If input spec is too general, then don't define an input signature.
-      for spec in nest.flatten(input_signature):
-        if spec.shape == tensor_shape.TensorShape(None):
-          input_signature = None
-          break
-    else:
-      input_signature = None
-
-    if input_signature is not None and layer._expects_training_arg:  # pylint: disable=protected-access
-      input_signature.append(
-          tensor_spec.TensorSpec(shape=[], dtype=dtypes.bool))
-
   # Create function that generates both outputs and losses
   layer_call = layer.call
   if layer._expects_training_arg:  # pylint: disable=protected-access
-    def call_and_return_conditional_losses(inputs, training):
-      _set_symbolic_learning_phase(training)
+    def call_and_return_conditional_losses(inputs, training=False):
       return layer_call(inputs, training=training), layer.get_losses_for(inputs)
   else:
     def call_and_return_conditional_losses(inputs):
       K.set_learning_phase(0)
       return layer_call(inputs), layer.get_losses_for(inputs)
-  return def_function.Function(
-      call_and_return_conditional_losses,
-      '{}_layer_call_and_return_conditional_losses'.format(layer.name),
-      input_signature=input_signature,
-      # TODO(kathywu): Investigate autograph error.
-      autograph=False)
+  return call_and_return_conditional_losses
 
 
 def _extract_outputs_from_fn(layer, call_and_return_conditional_losses):
@@ -1018,50 +1153,22 @@ def _extract_outputs_from_fn(layer, call_and_return_conditional_losses):
   if isinstance(layer, RevivedLayer):
     return layer.keras_api.__call__  # pylint: disable=protected-access
   if layer._expects_training_arg:  # pylint: disable=protected-access
-    def call(inputs, training):
-      return call_and_return_conditional_losses(inputs, training)[0]
+    def call(inputs, training=False):
+      return call_and_return_conditional_losses(inputs, training=training)[0]
   else:
     def call(inputs):
       return call_and_return_conditional_losses(inputs)[0]
-  return def_function.Function(
-      call, '{}_layer_call_fn'.format(layer.name),
-      input_signature=call_and_return_conditional_losses.input_signature,
-      # TODO(kathywu): Investigate autograph error.
-      autograph=False)
-
-
-def _set_symbolic_learning_phase(value):
-  """Set learning phase to a tensor value (for internal use only).
-
-  This is used when wrapping call functions as tf.functions that have training
-  as a tensor input. Thus, when `learning_phase()` is called, the training
-  tensor is returned. This function is called when saving a model to SavedModel.
-
-  Args:
-    value: A Tensor object.
-
-  Raises:
-    ValueError: If the input value is not a graph tensor
-  """
-  graph = K.get_graph()
-  if not isinstance(value, ops.Tensor):
-    raise ValueError('Symbolic learning phase must be a graph tensor.')
-  K._GRAPH_LEARNING_PHASES[graph] = value  # pylint: disable=protected-access
+  return call
 
 
 def _append_activity_regularizer_loss(
-    layer, call_fn_with_losses, activity_regularizer_fn):
+    call_fn_with_losses, activity_regularizer_fn):
   """Appends activity regularizer loss to losses returned by the wrapped fn."""
-  def fn(*args):
-    outputs, losses = call_fn_with_losses(*args)
+  def fn(*args, **kwargs):
+    outputs, losses = call_fn_with_losses(*args, **kwargs)
     losses.append(activity_regularizer_fn(outputs))
     return outputs, losses
-  return def_function.Function(
-      fn,
-      '{}_layer_call_and_return_all_conditional_losses'.format(layer.name),
-      input_signature=call_fn_with_losses.input_signature,
-      # TODO(kathywu): Investigate autograph error.
-      autograph=False)
+  return fn
 
 
 def _wrap_unconditional_loss(loss_fn, index):
@@ -1135,9 +1242,11 @@ class KerasObjectLoader(load.Loader):
     # pylint: disable=protected-access
     for node in self._nodes:
       if isinstance(node, RevivedModel):
-        input_signature = (
-            node.keras_api.call_and_return_conditional_losses.input_signature[0]
-            )
+        call_fn = node.keras_api.call_and_return_conditional_losses
+        if call_fn.input_signature is None:
+          inputs = infer_inputs_from_restored_call_function(call_fn)
+        else:
+          inputs = call_fn.input_signature[0]
         if isinstance(node, RevivedSequential):
           with trackable.no_automatic_dependency_tracking_scope(node):
             node._layers = []
@@ -1147,7 +1256,7 @@ class KerasObjectLoader(load.Loader):
         if not node.inputs:
           # Since this revived object is technically a subclassed model (even if
           # the original model is functional/sequential), inputs should be set.
-          node._set_inputs(input_signature)
+          node._set_inputs(inputs)
       if isinstance(node, RevivedLayer):
         losses = node._serialized_attributes.get('regularization_losses', [])
         for loss in losses:
@@ -1276,6 +1385,26 @@ def recursively_deserialize_keras_object(config, module_objects=None):
     raise ValueError('Unable to decode config: {}'.format(config))
 
 
+def infer_inputs_from_restored_call_function(fn):
+  """Returns TensorSpec of inputs from a restored call function.
+
+  Args:
+    fn: Restored layer call function. It is assumed that the inputs are entirely
+      in the first argument.
+
+  Returns:
+    TensorSpec of call function inputs.
+  """
+  def common_spec(x, y):
+    return tensor_spec.TensorSpec(defun.common_shape(x.shape, y.shape),
+                                  x.dtype, x.name)
+  spec = fn.concrete_functions[0].structured_input_signature[0][0]
+  for concrete in fn.concrete_functions[1:]:
+    spec2 = concrete.structured_input_signature[0][0]
+    spec = nest.map_structure(common_spec, spec, spec2)
+  return spec
+
+
 class RevivedNetwork(RevivedLayer):
   """Keras network of layers loaded from a SavedModel."""
 
diff --git a/tensorflow/python/keras/saving/saved_model_test.py b/tensorflow/python/keras/saving/saved_model_test.py
index 919ae45972d..732bf820868 100644
--- a/tensorflow/python/keras/saving/saved_model_test.py
+++ b/tensorflow/python/keras/saving/saved_model_test.py
@@ -705,14 +705,8 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
     expected_layers = len(model.layers)
     self.assertEqual(expected_layers, len(loaded.keras_api.layers))
     input_arr = array_ops.ones((4, 3))
-    training_bool = constant_op.constant(False)
-
-    if model._expects_training_arg:
-      call_args = [input_arr, training_bool]
-    else:
-      call_args = [input_arr]
     self.assertAllClose(self.evaluate(model(input_arr)),
-                        self.evaluate(loaded(*call_args)))
+                        self.evaluate(loaded(input_arr)))
 
   @keras_parameterized.run_with_all_model_types
   def test_compiled_model(self):
@@ -765,6 +759,20 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
     self.assertAllEqual([None, 2, 3], loaded.input_spec['b'].shape)
     self.assertEqual('float16', loaded.input_spec['b'].dtype)
 
+  def test_multi_input_model(self):
+    input_1 = keras.layers.Input(shape=(3,))
+    input_2 = keras.layers.Input(shape=(5,))
+    model = keras.Model([input_1, input_2], [input_1, input_2])
+    saved_model_dir = self._save_model_dir()
+
+    model.save(saved_model_dir, save_format='tf')
+    loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir)
+    input_arr_1 = np.random.random((1, 3)).astype('float32')
+    input_arr_2 = np.random.random((1, 5)).astype('float32')
+
+    outputs = loaded([input_arr_1, input_arr_2])
+    self.assertAllEqual(input_arr_1, outputs[0])
+    self.assertAllEqual(input_arr_2, outputs[1])
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/keras/saving/saving_utils.py b/tensorflow/python/keras/saving/saving_utils.py
index 866f596fca3..718c2ad5340 100644
--- a/tensorflow/python/keras/saving/saving_utils.py
+++ b/tensorflow/python/keras/saving/saving_utils.py
@@ -66,7 +66,7 @@ def model_input_signature(model):
 
   Returns:
     A list containing either a single TensorSpec or an object with nested
-    TensorSpecs.
+    TensorSpecs. This list does not contain the `training` argument.
   """
   try:
     inputs = model.inputs
diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py
index 94618989e4f..4804e4d0e0d 100644
--- a/tensorflow/python/saved_model/function_deserialization.py
+++ b/tensorflow/python/saved_model/function_deserialization.py
@@ -177,11 +177,11 @@ class RestoredFunction(def_function.Function):
     # TODO(mdan): We may enable autograph once exceptions are supported.
     super(RestoredFunction, self).__init__(
         python_function, name, autograph=False)
-    self._concrete_functions = concrete_functions
+    self.concrete_functions = concrete_functions
     self._function_spec = function_spec
 
   def _list_all_concrete_functions_for_serialization(self):
-    return self._concrete_functions
+    return self.concrete_functions
 
   def _defun_with_scope(self, scope):
     func = super(RestoredFunction, self)._defun_with_scope(scope)
diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py
index 23ff7093a4b..984c9ab2cde 100644
--- a/tensorflow/python/saved_model/load_test.py
+++ b/tensorflow/python/saved_model/load_test.py
@@ -179,7 +179,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
     # Calling get_concrete_function wraps in a second call operation; we want to
     # inspect the original function body for the control output; digging into
     # graph.as_graph_def() and its FunctionDefLibrary is another option.
-    imported_concrete, = imported.f._concrete_functions
+    imported_concrete, = imported.f.concrete_functions
     imported_graph = imported_concrete.graph
     self.assertIn(
         imported_graph.get_operation_by_name("should_be_control_output"),