Added a function for nested_value_rowids, similar to nested_row_lengths.

This will help when handling value_rowids natively.

PiperOrigin-RevId: 257512332
This commit is contained in:
A. Unique TensorFlower 2019-07-10 17:04:40 -07:00 committed by TensorFlower Gardener
parent 19c530e4aa
commit 2246b38305
4 changed files with 54 additions and 2 deletions

View File

@ -980,6 +980,44 @@ class RaggedTensor(composite_tensor.CompositeTensor):
with ops.name_scope(name, "RaggedValueRowIds", [self]):
return segment_id_ops.row_splits_to_segment_ids(self.row_splits)
def nested_value_rowids(self, name=None):
"""Returns a tuple containing the value_rowids for all ragged dimensions.
`rt.nested_value_rowids` is a tuple containing the `value_rowids` tensors
for
all ragged dimensions in `rt`, ordered from outermost to innermost. In
particular, `rt.nested_value_rowids = (rt.value_rowids(),) + value_ids`
where:
* `value_ids = ()` if `rt.values` is a `Tensor`.
* `value_ids = rt.values.nested_value_rowids` otherwise.
Args:
name: A name prefix for the returned tensors (optional).
Returns:
A `tuple` of 1-D integer `Tensor`s.
#### Example:
```python
>>> rt = ragged.constant([[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]])
>>> for i, ids in enumerate(rt.nested_value_rowids()):
... print('row ids for dimension %d: %s' % (i+1, ids))
row ids for dimension 1: [0]
row ids for dimension 2: [0, 0, 0, 2, 2]
row ids for dimension 3: [0, 0, 0, 0, 2, 2, 2, 3]
```
"""
with ops.name_scope(name, "RaggedNestedValueRowIds", [self]):
rt_nested_ids = [self.value_rowids()]
rt_values = self.values
while isinstance(rt_values, RaggedTensor):
rt_nested_ids.append(rt_values.value_rowids())
rt_values = rt_values.values
return tuple(rt_nested_ids)
def nrows(self, out_type=None, name=None):
"""Returns the number of rows in this ragged tensor.
@ -1106,8 +1144,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
def nested_row_lengths(self, name=None):
"""Returns a tuple containing the row_lengths for all ragged dimensions.
`rtnested_row_lengths()` is a tuple containing the `row_lengths` tensors for
all ragged dimensions in `rt`, ordered from outermost to innermost.
`rt.nested_row_lengths()` is a tuple containing the `row_lengths` tensors
for all ragged dimensions in `rt`, ordered from outermost to innermost.
Args:
name: A name prefix for the returned tensors (optional).

View File

@ -651,6 +651,9 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]])
self.assertLen(rt.nested_row_splits, 1)
self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 2, 5, 6, 7])
self.assertLen(rt.nested_value_rowids(), 1)
self.assertAllEqual(rt.nested_value_rowids()[0], [0, 0, 2, 2, 2, 3, 4])
def testRaggedTensorAccessors_3d_with_ragged_rank_2(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
@ -685,6 +688,9 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertLen(rt.nested_row_splits, 2)
self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 3, 3, 5])
self.assertAllEqual(rt.nested_row_splits[1], [0, 2, 2, 5, 6, 7])
self.assertLen(rt.nested_value_rowids(), 2)
self.assertAllEqual(rt.nested_value_rowids()[0], [0, 0, 1, 3, 3])
self.assertAllEqual(rt.nested_value_rowids()[1], [0, 0, 2, 2, 2, 3, 4])
#=============================================================================
# RaggedTensor.shape

View File

@ -87,6 +87,10 @@ tf_class {
name: "nested_row_lengths"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "nested_value_rowids"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "nrows"
argspec: "args=[\'self\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "

View File

@ -87,6 +87,10 @@ tf_class {
name: "nested_row_lengths"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "nested_value_rowids"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "nrows"
argspec: "args=[\'self\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "