Add broadcasting support for tf.ragged.to_tensor()'s default_value argument.

PiperOrigin-RevId: 221828849
This commit is contained in:
A. Unique TensorFlower 2018-11-16 11:53:30 -08:00 committed by TensorFlower Gardener
parent 4070b1650c
commit 74660a6db3
2 changed files with 47 additions and 18 deletions

View File

@ -196,7 +196,7 @@ def to_tensor(rt_input, default_value=None, name=None):
Args:
rt_input: The input `RaggedTensor`.
default_value: Value to set for indices not specified in `rt_input`.
Defaults to zero. `default_value.shape` must be equal to
Defaults to zero. `default_value` must be broadcastable to
`rt_input.shape[rt_input.ragged_rank + 1:]`.
name: A name prefix for the returned tensors (optional).
@ -210,6 +210,9 @@ def to_tensor(rt_input, default_value=None, name=None):
rt_input, name='rt_input')
if not ragged_tensor.is_ragged(rt_input):
return rt_input # already dense
if default_value is not None:
default_value = ops.convert_to_tensor(
default_value, name='default_value', dtype=rt_input.dtype)
# If ragged_rank > 1, then recursively convert the ragged values into a
# `Tensor` before we proceed.
@ -217,6 +220,16 @@ def to_tensor(rt_input, default_value=None, name=None):
if ragged_tensor.is_ragged(values):
values = to_tensor(values, default_value)
# Tile the default value, if necessary.
if default_value is not None:
if values.shape.ndims is not None:
default_value.shape.with_rank_at_most(values.shape.ndims - 1)
if (values.shape.ndims is None or default_value.shape.ndims is None or
values.shape.ndims != default_value.shape.ndims + 1):
value_shape = array_ops.shape(values)[1:]
default_value = array_ops.broadcast_to(default_value, value_shape)
default_value.shape.assert_is_compatible_with(values.shape[1:])
# Get the expected dense shape ([nrows, ncols] + value_shape).
rt_row_lengths = [rt_input.row_splits[1:] - rt_input.row_splits[:-1]]
nrows = array_ops.shape(rt_input.row_splits, out_type=dtypes.int64)[0] - 1
@ -228,9 +241,6 @@ def to_tensor(rt_input, default_value=None, name=None):
# Build a default value if none was supplied.
if default_value is None:
default_value = array_ops.zeros(value_shape, dtype=values.dtype)
else:
default_value = ops.convert_to_tensor(
default_value, name='default_value', dtype=values.dtype)
default_value.shape.assert_is_compatible_with(values.shape[1:])
default_value.set_shape(values.shape[1:])

View File

@ -71,8 +71,30 @@ class RaggedTensorToTensorOpTest(test_util.TensorFlowTestCase,
[[1, 2], [0, 0], [3, 4]], #
[[0, 0], [0, 0], [0, 0]], #
[[5, 0], [0, 0], [0, 0]], #
[[6, 7], [8, 0], [0, 0]]
] #
[[6, 7], [8, 0], [0, 0]], #
]
},
{
'rt_input': [[[1, 2], [], [3, 4]], [], [[5]], [[6, 7], [8]]],
'default':
9,
'expected': [
[[1, 2], [9, 9], [3, 4]], #
[[9, 9], [9, 9], [9, 9]], #
[[5, 9], [9, 9], [9, 9]], #
[[6, 7], [8, 9], [9, 9]], #
]
},
{
'rt_input': [[[1], [2], [3]]],
'ragged_rank': 1,
'default': 0,
'expected': [[[1], [2], [3]]],
},
{
'rt_input': [[[[1], [2]], [], [[3]]]],
'default': 9,
'expected': [[[[1], [2]], [[9], [9]], [[3], [9]]]],
},
)
def testRaggedTensorToTensor(self,
@ -96,17 +118,13 @@ class RaggedTensorToTensorOpTest(test_util.TensorFlowTestCase,
{
'rt_input': [[1, 2, 3]],
'default': [0],
'error': (ValueError, r'Shapes \(1,\) and \(\) are incompatible'),
'error': (ValueError, r'Shape \(1,\) must have rank at most 0'),
},
{
'rt_input': [[[1], [2], [3]]],
'default': 0,
'error': (ValueError, r'Shapes \(\) and \(1,\) are incompatible'),
},
{
'rt_input': [[[[1], [2]], [], [[3]]]],
'default': 0,
'error': (ValueError, r'Shapes \(\) and \(1,\) are incompatible'),
'rt_input': [[[1, 2], [3, 4]], [[5, 6]]],
'ragged_rank': 1,
'default': [7, 8, 9],
'error': (ValueError, r'Shapes \(3,\) and \(2,\) are incompatible'),
},
{
'rt_input': [[1, 2, 3]],
@ -114,9 +132,10 @@ class RaggedTensorToTensorOpTest(test_util.TensorFlowTestCase,
'error': (TypeError, "Expected int32, got 'a' of type 'str' instead"),
},
)
def testError(self, rt_input, default, error):
rt = ragged.constant(rt_input)
self.assertRaisesRegexp(error[0], error[1], ragged.to_tensor, rt, default)
def testError(self, rt_input, default, error, ragged_rank=None):
rt = ragged.constant(rt_input, ragged_rank=ragged_rank)
with self.assertRaisesRegexp(error[0], error[1]):
ragged.to_tensor(rt, default)
if __name__ == '__main__':