Fix bug in RaggedGather kernel when indices
is a dense tensor with at least 3 dimensions. (The generated row_splits corresponding to indices were too short.)
PiperOrigin-RevId: 297603955 Change-Id: I70f72cb5611262796cac8b56ea1c343e1e1b7fdd
This commit is contained in:
parent
e92f30fa54
commit
66fb3dbbd7
@ -138,16 +138,16 @@ class RaggedGatherOpBase : public OpKernel {
|
||||
// Add `splits` that come from all but the last dimension of the dense
|
||||
// Tensor `indices`. In particular, for each dimension D, we add a
|
||||
// splits tensor whose values are:
|
||||
// range(splits.shape[D]*splits.shape[D+1] + 1, step=splits.shape[D+1])
|
||||
// E.g., if indices.shape=[5, 3] then we will add a splits tensor
|
||||
// [0, 3, 6, 9, 12, 15], since the outermost dimension has 5 elements,
|
||||
// each of which contains 3 values.
|
||||
// range(reduce_prod(splits.shape[:D]) + 1) * splits.shape[D+1]
|
||||
// E.g., if indices.shape=[2, 3, 4] then we will add splits tensors:
|
||||
// [0, 3, 6] # length=2+1, stride=3
|
||||
// [0, 4, 8, 12, 16, 20, 24] # length=2*3+1, stride=4
|
||||
int nrows = 1;
|
||||
for (int dim = 0; dim < indices_in.dims() - 1; ++dim) {
|
||||
int stride = indices_in.dim_size(dim + 1);
|
||||
int index = stride;
|
||||
for (int i = 0; i < indices_in.dim_size(dim); ++i) {
|
||||
out_splits->at(dim).push_back(index);
|
||||
index += stride;
|
||||
nrows *= indices_in.dim_size(dim);
|
||||
int row_length = indices_in.dim_size(dim + 1);
|
||||
for (int i = 1; i < nrows + 1; ++i) {
|
||||
out_splits->at(dim).push_back(i * row_length);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -103,6 +103,22 @@ class RaggedGatherOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
[[[b'g']], [[b'g']]]] # [p2, p2]]
|
||||
) # pyformat: disable
|
||||
|
||||
def test3DRaggedParamsAnd3DTensorIndices(self):
|
||||
params = ragged_factory_ops.constant([[['a', 'b'], []], # p0
|
||||
[['c', 'd'], ['e'], ['f']], # p1
|
||||
[['g']] # p2
|
||||
]) # pyformat: disable
|
||||
indices = [[[1, 2], [0, 1], [2, 2]], [[0, 0], [1, 2], [0, 1]]]
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(params, indices),
|
||||
[[[[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]], # [[p1, p2],
|
||||
[[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]], # [p0, p1],
|
||||
[[[b'g']], [[b'g']]]], # [p2, p2]]
|
||||
[[[[b'a', b'b'], []], [[b'a', b'b'], []]], # [[p0, p0],
|
||||
[[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]], # [p1, p2],
|
||||
[[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]]]] # [p0, p1]]
|
||||
) # pyformat: disable
|
||||
|
||||
def testTensorParamsAnd4DRaggedIndices(self):
|
||||
indices = ragged_factory_ops.constant(
|
||||
[[[[3, 4], [0, 6]], []], [[[2, 1], [1, 0]], [[2, 5]], [[2, 3]]],
|
||||
|
Loading…
x
Reference in New Issue
Block a user