Add take_along_axis to numpy_ops

PiperOrigin-RevId: 316202647
Change-Id: I1c0253f72ec6f8b1f1d67e5547e1f64f9c65cb42
This commit is contained in:
Akshay Modi 2020-06-12 16:51:36 -07:00 committed by TensorFlower Gardener
parent 49633b387b
commit 619b0b276f

View File

@ -978,13 +978,11 @@ def transpose(a, axes=None):
@np_utils.np_doc(np.swapaxes)
def swapaxes(a, axis1, axis2): # pylint: disable=missing-docstring
a = asarray(a)
a = asarray(a).data
a_rank = array_ops.rank(a)
if axis1 < 0:
axis1 += a_rank
if axis2 < 0:
axis2 += a_rank
axis1 = array_ops.where_v2(axis1 < 0, axis1 + a_rank, axis1)
axis2 = array_ops.where_v2(axis2 < 0, axis2 + a_rank, axis2)
perm = math_ops.range(a_rank)
perm = array_ops.tensor_scatter_update(perm, [[axis1], [axis2]],
@ -1646,3 +1644,65 @@ def sign(x, out=None, where=None, **kwargs): # pylint: disable=missing-docstrin
result = math_ops.sign(x.data)
return np_utils.tensor_to_ndarray(result)
# Note that np.take_along_axis may not be present in some supported versions of
# numpy.
@np_utils.np_doc(None, np_fun_name='take_along_axis')
def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring
arr = asarray(arr)
indices = asarray(indices)
if axis is None:
return take_along_axis(arr.ravel(), indices, 0)
arr = arr.data
indices = indices.data
rank = array_ops.rank(arr)
axis = array_ops.where_v2(axis < 0, axis + rank, axis)
# Broadcast shapes to match, ensure that the axis of interest is not
# broadcast.
arr_shape_original = array_ops.shape(arr)
indices_shape_original = array_ops.shape(indices)
arr_shape = array_ops.tensor_scatter_update(
arr_shape_original, [[axis]], [1])
indices_shape = array_ops.tensor_scatter_update(
indices_shape_original, [[axis]], [1])
broadcasted_shape = array_ops.broadcast_dynamic_shape(
arr_shape, indices_shape)
arr_shape = array_ops.tensor_scatter_update(
broadcasted_shape, [[axis]], [arr_shape_original[axis]])
indices_shape = array_ops.tensor_scatter_update(
broadcasted_shape, [[axis]], [indices_shape_original[axis]])
arr = array_ops.broadcast_to(arr, arr_shape)
indices = array_ops.broadcast_to(indices, indices_shape)
# Save indices shape so we can restore it later.
possible_result_shape = indices.shape
# Correct indices since gather doesn't correctly handle negative indices.
indices = array_ops.where_v2(indices < 0, indices + arr_shape[axis], indices)
swapaxes_ = lambda t: swapaxes(np_utils.tensor_to_ndarray(t), axis, -1).data
dont_move_axis_to_end = math_ops.equal(axis, rank - 1)
arr = np_utils.cond(
dont_move_axis_to_end, lambda: arr, lambda: swapaxes_(arr))
indices = np_utils.cond(
dont_move_axis_to_end, lambda: indices, lambda: swapaxes_(indices))
arr_shape = array_ops.shape(arr)
arr = array_ops.reshape(arr, [-1, arr_shape[-1]])
indices_shape = array_ops.shape(indices)
indices = array_ops.reshape(indices, [-1, indices_shape[-1]])
result = array_ops.gather(arr, indices, batch_dims=1)
result = array_ops.reshape(result, indices_shape)
result = np_utils.cond(
dont_move_axis_to_end, lambda: result, lambda: swapaxes_(result))
result.set_shape(possible_result_shape)
return np_utils.tensor_to_ndarray(result)