Add take_along_axis to numpy_ops
PiperOrigin-RevId: 316202647 Change-Id: I1c0253f72ec6f8b1f1d67e5547e1f64f9c65cb42
This commit is contained in:
parent
49633b387b
commit
619b0b276f
@ -978,13 +978,11 @@ def transpose(a, axes=None):
|
||||
|
||||
@np_utils.np_doc(np.swapaxes)
|
||||
def swapaxes(a, axis1, axis2): # pylint: disable=missing-docstring
|
||||
a = asarray(a)
|
||||
a = asarray(a).data
|
||||
|
||||
a_rank = array_ops.rank(a)
|
||||
if axis1 < 0:
|
||||
axis1 += a_rank
|
||||
if axis2 < 0:
|
||||
axis2 += a_rank
|
||||
axis1 = array_ops.where_v2(axis1 < 0, axis1 + a_rank, axis1)
|
||||
axis2 = array_ops.where_v2(axis2 < 0, axis2 + a_rank, axis2)
|
||||
|
||||
perm = math_ops.range(a_rank)
|
||||
perm = array_ops.tensor_scatter_update(perm, [[axis1], [axis2]],
|
||||
@ -1646,3 +1644,65 @@ def sign(x, out=None, where=None, **kwargs): # pylint: disable=missing-docstrin
|
||||
result = math_ops.sign(x.data)
|
||||
|
||||
return np_utils.tensor_to_ndarray(result)
|
||||
|
||||
|
||||
# Note that np.take_along_axis may not be present in some supported versions of
|
||||
# numpy.
|
||||
@np_utils.np_doc(None, np_fun_name='take_along_axis')
|
||||
def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring
|
||||
arr = asarray(arr)
|
||||
indices = asarray(indices)
|
||||
|
||||
if axis is None:
|
||||
return take_along_axis(arr.ravel(), indices, 0)
|
||||
|
||||
arr = arr.data
|
||||
indices = indices.data
|
||||
|
||||
rank = array_ops.rank(arr)
|
||||
axis = array_ops.where_v2(axis < 0, axis + rank, axis)
|
||||
|
||||
# Broadcast shapes to match, ensure that the axis of interest is not
|
||||
# broadcast.
|
||||
arr_shape_original = array_ops.shape(arr)
|
||||
indices_shape_original = array_ops.shape(indices)
|
||||
arr_shape = array_ops.tensor_scatter_update(
|
||||
arr_shape_original, [[axis]], [1])
|
||||
indices_shape = array_ops.tensor_scatter_update(
|
||||
indices_shape_original, [[axis]], [1])
|
||||
broadcasted_shape = array_ops.broadcast_dynamic_shape(
|
||||
arr_shape, indices_shape)
|
||||
arr_shape = array_ops.tensor_scatter_update(
|
||||
broadcasted_shape, [[axis]], [arr_shape_original[axis]])
|
||||
indices_shape = array_ops.tensor_scatter_update(
|
||||
broadcasted_shape, [[axis]], [indices_shape_original[axis]])
|
||||
arr = array_ops.broadcast_to(arr, arr_shape)
|
||||
indices = array_ops.broadcast_to(indices, indices_shape)
|
||||
|
||||
# Save indices shape so we can restore it later.
|
||||
possible_result_shape = indices.shape
|
||||
|
||||
# Correct indices since gather doesn't correctly handle negative indices.
|
||||
indices = array_ops.where_v2(indices < 0, indices + arr_shape[axis], indices)
|
||||
|
||||
swapaxes_ = lambda t: swapaxes(np_utils.tensor_to_ndarray(t), axis, -1).data
|
||||
|
||||
dont_move_axis_to_end = math_ops.equal(axis, rank - 1)
|
||||
arr = np_utils.cond(
|
||||
dont_move_axis_to_end, lambda: arr, lambda: swapaxes_(arr))
|
||||
indices = np_utils.cond(
|
||||
dont_move_axis_to_end, lambda: indices, lambda: swapaxes_(indices))
|
||||
|
||||
arr_shape = array_ops.shape(arr)
|
||||
arr = array_ops.reshape(arr, [-1, arr_shape[-1]])
|
||||
|
||||
indices_shape = array_ops.shape(indices)
|
||||
indices = array_ops.reshape(indices, [-1, indices_shape[-1]])
|
||||
|
||||
result = array_ops.gather(arr, indices, batch_dims=1)
|
||||
result = array_ops.reshape(result, indices_shape)
|
||||
result = np_utils.cond(
|
||||
dont_move_axis_to_end, lambda: result, lambda: swapaxes_(result))
|
||||
result.set_shape(possible_result_shape)
|
||||
|
||||
return np_utils.tensor_to_ndarray(result)
|
||||
|
Loading…
Reference in New Issue
Block a user