diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index 3bdaa866ee6..8d7a0cd3a18 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -136,6 +136,7 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
       auto* out_shape_and_type = handle_data.add_shape_and_type();
       ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
       out_shape_and_type->set_dtype(p.dtype);
+      out_shape_and_type->set_specialized_type(p.specialized_type);
     }
   }
   string result;
@@ -163,7 +164,8 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
     status->status =
         ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
     if (TF_GetCode(status) != TF_OK) return;
-    shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
+    shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(),
+                                  shape_and_type_proto.specialized_type());
   }
   ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
 }
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index bb79b278cb1..10b54476d18 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -133,9 +133,14 @@ struct DimensionOrConstant {
 struct ShapeAndType {
   ShapeAndType() {}
   ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {}
+  ShapeAndType(ShapeHandle s, DataType t, SpecializedType specialized_t)
+      : shape(s), dtype(t), specialized_type(specialized_t) {}
 
   ShapeHandle shape;
   DataType dtype = DT_INVALID;
+  // The type of a variant-dtype tensor sometimes affects graph building
+  // (e.g. for vectorization), and needs to be know statically in such cases.
+  SpecializedType specialized_type = ST_INVALID;
 };
 
 // Shape inference functions registered on ops in REGISTER_OP implement
diff --git a/tensorflow/core/framework/types.proto b/tensorflow/core/framework/types.proto
index 900132c0db9..61549ae08ce 100644
--- a/tensorflow/core/framework/types.proto
+++ b/tensorflow/core/framework/types.proto
@@ -74,3 +74,14 @@ enum DataType {
 //    https://www.tensorflow.org/code/tensorflow/core/framework/types.cc,
 //    https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py,
 //    https://www.tensorflow.org/code/tensorflow/python/framework/function.py)
+
+// For identifying the underlying type of a variant. For variants, the types
+// listed here are a subset of the types in the variant type registry,
+// corresponding to commonly used variants which must occasionally be
+// special-cased.
+enum SpecializedType {
+  // Invalid/unknown specialized type.
+  ST_INVALID = 0;
+  // "tensorflow::TensorList" in the variant type registry.
+  ST_TENSOR_LIST = 1;
+}
\ No newline at end of file
diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc
index 4ad676c37ea..91bcc3be49a 100644
--- a/tensorflow/core/ops/list_ops.cc
+++ b/tensorflow/core/ops/list_ops.cc
@@ -16,6 +16,7 @@ limitations under the License.
 #include "tensorflow/core/framework/common_shape_fns.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/types.pb.h"
 
 namespace tensorflow {
 namespace {
@@ -369,7 +370,7 @@ REGISTER_OP("TensorListFromTensor")
                                   &tensor_shape_except_first_dim));
       c->set_output_handle_shapes_and_types(
           0, std::vector<shape_inference::ShapeAndType>{
-                 {element_shape, element_dtype}});
+                 {element_shape, element_dtype, ST_TENSOR_LIST}});
       return Status::OK();
     });
 
@@ -409,7 +410,7 @@ REGISTER_OP("TensorListReserve")
       TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
       c->set_output_handle_shapes_and_types(
           0, std::vector<shape_inference::ShapeAndType>{
-                 {element_shape, element_dtype}});
+                 {element_shape, element_dtype, ST_TENSOR_LIST}});
       return Status::OK();
     });
 
@@ -481,7 +482,7 @@ REGISTER_OP("TensorListSetItem")
         c->set_output_handle_shapes_and_types(0, *handle_data);
       } else {
         c->set_output_handle_shapes_and_types(
-            0, {{c->UnknownShape(), element_dtype}});
+            0, {{c->UnknownShape(), element_dtype, ST_TENSOR_LIST}});
       }
       return Status::OK();
     });
@@ -532,8 +533,8 @@ REGISTER_OP("TensorListScatter")
       shape_inference::ShapeHandle element_shape;
       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
           2, &element_shape));
-      c->set_output_handle_shapes_and_types(0,
-                                            {{element_shape, element_dtype}});
+      c->set_output_handle_shapes_and_types(
+          0, {{element_shape, element_dtype, ST_TENSOR_LIST}});
       c->set_output(0, c->Scalar());
       return Status::OK();
     });
@@ -552,8 +553,8 @@ REGISTER_OP("TensorListScatterV2")
       shape_inference::ShapeHandle element_shape;
       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
           2, &element_shape));
-      c->set_output_handle_shapes_and_types(0,
-                                            {{element_shape, element_dtype}});
+      c->set_output_handle_shapes_and_types(
+          0, {{element_shape, element_dtype, ST_TENSOR_LIST}});
       c->set_output(0, c->Scalar());
       return Status::OK();
     });
@@ -580,8 +581,8 @@ REGISTER_OP("TensorListScatterIntoExistingList")
         TF_RETURN_IF_ERROR(VerifyHandleData(c, *handle_data, element_dtype));
         element_shape = GetElementShapeFromHandleData(*handle_data);
       }
