From cb264418f6be7a8ce4fcfc6ee19a60404e428162 Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Wed, 17 Jun 2020 09:19:04 -0700 Subject: [PATCH] Add tf.strings.format() and tf.print() to support RaggedTensors. PiperOrigin-RevId: 316901841 Change-Id: I5fa78acc118557fcf43d1b805149172f8547c0e1 --- tensorflow/python/ops/ragged/BUILD | 28 ++- .../python/ops/ragged/ragged_dispatch.py | 11 +- .../python/ops/ragged/ragged_dispatch_test.py | 22 +- .../python/ops/ragged/ragged_print_op_test.py | 195 ++++++++++++++++++ .../python/ops/ragged/ragged_string_ops.py | 117 ++++++++++- 5 files changed, 359 insertions(+), 14 deletions(-) create mode 100644 tensorflow/python/ops/ragged/ragged_print_op_test.py diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD index b2a02b82454..95e5602a246 100644 --- a/tensorflow/python/ops/ragged/BUILD +++ b/tensorflow/python/ops/ragged/BUILD @@ -1264,9 +1264,35 @@ py_test( srcs_version = "PY2AND3", deps = [ ":ragged_array_ops", - "//tensorflow/python:constant_op", + ":ragged_factory_ops", + ":ragged_tensor", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", + "//tensorflow/python:tensor_shape", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "ragged_print_op_test", + srcs = ["ragged_print_op_test.py"], + python_version = "PY3", + srcs_version = "PY2AND3", + deps = [ + ":ragged", # fixdeps: keep + ":ragged_factory_ops", + ":ragged_string_ops", + ":ragged_tensor", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:logging_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:sparse_ops", + "//tensorflow/python/eager:def_function", "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/python/ops/ragged/ragged_dispatch.py b/tensorflow/python/ops/ragged/ragged_dispatch.py index f13bed07ba0..5c9388b8677 100644 --- a/tensorflow/python/ops/ragged/ragged_dispatch.py +++ b/tensorflow/python/ops/ragged/ragged_dispatch.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gen_bitwise_ops +from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import parsing_ops @@ -510,6 +511,7 @@ _RAGGED_DISPATCH_OPS = [ ['data', 'segment_ids']), (math_ops.unsorted_segment_sqrt_n, ragged_math_ops.segment_sqrt_n, ['data', 'segment_ids']), + (string_ops.string_format, ragged_string_ops.string_format, ['[inputs]']), (string_ops.reduce_join_v2, ragged_string_ops.reduce_join, ['inputs']), (math_ops.reduce_sum, ragged_math_ops.reduce_sum, ['input_tensor']), (math_ops.reduce_prod, ragged_math_ops.reduce_prod, ['input_tensor']), @@ -549,7 +551,7 @@ def register_dispatchers(): RaggedDispatcher(original_op, ragged_op, args).register(original_op) -def _ragged_op_signature(op, ragged_args): +def _ragged_op_signature(op, ragged_args, ragged_varargs=False): """Returns a signature for the given op, marking ragged args in bold.""" op_name = tf_export.get_canonical_name_for_symbol(op) argspec = tf_inspect.getfullargspec(op) @@ -566,7 +568,10 @@ def _ragged_op_signature(op, ragged_args): # Add varargs and keyword args if argspec.varargs: - arg_names.append('*' + argspec.varargs) + if ragged_varargs: + arg_names.append('***' + argspec.varargs + '**') + else: + arg_names.append('*' + argspec.varargs) if argspec.varkw: arg_names.append('**' + argspec.varkw) @@ -597,6 +602,8 @@ def ragged_op_list(tf_version=1): arginfos = _get_arg_infos(op, ragged_args) ragged_args = [arginfo.position for arginfo in arginfos] lines.append(_ragged_op_signature(op, ragged_args)) + lines.append( + _ragged_op_signature(logging_ops.print_v2, [], ragged_varargs=True)) return ('\n\n### Additional ops that support `RaggedTensor`\n\n' 'Arguments that accept `RaggedTensor`s are marked in **bold**.\n\n' + '\n'.join(sorted(lines)) + 'n') diff --git a/tensorflow/python/ops/ragged/ragged_dispatch_test.py b/tensorflow/python/ops/ragged/ragged_dispatch_test.py index 60d9f6c8713..193e329e18a 100644 --- a/tensorflow/python/ops/ragged/ragged_dispatch_test.py +++ b/tensorflow/python/ops/ragged/ragged_dispatch_test.py @@ -142,8 +142,7 @@ BINARY_INT_OPS = [ # pylint: disable=g-complex-comprehension @test_util.run_all_in_graph_and_eager_modes -class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, - parameterized.TestCase): +class RaggedDispatchTest(test_util.TensorFlowTestCase, parameterized.TestCase): def assertSameShape(self, x, y): """Checks that x and y have the same shape (including ragged shapes).""" @@ -763,7 +762,12 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, 'tensor': ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]]), 'axis': [0, -1] }, - expected=ragged_factory_ops.constant_value([[5, 4], [3, 2, 1]])) + expected=ragged_factory_ops.constant_value([[5, 4], [3, 2, 1]])), + dict( + op=string_ops.string_format, + kwargs={'template': 'Hi {}', + 'inputs': [ragged_factory_ops.constant_value([[1, 2], [3]])]}, + expected='Hi [[1, 2], [3]]'), ]) def testRaggedDispatch(self, op, expected, args=(), result_is_list=False, kwargs=None): @@ -819,14 +823,14 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, 'math.unsorted_segment_mean', 'math.unsorted_segment_min', 'math.unsorted_segment_prod', 'math.unsorted_segment_sqrt_n', 'math.unsorted_segment_sum', 'one_hot', 'ones_like', 'rank', 'realdiv', - 'reduce_all', 'size', 'squeeze', 'stack', 'strings.as_string', + 'math.reduce_all', 'size', 'squeeze', 'stack', 'strings.as_string', 'strings.join', 'strings.length', 'strings.reduce_join', 'strings.regex_full_match', 'strings.regex_replace', 'strings.strip', 'strings.substr', 'strings.to_hash_bucket_fast', 'strings.to_hash_bucket_strong', 'strings.to_hash_bucket', 'strings.to_number', 'strings.unicode_script', 'tile', 'truncatediv', 'truncatemod', 'zeros_like', 'dynamic_partition', 'reverse', - 'nn.dropout', + 'nn.dropout', 'strings.format', 'print' ] # Ops that should be listed as supported in v1 only. @@ -838,15 +842,15 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, v1_ragged_ops = ragged_dispatch.ragged_op_list(tf_version=1) for element in supported_ops + supported_ops_v1: - self.assertIn(element, v1_ragged_ops) + self.assertIn('`tf.' + element + '`', v1_ragged_ops) for element in supported_ops_v2: - self.assertNotIn(element, v1_ragged_ops) + self.assertNotIn('`tf.' + element + '`', v1_ragged_ops) v2_ragged_ops = ragged_dispatch.ragged_op_list(tf_version=2) for element in supported_ops + supported_ops_v2: - self.assertIn(element, v2_ragged_ops) + self.assertIn('`tf.' + element + '`', v2_ragged_ops) for element in supported_ops_v1: - self.assertNotIn(element, v2_ragged_ops) + self.assertNotIn('`tf.' + element + '`', v2_ragged_ops) if __name__ == '__main__': diff --git a/tensorflow/python/ops/ragged/ragged_print_op_test.py b/tensorflow/python/ops/ragged/ragged_print_op_test.py new file mode 100644 index 00000000000..2b612d463d0 --- /dev/null +++ b/tensorflow/python/ops/ragged/ragged_print_op_test.py @@ -0,0 +1,195 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.print with ragged tensors. + +Note: ragged support for tf.print is implemented by RaggedPrintV2Dispatcher in +ragged_dispatch.py. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path +import tempfile +from absl.testing import parameterized +from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import test_util +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops.ragged import ragged_factory_ops +from tensorflow.python.ops.ragged import ragged_string_ops +from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.platform import googletest + + +@test_util.run_all_in_graph_and_eager_modes +class RaggedPrintV2Test(test_util.TensorFlowTestCase, parameterized.TestCase): + + # pylint: disable=g-long-lambda + @parameterized.named_parameters([ + dict( + testcase_name='2d_int_values', + inputs=lambda: [ragged_factory_ops.constant([[1, 2], [3]])], + expected='[[1, 2], [3]]\n'), + dict( + testcase_name='3d_int_values', + inputs=lambda: [ragged_factory_ops.constant([[[1, 2], [3]], [[4]]])], + expected='[[[1, 2], [3]], [[4]]]\n'), + dict( + testcase_name='2d_str_values', + inputs=lambda: [ragged_factory_ops.constant([['a', 'b'], ['c']])], + expected="[['a', 'b'], ['c']]\n"), + dict( + testcase_name='2d_str_values_with_escaping', + inputs=lambda: [ragged_factory_ops.constant([["a'b"], ['c"d']])], + expected="[['a\\'b'], ['c\"d']]\n"), + dict( + testcase_name='two_ragged_values', + inputs=lambda: [ + ragged_factory_ops.constant([[1, 2], [3]]), + ragged_factory_ops.constant([[5], [], [6, 7, 8]]) + ], + expected='[[1, 2], [3]] [[5], [], [6, 7, 8]]\n'), + dict( + testcase_name='ragged_value_and_non_tensor_values', + inputs=lambda: + ['a', 5, True, + ragged_factory_ops.constant([[1, 2], [3]]), 'c'], + expected='a 5 True [[1, 2], [3]] c\n'), + dict( + testcase_name='ragged_value_and_dense_value', + inputs=lambda: [ + ragged_factory_ops.constant([[1, 2], [3]]), + constant_op.constant([[1, 2], [3, 4]]) + ], + expected='[[1, 2], [3]] [[1 2]\n [3 4]]\n'), + dict( + testcase_name='ragged_value_and_sparse_value', + inputs=lambda: [ + ragged_factory_ops.constant([[1, 2], [3]]), + sparse_ops.from_dense([[1]]) + ], + expected=( + '[[1, 2], [3]] ' + "'SparseTensor(indices=[[0 0]], values=[1], shape=[1 1])'\n")), + dict( + testcase_name='summarize_default', + inputs=lambda: [ + ragged_factory_ops.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9], [10], [ + ], [], [], [], [11, 12]]) + ], + expected=('[[1, 2, 3, ..., 7, 8, 9], [10], [], ' + '..., ' + '[], [], [11, 12]]\n')), + dict( + testcase_name='summarize_2', + inputs=lambda: [ + ragged_factory_ops.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9], [10], [ + ], [], [], [], [11, 12]]) + ], + summarize=2, + expected='[[1, 2, ..., 8, 9], [10], ..., [], [11, 12]]\n'), + dict( + testcase_name='summarize_neg1', + inputs=lambda: [ + ragged_factory_ops.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9], [10], [ + ], [], [], [], [11, 12]]) + ], + summarize=-1, + expected=('[[1, 2, 3, 4, 5, 6, 7, 8, 9], [10], ' + '[], [], [], [], [11, 12]]\n')), + ]) + def testRaggedPrint(self, inputs, expected, summarize=None): + if callable(inputs): + inputs = inputs() + with tempfile.TemporaryDirectory() as tmpdirname: + path = os.path.join(tmpdirname, 'print_output') + kwargs = {'output_stream': 'file://{}'.format(path)} + if summarize is not None: + kwargs.update(summarize=summarize) + self.evaluate(logging_ops.print_v2(*inputs, **kwargs)) + actual = open(path, 'r').read() + self.assertEqual(repr(actual), repr(expected)) + + +@test_util.run_all_in_graph_and_eager_modes +class RaggedToStringTest(test_util.TensorFlowTestCase, parameterized.TestCase): + + @parameterized.named_parameters([ + ('2d_int', [[1, 2], [], [3, 4, 5]], '[[1, 2], [], [3, 4, 5]]'), + ('2d_str', [['a'], ['b'], ['c', 'd']], "[['a'], ['b'], ['c', 'd']]"), + ('3d_int', [[[1, 2], []], [[3, 4, 5]]], '[[[1, 2], []], [[3, 4, 5]]]'), + ('escape', [["a'b"], [r'c\d']], r"[['a\'b'], ['c\\d']]"), + dict(testcase_name='2d_empty', rt=[], ragged_rank=1, expected='[]'), + dict(testcase_name='3d_empty', rt=[], ragged_rank=2, expected='[]'), + dict( + testcase_name='3d_rrank1', + rt=[[[1, 2], [3, 4]], [], [[5, 6]]], + ragged_rank=1, + expected='[[[1, 2], [3, 4]], [], [[5, 6]]]'), + dict( + testcase_name='2d_empty_row', rt=[[]], ragged_rank=1, + expected='[[]]'), + dict( + testcase_name='3d_empty_row', rt=[[]], ragged_rank=2, + expected='[[]]'), + dict( + testcase_name='summarize_1', + rt=[[1, 2, 3, 4, 5], [], [6], [7], [8, 9]], + summarize=1, + expected='[[1, ..., 5], ..., [8, 9]]'), + dict( + testcase_name='summarize_2', + rt=[[1, 2, 3, 4, 5], [], [6], [7], [8, 9]], + summarize=2, + expected='[[1, 2, ..., 4, 5], [], ..., [7], [8, 9]]'), + ]) + def testRaggedToString(self, rt, expected, summarize=None, ragged_rank=None): + rt = ragged_factory_ops.constant(rt, ragged_rank=ragged_rank) + actual = ragged_string_ops.ragged_tensor_to_string(rt, summarize=summarize) + self.assertAllEqual(actual, expected) + + @parameterized.named_parameters([ + ('maxelts_BadType', [[1]], "Expected summarize .*, got 'foo'", 'foo'), + ('maxelts_0', [[1]], 'Expected summarize to be .*, got 0', 0), + ('maxelts_Neg2', [[1]], 'Expected summarize to be .*, got -2', -2), + ]) + def testRaggedToStringErrors(self, + rt, + error, + summarize=None, + exception=ValueError): + rt = ragged_factory_ops.constant(rt) + with self.assertRaisesRegex(exception, error): + self.evaluate( + ragged_string_ops.ragged_tensor_to_string(rt, summarize=summarize)) + + def testRaggedToStringUnknownRank(self): + + @def_function.function( + input_signature=[ragged_tensor.RaggedTensorSpec(ragged_rank=1)]) + def f(rt): + return ragged_string_ops.ragged_tensor_to_string(rt) + + with self.assertRaisesRegex( + ValueError, 'RaggedTensor to_string requires ' + 'that rt.shape.rank is not None'): + f(ragged_factory_ops.constant([[1, 2], [3]])) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/python/ops/ragged/ragged_string_ops.py b/tensorflow/python/ops/ragged/ragged_string_ops.py index 0d9c4d506f3..0ac23c298ba 100755 --- a/tensorflow/python/ops/ragged/ragged_string_ops.py +++ b/tensorflow/python/ops/ragged/ragged_string_ops.py @@ -18,10 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_string_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops.ragged import ragged_array_ops @@ -30,9 +33,14 @@ from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.util import compat as util_compat from tensorflow.python.util import deprecation from tensorflow.python.util import dispatch +from tensorflow.python.util.lazy_loader import LazyLoader from tensorflow.python.util.tf_export import tf_export +map_fn_lib = LazyLoader("map_fn_lib", globals(), + "tensorflow.python.ops.map_fn") + + @tf_export("strings.bytes_split") @dispatch.add_dispatch_support def string_bytes_split(input, name=None): # pylint: disable=redefined-builtin @@ -640,7 +648,7 @@ def strings_split_v1(input=None, sep=None, maxsplit=-1, # pylint: disable=redef input, dtype=dtypes.string, name="input") if input.shape.rank == 0: - input = gen_array_ops.expand_dims(input, 0) + input = array_ops.expand_dims(input, 0) if result_type == "SparseTensor": if input.shape.rank == 1: @@ -813,3 +821,108 @@ def ngrams(data, values=output, row_splits=output_splits, validate=False) return array_ops.reshape(output.flat_values, dense_shape) if to_tensor else output + + +def string_format(template, inputs, placeholder="{}", summarize=3, name=None): + """Version of tf.strings.format that handles RaggedTensors.""" + if tensor_util.is_tensor(inputs) or ragged_tensor.is_ragged(inputs): + inputs = [inputs] + + split_template = template.split(placeholder) + if len(inputs) != len(split_template) - 1: + raise ValueError("num placeholders in template and num inputs must match" + ": {} vs {}".format(len(split_template) - 1, len(inputs))) + + with ops.name_scope(name, "StringFormat", [inputs]): + output_pieces = [constant_op.constant(split_template[0])] + for i, input in enumerate(inputs): + if ragged_tensor.is_ragged(input): + output_pieces.append(ragged_tensor_to_string(input, summarize)) + else: + output_pieces.append(string_ops.string_format( + "{}", [input], summarize=summarize)) + output_pieces.append(constant_op.constant(split_template[i + 1])) + if len(output_pieces) == 1: + return output_pieces[0] + else: + return string_ops.reduce_join(output_pieces) + + +def ragged_tensor_to_string(rt, summarize=None): + """Returns a scalar string tensor with the contents of a RaggedTensor. + + Requires that `rt.shape.rank` is not `None`. + + Note: this converts the entire `RaggedTensor` into a single string scalar. + If you want to convert individual elements, use `tf.strings.as_string(rt)`. + + >>> rt1 = tf.ragged.constant([[1, 2, 3], [4, 5]]) + >>> ragged_tensor_to_string(rt1).numpy() + b'[[1, 2, 3], [4, 5]]' + + >>> rt2 = tf.ragged.constant([[['a'], ['b', 'c']], [['d', 'e', 'f'], []]]) + >>> ragged_tensor_to_string(rt2).numpy() + b"[[['a'], ['b', 'c']], [['d', 'e', 'f'], []]]" + + >>> rt3 = tf.ragged.constant([[1], [2, 3, 4, 5, 6], [], [], [7], [8, 9]]) + >>> ragged_tensor_to_string(rt3, summarize=2).numpy() + b'[[1], [2, 3, ..., 5, 6], ..., [7], [8, 9]]' + + Args: + rt: The RaggedTensor that should be converted to a string. + summarize: If specified, then only the first and last `summarize` elements + within each dimension are included in the string. If `-1` or `None`, then + all elements are included. + """ + if (summarize is not None and summarize != -1 and + not (isinstance(summarize, int) and summarize > 0)): + raise ValueError("Expected summarize to be -1 or a positive int, got %r" % + summarize) + with ops.name_scope(None, "AsString", [rt]): + rt = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt) + if rt.shape.rank is None: + raise ValueError("RaggedTensor to_string requires that rt.shape.rank " + "is not None.") + # Convert all elements of `rt` to strings. + if rt.dtype == dtypes.string: + escaped = string_ops.regex_replace(rt.flat_values, r"(['\\])", r"\\\1") + str_t = rt.with_flat_values("'" + escaped + "'") + else: + str_t = rt.with_flat_values(string_ops.as_string(rt.flat_values)) + + return _ragged_tensor_to_string(str_t, summarize) + + +def _ragged_tensor_to_string(string_tensor, summarize): + """Returns a scalar string tensor with the contents of `string_tensor`. + + Args: + string_tensor: A potentially ragged tensor with dtype=string. + summarize: Include only the first and last `summarize` elements of each + dimension. If `-1` or `None`, then include all elements. + + Returns: + A scalar string Tensor. + """ + if string_tensor.shape.rank == 1: + pieces = string_tensor + else: + pieces = map_fn_lib.map_fn( + lambda s: _ragged_tensor_to_string(s, summarize), + string_tensor, + fn_output_signature=tensor_spec.TensorSpec(None, dtypes.string)) + if summarize not in (-1, None): + pieces = control_flow_ops.cond( + _nrows(string_tensor) <= 2 * summarize, + lambda: pieces, + lambda: array_ops.concat( # pylint: disable=g-long-lambda + [pieces[:summarize], ["..."], pieces[-summarize:]], + axis=0)) + return "[" + string_ops.reduce_join(pieces, separator=", ") + "]" + + +def _nrows(tensor, out_type=dtypes.int32): + if isinstance(tensor, ragged_tensor.RaggedTensor): + return tensor.nrows(out_type=out_type) + else: + return array_ops.shape(tensor, out_type=out_type)[0]