Add a new "result_type" parameter to tf.strings.split, which controls whether the result is a SparseTensor or a RaggedTensor. Defaults to SparseTensor for TensorFlow 1.x, and RaggedTensor for TensorFlow 2.x.
PiperOrigin-RevId: 242673690
This commit is contained in:
parent
374d665e0f
commit
6cf4b5d3eb
tensorflow
examples/saved_model/integration_tests
python
tools
api/golden/v1
compatibility
@ -72,7 +72,7 @@ class TextEmbeddingModel(tf.train.Checkpoint):
|
||||
normalized_sentences = tf.strings.regex_replace(
|
||||
input=sentences, pattern=r"\pP", rewrite="")
|
||||
normalized_sentences = tf.reshape(normalized_sentences, [-1])
|
||||
sparse_tokens = tf.strings.split(normalized_sentences, " ")
|
||||
sparse_tokens = tf.strings.split(normalized_sentences, " ").to_sparse()
|
||||
|
||||
# Deal with a corner case: there is one empty sentence.
|
||||
sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant(""))
|
||||
|
@ -50,7 +50,7 @@ class TextRnnModel(tf.train.Checkpoint):
|
||||
# splitting on spaces.
|
||||
normalized_sentences = tf.strings.regex_replace(
|
||||
input=sentences, pattern=r"\pP", rewrite="")
|
||||
sparse_tokens = tf.strings.split(normalized_sentences, " ")
|
||||
sparse_tokens = tf.strings.split(normalized_sentences, " ").to_sparse()
|
||||
|
||||
# Deal with a corner case: there is one empty sentence.
|
||||
sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant(""))
|
||||
|
@ -1068,6 +1068,7 @@ tf_py_test(
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python/ops/ragged:ragged_string_ops",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_string_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -170,12 +171,16 @@ class StringSplitV2OpTest(test.TestCase):
|
||||
def testSplitV2(self):
|
||||
strings = ["pigs on the wing", "animals"]
|
||||
|
||||
with self.cached_session() as sess:
|
||||
tokens = string_ops.string_split_v2(strings)
|
||||
indices, values, shape = self.evaluate(tokens)
|
||||
self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [0, 3], [1, 0]])
|
||||
self.assertAllEqual(values, [b"pigs", b"on", b"the", b"wing", b"animals"])
|
||||
self.assertAllEqual(shape, [2, 4])
|
||||
tokens = string_ops.string_split_v2(strings)
|
||||
indices, values, shape = self.evaluate(tokens)
|
||||
self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [0, 3], [1, 0]])
|
||||
self.assertAllEqual(values, [b"pigs", b"on", b"the", b"wing", b"animals"])
|
||||
self.assertAllEqual(shape, [2, 4])
|
||||
|
||||
ragged_tokens = ragged_string_ops.string_split_v2(strings)
|
||||
self.assertAllEqual(ragged_tokens.row_splits, [0, 4, 5])
|
||||
self.assertAllEqual(ragged_tokens.values,
|
||||
[b"pigs", b"on", b"the", b"wing", b"animals"])
|
||||
|
||||
def testSplitV2MultiCharSeparator(self):
|
||||
# Match Python behavior:
|
||||
@ -185,15 +190,20 @@ class StringSplitV2OpTest(test.TestCase):
|
||||
# ['', '', '4', '5', '', '6', '']
|
||||
strings = ["1<>2<>3", "<><>4<>5<><>6<>"]
|
||||
|
||||
with self.cached_session() as sess:
|
||||
tokens = string_ops.string_split_v2(strings, sep="<>")
|
||||
indices, values, shape = self.evaluate(tokens)
|
||||
self.assertAllEqual(
|
||||
indices, [[0, 0], [0, 1], [0, 2],
|
||||
[1, 0], [1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6]])
|
||||
self.assertAllEqual(values, [b"1", b"2", b"3",
|
||||
b"", b"", b"4", b"5", b"", b"6", b""])
|
||||
self.assertAllEqual(shape, [2, 7])
|
||||
tokens = string_ops.string_split_v2(strings, sep="<>")
|
||||
indices, values, shape = self.evaluate(tokens)
|
||||
self.assertAllEqual(indices,
|
||||
[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [1, 3],
|
||||
[1, 4], [1, 5], [1, 6]])
|
||||
self.assertAllEqual(
|
||||
values, [b"1", b"2", b"3", b"", b"", b"4", b"5", b"", b"6", b""])
|
||||
self.assertAllEqual(shape, [2, 7])
|
||||
|
||||
ragged_tokens = ragged_string_ops.string_split_v2(strings, sep="<>")
|
||||
self.assertAllEqual(ragged_tokens.row_splits, [0, 3, 10])
|
||||
self.assertAllEqual(
|
||||
ragged_tokens.values,
|
||||
[b"1", b"2", b"3", b"", b"", b"4", b"5", b"", b"6", b""])
|
||||
|
||||
def testSplitV2SimpleSeparator(self):
|
||||
# Match Python behavior:
|
||||
@ -203,30 +213,38 @@ class StringSplitV2OpTest(test.TestCase):
|
||||
# ['1', '2', '', '3', '']
|
||||
strings = ["1,2,3", "4,5,,6,"]
|
||||
|
||||
with self.cached_session() as sess:
|
||||
tokens = string_ops.string_split_v2(strings, sep=',')
|
||||
indices, values, shape = self.evaluate(tokens)
|
||||
self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2],
|
||||
[1, 0], [1, 1], [1, 2], [1, 3], [1, 4]])
|
||||
self.assertAllEqual(values, [b"1", b"2", b"3",
|
||||
b"4", b"5", b"", b"6", b""])
|
||||
self.assertAllEqual(shape, [2, 5])
|
||||
tokens = string_ops.string_split_v2(strings, sep=",")
|
||||
indices, values, shape = self.evaluate(tokens)
|
||||
self.assertAllEqual(
|
||||
indices,
|
||||
[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [1, 3], [1, 4]])
|
||||
self.assertAllEqual(values, [b"1", b"2", b"3", b"4", b"5", b"", b"6", b""])
|
||||
self.assertAllEqual(shape, [2, 5])
|
||||
|
||||
ragged_tokens = ragged_string_ops.string_split_v2(strings, sep=",")
|
||||
self.assertAllEqual(ragged_tokens.row_splits, [0, 3, 8])
|
||||
self.assertAllEqual(ragged_tokens.values,
|
||||
[b"1", b"2", b"3", b"4", b"5", b"", b"6", b""])
|
||||
|
||||
def testSplitV2EmptySeparator(self):
|
||||
# Match Python behavior:
|
||||
# >>> '1 2 3'.split()
|
||||
# ['1', '2', '3']
|
||||
#>>> ' 1 2 3 '.split()
|
||||
#['1', '2', '3']
|
||||
# >>> ' 1 2 3 '.split()
|
||||
# ['1', '2', '3']
|
||||
strings = ["1 2 3", " 4 5 6 "]
|
||||
|
||||
with self.cached_session() as sess:
|
||||
tokens = string_ops.string_split_v2(strings)
|
||||
indices, values, shape = self.evaluate(tokens)
|
||||
self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2],
|
||||
[1, 0], [1, 1], [1, 2]])
|
||||
self.assertAllEqual(values, [b"1", b"2", b"3", b"4", b"5", b"6"])
|
||||
self.assertAllEqual(shape, [2, 3])
|
||||
tokens = string_ops.string_split_v2(strings)
|
||||
indices, values, shape = self.evaluate(tokens)
|
||||
self.assertAllEqual(indices,
|
||||
[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]])
|
||||
self.assertAllEqual(values, [b"1", b"2", b"3", b"4", b"5", b"6"])
|
||||
self.assertAllEqual(shape, [2, 3])
|
||||
|
||||
ragged_tokens = ragged_string_ops.string_split_v2(strings)
|
||||
self.assertAllEqual(ragged_tokens.row_splits, [0, 3, 6])
|
||||
self.assertAllEqual(ragged_tokens.values,
|
||||
[b"1", b"2", b"3", b"4", b"5", b"6"])
|
||||
|
||||
def testSplitV2SimpleSeparatorMaxSplit(self):
|
||||
# Match Python behavior:
|
||||
@ -236,13 +254,16 @@ class StringSplitV2OpTest(test.TestCase):
|
||||
# ['4', '5,,6,']
|
||||
strings = ["1,2,3", "4,5,,6,"]
|
||||
|
||||
with self.cached_session() as sess:
|
||||
tokens = string_ops.string_split_v2(strings, sep=',', maxsplit=1)
|
||||
indices, values, shape = self.evaluate(tokens)
|
||||
self.assertAllEqual(indices, [[0, 0], [0, 1],
|
||||
[1, 0], [1, 1]])
|
||||
self.assertAllEqual(values, [b"1", b"2,3", b"4", b"5,,6,"])
|
||||
self.assertAllEqual(shape, [2, 2])
|
||||
tokens = string_ops.string_split_v2(strings, sep=",", maxsplit=1)
|
||||
indices, values, shape = self.evaluate(tokens)
|
||||
self.assertAllEqual(indices, [[0, 0], [0, 1], [1, 0], [1, 1]])
|
||||
self.assertAllEqual(values, [b"1", b"2,3", b"4", b"5,,6,"])
|
||||
self.assertAllEqual(shape, [2, 2])
|
||||
|
||||
ragged_tokens = ragged_string_ops.string_split_v2(
|
||||
strings, sep=",", maxsplit=1)
|
||||
self.assertAllEqual(ragged_tokens.row_splits, [0, 2, 4])
|
||||
self.assertAllEqual(ragged_tokens.values, [b"1", b"2,3", b"4", b"5,,6,"])
|
||||
|
||||
def testSplitV2EmptySeparatorMaxSplit(self):
|
||||
# Match Python behavior:
|
||||
@ -252,13 +273,15 @@ class StringSplitV2OpTest(test.TestCase):
|
||||
# ['4', '5 6 ']
|
||||
strings = ["1 2 3", " 4 5 6 "]
|
||||
|
||||
with self.cached_session() as sess:
|
||||
tokens = string_ops.string_split_v2(strings, maxsplit=1)
|
||||
indices, values, shape = self.evaluate(tokens)
|
||||
self.assertAllEqual(indices, [[0, 0], [0, 1],
|
||||
[1, 0], [1, 1]])
|
||||
self.assertAllEqual(values, [b"1", b"2 3", b"4", b"5 6 "])
|
||||
self.assertAllEqual(shape, [2, 2])
|
||||
tokens = string_ops.string_split_v2(strings, maxsplit=1)
|
||||
indices, values, shape = self.evaluate(tokens)
|
||||
self.assertAllEqual(indices, [[0, 0], [0, 1], [1, 0], [1, 1]])
|
||||
self.assertAllEqual(values, [b"1", b"2 3", b"4", b"5 6 "])
|
||||
self.assertAllEqual(shape, [2, 2])
|
||||
|
||||
ragged_tokens = ragged_string_ops.string_split_v2(strings, maxsplit=1)
|
||||
self.assertAllEqual(ragged_tokens.row_splits, [0, 2, 4])
|
||||
self.assertAllEqual(ragged_tokens.values, [b"1", b"2 3", b"4", b"5 6 "])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -22,9 +22,11 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_string_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_conversion_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -400,3 +402,163 @@ def _unicode_decode(input, input_encoding, errors, replacement_char,
|
||||
else:
|
||||
return codepoints
|
||||
|
||||
|
||||
@tf_export("strings.split", v1=[])
|
||||
def string_split_v2(source, sep=None, maxsplit=-1):
|
||||
"""Split elements of `source` based on `sep` into a `RaggedTensor`.
|
||||
|
||||
Let N be the size of source (typically N will be the batch size). Split each
|
||||
element of `source` based on `sep` and return a `SparseTensor` or
|
||||
`RaggedTensor` containing the split tokens. Empty tokens are ignored.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> tf.strings.split(['hello world', 'a b c'])
|
||||
<tf.RaggedTensor [['hello', 'world'], ['a', 'b', 'c']]>
|
||||
```
|
||||
|
||||
If `sep` is given, consecutive delimiters are not grouped together and are
|
||||
deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and
|
||||
sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
|
||||
string, consecutive whitespace are regarded as a single separator, and the
|
||||
result will contain no empty strings at the start or end if the string has
|
||||
leading or trailing whitespace.
|
||||
|
||||
Note that the above mentioned behavior matches python's str.split.
|
||||
|
||||
Args:
|
||||
source: `1-D` string `Tensor`, the strings to split.
|
||||
sep: `0-D` string `Tensor`, the delimiter string.
|
||||
maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result.
|
||||
|
||||
Raises:
|
||||
ValueError: If sep is not a string.
|
||||
|
||||
Returns:
|
||||
A `RaggedTensor` of rank `2`: the strings split according to the delimiter.
|
||||
"""
|
||||
sparse_result = string_ops.string_split_v2(source, sep=sep, maxsplit=maxsplit)
|
||||
return ragged_tensor.RaggedTensor.from_value_rowids(
|
||||
values=sparse_result.values,
|
||||
value_rowids=sparse_result.indices[:, 0],
|
||||
nrows=sparse_result.dense_shape[0])
|
||||
|
||||
|
||||
@tf_export(v1=["string_split"])
|
||||
@deprecation.deprecated_args(None,
|
||||
"delimiter is deprecated, please use sep instead.",
|
||||
"delimiter")
|
||||
def string_split(source, sep=None, skip_empty=True, delimiter=None,
|
||||
result_type="SparseTensor"): # pylint: disable=invalid-name
|
||||
"""Split elements of `source` based on `delimiter`.
|
||||
|
||||
Let N be the size of `source` (typically N will be the batch size). Split each
|
||||
element of `source` based on `delimiter` and return a `SparseTensor`
|
||||
or `RaggedTensor` containing the split tokens. Empty tokens are ignored.
|
||||
|
||||
If `sep` is an empty string, each element of the `source` is split
|
||||
into individual strings, each containing one byte. (This includes splitting
|
||||
multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is
|
||||
treated as a set of delimiters with each considered a potential split point.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> tf.strings.split(['hello world', 'a b c'])
|
||||
tf.SparseTensor(indices=[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]],
|
||||
values=['hello', 'world', 'a', 'b', 'c']
|
||||
dense_shape=[2, 3])
|
||||
|
||||
>>> tf.strings.split(['hello world', 'a b c'], result_type="RaggedTensor")
|
||||
<tf.RaggedTensor [['hello', 'world'], ['a', 'b', 'c']]>
|
||||
```
|
||||
|
||||
Args:
|
||||
source: `1-D` string `Tensor`, the strings to split.
|
||||
sep: `0-D` string `Tensor`, the delimiter character, the string should
|
||||
be length 0 or 1. Default is ' '.
|
||||
skip_empty: A `bool`. If `True`, skip the empty strings from the result.
|
||||
delimiter: deprecated alias for `sep`.
|
||||
result_type: The tensor type for the result: one of `"RaggedTensor"` or
|
||||
`"SparseTensor"`.
|
||||
|
||||
Raises:
|
||||
ValueError: If delimiter is not a string.
|
||||
|
||||
Returns:
|
||||
A `SparseTensor` or `RaggedTensor` of rank `2`, the strings split according
|
||||
to the delimiter. The first column of the indices corresponds to the row
|
||||
in `source` and the second column corresponds to the index of the split
|
||||
component in this row.
|
||||
"""
|
||||
sparse_result = string_ops.string_split(
|
||||
source, sep=sep, skip_empty=skip_empty, delimiter=delimiter)
|
||||
if result_type == "SparseTensor":
|
||||
return sparse_result
|
||||
elif result_type == "RaggedTensor":
|
||||
return ragged_tensor.RaggedTensor.from_value_rowids(
|
||||
values=sparse_result.values,
|
||||
value_rowids=sparse_result.indices[:, 0],
|
||||
nrows=sparse_result.dense_shape[0])
|
||||
else:
|
||||
raise ValueError("result_type must be 'RaggedTensor' or 'SparseTensor'.")
|
||||
|
||||
|
||||
# In TensorFlow 1.x, "tf.strings.split" uses the new signature (with maxsplit),
|
||||
# but we need to add the result_type argument.
|
||||
@tf_export(v1=["strings.split"])
|
||||
def strings_split_v1(source, sep=None, maxsplit=-1, result_type="SparseTensor"):
|
||||
"""Split elements of `source` based on `sep`.
|
||||
|
||||
Let N be the size of source (typically N will be the batch size). Split each
|
||||
element of `source` based on `sep` and return a `SparseTensor` or
|
||||
`RaggedTensor` containing the split tokens. Empty tokens are ignored.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> tf.strings.split(['hello world', 'a b c'])
|
||||
tf.SparseTensor(indices=[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]],
|
||||
values=['hello', 'world', 'a', 'b', 'c']
|
||||
dense_shape=[2, 3])
|
||||
|
||||
>>> tf.strings.split(['hello world', 'a b c'], result_type="RaggedTensor")
|
||||
<tf.RaggedTensor [['hello', 'world'], ['a', 'b', 'c']]>
|
||||
```
|
||||
|
||||
If `sep` is given, consecutive delimiters are not grouped together and are
|
||||
deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and
|
||||
sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
|
||||
string, consecutive whitespace are regarded as a single separator, and the
|
||||
result will contain no empty strings at the start or end if the string has
|
||||
leading or trailing whitespace.
|
||||
|
||||
Note that the above mentioned behavior matches python's str.split.
|
||||
|
||||
Args:
|
||||
source: `1-D` string `Tensor`, the strings to split.
|
||||
sep: `0-D` string `Tensor`, the delimiter character.
|
||||
maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result.
|
||||
result_type: The tensor type for the result: one of `"RaggedTensor"` or
|
||||
`"SparseTensor"`.
|
||||
|
||||
Raises:
|
||||
ValueError: If sep is not a string.
|
||||
|
||||
Returns:
|
||||
A `SparseTensor` of rank `2`, the strings split according to the delimiter.
|
||||
The first column of the indices corresponds to the row in `source` and the
|
||||
second column corresponds to the index of the split component in this row.
|
||||
"""
|
||||
sparse_result = string_ops.string_split_v2(
|
||||
source, sep=sep, maxsplit=maxsplit)
|
||||
if result_type == "SparseTensor":
|
||||
return sparse_result
|
||||
elif result_type == "RaggedTensor":
|
||||
return ragged_tensor.RaggedTensor.from_value_rowids(
|
||||
values=sparse_result.values,
|
||||
value_rowids=sparse_result.indices[:, 0],
|
||||
nrows=sparse_result.dense_shape[0])
|
||||
else:
|
||||
raise ValueError("result_type must be 'RaggedTensor' or 'SparseTensor'.")
|
||||
|
@ -192,10 +192,8 @@ def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
|
||||
name=name)
|
||||
|
||||
|
||||
@tf_export(v1=["string_split"])
|
||||
@deprecation.deprecated_args(None,
|
||||
"delimiter is deprecated, please use sep instead.",
|
||||
"delimiter")
|
||||
# Note: tf.strings.split is exported in ragged/ragged_string_ops.py, which
|
||||
# defines a wrapper for this function.
|
||||
def string_split(source, sep=None, skip_empty=True, delimiter=None): # pylint: disable=invalid-name
|
||||
"""Split elements of `source` based on `delimiter` into a `SparseTensor`.
|
||||
|
||||
@ -251,7 +249,8 @@ def string_split(source, sep=None, skip_empty=True, delimiter=None): # pylint:
|
||||
return sparse_tensor.SparseTensor(indices, values, shape)
|
||||
|
||||
|
||||
@tf_export("strings.split")
|
||||
# Note: tf.strings.split is exported in ragged/ragged_string_ops.py, which
|
||||
# defines a wrapper for this function.
|
||||
def string_split_v2(source, sep=None, maxsplit=-1):
|
||||
"""Split elements of `source` based on `sep` into a `SparseTensor`.
|
||||
|
||||
|
@ -2218,7 +2218,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "string_split"
|
||||
argspec: "args=[\'source\', \'sep\', \'skip_empty\', \'delimiter\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
|
||||
argspec: "args=[\'source\', \'sep\', \'skip_empty\', \'delimiter\', \'result_type\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'SparseTensor\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "string_strip"
|
||||
|
@ -30,7 +30,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "split"
|
||||
argspec: "args=[\'source\', \'sep\', \'maxsplit\'], varargs=None, keywords=None, defaults=[\'None\', \'-1\'], "
|
||||
argspec: "args=[\'source\', \'sep\', \'maxsplit\', \'result_type\'], varargs=None, keywords=None, defaults=[\'None\', \'-1\', \'SparseTensor\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "strip"
|
||||
|
@ -109,6 +109,7 @@ py_test(
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/tools/common:public_api",
|
||||
"//tensorflow/tools/common:traverse",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ast
|
||||
import copy
|
||||
import functools
|
||||
import sys
|
||||
|
||||
@ -1595,6 +1596,8 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"tf.nn.fractional_avg_pool": _pool_seed_transformer,
|
||||
"tf.nn.fractional_max_pool": _pool_seed_transformer,
|
||||
"tf.name_scope": _name_scope_transformer,
|
||||
"tf.string_split": _string_split_transformer,
|
||||
"tf.strings.split": _string_split_rtype_transformer,
|
||||
"tf.estimator.DNNEstimator":
|
||||
functools.partial(
|
||||
_rename_if_arg_found_transformer,
|
||||
@ -1667,12 +1670,6 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"takes input_layer_partitioner, so the call was converted to "
|
||||
"compat.v1."
|
||||
),
|
||||
"tf.string_split": functools.partial(
|
||||
_rename_if_arg_found_transformer, arg_name="skip_empty",
|
||||
arg_ok_predicate=_is_ast_false, remove_if_ok=True,
|
||||
message="tf.string_split's replacement no longer takes the "
|
||||
"skip_empty argument. Since the argument was present, the call was "
|
||||
"converted to compat.v1."),
|
||||
"tf.device": functools.partial(
|
||||
_rename_if_arg_found_transformer, arg_name="device_name",
|
||||
arg_ok_predicate=_is_ast_str, remove_if_ok=False,
|
||||
@ -2274,3 +2271,82 @@ def _name_scope_transformer(parent, node, full_name, name, logs):
|
||||
logs.append((ast_edits.ERROR, node.lineno, node.col_offset,
|
||||
"name_scope call with neither name nor default_name cannot be "
|
||||
"converted properly."))
|
||||
|
||||
|
||||
def _rename_to_compat_v1(node, full_name, logs, reason):
|
||||
new_name = full_name.replace("tf.", "tf.compat.v1.", 1)
|
||||
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
|
||||
"Renamed %r to %r: %s" % (full_name, new_name, reason)))
|
||||
new_name_node = ast_edits.full_name_node(new_name, node.func.ctx)
|
||||
ast.copy_location(new_name_node, node.func)
|
||||
pasta.ast_utils.replace_child(node, node.func, new_name_node)
|
||||
return node
|
||||
|
||||
|
||||
def _string_split_transformer(parent, node, full_name, name, logs):
|
||||
"""Update tf.string_split arguments: skip_empty, sep, result_type."""
|
||||
# Check the skip_empty parameter: if not false, then use compat.v1.
|
||||
for i, kw in enumerate(node.keywords):
|
||||
if kw.arg == "skip_empty":
|
||||
if _is_ast_false(kw.value):
|
||||
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
|
||||
"removed argument skip_empty for tf.string_split."))
|
||||
node.keywords.pop(i)
|
||||
break
|
||||
else:
|
||||
return _rename_to_compat_v1(
|
||||
node, full_name, logs, "tf.string_split's replacement no longer "
|
||||
"takes the skip_empty argument.")
|
||||
|
||||
# Check the sep parameter: if it might be an empty string, then use compat.v1.
|
||||
sep_is_nonempty_string = False
|
||||
for i, kw in enumerate(node.keywords):
|
||||
if ((kw.arg == "sep" or kw.arg == "delimiter") and
|
||||
isinstance(kw.value, ast.Str) and kw.value.s != ""):
|
||||
sep_is_nonempty_string = True
|
||||
if not sep_is_nonempty_string:
|
||||
return _rename_to_compat_v1(
|
||||
node, full_name, logs,
|
||||
"The semantics for tf.string_split's sep parameter have changed when "
|
||||
"sep is the empty string.")
|
||||
|
||||
# Check the result_type parameter
|
||||
return _string_split_rtype_transformer(parent, node, full_name, name, logs)
|
||||
|
||||
|
||||
def _string_split_rtype_transformer(parent, node, full_name, name, logs):
|
||||
"""Update tf.strings.split argument: result_type."""
|
||||
# Remove the "result_type" argument.
|
||||
need_to_sparse = True
|
||||
for i, kw in enumerate(node.keywords):
|
||||
if kw.arg == "result_type":
|
||||
if (isinstance(kw.value, ast.Str) and
|
||||
kw.value.s in ("RaggedTensor", "SparseTensor")):
|
||||
logs.append((ast_edits.INFO, node.lineno, node.col_offset,
|
||||
"Removed argument result_type=%r for function %s" %
|
||||
(kw.value.s, full_name or name)))
|
||||
node.keywords.pop(i)
|
||||
if kw.value.s == "RaggedTensor":
|
||||
need_to_sparse = False
|
||||
else:
|
||||
return _rename_to_compat_v1(
|
||||
node, full_name, logs,
|
||||
"%s no longer takes the result_type parameter." % full_name)
|
||||
break
|
||||
|
||||
# If necessary, add a call to .to_sparse() to convert the output of
|
||||
# strings.split from a RaggedTensor to a SparseTensor.
|
||||
if need_to_sparse:
|
||||
if (isinstance(parent, ast.Attribute) and parent.attr == "to_sparse"):
|
||||
return # Prevent infinite recursion (since child nodes are transformed)
|
||||
logs.append(
|
||||
(ast_edits.INFO, node.lineno, node.col_offset,
|
||||
"Adding call to RaggedTensor.to_sparse() to result of strings.split, "
|
||||
"since it now returns a RaggedTensor."))
|
||||
node = ast.Attribute(value=copy.deepcopy(node), attr="to_sparse")
|
||||
try:
|
||||
node = ast.Call(node, [], [])
|
||||
except TypeError:
|
||||
node = ast.Call(node, [], [], None, None)
|
||||
|
||||
return node
|
||||
|
@ -22,6 +22,7 @@ import inspect
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from absl.testing import parameterized
|
||||
import six
|
||||
import tensorflow as tf
|
||||
# OSS TF V2 import placeholder.
|
||||
@ -75,7 +76,7 @@ def get_func_and_args_from_str(call_str):
|
||||
return function_name, args
|
||||
|
||||
|
||||
class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
class TestUpgrade(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
"""Test various APIs that have been changed in 2.0.
|
||||
|
||||
We also test whether a converted file is executable. test_file_v1_10.py
|
||||
@ -85,6 +86,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(TestUpgrade, cls).setUpClass()
|
||||
cls.v2_symbols = {}
|
||||
cls.v1_symbols = {}
|
||||
if hasattr(tf.compat, "v2"):
|
||||
@ -1512,19 +1514,54 @@ def _log_prob(self, x):
|
||||
self.assertIn("name_scope call with neither name nor default_name",
|
||||
errors[0])
|
||||
|
||||
def test_string_split(self):
|
||||
text = "tf.string_split('test', delimiter=' ')"
|
||||
expected_text = "tf.strings.split(source='test', sep=' ')"
|
||||
@parameterized.parameters(
|
||||
# Rename parameter: delimiter -> sep and add .to_sparse()
|
||||
["tf.string_split('test', delimiter=' ')",
|
||||
"tf.strings.split(source='test', sep=' ').to_sparse()"],
|
||||
# Use compat.v1 for skip_empty parameter.
|
||||
["tf.string_split('test', ' ', True)",
|
||||
"tf.compat.v1.string_split(source='test', sep=' ', skip_empty=True)"],
|
||||
["tf.string_split('test', ' ', skip_empty=False)",
|
||||
"tf.strings.split(source='test', sep=' ').to_sparse()"],
|
||||
# Split behavior for sep='' changed:
|
||||
["tf.string_split(x)",
|
||||
"tf.compat.v1.string_split(source=x)"],
|
||||
["tf.string_split(x, '')",
|
||||
"tf.compat.v1.string_split(source=x, sep='')"],
|
||||
# If sep is a variable, we can't tell if it's empty:
|
||||
["tf.string_split(x, sep)",
|
||||
"tf.compat.v1.string_split(source=x, sep=sep)"],
|
||||
# If sep is a non-empty string literal, then we don't need compat.v1.
|
||||
["tf.string_split(x, 'non-empty-sep')",
|
||||
"tf.strings.split(source=x, sep='non-empty-sep').to_sparse()"],
|
||||
# Add to_sparse unless result_type is RaggedTensor:
|
||||
["tf.string_split(x, ' ')",
|
||||
"tf.strings.split(source=x, sep=' ').to_sparse()"],
|
||||
["tf.string_split(x, ' ', result_type='SparseTensor')",
|
||||
"tf.strings.split(source=x, sep=' ').to_sparse()"],
|
||||
["tf.string_split(x, ' ', result_type='RaggedTensor')",
|
||||
"tf.strings.split(source=x, sep=' ')"],
|
||||
["tf.string_split(x, ' ', result_type=x)",
|
||||
"tf.compat.v1.string_split(source=x, sep=' ', result_type=x)"],
|
||||
) # pyformat: disable
|
||||
def test_string_split(self, text, expected_text):
|
||||
"""Tests for transforming from tf.string_split."""
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
text = "tf.string_split('test', ' ', True)"
|
||||
expected_text = "tf.compat.v1.string_split(source='test', sep=' ', skip_empty=True)" # pylint: disable=line-too-long
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
text = "tf.string_split('test', ' ', skip_empty=False)"
|
||||
expected_text = "tf.strings.split(source='test', sep=' ')" # pylint: disable=line-too-long
|
||||
@parameterized.parameters(
|
||||
# Add to_sparse unless result_type is RaggedTensor:
|
||||
["tf.strings.split(x, sep)",
|
||||
"tf.strings.split(x, sep).to_sparse()"],
|
||||
["tf.strings.split(x, sep, result_type='SparseTensor')",
|
||||
"tf.strings.split(x, sep).to_sparse()"],
|
||||
["tf.strings.split(x, sep, result_type='RaggedTensor')",
|
||||
"tf.strings.split(x, sep)"],
|
||||
["tf.strings.split(x, sep, result_type=x)",
|
||||
"tf.compat.v1.strings.split(x, sep, result_type=x)"],
|
||||
) # pyformat: disable
|
||||
def test_strings_split(self, text, expected_text):
|
||||
"""Tests for transforming from tf.strings.split."""
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(expected_text, new_text)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user