Add a new "result_type" parameter to tf.strings.split, which controls whether the result is a SparseTensor or a RaggedTensor. Defaults to SparseTensor for TensorFlow 1.x, and RaggedTensor for TensorFlow 2.x.

PiperOrigin-RevId: 242673690
This commit is contained in:
Edward Loper 2019-04-09 08:45:54 -07:00 committed by TensorFlower Gardener
parent 374d665e0f
commit 6cf4b5d3eb
11 changed files with 371 additions and 72 deletions

View File

@ -72,7 +72,7 @@ class TextEmbeddingModel(tf.train.Checkpoint):
normalized_sentences = tf.strings.regex_replace(
input=sentences, pattern=r"\pP", rewrite="")
normalized_sentences = tf.reshape(normalized_sentences, [-1])
sparse_tokens = tf.strings.split(normalized_sentences, " ")
sparse_tokens = tf.strings.split(normalized_sentences, " ").to_sparse()
# Deal with a corner case: there is one empty sentence.
sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant(""))

View File

@ -50,7 +50,7 @@ class TextRnnModel(tf.train.Checkpoint):
# splitting on spaces.
normalized_sentences = tf.strings.regex_replace(
input=sentences, pattern=r"\pP", rewrite="")
sparse_tokens = tf.strings.split(normalized_sentences, " ")
sparse_tokens = tf.strings.split(normalized_sentences, " ").to_sparse()
# Deal with a corner case: there is one empty sentence.
sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant(""))

View File

@ -1068,6 +1068,7 @@ tf_py_test(
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:string_ops",
"//tensorflow/python/ops/ragged:ragged_string_ops",
],
)

View File

