Update docstring of tf.repeats, and add additional examples based on review feedback

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2019-06-18 20:59:57 +00:00
parent d4ffb947ad
commit 6b1652e6a4
2 changed files with 24 additions and 8 deletions

View File

@ -4889,16 +4889,33 @@ def _with_nonzero_rank(data):
@tf_export("repeat")
def repeat(input, repeats, axis=None, name=None): # pylint: disable=redefined-builtin
"""Repeat elements of an array
"""Repeat elements of `input`
Args:
input: A Tensor.
input: An `N`-dimensional Tensor.
repeats: An 1-D `int` Tensor. The number of repetitions for each element.
repeats is broadcasted to fit the shape of the given axis
axis: An int. The axis along which to repeat values. By default, use the
flattened input array, and return a flat output array.
name: name of the op.
repeats is broadcasted to fit the shape of the given axis.
`len(repeats)` must equal `input.shape[axis]` if axis is not None.
axis: An int. The axis along which to repeat values. By default
(axis=None), use the flattened input array, and return a flat output
array.
name: A name for the operation.
Returns:
A Tensor which has the same shape as a, except along the given axis.
A Tensor which has the same shape as `input`, except along the given axis.
If axis is None then the output array is flattened to match the flattened
input array.
#### Examples:
```python
>>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
['a', 'a', 'a', 'c', 'c']
>>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
[[1, 2], [1, 2], [3, 4], [3, 4], [3, 4]]
>>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
[[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]]
>>> repeat(3, repeats=4)
[3, 3, 3, 3]
>>> repeat([[1,2], [3,4]], repeats=2)
[1, 1, 2, 2, 3, 3, 4, 4]
```
"""
if axis is None:
input = reshape(input, [-1])

View File

@ -21,7 +21,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import gen_ragged_math_ops