Add support for batch_dims < 0
to Gather pfor converter.
PiperOrigin-RevId: 327902700 Change-Id: I5c4d0d2f007bc74f3165f701595f0ba0e3c5a5ba
This commit is contained in:
parent
ccc05be4cf
commit
83615c0ae7
@ -59,6 +59,11 @@ class ArrayTest(PForTestCase):
|
||||
outputs.append(array_ops.gather(y, [i, 1, 2], axis=2, batch_dims=1))
|
||||
outputs.append(array_ops.gather(y, [[2, i], [i, 1], [2, 1]],
|
||||
axis=-1, batch_dims=1))
|
||||
outputs.append(
|
||||
array_ops.gather(y, [[0, 1, 2]] * 3, axis=2, batch_dims=2))
|
||||
outputs.append(array_ops.gather(y, [0, 1, 2], axis=1, batch_dims=-1))
|
||||
outputs.append(
|
||||
array_ops.gather(y, [[0, 1, 2]] * 3, axis=2, batch_dims=-2))
|
||||
|
||||
return outputs
|
||||
|
||||
|
@ -2275,7 +2275,11 @@ def _convert_gather(pfor_input):
|
||||
# it must be picking up all the rows of param.
|
||||
return wrap(param, True)
|
||||
|
||||
if batch_dims > 0:
|
||||
if batch_dims != 0:
|
||||
# Convert `batch_dims` to its positive equivalent if necessary.
|
||||
batch_dims_pos = batch_dims
|
||||
if batch_dims < 0:
|
||||
batch_dims_pos += array_ops.rank(indices)
|
||||
# In order to maintain
|
||||
# indices.shape[:batch_dims] == params.shape[:batch_dims]
|
||||
# with stacked indices, we move the first dimension of `indices` to the
|
||||
@ -2283,8 +2287,9 @@ def _convert_gather(pfor_input):
|
||||
# inserted into the shape of `output` at the `axis` dimension, which is
|
||||
# then transposed to the front (below).
|
||||
order = array_ops.concat([
|
||||
(list(range(1, batch_dims + 1)) + [0]),
|
||||
math_ops.range(batch_dims + 1, array_ops.rank(indices))], axis=0)
|
||||
math_ops.range(1, batch_dims_pos + 1),
|
||||
[0],
|
||||
math_ops.range(batch_dims_pos + 1, array_ops.rank(indices))], axis=0)
|
||||
indices = array_ops.transpose(indices, order)
|
||||
|
||||
output = array_ops.gather(
|
||||
@ -2310,7 +2315,7 @@ def _convert_gather(pfor_input):
|
||||
output = array_ops.gather(
|
||||
param, indices,
|
||||
axis=array_ops.where(axis >= 0, axis + 1, axis),
|
||||
batch_dims=batch_dims + 1)
|
||||
batch_dims=(batch_dims + 1 if batch_dims >= 0 else batch_dims))
|
||||
return wrap(output, True)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user