@ -25,6 +25,7 @@ from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops.ragged import ragged_string_ops
from tensorflow.python.platform import test
@ -170,12 +171,16 @@ class StringSplitV2OpTest(test.TestCase):
def testSplitV2(self):
strings = ["pigs on the wing", "animals"]
with self.cached_session() as sess:
tokens = string_ops.string_split_v2(strings)
indices, values, shape = self.evaluate(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [0, 3], [1, 0]])
self.assertAllEqual(values, [b"pigs", b"on", b"the", b"wing", b"animals"])
self.assertAllEqual(shape, [2, 4])
tokens = string_ops.string_split_v2(strings)
indices, values, shape = self.evaluate(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [0, 3], [1, 0]])
self.assertAllEqual(values, [b"pigs", b"on", b"the", b"wing", b"animals"])
self.assertAllEqual(shape, [2, 4])
ragged_tokens = ragged_string_ops.string_split_v2(strings)
self.assertAllEqual(ragged_tokens.row_splits, [0, 4, 5])
self.assertAllEqual(ragged_tokens.values,
[b"pigs", b"on", b"the", b"wing", b"animals"])
def testSplitV2MultiCharSeparator(self):
# Match Python behavior:
@ -185,15 +190,20 @@ class StringSplitV2OpTest(test.TestCase):
# ['', '', '4', '5', '', '6', '']
strings = ["1<>2<>3", "<><>4<>5<><>6<>"]
with self.cached_session() as sess:
tokens = string_ops.string_split_v2(strings, sep="<>")
indices, values, shape = self.evaluate(tokens)
self.assertAllEqual(
indices, [[0, 0], [0, 1], [0, 2],
[1, 0], [1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6]])
self.assertAllEqual(values, [b"1", b"2", b"3",
b"", b"", b"4", b"5", b"", b"6", b""])
self.assertAllEqual(shape, [2, 7])
tokens = string_ops.string_split_v2(strings, sep="<>")
indices, values, shape = self.evaluate(tokens)
self.assertAllEqual(indices,
[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [1, 3],
[1, 4], [1, 5], [1, 6]])
self.assertAllEqual(
values, [b"1", b"2", b"3", b"", b"", b"4", b"5", b"", b"6", b""])
self.assertAllEqual(shape, [2, 7])
ragged_tokens = ragged_string_ops.string_split_v2(strings, sep="<>")
self.assertAllEqual(ragged_tokens.row_splits, [0, 3, 10])
self.assertAllEqual(
ragged_tokens.values,
[b"1", b"2", b"3", b"", b"", b"4", b"5", b"", b"6", b""])
def testSplitV2SimpleSeparator(self):
# Match Python behavior:
@ -203,30 +213,38 @@ class StringSplitV2OpTest(test.TestCase):
# ['1', '2', '', '3', '']
strings = ["1,2,3", "4,5,,6,"]
with self.cached_session() as sess:
tokens = string_ops.string_split_v2(strings, sep=',')
indices, values, shape = self.evaluate(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2],
[1, 0], [1, 1], [1, 2], [1, 3], [1, 4]])
self.assertAllEqual(values, [b"1", b"2", b"3",
b"4", b"5", b"", b"6", b""])
self.assertAllEqual(shape, [2, 5])
tokens = string_ops.string_split_v2(strings, sep=",")
indices, values, shape = self.evaluate(tokens)
self.assertAllEqual(
indices,
[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [1, 3], [1, 4]])
self.assertAllEqual(values, [b"1", b"2", b"3", b"4", b"5", b"", b"6", b""])
self.assertAllEqual(shape, [2, 5])
ragged_tokens = ragged_string_ops.string_split_v2(strings, sep=",")
self.assertAllEqual(ragged_tokens.row_splits, [0, 3, 8])
self.assertAllEqual(ragged_tokens.values,
[b"1", b"2", b"3", b"4", b"5", b"", b"6", b""])
def testSplitV2EmptySeparator(self):
# Match Python behavior:
# >>> '1 2 3'.split()
# ['1', '2', '3']
#>>> ' 1 2 3 '.split()
#['1', '2', '3']
# >>> ' 1 2 3 '.split()
# ['1', '2', '3']
strings = ["1 2 3", " 4 5 6 "]
with self.cached_session() as sess:
tokens = string_ops.string_split_v2(strings)
indices, values, shape = self.evaluate(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2],
[1, 0], [1, 1], [1, 2]])
self.assertAllEqual(values, [b"1", b"2", b"3", b"4", b"5", b"6"])
self.assertAllEqual(shape, [2, 3])
tokens = string_ops.string_split_v2(strings)
indices, values, shape = self.evaluate(tokens)
self.assertAllEqual(indices,
[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]])
self.assertAllEqual(values, [b"1", b"2", b"3", b"4", b"5", b"6"])
self.assertAllEqual(shape, [2, 3])
ragged_tokens = ragged_string_ops.string_split_v2(strings)
self.assertAllEqual(ragged_tokens.row_splits, [0, 3, 6])
self.assertAllEqual(ragged_tokens.values,
[b"1", b"2", b"3", b"4", b"5", b"6"])
def testSplitV2SimpleSeparatorMaxSplit(self):
# Match Python behavior:
@ -236,13 +254,16 @@ class StringSplitV2OpTest(test.TestCase):
# ['4', '5,,6,']
strings = ["1,2,3", "4,5,,6,"]
with self.cached_session() as sess:
tokens = string_ops.string_split_v2(strings, sep=',', maxsplit=1)
indices, values, shape = self.evaluate(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1],
[1, 0], [1, 1]])
self.assertAllEqual(values, [b"1", b"2,3", b"4", b"5,,6,"])
self.assertAllEqual(shape, [2, 2])
tokens = string_ops.string_split_v2(strings, sep=",", maxsplit=1)
indices, values, shape = self.evaluate(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1], [1, 0], [1, 1]])
self.assertAllEqual(values, [b"1", b"2,3", b"4", b"5,,6,"])
self.assertAllEqual(shape, [2, 2])
ragged_tokens = ragged_string_ops.string_split_v2(
strings, sep=",", maxsplit=1)
self.assertAllEqual(ragged_tokens.row_splits, [0, 2, 4])
self.assertAllEqual(ragged_tokens.values, [b"1", b"2,3", b"4", b"5,,6,"])
def testSplitV2EmptySeparatorMaxSplit(self):
# Match Python behavior:
@ -252,13 +273,15 @@ class StringSplitV2OpTest(test.TestCase):
# ['4', '5 6 ']
strings = ["1 2 3", " 4 5 6 "]
with self.cached_session() as sess:
tokens = string_ops.string_split_v2(strings, maxsplit=1)
indices, values, shape = self.evaluate(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1],
[1, 0], [1, 1]])
self.assertAllEqual(values, [b"1", b"2 3", b"4", b"5 6 "])
self.assertAllEqual(shape, [2, 2])
tokens = string_ops.string_split_v2(strings, maxsplit=1)
indices, values, shape = self.evaluate(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1], [1, 0], [1, 1]])
self.assertAllEqual(values, [b"1", b"2 3", b"4", b"5 6 "])
self.assertAllEqual(shape, [2, 2])
ragged_tokens = ragged_string_ops.string_split_v2(strings, maxsplit=1)
self.assertAllEqual(ragged_tokens.row_splits, [0, 2, 4])
self.assertAllEqual(ragged_tokens.values, [b"1", b"2 3", b"4", b"5 6 "])
if __name__ == "__main__":

