From ccd35296e8d7569cd2bc9d2ab01541ddb6903308 Mon Sep 17 00:00:00 2001
From: Edward Loper <edloper@google.com>
Date: Thu, 25 Apr 2019 11:18:25 -0700
Subject: [PATCH] Add TensorSpec support for CompositeTensors (such as
 SparseTensor and RaggedTensor).

PiperOrigin-RevId: 245272144
---
 tensorflow/python/BUILD                       | 26 +++++++
 tensorflow/python/eager/function.py           |  4 +-
 tensorflow/python/eager/function_test.py      | 12 +++
 tensorflow/python/framework/func_graph.py     |  2 -
 .../framework/indexed_slices_tensor_spec.py   | 65 ++++++++++++++++
 tensorflow/python/framework/ops.py            |  9 ++-
 tensorflow/python/framework/sparse_tensor.py  | 35 +++++----
 .../python/framework/sparse_tensor_spec.py    | 56 ++++++++++++++
 tensorflow/python/ops/ragged/BUILD            | 13 ++++
 tensorflow/python/ops/ragged/ragged_tensor.py | 33 +++++---
 .../python/ops/ragged/ragged_tensor_spec.py   | 77 +++++++++++++++++++
 11 files changed, 305 insertions(+), 27 deletions(-)
 create mode 100644 tensorflow/python/framework/indexed_slices_tensor_spec.py
 create mode 100644 tensorflow/python/framework/sparse_tensor_spec.py
 create mode 100644 tensorflow/python/ops/ragged/ragged_tensor_spec.py

diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index c64ba8617a3..ed1e33b7ae9 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -699,11 +699,13 @@ py_library(
         ":framework_for_generated_wrappers",
         ":function",
         ":graph_util",
+        ":indexed_slices_tensor_spec",
         ":lib",
         ":platform",
         ":pywrap_tensorflow",
         ":random_seed",
         ":sparse_tensor",
+        ":sparse_tensor_spec",
         ":tensor_spec",
         ":tensor_util",
         ":util",
@@ -1191,6 +1193,30 @@ py_library(
     ],
 )
 
+py_library(
+    name = "sparse_tensor_spec",
+    srcs = ["framework/sparse_tensor_spec.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":dtypes",
+        ":sparse_tensor",
+        ":tensor_shape",
+        ":tensor_spec",
+    ],
+)
+
+py_library(
+    name = "indexed_slices_tensor_spec",
+    srcs = ["framework/indexed_slices_tensor_spec.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":dtypes",
+        ":framework_ops",
+        ":sparse_tensor",
+        ":tensor_shape",
+    ],
+)
+
 py_library(
     name = "tensor_util",
     srcs = ["framework/tensor_util.py"],
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 4f0b3dec4e5..8574fa1382e 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -1039,7 +1039,7 @@ class FunctionSpec(object):
 
       self._input_signature = tuple(input_signature)
       self._flat_input_signature = tuple(nest.flatten(input_signature,
-                                                      expand_composites=False))
+                                                      expand_composites=True))
 
   @property
   def fullargspec(self):
@@ -1656,7 +1656,7 @@ def register(func, *args, **kwargs):
 
 def validate_signature(signature):
   if any(not isinstance(arg, tensor_spec.TensorSpec)
-         for arg in nest.flatten(signature, expand_composites=False)):
+         for arg in nest.flatten(signature, expand_composites=True)):
     raise TypeError("Invalid input_signature %s; input_signature must be "
                     "a possibly nested sequence of TensorSpec objects.")
 
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 02130b053c5..9b8dc930946 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -41,6 +41,7 @@ from tensorflow.python.framework import function as tf_function
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import random_seed
 from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import sparse_tensor_spec
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import test_ops
@@ -66,6 +67,7 @@ from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 from tensorflow.python.ops.ragged import ragged_tensor
+from tensorflow.python.ops.ragged import ragged_tensor_spec
 from tensorflow.python.platform import test
 from tensorflow.python.training import training_ops
 from tensorflow.python.util import compat
