Added tests to improve code coverage in ragged_string_ops.py

PiperOrigin-RevId: 299215270
Change-Id: I5dc5e574b8fdcf0713d96ebb55b846f1e38c5770
This commit is contained in:
Edward Loper 2020-03-05 15:54:53 -08:00 committed by TensorFlower Gardener
parent df00d7ebbf
commit 4569e70aa2
4 changed files with 155 additions and 2 deletions

View File

@ -22,6 +22,8 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.eager import def_function
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_string_ops
@ -64,6 +66,16 @@ class StringsToBytesOpTest(test_util.TensorFlowTestCase,
result = ragged_string_ops.string_bytes_split(source)
self.assertAllEqual(expected, result)
def testUnknownInputRankError(self):
# Use a tf.function that erases shape information.
@def_function.function(input_signature=[tensor_spec.TensorSpec(None)])
def f(v):
return ragged_string_ops.string_bytes_split(v)
with self.assertRaisesRegexp(ValueError,
'input must have a statically-known rank'):
f(['foo'])
if __name__ == '__main__':
test.main()

View File

@ -24,6 +24,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import string_ops
@ -33,7 +34,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
class StringSplitOpTest(test.TestCase):
class StringSplitOpTest(test.TestCase, parameterized.TestCase):
def testStringSplit(self):
strings = ["pigs on the wing", "animals"]
@ -169,6 +170,66 @@ class StringSplitOpTest(test.TestCase):
self.assertAllEqual(indices, [[0, 0], [1, 0], [2, 0]])
self.assertAllEqual(shape, [3, 1])
@parameterized.named_parameters([
dict(
testcase_name="RaggedResultType",
source=[b"pigs on the wing", b"animals"],
result_type="RaggedTensor",
expected=[[b"pigs", b"on", b"the", b"wing"], [b"animals"]]),
dict(
testcase_name="SparseResultType",
source=[b"pigs on the wing", b"animals"],
result_type="SparseTensor",
expected=sparse_tensor.SparseTensorValue(
[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0]],
[b"pigs", b"on", b"the", b"wing", b"animals"], [2, 4])),
dict(
testcase_name="DefaultResultType",
source=[b"pigs on the wing", b"animals"],
expected=sparse_tensor.SparseTensorValue(
[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0]],
[b"pigs", b"on", b"the", b"wing", b"animals"], [2, 4])),
dict(
testcase_name="BadResultType",
source=[b"pigs on the wing", b"animals"],
result_type="BouncyTensor",
error="result_type must be .*"),
dict(
testcase_name="WithSepAndAndSkipEmpty",
source=[b"+hello+++this+is+a+test"],
sep="+",
skip_empty=False,
result_type="RaggedTensor",
expected=[[b"", b"hello", b"", b"", b"this", b"is", b"a", b"test"]]),
dict(
testcase_name="WithDelimiter",
source=[b"hello world"],
delimiter="l",
result_type="RaggedTensor",
expected=[[b"he", b"o wor", b"d"]]),
])
def testRaggedStringSplitWrapper(self,
source,
sep=None,
skip_empty=True,
delimiter=None,
result_type="SparseTensor",
expected=None,
error=None):
if error is not None:
with self.assertRaisesRegexp(ValueError, error):
ragged_string_ops.string_split(source, sep, skip_empty, delimiter,
result_type)
if expected is not None:
result = ragged_string_ops.string_split(source, sep, skip_empty,
delimiter, result_type)
if isinstance(expected, sparse_tensor.SparseTensorValue):
self.assertAllEqual(result.indices, expected.indices)
self.assertAllEqual(result.values, expected.values)
self.assertAllEqual(result.dense_shape, expected.dense_shape)
else:
self.assertAllEqual(result, expected)
class StringSplitV2OpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
@ -385,6 +446,10 @@ class StringSplitV2OpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertEqual(expected_sparse.dense_shape.tolist(),
self.evaluate(actual_sparse_v1.dense_shape).tolist())
def testSplitV1BadResultType(self):
with self.assertRaisesRegexp(ValueError, "result_type must be .*"):
ragged_string_ops.strings_split_v1("foo", result_type="BouncyTensor")
def _py_split(self, strings, **kwargs):
if isinstance(strings, compat.bytes_or_text_types):
# Note: str.split doesn't accept keyword args.

View File

@ -21,8 +21,10 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl as errors
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_string_ops
@ -285,6 +287,16 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
self.assertAllEqual(unicode_encode_op, expected_value)
def testUnknownInputRankError(self):
# Use a tf.function that erases shape information.
@def_function.function(input_signature=[tensor_spec.TensorSpec(None)])
def f(v):
return ragged_string_ops.unicode_encode(v, "UTF-8")
with self.assertRaisesRegexp(
ValueError, "Rank of input_tensor must be statically known."):
f([72, 101, 108, 108, 111])
if __name__ == "__main__":
test.main()

View File

@ -18,16 +18,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_string_ops
from tensorflow.python.platform import test
class StringNgramsTest(test_util.TensorFlowTestCase):
class StringNgramsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def test_unpadded_ngrams(self):
data = [[b"aa", b"bb", b"cc", b"dd"], [b"ee", b"ff"]]
@ -282,6 +286,66 @@ class StringNgramsTest(test_util.TensorFlowTestCase):
self.assertAllEqual(constant_op.constant([], dtype=dtypes.string),
result.values)
@parameterized.parameters([
dict(
data=[b"a", b"z"],
ngram_width=2,
pad_values=5,
exception=TypeError,
error="pad_values must be a string, tuple of strings, or None."),
dict(
data=[b"a", b"z"],
ngram_width=2,
pad_values=[5, 3],
exception=TypeError,
error="pad_values must be a string, tuple of strings, or None."),
dict(
data=[b"a", b"z"],
ngram_width=2,
padding_width=0,
pad_values="X",
error="padding_width must be greater than 0."),
dict(
data=[b"a", b"z"],
ngram_width=2,
padding_width=1,
error="pad_values must be provided if padding_width is set."),
dict(
data=b"hello",
ngram_width=2,
padding_width=1,
pad_values="X",
error="Data must have rank>0"),
dict(
data=[b"hello", b"world"],
ngram_width=[1, 2, -1],
padding_width=1,
pad_values="X",
error="All ngram_widths must be greater than 0. Got .*"),
])
def test_error(self,
data,
ngram_width,
separator=" ",
pad_values=None,
padding_width=None,
preserve_short_sequences=False,
error=None,
exception=ValueError):
with self.assertRaisesRegexp(exception, error):
ragged_string_ops.ngrams(data, ngram_width, separator, pad_values,
padding_width, preserve_short_sequences)
def test_unknown_rank_error(self):
# Use a tf.function that erases shape information.
@def_function.function(
input_signature=[tensor_spec.TensorSpec(None, dtypes.string)])
def f(v):
return ragged_string_ops.ngrams(v, 2)
with self.assertRaisesRegexp(ValueError, "Rank of data must be known."):
f([b"foo", b"bar"])
if __name__ == "__main__":
test.main()