View File

@ -22,9 +22,11 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_conversion_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@ -400,3 +402,163 @@ def _unicode_decode(input, input_encoding, errors, replacement_char,
else:
return codepoints
@tf_export("strings.split", v1=[])
def string_split_v2(source, sep=None, maxsplit=-1):
"""Split elements of `source` based on `sep` into a `RaggedTensor`.
Let N be the size of source (typically N will be the batch size). Split each
element of `source` based on `sep` and return a `SparseTensor` or
`RaggedTensor` containing the split tokens. Empty tokens are ignored.
Example:
```python
>>> tf.strings.split(['hello world', 'a b c'])
<tf.RaggedTensor [['hello', 'world'], ['a', 'b', 'c']]>
```
If `sep` is given, consecutive delimiters are not grouped together and are
deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and
sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
string, consecutive whitespace are regarded as a single separator, and the
result will contain no empty strings at the start or end if the string has
leading or trailing whitespace.
Note that the above mentioned behavior matches python's str.split.
Args:
source: `1-D` string `Tensor`, the strings to split.
sep: `0-D` string `Tensor`, the delimiter string.
maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result.
Raises:
ValueError: If sep is not a string.
Returns:
A `RaggedTensor` of rank `2`: the strings split according to the delimiter.
"""
sparse_result = string_ops.string_split_v2(source, sep=sep, maxsplit=maxsplit)
return ragged_tensor.RaggedTensor.from_value_rowids(
values=sparse_result.values,
value_rowids=sparse_result.indices[:, 0],
nrows=sparse_result.dense_shape[0])
@tf_export(v1=["string_split"])
@deprecation.deprecated_args(None,
"delimiter is deprecated, please use sep instead.",
"delimiter")
def string_split(source, sep=None, skip_empty=True, delimiter=None,
result_type="SparseTensor"): # pylint: disable=invalid-name
"""Split elements of `source` based on `delimiter`.
Let N be the size of `source` (typically N will be the batch size). Split each
element of `source` based on `delimiter` and return a `SparseTensor`
or `RaggedTensor` containing the split tokens. Empty tokens are ignored.
If `sep` is an empty string, each element of the `source` is split
into individual strings, each containing one byte. (This includes splitting
multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is
treated as a set of delimiters with each considered a potential split point.
Examples:
```python
>>> tf.strings.split(['hello world', 'a b c'])
tf.SparseTensor(indices=[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]],
values=['hello', 'world', 'a', 'b', 'c']
dense_shape=[2, 3])
>>> tf.strings.split(['hello world', 'a b c'], result_type="RaggedTensor")
<tf.RaggedTensor [['hello', 'world'], ['a', 'b', 'c']]>
```
Args:
source: `1-D` string `Tensor`, the strings to split.
sep: `0-D` string `Tensor`, the delimiter character, the string should
be length 0 or 1. Default is ' '.
skip_empty: A `bool`. If `True`, skip the empty strings from the result.
delimiter: deprecated alias for `sep`.
result_type: The tensor type for the result: one of `"RaggedTensor"` or
`"SparseTensor"`.
Raises:
ValueError: If delimiter is not a string.
Returns:
A `SparseTensor` or `RaggedTensor` of rank `2`, the strings split according
to the delimiter. The first column of the indices corresponds to the row
in `source` and the second column corresponds to the index of the split
component in this row.
"""
sparse_result = string_ops.string_split(
source, sep=sep, skip_empty=skip_empty, delimiter=delimiter)
if result_type == "SparseTensor":
return sparse_result
elif result_type == "RaggedTensor":
return ragged_tensor.RaggedTensor.from_value_rowids(
values=sparse_result.values,
value_rowids=sparse_result.indices[:, 0],
nrows=sparse_result.dense_shape[0])
else:
raise ValueError("result_type must be 'RaggedTensor' or 'SparseTensor'.")
# In TensorFlow 1.x, "tf.strings.split" uses the new signature (with maxsplit),
# but we need to add the result_type argument.
@tf_export(v1=["strings.split"])
def strings_split_v1(source, sep=None, maxsplit=-1, result_type="SparseTensor"):
"""Split elements of `source` based on `sep`.
Let N be the size of source (typically N will be the batch size). Split each
element of `source` based on `sep` and return a `SparseTensor` or
`RaggedTensor` containing the split tokens. Empty tokens are ignored.
Examples:
```python
>>> tf.strings.split(['hello world', 'a b c'])
tf.SparseTensor(indices=[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]],
values=['hello', 'world', 'a', 'b', 'c']
dense_shape=[2, 3])
>>> tf.strings.split(['hello world', 'a b c'], result_type="RaggedTensor")
<tf.RaggedTensor [['hello', 'world'], ['a', 'b', 'c']]>
```
If `sep` is given, consecutive delimiters are not grouped together and are
deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and
sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
string, consecutive whitespace are regarded as a single separator, and the
result will contain no empty strings at the start or end if the string has
leading or trailing whitespace.
Note that the above mentioned behavior matches python's str.split.
Args:
source: `1-D` string `Tensor`, the strings to split.
sep: `0-D` string `Tensor`, the delimiter character.
maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result.
result_type: The tensor type for the result: one of `"RaggedTensor"` or
`"SparseTensor"`.
Raises:
ValueError: If sep is not a string.
Returns:
A `SparseTensor` of rank `2`, the strings split according to the delimiter.
The first column of the indices corresponds to the row in `source` and the
second column corresponds to the index of the split component in this row.
"""
sparse_result = string_ops.string_split_v2(
source, sep=sep, maxsplit=maxsplit)
if result_type == "SparseTensor":
return sparse_result
elif result_type == "RaggedTensor":
return ragged_tensor.RaggedTensor.from_value_rowids(
values=sparse_result.values,
value_rowids=sparse_result.indices[:, 0],
nrows=sparse_result.dense_shape[0])
else:
raise ValueError("result_type must be 'RaggedTensor' or 'SparseTensor'.")

View File

@ -192,10 +192,8 @@ def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
name=name)
@tf_export(v1=["string_split"])
@deprecation.deprecated_args(None,
"delimiter is deprecated, please use sep instead.",
"delimiter")
# Note: tf.strings.split is exported in ragged/ragged_string_ops.py, which
# defines a wrapper for this function.
def string_split(source, sep=None, skip_empty=True, delimiter=None): # pylint: disable=invalid-name
"""Split elements of `source` based on `delimiter` into a `SparseTensor`.
@ -251,7 +249,8 @@ def string_split(source, sep=None, skip_empty=True, delimiter=None): # pylint:
return sparse_tensor.SparseTensor(indices, values, shape)
@tf_export("strings.split")
# Note: tf.strings.split is exported in ragged/ragged_string_ops.py, which
# defines a wrapper for this function.
def string_split_v2(source, sep=None, maxsplit=-1):
"""Split elements of `source` based on `sep` into a `SparseTensor`.

