Add examples to top_k
PiperOrigin-RevId: 342970448 Change-Id: I6dc4ae785515c99c60890c5367cc79dd86c7fa88
This commit is contained in:
parent
e34e39a1f9
commit
2fbed3b58e
@ -5203,13 +5203,37 @@ def top_k(input, k=1, sorted=True, name=None): # pylint: disable=redefined-buil
|
||||
and outputs their values and indices as vectors. Thus `values[j]` is the
|
||||
`j`-th largest entry in `input`, and its index is `indices[j]`.
|
||||
|
||||
>>> result = tf.math.top_k([1, 2, 98, 1, 1, 99, 3, 1, 3, 96, 4, 1],
|
||||
... k=3)
|
||||
>>> result.values.numpy()
|
||||
array([99, 98, 96], dtype=int32)
|
||||
>>> result.indices.numpy()
|
||||
array([5, 2, 9], dtype=int32)
|
||||
|
||||
For matrices (resp. higher rank input), computes the top `k` entries in each
|
||||
row (resp. vector along the last dimension). Thus,
|
||||
|
||||
values.shape = indices.shape = input.shape[:-1] + [k]
|
||||
>>> input = tf.random.normal(shape=(3,4,5,6))
|
||||
>>> k = 2
|
||||
>>> values, indices = tf.math.top_k(input, k=k)
|
||||
>>> values.shape.as_list()
|
||||
[3, 4, 5, 2]
|
||||
>>>
|
||||
>>> values.shape == indices.shape == input.shape[:-1] + [k]
|
||||
True
|
||||
|
||||
The indices can be used to `gather` from a tensor who's shape matches `input`.
|
||||
|
||||
>>> gathered_values = tf.gather(input, indices, batch_dims=-1)
|
||||
>>> assert tf.reduce_all(gathered_values == values)
|
||||
|
||||
If two elements are equal, the lower-index element appears first.
|
||||
|
||||
>>> result = tf.math.top_k([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0],
|
||||
... k=3)
|
||||
>>> result.indices.numpy()
|
||||
array([0, 1, 3], dtype=int32)
|
||||
|
||||
Args:
|
||||
input: 1-D or higher `Tensor` with last dimension at least `k`.
|
||||
k: 0-D `int32` `Tensor`. Number of top elements to look for along the last
|
||||
@ -5219,6 +5243,7 @@ def top_k(input, k=1, sorted=True, name=None): # pylint: disable=redefined-buil
|
||||
name: Optional name for the operation.
|
||||
|
||||
Returns:
|
||||
A tuple with two named fields:
|
||||
values: The `k` largest elements along each last dimensional slice.
|
||||
indices: The indices of `values` within the last dimension of `input`.
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user