PR #42549: Fix GatherV2 shape inference
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/42549
Fixes #42522.
8a002f2269/tensorflow/core/kernels/gather_op.cc (L92-L107)
Copybara import of the project:
--
1b1979fccdfe3035aedbb17d919e1bdabb2cc3f1 by Tzu-Wei Sung <windqaq@gmail.com>:
Fix GatherV2 shape inference
Test against indices
Add 1
PiperOrigin-RevId: 332178331
Change-Id: I28eec64ec6b2b5d9f75866d3882ffdec904cd5eb
This commit is contained in:
parent
e8e87bcd80
commit
151800d58a
@ -1219,15 +1219,9 @@ REGISTER_OP("GatherV2")
|
||||
// Note, batch_dims can be negative.
|
||||
int32 batch_dims;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("batch_dims", &batch_dims));
|
||||
// -rank(indices) <= batch_dims <= rank(indices)
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRankAtLeast(indices_shape, std::abs(batch_dims), &unused));
|
||||
if (batch_dims < 0) {
|
||||
batch_dims += c->Rank(indices_shape);
|
||||
}
|
||||
// rank(params) > batch_dims
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRankAtLeast(params_shape, batch_dims + 1, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(
|
||||
params_shape, batch_dims < 0 ? -batch_dims : batch_dims + 1,
|
||||
&unused));
|
||||
|
||||
ShapeHandle params_outer_subshape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
@ -436,12 +436,6 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
||||
params=[[10, 11, 12], [13, 14, 15]],
|
||||
indices=[1, 0],
|
||||
expected=[[11, 10], [14, 13]]),
|
||||
dict( # 3D indices, batch_dims=-3, axis=1
|
||||
batch_dims=-3,
|
||||
axis=1,
|
||||
params=[[0, 1, 2], [3, 4, 5]],
|
||||
indices=[[[0, 1], [1, 0]]],
|
||||
expected=[[[[0, 1], [1, 0]]], [[[3, 4], [4, 3]]]]),
|
||||
])
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testBatchDims(self, params, indices, batch_dims, expected=None,
|
||||
|
Loading…
x
Reference in New Issue
Block a user