Added a function for nested_value_rowids, similar to nested_row_lengths.
This will help when handling value_rowids natively. PiperOrigin-RevId: 257512332
This commit is contained in:
parent
19c530e4aa
commit
2246b38305
@ -980,6 +980,44 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
with ops.name_scope(name, "RaggedValueRowIds", [self]):
|
||||
return segment_id_ops.row_splits_to_segment_ids(self.row_splits)
|
||||
|
||||
def nested_value_rowids(self, name=None):
|
||||
"""Returns a tuple containing the value_rowids for all ragged dimensions.
|
||||
|
||||
`rt.nested_value_rowids` is a tuple containing the `value_rowids` tensors
|
||||
for
|
||||
all ragged dimensions in `rt`, ordered from outermost to innermost. In
|
||||
particular, `rt.nested_value_rowids = (rt.value_rowids(),) + value_ids`
|
||||
where:
|
||||
|
||||
* `value_ids = ()` if `rt.values` is a `Tensor`.
|
||||
* `value_ids = rt.values.nested_value_rowids` otherwise.
|
||||
|
||||
Args:
|
||||
name: A name prefix for the returned tensors (optional).
|
||||
|
||||
Returns:
|
||||
A `tuple` of 1-D integer `Tensor`s.
|
||||
|
||||
#### Example:
|
||||
|
||||
```python
|
||||
>>> rt = ragged.constant([[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]])
|
||||
>>> for i, ids in enumerate(rt.nested_value_rowids()):
|
||||
... print('row ids for dimension %d: %s' % (i+1, ids))
|
||||
row ids for dimension 1: [0]
|
||||
row ids for dimension 2: [0, 0, 0, 2, 2]
|
||||
row ids for dimension 3: [0, 0, 0, 0, 2, 2, 2, 3]
|
||||
```
|
||||
|
||||
"""
|
||||
with ops.name_scope(name, "RaggedNestedValueRowIds", [self]):
|
||||
rt_nested_ids = [self.value_rowids()]
|
||||
rt_values = self.values
|
||||
while isinstance(rt_values, RaggedTensor):
|
||||
rt_nested_ids.append(rt_values.value_rowids())
|
||||
rt_values = rt_values.values
|
||||
return tuple(rt_nested_ids)
|
||||
|
||||
def nrows(self, out_type=None, name=None):
|
||||
"""Returns the number of rows in this ragged tensor.
|
||||
|
||||
@ -1106,8 +1144,8 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
def nested_row_lengths(self, name=None):
|
||||
"""Returns a tuple containing the row_lengths for all ragged dimensions.
|
||||
|
||||
`rtnested_row_lengths()` is a tuple containing the `row_lengths` tensors for
|
||||
all ragged dimensions in `rt`, ordered from outermost to innermost.
|
||||
`rt.nested_row_lengths()` is a tuple containing the `row_lengths` tensors
|
||||
for all ragged dimensions in `rt`, ordered from outermost to innermost.
|
||||
|
||||
Args:
|
||||
name: A name prefix for the returned tensors (optional).
|
||||
|
||||
@ -651,6 +651,9 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]])
|
||||
self.assertLen(rt.nested_row_splits, 1)
|
||||
self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 2, 5, 6, 7])
|
||||
self.assertLen(rt.nested_value_rowids(), 1)
|
||||
|
||||
self.assertAllEqual(rt.nested_value_rowids()[0], [0, 0, 2, 2, 2, 3, 4])
|
||||
|
||||
def testRaggedTensorAccessors_3d_with_ragged_rank_2(self):
|
||||
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
|
||||
@ -685,6 +688,9 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
self.assertLen(rt.nested_row_splits, 2)
|
||||
self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 3, 3, 5])
|
||||
self.assertAllEqual(rt.nested_row_splits[1], [0, 2, 2, 5, 6, 7])
|
||||
self.assertLen(rt.nested_value_rowids(), 2)
|
||||
self.assertAllEqual(rt.nested_value_rowids()[0], [0, 0, 1, 3, 3])
|
||||
self.assertAllEqual(rt.nested_value_rowids()[1], [0, 0, 2, 2, 2, 3, 4])
|
||||
|
||||
#=============================================================================
|
||||
# RaggedTensor.shape
|
||||
|
||||
@ -87,6 +87,10 @@ tf_class {
|
||||
name: "nested_row_lengths"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "nested_value_rowids"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "nrows"
|
||||
argspec: "args=[\'self\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
||||
@ -87,6 +87,10 @@ tf_class {
|
||||
name: "nested_row_lengths"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "nested_value_rowids"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "nrows"
|
||||
argspec: "args=[\'self\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user