Add tf.strings.format() and tf.print() to support RaggedTensors.

PiperOrigin-RevId: 316901841
Change-Id: I5fa78acc118557fcf43d1b805149172f8547c0e1
This commit is contained in:
Edward Loper 2020-06-17 09:19:04 -07:00 committed by TensorFlower Gardener
parent fc296acdc1
commit cb264418f6
5 changed files with 359 additions and 14 deletions

View File

@ -1264,9 +1264,35 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":ragged_array_ops",
"//tensorflow/python:constant_op",
":ragged_factory_ops",
":ragged_tensor",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:tensor_shape",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
py_test(
name = "ragged_print_op_test",
srcs = ["ragged_print_op_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":ragged", # fixdeps: keep
":ragged_factory_ops",
":ragged_string_ops",
":ragged_tensor",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:logging_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:sparse_ops",
"//tensorflow/python/eager:def_function",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -29,6 +29,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_bitwise_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import parsing_ops
@ -510,6 +511,7 @@ _RAGGED_DISPATCH_OPS = [
['data', 'segment_ids']),
(math_ops.unsorted_segment_sqrt_n, ragged_math_ops.segment_sqrt_n,
['data', 'segment_ids']),
(string_ops.string_format, ragged_string_ops.string_format, ['[inputs]']),
(string_ops.reduce_join_v2, ragged_string_ops.reduce_join, ['inputs']),
(math_ops.reduce_sum, ragged_math_ops.reduce_sum, ['input_tensor']),
(math_ops.reduce_prod, ragged_math_ops.reduce_prod, ['input_tensor']),
@ -549,7 +551,7 @@ def register_dispatchers():
RaggedDispatcher(original_op, ragged_op, args).register(original_op)
def _ragged_op_signature(op, ragged_args):
def _ragged_op_signature(op, ragged_args, ragged_varargs=False):
"""Returns a signature for the given op, marking ragged args in bold."""
op_name = tf_export.get_canonical_name_for_symbol(op)
argspec = tf_inspect.getfullargspec(op)
@ -566,7 +568,10 @@ def _ragged_op_signature(op, ragged_args):
# Add varargs and keyword args
if argspec.varargs:
arg_names.append('*' + argspec.varargs)
if ragged_varargs:
arg_names.append('***' + argspec.varargs + '**')
else:
arg_names.append('*' + argspec.varargs)
if argspec.varkw:
arg_names.append('**' + argspec.varkw)
@ -597,6 +602,8 @@ def ragged_op_list(tf_version=1):
arginfos = _get_arg_infos(op, ragged_args)
ragged_args = [arginfo.position for arginfo in arginfos]
lines.append(_ragged_op_signature(op, ragged_args))
lines.append(
_ragged_op_signature(logging_ops.print_v2, [], ragged_varargs=True))
return ('\n\n### Additional ops that support `RaggedTensor`\n\n'
'Arguments that accept `RaggedTensor`s are marked in **bold**.\n\n' +
'\n'.join(sorted(lines)) + 'n')

View File

@ -142,8 +142,7 @@ BINARY_INT_OPS = [
# pylint: disable=g-complex-comprehension
@test_util.run_all_in_graph_and_eager_modes
class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
class RaggedDispatchTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def assertSameShape(self, x, y):
"""Checks that x and y have the same shape (including ragged shapes)."""
@ -763,7 +762,12 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
'tensor': ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]]),
'axis': [0, -1]
},
expected=ragged_factory_ops.constant_value([[5, 4], [3, 2, 1]]))
expected=ragged_factory_ops.constant_value([[5, 4], [3, 2, 1]])),
dict(
op=string_ops.string_format,
kwargs={'template': 'Hi {}',
'inputs': [ragged_factory_ops.constant_value([[1, 2], [3]])]},
expected='Hi [[1, 2], [3]]'),
])
def testRaggedDispatch(self, op, expected, args=(), result_is_list=False,
kwargs=None):
@ -819,14 +823,14 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
'math.unsorted_segment_mean', 'math.unsorted_segment_min',
'math.unsorted_segment_prod', 'math.unsorted_segment_sqrt_n',
'math.unsorted_segment_sum', 'one_hot', 'ones_like', 'rank', 'realdiv',
'reduce_all', 'size', 'squeeze', 'stack', 'strings.as_string',
'math.reduce_all', 'size', 'squeeze', 'stack', 'strings.as_string',
'strings.join', 'strings.length', 'strings.reduce_join',
'strings.regex_full_match', 'strings.regex_replace', 'strings.strip',
'strings.substr', 'strings.to_hash_bucket_fast',
'strings.to_hash_bucket_strong', 'strings.to_hash_bucket',
'strings.to_number', 'strings.unicode_script', 'tile', 'truncatediv',
'truncatemod', 'zeros_like', 'dynamic_partition', 'reverse',
'nn.dropout',
'nn.dropout', 'strings.format', 'print'
]
# Ops that should be listed as supported in v1 only.
@ -838,15 +842,15 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
v1_ragged_ops = ragged_dispatch.ragged_op_list(tf_version=1)
for element in supported_ops + supported_ops_v1:
self.assertIn(element, v1_ragged_ops)
self.assertIn('`tf.' + element + '`', v1_ragged_ops)
for element in supported_ops_v2:
self.assertNotIn(element, v1_ragged_ops)
self.assertNotIn('`tf.' + element + '`', v1_ragged_ops)
v2_ragged_ops = ragged_dispatch.ragged_op_list(tf_version=2)
for element in supported_ops + supported_ops_v2:
self.assertIn(element, v2_ragged_ops)
self.assertIn('`tf.' + element + '`', v2_ragged_ops)
for element in supported_ops_v1:
self.assertNotIn(element, v2_ragged_ops)
self.assertNotIn('`tf.' + element + '`', v2_ragged_ops)
if __name__ == '__main__':

