Fix bug in RaggedGather kernel when indices is a dense tensor with at least 3 dimensions. (The generated row_splits corresponding to indices were too short.)

PiperOrigin-RevId: 297603955
Change-Id: I70f72cb5611262796cac8b56ea1c343e1e1b7fdd
This commit is contained in:
Edward Loper 2020-02-27 08:30:52 -08:00 committed by TensorFlower Gardener
parent e92f30fa54
commit 66fb3dbbd7
2 changed files with 25 additions and 9 deletions

View File

@ -138,16 +138,16 @@ class RaggedGatherOpBase : public OpKernel {
// Add `splits` that come from all but the last dimension of the dense
// Tensor `indices`. In particular, for each dimension D, we add a
// splits tensor whose values are:
// range(splits.shape[D]*splits.shape[D+1] + 1, step=splits.shape[D+1])
// E.g., if indices.shape=[5, 3] then we will add a splits tensor
// [0, 3, 6, 9, 12, 15], since the outermost dimension has 5 elements,
// each of which contains 3 values.
// range(reduce_prod(splits.shape[:D]) + 1) * splits.shape[D+1]
// E.g., if indices.shape=[2, 3, 4] then we will add splits tensors:
// [0, 3, 6] # length=2+1, stride=3
// [0, 4, 8, 12, 16, 20, 24] # length=2*3+1, stride=4
int nrows = 1;
for (int dim = 0; dim < indices_in.dims() - 1; ++dim) {
int stride = indices_in.dim_size(dim + 1);
int index = stride;
for (int i = 0; i < indices_in.dim_size(dim); ++i) {
out_splits->at(dim).push_back(index);
index += stride;
nrows *= indices_in.dim_size(dim);
int row_length = indices_in.dim_size(dim + 1);
for (int i = 1; i < nrows + 1; ++i) {
out_splits->at(dim).push_back(i * row_length);
}
}

View File

@ -103,6 +103,22 @@ class RaggedGatherOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
[[[b'g']], [[b'g']]]] # [p2, p2]]
) # pyformat: disable
def test3DRaggedParamsAnd3DTensorIndices(self):
params = ragged_factory_ops.constant([[['a', 'b'], []], # p0
[['c', 'd'], ['e'], ['f']], # p1
[['g']] # p2
]) # pyformat: disable
indices = [[[1, 2], [0, 1], [2, 2]], [[0, 0], [1, 2], [0, 1]]]
self.assertAllEqual(
ragged_gather_ops.gather(params, indices),
[[[[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]], # [[p1, p2],
[[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]], # [p0, p1],
[[[b'g']], [[b'g']]]], # [p2, p2]]
[[[[b'a', b'b'], []], [[b'a', b'b'], []]], # [[p0, p0],
[[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]], # [p1, p2],
[[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]]]] # [p0, p1]]
) # pyformat: disable
def testTensorParamsAnd4DRaggedIndices(self):
indices = ragged_factory_ops.constant(
[[[[3, 4], [0, 6]], []], [[[2, 1], [1, 0]], [[2, 5]], [[2, 3]]],