PR #42549: Fix GatherV2 shape inference

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/42549

Fixes #42522.

8a002f2269/tensorflow/core/kernels/gather_op.cc (L92-L107)
Copybara import of the project:

--
1b1979fccdfe3035aedbb17d919e1bdabb2cc3f1 by Tzu-Wei Sung <windqaq@gmail.com>:

Fix GatherV2 shape inference

Test against indices

Add 1

PiperOrigin-RevId: 332178331
Change-Id: I28eec64ec6b2b5d9f75866d3882ffdec904cd5eb
This commit is contained in:
A. Unique TensorFlower 2020-09-17 00:45:35 -07:00 committed by TensorFlower Gardener
parent e8e87bcd80
commit 151800d58a
2 changed files with 3 additions and 15 deletions

View File

@ -1219,15 +1219,9 @@ REGISTER_OP("GatherV2")
// Note, batch_dims can be negative.
int32 batch_dims;
TF_RETURN_IF_ERROR(c->GetAttr("batch_dims", &batch_dims));
// -rank(indices) <= batch_dims <= rank(indices)
TF_RETURN_IF_ERROR(
c->WithRankAtLeast(indices_shape, std::abs(batch_dims), &unused));
if (batch_dims < 0) {
batch_dims += c->Rank(indices_shape);
}
// rank(params) > batch_dims
TF_RETURN_IF_ERROR(
c->WithRankAtLeast(params_shape, batch_dims + 1, &unused));
TF_RETURN_IF_ERROR(c->WithRankAtLeast(
params_shape, batch_dims < 0 ? -batch_dims : batch_dims + 1,
&unused));
ShapeHandle params_outer_subshape;
TF_RETURN_IF_ERROR(

View File

@ -436,12 +436,6 @@ class GatherTest(test.TestCase, parameterized.TestCase):
params=[[10, 11, 12], [13, 14, 15]],
indices=[1, 0],
expected=[[11, 10], [14, 13]]),
dict( # 3D indices, batch_dims=-3, axis=1
batch_dims=-3,
axis=1,
params=[[0, 1, 2], [3, 4, 5]],
indices=[[[0, 1], [1, 0]]],
expected=[[[[0, 1], [1, 0]]], [[[3, 4], [4, 3]]]]),
])
@test_util.run_in_graph_and_eager_modes
def testBatchDims(self, params, indices, batch_dims, expected=None,