View File

@ -2218,7 +2218,7 @@ tf_module {
}
member_method {
name: "string_split"
argspec: "args=[\'source\', \'sep\', \'skip_empty\', \'delimiter\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
argspec: "args=[\'source\', \'sep\', \'skip_empty\', \'delimiter\', \'result_type\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'SparseTensor\'], "
}
member_method {
name: "string_strip"

View File

@ -30,7 +30,7 @@ tf_module {
}
member_method {
name: "split"
argspec: "args=[\'source\', \'sep\', \'maxsplit\'], varargs=None, keywords=None, defaults=[\'None\', \'-1\'], "
argspec: "args=[\'source\', \'sep\', \'maxsplit\', \'result_type\'], varargs=None, keywords=None, defaults=[\'None\', \'-1\', \'SparseTensor\'], "
}
member_method {
name: "strip"

View File

@ -109,6 +109,7 @@ py_test(
"//tensorflow/python:framework_test_lib",
"//tensorflow/tools/common:public_api",
"//tensorflow/tools/common:traverse",
"@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
)

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import ast
import copy
import functools
import sys
@ -1595,6 +1596,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"tf.nn.fractional_avg_pool": _pool_seed_transformer,
"tf.nn.fractional_max_pool": _pool_seed_transformer,
"tf.name_scope": _name_scope_transformer,
"tf.string_split": _string_split_transformer,
"tf.strings.split": _string_split_rtype_transformer,
"tf.estimator.DNNEstimator":
functools.partial(
_rename_if_arg_found_transformer,
@ -1667,12 +1670,6 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"takes input_layer_partitioner, so the call was converted to "
"compat.v1."
),
"tf.string_split": functools.partial(
_rename_if_arg_found_transformer, arg_name="skip_empty",
arg_ok_predicate=_is_ast_false, remove_if_ok=True,
message="tf.string_split's replacement no longer takes the "
"skip_empty argument. Since the argument was present, the call was "
"converted to compat.v1."),
"tf.device": functools.partial(
_rename_if_arg_found_transformer, arg_name="device_name",
arg_ok_predicate=_is_ast_str, remove_if_ok=False,
@ -2274,3 +2271,82 @@ def _name_scope_transformer(parent, node, full_name, name, logs):
logs.append((ast_edits.ERROR, node.lineno, node.col_offset,
"name_scope call with neither name nor default_name cannot be "
"converted properly."))
def _rename_to_compat_v1(node, full_name, logs, reason):
new_name = full_name.replace("tf.", "tf.compat.v1.", 1)
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
"Renamed %r to %r: %s" % (full_name, new_name, reason)))
new_name_node = ast_edits.full_name_node(new_name, node.func.ctx)
ast.copy_location(new_name_node, node.func)
pasta.ast_utils.replace_child(node, node.func, new_name_node)
return node
def _string_split_transformer(parent, node, full_name, name, logs):
"""Update tf.string_split arguments: skip_empty, sep, result_type."""
# Check the skip_empty parameter: if not false, then use compat.v1.
for i, kw in enumerate(node.keywords):
if kw.arg == "skip_empty":
if _is_ast_false(kw.value):
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
"removed argument skip_empty for tf.string_split."))
node.keywords.pop(i)
break
else:
return _rename_to_compat_v1(
node, full_name, logs, "tf.string_split's replacement no longer "
"takes the skip_empty argument.")
# Check the sep parameter: if it might be an empty string, then use compat.v1.
sep_is_nonempty_string = False
for i, kw in enumerate(node.keywords):
if ((kw.arg == "sep" or kw.arg == "delimiter") and
isinstance(kw.value, ast.Str) and kw.value.s != ""):
sep_is_nonempty_string = True
if not sep_is_nonempty_string:
return _rename_to_compat_v1(
node, full_name, logs,
"The semantics for tf.string_split's sep parameter have changed when "
"sep is the empty string.")
# Check the result_type parameter
return _string_split_rtype_transformer(parent, node, full_name, name, logs)
def _string_split_rtype_transformer(parent, node, full_name, name, logs):
"""Update tf.strings.split argument: result_type."""
# Remove the "result_type" argument.
need_to_sparse = True
for i, kw in enumerate(node.keywords):
if kw.arg == "result_type":
if (isinstance(kw.value, ast.Str) and
kw.value.s in ("RaggedTensor", "SparseTensor")):
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
"Removed argument result_type=%r for function %s" %
(kw.value.s, full_name or name)))
node.keywords.pop(i)
if kw.value.s == "RaggedTensor":
need_to_sparse = False
else:
return _rename_to_compat_v1(
node, full_name, logs,
"%s no longer takes the result_type parameter." % full_name)
break
# If necessary, add a call to .to_sparse() to convert the output of
# strings.split from a RaggedTensor to a SparseTensor.
if need_to_sparse:
if (isinstance(parent, ast.Attribute) and parent.attr == "to_sparse"):
return # Prevent infinite recursion (since child nodes are transformed)
logs.append(
(ast_edits.INFO, node.lineno, node.col_offset,
"Adding call to RaggedTensor.to_sparse() to result of strings.split, "
"since it now returns a RaggedTensor."))
node = ast.Attribute(value=copy.deepcopy(node), attr="to_sparse")
try:
node = ast.Call(node, [], [])
except TypeError:
node = ast.Call(node, [], [], None, None)
return node

