Fix repeat, and add tests
PiperOrigin-RevId: 302538059 Change-Id: Ie8a17b4fe5818d04260cc5a3f8868178b92e0014
This commit is contained in:
parent
8abb0d2992
commit
f446a69e5c
|
@ -21,6 +21,7 @@ import re
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
@ -1890,48 +1891,33 @@ class BatchGatherNdTest(test_util.TensorFlowTestCase):
|
||||||
self.assertEqual(None, tensor_shape.dimension_value(shape[0]))
|
self.assertEqual(None, tensor_shape.dimension_value(shape[0]))
|
||||||
|
|
||||||
|
|
||||||
class RepeatTest(test_util.TensorFlowTestCase):
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
class RepeatTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@parameterized.parameters(
|
||||||
def testRepeatScalar(self):
|
(3, 4, None),
|
||||||
with self.test_session():
|
([[1, 2], [3, 4]], 2, None),
|
||||||
v_tf = array_ops.repeat(constant_op.constant(3), 4)
|
([[1, 2], [3, 4]], [1, 2], 0),
|
||||||
v_np = np.repeat(3, 4)
|
([[1, 2], [3, 4]], [1, 2], 1),
|
||||||
self.assertAllEqual(v_tf.eval(), v_np)
|
([[1, 2], [3, 4]], 3, 1),
|
||||||
|
([[1, 2], [3, 4]], [1, 2, 3, 4], None),
|
||||||
|
(np.ones([0, 4]), 0, 1),
|
||||||
|
(np.ones([1, 2]), [2], None),
|
||||||
|
)
|
||||||
|
def testRepeat(self, array, repeats, axis):
|
||||||
|
array = np.array(array)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@def_function.function(
|
||||||
def testRepeatMatrix(self):
|
input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)] * 2)
|
||||||
with self.test_session():
|
def repeat_fn(array, repeats):
|
||||||
x = np.array([[1, 2], [3, 4]], dtype=np.int32)
|
return array_ops.repeat(array, repeats, axis)
|
||||||
v_tf = array_ops.repeat(constant_op.constant(x), 2)
|
|
||||||
v_np = np.repeat(x, 2)
|
|
||||||
self.assertAllEqual(v_tf.eval(), v_np)
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
v_tf = array_ops.repeat(constant_op.constant(array), repeats, axis)
|
||||||
def testRepeatMatrixAxis0(self):
|
v_tf_fn = repeat_fn(
|
||||||
with self.test_session():
|
constant_op.constant(array, dtype=dtypes.int32), repeats)
|
||||||
x = np.array([[1, 2], [3, 4]], dtype=np.int32)
|
v_np = np.repeat(array, repeats, axis)
|
||||||
v_tf = array_ops.repeat(
|
self.assertAllEqual(v_tf, v_np)
|
||||||
constant_op.constant(x), constant_op.constant([1, 2]), axis=0)
|
self.assertAllEqual(v_tf_fn, v_np)
|
||||||
v_np = np.repeat(x, [1, 2], axis=0)
|
|
||||||
self.assertAllEqual(v_tf.eval(), v_np)
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testRepeatMatrixAxis1(self):
|
|
||||||
with self.test_session():
|
|
||||||
x = np.array([[1, 2], [3, 4]], dtype=np.int32)
|
|
||||||
v_tf = array_ops.repeat(
|
|
||||||
constant_op.constant(x), constant_op.constant(3), axis=1)
|
|
||||||
v_np = np.repeat(x, 3, axis=1)
|
|
||||||
self.assertAllEqual(v_tf.eval(), v_np)
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testRepeatMatrixRepeatArray(self):
|
|
||||||
with self.test_session():
|
|
||||||
x = np.array([[1, 2], [3, 4]], dtype=np.int32)
|
|
||||||
v_tf = array_ops.repeat(constant_op.constant(x), [1, 2, 3, 4])
|
|
||||||
v_np = np.repeat(x, [1, 2, 3, 4])
|
|
||||||
self.assertAllEqual(v_tf.eval(), v_np)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -2278,13 +2278,13 @@ class ParseSequenceExampleTest(test.TestCase):
|
||||||
serialized=ops.convert_to_tensor(original.SerializeToString()),
|
serialized=ops.convert_to_tensor(original.SerializeToString()),
|
||||||
sequence_features=sequence_features),
|
sequence_features=sequence_features),
|
||||||
expected_err=(
|
expected_err=(
|
||||||
(errors_impl.OpError, ValueError),
|
(errors_impl.InvalidArgumentError, ValueError),
|
||||||
# Message for batch=true:
|
# Message for batch=true:
|
||||||
"Feature b: values and partitions are not aligned"
|
"Feature b: values and partitions are not aligned"
|
||||||
# Message for batch=false in graph mode:
|
# Message for batch=false in graph mode:
|
||||||
"|.* do not form a valid RaggedTensor"
|
"|.* do not form a valid RaggedTensor"
|
||||||
# Message for batch=false in eager mode:
|
# Message for batch=false in eager mode:
|
||||||
"|Dimensions 2 and 1 are not compatible"))
|
"|Incompatible shapes"))
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
|
|
@ -5511,17 +5511,24 @@ def repeat_with_axis(data, repeats, axis, name=None):
|
||||||
# If `axis` is negative, then convert it to a positive value.
|
# If `axis` is negative, then convert it to a positive value.
|
||||||
axis = get_positive_axis(axis, data.shape.rank, ndims_name="rank(data)")
|
axis = get_positive_axis(axis, data.shape.rank, ndims_name="rank(data)")
|
||||||
|
|
||||||
|
# If we know that `repeats` is a scalar, then we can just tile & reshape.
|
||||||
|
if repeats.shape.num_elements() == 1:
|
||||||
|
repeats = reshape(repeats, [])
|
||||||
|
expanded = expand_dims(data, axis + 1)
|
||||||
|
tiled = tile_one_dimension(expanded, axis + 1, repeats)
|
||||||
|
result_shape = concat([
|
||||||
|
data_shape[:axis], [repeats * data_shape[axis]], data_shape[axis + 1:]
|
||||||
|
],
|
||||||
|
axis=0)
|
||||||
|
return reshape(tiled, result_shape)
|
||||||
|
|
||||||
|
|
||||||
# Check data Tensor shapes.
|
# Check data Tensor shapes.
|
||||||
if repeats.shape.ndims == 1:
|
if repeats.shape.ndims == 1:
|
||||||
data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0])
|
data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0])
|
||||||
|
|
||||||
# If we know that `repeats` is a scalar, then we can just tile & reshape.
|
repeats = broadcast_to(repeats, [data_shape[axis]])
|
||||||
if repeats.shape.ndims == 0:
|
repeats_original = repeats
|
||||||
expanded = expand_dims(data, axis + 1)
|
|
||||||
tiled = tile_one_dimension(expanded, axis + 1, repeats)
|
|
||||||
result_shape = concat([data_shape[:axis], [-1], data_shape[axis + 1:]],
|
|
||||||
axis=0)
|
|
||||||
return reshape(tiled, result_shape)
|
|
||||||
|
|
||||||
# Broadcast the `repeats` tensor so rank(repeats) == axis + 1.
|
# Broadcast the `repeats` tensor so rank(repeats) == axis + 1.
|
||||||
if repeats.shape.ndims != axis + 1:
|
if repeats.shape.ndims != axis + 1:
|
||||||
|
@ -5552,7 +5559,11 @@ def repeat_with_axis(data, repeats, axis, name=None):
|
||||||
if axis == 0:
|
if axis == 0:
|
||||||
result = masked
|
result = masked
|
||||||
else:
|
else:
|
||||||
result_shape = concat([data_shape[:axis], [-1], data_shape[axis + 1:]],
|
repeated_dim_size = gen_math_ops._sum(
|
||||||
|
repeats_original,
|
||||||
|
axis=gen_math_ops._range(0, rank(repeats_original), 1))
|
||||||
|
result_shape = concat(
|
||||||
|
[data_shape[:axis], [repeated_dim_size], data_shape[axis + 1:]],
|
||||||
axis=0)
|
axis=0)
|
||||||
result = reshape(masked, result_shape)
|
result = reshape(masked, result_shape)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue