From be6b1fdb0699d4000b70ad32cc23d1503e5c7511 Mon Sep 17 00:00:00 2001
From: Edward Loper <edloper@google.com>
Date: Wed, 14 Oct 2020 09:41:17 -0700
Subject: [PATCH] Added gradients for RaggedTensorToVariant and
 RaggedTensorFromVariant.  (This allows gradients to pass through map_fn when
 it is applied to ragged tensors.)

PiperOrigin-RevId: 337108621
Change-Id: I73d5f3296181877f0cc4c7a6273b693bcf8310ab
---
 RELEASE.md                                    |   1 +
 ...pi_def_RaggedTensorToVariantGradient.pbtxt |  38 ++
 tensorflow/core/kernels/BUILD                 |  15 +
 .../core/kernels/data/experimental/BUILD      |   1 +
 .../experimental/parse_example_dataset_op.cc  |  10 +-
 .../kernels/ragged_tensor_from_variant_op.cc  | 168 +++----
 .../ragged_tensor_from_variant_op_test.cc     | 160 +++----
 .../kernels/ragged_tensor_to_variant_op.cc    | 180 +++++---
 .../ragged_tensor_to_variant_op_test.cc       | 427 +++++-------------
 .../core/kernels/ragged_tensor_variant.cc     |  86 ++++
 .../core/kernels/ragged_tensor_variant.h      | 110 +++++
 tensorflow/core/ops/ragged_conversion_ops.cc  |  20 +-
 tensorflow/python/ops/ragged/BUILD            |   1 +
 .../ops/ragged/ragged_conversion_ops.py       |  39 ++
 tensorflow/python/ops/ragged/ragged_tensor.py |   3 -
 .../python/ops/ragged/ragged_tensor_test.py   | 146 +++++-
 .../api/golden/v1/tensorflow.raw_ops.pbtxt    |   4 +
 .../api/golden/v2/tensorflow.raw_ops.pbtxt    |   4 +
 18 files changed, 820 insertions(+), 593 deletions(-)
 create mode 100644 tensorflow/core/api_def/base_api/api_def_RaggedTensorToVariantGradient.pbtxt
 create mode 100644 tensorflow/core/kernels/ragged_tensor_variant.cc
 create mode 100644 tensorflow/core/kernels/ragged_tensor_variant.h

diff --git a/RELEASE.md b/RELEASE.md
index 23324d56ca7..0886b6e116c 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -138,6 +138,7 @@
         stateful ops.
     *   Added `tf.config.experimental.get_memory_usage` to return total memory
         usage of the device.
