From adea107afdaad2d93eb8f966b580703c743383c3 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 9 Mar 2019 22:07:28 +0000 Subject: [PATCH 1/8] Add tf.repeat support equivalent to numpy.repeat (Most of the credit goes to RaggedTensor author, couldn't find the GitHub handle but the related commit is bad5d1a) This PR tries to address the issue raised in 8246 to have tf.repeat equivalent to numpy.repeat. Multiple attempts have been made in the past. My previous PR 15224 add the support with C++ which was closed, as a python implementation relying on existing ops is more desireable. This PR - moves the repeat implementation in ragged_util.py to array_ops.py so that it is possible to get exposed. Signed-off-by: Yong Tang --- tensorflow/python/ops/array_ops.py | 175 ++++++++++++++++++++ tensorflow/python/ops/ragged/ragged_util.py | 174 ------------------- 2 files changed, 175 insertions(+), 174 deletions(-) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 32c4bfc9d39..627d3715dbd 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -4715,3 +4715,178 @@ def fingerprint(data, method="farmhash64", name=None): fingerprint algorithm. """ return gen_array_ops.fingerprint(data, method, name) + + +def convert_to_int_tensor(tensor, name, dtype=dtypes.int32): + """Converts the given value to an integer Tensor.""" + tensor = ops.convert_to_tensor(tensor, name=name, preferred_dtype=dtype) + if tensor.dtype.is_integer: + tensor = math_ops.cast(tensor, dtype) + else: + raise TypeError( + "%s must be an integer tensor; dtype=%s" % (name, tensor.dtype)) + return tensor + + +def get_positive_axis(axis, ndims): + """Validate an `axis` parameter, and normalize it to be positive. + + If `ndims` is known (i.e., not `None`), then check that `axis` is in the + range `-ndims <= axis < ndims`, and return `axis` (if `axis >= 0`) or + `axis + ndims` (otherwise). + If `ndims` is not known, and `axis` is positive, then return it as-is. + If `ndims` is not known, and `axis` is negative, then report an error. + + Args: + axis: An integer constant + ndims: An integer constant, or `None` + + Returns: + The normalized `axis` value. + + Raises: + ValueError: If `axis` is out-of-bounds, or if `axis` is negative and + `ndims is None`. + """ + if not isinstance(axis, int): + raise TypeError("axis must be an int; got %s" % type(axis).__name__) + if ndims is not None: + if 0 <= axis < ndims: + return axis + elif -ndims <= axis < 0: + return axis + ndims + else: + raise ValueError( + "axis=%s out of bounds: expected %s<=axis<%s" % (axis, -ndims, ndims)) + elif axis < 0: + raise ValueError("axis may only be negative if ndims is statically known.") + return axis + + +# This op is intended to exactly match the semantics of numpy.repeat, with +# one exception: numpy.repeat has special (and somewhat non-intuitive) behavior +# when axis is not specified. Rather than implement that special behavior, we +# simply make `axis` be a required argument. +# +# External (OSS) `tf.repeat` feature request: +# https://github.com/tensorflow/tensorflow/issues/8246 +def repeat(data, repeats, axis, name=None): + """Repeats elements of `data`. + + Args: + data: An `N`-dimensional tensor. + repeats: A 1-D integer tensor specifying how many times each element in + `axis` should be repeated. `len(repeats)` must equal `data.shape[axis]`. + Supports broadcasting from a scalar value. + axis: `int`. The axis along which to repeat values. Must be less than + `max(N, 1)`. + name: A name for the operation. + + Returns: + A tensor with `max(N, 1)` dimensions. Has the same shape as `data`, + except that dimension `axis` has size `sum(repeats)`. + + #### Examples: + ```python + >>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0) + ['a', 'a', 'a', 'c', 'c'] + >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0) + [[1, 2], [1, 2], [3, 4], [3, 4], [3, 4]] + >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1) + [[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]] + ``` + """ + if not isinstance(axis, int): + raise TypeError("axis must be an int; got %s" % type(axis).__name__) + + with ops.name_scope(name, "Repeat", [data, repeats]): + data = ops.convert_to_tensor(data, name="data") + repeats = convert_to_int_tensor(repeats, name="repeats") + repeats.shape.with_rank_at_most(1) + + # If `data` is a scalar, then upgrade it to a vector. + data = _with_nonzero_rank(data) + data_shape = array_ops.shape(data) + + # If `axis` is negative, then convert it to a positive value. + axis = get_positive_axis(axis, data.shape.ndims) + + # Check data Tensor shapes. + if repeats.shape.ndims == 1: + data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0]) + + # If we know that `repeats` is a scalar, then we can just tile & reshape. + if repeats.shape.ndims == 0: + expanded = array_ops.expand_dims(data, axis + 1) + tiled = tile_one_dimension(expanded, axis + 1, repeats) + result_shape = array_ops.concat( + [data_shape[:axis], [-1], data_shape[axis + 1:]], axis=0) + return array_ops.reshape(tiled, result_shape) + + # Broadcast the `repeats` tensor so rank(repeats) == axis + 1. + if repeats.shape.ndims != axis + 1: + repeats_shape = array_ops.shape(repeats) + repeats_ndims = array_ops.rank(repeats) + broadcast_shape = array_ops.concat( + [data_shape[:axis + 1 - repeats_ndims], repeats_shape], axis=0) + repeats = array_ops.broadcast_to(repeats, broadcast_shape) + repeats.set_shape([None] * (axis + 1)) + + # Create a "sequence mask" based on `repeats`, where slices across `axis` + # contain one `True` value for each repetition. E.g., if + # `repeats = [3, 1, 2]`, then `mask = [[1, 1, 1], [1, 0, 0], [1, 1, 0]]`. + max_repeat = math_ops.maximum(0, math_ops.reduce_max(repeats)) + mask = array_ops.sequence_mask(repeats, max_repeat) + + # Add a new dimension around each value that needs to be repeated, and + # then tile that new dimension to match the maximum number of repetitions. + expanded = array_ops.expand_dims(data, axis + 1) + tiled = tile_one_dimension(expanded, axis + 1, max_repeat) + + # Use `boolean_mask` to discard the extra repeated values. This also + # flattens all dimensions up through `axis`. + masked = array_ops.boolean_mask(tiled, mask) + + # Reshape the output tensor to add the outer dimensions back. + if axis == 0: + result = masked + else: + result_shape = array_ops.concat( + [data_shape[:axis], [-1], data_shape[axis + 1:]], axis=0) + result = array_ops.reshape(masked, result_shape) + + # Preserve shape information. + if data.shape.ndims is not None: + new_axis_size = 0 if repeats.shape[0] == 0 else None + result.set_shape(data.shape[:axis].concatenate( + [new_axis_size]).concatenate(data.shape[axis + 1:])) + + return result + + +def tile_one_dimension(data, axis, multiple): + """Tiles a single dimension of a tensor.""" + # Assumes axis is a nonnegative int. + if data.shape.ndims is not None: + multiples = [1] * data.shape.ndims + multiples[axis] = multiple + else: + ones = array_ops.ones(array_ops.rank(data), dtypes.int32) + multiples = array_ops.concat([ones[:axis], [multiple], ones[axis + 1:]], + axis=0) + return array_ops.tile(data, multiples) + + +def _with_nonzero_rank(data): + """If `data` is scalar, then add a dimension; otherwise return as-is.""" + if data.shape.ndims is not None: + if data.shape.ndims == 0: + return array_ops.stack([data]) + else: + return data + else: + data_shape = array_ops.shape(data) + data_ndims = array_ops.rank(data) + return array_ops.reshape( + data, + array_ops.concat([[1], data_shape], axis=0)[-data_ndims:]) diff --git a/tensorflow/python/ops/ragged/ragged_util.py b/tensorflow/python/ops/ragged/ragged_util.py index 2c738e7cd29..aeea692036a 100644 --- a/tensorflow/python/ops/ragged/ragged_util.py +++ b/tensorflow/python/ops/ragged/ragged_util.py @@ -29,51 +29,6 @@ from tensorflow.python.ops import gen_ragged_math_ops from tensorflow.python.ops import math_ops -def convert_to_int_tensor(tensor, name, dtype=dtypes.int32): - """Converts the given value to an integer Tensor.""" - tensor = ops.convert_to_tensor(tensor, name=name, preferred_dtype=dtype) - if tensor.dtype.is_integer: - tensor = math_ops.cast(tensor, dtype) - else: - raise TypeError( - "%s must be an integer tensor; dtype=%s" % (name, tensor.dtype)) - return tensor - - -def get_positive_axis(axis, ndims): - """Validate an `axis` parameter, and normalize it to be positive. - - If `ndims` is known (i.e., not `None`), then check that `axis` is in the - range `-ndims <= axis < ndims`, and return `axis` (if `axis >= 0`) or - `axis + ndims` (otherwise). - If `ndims` is not known, and `axis` is positive, then return it as-is. - If `ndims` is not known, and `axis` is negative, then report an error. - - Args: - axis: An integer constant - ndims: An integer constant, or `None` - - Returns: - The normalized `axis` value. - - Raises: - ValueError: If `axis` is out-of-bounds, or if `axis` is negative and - `ndims is None`. - """ - if not isinstance(axis, int): - raise TypeError("axis must be an int; got %s" % type(axis).__name__) - if ndims is not None: - if 0 <= axis < ndims: - return axis - elif -ndims <= axis < 0: - return axis + ndims - else: - raise ValueError( - "axis=%s out of bounds: expected %s<=axis<%s" % (axis, -ndims, ndims)) - elif axis < 0: - raise ValueError("axis may only be negative if ndims is statically known.") - return axis - def assert_splits_match(nested_splits_lists): """Checks that the given splits lists are identical. @@ -103,135 +58,6 @@ def assert_splits_match(nested_splits_lists): ] -# This op is intended to exactly match the semantics of numpy.repeat, with -# one exception: numpy.repeat has special (and somewhat non-intuitive) behavior -# when axis is not specified. Rather than implement that special behavior, we -# simply make `axis` be a required argument. -# -# External (OSS) `tf.repeat` feature request: -# https://github.com/tensorflow/tensorflow/issues/8246 -def repeat(data, repeats, axis, name=None): - """Repeats elements of `data`. - - Args: - data: An `N`-dimensional tensor. - repeats: A 1-D integer tensor specifying how many times each element in - `axis` should be repeated. `len(repeats)` must equal `data.shape[axis]`. - Supports broadcasting from a scalar value. - axis: `int`. The axis along which to repeat values. Must be less than - `max(N, 1)`. - name: A name for the operation. - - Returns: - A tensor with `max(N, 1)` dimensions. Has the same shape as `data`, - except that dimension `axis` has size `sum(repeats)`. - - #### Examples: - ```python - >>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0) - ['a', 'a', 'a', 'c', 'c'] - >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0) - [[1, 2], [1, 2], [3, 4], [3, 4], [3, 4]] - >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1) - [[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]] - ``` - """ - if not isinstance(axis, int): - raise TypeError("axis must be an int; got %s" % type(axis).__name__) - - with ops.name_scope(name, "Repeat", [data, repeats]): - data = ops.convert_to_tensor(data, name="data") - repeats = convert_to_int_tensor(repeats, name="repeats") - repeats.shape.with_rank_at_most(1) - - # If `data` is a scalar, then upgrade it to a vector. - data = _with_nonzero_rank(data) - data_shape = array_ops.shape(data) - - # If `axis` is negative, then convert it to a positive value. - axis = get_positive_axis(axis, data.shape.ndims) - - # Check data Tensor shapes. - if repeats.shape.ndims == 1: - data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0]) - - # If we know that `repeats` is a scalar, then we can just tile & reshape. - if repeats.shape.ndims == 0: - expanded = array_ops.expand_dims(data, axis + 1) - tiled = tile_one_dimension(expanded, axis + 1, repeats) - result_shape = array_ops.concat( - [data_shape[:axis], [-1], data_shape[axis + 1:]], axis=0) - return array_ops.reshape(tiled, result_shape) - - # Broadcast the `repeats` tensor so rank(repeats) == axis + 1. - if repeats.shape.ndims != axis + 1: - repeats_shape = array_ops.shape(repeats) - repeats_ndims = array_ops.rank(repeats) - broadcast_shape = array_ops.concat( - [data_shape[:axis + 1 - repeats_ndims], repeats_shape], axis=0) - repeats = array_ops.broadcast_to(repeats, broadcast_shape) - repeats.set_shape([None] * (axis + 1)) - - # Create a "sequence mask" based on `repeats`, where slices across `axis` - # contain one `True` value for each repetition. E.g., if - # `repeats = [3, 1, 2]`, then `mask = [[1, 1, 1], [1, 0, 0], [1, 1, 0]]`. - max_repeat = math_ops.maximum(0, math_ops.reduce_max(repeats)) - mask = array_ops.sequence_mask(repeats, max_repeat) - - # Add a new dimension around each value that needs to be repeated, and - # then tile that new dimension to match the maximum number of repetitions. - expanded = array_ops.expand_dims(data, axis + 1) - tiled = tile_one_dimension(expanded, axis + 1, max_repeat) - - # Use `boolean_mask` to discard the extra repeated values. This also - # flattens all dimensions up through `axis`. - masked = array_ops.boolean_mask(tiled, mask) - - # Reshape the output tensor to add the outer dimensions back. - if axis == 0: - result = masked - else: - result_shape = array_ops.concat( - [data_shape[:axis], [-1], data_shape[axis + 1:]], axis=0) - result = array_ops.reshape(masked, result_shape) - - # Preserve shape information. - if data.shape.ndims is not None: - new_axis_size = 0 if repeats.shape[0] == 0 else None - result.set_shape(data.shape[:axis].concatenate( - [new_axis_size]).concatenate(data.shape[axis + 1:])) - - return result - - -def tile_one_dimension(data, axis, multiple): - """Tiles a single dimension of a tensor.""" - # Assumes axis is a nonnegative int. - if data.shape.ndims is not None: - multiples = [1] * data.shape.ndims - multiples[axis] = multiple - else: - ones = array_ops.ones(array_ops.rank(data), dtypes.int32) - multiples = array_ops.concat([ones[:axis], [multiple], ones[axis + 1:]], - axis=0) - return array_ops.tile(data, multiples) - - -def _with_nonzero_rank(data): - """If `data` is scalar, then add a dimension; otherwise return as-is.""" - if data.shape.ndims is not None: - if data.shape.ndims == 0: - return array_ops.stack([data]) - else: - return data - else: - data_shape = array_ops.shape(data) - data_ndims = array_ops.rank(data) - return array_ops.reshape( - data, - array_ops.concat([[1], data_shape], axis=0)[-data_ndims:]) - - def lengths_to_splits(lengths): """Returns splits corresponding to the given lengths.""" return array_ops.concat([[0], math_ops.cumsum(lengths)], axis=-1) From bdd3ba466d863488e38a384be3ca28658fde26b3 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 9 Mar 2019 22:19:06 +0000 Subject: [PATCH 2/8] Use functions in array_ops directly Signed-off-by: Yong Tang --- tensorflow/python/ops/array_ops.py | 45 +++++++++++++----------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 627d3715dbd..a7b3c6c5188 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -4721,13 +4721,12 @@ def convert_to_int_tensor(tensor, name, dtype=dtypes.int32): """Converts the given value to an integer Tensor.""" tensor = ops.convert_to_tensor(tensor, name=name, preferred_dtype=dtype) if tensor.dtype.is_integer: - tensor = math_ops.cast(tensor, dtype) + tensor = gen_math_ops.cast(tensor, dtype) else: raise TypeError( "%s must be an integer tensor; dtype=%s" % (name, tensor.dtype)) return tensor - def get_positive_axis(axis, ndims): """Validate an `axis` parameter, and normalize it to be positive. @@ -4761,8 +4760,8 @@ def get_positive_axis(axis, ndims): elif axis < 0: raise ValueError("axis may only be negative if ndims is statically known.") return axis - - + + # This op is intended to exactly match the semantics of numpy.repeat, with # one exception: numpy.repeat has special (and somewhat non-intuitive) behavior # when axis is not specified. Rather than implement that special behavior, we @@ -4772,7 +4771,6 @@ def get_positive_axis(axis, ndims): # https://github.com/tensorflow/tensorflow/issues/8246 def repeat(data, repeats, axis, name=None): """Repeats elements of `data`. - Args: data: An `N`-dimensional tensor. repeats: A 1-D integer tensor specifying how many times each element in @@ -4781,11 +4779,9 @@ def repeat(data, repeats, axis, name=None): axis: `int`. The axis along which to repeat values. Must be less than `max(N, 1)`. name: A name for the operation. - Returns: A tensor with `max(N, 1)` dimensions. Has the same shape as `data`, except that dimension `axis` has size `sum(repeats)`. - #### Examples: ```python >>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0) @@ -4806,7 +4802,7 @@ def repeat(data, repeats, axis, name=None): # If `data` is a scalar, then upgrade it to a vector. data = _with_nonzero_rank(data) - data_shape = array_ops.shape(data) + data_shape = shape(data) # If `axis` is negative, then convert it to a positive value. axis = get_positive_axis(axis, data.shape.ndims) @@ -4817,43 +4813,43 @@ def repeat(data, repeats, axis, name=None): # If we know that `repeats` is a scalar, then we can just tile & reshape. if repeats.shape.ndims == 0: - expanded = array_ops.expand_dims(data, axis + 1) + expanded = expand_dims(data, axis + 1) tiled = tile_one_dimension(expanded, axis + 1, repeats) - result_shape = array_ops.concat( + result_shape = concat( [data_shape[:axis], [-1], data_shape[axis + 1:]], axis=0) - return array_ops.reshape(tiled, result_shape) + return reshape(tiled, result_shape) # Broadcast the `repeats` tensor so rank(repeats) == axis + 1. if repeats.shape.ndims != axis + 1: - repeats_shape = array_ops.shape(repeats) - repeats_ndims = array_ops.rank(repeats) - broadcast_shape = array_ops.concat( + repeats_shape = shape(repeats) + repeats_ndims = rank(repeats) + broadcast_shape = concat( [data_shape[:axis + 1 - repeats_ndims], repeats_shape], axis=0) - repeats = array_ops.broadcast_to(repeats, broadcast_shape) + repeats = broadcast_to(repeats, broadcast_shape) repeats.set_shape([None] * (axis + 1)) # Create a "sequence mask" based on `repeats`, where slices across `axis` # contain one `True` value for each repetition. E.g., if # `repeats = [3, 1, 2]`, then `mask = [[1, 1, 1], [1, 0, 0], [1, 1, 0]]`. - max_repeat = math_ops.maximum(0, math_ops.reduce_max(repeats)) - mask = array_ops.sequence_mask(repeats, max_repeat) + max_repeat = gen_math_ops.maximum(0, gen_math_ops._max(repeats, _all_dimensions(repeats))) + mask = sequence_mask(repeats, max_repeat) # Add a new dimension around each value that needs to be repeated, and # then tile that new dimension to match the maximum number of repetitions. - expanded = array_ops.expand_dims(data, axis + 1) + expanded = expand_dims(data, axis + 1) tiled = tile_one_dimension(expanded, axis + 1, max_repeat) # Use `boolean_mask` to discard the extra repeated values. This also # flattens all dimensions up through `axis`. - masked = array_ops.boolean_mask(tiled, mask) + masked = boolean_mask(tiled, mask) # Reshape the output tensor to add the outer dimensions back. if axis == 0: result = masked else: - result_shape = array_ops.concat( + result_shape = concat( [data_shape[:axis], [-1], data_shape[axis + 1:]], axis=0) - result = array_ops.reshape(masked, result_shape) + result = reshape(masked, result_shape) # Preserve shape information. if data.shape.ndims is not None: @@ -4863,7 +4859,6 @@ def repeat(data, repeats, axis, name=None): return result - def tile_one_dimension(data, axis, multiple): """Tiles a single dimension of a tensor.""" # Assumes axis is a nonnegative int. @@ -4871,10 +4866,10 @@ def tile_one_dimension(data, axis, multiple): multiples = [1] * data.shape.ndims multiples[axis] = multiple else: - ones = array_ops.ones(array_ops.rank(data), dtypes.int32) - multiples = array_ops.concat([ones[:axis], [multiple], ones[axis + 1:]], + ones_value = ones(rank(data), dtypes.int32) + multiples = ops.concat([ones_value[:axis], [multiple], ones[axis + 1:]], axis=0) - return array_ops.tile(data, multiples) + return tile(data, multiples) def _with_nonzero_rank(data): From f7c56dee4b59f6a7ccc9ad4748b5f5150ecd54f2 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 9 Mar 2019 22:23:05 +0000 Subject: [PATCH 3/8] Change repeat to repeat_with_axis and add repeat (without axis) to comform to numpy.repeat Signed-off-by: Yong Tang --- tensorflow/python/ops/array_ops.py | 31 +++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index a7b3c6c5188..ef0c8801850 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -4769,7 +4769,7 @@ def get_positive_axis(axis, ndims): # # External (OSS) `tf.repeat` feature request: # https://github.com/tensorflow/tensorflow/issues/8246 -def repeat(data, repeats, axis, name=None): +def repeat_with_axis(data, repeats, axis, name=None): """Repeats elements of `data`. Args: data: An `N`-dimensional tensor. @@ -4867,7 +4867,7 @@ def tile_one_dimension(data, axis, multiple): multiples[axis] = multiple else: ones_value = ones(rank(data), dtypes.int32) - multiples = ops.concat([ones_value[:axis], [multiple], ones[axis + 1:]], + multiples = concat([ones_value[:axis], [multiple], ones_value[axis + 1:]], axis=0) return tile(data, multiples) @@ -4876,12 +4876,29 @@ def _with_nonzero_rank(data): """If `data` is scalar, then add a dimension; otherwise return as-is.""" if data.shape.ndims is not None: if data.shape.ndims == 0: - return array_ops.stack([data]) + return stack([data]) else: return data else: - data_shape = array_ops.shape(data) - data_ndims = array_ops.rank(data) - return array_ops.reshape( + data_shape = shape(data) + data_ndims = rank(data) + return reshape( data, - array_ops.concat([[1], data_shape], axis=0)[-data_ndims:]) + concat([[1], data_shape], axis=0)[-data_ndims:]) + +def repeat(input, repeats, axis=None, name=None): + """Repeat elements of an array + Args: + input: A Tensor. + repeats: An 1-D `int` Tensor. The number of repetitions for each element. + repeats is broadcasted to fit the shape of the given axis + axis: An int. The axis along which to repeat values. By default, use the + flattened input array, and return a flat output array. + name: name of the op. + Returns: + A Tensor which has the same shape as a, except along the given axis. + """ + if axis is None: + input = reshape(input, [-1]) + axis = 0 + return repeat_with_axis(input, repeats, axis, name) From af717fd9ad051a0340fae5c67dacd46408bfb07f Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 9 Mar 2019 22:23:50 +0000 Subject: [PATCH 4/8] Add test case for tf.repeat Signed-off-by: Yong Tang --- .../python/kernel_tests/array_ops_test.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 542833151ef..ff6d5d3d44d 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -1819,5 +1819,42 @@ class BatchGatherNdTest(test_util.TensorFlowTestCase): self.assertEqual(None, tensor_shape.dimension_value(shape[0])) +class RepeatTest(test_util.TensorFlowTestCase): + + def testRepeatScalar(self): + with self.test_session(): + v_tf = array_ops.repeat(constant_op.constant(3), 4) + v_np = np.repeat(3, 4) + self.assertAllEqual(v_tf.eval(), v_np) + + def testRepeatMatrix(self): + with self.test_session(): + x = np.array([[1, 2], [3, 4]], dtype=np.int32) + v_tf = array_ops.repeat(constant_op.constant(x), 2) + v_np = np.repeat(x, 2) + self.assertAllEqual(v_tf.eval(), v_np) + + def testRepeatMatrixAxis0(self): + with self.test_session(): + x = np.array([[1, 2], [3, 4]], dtype=np.int32) + v_tf = array_ops.repeat(constant_op.constant(x), constant_op.constant([1, 2]), axis=0) + v_np = np.repeat(x, [1, 2], axis=0) + self.assertAllEqual(v_tf.eval(), v_np) + + def testRepeatMatrixAxis1(self): + with self.test_session(): + x = np.array([[1, 2], [3, 4]], dtype=np.int32) + v_tf = array_ops.repeat(constant_op.constant(x), constant_op.constant(3), axis=1) + v_np = np.repeat(x, 3, axis=1) + self.assertAllEqual(v_tf.eval(), v_np) + + def testRepeatMatrixRepeatArray(self): + with self.test_session(): + x = np.array([[1, 2], [3, 4]], dtype=np.int32) + v_tf = array_ops.repeat(constant_op.constant(x), [1, 2, 3, 4]) + v_np = np.repeat(x, [1, 2, 3, 4]) + self.assertAllEqual(v_tf.eval(), v_np) + + if __name__ == "__main__": test_lib.main() From 25d794ce99887d6ed4e9c10f495c72b67b86d696 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 9 Mar 2019 22:24:25 +0000 Subject: [PATCH 5/8] Reexpose repeat related functions to ragged_util.py to avoid breaking test cases Signed-off-by: Yong Tang --- tensorflow/python/ops/ragged/ragged_util.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorflow/python/ops/ragged/ragged_util.py b/tensorflow/python/ops/ragged/ragged_util.py index aeea692036a..260580c56f5 100644 --- a/tensorflow/python/ops/ragged/ragged_util.py +++ b/tensorflow/python/ops/ragged/ragged_util.py @@ -58,6 +58,11 @@ def assert_splits_match(nested_splits_lists): ] +# Note: imported here to avoid circular dependency of array_ops. +get_positive_axis = array_ops.get_positive_axis +convert_to_int_tensor = array_ops.convert_to_int_tensor +repeat = array_ops.repeat_with_axis + def lengths_to_splits(lengths): """Returns splits corresponding to the given lengths.""" return array_ops.concat([[0], math_ops.cumsum(lengths)], axis=-1) From 6744f5c233add6d2c713edd3e053bc42f1c3ca73 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 9 Mar 2019 22:51:05 +0000 Subject: [PATCH 6/8] Expose tf.repeat in v1 and v2 and update api golden Signed-off-by: Yong Tang --- tensorflow/python/ops/array_ops.py | 1 + tensorflow/tools/api/golden/v1/tensorflow.pbtxt | 4 ++++ tensorflow/tools/api/golden/v2/tensorflow.pbtxt | 4 ++++ 3 files changed, 9 insertions(+) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index ef0c8801850..04dd7681a50 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -4886,6 +4886,7 @@ def _with_nonzero_rank(data): data, concat([[1], data_shape], axis=0)[-data_ndims:]) +@tf_export("repeat") def repeat(input, repeats, axis=None, name=None): """Repeat elements of an array Args: diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 47f82fbb05c..8d65cac65ca 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1916,6 +1916,10 @@ tf_module { name: "register_tensor_conversion_function" argspec: "args=[\'base_type\', \'conversion_func\', \'priority\'], varargs=None, keywords=None, defaults=[\'100\'], " } + member_method { + name: "repeat" + argspec: "args=[\'input\', \'repeats\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } member_method { name: "report_uninitialized_variables" argspec: "args=[\'var_list\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'report_uninitialized_variables\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 63c70f8aeb4..71fa993690d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -900,6 +900,10 @@ tf_module { name: "register_tensor_conversion_function" argspec: "args=[\'base_type\', \'conversion_func\', \'priority\'], varargs=None, keywords=None, defaults=[\'100\'], " } + member_method { + name: "repeat" + argspec: "args=[\'input\', \'repeats\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } member_method { name: "required_space_to_batch_paddings" argspec: "args=[\'input_shape\', \'block_shape\', \'base_paddings\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " From d4ffb947ad6affa022d95a78b1f5da0ab5142dad Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 9 Mar 2019 23:27:47 +0000 Subject: [PATCH 7/8] Fix sanity CI test issue Signed-off-by: Yong Tang --- tensorflow/python/kernel_tests/array_ops_test.py | 6 ++++-- tensorflow/python/ops/array_ops.py | 7 ++++--- tensorflow/python/ops/ragged/ragged_util.py | 1 - 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index ff6d5d3d44d..b87bab36992 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -1837,14 +1837,16 @@ class RepeatTest(test_util.TensorFlowTestCase): def testRepeatMatrixAxis0(self): with self.test_session(): x = np.array([[1, 2], [3, 4]], dtype=np.int32) - v_tf = array_ops.repeat(constant_op.constant(x), constant_op.constant([1, 2]), axis=0) + v_tf = array_ops.repeat( + constant_op.constant(x), constant_op.constant([1, 2]), axis=0) v_np = np.repeat(x, [1, 2], axis=0) self.assertAllEqual(v_tf.eval(), v_np) def testRepeatMatrixAxis1(self): with self.test_session(): x = np.array([[1, 2], [3, 4]], dtype=np.int32) - v_tf = array_ops.repeat(constant_op.constant(x), constant_op.constant(3), axis=1) + v_tf = array_ops.repeat( + constant_op.constant(x), constant_op.constant(3), axis=1) v_np = np.repeat(x, 3, axis=1) self.assertAllEqual(v_tf.eval(), v_np) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 04dd7681a50..55b746f179d 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -4831,7 +4831,8 @@ def repeat_with_axis(data, repeats, axis, name=None): # Create a "sequence mask" based on `repeats`, where slices across `axis` # contain one `True` value for each repetition. E.g., if # `repeats = [3, 1, 2]`, then `mask = [[1, 1, 1], [1, 0, 0], [1, 1, 0]]`. - max_repeat = gen_math_ops.maximum(0, gen_math_ops._max(repeats, _all_dimensions(repeats))) + max_repeat = gen_math_ops.maximum( + 0, gen_math_ops._max(repeats, _all_dimensions(repeats))) mask = sequence_mask(repeats, max_repeat) # Add a new dimension around each value that needs to be repeated, and @@ -4868,7 +4869,7 @@ def tile_one_dimension(data, axis, multiple): else: ones_value = ones(rank(data), dtypes.int32) multiples = concat([ones_value[:axis], [multiple], ones_value[axis + 1:]], - axis=0) + axis=0) return tile(data, multiples) @@ -4887,7 +4888,7 @@ def _with_nonzero_rank(data): concat([[1], data_shape], axis=0)[-data_ndims:]) @tf_export("repeat") -def repeat(input, repeats, axis=None, name=None): +def repeat(input, repeats, axis=None, name=None): # pylint: disable=redefined-builtin """Repeat elements of an array Args: input: A Tensor. diff --git a/tensorflow/python/ops/ragged/ragged_util.py b/tensorflow/python/ops/ragged/ragged_util.py index 260580c56f5..cef1a3f5155 100644 --- a/tensorflow/python/ops/ragged/ragged_util.py +++ b/tensorflow/python/ops/ragged/ragged_util.py @@ -22,7 +22,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import gen_ragged_math_ops From 6b1652e6a40d23a7d87ed7cec5d62e163ec039f2 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 18 Jun 2019 20:59:57 +0000 Subject: [PATCH 8/8] Update docstring of tf.repeats, and add additional examples based on review feedback Signed-off-by: Yong Tang --- tensorflow/python/ops/array_ops.py | 31 ++++++++++++++++----- tensorflow/python/ops/ragged/ragged_util.py | 1 - 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 55b746f179d..4e2fbea5a02 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -4889,16 +4889,33 @@ def _with_nonzero_rank(data): @tf_export("repeat") def repeat(input, repeats, axis=None, name=None): # pylint: disable=redefined-builtin - """Repeat elements of an array + """Repeat elements of `input` Args: - input: A Tensor. + input: An `N`-dimensional Tensor. repeats: An 1-D `int` Tensor. The number of repetitions for each element. - repeats is broadcasted to fit the shape of the given axis - axis: An int. The axis along which to repeat values. By default, use the - flattened input array, and return a flat output array. - name: name of the op. + repeats is broadcasted to fit the shape of the given axis. + `len(repeats)` must equal `input.shape[axis]` if axis is not None. + axis: An int. The axis along which to repeat values. By default + (axis=None), use the flattened input array, and return a flat output + array. + name: A name for the operation. Returns: - A Tensor which has the same shape as a, except along the given axis. + A Tensor which has the same shape as `input`, except along the given axis. + If axis is None then the output array is flattened to match the flattened + input array. + #### Examples: + ```python + >>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0) + ['a', 'a', 'a', 'c', 'c'] + >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0) + [[1, 2], [1, 2], [3, 4], [3, 4], [3, 4]] + >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1) + [[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]] + >>> repeat(3, repeats=4) + [3, 3, 3, 3] + >>> repeat([[1,2], [3,4]], repeats=2) + [1, 1, 2, 2, 3, 3, 4, 4] + ``` """ if axis is None: input = reshape(input, [-1]) diff --git a/tensorflow/python/ops/ragged/ragged_util.py b/tensorflow/python/ops/ragged/ragged_util.py index cef1a3f5155..0ae9f127d16 100644 --- a/tensorflow/python/ops/ragged/ragged_util.py +++ b/tensorflow/python/ops/ragged/ragged_util.py @@ -21,7 +21,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import gen_ragged_math_ops