Fix bug in RaggedGather kernel when indices
is a dense tensor with at least 3 dimensions. (The generated row_splits corresponding to indices were too short.)
PiperOrigin-RevId: 297603955 Change-Id: I70f72cb5611262796cac8b56ea1c343e1e1b7fdd
This commit is contained in:
parent
e92f30fa54
commit
66fb3dbbd7
@ -138,16 +138,16 @@ class RaggedGatherOpBase : public OpKernel {
|
|||||||
// Add `splits` that come from all but the last dimension of the dense
|
// Add `splits` that come from all but the last dimension of the dense
|
||||||
// Tensor `indices`. In particular, for each dimension D, we add a
|
// Tensor `indices`. In particular, for each dimension D, we add a
|
||||||
// splits tensor whose values are:
|
// splits tensor whose values are:
|
||||||
// range(splits.shape[D]*splits.shape[D+1] + 1, step=splits.shape[D+1])
|
// range(reduce_prod(splits.shape[:D]) + 1) * splits.shape[D+1]
|
||||||
// E.g., if indices.shape=[5, 3] then we will add a splits tensor
|
// E.g., if indices.shape=[2, 3, 4] then we will add splits tensors:
|
||||||
// [0, 3, 6, 9, 12, 15], since the outermost dimension has 5 elements,
|
// [0, 3, 6] # length=2+1, stride=3
|
||||||
// each of which contains 3 values.
|
// [0, 4, 8, 12, 16, 20, 24] # length=2*3+1, stride=4
|
||||||
|
int nrows = 1;
|
||||||
for (int dim = 0; dim < indices_in.dims() - 1; ++dim) {
|
for (int dim = 0; dim < indices_in.dims() - 1; ++dim) {
|
||||||
int stride = indices_in.dim_size(dim + 1);
|
nrows *= indices_in.dim_size(dim);
|
||||||
int index = stride;
|
int row_length = indices_in.dim_size(dim + 1);
|
||||||
for (int i = 0; i < indices_in.dim_size(dim); ++i) {
|
for (int i = 1; i < nrows + 1; ++i) {
|
||||||
out_splits->at(dim).push_back(index);
|
out_splits->at(dim).push_back(i * row_length);
|
||||||
index += stride;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,6 +103,22 @@ class RaggedGatherOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
[[[b'g']], [[b'g']]]] # [p2, p2]]
|
[[[b'g']], [[b'g']]]] # [p2, p2]]
|
||||||
) # pyformat: disable
|
) # pyformat: disable
|
||||||
|
|
||||||
|
def test3DRaggedParamsAnd3DTensorIndices(self):
|
||||||
|
params = ragged_factory_ops.constant([[['a', 'b'], []], # p0
|
||||||
|
[['c', 'd'], ['e'], ['f']], # p1
|
||||||
|
[['g']] # p2
|
||||||
|
]) # pyformat: disable
|
||||||
|
indices = [[[1, 2], [0, 1], [2, 2]], [[0, 0], [1, 2], [0, 1]]]
|
||||||
|
self.assertAllEqual(
|
||||||
|
ragged_gather_ops.gather(params, indices),
|
||||||
|
[[[[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]], # [[p1, p2],
|
||||||
|
[[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]], # [p0, p1],
|
||||||
|
[[[b'g']], [[b'g']]]], # [p2, p2]]
|
||||||
|
[[[[b'a', b'b'], []], [[b'a', b'b'], []]], # [[p0, p0],
|
||||||
|
[[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]], # [p1, p2],
|
||||||
|
[[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]]]] # [p0, p1]]
|
||||||
|
) # pyformat: disable
|
||||||
|
|
||||||
def testTensorParamsAnd4DRaggedIndices(self):
|
def testTensorParamsAnd4DRaggedIndices(self):
|
||||||
indices = ragged_factory_ops.constant(
|
indices = ragged_factory_ops.constant(
|
||||||
[[[[3, 4], [0, 6]], []], [[[2, 1], [1, 0]], [[2, 5]], [[2, 3]]],
|
[[[[3, 4], [0, 6]], []], [[[2, 1], [1, 0]], [[2, 5]], [[2, 3]]],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user