diff --git a/tensorflow/core/kernels/ragged_tensor_to_tensor_op_test.cc b/tensorflow/core/kernels/ragged_tensor_to_tensor_op_test.cc
index 7337ebe4ba5..7d7b0e5a9ec 100644
--- a/tensorflow/core/kernels/ragged_tensor_to_tensor_op_test.cc
+++ b/tensorflow/core/kernels/ragged_tensor_to_tensor_op_test.cc
@@ -527,10 +527,14 @@ TEST_F(RaggedTensorToTensorOpUnknownShapeTest, ValueRowIDs) {
   INFER_OK(*op_, "?;[6,2];?;[];[6]", "[?,?,2]");
   INFER_OK(*op_, "?;[6,2];[2];[];[6]", "[?,?,2]");
   INFER_OK(*op_, "?;[6,2,7];[2,7];[];[6]", "[?,?,2,7]");
-  INFER_ERROR("default_value_shape and value_shape do not match", *op_,
-              "?;[6,2];[3];[];[6]");
-  INFER_ERROR("default_value_shape and value_shape do not match", *op_,
-              "?;[6,2,1,2];[2,2];[];[6]");
+  INFER_ERROR(
+      "default_value.shape=[3] and rt_input.flat_values.shape=[6,2] "
+      "are incompatible",
+      *op_, "?;[6,2];[3];[];[6]");
+  INFER_ERROR(
+      "default_value.shape=[2,2] and rt_input.flat_values.shape="
+      "[6,2,1,2] are incompatible",
+      *op_, "?;[6,2,1,2];[2,2];[];[6]");
   INFER_ERROR("must be a vector", *op_, "?;[6];[];[];[3,6]");
   INFER_ERROR("must be a scalar", *op_, "?;[6];[];[7];[3]");
 }
