diff --git a/tensorflow/core/kernels/ragged_gather_op.cc b/tensorflow/core/kernels/ragged_gather_op.cc index 623b848a656..88c0d1ebd69 100644 --- a/tensorflow/core/kernels/ragged_gather_op.cc +++ b/tensorflow/core/kernels/ragged_gather_op.cc @@ -138,16 +138,16 @@ class RaggedGatherOpBase : public OpKernel { // Add `splits` that come from all but the last dimension of the dense // Tensor `indices`. In particular, for each dimension D, we add a // splits tensor whose values are: - // range(splits.shape[D]*splits.shape[D+1] + 1, step=splits.shape[D+1]) - // E.g., if indices.shape=[5, 3] then we will add a splits tensor - // [0, 3, 6, 9, 12, 15], since the outermost dimension has 5 elements, - // each of which contains 3 values. + // range(reduce_prod(splits.shape[:D]) + 1) * splits.shape[D+1] + // E.g., if indices.shape=[2, 3, 4] then we will add splits tensors: + // [0, 3, 6] # length=2+1, stride=3 + // [0, 4, 8, 12, 16, 20, 24] # length=2*3+1, stride=4 + int nrows = 1; for (int dim = 0; dim < indices_in.dims() - 1; ++dim) { - int stride = indices_in.dim_size(dim + 1); - int index = stride; - for (int i = 0; i < indices_in.dim_size(dim); ++i) { - out_splits->at(dim).push_back(index); - index += stride; + nrows *= indices_in.dim_size(dim); + int row_length = indices_in.dim_size(dim + 1); + for (int i = 1; i < nrows + 1; ++i) { + out_splits->at(dim).push_back(i * row_length); } } diff --git a/tensorflow/python/ops/ragged/ragged_gather_op_test.py b/tensorflow/python/ops/ragged/ragged_gather_op_test.py index 8138a10b6c7..99f6316c26c 100644 --- a/tensorflow/python/ops/ragged/ragged_gather_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_gather_op_test.py @@ -103,6 +103,22 @@ class RaggedGatherOpTest(test_util.TensorFlowTestCase, parameterized.TestCase): [[[b'g']], [[b'g']]]] # [p2, p2]] ) # pyformat: disable + def test3DRaggedParamsAnd3DTensorIndices(self): + params = ragged_factory_ops.constant([[['a', 'b'], []], # p0 + [['c', 'd'], ['e'], ['f']], # p1 + [['g']] # p2 + ]) # pyformat: disable + indices = [[[1, 2], [0, 1], [2, 2]], [[0, 0], [1, 2], [0, 1]]] + self.assertAllEqual( + ragged_gather_ops.gather(params, indices), + [[[[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]], # [[p1, p2], + [[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]], # [p0, p1], + [[[b'g']], [[b'g']]]], # [p2, p2]] + [[[[b'a', b'b'], []], [[b'a', b'b'], []]], # [[p0, p0], + [[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]], # [p1, p2], + [[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]]]] # [p0, p1]] + ) # pyformat: disable + def testTensorParamsAnd4DRaggedIndices(self): indices = ragged_factory_ops.constant( [[[[3, 4], [0, 6]], []], [[[2, 1], [1, 0]], [[2, 5]], [[2, 3]]],