@@ -964,6 +966,16 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
        {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}),
       (sparse_tensor.SparseTensor,
        {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}),
+      (ragged_tensor.RaggedTensor.from_row_lengths,
+       {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]},
+       [ragged_tensor_spec.ragged_tensor_spec([None, None], dtypes.int32)]),
+      (ragged_tensor.RaggedTensor.from_nested_row_lengths,
+       {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]},
+       [ragged_tensor_spec.ragged_tensor_spec([None, None, None],
+                                              dtypes.int32)]),
+      (sparse_tensor.SparseTensor,
+       {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]},
+       [sparse_tensor_spec.sparse_tensor_spec([None], dtypes.int32)]),
   ])  # pyformat: disable
   def testCompositeAsArgumentTensorWithDefun(self,
                                              factory_fn,
diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py
index db33c5472e4..5311f6601ac 100644
--- a/tensorflow/python/framework/func_graph.py
+++ b/tensorflow/python/framework/func_graph.py
@@ -27,7 +27,6 @@ from tensorflow.python.eager import context
 from tensorflow.python.eager import execute
 from tensorflow.python.eager import tape
 from tensorflow.python.eager.graph_only_ops import graph_placeholder
-from tensorflow.python.framework import composite_tensor
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_spec
@@ -79,7 +78,6 @@ def convert_structure_to_signature(structure, arg_names=None):
     Identical structure that has TensorSpec objects instead of Tensors and
     UknownArgument instead of any unsupported types.
   """
-  structure = composite_tensor.replace_composites_with_components(structure)
   def encode_arg(arg, path):
     """A representation for this argument, for converting into signatures."""
     if isinstance(arg, ops.Tensor):
diff --git a/tensorflow/python/framework/indexed_slices_tensor_spec.py b/tensorflow/python/framework/indexed_slices_tensor_spec.py
new file mode 100644
index 00000000000..965e09231a9
--- /dev/null
+++ b/tensorflow/python/framework/indexed_slices_tensor_spec.py
@@ -0,0 +1,65 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TensorSpec factory for sparse tensors."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
+
+
+def indexed_slices_tensor_spec(shape=None,
+                               dtype=dtypes.float32,
+                               num_slices=None,
+                               has_dense_shape=True,
+                               name=None):
+  """Returns a tensor specification for a IndexedSlices.
+
+  Returns an object which can be passed to `tf.function` (or other
+  functions that expect `TensorSpec`s) to specify shape constraints
+  for a `IndexedSlices` argument.
+
+  Args:
+    shape: The shape of the IndexedSlices, or `None` to allow any shape.
+      The returned specification object depends only on `shape[1:]`.
+    dtype: Data type of values in the IndexedSlices.
+    num_slices: Number of slices.  Default allows for any number of slices.
+    has_dense_shape: Indicates whether the IndexedSlices is expected to have a
+      `dense_shape` component.
+    name: Optional name prefix for the `TensorSpec`s.
+
+  Returns:
+    An object describing the `values`, `indices` and `dense_shape` tensors
+    that comprise the `IndexedSlices`.
+  """
+  dtype = dtypes.as_dtype(dtype)
+  shape = tensor_shape.TensorShape(shape)
+  num_slices = tensor_shape.Shape([num_slices])
+
+  values = tensor_spec.TensorSpec(
+      num_slices.concatenate(shape[1:]), dtype, name)
+  indices = tensor_spec.TensorSpec(num_slices, dtypes.int64,
+                                   ("%s.indices" % name) if name else None)
+  if has_dense_shape:
+    dense_shape = tensor_spec.TensorSpec([shape.ndims], dtypes.int64,
+                                         ("%s.dense_shape" %
+                                          name) if name else None)
+  else:
+    dense_shape = None
+  return ops.IndexedSlices(values, indices, dense_shape)
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 1b1c19e9511..d23cfc77c94 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -62,8 +62,14 @@ from tensorflow.python.util import memory
 from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util import tf_stack
 from tensorflow.python.util.deprecation import deprecated_args
+from tensorflow.python.util.lazy_loader import LazyLoader
 from tensorflow.python.util.tf_export import tf_export
 
+# This is to avoid a circular dependency: ops -> tensor_spec -> ops
+tensor_spec = LazyLoader(
+    "tensor_spec", globals(),
+    "tensorflow.python.framework.tensor_spec")
+
 # Temporary global switches determining if we should enable the work-in-progress
 # calls to the C API. These will be removed once all functionality is supported.
 _USE_C_API = True
@@ -1686,7 +1692,8 @@ class IndexedSlices(_TensorLike, composite_tensor.CompositeTensor):
 
   def __init__(self, values, indices, dense_shape=None):
     """Creates an `IndexedSlices`."""
-    _get_graph_from_inputs([values, indices, dense_shape])
+    if not isinstance(values, tensor_spec.TensorSpec):
+      _get_graph_from_inputs([values, indices, dense_shape])
     self._values = values
     self._indices = indices
     self._dense_shape = dense_shape
diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py
index 834fa3338ba..e9199b1e661 100644
--- a/tensorflow/python/framework/sparse_tensor.py
+++ b/tensorflow/python/framework/sparse_tensor.py
@@ -25,6 +25,7 @@ from tensorflow.python.framework import composite_tensor
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.util.tf_export import tf_export
 
@@ -114,21 +115,29 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
       values: A 1-D tensor of any type and shape `[N]`.
       dense_shape: A 1-D int64 tensor of shape `[ndims]`.
     """
-    with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]):
-      indices = ops.convert_to_tensor(
-          indices, name="indices", dtype=dtypes.int64)
-      # TODO(touts): Consider adding mutable_values() when 'values'
-      # is a VariableOp and updating users of SparseTensor.
-      values = ops.internal_convert_to_tensor(values, name="values")
-      dense_shape = ops.convert_to_tensor(
-          dense_shape, name="dense_shape", dtype=dtypes.int64)
+    if isinstance(indices, tensor_spec.TensorSpec):
+      if not isinstance(values, tensor_spec.TensorSpec):
+        raise TypeError("Expected values to be a TensorSpec")
+      if not isinstance(dense_shape, tensor_spec.TensorSpec):
+        raise TypeError("Expected dense_shape to be a TensorSpec")
+      if indices.dtype != dtypes.int64 or dense_shape.dtype != dtypes.int64:
+        raise TypeError("indices and dense_shape must have dtype=int64")
+    else:
+      with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]):
+        indices = ops.convert_to_tensor(
+            indices, name="indices", dtype=dtypes.int64)
+        # TODO(touts): Consider adding mutable_values() when 'values'
+        # is a VariableOp and updating users of SparseTensor.
+        values = ops.internal_convert_to_tensor(values, name="values")
+        dense_shape = ops.convert_to_tensor(
+            dense_shape, name="dense_shape", dtype=dtypes.int64)
     self._indices = indices
     self._values = values
     self._dense_shape = dense_shape
 
