Fix repeat, and add tests

PiperOrigin-RevId: 302538059
Change-Id: Ie8a17b4fe5818d04260cc5a3f8868178b92e0014
This commit is contained in:
Akshay Modi 2020-03-23 16:01:12 -07:00 committed by TensorFlower Gardener
parent 8abb0d2992
commit f446a69e5c
3 changed files with 47 additions and 50 deletions

View File

@ -21,6 +21,7 @@ import re
import time import time
import unittest import unittest
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
@ -1890,48 +1891,33 @@ class BatchGatherNdTest(test_util.TensorFlowTestCase):
self.assertEqual(None, tensor_shape.dimension_value(shape[0])) self.assertEqual(None, tensor_shape.dimension_value(shape[0]))
class RepeatTest(test_util.TensorFlowTestCase): @test_util.run_all_in_graph_and_eager_modes
class RepeatTest(test_util.TensorFlowTestCase, parameterized.TestCase):
@test_util.run_deprecated_v1 @parameterized.parameters(
def testRepeatScalar(self): (3, 4, None),
with self.test_session(): ([[1, 2], [3, 4]], 2, None),
v_tf = array_ops.repeat(constant_op.constant(3), 4) ([[1, 2], [3, 4]], [1, 2], 0),
v_np = np.repeat(3, 4) ([[1, 2], [3, 4]], [1, 2], 1),
self.assertAllEqual(v_tf.eval(), v_np) ([[1, 2], [3, 4]], 3, 1),
([[1, 2], [3, 4]], [1, 2, 3, 4], None),
(np.ones([0, 4]), 0, 1),
(np.ones([1, 2]), [2], None),
)
def testRepeat(self, array, repeats, axis):
array = np.array(array)
@test_util.run_deprecated_v1 @def_function.function(
def testRepeatMatrix(self): input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)] * 2)
with self.test_session(): def repeat_fn(array, repeats):
x = np.array([[1, 2], [3, 4]], dtype=np.int32) return array_ops.repeat(array, repeats, axis)
v_tf = array_ops.repeat(constant_op.constant(x), 2)
v_np = np.repeat(x, 2)
self.assertAllEqual(v_tf.eval(), v_np)
@test_util.run_deprecated_v1 v_tf = array_ops.repeat(constant_op.constant(array), repeats, axis)
def testRepeatMatrixAxis0(self): v_tf_fn = repeat_fn(
with self.test_session(): constant_op.constant(array, dtype=dtypes.int32), repeats)
x = np.array([[1, 2], [3, 4]], dtype=np.int32) v_np = np.repeat(array, repeats, axis)
v_tf = array_ops.repeat( self.assertAllEqual(v_tf, v_np)
constant_op.constant(x), constant_op.constant([1, 2]), axis=0) self.assertAllEqual(v_tf_fn, v_np)
v_np = np.repeat(x, [1, 2], axis=0)
self.assertAllEqual(v_tf.eval(), v_np)
@test_util.run_deprecated_v1
def testRepeatMatrixAxis1(self):
with self.test_session():
x = np.array([[1, 2], [3, 4]], dtype=np.int32)
v_tf = array_ops.repeat(
constant_op.constant(x), constant_op.constant(3), axis=1)
v_np = np.repeat(x, 3, axis=1)
self.assertAllEqual(v_tf.eval(), v_np)
@test_util.run_deprecated_v1
def testRepeatMatrixRepeatArray(self):
with self.test_session():
x = np.array([[1, 2], [3, 4]], dtype=np.int32)
v_tf = array_ops.repeat(constant_op.constant(x), [1, 2, 3, 4])
v_np = np.repeat(x, [1, 2, 3, 4])
self.assertAllEqual(v_tf.eval(), v_np)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -2278,13 +2278,13 @@ class ParseSequenceExampleTest(test.TestCase):
serialized=ops.convert_to_tensor(original.SerializeToString()), serialized=ops.convert_to_tensor(original.SerializeToString()),
sequence_features=sequence_features), sequence_features=sequence_features),
expected_err=( expected_err=(
(errors_impl.OpError, ValueError), (errors_impl.InvalidArgumentError, ValueError),
# Message for batch=true: # Message for batch=true:
"Feature b: values and partitions are not aligned" "Feature b: values and partitions are not aligned"
# Message for batch=false in graph mode: # Message for batch=false in graph mode:
"|.* do not form a valid RaggedTensor" "|.* do not form a valid RaggedTensor"
# Message for batch=false in eager mode: # Message for batch=false in eager mode:
"|Dimensions 2 and 1 are not compatible")) "|Incompatible shapes"))
@test_util.run_all_in_graph_and_eager_modes @test_util.run_all_in_graph_and_eager_modes

View File

@ -5511,17 +5511,24 @@ def repeat_with_axis(data, repeats, axis, name=None):
# If `axis` is negative, then convert it to a positive value. # If `axis` is negative, then convert it to a positive value.
axis = get_positive_axis(axis, data.shape.rank, ndims_name="rank(data)") axis = get_positive_axis(axis, data.shape.rank, ndims_name="rank(data)")
# If we know that `repeats` is a scalar, then we can just tile & reshape.
if repeats.shape.num_elements() == 1:
repeats = reshape(repeats, [])
expanded = expand_dims(data, axis + 1)
tiled = tile_one_dimension(expanded, axis + 1, repeats)
result_shape = concat([
data_shape[:axis], [repeats * data_shape[axis]], data_shape[axis + 1:]
],
axis=0)
return reshape(tiled, result_shape)
# Check data Tensor shapes. # Check data Tensor shapes.
if repeats.shape.ndims == 1: if repeats.shape.ndims == 1:
data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0]) data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0])
# If we know that `repeats` is a scalar, then we can just tile & reshape. repeats = broadcast_to(repeats, [data_shape[axis]])
if repeats.shape.ndims == 0: repeats_original = repeats
expanded = expand_dims(data, axis + 1)
tiled = tile_one_dimension(expanded, axis + 1, repeats)
result_shape = concat([data_shape[:axis], [-1], data_shape[axis + 1:]],
axis=0)
return reshape(tiled, result_shape)
# Broadcast the `repeats` tensor so rank(repeats) == axis + 1. # Broadcast the `repeats` tensor so rank(repeats) == axis + 1.
if repeats.shape.ndims != axis + 1: if repeats.shape.ndims != axis + 1:
@ -5552,7 +5559,11 @@ def repeat_with_axis(data, repeats, axis, name=None):
if axis == 0: if axis == 0:
result = masked result = masked
else: else:
result_shape = concat([data_shape[:axis], [-1], data_shape[axis + 1:]], repeated_dim_size = gen_math_ops._sum(
repeats_original,
axis=gen_math_ops._range(0, rank(repeats_original), 1))
result_shape = concat(
[data_shape[:axis], [repeated_dim_size], data_shape[axis + 1:]],
axis=0) axis=0)
result = reshape(masked, result_shape) result = reshape(masked, result_shape)