Add examples to top_k
PiperOrigin-RevId: 342970448 Change-Id: I6dc4ae785515c99c60890c5367cc79dd86c7fa88
This commit is contained in:
		
							parent
							
								
									e34e39a1f9
								
							
						
					
					
						commit
						2fbed3b58e
					
				@ -5203,13 +5203,37 @@ def top_k(input, k=1, sorted=True, name=None):  # pylint: disable=redefined-buil
 | 
			
		||||
  and outputs their values and indices as vectors.  Thus `values[j]` is the
 | 
			
		||||
  `j`-th largest entry in `input`, and its index is `indices[j]`.
 | 
			
		||||
 | 
			
		||||
  >>> result = tf.math.top_k([1, 2, 98, 1, 1, 99, 3, 1, 3, 96, 4, 1],
 | 
			
		||||
  ...                         k=3)
 | 
			
		||||
  >>> result.values.numpy()
 | 
			
		||||
  array([99, 98, 96], dtype=int32)
 | 
			
		||||
  >>> result.indices.numpy()
 | 
			
		||||
  array([5, 2, 9], dtype=int32)
 | 
			
		||||
 | 
			
		||||
  For matrices (resp. higher rank input), computes the top `k` entries in each
 | 
			
		||||
  row (resp. vector along the last dimension).  Thus,
 | 
			
		||||
 | 
			
		||||
      values.shape = indices.shape = input.shape[:-1] + [k]
 | 
			
		||||
  >>> input = tf.random.normal(shape=(3,4,5,6))
 | 
			
		||||
  >>> k = 2
 | 
			
		||||
  >>> values, indices  = tf.math.top_k(input, k=k)
 | 
			
		||||
  >>> values.shape.as_list()
 | 
			
		||||
  [3, 4, 5, 2]
 | 
			
		||||
  >>>
 | 
			
		||||
  >>> values.shape == indices.shape == input.shape[:-1] + [k]
 | 
			
		||||
  True
 | 
			
		||||
 | 
			
		||||
  The indices can be used to `gather` from a tensor who's shape matches `input`.
 | 
			
		||||
 | 
			
		||||
  >>> gathered_values = tf.gather(input, indices, batch_dims=-1)
 | 
			
		||||
  >>> assert tf.reduce_all(gathered_values == values)
 | 
			
		||||
 | 
			
		||||
  If two elements are equal, the lower-index element appears first.
 | 
			
		||||
 | 
			
		||||
  >>> result = tf.math.top_k([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0],
 | 
			
		||||
  ...                        k=3)
 | 
			
		||||
  >>> result.indices.numpy()
 | 
			
		||||
  array([0, 1, 3], dtype=int32)
 | 
			
		||||
 | 
			
		||||
  Args:
 | 
			
		||||
    input: 1-D or higher `Tensor` with last dimension at least `k`.
 | 
			
		||||
    k: 0-D `int32` `Tensor`.  Number of top elements to look for along the last
 | 
			
		||||
@ -5219,6 +5243,7 @@ def top_k(input, k=1, sorted=True, name=None):  # pylint: disable=redefined-buil
 | 
			
		||||
    name: Optional name for the operation.
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
    A tuple with two named fields:
 | 
			
		||||
    values: The `k` largest elements along each last dimensional slice.
 | 
			
		||||
    indices: The indices of `values` within the last dimension of `input`.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user