View File

@ -0,0 +1,195 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.print with ragged tensors.
Note: ragged support for tf.print is implemented by RaggedPrintV2Dispatcher in
ragged_dispatch.py.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import tempfile
from absl.testing import parameterized
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_string_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import googletest
@test_util.run_all_in_graph_and_eager_modes
class RaggedPrintV2Test(test_util.TensorFlowTestCase, parameterized.TestCase):
# pylint: disable=g-long-lambda
@parameterized.named_parameters([
dict(
testcase_name='2d_int_values',
inputs=lambda: [ragged_factory_ops.constant([[1, 2], [3]])],
expected='[[1, 2], [3]]\n'),
dict(
testcase_name='3d_int_values',
inputs=lambda: [ragged_factory_ops.constant([[[1, 2], [3]], [[4]]])],
expected='[[[1, 2], [3]], [[4]]]\n'),
dict(
testcase_name='2d_str_values',
inputs=lambda: [ragged_factory_ops.constant([['a', 'b'], ['c']])],
expected="[['a', 'b'], ['c']]\n"),
dict(
testcase_name='2d_str_values_with_escaping',
inputs=lambda: [ragged_factory_ops.constant([["a'b"], ['c"d']])],
expected="[['a\\'b'], ['c\"d']]\n"),
dict(
testcase_name='two_ragged_values',
inputs=lambda: [
ragged_factory_ops.constant([[1, 2], [3]]),
ragged_factory_ops.constant([[5], [], [6, 7, 8]])
],
expected='[[1, 2], [3]] [[5], [], [6, 7, 8]]\n'),
dict(
testcase_name='ragged_value_and_non_tensor_values',
inputs=lambda:
['a', 5, True,
ragged_factory_ops.constant([[1, 2], [3]]), 'c'],
expected='a 5 True [[1, 2], [3]] c\n'),
dict(
testcase_name='ragged_value_and_dense_value',
inputs=lambda: [
ragged_factory_ops.constant([[1, 2], [3]]),
constant_op.constant([[1, 2], [3, 4]])
],
expected='[[1, 2], [3]] [[1 2]\n [3 4]]\n'),
dict(
testcase_name='ragged_value_and_sparse_value',
inputs=lambda: [
ragged_factory_ops.constant([[1, 2], [3]]),
sparse_ops.from_dense([[1]])
],
expected=(
'[[1, 2], [3]] '
"'SparseTensor(indices=[[0 0]], values=[1], shape=[1 1])'\n")),
dict(
testcase_name='summarize_default',
inputs=lambda: [
ragged_factory_ops.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9], [10], [
], [], [], [], [11, 12]])
],
expected=('[[1, 2, 3, ..., 7, 8, 9], [10], [], '
'..., '
'[], [], [11, 12]]\n')),
dict(
testcase_name='summarize_2',
inputs=lambda: [
ragged_factory_ops.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9], [10], [
], [], [], [], [11, 12]])
],
summarize=2,
expected='[[1, 2, ..., 8, 9], [10], ..., [], [11, 12]]\n'),
dict(
testcase_name='summarize_neg1',
inputs=lambda: [
ragged_factory_ops.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9], [10], [
], [], [], [], [11, 12]])
],
summarize=-1,
expected=('[[1, 2, 3, 4, 5, 6, 7, 8, 9], [10], '
'[], [], [], [], [11, 12]]\n')),
])
def testRaggedPrint(self, inputs, expected, summarize=None):
if callable(inputs):
inputs = inputs()
with tempfile.TemporaryDirectory() as tmpdirname:
path = os.path.join(tmpdirname, 'print_output')
kwargs = {'output_stream': 'file://{}'.format(path)}
if summarize is not None:
kwargs.update(summarize=summarize)
self.evaluate(logging_ops.print_v2(*inputs, **kwargs))
actual = open(path, 'r').read()
self.assertEqual(repr(actual), repr(expected))
@test_util.run_all_in_graph_and_eager_modes
class RaggedToStringTest(test_util.TensorFlowTestCase, parameterized.TestCase):
@parameterized.named_parameters([
('2d_int', [[1, 2], [], [3, 4, 5]], '[[1, 2], [], [3, 4, 5]]'),
('2d_str', [['a'], ['b'], ['c', 'd']], "[['a'], ['b'], ['c', 'd']]"),
('3d_int', [[[1, 2], []], [[3, 4, 5]]], '[[[1, 2], []], [[3, 4, 5]]]'),
('escape', [["a'b"], [r'c\d']], r"[['a\'b'], ['c\\d']]"),
dict(testcase_name='2d_empty', rt=[], ragged_rank=1, expected='[]'),
dict(testcase_name='3d_empty', rt=[], ragged_rank=2, expected='[]'),
dict(
testcase_name='3d_rrank1',
rt=[[[1, 2], [3, 4]], [], [[5, 6]]],
ragged_rank=1,
expected='[[[1, 2], [3, 4]], [], [[5, 6]]]'),
dict(
testcase_name='2d_empty_row', rt=[[]], ragged_rank=1,
expected='[[]]'),
dict(
testcase_name='3d_empty_row', rt=[[]], ragged_rank=2,
expected='[[]]'),
dict(
testcase_name='summarize_1',
rt=[[1, 2, 3, 4, 5], [], [6], [7], [8, 9]],
summarize=1,
expected='[[1, ..., 5], ..., [8, 9]]'),
dict(
testcase_name='summarize_2',
rt=[[1, 2, 3, 4, 5], [], [6], [7], [8, 9]],
summarize=2,
expected='[[1, 2, ..., 4, 5], [], ..., [7], [8, 9]]'),
])
def testRaggedToString(self, rt, expected, summarize=None, ragged_rank=None):
rt = ragged_factory_ops.constant(rt, ragged_rank=ragged_rank)
actual = ragged_string_ops.ragged_tensor_to_string(rt, summarize=summarize)
self.assertAllEqual(actual, expected)
@parameterized.named_parameters([
('maxelts_BadType', [[1]], "Expected summarize .*, got 'foo'", 'foo'),
('maxelts_0', [[1]], 'Expected summarize to be .*, got 0', 0),
('maxelts_Neg2', [[1]], 'Expected summarize to be .*, got -2', -2),
])
def testRaggedToStringErrors(self,
rt,
error,
summarize=None,
exception=ValueError):
rt = ragged_factory_ops.constant(rt)
with self.assertRaisesRegex(exception, error):
self.evaluate(
ragged_string_ops.ragged_tensor_to_string(rt, summarize=summarize))
def testRaggedToStringUnknownRank(self):
@def_function.function(
input_signature=[ragged_tensor.RaggedTensorSpec(ragged_rank=1)])
def f(rt):
return ragged_string_ops.ragged_tensor_to_string(rt)
with self.assertRaisesRegex(
ValueError, 'RaggedTensor to_string requires '
'that rt.shape.rank is not None'):
f(ragged_factory_ops.constant([[1, 2], [3]]))
if __name__ == '__main__':
googletest.main()