View File

@ -22,6 +22,7 @@ import inspect
import os
import tempfile
from absl.testing import parameterized
import six
import tensorflow as tf
# OSS TF V2 import placeholder.
@ -75,7 +76,7 @@ def get_func_and_args_from_str(call_str):
return function_name, args
class TestUpgrade(test_util.TensorFlowTestCase):
class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
"""Test various APIs that have been changed in 2.0.
We also test whether a converted file is executable. test_file_v1_10.py
@ -85,6 +86,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
@classmethod
def setUpClass(cls):
super(TestUpgrade, cls).setUpClass()
cls.v2_symbols = {}
cls.v1_symbols = {}
if hasattr(tf.compat, "v2"):
@ -1512,19 +1514,54 @@ def _log_prob(self, x):
self.assertIn("name_scope call with neither name nor default_name",
errors[0])
def test_string_split(self):
text = "tf.string_split('test', delimiter=' ')"
expected_text = "tf.strings.split(source='test', sep=' ')"
@parameterized.parameters(
# Rename parameter: delimiter -> sep and add .to_sparse()
["tf.string_split('test', delimiter=' ')",
"tf.strings.split(source='test', sep=' ').to_sparse()"],
# Use compat.v1 for skip_empty parameter.
["tf.string_split('test', ' ', True)",
"tf.compat.v1.string_split(source='test', sep=' ', skip_empty=True)"],
["tf.string_split('test', ' ', skip_empty=False)",
"tf.strings.split(source='test', sep=' ').to_sparse()"],
# Split behavior for sep='' changed:
["tf.string_split(x)",
"tf.compat.v1.string_split(source=x)"],
["tf.string_split(x, '')",
"tf.compat.v1.string_split(source=x, sep='')"],
# If sep is a variable, we can't tell if it's empty:
["tf.string_split(x, sep)",
"tf.compat.v1.string_split(source=x, sep=sep)"],
# If sep is a non-empty string literal, then we don't need compat.v1.
["tf.string_split(x, 'non-empty-sep')",
"tf.strings.split(source=x, sep='non-empty-sep').to_sparse()"],
# Add to_sparse unless result_type is RaggedTensor:
["tf.string_split(x, ' ')",
"tf.strings.split(source=x, sep=' ').to_sparse()"],
["tf.string_split(x, ' ', result_type='SparseTensor')",
"tf.strings.split(source=x, sep=' ').to_sparse()"],
["tf.string_split(x, ' ', result_type='RaggedTensor')",
"tf.strings.split(source=x, sep=' ')"],
["tf.string_split(x, ' ', result_type=x)",
"tf.compat.v1.string_split(source=x, sep=' ', result_type=x)"],
) # pyformat: disable
def test_string_split(self, text, expected_text):
"""Tests for transforming from tf.string_split."""
_, _, _, new_text = self._upgrade(text)
self.assertEqual(expected_text, new_text)
text = "tf.string_split('test', ' ', True)"
expected_text = "tf.compat.v1.string_split(source='test', sep=' ', skip_empty=True)" # pylint: disable=line-too-long
_, _, _, new_text = self._upgrade(text)
self.assertEqual(expected_text, new_text)
text = "tf.string_split('test', ' ', skip_empty=False)"
expected_text = "tf.strings.split(source='test', sep=' ')" # pylint: disable=line-too-long
@parameterized.parameters(
# Add to_sparse unless result_type is RaggedTensor:
["tf.strings.split(x, sep)",
"tf.strings.split(x, sep).to_sparse()"],
["tf.strings.split(x, sep, result_type='SparseTensor')",
"tf.strings.split(x, sep).to_sparse()"],
["tf.strings.split(x, sep, result_type='RaggedTensor')",
"tf.strings.split(x, sep)"],
["tf.strings.split(x, sep, result_type=x)",
"tf.compat.v1.strings.split(x, sep, result_type=x)"],
) # pyformat: disable
def test_strings_split(self, text, expected_text):
"""Tests for transforming from tf.strings.split."""
_, _, _, new_text = self._upgrade(text)
self.assertEqual(expected_text, new_text)