-      c->set_output_handle_shapes_and_types(0,
-                                            {{element_shape, element_dtype}});
+      c->set_output_handle_shapes_and_types(
+          0, {{element_shape, element_dtype, ST_TENSOR_LIST}});
       c->set_output(0, c->Scalar());
       return Status::OK();
     });
@@ -606,7 +607,7 @@ REGISTER_OP("TensorListConcatLists")
       bool handle_data_b_nonempty = handle_data_b && !handle_data_b->empty();
       if (!(handle_data_a_nonempty || handle_data_b_nonempty)) {
         c->set_output_handle_shapes_and_types(
-            0, {{c->UnknownShape(), element_dtype}});
+            0, {{c->UnknownShape(), element_dtype, ST_TENSOR_LIST}});
         return Status::OK();
       }
       shape_inference::ShapeAndType list_shape_type_a =
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 02f167b4688..ba424193532 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -2174,6 +2174,7 @@ class ConcreteFunction(object):
     j = 0
     for i, o in enumerate(outputs_list):
       if o is not None:
+        custom_gradient.copy_handle_data(self.outputs[j], result[j])
         outputs_list[i] = result[j]
         j += 1
     ret = nest.pack_sequence_as(self._func_graph.structured_outputs,
diff --git a/tensorflow/python/framework/cpp_shape_inference.proto b/tensorflow/python/framework/cpp_shape_inference.proto
index 1bf14570292..aa4df78c40b 100644
--- a/tensorflow/python/framework/cpp_shape_inference.proto
+++ b/tensorflow/python/framework/cpp_shape_inference.proto
@@ -11,6 +11,10 @@ message CppShapeInferenceResult {
   message HandleShapeAndType {
     TensorShapeProto shape = 1;
     DataType dtype = 2;
+    // For dtype==DT_VARIANT, specialized_type may indicate a more specific
+    // type. For other dtypes or when the information is unavailable it is set
+    // to ST_INVALID.
+    SpecializedType specialized_type = 3;
   }
   message HandleData {
     bool is_set = 1;
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 33838aa502e..d874f4f685c 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -121,6 +121,7 @@ cuda_py_test(
         "noasan",  # TODO(b/155406705): flaky
     ],
     deps = [
+        "//tensorflow/core:protos_all_py",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py
index f792cda6ea1..f2f6dc33b84 100644
--- a/tensorflow/python/kernel_tests/list_ops_test.py
+++ b/tensorflow/python/kernel_tests/list_ops_test.py
@@ -22,6 +22,7 @@ from __future__ import print_function
 from absl.testing import parameterized
 import numpy as np  # pylint: disable=unused-import
 
+from tensorflow.core.framework import types_pb2
 from tensorflow.python.client import session
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
@@ -40,6 +41,7 @@ from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import list_ops
 from tensorflow.python.ops import map_fn
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import string_ops
 from tensorflow.python.ops import variable_scope as vs
@@ -1600,9 +1602,18 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
     def func():
       t = constant_op.constant([1., 2., 3.])
       l = list_ops.tensor_list_from_tensor(t, element_shape=[])
+      handle_data = resource_variable_ops.get_eager_safe_handle_data(l)
+      self.assertTrue(handle_data.is_set)
+      self.assertEqual(types_pb2.ST_TENSOR_LIST,
+                       handle_data.shape_and_type[0].specialized_type)
       return l
 
     tensor_list = func()
+    handle_data = resource_variable_ops.get_eager_safe_handle_data(tensor_list)
+    self.assertTrue(handle_data.is_set)
+    self.assertEqual(dtypes.float32, handle_data.shape_and_type[0].dtype)
+    self.assertEqual(types_pb2.ST_TENSOR_LIST,
+                     handle_data.shape_and_type[0].specialized_type)
     element = list_ops.tensor_list_get_item(
         tensor_list, 0, element_dtype=dtypes.float32)
     self.assertAllEqual(element.shape.as_list(), [])
diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py
index f081f036b58..33156f7c9c7 100644
--- a/tensorflow/python/ops/custom_gradient.py
+++ b/tensorflow/python/ops/custom_gradient.py
@@ -67,22 +67,13 @@ def copy_handle_data(source_t, target_t):
         and handle_data.is_set
         and handle_data.shape_and_type):
       # pylint: disable=protected-access
+      if isinstance(target_t, ops.EagerTensor):
+        target_t._handle_data = handle_data
+        return
       pywrap_tf_session.SetHandleShapeAndType(target_t.graph._c_graph,
                                               target_t._as_tf_output(),
                                               handle_data.SerializeToString())
       # pylint: enable=protected-access
-      # Ensure that shapes and dtypes are propagated.
-      shapes, types = zip(*[(pair.shape, pair.dtype)
-                            for pair in handle_data.shape_and_type])
-      ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
-      shapes = [[d.size for d in s.dim]  # pylint: disable=g-complex-comprehension
-                if not s.unknown_rank else None for s in shapes]
-      pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
-          target_t._op._graph._c_graph,  # pylint: disable=protected-access
-          target_t._as_tf_output(),  # pylint: disable=protected-access
-          shapes,
-          ranks,
-          types)
 
 
 @tf_export("custom_gradient")