Added tests to improve code coverage in ragged_string_ops.py
PiperOrigin-RevId: 299215270 Change-Id: I5dc5e574b8fdcf0713d96ebb55b846f1e38c5770
This commit is contained in:
parent
df00d7ebbf
commit
4569e70aa2
@ -22,6 +22,8 @@ from __future__ import print_function
|
||||
from absl.testing import parameterized
|
||||
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_string_ops
|
||||
@ -64,6 +66,16 @@ class StringsToBytesOpTest(test_util.TensorFlowTestCase,
|
||||
result = ragged_string_ops.string_bytes_split(source)
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
def testUnknownInputRankError(self):
|
||||
# Use a tf.function that erases shape information.
|
||||
@def_function.function(input_signature=[tensor_spec.TensorSpec(None)])
|
||||
def f(v):
|
||||
return ragged_string_ops.string_bytes_split(v)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'input must have a statically-known rank'):
|
||||
f(['foo'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -24,6 +24,7 @@ import numpy as np
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
@ -33,7 +34,7 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
class StringSplitOpTest(test.TestCase):
|
||||
class StringSplitOpTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testStringSplit(self):
|
||||
strings = ["pigs on the wing", "animals"]
|
||||
@ -169,6 +170,66 @@ class StringSplitOpTest(test.TestCase):
|
||||
self.assertAllEqual(indices, [[0, 0], [1, 0], [2, 0]])
|
||||
self.assertAllEqual(shape, [3, 1])
|
||||
|
||||
@parameterized.named_parameters([
|
||||
dict(
|
||||
testcase_name="RaggedResultType",
|
||||
source=[b"pigs on the wing", b"animals"],
|
||||
result_type="RaggedTensor",
|
||||
expected=[[b"pigs", b"on", b"the", b"wing"], [b"animals"]]),
|
||||
dict(
|
||||
testcase_name="SparseResultType",
|
||||
source=[b"pigs on the wing", b"animals"],
|
||||
result_type="SparseTensor",
|
||||
expected=sparse_tensor.SparseTensorValue(
|
||||
[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0]],
|
||||
[b"pigs", b"on", b"the", b"wing", b"animals"], [2, 4])),
|
||||
dict(
|
||||
testcase_name="DefaultResultType",
|
||||
source=[b"pigs on the wing", b"animals"],
|
||||
expected=sparse_tensor.SparseTensorValue(
|
||||
[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0]],
|
||||
[b"pigs", b"on", b"the", b"wing", b"animals"], [2, 4])),
|
||||
dict(
|
||||
testcase_name="BadResultType",
|
||||
source=[b"pigs on the wing", b"animals"],
|
||||
result_type="BouncyTensor",
|
||||
error="result_type must be .*"),
|
||||
dict(
|
||||
testcase_name="WithSepAndAndSkipEmpty",
|
||||
source=[b"+hello+++this+is+a+test"],
|
||||
sep="+",
|
||||
skip_empty=False,
|
||||
result_type="RaggedTensor",
|
||||
expected=[[b"", b"hello", b"", b"", b"this", b"is", b"a", b"test"]]),
|
||||
dict(
|
||||
testcase_name="WithDelimiter",
|
||||
source=[b"hello world"],
|
||||
delimiter="l",
|
||||
result_type="RaggedTensor",
|
||||
expected=[[b"he", b"o wor", b"d"]]),
|
||||
])
|
||||
def testRaggedStringSplitWrapper(self,
|
||||
source,
|
||||
sep=None,
|
||||
skip_empty=True,
|
||||
delimiter=None,
|
||||
result_type="SparseTensor",
|
||||
expected=None,
|
||||
error=None):
|
||||
if error is not None:
|
||||
with self.assertRaisesRegexp(ValueError, error):
|
||||
ragged_string_ops.string_split(source, sep, skip_empty, delimiter,
|
||||
result_type)
|
||||
if expected is not None:
|
||||
result = ragged_string_ops.string_split(source, sep, skip_empty,
|
||||
delimiter, result_type)
|
||||
if isinstance(expected, sparse_tensor.SparseTensorValue):
|
||||
self.assertAllEqual(result.indices, expected.indices)
|
||||
self.assertAllEqual(result.values, expected.values)
|
||||
self.assertAllEqual(result.dense_shape, expected.dense_shape)
|
||||
else:
|
||||
self.assertAllEqual(result, expected)
|
||||
|
||||
|
||||
class StringSplitV2OpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
@ -385,6 +446,10 @@ class StringSplitV2OpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
self.assertEqual(expected_sparse.dense_shape.tolist(),
|
||||
self.evaluate(actual_sparse_v1.dense_shape).tolist())
|
||||
|
||||
def testSplitV1BadResultType(self):
|
||||
with self.assertRaisesRegexp(ValueError, "result_type must be .*"):
|
||||
ragged_string_ops.strings_split_v1("foo", result_type="BouncyTensor")
|
||||
|
||||
def _py_split(self, strings, **kwargs):
|
||||
if isinstance(strings, compat.bytes_or_text_types):
|
||||
# Note: str.split doesn't accept keyword args.
|
||||
|
@ -21,8 +21,10 @@ from __future__ import print_function
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors_impl as errors
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_string_ops
|
||||
@ -285,6 +287,16 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
|
||||
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
|
||||
self.assertAllEqual(unicode_encode_op, expected_value)
|
||||
|
||||
def testUnknownInputRankError(self):
|
||||
# Use a tf.function that erases shape information.
|
||||
@def_function.function(input_signature=[tensor_spec.TensorSpec(None)])
|
||||
def f(v):
|
||||
return ragged_string_ops.unicode_encode(v, "UTF-8")
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Rank of input_tensor must be statically known."):
|
||||
f([72, 101, 108, 108, 111])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -18,16 +18,20 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_string_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class StringNgramsTest(test_util.TensorFlowTestCase):
|
||||
class StringNgramsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
def test_unpadded_ngrams(self):
|
||||
data = [[b"aa", b"bb", b"cc", b"dd"], [b"ee", b"ff"]]
|
||||
@ -282,6 +286,66 @@ class StringNgramsTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual(constant_op.constant([], dtype=dtypes.string),
|
||||
result.values)
|
||||
|
||||
@parameterized.parameters([
|
||||
dict(
|
||||
data=[b"a", b"z"],
|
||||
ngram_width=2,
|
||||
pad_values=5,
|
||||
exception=TypeError,
|
||||
error="pad_values must be a string, tuple of strings, or None."),
|
||||
dict(
|
||||
data=[b"a", b"z"],
|
||||
ngram_width=2,
|
||||
pad_values=[5, 3],
|
||||
exception=TypeError,
|
||||
error="pad_values must be a string, tuple of strings, or None."),
|
||||
dict(
|
||||
data=[b"a", b"z"],
|
||||
ngram_width=2,
|
||||
padding_width=0,
|
||||
pad_values="X",
|
||||
error="padding_width must be greater than 0."),
|
||||
dict(
|
||||
data=[b"a", b"z"],
|
||||
ngram_width=2,
|
||||
padding_width=1,
|
||||
error="pad_values must be provided if padding_width is set."),
|
||||
dict(
|
||||
data=b"hello",
|
||||
ngram_width=2,
|
||||
padding_width=1,
|
||||
pad_values="X",
|
||||
error="Data must have rank>0"),
|
||||
dict(
|
||||
data=[b"hello", b"world"],
|
||||
ngram_width=[1, 2, -1],
|
||||
padding_width=1,
|
||||
pad_values="X",
|
||||
error="All ngram_widths must be greater than 0. Got .*"),
|
||||
])
|
||||
def test_error(self,
|
||||
data,
|
||||
ngram_width,
|
||||
separator=" ",
|
||||
pad_values=None,
|
||||
padding_width=None,
|
||||
preserve_short_sequences=False,
|
||||
error=None,
|
||||
exception=ValueError):
|
||||
with self.assertRaisesRegexp(exception, error):
|
||||
ragged_string_ops.ngrams(data, ngram_width, separator, pad_values,
|
||||
padding_width, preserve_short_sequences)
|
||||
|
||||
def test_unknown_rank_error(self):
|
||||
# Use a tf.function that erases shape information.
|
||||
@def_function.function(
|
||||
input_signature=[tensor_spec.TensorSpec(None, dtypes.string)])
|
||||
def f(v):
|
||||
return ragged_string_ops.ngrams(v, 2)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Rank of data must be known."):
|
||||
f([b"foo", b"bar"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user