Fix repeat, and add tests
PiperOrigin-RevId: 302538059 Change-Id: Ie8a17b4fe5818d04260cc5a3f8868178b92e0014
This commit is contained in:
parent
8abb0d2992
commit
f446a69e5c
|
@ -21,6 +21,7 @@ import re
|
|||
import time
|
||||
import unittest
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
|
@ -1890,48 +1891,33 @@ class BatchGatherNdTest(test_util.TensorFlowTestCase):
|
|||
self.assertEqual(None, tensor_shape.dimension_value(shape[0]))
|
||||
|
||||
|
||||
class RepeatTest(test_util.TensorFlowTestCase):
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RepeatTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRepeatScalar(self):
|
||||
with self.test_session():
|
||||
v_tf = array_ops.repeat(constant_op.constant(3), 4)
|
||||
v_np = np.repeat(3, 4)
|
||||
self.assertAllEqual(v_tf.eval(), v_np)
|
||||
@parameterized.parameters(
|
||||
(3, 4, None),
|
||||
([[1, 2], [3, 4]], 2, None),
|
||||
([[1, 2], [3, 4]], [1, 2], 0),
|
||||
([[1, 2], [3, 4]], [1, 2], 1),
|
||||
([[1, 2], [3, 4]], 3, 1),
|
||||
([[1, 2], [3, 4]], [1, 2, 3, 4], None),
|
||||
(np.ones([0, 4]), 0, 1),
|
||||
(np.ones([1, 2]), [2], None),
|
||||
)
|
||||
def testRepeat(self, array, repeats, axis):
|
||||
array = np.array(array)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRepeatMatrix(self):
|
||||
with self.test_session():
|
||||
x = np.array([[1, 2], [3, 4]], dtype=np.int32)
|
||||
v_tf = array_ops.repeat(constant_op.constant(x), 2)
|
||||
v_np = np.repeat(x, 2)
|
||||
self.assertAllEqual(v_tf.eval(), v_np)
|
||||
@def_function.function(
|
||||
input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)] * 2)
|
||||
def repeat_fn(array, repeats):
|
||||
return array_ops.repeat(array, repeats, axis)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRepeatMatrixAxis0(self):
|
||||
with self.test_session():
|
||||
x = np.array([[1, 2], [3, 4]], dtype=np.int32)
|
||||
v_tf = array_ops.repeat(
|
||||
constant_op.constant(x), constant_op.constant([1, 2]), axis=0)
|
||||
v_np = np.repeat(x, [1, 2], axis=0)
|
||||
self.assertAllEqual(v_tf.eval(), v_np)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRepeatMatrixAxis1(self):
|
||||
with self.test_session():
|
||||
x = np.array([[1, 2], [3, 4]], dtype=np.int32)
|
||||
v_tf = array_ops.repeat(
|
||||
constant_op.constant(x), constant_op.constant(3), axis=1)
|
||||
v_np = np.repeat(x, 3, axis=1)
|
||||
self.assertAllEqual(v_tf.eval(), v_np)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testRepeatMatrixRepeatArray(self):
|
||||
with self.test_session():
|
||||
x = np.array([[1, 2], [3, 4]], dtype=np.int32)
|
||||
v_tf = array_ops.repeat(constant_op.constant(x), [1, 2, 3, 4])
|
||||
v_np = np.repeat(x, [1, 2, 3, 4])
|
||||
self.assertAllEqual(v_tf.eval(), v_np)
|
||||
v_tf = array_ops.repeat(constant_op.constant(array), repeats, axis)
|
||||
v_tf_fn = repeat_fn(
|
||||
constant_op.constant(array, dtype=dtypes.int32), repeats)
|
||||
v_np = np.repeat(array, repeats, axis)
|
||||
self.assertAllEqual(v_tf, v_np)
|
||||
self.assertAllEqual(v_tf_fn, v_np)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -2278,13 +2278,13 @@ class ParseSequenceExampleTest(test.TestCase):
|
|||
serialized=ops.convert_to_tensor(original.SerializeToString()),
|
||||
sequence_features=sequence_features),
|
||||
expected_err=(
|
||||
(errors_impl.OpError, ValueError),
|
||||
(errors_impl.InvalidArgumentError, ValueError),
|
||||
# Message for batch=true:
|
||||
"Feature b: values and partitions are not aligned"
|
||||
# Message for batch=false in graph mode:
|
||||
"|.* do not form a valid RaggedTensor"
|
||||
# Message for batch=false in eager mode:
|
||||
"|Dimensions 2 and 1 are not compatible"))
|
||||
"|Incompatible shapes"))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
|
|
|
@ -5511,17 +5511,24 @@ def repeat_with_axis(data, repeats, axis, name=None):
|
|||
# If `axis` is negative, then convert it to a positive value.
|
||||
axis = get_positive_axis(axis, data.shape.rank, ndims_name="rank(data)")
|
||||
|
||||
# If we know that `repeats` is a scalar, then we can just tile & reshape.
|
||||
if repeats.shape.num_elements() == 1:
|
||||
repeats = reshape(repeats, [])
|
||||
expanded = expand_dims(data, axis + 1)
|
||||
tiled = tile_one_dimension(expanded, axis + 1, repeats)
|
||||
result_shape = concat([
|
||||
data_shape[:axis], [repeats * data_shape[axis]], data_shape[axis + 1:]
|
||||
],
|
||||
axis=0)
|
||||
return reshape(tiled, result_shape)
|
||||
|
||||
|
||||
# Check data Tensor shapes.
|
||||
if repeats.shape.ndims == 1:
|
||||
data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0])
|
||||
|
||||
# If we know that `repeats` is a scalar, then we can just tile & reshape.
|
||||
if repeats.shape.ndims == 0:
|
||||
expanded = expand_dims(data, axis + 1)
|
||||
tiled = tile_one_dimension(expanded, axis + 1, repeats)
|
||||
result_shape = concat([data_shape[:axis], [-1], data_shape[axis + 1:]],
|
||||
axis=0)
|
||||
return reshape(tiled, result_shape)
|
||||
repeats = broadcast_to(repeats, [data_shape[axis]])
|
||||
repeats_original = repeats
|
||||
|
||||
# Broadcast the `repeats` tensor so rank(repeats) == axis + 1.
|
||||
if repeats.shape.ndims != axis + 1:
|
||||
|
@ -5552,7 +5559,11 @@ def repeat_with_axis(data, repeats, axis, name=None):
|
|||
if axis == 0:
|
||||
result = masked
|
||||
else:
|
||||
result_shape = concat([data_shape[:axis], [-1], data_shape[axis + 1:]],
|
||||
repeated_dim_size = gen_math_ops._sum(
|
||||
repeats_original,
|
||||
axis=gen_math_ops._range(0, rank(repeats_original), 1))
|
||||
result_shape = concat(
|
||||
[data_shape[:axis], [repeated_dim_size], data_shape[axis + 1:]],
|
||||
axis=0)
|
||||
result = reshape(masked, result_shape)
|
||||
|
||||
|
|
Loading…
Reference in New Issue