From f446a69e5c340c3698ee57ad1f1902885058dd5c Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Mon, 23 Mar 2020 16:01:12 -0700 Subject: [PATCH] Fix repeat, and add tests PiperOrigin-RevId: 302538059 Change-Id: Ie8a17b4fe5818d04260cc5a3f8868178b92e0014 --- .../python/kernel_tests/array_ops_test.py | 64 ++++++++----------- .../python/kernel_tests/parsing_ops_test.py | 4 +- tensorflow/python/ops/array_ops.py | 29 ++++++--- 3 files changed, 47 insertions(+), 50 deletions(-) diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index b81ec5f36a8..ec3ed932996 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -21,6 +21,7 @@ import re import time import unittest +from absl.testing import parameterized import numpy as np from tensorflow.core.protobuf import config_pb2 @@ -1890,48 +1891,33 @@ class BatchGatherNdTest(test_util.TensorFlowTestCase): self.assertEqual(None, tensor_shape.dimension_value(shape[0])) -class RepeatTest(test_util.TensorFlowTestCase): +@test_util.run_all_in_graph_and_eager_modes +class RepeatTest(test_util.TensorFlowTestCase, parameterized.TestCase): - @test_util.run_deprecated_v1 - def testRepeatScalar(self): - with self.test_session(): - v_tf = array_ops.repeat(constant_op.constant(3), 4) - v_np = np.repeat(3, 4) - self.assertAllEqual(v_tf.eval(), v_np) + @parameterized.parameters( + (3, 4, None), + ([[1, 2], [3, 4]], 2, None), + ([[1, 2], [3, 4]], [1, 2], 0), + ([[1, 2], [3, 4]], [1, 2], 1), + ([[1, 2], [3, 4]], 3, 1), + ([[1, 2], [3, 4]], [1, 2, 3, 4], None), + (np.ones([0, 4]), 0, 1), + (np.ones([1, 2]), [2], None), + ) + def testRepeat(self, array, repeats, axis): + array = np.array(array) - @test_util.run_deprecated_v1 - def testRepeatMatrix(self): - with self.test_session(): - x = np.array([[1, 2], [3, 4]], dtype=np.int32) - v_tf = array_ops.repeat(constant_op.constant(x), 2) - v_np = np.repeat(x, 2) - self.assertAllEqual(v_tf.eval(), v_np) + @def_function.function( + input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)] * 2) + def repeat_fn(array, repeats): + return array_ops.repeat(array, repeats, axis) - @test_util.run_deprecated_v1 - def testRepeatMatrixAxis0(self): - with self.test_session(): - x = np.array([[1, 2], [3, 4]], dtype=np.int32) - v_tf = array_ops.repeat( - constant_op.constant(x), constant_op.constant([1, 2]), axis=0) - v_np = np.repeat(x, [1, 2], axis=0) - self.assertAllEqual(v_tf.eval(), v_np) - - @test_util.run_deprecated_v1 - def testRepeatMatrixAxis1(self): - with self.test_session(): - x = np.array([[1, 2], [3, 4]], dtype=np.int32) - v_tf = array_ops.repeat( - constant_op.constant(x), constant_op.constant(3), axis=1) - v_np = np.repeat(x, 3, axis=1) - self.assertAllEqual(v_tf.eval(), v_np) - - @test_util.run_deprecated_v1 - def testRepeatMatrixRepeatArray(self): - with self.test_session(): - x = np.array([[1, 2], [3, 4]], dtype=np.int32) - v_tf = array_ops.repeat(constant_op.constant(x), [1, 2, 3, 4]) - v_np = np.repeat(x, [1, 2, 3, 4]) - self.assertAllEqual(v_tf.eval(), v_np) + v_tf = array_ops.repeat(constant_op.constant(array), repeats, axis) + v_tf_fn = repeat_fn( + constant_op.constant(array, dtype=dtypes.int32), repeats) + v_np = np.repeat(array, repeats, axis) + self.assertAllEqual(v_tf, v_np) + self.assertAllEqual(v_tf_fn, v_np) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py index 0aaead2fa2b..c94fd0fde49 100644 --- a/tensorflow/python/kernel_tests/parsing_ops_test.py +++ b/tensorflow/python/kernel_tests/parsing_ops_test.py @@ -2278,13 +2278,13 @@ class ParseSequenceExampleTest(test.TestCase): serialized=ops.convert_to_tensor(original.SerializeToString()), sequence_features=sequence_features), expected_err=( - (errors_impl.OpError, ValueError), + (errors_impl.InvalidArgumentError, ValueError), # Message for batch=true: "Feature b: values and partitions are not aligned" # Message for batch=false in graph mode: "|.* do not form a valid RaggedTensor" # Message for batch=false in eager mode: - "|Dimensions 2 and 1 are not compatible")) + "|Incompatible shapes")) @test_util.run_all_in_graph_and_eager_modes diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index cbb5db77801..d286c96ec4e 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -5511,17 +5511,24 @@ def repeat_with_axis(data, repeats, axis, name=None): # If `axis` is negative, then convert it to a positive value. axis = get_positive_axis(axis, data.shape.rank, ndims_name="rank(data)") + # If we know that `repeats` is a scalar, then we can just tile & reshape. + if repeats.shape.num_elements() == 1: + repeats = reshape(repeats, []) + expanded = expand_dims(data, axis + 1) + tiled = tile_one_dimension(expanded, axis + 1, repeats) + result_shape = concat([ + data_shape[:axis], [repeats * data_shape[axis]], data_shape[axis + 1:] + ], + axis=0) + return reshape(tiled, result_shape) + + # Check data Tensor shapes. if repeats.shape.ndims == 1: data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0]) - # If we know that `repeats` is a scalar, then we can just tile & reshape. - if repeats.shape.ndims == 0: - expanded = expand_dims(data, axis + 1) - tiled = tile_one_dimension(expanded, axis + 1, repeats) - result_shape = concat([data_shape[:axis], [-1], data_shape[axis + 1:]], - axis=0) - return reshape(tiled, result_shape) + repeats = broadcast_to(repeats, [data_shape[axis]]) + repeats_original = repeats # Broadcast the `repeats` tensor so rank(repeats) == axis + 1. if repeats.shape.ndims != axis + 1: @@ -5552,8 +5559,12 @@ def repeat_with_axis(data, repeats, axis, name=None): if axis == 0: result = masked else: - result_shape = concat([data_shape[:axis], [-1], data_shape[axis + 1:]], - axis=0) + repeated_dim_size = gen_math_ops._sum( + repeats_original, + axis=gen_math_ops._range(0, rank(repeats_original), 1)) + result_shape = concat( + [data_shape[:axis], [repeated_dim_size], data_shape[axis + 1:]], + axis=0) result = reshape(masked, result_shape) # Preserve shape information.