Add tf.strings.format() and tf.print() to support RaggedTensors.
PiperOrigin-RevId: 316901841 Change-Id: I5fa78acc118557fcf43d1b805149172f8547c0e1
This commit is contained in:
parent
fc296acdc1
commit
cb264418f6
@ -1264,9 +1264,35 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
":ragged_factory_ops",
|
||||
":ragged_tensor",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ragged_print_op_test",
|
||||
srcs = ["ragged_print_op_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged", # fixdeps: keep
|
||||
":ragged_factory_ops",
|
||||
":ragged_string_ops",
|
||||
":ragged_tensor",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:logging_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:sparse_ops",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import gen_bitwise_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
@ -510,6 +511,7 @@ _RAGGED_DISPATCH_OPS = [
|
||||
['data', 'segment_ids']),
|
||||
(math_ops.unsorted_segment_sqrt_n, ragged_math_ops.segment_sqrt_n,
|
||||
['data', 'segment_ids']),
|
||||
(string_ops.string_format, ragged_string_ops.string_format, ['[inputs]']),
|
||||
(string_ops.reduce_join_v2, ragged_string_ops.reduce_join, ['inputs']),
|
||||
(math_ops.reduce_sum, ragged_math_ops.reduce_sum, ['input_tensor']),
|
||||
(math_ops.reduce_prod, ragged_math_ops.reduce_prod, ['input_tensor']),
|
||||
@ -549,7 +551,7 @@ def register_dispatchers():
|
||||
RaggedDispatcher(original_op, ragged_op, args).register(original_op)
|
||||
|
||||
|
||||
def _ragged_op_signature(op, ragged_args):
|
||||
def _ragged_op_signature(op, ragged_args, ragged_varargs=False):
|
||||
"""Returns a signature for the given op, marking ragged args in bold."""
|
||||
op_name = tf_export.get_canonical_name_for_symbol(op)
|
||||
argspec = tf_inspect.getfullargspec(op)
|
||||
@ -566,7 +568,10 @@ def _ragged_op_signature(op, ragged_args):
|
||||
|
||||
# Add varargs and keyword args
|
||||
if argspec.varargs:
|
||||
arg_names.append('*' + argspec.varargs)
|
||||
if ragged_varargs:
|
||||
arg_names.append('***' + argspec.varargs + '**')
|
||||
else:
|
||||
arg_names.append('*' + argspec.varargs)
|
||||
if argspec.varkw:
|
||||
arg_names.append('**' + argspec.varkw)
|
||||
|
||||
@ -597,6 +602,8 @@ def ragged_op_list(tf_version=1):
|
||||
arginfos = _get_arg_infos(op, ragged_args)
|
||||
ragged_args = [arginfo.position for arginfo in arginfos]
|
||||
lines.append(_ragged_op_signature(op, ragged_args))
|
||||
lines.append(
|
||||
_ragged_op_signature(logging_ops.print_v2, [], ragged_varargs=True))
|
||||
return ('\n\n### Additional ops that support `RaggedTensor`\n\n'
|
||||
'Arguments that accept `RaggedTensor`s are marked in **bold**.\n\n' +
|
||||
'\n'.join(sorted(lines)) + 'n')
|
||||
|
@ -142,8 +142,7 @@ BINARY_INT_OPS = [
|
||||
|
||||
# pylint: disable=g-complex-comprehension
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
class RaggedDispatchTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
def assertSameShape(self, x, y):
|
||||
"""Checks that x and y have the same shape (including ragged shapes)."""
|
||||
@ -763,7 +762,12 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
'tensor': ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]]),
|
||||
'axis': [0, -1]
|
||||
},
|
||||
expected=ragged_factory_ops.constant_value([[5, 4], [3, 2, 1]]))
|
||||
expected=ragged_factory_ops.constant_value([[5, 4], [3, 2, 1]])),
|
||||
dict(
|
||||
op=string_ops.string_format,
|
||||
kwargs={'template': 'Hi {}',
|
||||
'inputs': [ragged_factory_ops.constant_value([[1, 2], [3]])]},
|
||||
expected='Hi [[1, 2], [3]]'),
|
||||
])
|
||||
def testRaggedDispatch(self, op, expected, args=(), result_is_list=False,
|
||||
kwargs=None):
|
||||
@ -819,14 +823,14 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
'math.unsorted_segment_mean', 'math.unsorted_segment_min',
|
||||
'math.unsorted_segment_prod', 'math.unsorted_segment_sqrt_n',
|
||||
'math.unsorted_segment_sum', 'one_hot', 'ones_like', 'rank', 'realdiv',
|
||||
'reduce_all', 'size', 'squeeze', 'stack', 'strings.as_string',
|
||||
'math.reduce_all', 'size', 'squeeze', 'stack', 'strings.as_string',
|
||||
'strings.join', 'strings.length', 'strings.reduce_join',
|
||||
'strings.regex_full_match', 'strings.regex_replace', 'strings.strip',
|
||||
'strings.substr', 'strings.to_hash_bucket_fast',
|
||||
'strings.to_hash_bucket_strong', 'strings.to_hash_bucket',
|
||||
'strings.to_number', 'strings.unicode_script', 'tile', 'truncatediv',
|
||||
'truncatemod', 'zeros_like', 'dynamic_partition', 'reverse',
|
||||
'nn.dropout',
|
||||
'nn.dropout', 'strings.format', 'print'
|
||||
]
|
||||
|
||||
# Ops that should be listed as supported in v1 only.
|
||||
@ -838,15 +842,15 @@ class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
|
||||
v1_ragged_ops = ragged_dispatch.ragged_op_list(tf_version=1)
|
||||
for element in supported_ops + supported_ops_v1:
|
||||
self.assertIn(element, v1_ragged_ops)
|
||||
self.assertIn('`tf.' + element + '`', v1_ragged_ops)
|
||||
for element in supported_ops_v2:
|
||||
self.assertNotIn(element, v1_ragged_ops)
|
||||
self.assertNotIn('`tf.' + element + '`', v1_ragged_ops)
|
||||
|
||||
v2_ragged_ops = ragged_dispatch.ragged_op_list(tf_version=2)
|
||||
for element in supported_ops + supported_ops_v2:
|
||||
self.assertIn(element, v2_ragged_ops)
|
||||
self.assertIn('`tf.' + element + '`', v2_ragged_ops)
|
||||
for element in supported_ops_v1:
|
||||
self.assertNotIn(element, v2_ragged_ops)
|
||||
self.assertNotIn('`tf.' + element + '`', v2_ragged_ops)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
195
tensorflow/python/ops/ragged/ragged_print_op_test.py
Normal file
195
tensorflow/python/ops/ragged/ragged_print_op_test.py
Normal file
@ -0,0 +1,195 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf.print with ragged tensors.
|
||||
|
||||
Note: ragged support for tf.print is implemented by RaggedPrintV2Dispatcher in
|
||||
ragged_dispatch.py.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path
|
||||
import tempfile
|
||||
from absl.testing import parameterized
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedPrintV2Test(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
# pylint: disable=g-long-lambda
|
||||
@parameterized.named_parameters([
|
||||
dict(
|
||||
testcase_name='2d_int_values',
|
||||
inputs=lambda: [ragged_factory_ops.constant([[1, 2], [3]])],
|
||||
expected='[[1, 2], [3]]\n'),
|
||||
dict(
|
||||
testcase_name='3d_int_values',
|
||||
inputs=lambda: [ragged_factory_ops.constant([[[1, 2], [3]], [[4]]])],
|
||||
expected='[[[1, 2], [3]], [[4]]]\n'),
|
||||
dict(
|
||||
testcase_name='2d_str_values',
|
||||
inputs=lambda: [ragged_factory_ops.constant([['a', 'b'], ['c']])],
|
||||
expected="[['a', 'b'], ['c']]\n"),
|
||||
dict(
|
||||
testcase_name='2d_str_values_with_escaping',
|
||||
inputs=lambda: [ragged_factory_ops.constant([["a'b"], ['c"d']])],
|
||||
expected="[['a\\'b'], ['c\"d']]\n"),
|
||||
dict(
|
||||
testcase_name='two_ragged_values',
|
||||
inputs=lambda: [
|
||||
ragged_factory_ops.constant([[1, 2], [3]]),
|
||||
ragged_factory_ops.constant([[5], [], [6, 7, 8]])
|
||||
],
|
||||
expected='[[1, 2], [3]] [[5], [], [6, 7, 8]]\n'),
|
||||
dict(
|
||||
testcase_name='ragged_value_and_non_tensor_values',
|
||||
inputs=lambda:
|
||||
['a', 5, True,
|
||||
ragged_factory_ops.constant([[1, 2], [3]]), 'c'],
|
||||
expected='a 5 True [[1, 2], [3]] c\n'),
|
||||
dict(
|
||||
testcase_name='ragged_value_and_dense_value',
|
||||
inputs=lambda: [
|
||||
ragged_factory_ops.constant([[1, 2], [3]]),
|
||||
constant_op.constant([[1, 2], [3, 4]])
|
||||
],
|
||||
expected='[[1, 2], [3]] [[1 2]\n [3 4]]\n'),
|
||||
dict(
|
||||
testcase_name='ragged_value_and_sparse_value',
|
||||
inputs=lambda: [
|
||||
ragged_factory_ops.constant([[1, 2], [3]]),
|
||||
sparse_ops.from_dense([[1]])
|
||||
],
|
||||
expected=(
|
||||
'[[1, 2], [3]] '
|
||||
"'SparseTensor(indices=[[0 0]], values=[1], shape=[1 1])'\n")),
|
||||
dict(
|
||||
testcase_name='summarize_default',
|
||||
inputs=lambda: [
|
||||
ragged_factory_ops.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9], [10], [
|
||||
], [], [], [], [11, 12]])
|
||||
],
|
||||
expected=('[[1, 2, 3, ..., 7, 8, 9], [10], [], '
|
||||
'..., '
|
||||
'[], [], [11, 12]]\n')),
|
||||
dict(
|
||||
testcase_name='summarize_2',
|
||||
inputs=lambda: [
|
||||
ragged_factory_ops.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9], [10], [
|
||||
], [], [], [], [11, 12]])
|
||||
],
|
||||
summarize=2,
|
||||
expected='[[1, 2, ..., 8, 9], [10], ..., [], [11, 12]]\n'),
|
||||
dict(
|
||||
testcase_name='summarize_neg1',
|
||||
inputs=lambda: [
|
||||
ragged_factory_ops.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9], [10], [
|
||||
], [], [], [], [11, 12]])
|
||||
],
|
||||
summarize=-1,
|
||||
expected=('[[1, 2, 3, 4, 5, 6, 7, 8, 9], [10], '
|
||||
'[], [], [], [], [11, 12]]\n')),
|
||||
])
|
||||
def testRaggedPrint(self, inputs, expected, summarize=None):
|
||||
if callable(inputs):
|
||||
inputs = inputs()
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
path = os.path.join(tmpdirname, 'print_output')
|
||||
kwargs = {'output_stream': 'file://{}'.format(path)}
|
||||
if summarize is not None:
|
||||
kwargs.update(summarize=summarize)
|
||||
self.evaluate(logging_ops.print_v2(*inputs, **kwargs))
|
||||
actual = open(path, 'r').read()
|
||||
self.assertEqual(repr(actual), repr(expected))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedToStringTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters([
|
||||
('2d_int', [[1, 2], [], [3, 4, 5]], '[[1, 2], [], [3, 4, 5]]'),
|
||||
('2d_str', [['a'], ['b'], ['c', 'd']], "[['a'], ['b'], ['c', 'd']]"),
|
||||
('3d_int', [[[1, 2], []], [[3, 4, 5]]], '[[[1, 2], []], [[3, 4, 5]]]'),
|
||||
('escape', [["a'b"], [r'c\d']], r"[['a\'b'], ['c\\d']]"),
|
||||
dict(testcase_name='2d_empty', rt=[], ragged_rank=1, expected='[]'),
|
||||
dict(testcase_name='3d_empty', rt=[], ragged_rank=2, expected='[]'),
|
||||
dict(
|
||||
testcase_name='3d_rrank1',
|
||||
rt=[[[1, 2], [3, 4]], [], [[5, 6]]],
|
||||
ragged_rank=1,
|
||||
expected='[[[1, 2], [3, 4]], [], [[5, 6]]]'),
|
||||
dict(
|
||||
testcase_name='2d_empty_row', rt=[[]], ragged_rank=1,
|
||||
expected='[[]]'),
|
||||
dict(
|
||||
testcase_name='3d_empty_row', rt=[[]], ragged_rank=2,
|
||||
expected='[[]]'),
|
||||
dict(
|
||||
testcase_name='summarize_1',
|
||||
rt=[[1, 2, 3, 4, 5], [], [6], [7], [8, 9]],
|
||||
summarize=1,
|
||||
expected='[[1, ..., 5], ..., [8, 9]]'),
|
||||
dict(
|
||||
testcase_name='summarize_2',
|
||||
rt=[[1, 2, 3, 4, 5], [], [6], [7], [8, 9]],
|
||||
summarize=2,
|
||||
expected='[[1, 2, ..., 4, 5], [], ..., [7], [8, 9]]'),
|
||||
])
|
||||
def testRaggedToString(self, rt, expected, summarize=None, ragged_rank=None):
|
||||
rt = ragged_factory_ops.constant(rt, ragged_rank=ragged_rank)
|
||||
actual = ragged_string_ops.ragged_tensor_to_string(rt, summarize=summarize)
|
||||
self.assertAllEqual(actual, expected)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
('maxelts_BadType', [[1]], "Expected summarize .*, got 'foo'", 'foo'),
|
||||
('maxelts_0', [[1]], 'Expected summarize to be .*, got 0', 0),
|
||||
('maxelts_Neg2', [[1]], 'Expected summarize to be .*, got -2', -2),
|
||||
])
|
||||
def testRaggedToStringErrors(self,
|
||||
rt,
|
||||
error,
|
||||
summarize=None,
|
||||
exception=ValueError):
|
||||
rt = ragged_factory_ops.constant(rt)
|
||||
with self.assertRaisesRegex(exception, error):
|
||||
self.evaluate(
|
||||
ragged_string_ops.ragged_tensor_to_string(rt, summarize=summarize))
|
||||
|
||||
def testRaggedToStringUnknownRank(self):
|
||||
|
||||
@def_function.function(
|
||||
input_signature=[ragged_tensor.RaggedTensorSpec(ragged_rank=1)])
|
||||
def f(rt):
|
||||
return ragged_string_ops.ragged_tensor_to_string(rt)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'RaggedTensor to_string requires '
|
||||
'that rt.shape.rank is not None'):
|
||||
f(ragged_factory_ops.constant([[1, 2], [3]]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
@ -18,10 +18,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_string_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_array_ops
|
||||
@ -30,9 +33,14 @@ from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.util import compat as util_compat
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import dispatch
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
map_fn_lib = LazyLoader("map_fn_lib", globals(),
|
||||
"tensorflow.python.ops.map_fn")
|
||||
|
||||
|
||||
@tf_export("strings.bytes_split")
|
||||
@dispatch.add_dispatch_support
|
||||
def string_bytes_split(input, name=None): # pylint: disable=redefined-builtin
|
||||
@ -640,7 +648,7 @@ def strings_split_v1(input=None, sep=None, maxsplit=-1, # pylint: disable=redef
|
||||
input, dtype=dtypes.string, name="input")
|
||||
|
||||
if input.shape.rank == 0:
|
||||
input = gen_array_ops.expand_dims(input, 0)
|
||||
input = array_ops.expand_dims(input, 0)
|
||||
|
||||
if result_type == "SparseTensor":
|
||||
if input.shape.rank == 1:
|
||||
@ -813,3 +821,108 @@ def ngrams(data,
|
||||
values=output, row_splits=output_splits, validate=False)
|
||||
return array_ops.reshape(output.flat_values,
|
||||
dense_shape) if to_tensor else output
|
||||
|
||||
|
||||
def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
|
||||
"""Version of tf.strings.format that handles RaggedTensors."""
|
||||
if tensor_util.is_tensor(inputs) or ragged_tensor.is_ragged(inputs):
|
||||
inputs = [inputs]
|
||||
|
||||
split_template = template.split(placeholder)
|
||||
if len(inputs) != len(split_template) - 1:
|
||||
raise ValueError("num placeholders in template and num inputs must match"
|
||||
": {} vs {}".format(len(split_template) - 1, len(inputs)))
|
||||
|
||||
with ops.name_scope(name, "StringFormat", [inputs]):
|
||||
output_pieces = [constant_op.constant(split_template[0])]
|
||||
for i, input in enumerate(inputs):
|
||||
if ragged_tensor.is_ragged(input):
|
||||
output_pieces.append(ragged_tensor_to_string(input, summarize))
|
||||
else:
|
||||
output_pieces.append(string_ops.string_format(
|
||||
"{}", [input], summarize=summarize))
|
||||
output_pieces.append(constant_op.constant(split_template[i + 1]))
|
||||
if len(output_pieces) == 1:
|
||||
return output_pieces[0]
|
||||
else:
|
||||
return string_ops.reduce_join(output_pieces)
|
||||
|
||||
|
||||
def ragged_tensor_to_string(rt, summarize=None):
|
||||
"""Returns a scalar string tensor with the contents of a RaggedTensor.
|
||||
|
||||
Requires that `rt.shape.rank` is not `None`.
|
||||
|
||||
Note: this converts the entire `RaggedTensor` into a single string scalar.
|
||||
If you want to convert individual elements, use `tf.strings.as_string(rt)`.
|
||||
|
||||
>>> rt1 = tf.ragged.constant([[1, 2, 3], [4, 5]])
|
||||
>>> ragged_tensor_to_string(rt1).numpy()
|
||||
b'[[1, 2, 3], [4, 5]]'
|
||||
|
||||
>>> rt2 = tf.ragged.constant([[['a'], ['b', 'c']], [['d', 'e', 'f'], []]])
|
||||
>>> ragged_tensor_to_string(rt2).numpy()
|
||||
b"[[['a'], ['b', 'c']], [['d', 'e', 'f'], []]]"
|
||||
|
||||
>>> rt3 = tf.ragged.constant([[1], [2, 3, 4, 5, 6], [], [], [7], [8, 9]])
|
||||
>>> ragged_tensor_to_string(rt3, summarize=2).numpy()
|
||||
b'[[1], [2, 3, ..., 5, 6], ..., [7], [8, 9]]'
|
||||
|
||||
Args:
|
||||
rt: The RaggedTensor that should be converted to a string.
|
||||
summarize: If specified, then only the first and last `summarize` elements
|
||||
within each dimension are included in the string. If `-1` or `None`, then
|
||||
all elements are included.
|
||||
"""
|
||||
if (summarize is not None and summarize != -1 and
|
||||
not (isinstance(summarize, int) and summarize > 0)):
|
||||
raise ValueError("Expected summarize to be -1 or a positive int, got %r" %
|
||||
summarize)
|
||||
with ops.name_scope(None, "AsString", [rt]):
|
||||
rt = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt)
|
||||
if rt.shape.rank is None:
|
||||
raise ValueError("RaggedTensor to_string requires that rt.shape.rank "
|
||||
"is not None.")
|
||||
# Convert all elements of `rt` to strings.
|
||||
if rt.dtype == dtypes.string:
|
||||
escaped = string_ops.regex_replace(rt.flat_values, r"(['\\])", r"\\\1")
|
||||
str_t = rt.with_flat_values("'" + escaped + "'")
|
||||
else:
|
||||
str_t = rt.with_flat_values(string_ops.as_string(rt.flat_values))
|
||||
|
||||
return _ragged_tensor_to_string(str_t, summarize)
|
||||
|
||||
|
||||
def _ragged_tensor_to_string(string_tensor, summarize):
|
||||
"""Returns a scalar string tensor with the contents of `string_tensor`.
|
||||
|
||||
Args:
|
||||
string_tensor: A potentially ragged tensor with dtype=string.
|
||||
summarize: Include only the first and last `summarize` elements of each
|
||||
dimension. If `-1` or `None`, then include all elements.
|
||||
|
||||
Returns:
|
||||
A scalar string Tensor.
|
||||
"""
|
||||
if string_tensor.shape.rank == 1:
|
||||
pieces = string_tensor
|
||||
else:
|
||||
pieces = map_fn_lib.map_fn(
|
||||
lambda s: _ragged_tensor_to_string(s, summarize),
|
||||
string_tensor,
|
||||
fn_output_signature=tensor_spec.TensorSpec(None, dtypes.string))
|
||||
if summarize not in (-1, None):
|
||||
pieces = control_flow_ops.cond(
|
||||
_nrows(string_tensor) <= 2 * summarize,
|
||||
lambda: pieces,
|
||||
lambda: array_ops.concat( # pylint: disable=g-long-lambda
|
||||
[pieces[:summarize], ["..."], pieces[-summarize:]],
|
||||
axis=0))
|
||||
return "[" + string_ops.reduce_join(pieces, separator=", ") + "]"
|
||||
|
||||
|
||||
def _nrows(tensor, out_type=dtypes.int32):
|
||||
if isinstance(tensor, ragged_tensor.RaggedTensor):
|
||||
return tensor.nrows(out_type=out_type)
|
||||
else:
|
||||
return array_ops.shape(tensor, out_type=out_type)[0]
|
||||
|
Loading…
x
Reference in New Issue
Block a user