diff --git a/tensorflow/core/ops/ragged_to_dense_util.cc b/tensorflow/core/ops/ragged_to_dense_util.cc
index 246f72494fc..ecb95e163ab 100644
--- a/tensorflow/core/ops/ragged_to_dense_util.cc
+++ b/tensorflow/core/ops/ragged_to_dense_util.cc
@@ -91,10 +91,11 @@ tensorflow::Status CombineRaggedTensorToTensorShapes(
   }
   // At this point, value_shape and output_shape have known ranks.
   if (ragged_rank + value_shape.dim_size() != output_shape->dim_size()) {
-    return InvalidArgument("Value shape (", value_shape.DebugString(),
-                           "), ragged_rank(", ragged_rank, ") and shape(",
-                           shape.DebugString(),
-                           ") do not have a consistent number of dimensions");
+    return InvalidArgument(
+        "rt_input.shape and shape=", TensorShape::DebugString(shape),
+        " are incompatible: rt_input.rank = ",
+        ragged_rank + value_shape.dim_size(),
+        " but shape.rank = ", output_shape->dim_size());
   }
 
   for (int i = 1; i < value_shape.dim_size(); ++i) {
@@ -105,7 +106,11 @@ tensorflow::Status CombineRaggedTensorToTensorShapes(
     if (value_dim.size() >= 0) {
       if (output_shape_dim->size() >= 0) {
         if (output_shape_dim->size() != value_dim.size()) {
-          return InvalidArgument("Value and shape dimension are inconsistent.");
+          return InvalidArgument(
+              "rt_input.shape and shape=", TensorShape::DebugString(shape),
+              " are incompatible: rt_input.shape[", i + ragged_rank,
+              "] = ", value_dim.size(), " but shape[", i + ragged_rank,
+              "] = ", output_shape_dim->size());
         }
       } else {
         output_shape_dim->set_size(value_dim.size());
@@ -132,28 +137,29 @@ tensorflow::Status ValidateDefaultValueShape(
     return tensorflow::Status::OK();
   }
 
-  if (default_value_shape.dim_size() > value_shape.dim_size()) {
-    // TODO(martinz): This constraint is unnecessary. The
-    // default value could have as many dimensions as shape. If there is a
-    // discrepancy, it will be picked up when we broadcast the default value.
-    // For now, I'll relax the constraint only slightly.
+  int default_ndims = default_value_shape.dim_size();
+  int values_ndims = value_shape.dim_size();
+  if (default_ndims >= values_ndims) {
     return InvalidArgument(
-        "default_value_shape must have no more dimensions than the value. "
-        "default_value_shape: ",
-        default_value_shape.DebugString(),
-        " default_value_shape.dim_size(): ", default_value_shape.dim_size(),
-        " value_shape: ", value_shape.DebugString(),
-        " value_shape.dim_size(): ", value_shape.dim_size());
+        "default_value.shape=", TensorShape::DebugString(default_value_shape),
+        " and rt_input.flat_values.shape=",
+        TensorShape::DebugString(value_shape),
+        " are incompatible: default_value.rank = ", default_ndims,
+        "  must be less than rt_input.flat_values.rank = ", values_ndims);
   }
-  for (int i = 0;
-       i < std::min(default_value_shape.dim_size(), value_shape.dim_size() - 1);
-       ++i) {
-    if (default_value_shape.dim(i).size() >= 0 &&
-        value_shape.dim(i + 1).size() >= 0 &&
-        default_value_shape.dim(i).size() != 1 &&
-        default_value_shape.dim(i).size() != value_shape.dim(i + 1).size()) {
+  for (int i = 0; i < std::min(default_ndims, values_ndims - 1); ++i) {
+    int default_dim = default_value_shape.dim(i).size();
+    int value_dim = value_shape.dim(i + 1).size();
+    if (default_dim >= 0 && value_dim >= 0 && default_dim != 1 &&
+        default_dim != value_dim) {
       return InvalidArgument(
-          "default_value_shape and value_shape do not match on dimension ", i);
+          "default_value.shape=", TensorShape::DebugString(default_value_shape),
+          " and rt_input.flat_values.shape=",
+          TensorShape::DebugString(value_shape),
+          " are incompatible: default_value.shape[",
+          i - default_value_shape.dim_size(), "] = ", default_dim,
+          " but rt_input.flat_values.shape[",
+          i - default_value_shape.dim_size(), "] = ", value_dim);
     }
   }
   return tensorflow::Status::OK();
diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD
index 88c541b5f77..627bf9c2a1a 100644
--- a/tensorflow/python/ops/ragged/BUILD
+++ b/tensorflow/python/ops/ragged/BUILD
@@ -1,4 +1,5 @@
 load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow/tools/test:performance.bzl", "tf_py_logged_benchmark")
 
 package(
     default_visibility = [
@@ -669,6 +670,11 @@ py_test(
     ],
 )
 
+tf_py_logged_benchmark(
+    name = "ragged_to_tensor_op_benchmark",
+    target = "//tensorflow/python/ops/ragged:ragged_to_tensor_op_test",
+)
+
 py_test(
     name = "ragged_segment_op_test",
     srcs = ["ragged_segment_op_test.py"],
diff --git a/tensorflow/python/ops/ragged/ragged_to_tensor_op_test.py b/tensorflow/python/ops/ragged/ragged_to_tensor_op_test.py
index 0b994af2ff7..882fb01dee8 100644
--- a/tensorflow/python/ops/ragged/ragged_to_tensor_op_test.py
+++ b/tensorflow/python/ops/ragged/ragged_to_tensor_op_test.py
@@ -18,20 +18,70 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import random
 from absl.testing import parameterized
 import numpy as np
 
+from tensorflow.python.client import session
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops.ragged import ragged_conversion_ops
 from tensorflow.python.ops.ragged import ragged_factory_ops
+from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
+from tensorflow.python.platform import benchmark
 from tensorflow.python.platform import googletest
+from tensorflow.python.util import nest
+
+
+def make_placeholder(t):
+  return array_ops.placeholder_with_default(t, None)
+
+
+def rebuild_ragged_tensor_with_value_rowids(rt, feed_dict=None, sess=None):
+  """Returns a copy of `rt`, built using `from_value_rowids`.
+
+  This ensures that RaggedTensor._cached_value_rowids is populated, which
+  triggers a different code-path for converting ragged tensors to tensors.
+
+  If `feed_dict` and `sess` are specified, then build the new `RaggedTensor`
+  using placeholder tensors, and populate a feed dictionary that can be used
+  to feed the placeholders.
+
+  Args:
+    rt: The RaggedTensor to copy.
+    feed_dict: If specified, then build the new `RaggedTensor` using
+      placeholders, and populate this dict with entries to feed those
+      placeholders.
+    sess: A session used to evaluate tensors; required if feed_dict is
+      specified.
+
+  Returns:
+    A copy of `rt`, built using `from_value_rowids`.
+  """
+  if isinstance(rt, ragged_tensor.RaggedTensor):
+    values = rebuild_ragged_tensor_with_value_rowids(rt.values, feed_dict, sess)
+    rowids = rt.value_rowids()
+    nrows = rt.nrows()
+    if feed_dict is not None:
+      rowids_ph = make_placeholder(rowids)
+      nrows_ph = make_placeholder(nrows)
+      feed_dict[rowids_ph] = sess.run(rowids)
+      feed_dict[nrows_ph] = sess.run(nrows)
+      rowids, nrows = rowids_ph, nrows_ph
+    return ragged_tensor.RaggedTensor.from_value_rowids(values, rowids, nrows)
+  else:
+    if feed_dict is not None:
+      rt_ph = make_placeholder(rt)
+      feed_dict[rt_ph] = sess.run(rt)
+      rt = rt_ph
+    return rt
 
 
 @test_util.run_all_in_graph_and_eager_modes
@@ -212,30 +262,77 @@ class RaggedTensorToTensorOpNewTest(test_util.TensorFlowTestCase,
                                ragged_rank=None,
                                default=None,
                                expected_shape=None):
-    rt = ragged_factory_ops.constant(rt_input, ragged_rank=ragged_rank)
-    dt = ragged_conversion_ops.ragged_to_dense(rt, default_value=default)
+    rt1 = ragged_factory_ops.constant(rt_input, ragged_rank=ragged_rank)
+    dt1 = ragged_conversion_ops.ragged_to_dense(rt1, default_value=default)
+    rt2 = rebuild_ragged_tensor_with_value_rowids(rt1)
+    dt2 = ragged_conversion_ops.ragged_to_dense(rt2, default_value=default)
 
-    self.assertIsInstance(dt, ops.Tensor)
-    self.assertEqual(rt.dtype, dt.dtype)
-    self.assertTrue(dt.shape.is_compatible_with(rt.shape))
-    if expected_shape is not None:
-      expected = np.ndarray(expected_shape, buffer=np.array(expected))
-    self.assertAllEqual(dt, expected)
+    for (rt, dt) in [(rt1, dt1), (rt2, dt2)]:
+      self.assertIsInstance(dt, ops.Tensor)
+      self.assertEqual(rt.dtype, dt.dtype)
+      self.assertTrue(dt.shape.is_compatible_with(rt.shape))
+      if expected_shape is not None:
+        expected = np.ndarray(expected_shape, buffer=np.array(expected))
+      self.assertAllEqual(dt, expected)
 
-  @parameterized.parameters(
+  @parameterized.parameters([
       {
           'rt_input': [[1, 2, 3]],
           'default': 'a',
-          'error': (TypeError, '.*'),
-      }, {
+          'error_type': TypeError,
+          'error': r"Expected int32 passed to parameter 'default_value'|"
+                   r"Cannot convert 'a' to EagerTensor of dtype int32",
+      },
+      {
           'rt_input': [[1, 2, 3]],
-          'default': 'b',
-          'error': (TypeError, '.*'),
-      })
-  def testError(self, rt_input, default, error, ragged_rank=None):
+          'default': [0],
+          'error': r'default_value\.shape=\[1\] and '
+                   r'rt_input\.flat_values\.shape=\[3\] are incompatible: '
+                   r'default_value\.rank = 1  must be less than '
+                   r'rt_input\.flat_values\.rank = 1'
+      },
+      {
+          'rt_input': [[[1, 2], [3, 4]], [[5, 6]]],
+          'ragged_rank': 1,
+          'default': [7, 8, 9],
+          'error': r'default_value\.shape=\[3\] and '
+                   r'rt_input\.flat_values\.shape=\[3,2\] are incompatible: '
+                   r'default_value\.shape\[-1\] = 3 but '
+                   r'rt_input\.flat_values\.shape\[-1\] = 2'
+      },
+      {
+          'rt_input': [[1, 2, 3]],
+          'shape': [3, 3, 3],
+          'error': r'rt_input\.shape and shape=\[.,.,.\] are incompatible: '
+                   r'rt_input\.rank = 2 but shape\.rank = 3'
+      },
+      {
+          'rt_input': [[[1, 2, 3]]],
+          'ragged_rank': 1,
+          'shape': [1, 1, 4],
+          'error': r'rt_input\.shape and shape=\[1,1,4\] are incompatible: '
+                   r'rt_input\.shape\[2\] = 3 but shape\[2\] = 4'
+      },
+  ])
+  def testError(self,
+                rt_input,
+                error,
+                error_type=(ValueError, errors.InvalidArgumentError),
+                default=None,
+                ragged_rank=None,
+                shape=None):
+
     rt = ragged_factory_ops.constant(rt_input, ragged_rank=ragged_rank)
-    with self.assertRaisesRegexp(error[0], error[1]):
-      ragged_conversion_ops.ragged_to_dense(rt, default_value=default)
+    with self.assertRaisesRegexp(error_type, error):
+      self.evaluate(
+          ragged_conversion_ops.ragged_to_dense(
+              rt, default_value=default, shape=shape))
+    rt_placeholder = nest.map_structure(
+        make_placeholder, rt, expand_composites=True)
+    with self.assertRaisesRegexp(error_type, error):
+      self.evaluate(
+          ragged_conversion_ops.ragged_to_dense(
+              rt_placeholder, default_value=default, shape=shape))
 
 
 @test_util.run_all_in_graph_and_eager_modes
@@ -454,7 +551,7 @@ class RaggedToTensorOpAdditionalTests(test_util.TensorFlowTestCase):
     input_data = ragged_factory_ops.constant([[[[1, 2], [3, 4]]], []],
                                              ragged_rank=1)
     # This placeholder has a 2 x 1 dimension.
-    default_value = array_ops.placeholder_with_default([[5], [6]], shape=None)
+    default_value = make_placeholder([[5], [6]])
     actual = ragged_conversion_ops.ragged_to_dense(
         input_data, default_value=default_value)
     expected = [[[[1, 2], [3, 4]]], [[[5, 5], [6, 6]]]]
@@ -489,5 +586,145 @@ class RaggedToTensorOpAdditionalTests(test_util.TensorFlowTestCase):
     self.assertAllEqual(actual, [[3, 3, 3], [3, 3, 3]])
 
 
+class RaggedToDenseBenchmark(googletest.Benchmark):
+
+  # Configurations to test.  See `run_benchmark` for config param docs.
+  CONFIGS = [
+      {'shape': [10, 10]},
+      {'shape': [10, 1000]},
+      {'shape': [1000, 10]},
+      {'shape': [1000, 10], 'fill': [1, 0.95]},  # Mostly full.
+      {'shape': [1000, 10], 'fill': [1, 0.05]},  # Mostly empty.
+      {'shape': [1000, 10], 'dtype': dtypes.string},
+      {'shape': [1000, 10], 'dtype': dtypes.int64},
+      {'shape': [100, 100]},
+      {'shape': [50, 50, 32]},
+      {'shape': [100, 100, 100], 'min_iters': 100},
+      {'shape': [1000, 1000], 'min_iters': 100},
+      {'shape': [10, 10, 10, 10, 10]},
+      {'shape': [10, 10, 10, 10, 10], 'ragged_rank': 1},
+      {'shape': [10, 10, 10, 10, 10], 'ragged_rank': 2},
+      {'shape': [50, 50, 32], 'ragged_rank': 1, 'default_shape': [32]},
+      {'shape': [200, 50, 32], 'ragged_rank': 1, 'default_shape': [32]}
+  ]  # pyformat: disable
+
+  def run_benchmark(self,
+                    shape=(100, 100),
+                    ragged_rank=None,
+                    dtype=dtypes.float32,
+                    fill=None,
+                    default_shape=(),
+                    output_shape=None,
+                    min_iters=1000):
+    """Run a benchmark with the specified configuraiton parameters.
+
+    Args:
+      shape: Bounding box for the input ragged tensor.
+      ragged_rank: Ragged rank for the input ragged tensor.  Defauts to
+        `len(shape)-1`.
+      dtype: Data type for the input ragged tensor.
+      fill: How full each dimension should be (0-1).  Corresponds 1:1 with
+        `shape`.  Defaults to 0.8 for each dimension.
+      default_shape: Shape for the default (padding) value.
+      output_shape: Output shape -- ragged tensor will be padded or cropped to
+        this shape.
+      min_iters: Minimum iterations for benchmark.
+    """
+    if ragged_rank is None:
+      ragged_rank = len(shape) - 1
+    if fill is None:
+      fill = [0.8 for _ in shape]
+
+    # Build the inputs for the op.
+    rt_input = self._generateRaggedTensor(shape, ragged_rank, dtype, fill)
+    default_value = constant_op.constant(
+        self._generateRaggedTensor(default_shape, 0, dtype), dtype=dtype)
+
+    mbs = np.prod(shape) / (2**20)
+    with session.Session(config=benchmark.benchmark_config()) as sess:
+      extras = {
+          'shape': shape,
+          'ragged_rank': ragged_rank,
+          'dtype': dtype,
+          'fill': fill,
+          'default_shape': default_shape
+      }
+      rt = ragged_factory_ops.constant(rt_input, dtype, ragged_rank=ragged_rank)
+
+      # Inputs for with_splits:
+      splits_rt_placeholder = ragged_factory_ops.placeholder(
+          dtype, ragged_rank, shape[ragged_rank + 1:])
+      splits_feed_dict = {splits_rt_placeholder: sess.run(rt)}
+
+      # Inputs for with_rowids:
+      rowids_feed_dict = {}
+      rowids_rt_placeholder = rebuild_ragged_tensor_with_value_rowids(
+          rt, rowids_feed_dict, sess)
+
+      # Common arguments for benchmarks:
+      run_op_benchmark_kwargs = dict(
+          sess=sess,
+          store_memory_usage=True,
+          min_iters=min_iters,
+          burn_iters=max(5, min_iters // 10),
+          mbs=mbs,
+          extras=extras)
+
+      ragged_to_dense_with_splits = ragged_conversion_ops.ragged_to_dense(
+          splits_rt_placeholder, default_value=default_value)
+      self.run_op_benchmark(
+          op_or_tensor=ragged_to_dense_with_splits.op,
+          name='ragged_to_dense_with_splits',
+          feed_dict=splits_feed_dict,
+          **run_op_benchmark_kwargs)
+
+      ragged_to_tensor_with_splits = splits_rt_placeholder.to_tensor(
+          default_value=default_value)
+      self.run_op_benchmark(
+          op_or_tensor=ragged_to_tensor_with_splits.op,
+          name='ragged_to_tensor_with_splits',
+          feed_dict=splits_feed_dict,
+          **run_op_benchmark_kwargs)
+
+      ragged_to_dense_with_rowids = ragged_conversion_ops.ragged_to_dense(
+          rowids_rt_placeholder, default_value=default_value)
+      self.run_op_benchmark(
+          op_or_tensor=ragged_to_dense_with_rowids.op,
+          name='ragged_to_dense_with_rowids',
+          feed_dict=rowids_feed_dict,
+          **run_op_benchmark_kwargs)
+
+      ragged_to_tensor_with_rowids = rowids_rt_placeholder.to_tensor(
+          default_value=default_value)
+      self.run_op_benchmark(
+          op_or_tensor=ragged_to_tensor_with_rowids.op,
+          name='ragged_to_tensor_with_rowids',
+          feed_dict=rowids_feed_dict,
+          **run_op_benchmark_kwargs)
+
+  def _generateRaggedTensor(self, shape, ragged_rank, dtype, fill=None, axis=0):
+    if axis == len(shape):
+      value = random.random()
+      if dtype == dtypes.string:
+        value = str(value)
+      if dtype.is_integer:
+        value = int(value * 1000)
+      return value
+    if axis == 0 or axis > ragged_rank:
+      slice_size = shape[axis]
+    else:
+      slice_size = (np.random.geometric(fill[axis], shape[axis]) == 1).sum()
+    return [
+        self._generateRaggedTensor(shape, ragged_rank, dtype, fill, axis + 1)
+        for _ in range(slice_size)
+    ]
+
+  def benchmark_ragged_to_dense(self):
+    random.seed(5)
+    for config in self.CONFIGS:
+      self.run_benchmark(**config)
+
+
 if __name__ == '__main__':
   googletest.main()
+