Add broadcasting support for tf.ragged.to_tensor()'s default_value argument.
PiperOrigin-RevId: 221828849
This commit is contained in:
parent
4070b1650c
commit
74660a6db3
tensorflow/python/ops/ragged
@ -196,7 +196,7 @@ def to_tensor(rt_input, default_value=None, name=None):
|
||||
Args:
|
||||
rt_input: The input `RaggedTensor`.
|
||||
default_value: Value to set for indices not specified in `rt_input`.
|
||||
Defaults to zero. `default_value.shape` must be equal to
|
||||
Defaults to zero. `default_value` must be broadcastable to
|
||||
`rt_input.shape[rt_input.ragged_rank + 1:]`.
|
||||
name: A name prefix for the returned tensors (optional).
|
||||
|
||||
@ -210,6 +210,9 @@ def to_tensor(rt_input, default_value=None, name=None):
|
||||
rt_input, name='rt_input')
|
||||
if not ragged_tensor.is_ragged(rt_input):
|
||||
return rt_input # already dense
|
||||
if default_value is not None:
|
||||
default_value = ops.convert_to_tensor(
|
||||
default_value, name='default_value', dtype=rt_input.dtype)
|
||||
|
||||
# If ragged_rank > 1, then recursively convert the ragged values into a
|
||||
# `Tensor` before we proceed.
|
||||
@ -217,6 +220,16 @@ def to_tensor(rt_input, default_value=None, name=None):
|
||||
if ragged_tensor.is_ragged(values):
|
||||
values = to_tensor(values, default_value)
|
||||
|
||||
# Tile the default value, if necessary.
|
||||
if default_value is not None:
|
||||
if values.shape.ndims is not None:
|
||||
default_value.shape.with_rank_at_most(values.shape.ndims - 1)
|
||||
if (values.shape.ndims is None or default_value.shape.ndims is None or
|
||||
values.shape.ndims != default_value.shape.ndims + 1):
|
||||
value_shape = array_ops.shape(values)[1:]
|
||||
default_value = array_ops.broadcast_to(default_value, value_shape)
|
||||
default_value.shape.assert_is_compatible_with(values.shape[1:])
|
||||
|
||||
# Get the expected dense shape ([nrows, ncols] + value_shape).
|
||||
rt_row_lengths = [rt_input.row_splits[1:] - rt_input.row_splits[:-1]]
|
||||
nrows = array_ops.shape(rt_input.row_splits, out_type=dtypes.int64)[0] - 1
|
||||
@ -228,9 +241,6 @@ def to_tensor(rt_input, default_value=None, name=None):
|
||||
# Build a default value if none was supplied.
|
||||
if default_value is None:
|
||||
default_value = array_ops.zeros(value_shape, dtype=values.dtype)
|
||||
else:
|
||||
default_value = ops.convert_to_tensor(
|
||||
default_value, name='default_value', dtype=values.dtype)
|
||||
default_value.shape.assert_is_compatible_with(values.shape[1:])
|
||||
default_value.set_shape(values.shape[1:])
|
||||
|
||||
|
@ -71,8 +71,30 @@ class RaggedTensorToTensorOpTest(test_util.TensorFlowTestCase,
|
||||
[[1, 2], [0, 0], [3, 4]], #
|
||||
[[0, 0], [0, 0], [0, 0]], #
|
||||
[[5, 0], [0, 0], [0, 0]], #
|
||||
[[6, 7], [8, 0], [0, 0]]
|
||||
] #
|
||||
[[6, 7], [8, 0], [0, 0]], #
|
||||
]
|
||||
},
|
||||
{
|
||||
'rt_input': [[[1, 2], [], [3, 4]], [], [[5]], [[6, 7], [8]]],
|
||||
'default':
|
||||
9,
|
||||
'expected': [
|
||||
[[1, 2], [9, 9], [3, 4]], #
|
||||
[[9, 9], [9, 9], [9, 9]], #
|
||||
[[5, 9], [9, 9], [9, 9]], #
|
||||
[[6, 7], [8, 9], [9, 9]], #
|
||||
]
|
||||
},
|
||||
{
|
||||
'rt_input': [[[1], [2], [3]]],
|
||||
'ragged_rank': 1,
|
||||
'default': 0,
|
||||
'expected': [[[1], [2], [3]]],
|
||||
},
|
||||
{
|
||||
'rt_input': [[[[1], [2]], [], [[3]]]],
|
||||
'default': 9,
|
||||
'expected': [[[[1], [2]], [[9], [9]], [[3], [9]]]],
|
||||
},
|
||||
)
|
||||
def testRaggedTensorToTensor(self,
|
||||
@ -96,17 +118,13 @@ class RaggedTensorToTensorOpTest(test_util.TensorFlowTestCase,
|
||||
{
|
||||
'rt_input': [[1, 2, 3]],
|
||||
'default': [0],
|
||||
'error': (ValueError, r'Shapes \(1,\) and \(\) are incompatible'),
|
||||
'error': (ValueError, r'Shape \(1,\) must have rank at most 0'),
|
||||
},
|
||||
{
|
||||
'rt_input': [[[1], [2], [3]]],
|
||||
'default': 0,
|
||||
'error': (ValueError, r'Shapes \(\) and \(1,\) are incompatible'),
|
||||
},
|
||||
{
|
||||
'rt_input': [[[[1], [2]], [], [[3]]]],
|
||||
'default': 0,
|
||||
'error': (ValueError, r'Shapes \(\) and \(1,\) are incompatible'),
|
||||
'rt_input': [[[1, 2], [3, 4]], [[5, 6]]],
|
||||
'ragged_rank': 1,
|
||||
'default': [7, 8, 9],
|
||||
'error': (ValueError, r'Shapes \(3,\) and \(2,\) are incompatible'),
|
||||
},
|
||||
{
|
||||
'rt_input': [[1, 2, 3]],
|
||||
@ -114,9 +132,10 @@ class RaggedTensorToTensorOpTest(test_util.TensorFlowTestCase,
|
||||
'error': (TypeError, "Expected int32, got 'a' of type 'str' instead"),
|
||||
},
|
||||
)
|
||||
def testError(self, rt_input, default, error):
|
||||
rt = ragged.constant(rt_input)
|
||||
self.assertRaisesRegexp(error[0], error[1], ragged.to_tensor, rt, default)
|
||||
def testError(self, rt_input, default, error, ragged_rank=None):
|
||||
rt = ragged.constant(rt_input, ragged_rank=ragged_rank)
|
||||
with self.assertRaisesRegexp(error[0], error[1]):
|
||||
ragged.to_tensor(rt, default)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user