Merge pull request #26517 from yongtang:8264-repeat-tagged
PiperOrigin-RevId: 264268362
This commit is contained in:
commit
9e1fc592d7
@ -1819,5 +1819,49 @@ class BatchGatherNdTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(None, tensor_shape.dimension_value(shape[0]))
|
self.assertEqual(None, tensor_shape.dimension_value(shape[0]))
|
||||||
|
|
||||||
|
|
||||||
|
class RepeatTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testRepeatScalar(self):
|
||||||
|
with self.test_session():
|
||||||
|
v_tf = array_ops.repeat(constant_op.constant(3), 4)
|
||||||
|
v_np = np.repeat(3, 4)
|
||||||
|
self.assertAllEqual(v_tf.eval(), v_np)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testRepeatMatrix(self):
|
||||||
|
with self.test_session():
|
||||||
|
x = np.array([[1, 2], [3, 4]], dtype=np.int32)
|
||||||
|
v_tf = array_ops.repeat(constant_op.constant(x), 2)
|
||||||
|
v_np = np.repeat(x, 2)
|
||||||
|
self.assertAllEqual(v_tf.eval(), v_np)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testRepeatMatrixAxis0(self):
|
||||||
|
with self.test_session():
|
||||||
|
x = np.array([[1, 2], [3, 4]], dtype=np.int32)
|
||||||
|
v_tf = array_ops.repeat(
|
||||||
|
constant_op.constant(x), constant_op.constant([1, 2]), axis=0)
|
||||||
|
v_np = np.repeat(x, [1, 2], axis=0)
|
||||||
|
self.assertAllEqual(v_tf.eval(), v_np)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testRepeatMatrixAxis1(self):
|
||||||
|
with self.test_session():
|
||||||
|
x = np.array([[1, 2], [3, 4]], dtype=np.int32)
|
||||||
|
v_tf = array_ops.repeat(
|
||||||
|
constant_op.constant(x), constant_op.constant(3), axis=1)
|
||||||
|
v_np = np.repeat(x, 3, axis=1)
|
||||||
|
self.assertAllEqual(v_tf.eval(), v_np)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testRepeatMatrixRepeatArray(self):
|
||||||
|
with self.test_session():
|
||||||
|
x = np.array([[1, 2], [3, 4]], dtype=np.int32)
|
||||||
|
v_tf = array_ops.repeat(constant_op.constant(x), [1, 2, 3, 4])
|
||||||
|
v_np = np.repeat(x, [1, 2, 3, 4])
|
||||||
|
self.assertAllEqual(v_tf.eval(), v_np)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_lib.main()
|
test_lib.main()
|
||||||
|
@ -4715,3 +4715,213 @@ def fingerprint(data, method="farmhash64", name=None):
|
|||||||
fingerprint algorithm.
|
fingerprint algorithm.
|
||||||
"""
|
"""
|
||||||
return gen_array_ops.fingerprint(data, method, name)
|
return gen_array_ops.fingerprint(data, method, name)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_int_tensor(tensor, name, dtype=dtypes.int32):
|
||||||
|
"""Converts the given value to an integer Tensor."""
|
||||||
|
tensor = ops.convert_to_tensor(tensor, name=name, preferred_dtype=dtype)
|
||||||
|
if tensor.dtype.is_integer:
|
||||||
|
tensor = gen_math_ops.cast(tensor, dtype)
|
||||||
|
else:
|
||||||
|
raise TypeError("%s must be an integer tensor; dtype=%s" %
|
||||||
|
(name, tensor.dtype))
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def get_positive_axis(axis, ndims):
|
||||||
|
"""Validate an `axis` parameter, and normalize it to be positive.
|
||||||
|
|
||||||
|
If `ndims` is known (i.e., not `None`), then check that `axis` is in the
|
||||||
|
range `-ndims <= axis < ndims`, and return `axis` (if `axis >= 0`) or
|
||||||
|
`axis + ndims` (otherwise).
|
||||||
|
If `ndims` is not known, and `axis` is positive, then return it as-is.
|
||||||
|
If `ndims` is not known, and `axis` is negative, then report an error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
axis: An integer constant
|
||||||
|
ndims: An integer constant, or `None`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The normalized `axis` value.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `axis` is out-of-bounds, or if `axis` is negative and
|
||||||
|
`ndims is None`.
|
||||||
|
"""
|
||||||
|
if not isinstance(axis, int):
|
||||||
|
raise TypeError("axis must be an int; got %s" % type(axis).__name__)
|
||||||
|
if ndims is not None:
|
||||||
|
if 0 <= axis < ndims:
|
||||||
|
return axis
|
||||||
|
elif -ndims <= axis < 0:
|
||||||
|
return axis + ndims
|
||||||
|
else:
|
||||||
|
raise ValueError("axis=%s out of bounds: expected %s<=axis<%s" %
|
||||||
|
(axis, -ndims, ndims))
|
||||||
|
elif axis < 0:
|
||||||
|
raise ValueError("axis may only be negative if ndims is statically known.")
|
||||||
|
return axis
|
||||||
|
|
||||||
|
|
||||||
|
# This op is intended to exactly match the semantics of numpy.repeat, with
|
||||||
|
# one exception: numpy.repeat has special (and somewhat non-intuitive) behavior
|
||||||
|
# when axis is not specified. Rather than implement that special behavior, we
|
||||||
|
# simply make `axis` be a required argument.
|
||||||
|
#
|
||||||
|
# External (OSS) `tf.repeat` feature request:
|
||||||
|
# https://github.com/tensorflow/tensorflow/issues/8246
|
||||||
|
def repeat_with_axis(data, repeats, axis, name=None):
|
||||||
|
"""Repeats elements of `data`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: An `N`-dimensional tensor.
|
||||||
|
repeats: A 1-D integer tensor specifying how many times each element in
|
||||||
|
`axis` should be repeated. `len(repeats)` must equal `data.shape[axis]`.
|
||||||
|
Supports broadcasting from a scalar value.
|
||||||
|
axis: `int`. The axis along which to repeat values. Must be less than
|
||||||
|
`max(N, 1)`.
|
||||||
|
name: A name for the operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor with `max(N, 1)` dimensions. Has the same shape as `data`,
|
||||||
|
except that dimension `axis` has size `sum(repeats)`.
|
||||||
|
#### Examples:
|
||||||
|
```python
|
||||||
|
>>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
|
||||||
|
['a', 'a', 'a', 'c', 'c']
|
||||||
|
>>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
|
||||||
|
[[1, 2], [1, 2], [3, 4], [3, 4], [3, 4]]
|
||||||
|
>>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
|
||||||
|
[[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if not isinstance(axis, int):
|
||||||
|
raise TypeError("axis must be an int; got %s" % type(axis).__name__)
|
||||||
|
|
||||||
|
with ops.name_scope(name, "Repeat", [data, repeats]):
|
||||||
|
data = ops.convert_to_tensor(data, name="data")
|
||||||
|
repeats = convert_to_int_tensor(repeats, name="repeats")
|
||||||
|
repeats.shape.with_rank_at_most(1)
|
||||||
|
|
||||||
|
# If `data` is a scalar, then upgrade it to a vector.
|
||||||
|
data = _with_nonzero_rank(data)
|
||||||
|
data_shape = shape(data)
|
||||||
|
|
||||||
|
# If `axis` is negative, then convert it to a positive value.
|
||||||
|
axis = get_positive_axis(axis, data.shape.ndims)
|
||||||
|
|
||||||
|
# Check data Tensor shapes.
|
||||||
|
if repeats.shape.ndims == 1:
|
||||||
|
data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0])
|
||||||
|
|
||||||
|
# If we know that `repeats` is a scalar, then we can just tile & reshape.
|
||||||
|
if repeats.shape.ndims == 0:
|
||||||
|
expanded = expand_dims(data, axis + 1)
|
||||||
|
tiled = tile_one_dimension(expanded, axis + 1, repeats)
|
||||||
|
result_shape = concat([data_shape[:axis], [-1], data_shape[axis + 1:]],
|
||||||
|
axis=0)
|
||||||
|
return reshape(tiled, result_shape)
|
||||||
|
|
||||||
|
# Broadcast the `repeats` tensor so rank(repeats) == axis + 1.
|
||||||
|
if repeats.shape.ndims != axis + 1:
|
||||||
|
repeats_shape = shape(repeats)
|
||||||
|
repeats_ndims = rank(repeats)
|
||||||
|
broadcast_shape = concat(
|
||||||
|
[data_shape[:axis + 1 - repeats_ndims], repeats_shape], axis=0)
|
||||||
|
repeats = broadcast_to(repeats, broadcast_shape)
|
||||||
|
repeats.set_shape([None] * (axis + 1))
|
||||||
|
|
||||||
|
# Create a "sequence mask" based on `repeats`, where slices across `axis`
|
||||||
|
# contain one `True` value for each repetition. E.g., if
|
||||||
|
# `repeats = [3, 1, 2]`, then `mask = [[1, 1, 1], [1, 0, 0], [1, 1, 0]]`.
|
||||||
|
max_repeat = gen_math_ops.maximum(
|
||||||
|
0, gen_math_ops._max(repeats, _all_dimensions(repeats)))
|
||||||
|
mask = sequence_mask(repeats, max_repeat)
|
||||||
|
|
||||||
|
# Add a new dimension around each value that needs to be repeated, and
|
||||||
|
# then tile that new dimension to match the maximum number of repetitions.
|
||||||
|
expanded = expand_dims(data, axis + 1)
|
||||||
|
tiled = tile_one_dimension(expanded, axis + 1, max_repeat)
|
||||||
|
|
||||||
|
# Use `boolean_mask` to discard the extra repeated values. This also
|
||||||
|
# flattens all dimensions up through `axis`.
|
||||||
|
masked = boolean_mask(tiled, mask)
|
||||||
|
|
||||||
|
# Reshape the output tensor to add the outer dimensions back.
|
||||||
|
if axis == 0:
|
||||||
|
result = masked
|
||||||
|
else:
|
||||||
|
result_shape = concat([data_shape[:axis], [-1], data_shape[axis + 1:]],
|
||||||
|
axis=0)
|
||||||
|
result = reshape(masked, result_shape)
|
||||||
|
|
||||||
|
# Preserve shape information.
|
||||||
|
if data.shape.ndims is not None:
|
||||||
|
new_axis_size = 0 if repeats.shape[0] == 0 else None
|
||||||
|
result.set_shape(data.shape[:axis].concatenate(
|
||||||
|
[new_axis_size]).concatenate(data.shape[axis + 1:]))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def tile_one_dimension(data, axis, multiple):
|
||||||
|
"""Tiles a single dimension of a tensor."""
|
||||||
|
# Assumes axis is a nonnegative int.
|
||||||
|
if data.shape.ndims is not None:
|
||||||
|
multiples = [1] * data.shape.ndims
|
||||||
|
multiples[axis] = multiple
|
||||||
|
else:
|
||||||
|
ones_value = ones(rank(data), dtypes.int32)
|
||||||
|
multiples = concat([ones_value[:axis], [multiple], ones_value[axis + 1:]],
|
||||||
|
axis=0)
|
||||||
|
return tile(data, multiples)
|
||||||
|
|
||||||
|
|
||||||
|
def _with_nonzero_rank(data):
|
||||||
|
"""If `data` is scalar, then add a dimension; otherwise return as-is."""
|
||||||
|
if data.shape.ndims is not None:
|
||||||
|
if data.shape.ndims == 0:
|
||||||
|
return stack([data])
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
else:
|
||||||
|
data_shape = shape(data)
|
||||||
|
data_ndims = rank(data)
|
||||||
|
return reshape(data, concat([[1], data_shape], axis=0)[-data_ndims:])
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("repeat")
|
||||||
|
def repeat(input, repeats, axis=None, name=None): # pylint: disable=redefined-builtin
|
||||||
|
"""Repeat elements of `input`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: An `N`-dimensional Tensor.
|
||||||
|
repeats: An 1-D `int` Tensor. The number of repetitions for each element.
|
||||||
|
repeats is broadcasted to fit the shape of the given axis. `len(repeats)`
|
||||||
|
must equal `input.shape[axis]` if axis is not None.
|
||||||
|
axis: An int. The axis along which to repeat values. By default (axis=None),
|
||||||
|
use the flattened input array, and return a flat output array.
|
||||||
|
name: A name for the operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Tensor which has the same shape as `input`, except along the given axis.
|
||||||
|
If axis is None then the output array is flattened to match the flattened
|
||||||
|
input array.
|
||||||
|
#### Examples:
|
||||||
|
```python
|
||||||
|
>>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
|
||||||
|
['a', 'a', 'a', 'c', 'c']
|
||||||
|
>>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
|
||||||
|
[[1, 2], [1, 2], [3, 4], [3, 4], [3, 4]]
|
||||||
|
>>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
|
||||||
|
[[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]]
|
||||||
|
>>> repeat(3, repeats=4)
|
||||||
|
[3, 3, 3, 3]
|
||||||
|
>>> repeat([[1,2], [3,4]], repeats=2)
|
||||||
|
[1, 1, 2, 2, 3, 3, 4, 4]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if axis is None:
|
||||||
|
input = reshape(input, [-1])
|
||||||
|
axis = 0
|
||||||
|
return repeat_with_axis(input, repeats, axis, name)
|
||||||
|
@ -21,59 +21,12 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
|
||||||
from tensorflow.python.framework import ops
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
from tensorflow.python.ops import gen_ragged_math_ops
|
from tensorflow.python.ops import gen_ragged_math_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
|
|
||||||
def convert_to_int_tensor(tensor, name, dtype=dtypes.int32):
|
|
||||||
"""Converts the given value to an integer Tensor."""
|
|
||||||
tensor = ops.convert_to_tensor(tensor, name=name, preferred_dtype=dtype)
|
|
||||||
if tensor.dtype.is_integer:
|
|
||||||
tensor = math_ops.cast(tensor, dtype)
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
"%s must be an integer tensor; dtype=%s" % (name, tensor.dtype))
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
def get_positive_axis(axis, ndims):
|
|
||||||
"""Validate an `axis` parameter, and normalize it to be positive.
|
|
||||||
|
|
||||||
If `ndims` is known (i.e., not `None`), then check that `axis` is in the
|
|
||||||
range `-ndims <= axis < ndims`, and return `axis` (if `axis >= 0`) or
|
|
||||||
`axis + ndims` (otherwise).
|
|
||||||
If `ndims` is not known, and `axis` is positive, then return it as-is.
|
|
||||||
If `ndims` is not known, and `axis` is negative, then report an error.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
axis: An integer constant
|
|
||||||
ndims: An integer constant, or `None`
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The normalized `axis` value.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If `axis` is out-of-bounds, or if `axis` is negative and
|
|
||||||
`ndims is None`.
|
|
||||||
"""
|
|
||||||
if not isinstance(axis, int):
|
|
||||||
raise TypeError("axis must be an int; got %s" % type(axis).__name__)
|
|
||||||
if ndims is not None:
|
|
||||||
if 0 <= axis < ndims:
|
|
||||||
return axis
|
|
||||||
elif -ndims <= axis < 0:
|
|
||||||
return axis + ndims
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"axis=%s out of bounds: expected %s<=axis<%s" % (axis, -ndims, ndims))
|
|
||||||
elif axis < 0:
|
|
||||||
raise ValueError("axis may only be negative if ndims is statically known.")
|
|
||||||
return axis
|
|
||||||
|
|
||||||
|
|
||||||
def assert_splits_match(nested_splits_lists):
|
def assert_splits_match(nested_splits_lists):
|
||||||
"""Checks that the given splits lists are identical.
|
"""Checks that the given splits lists are identical.
|
||||||
@ -103,133 +56,10 @@ def assert_splits_match(nested_splits_lists):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# This op is intended to exactly match the semantics of numpy.repeat, with
|
# Note: imported here to avoid circular dependency of array_ops.
|
||||||
# one exception: numpy.repeat has special (and somewhat non-intuitive) behavior
|
get_positive_axis = array_ops.get_positive_axis
|
||||||
# when axis is not specified. Rather than implement that special behavior, we
|
convert_to_int_tensor = array_ops.convert_to_int_tensor
|
||||||
# simply make `axis` be a required argument.
|
repeat = array_ops.repeat_with_axis
|
||||||
#
|
|
||||||
# External (OSS) `tf.repeat` feature request:
|
|
||||||
# https://github.com/tensorflow/tensorflow/issues/8246
|
|
||||||
def repeat(data, repeats, axis, name=None):
|
|
||||||
"""Repeats elements of `data`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: An `N`-dimensional tensor.
|
|
||||||
repeats: A 1-D integer tensor specifying how many times each element in
|
|
||||||
`axis` should be repeated. `len(repeats)` must equal `data.shape[axis]`.
|
|
||||||
Supports broadcasting from a scalar value.
|
|
||||||
axis: `int`. The axis along which to repeat values. Must be less than
|
|
||||||
`max(N, 1)`.
|
|
||||||
name: A name for the operation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tensor with `max(N, 1)` dimensions. Has the same shape as `data`,
|
|
||||||
except that dimension `axis` has size `sum(repeats)`.
|
|
||||||
|
|
||||||
#### Examples:
|
|
||||||
```python
|
|
||||||
>>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
|
|
||||||
['a', 'a', 'a', 'c', 'c']
|
|
||||||
>>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
|
|
||||||
[[1, 2], [1, 2], [3, 4], [3, 4], [3, 4]]
|
|
||||||
>>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
|
|
||||||
[[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]]
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if not isinstance(axis, int):
|
|
||||||
raise TypeError("axis must be an int; got %s" % type(axis).__name__)
|
|
||||||
|
|
||||||
with ops.name_scope(name, "Repeat", [data, repeats]):
|
|
||||||
data = ops.convert_to_tensor(data, name="data")
|
|
||||||
repeats = convert_to_int_tensor(repeats, name="repeats")
|
|
||||||
repeats.shape.with_rank_at_most(1)
|
|
||||||
|
|
||||||
# If `data` is a scalar, then upgrade it to a vector.
|
|
||||||
data = _with_nonzero_rank(data)
|
|
||||||
data_shape = array_ops.shape(data)
|
|
||||||
|
|
||||||
# If `axis` is negative, then convert it to a positive value.
|
|
||||||
axis = get_positive_axis(axis, data.shape.ndims)
|
|
||||||
|
|
||||||
# Check data Tensor shapes.
|
|
||||||
if repeats.shape.ndims == 1:
|
|
||||||
data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0])
|
|
||||||
|
|
||||||
# If we know that `repeats` is a scalar, then we can just tile & reshape.
|
|
||||||
if repeats.shape.ndims == 0:
|
|
||||||
expanded = array_ops.expand_dims(data, axis + 1)
|
|
||||||
tiled = tile_one_dimension(expanded, axis + 1, repeats)
|
|
||||||
result_shape = array_ops.concat(
|
|
||||||
[data_shape[:axis], [-1], data_shape[axis + 1:]], axis=0)
|
|
||||||
return array_ops.reshape(tiled, result_shape)
|
|
||||||
|
|
||||||
# Broadcast the `repeats` tensor so rank(repeats) == axis + 1.
|
|
||||||
if repeats.shape.ndims != axis + 1:
|
|
||||||
repeats_shape = array_ops.shape(repeats)
|
|
||||||
repeats_ndims = array_ops.rank(repeats)
|
|
||||||
broadcast_shape = array_ops.concat(
|
|
||||||
[data_shape[:axis + 1 - repeats_ndims], repeats_shape], axis=0)
|
|
||||||
repeats = array_ops.broadcast_to(repeats, broadcast_shape)
|
|
||||||
repeats.set_shape([None] * (axis + 1))
|
|
||||||
|
|
||||||
# Create a "sequence mask" based on `repeats`, where slices across `axis`
|
|
||||||
# contain one `True` value for each repetition. E.g., if
|
|
||||||
# `repeats = [3, 1, 2]`, then `mask = [[1, 1, 1], [1, 0, 0], [1, 1, 0]]`.
|
|
||||||
max_repeat = math_ops.maximum(0, math_ops.reduce_max(repeats))
|
|
||||||
mask = array_ops.sequence_mask(repeats, max_repeat)
|
|
||||||
|
|
||||||
# Add a new dimension around each value that needs to be repeated, and
|
|
||||||
# then tile that new dimension to match the maximum number of repetitions.
|
|
||||||
expanded = array_ops.expand_dims(data, axis + 1)
|
|
||||||
tiled = tile_one_dimension(expanded, axis + 1, max_repeat)
|
|
||||||
|
|
||||||
# Use `boolean_mask` to discard the extra repeated values. This also
|
|
||||||
# flattens all dimensions up through `axis`.
|
|
||||||
masked = array_ops.boolean_mask(tiled, mask)
|
|
||||||
|
|
||||||
# Reshape the output tensor to add the outer dimensions back.
|
|
||||||
if axis == 0:
|
|
||||||
result = masked
|
|
||||||
else:
|
|
||||||
result_shape = array_ops.concat(
|
|
||||||
[data_shape[:axis], [-1], data_shape[axis + 1:]], axis=0)
|
|
||||||
result = array_ops.reshape(masked, result_shape)
|
|
||||||
|
|
||||||
# Preserve shape information.
|
|
||||||
if data.shape.ndims is not None:
|
|
||||||
new_axis_size = 0 if repeats.shape[0] == 0 else None
|
|
||||||
result.set_shape(data.shape[:axis].concatenate(
|
|
||||||
[new_axis_size]).concatenate(data.shape[axis + 1:]))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def tile_one_dimension(data, axis, multiple):
|
|
||||||
"""Tiles a single dimension of a tensor."""
|
|
||||||
# Assumes axis is a nonnegative int.
|
|
||||||
if data.shape.ndims is not None:
|
|
||||||
multiples = [1] * data.shape.ndims
|
|
||||||
multiples[axis] = multiple
|
|
||||||
else:
|
|
||||||
ones = array_ops.ones(array_ops.rank(data), dtypes.int32)
|
|
||||||
multiples = array_ops.concat([ones[:axis], [multiple], ones[axis + 1:]],
|
|
||||||
axis=0)
|
|
||||||
return array_ops.tile(data, multiples)
|
|
||||||
|
|
||||||
|
|
||||||
def _with_nonzero_rank(data):
|
|
||||||
"""If `data` is scalar, then add a dimension; otherwise return as-is."""
|
|
||||||
if data.shape.ndims is not None:
|
|
||||||
if data.shape.ndims == 0:
|
|
||||||
return array_ops.stack([data])
|
|
||||||
else:
|
|
||||||
return data
|
|
||||||
else:
|
|
||||||
data_shape = array_ops.shape(data)
|
|
||||||
data_ndims = array_ops.rank(data)
|
|
||||||
return array_ops.reshape(
|
|
||||||
data,
|
|
||||||
array_ops.concat([[1], data_shape], axis=0)[-data_ndims:])
|
|
||||||
|
|
||||||
|
|
||||||
def lengths_to_splits(lengths):
|
def lengths_to_splits(lengths):
|
||||||
|
@ -1916,6 +1916,10 @@ tf_module {
|
|||||||
name: "register_tensor_conversion_function"
|
name: "register_tensor_conversion_function"
|
||||||
argspec: "args=[\'base_type\', \'conversion_func\', \'priority\'], varargs=None, keywords=None, defaults=[\'100\'], "
|
argspec: "args=[\'base_type\', \'conversion_func\', \'priority\'], varargs=None, keywords=None, defaults=[\'100\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "repeat"
|
||||||
|
argspec: "args=[\'input\', \'repeats\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "report_uninitialized_variables"
|
name: "report_uninitialized_variables"
|
||||||
argspec: "args=[\'var_list\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'report_uninitialized_variables\'], "
|
argspec: "args=[\'var_list\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'report_uninitialized_variables\'], "
|
||||||
|
@ -900,6 +900,10 @@ tf_module {
|
|||||||
name: "register_tensor_conversion_function"
|
name: "register_tensor_conversion_function"
|
||||||
argspec: "args=[\'base_type\', \'conversion_func\', \'priority\'], varargs=None, keywords=None, defaults=[\'100\'], "
|
argspec: "args=[\'base_type\', \'conversion_func\', \'priority\'], varargs=None, keywords=None, defaults=[\'100\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "repeat"
|
||||||
|
argspec: "args=[\'input\', \'repeats\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "required_space_to_batch_paddings"
|
name: "required_space_to_batch_paddings"
|
||||||
argspec: "args=[\'input_shape\', \'block_shape\', \'base_paddings\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
argspec: "args=[\'input_shape\', \'block_shape\', \'base_paddings\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user