-    indices_shape = indices.get_shape().with_rank(2)
-    values_shape = values.get_shape().with_rank(1)
-    dense_shape_shape = dense_shape.get_shape().with_rank(1)
+    indices_shape = indices.shape.with_rank(2)
+    values_shape = values.shape.with_rank(1)
+    dense_shape_shape = dense_shape.shape.with_rank(1)
 
     # Assert number of rows in indices match the number of elements in values.
     indices_shape.dims[0].merge_with(values_shape.dims[0])
@@ -244,8 +253,8 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
     return [
         tensor_shape.TensorShape([None, rank]),  # indices
         tensor_shape.TensorShape([None]),  # values
-        tensor_shape.TensorShape([rank])
-    ]  # dense_shape
+        tensor_shape.TensorShape([rank])  # dense_shape
+    ]
 
   @property
   def _is_graph_tensor(self):
diff --git a/tensorflow/python/framework/sparse_tensor_spec.py b/tensorflow/python/framework/sparse_tensor_spec.py
new file mode 100644
index 00000000000..4c9f163384a
--- /dev/null
+++ b/tensorflow/python/framework/sparse_tensor_spec.py
@@ -0,0 +1,56 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TensorSpec factory for sparse tensors."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
+
+
+def sparse_tensor_spec(shape=None,
+                       dtype=dtypes.float32,
+                       num_values=None,
+                       name=None):
+  """Returns a tensor specification for a SparseTensor.
+
+  Returns an object which can be passed to `tf.function` (or other
+  functions that expect `TensorSpec`s) to specify shape constraints
+  for a `SparseTensor` argument.
+
+  Args:
+    shape: The shape of the SparseTensor, or `None` to allow any shape. The
+      returned specification object depends only on `shape.ndims`.
+    dtype: Data type of values in the SparseTensor.
+    num_values: The number of values in the SparseTensor, or `None` to allow any
+      number of values.
+    name: Optional name prefix for the `TensorSpec`s.
+
+  Returns:
+    An object describing the `values`, `indices` and `dense_shape` tensors
+    that comprise the `SparseTensor`.
+  """
+  dtype = dtypes.as_dtype(dtype)
+  rank = tensor_shape.TensorShape(shape).rank
+  indices = tensor_spec.TensorSpec([num_values, rank], dtypes.int64,
+                                   ("%s.indices" % name) if name else None)
+  values = tensor_spec.TensorSpec([num_values], dtype, name)
+  dense_shape = tensor_spec.TensorSpec(
+      [rank], dtypes.int64, ("%s.dense_shape" % name) if name else None)
+  return sparse_tensor.SparseTensor(indices, values, dense_shape)
diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD
index 3b7ea580e0f..d8fb74fae68 100644
--- a/tensorflow/python/ops/ragged/BUILD
+++ b/tensorflow/python/ops/ragged/BUILD
@@ -40,6 +40,7 @@ py_library(
         ":ragged_string_ops",
         ":ragged_tensor",
         ":ragged_tensor_shape",
+        ":ragged_tensor_spec",
         ":ragged_tensor_value",
         ":ragged_util",
         ":ragged_where_op",
@@ -434,6 +435,18 @@ py_library(
     ],
 )
 
