tf.numpy: Change a bunch of ops to handle unknown shapes.

Fix logic in sort_ops to handle 1D values.

PiperOrigin-RevId: 316620637
Change-Id: Iedc2ba8aad7673bbe210661bb741bf0660f047aa
This commit is contained in:
A. Unique TensorFlower 2020-06-15 23:00:10 -07:00 committed by TensorFlower Gardener
parent e74010b4e8
commit 42a734170d
5 changed files with 153 additions and 72 deletions

View File

@ -223,9 +223,10 @@ def full(shape, fill_value, dtype=None): # pylint: disable=redefined-outer-name
Raises: Raises:
ValueError: if `fill_value` can not be broadcast to shape `shape`. ValueError: if `fill_value` can not be broadcast to shape `shape`.
""" """
if not isinstance(shape, np_arrays.ndarray):
shape = asarray(np_arrays.convert_to_tensor(shape, dtype_hint=np.int32))
shape = atleast_1d(shape).data
fill_value = asarray(fill_value, dtype=dtype) fill_value = asarray(fill_value, dtype=dtype)
if np_utils.isscalar(shape):
shape = array_ops.reshape(shape, [1])
return np_arrays.tensor_to_ndarray( return np_arrays.tensor_to_ndarray(
array_ops.broadcast_to(fill_value.data, shape)) array_ops.broadcast_to(fill_value.data, shape))
@ -808,16 +809,21 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None): # pylint: d
@np_utils.np_doc(np.std) @np_utils.np_doc(np.std)
def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring
return _reduce( return _reduce(
math_ops.reduce_std, a, axis=axis, dtype=None, keepdims=keepdims, math_ops.reduce_std,
a,
axis=axis,
dtype=None,
keepdims=keepdims,
promote_int=_TO_FLOAT) promote_int=_TO_FLOAT)
@np_utils.np_doc(np.ravel) @np_utils.np_doc(np.ravel)
def ravel(a): # pylint: disable=missing-docstring def ravel(a): # pylint: disable=missing-docstring
a = asarray(a) a = asarray(a)
if a.ndim == 1: out = np_utils.cond(
return a math_ops.equal(a.ndim, 1), lambda: a.data,
return np_utils.tensor_to_ndarray(array_ops.reshape(a.data, [-1])) lambda: array_ops.reshape(a.data, [-1]))
return np_utils.tensor_to_ndarray(out)
setattr(np_arrays.ndarray, 'ravel', ravel) setattr(np_arrays.ndarray, 'ravel', ravel)
@ -846,7 +852,8 @@ def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring
a = asarray(a).data a = asarray(a).data
original_shape = a._shape_as_list() # pylint: disable=protected-access original_shape = a._shape_as_list() # pylint: disable=protected-access
# Best effort recovery of the shape. # Best effort recovery of the shape.
if original_shape is not None and None not in original_shape: known_shape = original_shape is not None and None not in original_shape
if known_shape:
if not original_shape: if not original_shape:
original_shape = (repeats,) original_shape = (repeats,)
else: else:
@ -865,7 +872,8 @@ def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring
repeats = asarray(repeats).data repeats = asarray(repeats).data
result = array_ops.repeat(a, repeats, axis) result = array_ops.repeat(a, repeats, axis)
result.set_shape(original_shape) if known_shape:
result.set_shape(original_shape)
return np_utils.tensor_to_ndarray(result) return np_utils.tensor_to_ndarray(result)
@ -1287,7 +1295,13 @@ def broadcast_to(array, shape): # pylint: disable=redefined-outer-name
@np_utils.np_doc(np.stack) @np_utils.np_doc(np.stack)
def stack(arrays, axis=0): def stack(arrays, axis=0): # pylint: disable=missing-function-docstring
if isinstance(arrays, (np_arrays.ndarray, ops.Tensor)):
arrays = asarray(arrays)
if axis == 0:
return arrays
else:
return swapaxes(arrays, 0, axis)
arrays = _promote_dtype(*arrays) # pylint: disable=protected-access arrays = _promote_dtype(*arrays) # pylint: disable=protected-access
unwrapped_arrays = [ unwrapped_arrays = [
a.data if isinstance(a, np_arrays.ndarray) else a for a in arrays a.data if isinstance(a, np_arrays.ndarray) else a for a in arrays
@ -1450,6 +1464,8 @@ def tri(N, M=None, k=0, dtype=None): # pylint: disable=invalid-name,missing-doc
@np_utils.np_doc(np.tril) @np_utils.np_doc(np.tril)
def tril(m, k=0): # pylint: disable=missing-docstring def tril(m, k=0): # pylint: disable=missing-docstring
m = asarray(m).data m = asarray(m).data
if m.shape.ndims is None:
raise ValueError('Argument to tril should have known rank')
m_shape = m.shape.as_list() m_shape = m.shape.as_list()
if len(m_shape) < 2: if len(m_shape) < 2:
@ -1470,6 +1486,8 @@ def tril(m, k=0): # pylint: disable=missing-docstring
@np_utils.np_doc(np.triu) @np_utils.np_doc(np.triu)
def triu(m, k=0): # pylint: disable=missing-docstring def triu(m, k=0): # pylint: disable=missing-docstring
m = asarray(m).data m = asarray(m).data
if m.shape.ndims is None:
raise ValueError('Argument to triu should have known rank')
m_shape = m.shape.as_list() m_shape = m.shape.as_list()
if len(m_shape) < 2: if len(m_shape) < 2:

View File

@ -13,6 +13,9 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""ndarray class.""" """ndarray class."""
# pylint: disable=g-direct-tensorflow-import
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
@ -151,13 +154,16 @@ def _slice_helper(tensor, slice_spec, var=None):
def convert_to_tensor(value, dtype=None, dtype_hint=None): def convert_to_tensor(value, dtype=None, dtype_hint=None):
"""Wrapper over `tf.convert_to_tensor`. """Wrapper over `tf.convert_to_tensor`.
Args: Args:
value: value to convert value: value to convert
dtype: (optional) the type we would like it to be converted to. dtype: (optional) the type we would like it to be converted to.
dtype_hint: (optional) soft preference for the type we would like it to dtype_hint: (optional) soft preference for the type we would like it to be
be converted to. `tf.convert_to_tensor` will attempt to convert value converted to. `tf.convert_to_tensor` will attempt to convert value to this
to this type first, but will not fail if conversion is not possible type first, but will not fail if conversion is not possible falling back
falling back to inferring the type instead. to inferring the type instead.
Returns:
Value converted to tf.Tensor.
""" """
# A safer version of `tf.convert_to_tensor` to work around b/149876037. # A safer version of `tf.convert_to_tensor` to work around b/149876037.
# TODO(wangpeng): Remove this function once the bug is fixed. # TODO(wangpeng): Remove this function once the bug is fixed.
@ -250,8 +256,12 @@ class ndarray(object): # pylint: disable=invalid-name
@property @property
def shape(self): def shape(self):
"""Returns a tuple of array dimensions.""" """Returns a tuple or tf.Tensor of array dimensions."""
return self.data._shape_tuple() # pylint: disable=protected-access shape = self.data.shape
if shape.is_fully_defined():
return tuple(shape.as_list())
else:
return array_ops.shape(self.data)
@property @property
def dtype(self): def dtype(self):
@ -259,19 +269,30 @@ class ndarray(object): # pylint: disable=invalid-name
@property @property
def ndim(self): def ndim(self):
return self.data.shape.ndims ndims = self.data.shape.ndims
if ndims is None:
return array_ops.rank(self.data)
else:
return ndims
@property @property
def size(self): def size(self):
"""Returns the number of elements in the array.""" """Returns the number of elements in the array."""
return np.prod(self.shape) shape = self.shape
if isinstance(shape, ops.Tensor):
return array_ops.size(self.data)
else:
return np.prod(self.shape)
@property @property
def T(self): # pylint: disable=invalid-name def T(self): # pylint: disable=invalid-name
return self.transpose() return self.transpose()
def __len__(self): def __len__(self):
if self.shape: shape = self.shape
if isinstance(shape, ops.Tensor):
raise TypeError('len() of symbolic tensor undefined')
elif shape:
return self.shape[0] return self.shape[0]
else: else:
raise TypeError('len() of unsized object.') raise TypeError('len() of unsized object.')
@ -320,6 +341,8 @@ class ndarray(object): # pylint: disable=invalid-name
return tensor_to_ndarray(result_t) return tensor_to_ndarray(result_t)
def __iter__(self): def __iter__(self):
if not isinstance(self.data, ops.EagerTensor):
raise TypeError('Iteration over symbolic tensor is not allowed')
for i in range(self.shape[0]): for i in range(self.shape[0]):
result_t = self.data[i] result_t = self.data[i]
yield tensor_to_ndarray(result_t) yield tensor_to_ndarray(result_t)
@ -356,6 +379,8 @@ class ndarray(object): # pylint: disable=invalid-name
ValueError: If the array does not have size 1. ValueError: If the array does not have size 1.
""" """
# TODO(wangpeng): Handle graph mode # TODO(wangpeng): Handle graph mode
if not isinstance(self.data, ops.EagerTensor):
raise TypeError('Indexing using symbolic tensor is not allowed')
return np.asscalar(self.data.numpy()) return np.asscalar(self.data.numpy())
def tolist(self): def tolist(self):
@ -384,5 +409,3 @@ def ndarray_to_tensor(arr, dtype=None, name=None, as_ref=False):
ops.register_tensor_conversion_function(ndarray, ndarray_to_tensor) ops.register_tensor_conversion_function(ndarray, ndarray_to_tensor)

View File

@ -50,9 +50,9 @@ def dot(a, b): # pylint: disable=missing-docstring
math_ops.equal(array_ops.rank(b), 0)), math_ops.equal(array_ops.rank(b), 0)),
lambda: a * b, lambda: a * b,
lambda: np_utils.cond( # pylint: disable=g-long-lambda lambda: np_utils.cond( # pylint: disable=g-long-lambda
math_ops.equal(array_ops.rank(b), 1), lambda: math_ops.tensordot( math_ops.equal(array_ops.rank(b), 1),
a, b, axes=[[-1], [-1]]), lambda: math_ops.tensordot( lambda: math_ops.tensordot(a, b, axes=[[-1], [-1]]),
a, b, axes=[[-1], [-2]]))) lambda: math_ops.tensordot(a, b, axes=[[-1], [-2]])))
return _bin_op(f, a, b) return _bin_op(f, a, b)
@ -204,8 +204,8 @@ def matmul(x1, x2): # pylint: disable=missing-docstring
return np_utils.cond( return np_utils.cond(
math_ops.equal(array_ops.rank(x2), 1), math_ops.equal(array_ops.rank(x2), 1),
lambda: math_ops.tensordot(x1, x2, axes=1), lambda: math_ops.tensordot(x1, x2, axes=1),
lambda: np_utils.cond( lambda: np_utils.cond( # pylint: disable=g-long-lambda
math_ops.equal(array_ops.rank(x1), 1), # pylint: disable=g-long-lambda math_ops.equal(array_ops.rank(x1), 1),
lambda: math_ops.tensordot( # pylint: disable=g-long-lambda lambda: math_ops.tensordot( # pylint: disable=g-long-lambda
x1, x2, axes=[[0], [-2]]), x1, x2, axes=[[0], [-2]]),
lambda: math_ops.matmul(x1, x2))) lambda: math_ops.matmul(x1, x2)))
@ -352,14 +352,30 @@ def hypot(x1, x2):
def kron(a, b): # pylint: disable=missing-function-docstring def kron(a, b): # pylint: disable=missing-function-docstring
# pylint: disable=protected-access,g-complex-comprehension # pylint: disable=protected-access,g-complex-comprehension
a, b = np_array_ops._promote_dtype(a, b) a, b = np_array_ops._promote_dtype(a, b)
ndim = max(a.ndim, b.ndim) t_a = np_utils.cond(
if a.ndim < ndim: a.ndim < b.ndim,
a = np_array_ops.reshape(a, np_array_ops._pad_left_to(ndim, a.shape)) lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda
if b.ndim < ndim: a.data, np_array_ops._pad_left_to(b.ndim, a.shape)),
b = np_array_ops.reshape(b, np_array_ops._pad_left_to(ndim, b.shape)) lambda: a.data)
a_reshaped = np_array_ops.reshape(a, [i for d in a.shape for i in (d, 1)]) t_b = np_utils.cond(
b_reshaped = np_array_ops.reshape(b, [i for d in b.shape for i in (1, d)]) b.ndim < a.ndim,
out_shape = tuple(np.multiply(a.shape, b.shape)) lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda
b.data, np_array_ops._pad_left_to(a.ndim, b.shape)),
lambda: b.data)
def _make_shape(shape, prepend):
ones = array_ops.ones_like(shape)
if prepend:
shapes = [ones, shape]
else:
shapes = [shape, ones]
return array_ops.reshape(array_ops.stack(shapes, axis=1), [-1])
a_shape = array_ops.shape(t_a)
b_shape = array_ops.shape(t_b)
a_reshaped = np_array_ops.reshape(t_a, _make_shape(a_shape, False))
b_reshaped = np_array_ops.reshape(t_b, _make_shape(b_shape, True))
out_shape = a_shape * b_shape
return np_array_ops.reshape(a_reshaped * b_reshaped, out_shape) return np_array_ops.reshape(a_reshaped * b_reshaped, out_shape)
@ -454,7 +470,8 @@ def _tf_gcd(x1, x2): # pylint: disable=missing-function-docstring
if (not np.issubdtype(x1.dtype.as_numpy_dtype, np.integer) or if (not np.issubdtype(x1.dtype.as_numpy_dtype, np.integer) or
not np.issubdtype(x2.dtype.as_numpy_dtype, np.integer)): not np.issubdtype(x2.dtype.as_numpy_dtype, np.integer)):
raise ValueError('Arguments to gcd must be integers.') raise ValueError('Arguments to gcd must be integers.')
shape = array_ops.broadcast_static_shape(x1.shape, x2.shape) shape = array_ops.broadcast_dynamic_shape(
array_ops.shape(x1), array_ops.shape(x2))
x1 = array_ops.broadcast_to(x1, shape) x1 = array_ops.broadcast_to(x1, shape)
x2 = array_ops.broadcast_to(x2, shape) x2 = array_ops.broadcast_to(x2, shape)
value, _ = control_flow_ops.while_loop(_gcd_cond_fn, _gcd_body_fn, value, _ = control_flow_ops.while_loop(_gcd_cond_fn, _gcd_body_fn,
@ -607,7 +624,7 @@ def signbit(x):
def f(x): def f(x):
if x.dtype == dtypes.bool: if x.dtype == dtypes.bool:
return array_ops.fill(x.shape, False) return array_ops.fill(array_ops.shape(x), False)
return x < 0 return x < 0
return _scalar(f, x) return _scalar(f, x)
@ -866,7 +883,11 @@ def square(x):
def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring
def f(a): def f(a):
# TODO(agarwal): transpose and reshape to N, H, 1 and do a 1D convolution
# TODO(agarwal): avoid depending on static rank.
nd = a.shape.rank nd = a.shape.rank
if nd is None:
raise ValueError('diff currently requires known rank for input `a`')
if (axis + nd if axis < 0 else axis) >= nd: if (axis + nd if axis < 0 else axis) >= nd:
raise ValueError('axis %s is out of bounds for array of dimension %s' % raise ValueError('axis %s is out of bounds for array of dimension %s' %
(axis, nd)) (axis, nd))
@ -887,8 +908,10 @@ def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring
def _flip_args(f): def _flip_args(f):
def _f(a, b): def _f(a, b):
return f(b, a) return f(b, a)
return _f return _f
@ -910,6 +933,7 @@ setattr(np_arrays.ndarray, '__rtruediv__', _flip_args(true_divide))
def _comparison(tf_fun, x1, x2, cast_bool_to_int=False): def _comparison(tf_fun, x1, x2, cast_bool_to_int=False):
"""Helper function for comparision."""
dtype = np_utils.result_type(x1, x2) dtype = np_utils.result_type(x1, x2)
# Cast x1 and x2 to the result_type if needed. # Cast x1 and x2 to the result_type if needed.
x1 = np_array_ops.array(x1, dtype=dtype) x1 = np_array_ops.array(x1, dtype=dtype)
@ -953,12 +977,18 @@ def less_equal(x1, x2):
@np_utils.np_doc(np.array_equal) @np_utils.np_doc(np.array_equal)
def array_equal(a1, a2): def array_equal(a1, a2): # pylint: disable=missing-function-docstring
def f(a1, a2): def f(x1, x2):
if a1.shape != a2.shape: return np_utils.cond(
return constant_op.constant(False) math_ops.equal(array_ops.rank(x1), array_ops.rank(x2)),
return math_ops.reduce_all(math_ops.equal(a1, a2)) lambda: np_utils.cond( # pylint: disable=g-long-lambda
np_utils.reduce_all(
math_ops.equal(array_ops.shape(x1), array_ops.shape(x2))
),
lambda: math_ops.reduce_all(math_ops.equal(x1, x2)),
lambda: constant_op.constant(False)),
lambda: constant_op.constant(False))
return _comparison(f, a1, a2) return _comparison(f, a1, a2)
@ -1001,7 +1031,13 @@ setattr(np_arrays.ndarray, '__ne__', not_equal)
@np_utils.np_doc(np.linspace) @np_utils.np_doc(np.linspace)
def linspace( # pylint: disable=missing-docstring def linspace( # pylint: disable=missing-docstring
start, stop, num=50, endpoint=True, retstep=False, dtype=float, axis=0): start,
stop,
num=50,
endpoint=True,
retstep=False,
dtype=float,
axis=0):
if dtype: if dtype:
dtype = np_utils.result_type(dtype) dtype = np_utils.result_type(dtype)
start = np_array_ops.array(start, dtype=dtype).data start = np_array_ops.array(start, dtype=dtype).data
@ -1054,10 +1090,14 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): # pylint
start_sign = 1 - np_array_ops.sign(np_array_ops.real(start)) start_sign = 1 - np_array_ops.sign(np_array_ops.real(start))
stop_sign = 1 - np_array_ops.sign(np_array_ops.real(stop)) stop_sign = 1 - np_array_ops.sign(np_array_ops.real(stop))
signflip = 1 - start_sign * stop_sign // 2 signflip = 1 - start_sign * stop_sign // 2
res = signflip * logspace(log10(signflip * start), res = signflip * logspace(
log10(signflip * stop), num, log10(signflip * start),
endpoint=endpoint, base=10.0, log10(signflip * stop),
dtype=computation_dtype, axis=0) num,
endpoint=endpoint,
base=10.0,
dtype=computation_dtype,
axis=0)
if axis != 0: if axis != 0:
res = np_array_ops.moveaxis(res, 0, axis) res = np_array_ops.moveaxis(res, 0, axis)
return np_utils.tensor_to_ndarray(math_ops.cast(res, dtype)) return np_utils.tensor_to_ndarray(math_ops.cast(res, dtype))

View File

@ -47,7 +47,7 @@ def _canonicalize_axes(axes, rank):
canonicalizer = ( canonicalizer = (
lambda axis: cond(axis < 0, lambda: axis + rank, lambda: axis)) lambda axis: cond(axis < 0, lambda: axis + rank, lambda: axis))
else: else:
canonicalizer = lambda axis: axis+rank if axis < 0 else axis canonicalizer = lambda axis: axis + rank if axis < 0 else axis
return [canonicalizer(axis) for axis in axes] return [canonicalizer(axis) for axis in axes]
@ -100,9 +100,16 @@ def finfo(dtype):
def isscalar(val): def isscalar(val):
"""Returns whether `val` is a scalar value or scalar Tensor.""" """Returns whether `val` is a scalar value or scalar Tensor."""
if isinstance(val, (np.ndarray, np_arrays.ndarray, ops.Tensor)): if isinstance(val, np_arrays.ndarray):
return len(val.shape) == 0 # pylint: disable=g-explicit-length-test val = val.data
return np.isscalar(val) if isinstance(val, ops.Tensor):
ndims = val.shape.ndims
if ndims is not None:
return ndims == 0
else:
return math_ops.equal(array_ops.rank(val), 0)
else:
return np.isscalar(val)
# Can't use np_doc because np.result_type is a builtin function. # Can't use np_doc because np.result_type is a builtin function.
@ -119,8 +126,8 @@ def result_type(*arrays_and_dtypes):
def maybe_get_dtype(x): def maybe_get_dtype(x):
# Don't put np.ndarray in this list, because np.result_type looks at the # Don't put np.ndarray in this list, because np.result_type looks at the
# value (not just dtype) of np.ndarray to decide the result type. # value (not just dtype) of np.ndarray to decide the result type.
if isinstance(x, (np_arrays.ndarray, ops.Tensor, if isinstance(
indexed_slices.IndexedSlices)): x, (np_arrays.ndarray, ops.Tensor, indexed_slices.IndexedSlices)):
return _to_numpy_type(x.dtype) return _to_numpy_type(x.dtype)
elif isinstance(x, dtypes.DType): elif isinstance(x, dtypes.DType):
return _to_numpy_type(x) return _to_numpy_type(x)
@ -277,8 +284,11 @@ def np_doc(np_fun, np_fun_name=None):
# for name in np_sig.parameters: # for name in np_sig.parameters:
# if name not in sig.parameters: # if name not in sig.parameters:
# unsupported_params.append(name) # unsupported_params.append(name)
f.__doc__ = _np_doc_helper(f, np_fun, np_fun_name=np_fun_name, f.__doc__ = _np_doc_helper(
unsupported_params=unsupported_params) f,
np_fun,
np_fun_name=np_fun_name,
unsupported_params=unsupported_params)
return f return f
return decorator return decorator
@ -287,9 +297,9 @@ def np_doc(np_fun, np_fun_name=None):
def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None): def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None):
"""Helper to get docs.""" """Helper to get docs."""
if not unsupported_params and not _has_docstring(f) and _has_docstring(np_f): if not unsupported_params and not _has_docstring(f) and _has_docstring(np_f):
# TODO(wangpeng): It looks like code snippets in numpy doc don't work # TODO(wangpeng): It looks like code snippets in numpy doc don't work
# correctly with doctest. Fix that and remove the reformatting of the np_f # correctly with doctest. Fix that and remove the reformatting of the np_f
# comment, here and below. # comment, here and below.
return np_f.__doc__.replace('>>>', '>') return np_f.__doc__.replace('>>>', '>')
assert np_f or np_fun_name assert np_f or np_fun_name
if not np_fun_name: if not np_fun_name:

View File

@ -134,7 +134,7 @@ def _sort_or_argsort(values, axis, direction, return_argsort):
# Axis must be an integer, not a Tensor. # Axis must be an integer, not a Tensor.
axis = framework_ops.convert_to_tensor(axis, name='axis') axis = framework_ops.convert_to_tensor(axis, name='axis')
axis_static = tensor_util.constant_value(axis) axis_static = tensor_util.constant_value(axis)
if axis.shape.ndims != 0 or axis_static is None: if axis.shape.ndims not in (None, 0) or axis_static is None:
raise ValueError('axis must be a constant scalar') raise ValueError('axis must be a constant scalar')
axis_static = int(axis_static) # Avoids NumPy casting error axis_static = int(axis_static) # Avoids NumPy casting error
@ -184,18 +184,8 @@ def _descending_sort(values, axis, return_argsort=False):
name='transposition') name='transposition')
else: else:
# Generate the transposition array from the tensors. # Generate the transposition array from the tensors.
transposition = array_ops.concat( transposition = array_ops.tensor_scatter_update(
[ math_ops.range(rank), [[axis], [rank-1]], [rank-1, axis])
# Axes up to axis are unchanged.
math_ops.range(axis),
# Swap axis and rank - 1.
[rank - 1],
# Axes in [axis + 1, rank - 1) are unchanged.
math_ops.range(axis + 1, rank - 1),
# Swap axis and rank - 1.
[axis]
],
axis=0)
top_k_input = array_ops.transpose(values, transposition) top_k_input = array_ops.transpose(values, transposition)
values, indices = nn_ops.top_k(top_k_input, k) values, indices = nn_ops.top_k(top_k_input, k)