Add examples to top_k

PiperOrigin-RevId: 342970448
Change-Id: I6dc4ae785515c99c60890c5367cc79dd86c7fa88
This commit is contained in:
Mark Daoust 2020-11-17 16:27:32 -08:00 committed by TensorFlower Gardener
parent e34e39a1f9
commit 2fbed3b58e

View File

@ -5203,13 +5203,37 @@ def top_k(input, k=1, sorted=True, name=None): # pylint: disable=redefined-buil
and outputs their values and indices as vectors. Thus `values[j]` is the
`j`-th largest entry in `input`, and its index is `indices[j]`.
>>> result = tf.math.top_k([1, 2, 98, 1, 1, 99, 3, 1, 3, 96, 4, 1],
... k=3)
>>> result.values.numpy()
array([99, 98, 96], dtype=int32)
>>> result.indices.numpy()
array([5, 2, 9], dtype=int32)
For matrices (resp. higher rank input), computes the top `k` entries in each
row (resp. vector along the last dimension). Thus,
values.shape = indices.shape = input.shape[:-1] + [k]
>>> input = tf.random.normal(shape=(3,4,5,6))
>>> k = 2
>>> values, indices = tf.math.top_k(input, k=k)
>>> values.shape.as_list()
[3, 4, 5, 2]
>>>
>>> values.shape == indices.shape == input.shape[:-1] + [k]
True
The indices can be used to `gather` from a tensor who's shape matches `input`.
>>> gathered_values = tf.gather(input, indices, batch_dims=-1)
>>> assert tf.reduce_all(gathered_values == values)
If two elements are equal, the lower-index element appears first.
>>> result = tf.math.top_k([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0],
... k=3)
>>> result.indices.numpy()
array([0, 1, 3], dtype=int32)
Args:
input: 1-D or higher `Tensor` with last dimension at least `k`.
k: 0-D `int32` `Tensor`. Number of top elements to look for along the last
@ -5219,6 +5243,7 @@ def top_k(input, k=1, sorted=True, name=None): # pylint: disable=redefined-buil
name: Optional name for the operation.
Returns:
A tuple with two named fields:
values: The `k` largest elements along each last dimensional slice.
indices: The indices of `values` within the last dimension of `input`.
"""