+py_library(
+    name = "ragged_tensor_spec",
+    srcs = ["ragged_tensor_spec.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":ragged_tensor",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:tensor_shape",
+        "//tensorflow/python:tensor_spec",
+    ],
+)
+
 #-------------------------------------------------------------------------------
 # RaggedTensor Tests
 #-------------------------------------------------------------------------------
diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py
index 561fd3bedc3..a30e6c13a77 100644
--- a/tensorflow/python/ops/ragged/ragged_tensor.py
+++ b/tensorflow/python/ops/ragged/ragged_tensor.py
@@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
@@ -241,16 +242,28 @@ class RaggedTensor(composite_tensor.CompositeTensor):
                        "of the factory methods instead (e.g., "
                        "RaggedTensor.from_row_lengths())")
 
-    # Validate the arguments.
-    if not isinstance(values, (RaggedTensor, ops.Tensor)):
-      raise TypeError("values must be a Tensor or RaggedTensor.")
-    if not isinstance(row_splits, ops.Tensor):
-      raise TypeError("Row-partitioning argument must be a Tensor.")
-    if row_splits.dtype not in (dtypes.int32, dtypes.int64):
-      raise ValueError("Row-partitioning argument must be int32 or int64")
-    values.shape.with_rank_at_least(1)
+    is_tensor_spec = isinstance(row_splits, tensor_spec.TensorSpec)
+    if is_tensor_spec:
+      if not (isinstance(values, tensor_spec.TensorSpec) or
+              (isinstance(values, RaggedTensor) and
+               isinstance(values.row_splits, tensor_spec.TensorSpec))):
+        raise TypeError("Expected values to be a TensorSpec, got %r" % values)
+    else:
+      # Validate the arguments.
+      if not isinstance(row_splits, ops.Tensor):
+        raise TypeError("Row-partitioning argument must be a Tensor, got %r" %
+                        row_splits)
+      if not isinstance(values, (RaggedTensor, ops.Tensor)):
+        raise TypeError("values must be a Tensor or RaggedTensor, got %r" %
+                        values)
+      if row_splits.dtype not in (dtypes.int32, dtypes.int64):
+        raise ValueError("Row-partitioning argument must be int32 or int64")
+
+    # Validate shapes & dtypes.
     row_splits.shape.assert_has_rank(1)
