diff --git a/tensorflow/python/ops/ragged/ragged_conversion_ops.py b/tensorflow/python/ops/ragged/ragged_conversion_ops.py index 3ec246ccaf1..0385be02d43 100644 --- a/tensorflow/python/ops/ragged/ragged_conversion_ops.py +++ b/tensorflow/python/ops/ragged/ragged_conversion_ops.py @@ -196,7 +196,7 @@ def to_tensor(rt_input, default_value=None, name=None): Args: rt_input: The input `RaggedTensor`. default_value: Value to set for indices not specified in `rt_input`. - Defaults to zero. `default_value.shape` must be equal to + Defaults to zero. `default_value` must be broadcastable to `rt_input.shape[rt_input.ragged_rank + 1:]`. name: A name prefix for the returned tensors (optional). @@ -210,6 +210,9 @@ def to_tensor(rt_input, default_value=None, name=None): rt_input, name='rt_input') if not ragged_tensor.is_ragged(rt_input): return rt_input # already dense + if default_value is not None: + default_value = ops.convert_to_tensor( + default_value, name='default_value', dtype=rt_input.dtype) # If ragged_rank > 1, then recursively convert the ragged values into a # `Tensor` before we proceed. @@ -217,6 +220,16 @@ def to_tensor(rt_input, default_value=None, name=None): if ragged_tensor.is_ragged(values): values = to_tensor(values, default_value) + # Tile the default value, if necessary. + if default_value is not None: + if values.shape.ndims is not None: + default_value.shape.with_rank_at_most(values.shape.ndims - 1) + if (values.shape.ndims is None or default_value.shape.ndims is None or + values.shape.ndims != default_value.shape.ndims + 1): + value_shape = array_ops.shape(values)[1:] + default_value = array_ops.broadcast_to(default_value, value_shape) + default_value.shape.assert_is_compatible_with(values.shape[1:]) + # Get the expected dense shape ([nrows, ncols] + value_shape). rt_row_lengths = [rt_input.row_splits[1:] - rt_input.row_splits[:-1]] nrows = array_ops.shape(rt_input.row_splits, out_type=dtypes.int64)[0] - 1 @@ -228,9 +241,6 @@ def to_tensor(rt_input, default_value=None, name=None): # Build a default value if none was supplied. if default_value is None: default_value = array_ops.zeros(value_shape, dtype=values.dtype) - else: - default_value = ops.convert_to_tensor( - default_value, name='default_value', dtype=values.dtype) default_value.shape.assert_is_compatible_with(values.shape[1:]) default_value.set_shape(values.shape[1:]) diff --git a/tensorflow/python/ops/ragged/ragged_to_tensor_op_test.py b/tensorflow/python/ops/ragged/ragged_to_tensor_op_test.py index 0ccc214a9c7..688676e46c6 100644 --- a/tensorflow/python/ops/ragged/ragged_to_tensor_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_to_tensor_op_test.py @@ -71,8 +71,30 @@ class RaggedTensorToTensorOpTest(test_util.TensorFlowTestCase, [[1, 2], [0, 0], [3, 4]], # [[0, 0], [0, 0], [0, 0]], # [[5, 0], [0, 0], [0, 0]], # - [[6, 7], [8, 0], [0, 0]] - ] # + [[6, 7], [8, 0], [0, 0]], # + ] + }, + { + 'rt_input': [[[1, 2], [], [3, 4]], [], [[5]], [[6, 7], [8]]], + 'default': + 9, + 'expected': [ + [[1, 2], [9, 9], [3, 4]], # + [[9, 9], [9, 9], [9, 9]], # + [[5, 9], [9, 9], [9, 9]], # + [[6, 7], [8, 9], [9, 9]], # + ] + }, + { + 'rt_input': [[[1], [2], [3]]], + 'ragged_rank': 1, + 'default': 0, + 'expected': [[[1], [2], [3]]], + }, + { + 'rt_input': [[[[1], [2]], [], [[3]]]], + 'default': 9, + 'expected': [[[[1], [2]], [[9], [9]], [[3], [9]]]], }, ) def testRaggedTensorToTensor(self, @@ -96,17 +118,13 @@ class RaggedTensorToTensorOpTest(test_util.TensorFlowTestCase, { 'rt_input': [[1, 2, 3]], 'default': [0], - 'error': (ValueError, r'Shapes \(1,\) and \(\) are incompatible'), + 'error': (ValueError, r'Shape \(1,\) must have rank at most 0'), }, { - 'rt_input': [[[1], [2], [3]]], - 'default': 0, - 'error': (ValueError, r'Shapes \(\) and \(1,\) are incompatible'), - }, - { - 'rt_input': [[[[1], [2]], [], [[3]]]], - 'default': 0, - 'error': (ValueError, r'Shapes \(\) and \(1,\) are incompatible'), + 'rt_input': [[[1, 2], [3, 4]], [[5, 6]]], + 'ragged_rank': 1, + 'default': [7, 8, 9], + 'error': (ValueError, r'Shapes \(3,\) and \(2,\) are incompatible'), }, { 'rt_input': [[1, 2, 3]], @@ -114,9 +132,10 @@ class RaggedTensorToTensorOpTest(test_util.TensorFlowTestCase, 'error': (TypeError, "Expected int32, got 'a' of type 'str' instead"), }, ) - def testError(self, rt_input, default, error): - rt = ragged.constant(rt_input) - self.assertRaisesRegexp(error[0], error[1], ragged.to_tensor, rt, default) + def testError(self, rt_input, default, error, ragged_rank=None): + rt = ragged.constant(rt_input, ragged_rank=ragged_rank) + with self.assertRaisesRegexp(error[0], error[1]): + ragged.to_tensor(rt, default) if __name__ == '__main__':