View File

@ -18,10 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops.ragged import ragged_array_ops
@ -30,9 +33,14 @@ from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util import compat as util_compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export
map_fn_lib = LazyLoader("map_fn_lib", globals(),
"tensorflow.python.ops.map_fn")
@tf_export("strings.bytes_split")
@dispatch.add_dispatch_support
def string_bytes_split(input, name=None): # pylint: disable=redefined-builtin
@ -640,7 +648,7 @@ def strings_split_v1(input=None, sep=None, maxsplit=-1, # pylint: disable=redef
input, dtype=dtypes.string, name="input")
if input.shape.rank == 0:
input = gen_array_ops.expand_dims(input, 0)
input = array_ops.expand_dims(input, 0)
if result_type == "SparseTensor":
if input.shape.rank == 1:
@ -813,3 +821,108 @@ def ngrams(data,
values=output, row_splits=output_splits, validate=False)
return array_ops.reshape(output.flat_values,
dense_shape) if to_tensor else output
def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
"""Version of tf.strings.format that handles RaggedTensors."""
if tensor_util.is_tensor(inputs) or ragged_tensor.is_ragged(inputs):
inputs = [inputs]
split_template = template.split(placeholder)
if len(inputs) != len(split_template) - 1:
raise ValueError("num placeholders in template and num inputs must match"
": {} vs {}".format(len(split_template) - 1, len(inputs)))
with ops.name_scope(name, "StringFormat", [inputs]):
output_pieces = [constant_op.constant(split_template[0])]
for i, input in enumerate(inputs):
if ragged_tensor.is_ragged(input):
output_pieces.append(ragged_tensor_to_string(input, summarize))
else:
output_pieces.append(string_ops.string_format(
"{}", [input], summarize=summarize))
output_pieces.append(constant_op.constant(split_template[i + 1]))
if len(output_pieces) == 1:
return output_pieces[0]
else:
return string_ops.reduce_join(output_pieces)
def ragged_tensor_to_string(rt, summarize=None):
"""Returns a scalar string tensor with the contents of a RaggedTensor.
Requires that `rt.shape.rank` is not `None`.
Note: this converts the entire `RaggedTensor` into a single string scalar.
If you want to convert individual elements, use `tf.strings.as_string(rt)`.
>>> rt1 = tf.ragged.constant([[1, 2, 3], [4, 5]])
>>> ragged_tensor_to_string(rt1).numpy()
b'[[1, 2, 3], [4, 5]]'
>>> rt2 = tf.ragged.constant([[['a'], ['b', 'c']], [['d', 'e', 'f'], []]])
>>> ragged_tensor_to_string(rt2).numpy()
b"[[['a'], ['b', 'c']], [['d', 'e', 'f'], []]]"
>>> rt3 = tf.ragged.constant([[1], [2, 3, 4, 5, 6], [], [], [7], [8, 9]])
>>> ragged_tensor_to_string(rt3, summarize=2).numpy()
b'[[1], [2, 3, ..., 5, 6], ..., [7], [8, 9]]'
Args:
rt: The RaggedTensor that should be converted to a string.
summarize: If specified, then only the first and last `summarize` elements
within each dimension are included in the string. If `-1` or `None`, then
all elements are included.
"""
if (summarize is not None and summarize != -1 and
not (isinstance(summarize, int) and summarize > 0)):
raise ValueError("Expected summarize to be -1 or a positive int, got %r" %
summarize)
with ops.name_scope(None, "AsString", [rt]):
rt = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt)
if rt.shape.rank is None:
raise ValueError("RaggedTensor to_string requires that rt.shape.rank "
"is not None.")
# Convert all elements of `rt` to strings.
if rt.dtype == dtypes.string:
escaped = string_ops.regex_replace(rt.flat_values, r"(['\\])", r"\\\1")
str_t = rt.with_flat_values("'" + escaped + "'")
else:
str_t = rt.with_flat_values(string_ops.as_string(rt.flat_values))
return _ragged_tensor_to_string(str_t, summarize)
def _ragged_tensor_to_string(string_tensor, summarize):
"""Returns a scalar string tensor with the contents of `string_tensor`.
Args:
string_tensor: A potentially ragged tensor with dtype=string.
summarize: Include only the first and last `summarize` elements of each
dimension. If `-1` or `None`, then include all elements.
Returns:
A scalar string Tensor.
"""
if string_tensor.shape.rank == 1:
pieces = string_tensor
else:
pieces = map_fn_lib.map_fn(
lambda s: _ragged_tensor_to_string(s, summarize),
string_tensor,
fn_output_signature=tensor_spec.TensorSpec(None, dtypes.string))
if summarize not in (-1, None):
pieces = control_flow_ops.cond(
_nrows(string_tensor) <= 2 * summarize,
lambda: pieces,
lambda: array_ops.concat( # pylint: disable=g-long-lambda
[pieces[:summarize], ["..."], pieces[-summarize:]],
axis=0))
return "[" + string_ops.reduce_join(pieces, separator=", ") + "]"
def _nrows(tensor, out_type=dtypes.int32):
if isinstance(tensor, ragged_tensor.RaggedTensor):
return tensor.nrows(out_type=out_type)
else:
return array_ops.shape(tensor, out_type=out_type)[0]