-    row_splits.set_shape([None])
+    values.shape.with_rank_at_least(1)
+    if not is_tensor_spec:
+      row_splits.set_shape([None])
     if isinstance(values, RaggedTensor):
       assert row_splits.dtype == values.row_splits.dtype
 
@@ -431,6 +444,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
       raise TypeError("validate must have type bool")
     if isinstance(row_splits, (list, tuple)) and not row_splits:
       raise ValueError("row_splits tensor may not be empty.")
+    if isinstance(row_splits, tensor_spec.TensorSpec):
+      return cls(values=values, row_splits=row_splits, internal=True)
 
     with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]):
       values, row_splits = cls._convert_values_and_row_partition(
diff --git a/tensorflow/python/ops/ragged/ragged_tensor_spec.py b/tensorflow/python/ops/ragged/ragged_tensor_spec.py
new file mode 100644
index 00000000000..9da282cea3c
--- /dev/null
+++ b/tensorflow/python/ops/ragged/ragged_tensor_spec.py
@@ -0,0 +1,77 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TensorSpec factory for ragged tensors."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
+from tensorflow.python.ops.ragged import ragged_tensor
+
+
+def ragged_tensor_spec(shape=None, dtype=dtypes.float32,
+                       ragged_rank=None, row_splits_dtype=dtypes.int64,
+                       name=None):
+  """Returns a tensor specification for a RaggedTensor.
+
+  Returns an object which can be passed to `tf.function` (or other
+  functions that expect `TensorSpec`s) to specify shape constraints
+  for a `RaggedTensor` argument.
+
+  Args:
+    shape: The shape of the RaggedTensor, or `None` to allow any shape.
+    dtype: Data type of values in the RaggedTensor.
+    ragged_rank: Python integer, the ragged rank of the RaggedTensor
+      to be described.  Defaults to `shape.ndims - 1`.
+    row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor.
+      One of `tf.int32` or `tf.int64`.
+    name: Optional name prefix for the `TensorSpec`s.
+
+  Returns:
+    An object describing the `flat_values` and `nested_row_splits` tensors
+    that comprise the `RaggedTensor`.
+  """
+  dtype = dtypes.as_dtype(dtype)
+  shape = tensor_shape.TensorShape(shape)
+  if ragged_rank is None:
+    if shape.ndims is None:
+      raise ValueError("Must specify ragged_rank or a shape with known rank.")
+    ragged_rank = shape.ndims - 1
+  elif not isinstance(ragged_rank, int):
+    raise TypeError("ragged_rank must be an int")
+  if ragged_rank == 0:
+    return tensor_spec.TensorSpec(shape=shape, dtype=dtype, name=name)
+
+  result = tensor_spec.TensorSpec(
+      tensor_shape.TensorShape([None]).concatenate(shape[ragged_rank + 1:]),
+      dtype, name)
+
+  for i in range(ragged_rank - 1, 0, -1):
+    splits = tensor_spec.TensorSpec(
+        [None], row_splits_dtype,
+        "%s.row_splits_%d" % (name, i) if name else None)
+    result = ragged_tensor.RaggedTensor.from_row_splits(result, splits)
+
+  outer_dim = tensor_shape.dimension_at_index(shape, 0)
+  splits_shape = [None if outer_dim is None else outer_dim + 1]
+  splits = tensor_spec.TensorSpec(
+      splits_shape, row_splits_dtype,
+      "%s.row_splits_0" % name if name else None)
+  result = ragged_tensor.RaggedTensor.from_row_splits(result, splits)
+
+  return result