Add support for batch_dims < 0 to Gather pfor converter.

PiperOrigin-RevId: 327902700
Change-Id: I5c4d0d2f007bc74f3165f701595f0ba0e3c5a5ba
This commit is contained in:
A. Unique TensorFlower 2020-08-21 17:37:52 -07:00 committed by TensorFlower Gardener
parent ccc05be4cf
commit 83615c0ae7
2 changed files with 14 additions and 4 deletions

View File

@ -59,6 +59,11 @@ class ArrayTest(PForTestCase):
outputs.append(array_ops.gather(y, [i, 1, 2], axis=2, batch_dims=1))
outputs.append(array_ops.gather(y, [[2, i], [i, 1], [2, 1]],
axis=-1, batch_dims=1))
outputs.append(
array_ops.gather(y, [[0, 1, 2]] * 3, axis=2, batch_dims=2))
outputs.append(array_ops.gather(y, [0, 1, 2], axis=1, batch_dims=-1))
outputs.append(
array_ops.gather(y, [[0, 1, 2]] * 3, axis=2, batch_dims=-2))
return outputs

View File

@ -2275,7 +2275,11 @@ def _convert_gather(pfor_input):
# it must be picking up all the rows of param.
return wrap(param, True)
if batch_dims > 0:
if batch_dims != 0:
# Convert `batch_dims` to its positive equivalent if necessary.
batch_dims_pos = batch_dims
if batch_dims < 0:
batch_dims_pos += array_ops.rank(indices)
# In order to maintain
# indices.shape[:batch_dims] == params.shape[:batch_dims]
# with stacked indices, we move the first dimension of `indices` to the
@ -2283,8 +2287,9 @@ def _convert_gather(pfor_input):
# inserted into the shape of `output` at the `axis` dimension, which is
# then transposed to the front (below).
order = array_ops.concat([
(list(range(1, batch_dims + 1)) + [0]),
math_ops.range(batch_dims + 1, array_ops.rank(indices))], axis=0)
math_ops.range(1, batch_dims_pos + 1),
[0],
math_ops.range(batch_dims_pos + 1, array_ops.rank(indices))], axis=0)
indices = array_ops.transpose(indices, order)
output = array_ops.gather(
@ -2310,7 +2315,7 @@ def _convert_gather(pfor_input):
output = array_ops.gather(
param, indices,
axis=array_ops.where(axis >= 0, axis + 1, axis),
batch_dims=batch_dims + 1)
batch_dims=(batch_dims + 1 if batch_dims >= 0 else batch_dims))
return wrap(output, True)