+    * Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`.
 *   `tf.data`:
     *   tf.data service:
     *   Added new `tf.data.experimental.service.register_dataset` and
diff --git a/tensorflow/core/api_def/base_api/api_def_RaggedTensorToVariantGradient.pbtxt b/tensorflow/core/api_def/base_api/api_def_RaggedTensorToVariantGradient.pbtxt
new file mode 100644
index 00000000000..066d6b5eae4
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RaggedTensorToVariantGradient.pbtxt
@@ -0,0 +1,38 @@
+op {
+  graph_op_name: "RaggedTensorToVariantGradient"
+  visibility: HIDDEN
+  in_arg {
+    name: "encoded_ragged_grad"
+    description: <<END
+A `variant` Tensor containing encoded `RaggedTensor` gradients.
+END
+  }
+  in_arg {
+    name: "row_splits"
+    description: <<END
+Outermost row-splits that were used as input to the RaggedTensorToVariant op.
+END
+  }
+  in_arg {
+    name: "dense_values_shape"
+    description: <<END
+Shape of the dense_values that was used as an input to the
+RaggedTensorToVariant op.
+END
+  }
+  out_arg {
+    name: "dense_values_grad"
+    description: <<END
+Gradient for the dense_values of the RaggedTensorToVariant op.
+END
+  }
+  summary: <<END
+Helper used to compute the gradient for `RaggedTensorToVariant`.
+END
+  description: <<END
+Computes the gradient for the dense_values input to the RaggedTensorToVariant
+op, given the variant-encoded ragged gradients of the outputs, along with
+the outer row-splits and the shape of the dense-values that were provided as
+inputs to the RaggedTensorToVariant op.
+END
+}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 6c045152434..f5874474ef8 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1419,10 +1419,22 @@ tf_cc_test(
     ],
 )
 
+cc_library(
+    name = "ragged_tensor_variant",
+    srcs = ["ragged_tensor_variant.cc"],
+    hdrs = ["ragged_tensor_variant.h"],
+    deps = [
+        ":cwise_op",
+        "//tensorflow/core:framework",
+    ],
+)
+
 tf_kernel_library(
     name = "ragged_tensor_to_variant_op",
     srcs = ["ragged_tensor_to_variant_op.cc"],
     deps = [
+        ":concat_lib",
+        ":ragged_tensor_variant",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
     ],
@@ -1432,6 +1444,7 @@ tf_kernel_library(
     name = "ragged_tensor_from_variant_op",
     srcs = ["ragged_tensor_from_variant_op.cc"],
     deps = [
+        ":ragged_tensor_variant",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
     ],
@@ -1444,6 +1457,7 @@ tf_cc_test(
     deps = [
         ":ops_testutil",
         ":ragged_tensor_to_variant_op",
+        ":ragged_tensor_variant",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
@@ -1460,6 +1474,7 @@ tf_cc_test(
     deps = [
         ":ops_testutil",
         ":ragged_tensor_from_variant_op",
+        ":ragged_tensor_variant",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD
index c3ef8122ebb..92bc48159ad 100644
--- a/tensorflow/core/kernels/data/experimental/BUILD
+++ b/tensorflow/core/kernels/data/experimental/BUILD
@@ -424,6 +424,7 @@ tf_kernel_library(
         "//tensorflow/core:framework",
         "//tensorflow/core:functional_ops_op_lib",
         "//tensorflow/core:lib",
+        "//tensorflow/core/kernels:ragged_tensor_variant",
         "//tensorflow/core/kernels/data:dataset_utils",
         "//tensorflow/core/kernels/data:name_utils",
         "//tensorflow/core/kernels/data:parallel_map_dataset_op",
diff --git a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc
index 16cf7fe6416..80f23bb5a0c 100644
--- a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc
@@ -21,6 +21,7 @@ limitations under the License.
 #include "tensorflow/core/kernels/data/name_utils.h"
 #include "tensorflow/core/kernels/data/parallel_map_dataset_op.h"
 #include "tensorflow/core/kernels/data/stats_utils.h"
+#include "tensorflow/core/kernels/ragged_tensor_variant.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/stringprintf.h"
 #include "tensorflow/core/profiler/lib/traceme.h"
@@ -678,12 +679,9 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
         for (int d = 0; d < dataset()->ragged_keys_.size(); ++d) {
           int output_index =
               dataset()->key_to_output_index_.at(dataset()->ragged_keys_[d]);
-          (*output)[output_index] = Tensor(ctx->allocator({}), DT_VARIANT, {});
-          Tensor serialized_ragged =
-              Tensor(ctx->allocator({}), DT_VARIANT, {2});
-          auto serialized_ragged_t = serialized_ragged.vec<Variant>();
-          serialized_ragged_t(0) = example_result.ragged_splits[d];
-          serialized_ragged_t(1) = example_result.ragged_values[d];
+          RaggedTensorVariant serialized_ragged;
+          serialized_ragged.append_splits(example_result.ragged_splits[d]);
+          serialized_ragged.set_values(example_result.ragged_values[d]);
           (*output)[output_index] = Tensor(ctx->allocator({}), DT_VARIANT, {});
           Tensor& ragged_wrapper = (*output)[output_index];
           ragged_wrapper.scalar<Variant>()() = serialized_ragged;
diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
index aa736ad7f60..d9993bb6d39 100644
--- a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
+++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc
@@ -20,110 +20,76 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/variant.h"
 #include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/kernels/ragged_tensor_variant.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
 
 namespace tensorflow {
 namespace {
 
-struct RaggedTensor {
-  Tensor values;
-  std::vector<Tensor> nested_splits;
-};
-
-Status RaggedComponentsFromVariant(const Tensor& encoded_variant,
-                                   int ragged_rank, DataType value_dtype,
-                                   DataType split_dtype,
-                                   std::vector<RaggedTensor>* decoded_ragged) {
+Status RaggedComponentsFromVariant(
+    const Tensor& encoded_variant, int ragged_rank, DataType value_dtype,
+    DataType split_dtype, std::vector<RaggedTensorVariant>* decoded_ragged) {
   const auto& flat_variants = encoded_variant.flat<Variant>();
-  decoded_ragged->resize(flat_variants.size());
-  // Step 1: Extract the 1-D DT_VARIANT Tensor from each Variant element in the
-  // input.
+  decoded_ragged->reserve(flat_variants.size());
+
   for (int i = 0; i < flat_variants.size(); i++) {
     const auto& flat_variant = flat_variants(i);
-    const Tensor* encoded_list = flat_variant.get<Tensor>();
-    if (encoded_list == nullptr) {
+    const RaggedTensorVariant* decoded =
+        flat_variant.get<RaggedTensorVariant>();
+    if (decoded == nullptr) {
       return errors::InvalidArgument(
           "Input Variant element at index ", i,
-          " doesn't hold a Tensor: ", flat_variant.DebugString());
+          " doesn't hold a RaggedTensorVariant: ", flat_variant.DebugString());
     }
-    if (encoded_list->dims() != 1) {
+    decoded_ragged->push_back(*decoded);
+    decoded = &decoded_ragged->back();
+    // Check ragged rank & types
+    if (decoded->ragged_rank() != ragged_rank) {
       return errors::InvalidArgument(
-          "Encoded input Variant must have rank 1, but found rank: ",
-          encoded_list->dims(),
-          ". encoded input Variant: ", encoded_list->DebugString());
+          "Encoded input RaggedTensorVariant has ragged_rank=",
+          decoded->ragged_rank(), ".  Expected ragged_rank=", ragged_rank, ".");
     }
-    if (encoded_list->NumElements() != (ragged_rank + 1) &&
-        encoded_list->NumElements() != 1) {
-      return errors::InvalidArgument(
-          "Encoded input Variant must hold either input_ragged_rank + 1 "
-          "Tensors or an empty Tensor (zero splits Tensors, 1 values Tensor), "
-          "input_ragged_rank: ",
-          ragged_rank,
-          ", encoded input Variant: ", encoded_list->DebugString());
-    }
-    const auto& input_vec = encoded_list->vec<Variant>();
-
-    // Step 2: Get the splits and value Tensors from the 1-D DT_VARIANT Tensor
-    // to create the component RaggedTensors.
-    (*decoded_ragged)[i].nested_splits.reserve(ragged_rank);
-    for (int j = 0; j < ragged_rank; j++) {
-      const Tensor* split_tensor = input_vec(j).get<Tensor>();
-      if (split_tensor == nullptr) {
-        return errors::InvalidArgument(
-            "Encoded scalar element at index ", i,
-            " doesn't have a splits Tensor at split_index ", j, ": ",
-            input_vec(j).DebugString());
-      }
-      Tensor splits_tensor = *split_tensor;
-      if (splits_tensor.dtype() != split_dtype) {
-        return errors::InvalidArgument(
-            "Expected splits Tensor dtype: ", split_dtype,
-            ", found: ", splits_tensor.dtype());
-      }
-      if (splits_tensor.dims() != 1) {
-        return errors::InvalidArgument(
-            "Ragged splits must have rank 1; encoded scalar element at index ",
-            i, " has splits Tensor at split_index ", j, ": ",
-            splits_tensor.DebugString());
-      }
-      (*decoded_ragged)[i].nested_splits.push_back(splits_tensor);
-    }
-    const Tensor* values_tensor = input_vec(ragged_rank).get<Tensor>();
-    if (values_tensor == nullptr) {
-      return errors::InvalidArgument("Encoded scalar element at index ", i,
-                                     " doesn't have a values Tensor: ",
-                                     input_vec(ragged_rank).DebugString());
-    }
-    if (values_tensor->dtype() != value_dtype) {
+    if (decoded->values().dtype() != value_dtype) {
       return errors::InvalidArgument(
           "Expected values Tensor dtype: ", DataTypeString(value_dtype),
-          ", found: ", DataTypeString(values_tensor->dtype()));
+          ", found: ", DataTypeString(decoded->values().dtype()));
     }
-    if (values_tensor->dims() < 1) {
+    if (decoded->values().dims() < 1) {
       return errors::InvalidArgument(
           "Ragged values must have rank >= 1; encoded scalar element at index ",
-          i, " has values Tensor: ", values_tensor->DebugString());
+          i, " has values Tensor: ", decoded->values().DebugString());
+    }
+    for (const auto& splits : decoded->nested_splits()) {
+      if (splits.dtype() != split_dtype) {
+        return errors::InvalidArgument(
+            "Expected row_splits Tensor dtype: ", DataTypeString(split_dtype),
+            ", found: ", DataTypeString(splits.dtype()));
+      }
+      if (splits.dims() != 1) {
+        return errors::InvalidArgument(
+            "Ragged splits must have rank 1; encoded scalar element at index ",
+            i, " has splits Tensor ", splits.DebugString());
+      }
     }
-    (*decoded_ragged)[i].values = *values_tensor;
   }
   return Status::OK();
 }
 
 template <typename VALUE_TYPE, typename SPLIT_TYPE>
 Status NestedStackRaggedTensors(
-    const std::vector<RaggedTensor>& ragged_components,
+    const std::vector<RaggedTensorVariant>& ragged_components,
     const std::vector<int>& nested_dim_sizes, const int input_ragged_rank,
-    const int output_ragged_rank, RaggedTensor* output_ragged) {
-  output_ragged->nested_splits.reserve(output_ragged_rank);
+    const int output_ragged_rank, RaggedTensorVariant* output_ragged) {
+  output_ragged->mutable_nested_splits()->reserve(output_ragged_rank);
   const int dims = nested_dim_sizes.size();
 
   // Populate first `dims - 1` splits.
   for (int i = 0; i < dims - 1; i++) {
     int dims_splits_size = nested_dim_sizes[i] + 1;
-    output_ragged->nested_splits.push_back(Tensor(
-        DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({dims_splits_size})));
-    auto splits_vec = output_ragged->nested_splits[i].vec<SPLIT_TYPE>();
+    output_ragged->append_splits(Tensor(DataTypeToEnum<SPLIT_TYPE>::value,
+                                        TensorShape({dims_splits_size})));
+    auto splits_vec = output_ragged->mutable_splits(i)->vec<SPLIT_TYPE>();
     int split_diff = nested_dim_sizes[i + 1];
     for (int j = 0; j < dims_splits_size; j++) {
       splits_vec(j) = j * split_diff;
@@ -132,15 +98,15 @@ Status NestedStackRaggedTensors(
 
   // Populate `dims`-th split.
   int splits_size = ragged_components.size() + 1;
-  output_ragged->nested_splits.push_back(
+  output_ragged->append_splits(
       Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({splits_size})));
   auto dims_splits_vec =
-      output_ragged->nested_splits[dims - 1].vec<SPLIT_TYPE>();
+      output_ragged->mutable_splits(dims - 1)->vec<SPLIT_TYPE>();
   dims_splits_vec(0) = 0;
   for (int i = 0; i < ragged_components.size(); i++) {
-    int split_val = ragged_components[i].values.shape().dim_size(0);
-    if (input_ragged_rank != 0 && !ragged_components[i].nested_splits.empty()) {
-      split_val = ragged_components[i].nested_splits[0].NumElements() - 1;
+    int split_val = ragged_components[i].values().shape().dim_size(0);
+    if (input_ragged_rank != 0 && ragged_components[i].ragged_rank() > 0) {
+      split_val = ragged_components[i].splits(0).NumElements() - 1;
     }
     dims_splits_vec(i + 1) = dims_splits_vec(i) + split_val;
   }
@@ -150,24 +116,24 @@ Status NestedStackRaggedTensors(
     int split_index = dims + i;
     int split_size = 1;
     for (int j = 0; j < ragged_components.size(); j++) {
-      if (!ragged_components[j].nested_splits.empty()) {
-        split_size += ragged_components[j].nested_splits[i].NumElements() - 1;
+      if (!ragged_components[j].nested_splits().empty()) {
+        split_size += ragged_components[j].splits(i).NumElements() - 1;
       }
     }
-    output_ragged->nested_splits.push_back(
+    output_ragged->append_splits(
         Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
     auto splits_vec =
-        output_ragged->nested_splits[split_index].vec<SPLIT_TYPE>();
+        output_ragged->mutable_splits(split_index)->vec<SPLIT_TYPE>();
     splits_vec(0) = 0;
     SPLIT_TYPE last_split_value = 0;
     int index = 1;
     for (int j = 0; j < ragged_components.size(); j++) {
-      if (ragged_components[j].nested_splits.empty()) {
+      if (ragged_components[j].nested_splits().empty()) {
         // Corner case: empty row. e.g [ [[x], [x]], [] ]
         continue;
       }
       auto component_splits_vec =
-          ragged_components[j].nested_splits[i].vec<SPLIT_TYPE>();
+          ragged_components[j].splits(i).vec<SPLIT_TYPE>();
       for (int k = 1; k < component_splits_vec.size(); k++, index++) {
         splits_vec(index) = component_splits_vec(k) + last_split_value;
       }
@@ -187,35 +153,35 @@ Status NestedStackRaggedTensors(
   if (ragged_components.empty()) {
     component_values_shape = TensorShape({0});
   } else {
-    component_values_shape = ragged_components[0].values.shape();
+    component_values_shape = ragged_components[0].values().shape();
   }
 
   // Populate values.
   int values_size = component_values_shape.dim_size(0);
   for (int i = 1; i < ragged_components.size(); i++) {
-    if (ragged_components[i].values.dims() != component_values_shape.dims()) {
+    if (ragged_components[i].values().dims() != component_values_shape.dims()) {
       return errors::InvalidArgument(
           "Rank of values must match for all "
           "components; values shape at index 0: ",
           component_values_shape.DebugString(), ", values shape at index ", i,
-          ": ", ragged_components[i].values.shape().DebugString());
+          ": ", ragged_components[i].values().shape().DebugString());
     }
-    values_size += ragged_components[i].values.shape().dim_size(0);
+    values_size += ragged_components[i].values().shape().dim_size(0);
   }
   component_values_shape.set_dim(0, values_size);
-  output_ragged->values =
-      Tensor(DataTypeToEnum<VALUE_TYPE>::value, component_values_shape);
+  output_ragged->set_values(
+      Tensor(DataTypeToEnum<VALUE_TYPE>::value, component_values_shape));
   auto output_values_flat =
-      output_ragged->values.flat_outer_dims<VALUE_TYPE, 2>();
+      output_ragged->mutable_values()->flat_outer_dims<VALUE_TYPE, 2>();
   int values_index = 0;
   for (int i = 0; i < ragged_components.size(); i++) {
     auto component_values_flat =
-        ragged_components[i].values.flat_outer_dims<VALUE_TYPE, 2>();
-    int num_inner_elements = ragged_components[i].values.NumElements();
-    if (ragged_components[i].values.dim_size(0) > 0) {
-      num_inner_elements /= ragged_components[i].values.dim_size(0);
+        ragged_components[i].values().flat_outer_dims<VALUE_TYPE, 2>();
+    int num_inner_elements = ragged_components[i].values().NumElements();
+    if (ragged_components[i].values().dim_size(0) > 0) {
+      num_inner_elements /= ragged_components[i].values().dim_size(0);
     }
-    for (int j = 0; j < ragged_components[i].values.dim_size(0);
+    for (int j = 0; j < ragged_components[i].values().dim_size(0);
          j++, values_index++) {
       for (int k = 0; k < num_inner_elements; k++) {
         output_values_flat(values_index, k) = component_values_flat(j, k);
@@ -265,7 +231,7 @@ class RaggedTensorFromVariantOp : public OpKernel {
     // Decode all variants.
     const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
     const auto split_dtype = DataTypeToEnum<SPLIT_TYPE>::v();
-    std::vector<RaggedTensor> decoded_components;
+    std::vector<RaggedTensorVariant> decoded_components;
     OP_REQUIRES_OK(context, RaggedComponentsFromVariant(
                                 encoded_variant, input_ragged_rank_,
                                 value_dtype, split_dtype, &decoded_components));
@@ -281,7 +247,7 @@ class RaggedTensorFromVariantOp : public OpKernel {
     for (int i = 0; i < encoded_variant.dims(); i++) {
       encoded_dim_sizes[i] = encoded_variant.dim_size(i);
     }
-    RaggedTensor output_ragged;
+    RaggedTensorVariant output_ragged;
     OP_REQUIRES_OK(
         context, NestedStackRaggedTensors<VALUE_TYPE, SPLIT_TYPE>(
                      decoded_components, encoded_dim_sizes, input_ragged_rank_,
@@ -296,15 +262,15 @@ class RaggedTensorFromVariantOp : public OpKernel {
   int output_ragged_rank_;
 
   void ReturnRaggedTensor(OpKernelContext* context,
-                          RaggedTensor ragged_tensor) {
-    int ragged_rank = ragged_tensor.nested_splits.size();
+                          const RaggedTensorVariant& ragged_tensor) {
+    int ragged_rank = ragged_tensor.ragged_rank();
     OpOutputList splits_out;
     OP_REQUIRES_OK(context,
                    context->output_list("output_nested_splits", &splits_out));
     for (int i = 0; i < ragged_rank; i++) {
-      splits_out.set(i, ragged_tensor.nested_splits[i]);
+      splits_out.set(i, ragged_tensor.splits(i));
     }
-    context->set_output(ragged_rank, ragged_tensor.values);
+    context->set_output(ragged_rank, ragged_tensor.values());
   }
 };
 
diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op_test.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op_test.cc
index bdf321d0515..fc46283c90e 100644
--- a/tensorflow/core/kernels/ragged_tensor_from_variant_op_test.cc
+++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op_test.cc
@@ -26,6 +26,7 @@ limitations under the License.
 #include "tensorflow/core/framework/variant.h"
 #include "tensorflow/core/framework/variant_encode_decode.h"
 #include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ragged_tensor_variant.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
 
@@ -55,28 +56,22 @@ class RaggedTensorFromVariantKernelTest : public ::tensorflow::OpsTestBase {
   }
 
   template <typename VALUE_TYPE, typename SPLIT_TYPE>
-  Tensor CreateVariantFromRagged(
+  RaggedTensorVariant CreateVariantFromRagged(
       const std::vector<std::vector<SPLIT_TYPE>>& ragged_splits,
       const TensorShape& ragged_values_shape,
       const std::vector<VALUE_TYPE>& ragged_values) {
-    // Step 1: Create Tensors out of ragged splits and values.
-    std::vector<Variant> ragged_components;
+    RaggedTensorVariant encoded;
     for (auto ragged_split : ragged_splits) {
       int splits_size = ragged_split.size();
       Tensor splits(DataTypeToEnum<SPLIT_TYPE>::v(),
                     TensorShape({splits_size}));
       test::FillValues<SPLIT_TYPE>(&splits, ragged_split);
-      ragged_components.push_back(splits);
+      encoded.append_splits(splits);
     }
     Tensor values(DataTypeToEnum<VALUE_TYPE>::v(), ragged_values_shape);
     test::FillValues<VALUE_TYPE>(&values, ragged_values);
-    ragged_components.push_back(values);
-
-    // Step 2: Encode into a 1-D Variant Tensor.
-    int num_splits = ragged_splits.size();
-    Tensor encoded_list(DT_VARIANT, TensorShape({num_splits + 1}));
-    test::FillValues<Variant>(&encoded_list, ragged_components);
-    return encoded_list;
+    encoded.set_values(values);
+    return encoded;
   }
 };
 
@@ -85,7 +80,7 @@ TEST_F(RaggedTensorFromVariantKernelTest, ScalarInput) {
   const std::vector<int64> split_2 = {0, 1, 2, 5, 6, 7};
   const std::vector<int> values = {0, 1, 1, 2, 2, 3, 4};
 
-  Tensor encoded_variant = CreateVariantFromRagged<int, int64>(
+  auto encoded_variant = CreateVariantFromRagged<int, int64>(
       {split_1, split_2}, TensorShape({7}), values);
   Tensor expected_splits_1(DT_INT64, TensorShape({6}));
   Tensor expected_splits_2(DT_INT64, TensorShape({6}));
@@ -113,7 +108,7 @@ TEST_F(RaggedTensorFromVariantKernelTest, OneInputElement) {
   const std::vector<int> values = {0, 1, 1, 2, 2, 3, 4};
   const std::vector<int64> batched_splits_1 = {0, 5};
 
-  Tensor encoded_variant = CreateVariantFromRagged<int, int64>(
+  auto encoded_variant = CreateVariantFromRagged<int, int64>(
       {split_1, split_2}, TensorShape({7}), values);
   Tensor expected_splits_1(DT_INT64, TensorShape({2}));
   Tensor expected_splits_2(DT_INT64, TensorShape({6}));
@@ -157,13 +152,13 @@ TEST_F(RaggedTensorFromVariantKernelTest, TensorIn2DOut) {
   const std::vector<int64> batched_splits_2 = {0, 3, 3, 5, 6};
   const std::vector<int> batched_values = {1, 2, 3, 4, 5, 6};
 
-  Tensor component_variant_1 =
+  auto component_variant_1 =
       CreateVariantFromRagged<int, int64>({}, TensorShape({3}), values_1);
-  Tensor component_variant_2 =
+  auto component_variant_2 =
       CreateVariantFromRagged<int, int64>({}, TensorShape({0}), values_2);
-  Tensor component_variant_3 =
+  auto component_variant_3 =
       CreateVariantFromRagged<int, int64>({}, TensorShape({2}), values_3);
-  Tensor component_variant_4 =
+  auto component_variant_4 =
       CreateVariantFromRagged<int, int64>({}, TensorShape({1}), values_4);
 
   Tensor expected_splits_1(DT_INT64, TensorShape({3}));
@@ -223,15 +218,15 @@ TEST_F(RaggedTensorFromVariantKernelTest, NonEmpty1DIn3DOut) {
   test::FillValues<int64>(&expected_splits_3, batched_splits_3);
   test::FillValues<int>(&expected_values, batched_values);
 
-  Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_1 = CreateVariantFromRagged<int, int64>(
       {component_split_1_1}, TensorShape({1}), component_values_1);
-  Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_2 = CreateVariantFromRagged<int, int64>(
       {component_split_2_1}, TensorShape({2}), component_values_2);
-  Tensor variant_component_3 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_3 = CreateVariantFromRagged<int, int64>(
       {component_split_3_1}, TensorShape({2}), component_values_3);
-  Tensor variant_component_4 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_4 = CreateVariantFromRagged<int, int64>(
       {component_split_4_1}, TensorShape({3}), component_values_4);
-  Tensor variant_component_5 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_5 = CreateVariantFromRagged<int, int64>(
       {component_split_5_1}, TensorShape({3}), component_values_5);
   int input_ragged_rank = 1;
   int output_ragged_rank = 3;
@@ -297,10 +292,10 @@ TEST_F(RaggedTensorFromVariantKernelTest,
   test::FillValues<int64>(&expected_splits_4, batched_splits_4);
   test::FillValues<int>(&expected_values, batched_values);
 
-  Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_1 = CreateVariantFromRagged<int, int64>(
       {component_split_1_1, component_split_1_2}, TensorShape({11}),
       component_values_1);
-  Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_2 = CreateVariantFromRagged<int, int64>(
       {component_split_2_1, component_split_2_2}, TensorShape({11}),
       component_values_2);
   int input_ragged_rank = -1;
@@ -336,9 +331,9 @@ TEST_F(RaggedTensorFromVariantKernelTest, EmptyRow1DIn2DOut) {
   test::FillValues<int64>(&expected_splits_2, batched_splits_2);
   test::FillValues<int>(&expected_values, batched_values);
 
-  Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_1 = CreateVariantFromRagged<int, int64>(
       {component_split_1_1}, TensorShape({3}), component_values_1);
-  Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_2 = CreateVariantFromRagged<int, int64>(
       {component_split_2_1}, TensorShape({0}), {});  // Empty row.
   int input_ragged_rank = 1;
   int output_ragged_rank = 2;
@@ -371,9 +366,9 @@ TEST_F(RaggedTensorFromVariantKernelTest, NDValues1DIn2DOut) {
   test::FillValues<int64>(&expected_splits_2, batched_splits_2);
   test::FillValues<int>(&expected_values, batched_values);
 
-  Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_1 = CreateVariantFromRagged<int, int64>(
       {component_split_1_1}, TensorShape({1, 2}), component_values_1);
-  Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_2 = CreateVariantFromRagged<int, int64>(
       {component_split_2_1}, TensorShape({2, 2}), component_values_2);
   int input_ragged_rank = 1;
   int output_ragged_rank = 2;
@@ -423,15 +418,15 @@ TEST_F(RaggedTensorFromVariantKernelTest, NonEmpty1DIn3DOutInt32Splits) {
   test::FillValues<int>(&expected_splits_3, batched_splits_3);
   test::FillValues<int>(&expected_values, batched_values);
 
-  Tensor variant_component_1 = CreateVariantFromRagged<int, int>(
+  auto variant_component_1 = CreateVariantFromRagged<int, int>(
       {component_split_1_1}, TensorShape({1}), component_values_1);
-  Tensor variant_component_2 = CreateVariantFromRagged<int, int>(
+  auto variant_component_2 = CreateVariantFromRagged<int, int>(
       {component_split_2_1}, TensorShape({2}), component_values_2);
-  Tensor variant_component_3 = CreateVariantFromRagged<int, int>(
+  auto variant_component_3 = CreateVariantFromRagged<int, int>(
       {component_split_3_1}, TensorShape({2}), component_values_3);
-  Tensor variant_component_4 = CreateVariantFromRagged<int, int>(
+  auto variant_component_4 = CreateVariantFromRagged<int, int>(
       {component_split_4_1}, TensorShape({3}), component_values_4);
-  Tensor variant_component_5 = CreateVariantFromRagged<int, int>(
+  auto variant_component_5 = CreateVariantFromRagged<int, int>(
       {component_split_5_1}, TensorShape({3}), component_values_5);
   int input_ragged_rank = 1;
   int output_ragged_rank = 3;
@@ -451,13 +446,13 @@ TEST_F(RaggedTensorFromVariantKernelTest, NonEmpty1DIn3DOutInt32Splits) {
 
 // Tests for invalid inputs.
 TEST_F(RaggedTensorFromVariantKernelTest, InvalidInferredInputRaggedRank) {
-  Tensor component_variant_1 =
+  auto component_variant_1 =
       CreateVariantFromRagged<int, int64>({}, TensorShape({3}), {1, 2, 3});
-  Tensor component_variant_2 =
+  auto component_variant_2 =
       CreateVariantFromRagged<int, int64>({}, TensorShape({0}), {});
-  Tensor component_variant_3 =
+  auto component_variant_3 =
       CreateVariantFromRagged<int, int64>({}, TensorShape({2}), {1, 2});
-  Tensor component_variant_4 =
+  auto component_variant_4 =
       CreateVariantFromRagged<int, int64>({}, TensorShape({1}), {1});
 
   int input_ragged_rank = -1;
@@ -478,9 +473,9 @@ TEST_F(RaggedTensorFromVariantKernelTest, InputDimsAndRaggedRankAttrsMismatch) {
   const std::vector<int> component_values_1 = {0};
   const std::vector<int> component_values_2 = {0, 1};
 
-  Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_1 = CreateVariantFromRagged<int, int64>(
       {component_split_1_1}, TensorShape({1}), component_values_1);
-  Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_2 = CreateVariantFromRagged<int, int64>(
       {component_split_2_1}, TensorShape({2}), component_values_2);
 
   int input_ragged_rank = 1;
@@ -493,33 +488,21 @@ TEST_F(RaggedTensorFromVariantKernelTest, InputDimsAndRaggedRankAttrsMismatch) {
                                "input_ragged_rank + encoded_ragged.dims()"));
 }
 
-TEST_F(RaggedTensorFromVariantKernelTest, InputDoesNotHoldTensors) {
+TEST_F(RaggedTensorFromVariantKernelTest, InputDoesNotHoldRaggedTensorVariant) {
   int input_ragged_rank = 1;
   int output_ragged_rank = 2;
   BuildDecodeRaggedTensorGraph<int, int64>(
       input_ragged_rank, output_ragged_rank, TensorShape({2}), {1, 2});
   EXPECT_TRUE(absl::StartsWith(
       RunOpKernel().error_message(),
-      "Input Variant element at index 0 doesn't hold a Tensor"));
-}
-
-TEST_F(RaggedTensorFromVariantKernelTest, InputVariantTensorRankNotOne) {
-  Tensor variant_list(DT_VARIANT, TensorShape({2, 1}));
-  test::FillValues<Variant>(&variant_list, {1, 2});
-  int input_ragged_rank = 1;
-  int output_ragged_rank = 2;
-  BuildDecodeRaggedTensorGraph<int, int64>(
-      input_ragged_rank, output_ragged_rank, TensorShape({1}), {variant_list});
-  EXPECT_TRUE(absl::StartsWith(
-      RunOpKernel().error_message(),
-      "Encoded input Variant must have rank 1, but found rank: 2"));
+      "Input Variant element at index 0 doesn't hold a RaggedTensorVariant"));
 }
 
 TEST_F(RaggedTensorFromVariantKernelTest,
        InputScalarElementDoesNotMatchInputRaggedRank) {
   const std::vector<int64> component_split_1_1 = {0, 1};
   const std::vector<int> component_values_1 = {1, 2};
-  Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_1 = CreateVariantFromRagged<int, int64>(
       {component_split_1_1}, TensorShape({1, 2}), component_values_1);
 
   int input_ragged_rank = 2;
@@ -527,31 +510,17 @@ TEST_F(RaggedTensorFromVariantKernelTest,
   BuildDecodeRaggedTensorGraph<int, int64>(input_ragged_rank,
                                            output_ragged_rank, TensorShape({1}),
                                            {variant_component_1});
-  EXPECT_TRUE(absl::StartsWith(
-      RunOpKernel().error_message(),
-      "Encoded input Variant must hold either input_ragged_rank + 1 "
-      "Tensors or an empty Tensor"));
-}
-
-TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitNotATensor) {
-  Tensor variant_list(DT_VARIANT, TensorShape({2}));
-  test::FillValues<Variant>(&variant_list, {1, 2});
-
-  int input_ragged_rank = 1;
-  int output_ragged_rank = 2;
-  BuildDecodeRaggedTensorGraph<int, int>(input_ragged_rank, output_ragged_rank,
-                                         TensorShape({1}), {variant_list});
   EXPECT_TRUE(
       absl::StartsWith(RunOpKernel().error_message(),
-                       "Encoded scalar element at index 0 doesn't have a "
-                       "splits Tensor at split_index 0"));
+                       "Encoded input RaggedTensorVariant has ragged_rank=1.  "
+                       "Expected ragged_rank=2."));
 }
 
 TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitTypeMismatch) {
   const std::vector<int64> component_split_1_1 = {0, 1};
   const std::vector<int> component_values_1 = {0};
 
-  Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_1 = CreateVariantFromRagged<int, int64>(
       {component_split_1_1}, TensorShape({1}), component_values_1);
 
   int input_ragged_rank = 1;
@@ -559,46 +528,29 @@ TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitTypeMismatch) {
   BuildDecodeRaggedTensorGraph<int, int>(input_ragged_rank, output_ragged_rank,
                                          TensorShape({1}),
                                          {variant_component_1});
-  EXPECT_TRUE(absl::StartsWith(RunOpKernel().error_message(),
-                               "Expected splits Tensor dtype: 3, found: 9"));
+  EXPECT_TRUE(absl::StartsWith(
+      RunOpKernel().error_message(),
+      "Expected row_splits Tensor dtype: int32, found: int64"));
 }
 
 TEST_F(RaggedTensorFromVariantKernelTest, RaggedSplitRankNotOne) {
-  Tensor splits(DT_INT64, TensorShape({2, 1}));
-  test::FillValues<int64>(&splits, {1, 2});
-  Tensor values(DT_INT32, {2});
-  test::FillValues<int>(&values, {1, 2});
-  Tensor encoded_list(DT_VARIANT, TensorShape({2}));
-  test::FillValues<Variant>(&encoded_list, {splits, values});
+  RaggedTensorVariant encoded(Tensor(DT_INT32, {2}),
+                              {Tensor(DT_INT64, {2, 1})});
+  test::FillValues<int64>(encoded.mutable_splits(0), {1, 2});
+  test::FillValues<int>(encoded.mutable_values(), {1, 2});
 
   int input_ragged_rank = 1;
   int output_ragged_rank = 2;
   BuildDecodeRaggedTensorGraph<int, int64>(
-      input_ragged_rank, output_ragged_rank, TensorShape({1}), {encoded_list});
+      input_ragged_rank, output_ragged_rank, TensorShape({1}), {encoded});
   EXPECT_TRUE(absl::StartsWith(RunOpKernel().error_message(),
                                "Ragged splits must have rank 1"));
 }
 
-TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesNotATensor) {
-  Tensor splits(DT_INT64, TensorShape({3}));
-  test::FillValues<int64>(&splits, {0, 2, 3});
-  Tensor variant_list(DT_VARIANT, TensorShape({2}));
-  test::FillValues<Variant>(&variant_list, {splits, 2});
-
-  int input_ragged_rank = 1;
-  int output_ragged_rank = 2;
-  BuildDecodeRaggedTensorGraph<int, int64>(
-      input_ragged_rank, output_ragged_rank, TensorShape({1}), {variant_list});
-  EXPECT_TRUE(
-      absl::StartsWith(RunOpKernel().error_message(),
-                       "Encoded scalar element at index 0 doesn't have a "
-                       "values Tensor"));
-}
-
 TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesTypeMismatch) {
   const std::vector<int64> component_split_1_1 = {0, 1};
   const std::vector<int> component_values_1 = {0};
-  Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_1 = CreateVariantFromRagged<int, int64>(
       {component_split_1_1}, TensorShape({1}), component_values_1);
   int input_ragged_rank = 1;
   int output_ragged_rank = 2;
@@ -611,7 +563,7 @@ TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesTypeMismatch) {
 }
 
 TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesRankNotGreaterThanOne) {
-  Tensor variant_component_1 =
+  auto variant_component_1 =
       CreateVariantFromRagged<int, int64>({{0, 1}}, TensorShape({}), {1});
   int input_ragged_rank = 1;
   int output_ragged_rank = 2;
@@ -628,9 +580,9 @@ TEST_F(RaggedTensorFromVariantKernelTest, RaggedValuesRankMismatch) {
   const std::vector<int> component_values_1 = {0};
   const std::vector<int> component_values_2 = {0, 1, 2, 3};
 
-  Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_1 = CreateVariantFromRagged<int, int64>(
       {component_split_1_1}, TensorShape({1}), component_values_1);
-  Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_2 = CreateVariantFromRagged<int, int64>(
       {component_split_2_1}, TensorShape({2, 2}), component_values_2);
   int input_ragged_rank = 1;
   int output_ragged_rank = 2;
@@ -711,13 +663,13 @@ TEST_F(RaggedTensorFromVariantKernelTest, 2DValuesTensorIn1DOut) {
   const std::vector<int> batched_values = {1, 1, 1, 1, 2, 2, 2, 2, 3, 3,
                                            3, 3, 4, 4, 4, 4, 5, 5, 5, 5};
 
-  Tensor variant_component_1 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_1 = CreateVariantFromRagged<int, int64>(
       {}, TensorShape({2, 2, 2}), {1, 1, 1, 1, 2, 2, 2, 2});
-  Tensor variant_component_2 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_2 = CreateVariantFromRagged<int, int64>(
       {}, TensorShape({1, 2, 2}), {3, 3, 3, 3});
-  Tensor variant_component_3 =
+  auto variant_component_3 =
       CreateVariantFromRagged<int, int64>({}, TensorShape({0, 2, 2}), {});
-  Tensor variant_component_4 = CreateVariantFromRagged<int, int64>(
+  auto variant_component_4 = CreateVariantFromRagged<int, int64>(
       {}, TensorShape({2, 2, 2}), {4, 4, 4, 4, 5, 5, 5, 5});
 
   Tensor expected_splits_1(DT_INT64, TensorShape({5}));
diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc
index 64c372b005e..549dc68dfbf 100644
--- a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc
+++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc
@@ -18,50 +18,38 @@ limitations under the License.
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
 #include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/variant.h"
 #include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
+#include "tensorflow/core/kernels/concat_lib.h"
+#include "tensorflow/core/kernels/ragged_tensor_variant.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/tensor_ops_util.h"
 
 namespace tensorflow {
 namespace {
 
-struct RaggedTensor {
-  Tensor values;
-  std::vector<Tensor> nested_splits;
-};
-
-Status RaggedToVariant(const RaggedTensor& ragged, Tensor* encoded_list) {
-  // Encode as a rank-1 Variant Tensor.
-  int ragged_rank = ragged.nested_splits.size();
-  *encoded_list = Tensor(DT_VARIANT, TensorShape({ragged_rank + 1}));
-  auto encoded_vec = encoded_list->vec<Variant>();
-  for (int i = 0; i < ragged_rank; i++) {
-    encoded_vec(i) = ragged.nested_splits[i];
-  }
-  encoded_vec(ragged_rank) = ragged.values;
-  return Status::OK();
-}
-
 template <typename VALUE_TYPE, typename SPLIT_TYPE>
-Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
-                              std::vector<RaggedTensor>* ragged_components) {
+Status UnbatchRaggedZerothDim(
+    const RaggedTensorVariant& batched_ragged,
+    std::vector<RaggedTensorVariant>* ragged_components) {
   // Set up the component Ragged Tensors.
-  int ragged_rank = batched_ragged.nested_splits.size();
-  auto batched_splits_top_vec =
-      batched_ragged.nested_splits[0].vec<SPLIT_TYPE>();
+  int ragged_rank = batched_ragged.ragged_rank();
+  auto batched_splits_top_vec = batched_ragged.splits(0).vec<SPLIT_TYPE>();
   int num_components = batched_splits_top_vec.size() - 1;
   int num_splits = ragged_rank - 1;
   ragged_components->resize(num_components);
-  for (RaggedTensor ragged_component : *ragged_components) {
-    ragged_component.nested_splits.reserve(num_splits);
+  for (RaggedTensorVariant& ragged_component : *ragged_components) {
+    ragged_component.mutable_nested_splits()->reserve(num_splits);
   }
-  const auto& batched_flat = batched_ragged.values.flat<VALUE_TYPE>();
-  int num_inner_elems = batched_ragged.values.NumElements();
-  if (batched_ragged.values.dim_size(0) > 1) {
-    num_inner_elems /= batched_ragged.values.dim_size(0);
+  const auto& batched_flat = batched_ragged.values().flat<VALUE_TYPE>();
+  int num_inner_elems = batched_ragged.values().NumElements();
+  if (batched_ragged.values().dim_size(0) > 1) {
+    num_inner_elems /= batched_ragged.values().dim_size(0);
   }
-  TensorShape values_shape = batched_ragged.values.shape();
+  TensorShape values_shape = batched_ragged.values().shape();
 
   // Corner case: ragged_rank == 1, e.g. [[1, 2, 3], [4, 5]]
   if (num_splits == 0) {
@@ -70,10 +58,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
       int limit = batched_splits_top_vec(i + 1);
       int num_values = limit - start;
       values_shape.set_dim(0, num_values);
-      (*ragged_components)[i].values =
-          Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape);
+      (*ragged_components)[i].set_values(
+          Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
       auto ragged_component_values_flat =
-          (*ragged_components)[i].values.flat<VALUE_TYPE>();
+          (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
       for (int j = 0; j < num_values * num_inner_elems; j++) {
         ragged_component_values_flat(j) =
             batched_flat(j + start * num_inner_elems);
@@ -86,8 +74,7 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
   std::vector<typename TTypes<SPLIT_TYPE>::ConstVec> batched_splits_vec;
   batched_splits_vec.reserve(ragged_rank);
   for (int i = 0; i < ragged_rank; i++) {
-    batched_splits_vec.push_back(
-        batched_ragged.nested_splits[i].vec<SPLIT_TYPE>());
+    batched_splits_vec.push_back(batched_ragged.splits(i).vec<SPLIT_TYPE>());
   }
   std::vector<int> index(num_splits, 1);
   std::vector<int> ragged_component_values_size(num_components, 0);
@@ -104,10 +91,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
         int last_index = ragged_component_splits_vec[j - 1].size() - 1;
         split_size = ragged_component_splits_vec[j - 1](last_index) + 1;
       }
-      (*ragged_components)[i].nested_splits.push_back(
+      (*ragged_components)[i].append_splits(
           Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
       ragged_component_splits_vec.push_back(
-          (*ragged_components)[i].nested_splits[j].vec<SPLIT_TYPE>());
+          (*ragged_components)[i].mutable_splits(j)->vec<SPLIT_TYPE>());
       SPLIT_TYPE last_split_value = batched_splits_vec[j + 1](index[j] - 1);
       ragged_component_splits_vec[j](0) = 0;
       for (int k = 1; k < split_size; k++, index[j]++) {
@@ -125,10 +112,10 @@ Status UnbatchRaggedZerothDim(const RaggedTensor& batched_ragged,
   for (int i = 0; i < num_components; i++) {
     int num_values = ragged_component_values_size[i];
     values_shape.set_dim(0, num_values);
-    (*ragged_components)[i].values =
-        Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape);
+    (*ragged_components)[i].set_values(
+        Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
     auto ragged_component_values_flat =
-        (*ragged_components)[i].values.flat<VALUE_TYPE>();
+        (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
     for (int j = 0; j < num_values * num_inner_elems; j++, value_index++) {
       ragged_component_values_flat(j) = batched_flat(value_index);
     }
@@ -152,46 +139,38 @@ class RaggedTensorToVariantOp : public OpKernel {
     OP_REQUIRES_OK(context, context->input_list("rt_nested_splits",
                                                 &ragged_nested_splits_in));
     const int ragged_nested_splits_len = ragged_nested_splits_in.size();
-    RaggedTensor batched_ragged_input;
+    RaggedTensorVariant batched_ragged_input;
     // Read ragged_values input.
-    batched_ragged_input.values = context->input(ragged_nested_splits_len);
-    batched_ragged_input.nested_splits.reserve(ragged_nested_splits_len);
+    batched_ragged_input.set_values(context->input(ragged_nested_splits_len));
+    batched_ragged_input.mutable_nested_splits()->reserve(
+        ragged_nested_splits_len);
     for (int i = 0; i < ragged_nested_splits_len; i++) {
-      batched_ragged_input.nested_splits.push_back(ragged_nested_splits_in[i]);
+      batched_ragged_input.append_splits(ragged_nested_splits_in[i]);
     }
 
     if (!batched_input_) {
-      // Encode the input as is.
-      Tensor encoded_list;
-      OP_REQUIRES_OK(context,
-                     RaggedToVariant(batched_ragged_input, &encoded_list));
       // Encode as a Scalar Variant Tensor.
       Tensor* encoded_scalar;
       OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}),
                                                        &encoded_scalar));
-      encoded_scalar->scalar<Variant>()() = std::move(encoded_list);
+      encoded_scalar->scalar<Variant>()() = std::move(batched_ragged_input);
       return;
     }
 
     // Unbatch the Ragged Tensor and encode the components.
-    std::vector<RaggedTensor> ragged_components;
+    std::vector<RaggedTensorVariant> unbatched_ragged_input;
     OP_REQUIRES_OK(context, UnbatchRaggedZerothDim<VALUE_TYPE, SPLIT_TYPE>(
-                                batched_ragged_input, &ragged_components));
-    std::vector<Tensor> encoded_components(ragged_components.size());
-    for (int i = 0; i < ragged_components.size(); i++) {
-      OP_REQUIRES_OK(context, RaggedToVariant(ragged_components[i],
-                                              &encoded_components[i]));
-    }
+                                batched_ragged_input, &unbatched_ragged_input));
 
     // Bundle the encoded scalar Variant Tensors into a rank-1 Variant Tensor.
-    Tensor* encoded_ragged;
-    int output_size = ragged_components.size();
+    Tensor* encoded_vector;
+    int output_size = unbatched_ragged_input.size();
     OP_REQUIRES_OK(context,
                    context->allocate_output(0, TensorShape({output_size}),
-                                            &encoded_ragged));
-    auto encoded_ragged_vec = encoded_ragged->vec<Variant>();
+                                            &encoded_vector));
+    auto encoded_vector_t = encoded_vector->vec<Variant>();
     for (int i = 0; i < output_size; i++) {
-      encoded_ragged_vec(i) = encoded_components[i];
+      encoded_vector_t(i) = unbatched_ragged_input[i];
     }
   }
 
@@ -199,12 +178,81 @@ class RaggedTensorToVariantOp : public OpKernel {
   bool batched_input_;
 };
 
-#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type)      \
-  REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant")               \
-                              .Device(DEVICE_CPU)                     \
-                              .TypeConstraint<value_type>("Tvalues")  \
-                              .TypeConstraint<split_type>("Tsplits"), \
-                          RaggedTensorToVariantOp<value_type, split_type>);
+template <typename VALUE_TYPE, typename SPLIT_TYPE>
+class RaggedTensorToVariantGradientOp : public OpKernel {
+ public:
+  using OpKernel::OpKernel;
+
+  void Compute(OpKernelContext* context) override {
+    // Read inputs.
+    Tensor encoded_variant = context->input(0);
+    Tensor row_splits = context->input(1);
+    auto flat_row_splits = row_splits.flat<SPLIT_TYPE>();
+    TensorShape dense_values_shape;
+    OP_REQUIRES_OK(context,
+                   TensorShapeUtils::MakeShape(context->input(2).vec<int32>(),
+                                               &dense_values_shape));
+
+    const auto& flat_variants = encoded_variant.flat<Variant>();
+
+    // Get a Tensor containing the flat_values for each variant.
+    std::vector<Tensor> values;
+    for (int i = 0; i < flat_variants.size(); ++i) {
+      if (const auto* encoded = flat_variants(i).get<RaggedTensorVariant>()) {
+        values.push_back(encoded->values());
+      } else {
+        // Missing value: this happens if only some of the variant values
+        // generated by ragged_tensor_to_variant impacted the value that we're
+        // calculating the gradient for.  In this case, we will see a
+        // default-constructed variant; so treat it as a zero tensor with the
+        // appropriate shape.
+        const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
+        int piece_size = flat_row_splits(i + 1) - flat_row_splits(i);
+        TensorShape zeros_shape = dense_values_shape;
+        zeros_shape.set_dim(0, piece_size);
+        Tensor zero(value_dtype, zeros_shape);
+        zero.flat<VALUE_TYPE>() =
+            zero.flat<VALUE_TYPE>().constant(VALUE_TYPE());
+        values.push_back(zero);
+      }
+    }
+
+    if (values.size() == 1) {
+      // Just one flat_value tensor: return as-is.
+      context->set_output(0, values[0]);
+    } else {
+      // Multiple flat_values tensors: concatenate them together.
+      using Piece = typename TTypes<VALUE_TYPE, 2>::Matrix;
+      using ConstPiece = typename TTypes<VALUE_TYPE, 2>::ConstMatrix;
+      std::vector<std::unique_ptr<ConstPiece>> pieces;
+      pieces.reserve(values.size());
+      for (const Tensor& t : values) {
+        pieces.emplace_back(
+            new ConstPiece(t.shaped<VALUE_TYPE, 2>({1, t.NumElements()})));
+      }
+      Tensor* out = nullptr;
+      OP_REQUIRES_OK(context,
+                     context->allocate_output(0, dense_values_shape, &out));
+      Piece out_flat =
+          out->shaped<VALUE_TYPE, 2>({1, dense_values_shape.num_elements()});
+      ConcatCPU<VALUE_TYPE>(context->device(), pieces, &out_flat);
+    }
+  }
+};
+
+#define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type)            \
+  REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant")                     \
+                              .Device(DEVICE_CPU)                           \
+                              .TypeConstraint<value_type>("Tvalues")        \
+                              .TypeConstraint<split_type>("Tsplits"),       \
+                          RaggedTensorToVariantOp<value_type, split_type>); \
+  REGISTER_KERNEL_BUILDER(                                                  \
+      Name("RaggedTensorToVariantGradient")                                 \
+          .Device(DEVICE_CPU)                                               \
+          .TypeConstraint<value_type>("Tvalues")                            \
+          .TypeConstraint<split_type>("Tsplits"),                           \
+      RaggedTensorToVariantGradientOp<value_type, split_type>);
+
 #define REGISTER_KERNELS(value_type)                  \
   REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \
   REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64)
diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc
index c1438dd7af9..94f35673c8b 100644
--- a/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc
+++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc
@@ -26,6 +26,7 @@ limitations under the License.
 #include "tensorflow/core/framework/variant.h"
 #include "tensorflow/core/framework/variant_encode_decode.h"
 #include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ragged_tensor_variant.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
 
@@ -60,6 +61,43 @@ class RaggedTensorToVariantKernelTest : public ::tensorflow::OpsTestBase {
     }
     AddInputFromArray<VALUE_TYPE>(ragged_values_shape, ragged_values);
   }
+
+  template <typename VALUE_TYPE, typename SPLIT_TYPE>
+  RaggedTensorVariant CreateVariantFromRagged(
+      const std::vector<std::vector<SPLIT_TYPE>>& ragged_splits,
+      const TensorShape& ragged_values_shape,
+      const std::vector<VALUE_TYPE>& ragged_values) {
+    RaggedTensorVariant encoded;
+    for (auto ragged_split : ragged_splits) {
+      int splits_size = ragged_split.size();
+      Tensor splits(DataTypeToEnum<SPLIT_TYPE>::v(),
+                    TensorShape({splits_size}));
+      test::FillValues<SPLIT_TYPE>(&splits, ragged_split);
+      encoded.append_splits(splits);
+    }
+    Tensor values(DataTypeToEnum<VALUE_TYPE>::v(), ragged_values_shape);
+    test::FillValues<VALUE_TYPE>(&values, ragged_values);
+    encoded.set_values(values);
+    return encoded;
+  }
+
+  template <typename VALUE_TYPE, typename SPLIT_TYPE>
+  RaggedTensorVariant CreateVariantFromRagged(
+      const std::vector<std::vector<SPLIT_TYPE>>& ragged_splits,
+      const std::vector<VALUE_TYPE>& ragged_values) {
+    int num_values = ragged_values.size();
+    return CreateVariantFromRagged(ragged_splits, {num_values}, ragged_values);
+  }
+
+  template <typename VALUE_TYPE, typename SPLIT_TYPE>
+  void ExpectRaggedTensorVariantEqual(const RaggedTensorVariant& expected,
+                                      const RaggedTensorVariant& actual) {
+    test::ExpectTensorEqual<VALUE_TYPE>(actual.values(), expected.values());
+    EXPECT_EQ(actual.ragged_rank(), expected.ragged_rank());
+    for (int i = 0; i < actual.ragged_rank(); ++i) {
+      test::ExpectTensorEqual<SPLIT_TYPE>(actual.splits(i), expected.splits(i));
+    }
+  }
 };
 
 TEST_F(RaggedTensorToVariantKernelTest, NoValuesInput) {
@@ -67,18 +105,6 @@ TEST_F(RaggedTensorToVariantKernelTest, NoValuesInput) {
   const std::vector<int64> batched_splits_1 = {0, 2, 3, 3};
   const std::vector<int64> batched_splits_2 = {0, 0, 0, 0};
 
-  const std::vector<int64> component_splits_1_1 = {0, 0, 0};
-  const std::vector<int64> component_splits_2_1 = {0, 0};
-  const std::vector<int64> component_splits_3_1 = {0};
-
-  Tensor expected_splits_1_1(DT_INT64, TensorShape({3}));
-  Tensor expected_splits_2_1(DT_INT64, TensorShape({2}));
-  Tensor expected_splits_3_1(DT_INT64, TensorShape({1}));
-
-  test::FillValues<int64>(&expected_splits_1_1, component_splits_1_1);
-  test::FillValues<int64>(&expected_splits_2_1, component_splits_2_1);
-  test::FillValues<int64>(&expected_splits_3_1, component_splits_3_1);
-
   BuildEncodeRaggedTensorGraph<int, int64>({batched_splits_1, batched_splits_2},
                                            TensorShape({0}), {}, true);
   TF_ASSERT_OK(RunOpKernel());
@@ -86,55 +112,26 @@ TEST_F(RaggedTensorToVariantKernelTest, NoValuesInput) {
   const auto& encoded_list = GetOutput(0)->vec<Variant>();
   EXPECT_EQ(encoded_list.size(), 3);
 
-  const Variant& encoded_splits_1_1 =
-      encoded_list(0).get<Tensor>()->vec<Variant>()(0);
-  const Variant& encoded_values_1 =
-      encoded_list(0).get<Tensor>()->vec<Variant>()(1);
-  const Variant& encoded_splits_2_1 =
-      encoded_list(1).get<Tensor>()->vec<Variant>()(0);
-  const Variant& encoded_values_2 =
-      encoded_list(1).get<Tensor>()->vec<Variant>()(1);
-  const Variant& encoded_splits_3_1 =
-      encoded_list(2).get<Tensor>()->vec<Variant>()(0);
-  const Variant& encoded_values_3 =
-      encoded_list(2).get<Tensor>()->vec<Variant>()(1);
-
-  test::ExpectTensorEqual<int64>(*encoded_splits_1_1.get<Tensor>(),
-                                 expected_splits_1_1);
-  test::ExpectTensorEqual<int64>(*encoded_splits_2_1.get<Tensor>(),
-                                 expected_splits_2_1);
-  test::ExpectTensorEqual<int64>(*encoded_splits_3_1.get<Tensor>(),
-                                 expected_splits_3_1);
-  test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
-                               Tensor(DT_INT32, TensorShape({0})));
-  test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
-                               Tensor(DT_INT32, TensorShape({0})));
-  test::ExpectTensorEqual<int>(*encoded_values_3.get<Tensor>(),
-                               Tensor(DT_INT32, TensorShape({0})));
+  ExpectRaggedTensorVariantEqual<int, int64>(
+      CreateVariantFromRagged<int, int64>({{0, 0, 0}}, {}),
+      *encoded_list(0).get<RaggedTensorVariant>());
+  ExpectRaggedTensorVariantEqual<int, int64>(
+      CreateVariantFromRagged<int, int64>({{0, 0}}, {}),
+      *encoded_list(1).get<RaggedTensorVariant>());
+  ExpectRaggedTensorVariantEqual<int, int64>(
+      CreateVariantFromRagged<int, int64>({{0}}, {}),
+      *encoded_list(2).get<RaggedTensorVariant>());
 }
 
 TEST_F(RaggedTensorToVariantKernelTest, 1DValuesRaggedRankOneInput) {
   // ragged_tensor=
-  // [ [x, x, x],
+  // [ [1, 2, 3],
   //   [       ],
-  //   [x, x   ],
-  //   [x      ]]
+  //   [4, 5   ],
+  //   [6      ]]
   const std::vector<int64> batched_splits = {0, 3, 3, 5, 6};
   const std::vector<int> batched_values = {1, 2, 3, 4, 5, 6};
 
-  const std::vector<int> component_values_1 = {1, 2, 3};
-  const std::vector<int> component_values_3 = {4, 5};
-  const std::vector<int> component_values_4 = {6};
-
-  Tensor expected_values_1(DT_INT32, TensorShape({3}));
-  Tensor expected_values_2(DT_INT32, TensorShape({0}));
-  Tensor expected_values_3(DT_INT32, TensorShape({2}));
-  Tensor expected_values_4(DT_INT32, TensorShape({1}));
-
-  test::FillValues<int>(&expected_values_1, component_values_1);
-  test::FillValues<int>(&expected_values_3, component_values_3);
-  test::FillValues<int>(&expected_values_4, component_values_4);
-
   BuildEncodeRaggedTensorGraph<int, int64>({batched_splits}, TensorShape({6}),
                                            batched_values, true);
   TF_ASSERT_OK(RunOpKernel());
@@ -142,45 +139,28 @@ TEST_F(RaggedTensorToVariantKernelTest, 1DValuesRaggedRankOneInput) {
   const auto& encoded_list = GetOutput(0)->vec<Variant>();
   EXPECT_EQ(encoded_list.size(), 4);
 
-  const Variant& encoded_values_1 =
-      encoded_list(0).get<Tensor>()->vec<Variant>()(0);
-  const Variant& encoded_values_2 =
-      encoded_list(1).get<Tensor>()->vec<Variant>()(0);
-  const Variant& encoded_values_3 =
-      encoded_list(2).get<Tensor>()->vec<Variant>()(0);
-  const Variant& encoded_values_4 =
-      encoded_list(3).get<Tensor>()->vec<Variant>()(0);
-
-  test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
-                               expected_values_1);
-  test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
-                               expected_values_2);
-  test::ExpectTensorEqual<int>(*encoded_values_3.get<Tensor>(),
-                               expected_values_3);
-  test::ExpectTensorEqual<int>(*encoded_values_4.get<Tensor>(),
-                               expected_values_4);
+  ExpectRaggedTensorVariantEqual<int, int64>(
+      CreateVariantFromRagged<int, int64>({}, {1, 2, 3}),
+      *encoded_list(0).get<RaggedTensorVariant>());
+  ExpectRaggedTensorVariantEqual<int, int64>(
+      CreateVariantFromRagged<int, int64>({}, {}),
+      *encoded_list(1).get<RaggedTensorVariant>());
+  ExpectRaggedTensorVariantEqual<int, int64>(
+      CreateVariantFromRagged<int, int64>({}, {4, 5}),
+      *encoded_list(2).get<RaggedTensorVariant>());
+  ExpectRaggedTensorVariantEqual<int, int64>(
+      CreateVariantFromRagged<int, int64>({}, {6}),
+      *encoded_list(3).get<RaggedTensorVariant>());
 }
 
 TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankOneInput) {
   // ragged_tensor=
-  // [[x, x],
-  //  [x, x],
-  //  [x, x]]
+  // [[1, 2],
+  //  [4, 5],
+  //  [6, 7]]
   const std::vector<int64> batched_splits = {0, 1, 2, 3};
   const std::vector<int> batched_values = {1, 2, 4, 5, 6, 7};
 
-  const std::vector<int> component_values_1 = {1, 2};
-  const std::vector<int> component_values_2 = {4, 5};
-  const std::vector<int> component_values_3 = {6, 7};
-
-  Tensor expected_values_1(DT_INT32, TensorShape({1, 2}));
-  Tensor expected_values_2(DT_INT32, TensorShape({1, 2}));
-  Tensor expected_values_3(DT_INT32, TensorShape({1, 2}));
-
-  test::FillValues<int>(&expected_values_1, component_values_1);
-  test::FillValues<int>(&expected_values_2, component_values_2);
-  test::FillValues<int>(&expected_values_3, component_values_3);
-
   BuildEncodeRaggedTensorGraph<int, int64>(
       {batched_splits}, TensorShape({3, 2}), batched_values, true);
   TF_ASSERT_OK(RunOpKernel());
@@ -188,44 +168,25 @@ TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankOneInput) {
   const auto& encoded_list = GetOutput(0)->vec<Variant>();
   EXPECT_EQ(encoded_list.size(), 3);
 
-  const Variant& encoded_values_1 =
-      encoded_list(0).get<Tensor>()->vec<Variant>()(0);
-  const Variant& encoded_values_2 =
-      encoded_list(1).get<Tensor>()->vec<Variant>()(0);
-  const Variant& encoded_values_3 =
-      encoded_list(2).get<Tensor>()->vec<Variant>()(0);
-
-  test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
-                               expected_values_1);
-  test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
-                               expected_values_2);
-  test::ExpectTensorEqual<int>(*encoded_values_3.get<Tensor>(),
-                               expected_values_3);
+  ExpectRaggedTensorVariantEqual<int, int64>(
+      CreateVariantFromRagged<int, int64>({}, {1, 2}, {1, 2}),
+      *encoded_list(0).get<RaggedTensorVariant>());
+  ExpectRaggedTensorVariantEqual<int, int64>(
+      CreateVariantFromRagged<int, int64>({}, {1, 2}, {4, 5}),
+      *encoded_list(1).get<RaggedTensorVariant>());
+  ExpectRaggedTensorVariantEqual<int, int64>(
+      CreateVariantFromRagged<int, int64>({}, {1, 2}, {6, 7}),
+      *encoded_list(2).get<RaggedTensorVariant>());
 }
 
 TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankTwoInput) {
-  // ragged_tensor=[
-  // [ [[x, x], [x, x]],
-  //   [[x, x]        ] ]
+  // ragged_tensor=
+  // [ [[[1, 2], [4, 5]]],
+  //   [[[6 7]]]          ]
   const std::vector<int64> batched_splits_1 = {0, 1, 2};
   const std::vector<int64> batched_splits_2 = {0, 2, 3};
   const std::vector<int> batched_values = {1, 2, 4, 5, 6, 7};
 
-  const std::vector<int64> component_splits_1_1 = {0, 2};
-  const std::vector<int64> component_splits_2_1 = {0, 1};
-  const std::vector<int> component_values_1 = {1, 2, 4, 5};
-  const std::vector<int> component_values_2 = {6, 7};
-
-  Tensor expected_splits_1_1(DT_INT64, TensorShape({2}));
-  Tensor expected_splits_2_1(DT_INT64, TensorShape({2}));
-  Tensor expected_values_1(DT_INT32, TensorShape({2, 2}));
-  Tensor expected_values_2(DT_INT32, TensorShape({1, 2}));
-
-  test::FillValues<int64>(&expected_splits_1_1, component_splits_1_1);
-  test::FillValues<int64>(&expected_splits_2_1, component_splits_2_1);
-  test::FillValues<int>(&expected_values_1, component_values_1);
-  test::FillValues<int>(&expected_values_2, component_values_2);
-
   BuildEncodeRaggedTensorGraph<int, int64>({batched_splits_1, batched_splits_2},
                                            TensorShape({3, 2}), batched_values,
                                            true);
@@ -234,23 +195,12 @@ TEST_F(RaggedTensorToVariantKernelTest, 2DBatchedValuesRankTwoInput) {
   const auto& encoded_list = GetOutput(0)->vec<Variant>();
   EXPECT_EQ(encoded_list.size(), 2);
 
-  const Variant& encoded_splits_1_1 =
-      encoded_list(0).get<Tensor>()->vec<Variant>()(0);
-  const Variant& encoded_values_1 =
-      encoded_list(0).get<Tensor>()->vec<Variant>()(1);
-  const Variant& encoded_splits_2_1 =
-      encoded_list(1).get<Tensor>()->vec<Variant>()(0);
-  const Variant& encoded_values_2 =
-      encoded_list(1).get<Tensor>()->vec<Variant>()(1);
-
-  test::ExpectTensorEqual<int64>(*encoded_splits_1_1.get<Tensor>(),
-                                 expected_splits_1_1);
-  test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
-                               expected_values_1);
-  test::ExpectTensorEqual<int64>(*encoded_splits_2_1.get<Tensor>(),
-                                 expected_splits_2_1);
-  test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
-                               expected_values_2);
+  ExpectRaggedTensorVariantEqual<int, int64>(
+      CreateVariantFromRagged<int, int64>({{0, 2}}, {2, 2}, {1, 2, 4, 5}),
+      *encoded_list(0).get<RaggedTensorVariant>());
+  ExpectRaggedTensorVariantEqual<int, int64>(
+      CreateVariantFromRagged<int, int64>({{0, 1}}, {1, 2}, {6, 7}),
+      *encoded_list(1).get<RaggedTensorVariant>());
 }
 
 TEST_F(RaggedTensorToVariantKernelTest, EmptyRowInBatchedInput) {
@@ -263,30 +213,6 @@ TEST_F(RaggedTensorToVariantKernelTest, EmptyRowInBatchedInput) {
   const std::vector<int64> batched_splits_2 = {0, 1, 3, 3, 8, 11, 11, 15};
   const std::vector<int> batched_values = {1, 2,  3,  4,  5,  6,  7, 8,
                                            9, 10, 11, 12, 13, 14, 15};
-  const std::vector<int64> component_splits_1_1 = {0, 1, 3, 3};
-  const std::vector<int64> component_splits_2_1 = {0};
-  const std::vector<int64> component_splits_3_1 = {0, 5, 8};
-  const std::vector<int64> component_splits_4_1 = {0, 0, 4};
-  const std::vector<int> component_values_1 = {1, 2, 3};
-  const std::vector<int> component_values_3 = {4, 5, 6, 7, 8, 9, 10, 11};
-  const std::vector<int> component_values_4 = {12, 13, 14, 15};
-
-  Tensor expected_splits_1_1(DT_INT64, TensorShape({4}));
-  Tensor expected_splits_2_1(DT_INT64, TensorShape({1}));
-  Tensor expected_splits_3_1(DT_INT64, TensorShape({3}));
-  Tensor expected_splits_4_1(DT_INT64, TensorShape({3}));
-  Tensor expected_values_1(DT_INT32, TensorShape({3}));
-  Tensor expected_values_2(DT_INT32, TensorShape({0}));
-  Tensor expected_values_3(DT_INT32, TensorShape({8}));
-  Tensor expected_values_4(DT_INT32, TensorShape({4}));
-
-  test::FillValues<int64>(&expected_splits_1_1, component_splits_1_1);
-  test::FillValues<int64>(&expected_splits_2_1, component_splits_2_1);
-  test::FillValues<int64>(&expected_splits_3_1, component_splits_3_1);
-  test::FillValues<int64>(&expected_splits_4_1, component_splits_4_1);
-  test::FillValues<int>(&expected_values_1, component_values_1);
-  test::FillValues<int>(&expected_values_3, component_values_3);
-  test::FillValues<int>(&expected_values_4, component_values_4);
 
   BuildEncodeRaggedTensorGraph<int, int64>({batched_splits_1, batched_splits_2},
                                            TensorShape({15}), batched_values,
@@ -296,39 +222,19 @@ TEST_F(RaggedTensorToVariantKernelTest, EmptyRowInBatchedInput) {
   const auto& encoded_list = GetOutput(0)->vec<Variant>();
   EXPECT_EQ(encoded_list.size(), 4);
 
-  const Variant& encoded_splits_1_1 =
-      encoded_list(0).get<Tensor>()->vec<Variant>()(0);
-  const Variant& encoded_values_1 =
-      encoded_list(0).get<Tensor>()->vec<Variant>()(1);
-  const Variant& encoded_splits_2_1 =
-      encoded_list(1).get<Tensor>()->vec<Variant>()(0);
-  const Variant& encoded_values_2 =
-      encoded_list(1).get<Tensor>()->vec<Variant>()(1);
-  const Variant& encoded_splits_3_1 =
-      encoded_list(2).get<Tensor>()->vec<Variant>()(0);
-  const Variant& encoded_values_3 =
-      encoded_list(2).get<Tensor>()->vec<Variant>()(1);
-  const Variant& encoded_splits_4_1 =
-      encoded_list(3).get<Tensor>()->vec<Variant>()(0);
-  const Variant& encoded_values_4 =
-      encoded_list(3).get<Tensor>()->vec<Variant>()(1);
-
-  test::ExpectTensorEqual<int64>(*encoded_splits_1_1.get<Tensor>(),
-                                 expected_splits_1_1);
-  test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
-                               expected_values_1);
-  test::ExpectTensorEqual<int64>(*encoded_splits_2_1.get<Tensor>(),
-                                 expected_splits_2_1);
-  test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
-                               expected_values_2);
-  test::ExpectTensorEqual<int64>(*encoded_splits_3_1.get<Tensor>(),
-                                 expected_splits_3_1);
-  test::ExpectTensorEqual<int>(*encoded_values_3.get<Tensor>(),
-                               expected_values_3);
-  test::ExpectTensorEqual<int64>(*encoded_splits_4_1.get<Tensor>(),
-                                 expected_splits_4_1);
-  test::ExpectTensorEqual<int>(*encoded_values_4.get<Tensor>(),
-                               expected_values_4);
+  ExpectRaggedTensorVariantEqual<int, int64>(
+      CreateVariantFromRagged<int, int64>({{0, 1, 3, 3}}, {1, 2, 3}),
+      *encoded_list(0).get<RaggedTensorVariant>());
+  ExpectRaggedTensorVariantEqual<int, int64>(
+      CreateVariantFromRagged<int, int64>({{0}}, {}),
+      *encoded_list(1).get<RaggedTensorVariant>());
+  ExpectRaggedTensorVariantEqual<int, int64>(
+      CreateVariantFromRagged<int, int64>({{0, 5, 8}},
+                                          {4, 5, 6, 7, 8, 9, 10, 11}),
+      *encoded_list(2).get<RaggedTensorVariant>());
+  ExpectRaggedTensorVariantEqual<int, int64>(
+      CreateVariantFromRagged<int, int64>({{0, 0, 4}}, {12, 13, 14, 15}),
+      *encoded_list(3).get<RaggedTensorVariant>());
 }
 
 TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInput) {
@@ -350,26 +256,6 @@ TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInput) {
                                                7, 8, 9, 12, 13, 14};
   const std::vector<int> batched_values = {0, 1, 1, 2, 2, 3, 4,
                                            5, 6, 7, 8, 9, 8, 9};
-  const std::vector<int64> component_split_1_1 = {0, 1, 3, 4, 5, 6};
-  const std::vector<int64> component_split_1_2 = {0, 2, 3, 4, 5, 6, 7};
-  const std::vector<int64> component_split_2_1 = {0, 1, 2, 3, 4, 5};
-  const std::vector<int64> component_split_2_2 = {0, 1, 2, 5, 6, 7};
-  const std::vector<int> component_values_1 = {0, 1, 1, 2, 2, 3, 4};
-  const std::vector<int> component_values_2 = {5, 6, 7, 8, 9, 8, 9};
-
-  Tensor expected_splits_1_1(DT_INT64, TensorShape({6}));
-  Tensor expected_splits_1_2(DT_INT64, TensorShape({7}));
-  Tensor expected_splits_2_1(DT_INT64, TensorShape({6}));
-  Tensor expected_splits_2_2(DT_INT64, TensorShape({6}));
-  Tensor expected_values_1(DT_INT32, TensorShape({7}));
-  Tensor expected_values_2(DT_INT32, TensorShape({7}));
-
-  test::FillValues<int64>(&expected_splits_1_1, component_split_1_1);
-  test::FillValues<int64>(&expected_splits_1_2, component_split_1_2);
-  test::FillValues<int64>(&expected_splits_2_1, component_split_2_1);
-  test::FillValues<int64>(&expected_splits_2_2, component_split_2_2);
-  test::FillValues<int>(&expected_values_1, component_values_1);
-  test::FillValues<int>(&expected_values_2, component_values_2);
 
   BuildEncodeRaggedTensorGraph<int, int64>(
       {batched_splits_1, batched_splits_2, batched_splits_3}, TensorShape({14}),
@@ -379,31 +265,14 @@ TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInput) {
   const auto& encoded_list = GetOutput(0)->vec<Variant>();
   EXPECT_EQ(encoded_list.size(), 2);
 
-  const Variant& encoded_splits_1_1 =
-      encoded_list(0).get<Tensor>()->vec<Variant>()(0);
-  const Variant& encoded_splits_1_2 =
-      encoded_list(0).get<Tensor>()->vec<Variant>()(1);
-  const Variant& encoded_values_1 =
-      encoded_list(0).get<Tensor>()->vec<Variant>()(2);
-  const Variant& encoded_splits_2_1 =
-      encoded_list(1).get<Tensor>()->vec<Variant>()(0);
-  const Variant& encoded_splits_2_2 =
-      encoded_list(1).get<Tensor>()->vec<Variant>()(1);
-  const Variant& encoded_values_2 =
-      encoded_list(1).get<Tensor>()->vec<Variant>()(2);
-
-  test::ExpectTensorEqual<int64>(*encoded_splits_1_1.get<Tensor>(),
-                                 expected_splits_1_1);
-  test::ExpectTensorEqual<int64>(*encoded_splits_1_2.get<Tensor>(),
-                                 expected_splits_1_2);
-  test::ExpectTensorEqual<int64>(*encoded_splits_2_1.get<Tensor>(),
-                                 expected_splits_2_1);
-  test::ExpectTensorEqual<int64>(*encoded_splits_2_2.get<Tensor>(),
-                                 expected_splits_2_2);
-  test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
-                               expected_values_1);
-  test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
-                               expected_values_2);
+  ExpectRaggedTensorVariantEqual<int, int64>(
+      CreateVariantFromRagged<int, int64>(
+          {{0, 1, 3, 4, 5, 6}, {0, 2, 3, 4, 5, 6, 7}}, {0, 1, 1, 2, 2, 3, 4}),
+      *encoded_list(0).get<RaggedTensorVariant>());
+  ExpectRaggedTensorVariantEqual<int, int64>(
+      CreateVariantFromRagged<int, int64>(
+          {{0, 1, 2, 3, 4, 5}, {0, 1, 2, 5, 6, 7}}, {5, 6, 7, 8, 9, 8, 9}),
+      *encoded_list(1).get<RaggedTensorVariant>());
 }
 
 TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInputInt32Splits) {
@@ -424,28 +293,8 @@ TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInputInt32Splits) {
                                              7, 8, 9, 12, 13, 14};
   const std::vector<int> batched_values = {0, 1, 1, 2, 2, 3, 4,
                                            5, 6, 7, 8, 9, 8, 9};
-  const std::vector<int> component_split_1_1 = {0, 1, 3, 4, 5, 6};
-  const std::vector<int> component_split_1_2 = {0, 2, 3, 4, 5, 6, 7};
-  const std::vector<int> component_split_2_1 = {0, 1, 2, 3, 4, 5};
-  const std::vector<int> component_split_2_2 = {0, 1, 2, 5, 6, 7};
-  const std::vector<int> component_values_1 = {0, 1, 1, 2, 2, 3, 4};
-  const std::vector<int> component_values_2 = {5, 6, 7, 8, 9, 8, 9};
 
-  Tensor expected_splits_1_1(DT_INT32, TensorShape({6}));
-  Tensor expected_splits_1_2(DT_INT32, TensorShape({7}));
-  Tensor expected_splits_2_1(DT_INT32, TensorShape({6}));
-  Tensor expected_splits_2_2(DT_INT32, TensorShape({6}));
-  Tensor expected_values_1(DT_INT32, TensorShape({7}));
-  Tensor expected_values_2(DT_INT32, TensorShape({7}));
-
-  test::FillValues<int>(&expected_splits_1_1, component_split_1_1);
-  test::FillValues<int>(&expected_splits_1_2, component_split_1_2);
-  test::FillValues<int>(&expected_splits_2_1, component_split_2_1);
-  test::FillValues<int>(&expected_splits_2_2, component_split_2_2);
-  test::FillValues<int>(&expected_values_1, component_values_1);
-  test::FillValues<int>(&expected_values_2, component_values_2);
-
-  BuildEncodeRaggedTensorGraph<int, int>(
+  BuildEncodeRaggedTensorGraph<int, int32>(
       {batched_splits_1, batched_splits_2, batched_splits_3}, TensorShape({14}),
       batched_values, true);
   TF_ASSERT_OK(RunOpKernel());
@@ -453,31 +302,14 @@ TEST_F(RaggedTensorToVariantKernelTest, NonEmptyBatchedInputInt32Splits) {
   const auto& encoded_list = GetOutput(0)->vec<Variant>();
   EXPECT_EQ(encoded_list.size(), 2);
 
-  const Variant& encoded_splits_1_1 =
-      encoded_list(0).get<Tensor>()->vec<Variant>()(0);
-  const Variant& encoded_splits_1_2 =
-      encoded_list(0).get<Tensor>()->vec<Variant>()(1);
-  const Variant& encoded_values_1 =
-      encoded_list(0).get<Tensor>()->vec<Variant>()(2);
-  const Variant& encoded_splits_2_1 =
-      encoded_list(1).get<Tensor>()->vec<Variant>()(0);
-  const Variant& encoded_splits_2_2 =
-      encoded_list(1).get<Tensor>()->vec<Variant>()(1);
-  const Variant& encoded_values_2 =
-      encoded_list(1).get<Tensor>()->vec<Variant>()(2);
-
-  test::ExpectTensorEqual<int>(*encoded_splits_1_1.get<Tensor>(),
-                               expected_splits_1_1);
-  test::ExpectTensorEqual<int>(*encoded_splits_1_2.get<Tensor>(),
-                               expected_splits_1_2);
-  test::ExpectTensorEqual<int>(*encoded_splits_2_1.get<Tensor>(),
-                               expected_splits_2_1);
-  test::ExpectTensorEqual<int>(*encoded_splits_2_2.get<Tensor>(),
-                               expected_splits_2_2);
-  test::ExpectTensorEqual<int>(*encoded_values_1.get<Tensor>(),
-                               expected_values_1);
-  test::ExpectTensorEqual<int>(*encoded_values_2.get<Tensor>(),
-                               expected_values_2);
+  ExpectRaggedTensorVariantEqual<int, int32>(
+      CreateVariantFromRagged<int, int32>(
+          {{0, 1, 3, 4, 5, 6}, {0, 2, 3, 4, 5, 6, 7}}, {0, 1, 1, 2, 2, 3, 4}),
+      *encoded_list(0).get<RaggedTensorVariant>());
+  ExpectRaggedTensorVariantEqual<int, int32>(
+      CreateVariantFromRagged<int, int32>(
+          {{0, 1, 2, 3, 4, 5}, {0, 1, 2, 5, 6, 7}}, {5, 6, 7, 8, 9, 8, 9}),
+      *encoded_list(1).get<RaggedTensorVariant>());
 }
 
 TEST_F(RaggedTensorToVariantKernelTest, NonBatchInput) {
@@ -491,33 +323,17 @@ TEST_F(RaggedTensorToVariantKernelTest, NonBatchInput) {
   const std::vector<int> batched_values = {1, 2,  3,  4,  5,  6,  7, 8,
                                            9, 10, 11, 12, 13, 14, 15};
 
-  Tensor batched_ragged_splits_1(DT_INT64, TensorShape({5}));
-  Tensor batched_ragged_splits_2(DT_INT64, TensorShape({8}));
-  Tensor batched_ragged_values(DT_INT32, TensorShape({15}));
-
-  test::FillValues<int64>(&batched_ragged_splits_1, batched_splits_1);
-  test::FillValues<int64>(&batched_ragged_splits_2, batched_splits_2);
-  test::FillValues<int>(&batched_ragged_values, batched_values);
-
   BuildEncodeRaggedTensorGraph<int, int64>({batched_splits_1, batched_splits_2},
                                            TensorShape({15}), batched_values,
                                            false);
   TF_ASSERT_OK(RunOpKernel());
 
   const auto& encoded_scalar = GetOutput(0)->scalar<Variant>()();
-  const Variant& encoded_splits_1 =
-      encoded_scalar.get<Tensor>()->vec<Variant>()(0);
-  const Variant& encoded_splits_2 =
-      encoded_scalar.get<Tensor>()->vec<Variant>()(1);
-  const Variant& encoded_values =
-      encoded_scalar.get<Tensor>()->vec<Variant>()(2);
 
-  test::ExpectTensorEqual<int64>(*encoded_splits_1.get<Tensor>(),
-                                 batched_ragged_splits_1);
-  test::ExpectTensorEqual<int64>(*encoded_splits_2.get<Tensor>(),
-                                 batched_ragged_splits_2);
-  test::ExpectTensorEqual<int>(*encoded_values.get<Tensor>(),
-                               batched_ragged_values);
+  ExpectRaggedTensorVariantEqual<int, int64>(
+      CreateVariantFromRagged<int, int64>({batched_splits_1, batched_splits_2},
+                                          batched_values),
+      *encoded_scalar.get<RaggedTensorVariant>());
 }
 
 TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestBatched) {
@@ -598,17 +414,14 @@ TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestNotBatched) {
 
 TEST_F(RaggedTensorToVariantKernelTest, NonRaggedInput) {
   const std::vector<int> values = {1, 2, 3, 4, 5, 6};
-  Tensor expected_values(DT_INT32, TensorShape({6}));
-  test::FillValues<int>(&expected_values, values);
 
   BuildEncodeRaggedTensorGraph<int, int64>({}, TensorShape({6}), values, false);
   TF_ASSERT_OK(RunOpKernel());
 
   const auto& encoded_scalar = GetOutput(0)->scalar<Variant>()();
-  const Variant& encoded_values =
-      encoded_scalar.get<Tensor>()->vec<Variant>()(0);
-
-  test::ExpectTensorEqual<int>(*encoded_values.get<Tensor>(), expected_values);
+  ExpectRaggedTensorVariantEqual<int, int64>(
+      CreateVariantFromRagged<int, int64>({}, values),
+      *encoded_scalar.get<RaggedTensorVariant>());
 }
 
 }  // namespace
diff --git a/tensorflow/core/kernels/ragged_tensor_variant.cc b/tensorflow/core/kernels/ragged_tensor_variant.cc
new file mode 100644
index 00000000000..9466313819b
--- /dev/null
+++ b/tensorflow/core/kernels/ragged_tensor_variant.cc
@@ -0,0 +1,86 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#define EIGEN_USE_GPU
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+
+#include "tensorflow/core/kernels/ragged_tensor_variant.h"
+
+namespace tensorflow {
+
+string RaggedTensorVariant::TypeName() const { return "RaggedTensorVariant"; }
+
+string RaggedTensorVariant::DebugString() const {
+  return absl::StrCat(
+      "RaggedTensorVariant(dtype=", DataTypeString(values_.dtype()),
+      ", ragged_rank=", nested_splits_.size(), ", splits_dtype=",
+      DataTypeString(nested_splits_.empty() ? DT_INVALID
+                                            : nested_splits_.back().dtype()));
+}
+
+void RaggedTensorVariant::Encode(VariantTensorData* data) const {
+  data->set_type_name(TypeName());
+  for (const auto& splits : nested_splits_) {
+    *data->add_tensors() = splits;
+  }
+  *data->add_tensors() = values_;
+}
+
+bool RaggedTensorVariant::Decode(const VariantTensorData& data) {
+  if (data.tensors_size() < 1) {
+    return false;
+  }
+  nested_splits_.assign(data.tensors().begin(),
+                        std::prev(data.tensors().end()));
+  values_ = data.tensors().back();
+  return true;
+}
+
+namespace {
+
+Status RaggedTensorVariantDeviceCopy(
+    const RaggedTensorVariant& from, RaggedTensorVariant* to,
+    const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
+  TF_RETURN_IF_ERROR(copy(from.values(), to->mutable_values()));
+  // TODO(b/170415165) Should we use `copy` to move splits from device<->host?
+  *to->mutable_nested_splits() = from.nested_splits();
+  return Status::OK();
+}
+
+}  // namespace
+
+REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(
+    ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, RaggedTensorVariant,
+    RaggedTensorVariantZerosLike<CPUDevice>);
+
+REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(
+    ADD_VARIANT_BINARY_OP, DEVICE_CPU, RaggedTensorVariant,
+    RaggedTensorVariantBinaryAdd<CPUDevice>);
+
+REGISTER_UNARY_VARIANT_DECODE_FUNCTION(RaggedTensorVariant,
+                                       "RaggedTensorVariant");
+
+#define REGISTER_RAGGED_TENSOR_VARIANT_COPY(DIRECTION)  \
+  INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
+      RaggedTensorVariant, DIRECTION, RaggedTensorVariantDeviceCopy)
+
+REGISTER_RAGGED_TENSOR_VARIANT_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
+REGISTER_RAGGED_TENSOR_VARIANT_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
+REGISTER_RAGGED_TENSOR_VARIANT_COPY(
+    VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/ragged_tensor_variant.h b/tensorflow/core/kernels/ragged_tensor_variant.h
new file mode 100644
index 00000000000..730758a3e82
--- /dev/null
+++ b/tensorflow/core/kernels/ragged_tensor_variant.h
@@ -0,0 +1,110 @@
+#include "tensorflow/core/framework/tensor_key.h"
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
+#define TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
+
+#define EIGEN_USE_THREADS
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#define EIGEN_USE_GPU
+#endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+
+#include <vector>
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
+#include "tensorflow/core/framework/variant_tensor_data.h"
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+#include "tensorflow/core/util/tensor_ops_util.h"
+
+namespace tensorflow {
+
+// Class used to store a RaggedTensor as a Variant scalar.
+class RaggedTensorVariant {
+ public:
+  RaggedTensorVariant() {}
+  RaggedTensorVariant(Tensor values, const std::vector<Tensor>& nested_splits)
+      : values_(std::move(values)), nested_splits_(nested_splits) {}
+
+  // Variant support methods.
+  string TypeName() const;
+  string DebugString() const;
+  void Encode(VariantTensorData* data) const;
+  bool Decode(const VariantTensorData& data);
+
+  // The flat_values of the RaggedTensor.
+  const Tensor& values() const { return values_; }
+  Tensor* mutable_values() { return &values_; }
+  void set_values(const Tensor& new_values) { values_ = new_values; }
+
+  // The nested row_splits of the RaggedTensor.
+  int ragged_rank() const { return nested_splits_.size(); }
+  const std::vector<Tensor>& nested_splits() const { return nested_splits_; }
+  std::vector<Tensor>* mutable_nested_splits() { return &nested_splits_; }
+  const Tensor& splits(int i) const { return nested_splits_[i]; }
+  Tensor* mutable_splits(int i) { return &nested_splits_[i]; }
+  void set_nested_splits(const std::vector<Tensor>& nested_splits) {
+    nested_splits_ = nested_splits;
+  }
+  void append_splits(const Tensor& splits) { nested_splits_.push_back(splits); }
+
+ private:
+  Tensor values_;
+  std::vector<Tensor> nested_splits_;
+};
+
+template <typename Device>
+Status RaggedTensorVariantZerosLike(OpKernelContext* c,
+                                    const RaggedTensorVariant& x,
+                                    RaggedTensorVariant* y) {
+  y->set_nested_splits(x.nested_splits());
+  TF_RETURN_IF_ERROR(
+      ZerosLikeTensor<Device>(c, x.values(), y->mutable_values()));
+  return Status::OK();
+}
+
+template <typename Device>
+Status RaggedTensorVariantBinaryAdd(OpKernelContext* c,
+                                    const RaggedTensorVariant& x,
+                                    const RaggedTensorVariant& y,
+                                    RaggedTensorVariant* out) {
+  if (x.values().dtype() != y.values().dtype()) {
+    return errors::InvalidArgument(
+        "Can't add RaggedTensorVariants of different dtypes. One is ",
+        DataTypeString(x.values().dtype()), " and the other is ",
+        DataTypeString(y.values().dtype()));
+  }
+  if (x.ragged_rank() != y.ragged_rank()) {
+    return errors::InvalidArgument(
+        "Can't add RaggedTensorVariants of different ragged rank. ", "One is ",
+        x.ragged_rank(), " and the other is ", y.ragged_rank());
+  }
+  for (int i = 0; i < x.ragged_rank(); ++i) {
+    if (TensorKey(x.splits(i)) != TensorKey(y.splits(i))) {
+      return errors::InvalidArgument(
+          "Can't add RaggedTensorVariants with different row_splits.");
+    }
+  }
+  out->set_nested_splits(x.nested_splits());
+  TF_RETURN_IF_ERROR(BinaryAddTensors<Device>(c, x.values(), y.values(),
+                                              out->mutable_values()));
+  return Status::OK();
+}
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
diff --git a/tensorflow/core/ops/ragged_conversion_ops.cc b/tensorflow/core/ops/ragged_conversion_ops.cc
index 44712bf7739..043ff469487 100644
--- a/tensorflow/core/ops/ragged_conversion_ops.cc
+++ b/tensorflow/core/ops/ragged_conversion_ops.cc
@@ -92,7 +92,8 @@ tensorflow::Status ValidateRowPartitionTypesAndShapes(
 Status RaggedTensorToSparseShapeFn(InferenceContext* c);
 Status RaggedTensorToVariantShapeFn(InferenceContext* c);
 Status RaggedTensorFromVariantShapeFn(InferenceContext* c);
-tensorflow::Status RaggedTensorToTensorShapeFn(InferenceContext* c);
+Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c);
+Status RaggedTensorToTensorShapeFn(InferenceContext* c);
 
 //==============================================================================
 // Registered Ops
@@ -129,6 +130,15 @@ REGISTER_OP("RaggedTensorFromVariant")
     .Attr("Tsplits: {int32, int64} = DT_INT64")
     .SetShapeFn(RaggedTensorFromVariantShapeFn);
 
+REGISTER_OP("RaggedTensorToVariantGradient")
+    .Input("encoded_ragged_grad: variant")
+    .Input("row_splits: Tsplits")
+    .Input("dense_values_shape: int32")
+    .Output("dense_values_grad: Tvalues")
+    .Attr("Tvalues: type")
+    .Attr("Tsplits: {int32, int64} = DT_INT64")
+    .SetShapeFn(RaggedTensorToVariantGradientShapeFn);
+
 REGISTER_OP("RaggedTensorToTensor")
     .Attr("T: type")
     .Attr("Tindex: {int64, int32}")
@@ -201,6 +211,14 @@ Status RaggedTensorToVariantShapeFn(InferenceContext* c) {
   return Status::OK();
 }
 
+Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c) {
+  ShapeHandle shape;
+  TF_RETURN_IF_ERROR(
+      c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &shape));
+  c->set_output(0, shape);
+  return Status::OK();
+}
+
 Status RaggedTensorFromVariantShapeFn(InferenceContext* c) {
   int64 input_ragged_rank;
   TF_RETURN_IF_ERROR(
diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD
index 309957a76a1..6bd517d7abc 100644
--- a/tensorflow/python/ops/ragged/BUILD
+++ b/tensorflow/python/ops/ragged/BUILD
@@ -510,6 +510,7 @@ py_test(
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:platform_test",
+        "//tensorflow/python:tensor_array_grad",
         "//tensorflow/python:tensor_shape",
         "//tensorflow/python:tensor_spec",
         "//tensorflow/python/data/ops:dataset_ops",
diff --git a/tensorflow/python/ops/ragged/ragged_conversion_ops.py b/tensorflow/python/ops/ragged/ragged_conversion_ops.py
index e8c625ccc73..e915d1ecd61 100644
--- a/tensorflow/python/ops/ragged/ragged_conversion_ops.py
+++ b/tensorflow/python/ops/ragged/ragged_conversion_ops.py
@@ -143,3 +143,42 @@ def to_sparse(rt_input, name=None):
 
 def from_sparse(st_input, name=None):
   return ragged_tensor.RaggedTensor.from_sparse(st_input, name)
+
+
+@ops.RegisterGradient("RaggedTensorFromVariant")
+def _ragged_tensor_from_variant_grad(op, *grads):
+  """Gradient for RaggedTensorFromVariant op."""
+
+  variant_rank = op.inputs[0].shape.rank
+  if variant_rank == 0:
+    batched_input = False
+  elif variant_rank == 1:
+    batched_input = True
+  elif variant_rank is None:
+    batched_input = (op.get_attr("output_ragged_rank") > 0)
+  else:
+    # TODO(edloper): Add a batch_dims argument to RaggedTensorToVariant, so
+    # we can support this.
+    raise ValueError("Unable to compute gradient: RaggedTensorToVariant "
+                     "can currently only generate 0D or 1D output.")
+  return [
+      gen_ragged_conversion_ops.ragged_tensor_to_variant(
+          rt_nested_splits=op.outputs[:-1],
+          rt_dense_values=grads[-1],
+          batched_input=batched_input)
+  ]
+
+
+@ops.RegisterGradient("RaggedTensorToVariant")
+def _ragged_tensor_to_variant_grad(op, encoded_ragged_grad):
+  """Gradient for RaggedTensorToVariant op."""
+  dense_values = op.inputs[-1]
+  ragged_rank = len(op.inputs) - 1
+  row_splits = 0 if ragged_rank == 0 else op.inputs[0]
+  values_grad = gen_ragged_conversion_ops.ragged_tensor_to_variant_gradient(
+      encoded_ragged_grad=encoded_ragged_grad,
+      row_splits=row_splits,
+      dense_values_shape=array_ops.shape(dense_values),
+      Tvalues=op.inputs[-1].dtype)
+  result = [None] * ragged_rank + [values_grad]
+  return result
diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py
index 5f713fa0793..800272d0dd9 100644
--- a/tensorflow/python/ops/ragged/ragged_tensor.py
+++ b/tensorflow/python/ops/ragged/ragged_tensor.py
@@ -2863,9 +2863,6 @@ def _get_optional_partition_dtype(values):
   return None
 
 
-ops.no_gradient("RaggedTensorToVariant")
-
-
 _SUPPORTED_RAGGED_VALUE_TYPES = (ops.Tensor, RaggedTensor)
 
 
diff --git a/tensorflow/python/ops/ragged/ragged_tensor_test.py b/tensorflow/python/ops/ragged/ragged_tensor_test.py
index d92cb9cec6c..a38c5527305 100644
--- a/tensorflow/python/ops/ragged/ragged_tensor_test.py
+++ b/tensorflow/python/ops/ragged/ragged_tensor_test.py
@@ -18,10 +18,12 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import functools
 from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -30,8 +32,15 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_grad  # pylint: disable=unused-import
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_ragged_conversion_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import map_fn
+from tensorflow.python.ops import math_grad  # pylint: disable=unused-import
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
 from tensorflow.python.ops.ragged import ragged_factory_ops
 from tensorflow.python.ops.ragged import ragged_math_ops
 from tensorflow.python.ops.ragged import ragged_tensor
@@ -1233,19 +1242,21 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
       with self.assertRaises(errors.InvalidArgumentError):
         self.evaluate(factory(**kwargs))
 
+  #=============================================================================
+  # RaggedTensor Variant conversion
+  #=============================================================================
 
-#=============================================================================
-# RaggedTensor Variant conversion
-#=============================================================================
-
-  @parameterized.parameters(
+  @parameterized.named_parameters(
       {
+          'testcase_name': 'Shape_5_none',
           'ragged_constant': [[1, 2], [3, 4, 5], [6], [], [7]],
           'ragged_rank': 1
       }, {
+          'testcase_name': 'Shape_4_none_2',
           'ragged_constant': [[[1, 2]], [], [[3, 4]], []],
           'ragged_rank': 1
       }, {
+          'testcase_name': 'Shape_1_none_none',
           'ragged_constant': [[[1], [2, 3, 4, 5, 6, 7]], [[]]],
           'ragged_rank': 2
       })
@@ -1432,6 +1443,131 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
           output_ragged_rank=1,
           input_ragged_rank=1)
 
+  def _testRaggedVarientGradient(self, func, x, expected_grad):
+    x = constant_op.constant(x)
+    if context.executing_eagerly():
+      with backprop.GradientTape() as t:
+        t.watch(x)
+        y = func(x)
+        g = t.gradient(y, x)
+    else:
+      y = func(x)
+      g = gradients_impl.gradients(ys=y, xs=x)[0]
+    self.assertAllClose(g, expected_grad)
+
+  def testRaggedVariantGradients(self):
+    def func(x):
+      rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
+      rt2 = rt1 * [[10], [100], [1000]]
+      v = rt2._to_variant(batched_input=False)
+      rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1)
+      return rt3.flat_values
+
+    self._testRaggedVarientGradient(
+        func,
+        [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
+        [10., 10., 10., 10., 100., 100., 100., 1000.])
+
+  def testRaggedVariantGradientsBatched(self):
+    def func(x):
+      rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
+      rt2 = rt1 * [[10], [100], [1000]]
+      v = rt2._to_variant(batched_input=True)
+      rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1)
+      return rt3.flat_values
+
+    self._testRaggedVarientGradient(
+        func,
+        [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
+        [10., 10., 10., 10., 100., 100., 100., 1000.])
+
+  def testRaggedVariantGradientsBatchedAndSliced(self):
+    def func(x, i):
+      rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
+      rt2 = rt1 * [[10], [100], [1000]]
+      v_slice = rt2._to_variant(batched_input=True)[i]
+      return RaggedTensor._from_variant(v_slice, dtype=rt2.dtype,
+                                        output_ragged_rank=0)
+
+    self._testRaggedVarientGradient(
+        functools.partial(func, i=0),
+        [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
+        [10., 10., 10., 10., 0., 0., 0., 0.])
+    self._testRaggedVarientGradient(
+        functools.partial(func, i=1),
+        [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
+        [0., 0., 0., 0., 100., 100., 100., 0.])
+    self._testRaggedVarientGradient(
+        functools.partial(func, i=2),
+        [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
+        [0., 0., 0., 0., 0., 0., 0., 1000.])
+
+  def testRaggedVariantGradientsRaggedRank0(self):
+    def func(x):
+      x2 = x * 2
+      v = gen_ragged_conversion_ops.ragged_tensor_to_variant(
+          [], x2, batched_input=False)
+      return RaggedTensor._from_variant(v, dtype=x2.dtype, output_ragged_rank=0)
+
+    self._testRaggedVarientGradient(
+        func,
+        [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
+        [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0])
+
+  def testRaggedVariantGradientsRaggedRank3(self):
+    def func(x):
+      x2 = x * 2
+      rt1 = RaggedTensor.from_nested_row_splits(
+          x2, ([0, 0, 3], [0, 2, 2, 3], [0, 4, 7, 8]))
+      v = rt1._to_variant(batched_input=False)
+      rt3 = RaggedTensor._from_variant(v, dtype=x2.dtype, output_ragged_rank=3)
+      return rt3.flat_values
+
+    self._testRaggedVarientGradient(
+        func,
+        [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
+        [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0])
+
+  def testRaggedVariantGradientsViaMapFn(self):
+    rt = RaggedTensor.from_row_splits(
+        values=[3, 1.0, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 7, 8])
+
+    def func(x):
+
+      def transform_row(row):
+        return math_ops.sqrt(
+            math_ops.reduce_mean(math_ops.square(row * x), keepdims=True))
+
+      return math_ops.reduce_sum(map_fn.map_fn(transform_row, rt))
+
+    self._testRaggedVarientGradient(func, 3.0, 14.653377)
+
+  def testRaggedVariantGradientsViaMapFnReduce(self):
+    def func(x):
+      rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
+      return map_fn.map_fn(
+          math_ops.reduce_max, rt1,
+          fn_output_signature=tensor_spec.TensorSpec((), x.dtype))
+
+    self._testRaggedVarientGradient(
+        func,
+        [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
+        [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0])
+
+  def testRaggedVariantGradientsErrors(self):
+    if context.executing_eagerly():
+      return
+
+    rt = RaggedTensor.from_row_splits([1.0, 2.0], row_splits=[0, 2, 2])
+    v1 = rt._to_variant()
+    v2 = array_ops.stack([array_ops.stack([v1])])
+    y = RaggedTensor._from_variant(v2, rt.dtype, output_ragged_rank=3)
+
+    with self.assertRaisesRegex(
+        ValueError, 'Unable to compute gradient: RaggedTensorToVariant '
+        'can currently only generate 0D or 1D output.'):
+      gradients_impl.gradients(ys=y.flat_values, xs=rt.flat_values)
+
   def assertNumpyObjectTensorsRecursivelyEqual(self, a, b, msg):
     """Check that two numpy arrays are equal.
 
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 2efd289c259..96be23b9e50 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -3212,6 +3212,10 @@ tf_module {
     name: "RaggedTensorToVariant"
     argspec: "args=[\'rt_nested_splits\', \'rt_dense_values\', \'batched_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
+  member_method {
+    name: "RaggedTensorToVariantGradient"
+    argspec: "args=[\'encoded_ragged_grad\', \'row_splits\', \'dense_values_shape\', \'Tvalues\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
   member_method {
     name: "RandomCrop"
     argspec: "args=[\'image\', \'size\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 2efd289c259..96be23b9e50 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -3212,6 +3212,10 @@ tf_module {
     name: "RaggedTensorToVariant"
     argspec: "args=[\'rt_nested_splits\', \'rt_dense_values\', \'batched_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
+  member_method {
+    name: "RaggedTensorToVariantGradient"
+    argspec: "args=[\'encoded_ragged_grad\', \'row_splits\', \'dense_values_shape\', \'Tvalues\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
   member_method {
     name: "RandomCrop"
     argspec: "args=[\'image\', \'size\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], "