Add "encoding" attribute to string length op, which controls how "string length" is defined:

* BYTE: The number of bytes in each string.  (Default)
  * UTF8: The number of UTF-8 encoded Unicode code points in each string.

RELNOTES: Add option to calculate string length in Unicode characters
PiperOrigin-RevId: 214478470
This commit is contained in:
A. Unique TensorFlower 2018-09-25 11:56:33 -07:00 committed by TensorFlower Gardener
parent df93001523
commit d5c5df164c
12 changed files with 193 additions and 8 deletions

View File

@ -253,6 +253,7 @@ tensorflow/core/kernels/strided_slice_op_inst_5.cc
tensorflow/core/kernels/strided_slice_op_inst_6.cc
tensorflow/core/kernels/strided_slice_op_inst_7.cc
tensorflow/core/kernels/string_join_op.cc
tensorflow/core/kernels/string_util.cc
tensorflow/core/kernels/tensor_array.cc
tensorflow/core/kernels/tensor_array_ops.cc
tensorflow/core/kernels/tile_functor_cpu.cc

View File

@ -1,5 +1,15 @@
op {
graph_op_name: "StringLength"
attr {
name: "unit"
description: <<END
The unit that is counted to compute string length. One of: `"BYTE"` (for
the number of bytes in each string) or `"UTF8_CHAR"` (for the number of UTF-8
encoded Unicode code points in each string). Results are undefined
if `unit=UTF8_CHAR` and the `input` strings do not contain structurally
valid UTF-8.
END
}
in_arg {
name: "input"
description: <<END

View File

@ -1,6 +1,4 @@
op {
graph_op_name: "StringLength"
endpoint {
name: "strings.length"
}
visibility: HIDDEN
}

View File

@ -4434,8 +4434,16 @@ cc_library(
],
)
cc_library(
name = "string_util",
srcs = ["string_util.cc"],
hdrs = ["string_util.h"],
deps = ["//tensorflow/core:lib"],
)
STRING_DEPS = [
":bounds_check",
":string_util",
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -5166,6 +5174,7 @@ filegroup(
"spacetobatch_functor.h",
"spacetodepth_op.h",
"spectrogram.h",
"string_util.h",
"tensor_array.h",
"tile_functor.h",
"tile_ops_cpu_impl.h",
@ -5334,6 +5343,7 @@ filegroup(
"spectrogram_op.cc",
"stack_ops.cc",
"string_join_op.cc",
"string_util.cc",
"summary_op.cc",
"tensor_array.cc",
"tensor_array_ops.cc",

View File

@ -14,13 +14,18 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/string_util.h"
namespace tensorflow {
namespace {
class StringLengthOp : public OpKernel {
public:
using OpKernel::OpKernel;
explicit StringLengthOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
string unit;
OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit));
OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_));
}
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
@ -32,10 +37,22 @@ class StringLengthOp : public OpKernel {
auto src = input.flat<string>();
auto dst = output->flat<int32>();
for (int n = 0; n < src.size(); ++n) {
dst(n) = src(n).size();
switch (unit_) {
case CharUnit::BYTE:
for (int n = 0; n < src.size(); ++n) {
dst(n) = src(n).size();
}
break;
case CharUnit::UTF8_CHAR:
for (int n = 0; n < src.size(); ++n) {
dst(n) = UTF8StrLen(src(n));
}
break;
}
}
private:
CharUnit unit_ = CharUnit::BYTE;
};
REGISTER_KERNEL_BUILDER(Name("StringLength").Device(DEVICE_CPU),

View File

@ -0,0 +1,63 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/string_util.h"
#include "tensorflow/core/lib/core/errors.h"
namespace {
inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; }
} // namespace
namespace tensorflow {
// Sets unit value based on str.
Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding) {
if (str == "UTF8") {
*encoding = UnicodeEncoding::UTF8;
} else {
return errors::InvalidArgument(strings::StrCat(
"Invalid encoding \"", str, "\": Should be one of: BYTE"));
}
return Status::OK();
}
// Sets unit value based on str.
Status ParseCharUnit(const string& str, CharUnit* unit) {
if (str == "BYTE") {
*unit = CharUnit::BYTE;
} else if (str == "UTF8_CHAR") {
*unit = CharUnit::UTF8_CHAR;
} else {
return errors::InvalidArgument(strings::StrCat(
"Invalid unit \"", str, "\": Should be one of: BYTE, UTF8_CHAR"));
}
return Status::OK();
}
// Return the number of Unicode characters in a UTF-8 string.
// Result may be incorrect if the input string is not valid UTF-8.
int32 UTF8StrLen(const string& string) {
const int32 byte_size = string.size();
const char* const end = string.data() + byte_size;
const char* ptr = string.data();
int32 skipped_count = 0;
while (ptr < end) {
skipped_count += IsTrailByte(*ptr++) ? 1 : 0;
}
const int32 result = byte_size - skipped_count;
return result;
}
} // namespace tensorflow

View File

@ -0,0 +1,45 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
#define TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// Enumeration for unicode encodings. Used by ops such as
// tf.strings.unicode_encode and tf.strings.unicode_decode.
// TODO(edloper): Add support for:
// UTF16, UTF32, UTF16BE, UTF32BE, UTF16LE, UTF32LE
enum class UnicodeEncoding { UTF8 };
// Enumeration for character units. Used by string such as
// tf.strings.length and tf.substr.
// TODO(edloper): Add support for: UTF32_CHAR, etc.
enum class CharUnit { BYTE, UTF8_CHAR };
// Sets `encoding` based on `str`.
Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding);
// Sets `unit` value based on `str`.
Status ParseCharUnit(const string& str, CharUnit* unit);
// Returns the number of Unicode characters in a UTF-8 string.
// Result may be incorrect if the input string is not valid UTF-8.
int32 UTF8StrLen(const string& string);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_

View File

@ -203,6 +203,7 @@ REGISTER_OP("StringStrip")
REGISTER_OP("StringLength")
.Input("input: string")
.Output("output: int32")
.Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("EncodeBase64")

View File

@ -32,6 +32,33 @@ class StringLengthOpTest(test.TestCase):
values = sess.run(lengths)
self.assertAllEqual(values, [[[1, 2], [3, 4], [5, 6]]])
def testUnit(self):
unicode_strings = [u"H\xc3llo", u"\U0001f604"]
utf8_strings = [s.encode("utf-8") for s in unicode_strings]
expected_utf8_byte_lengths = [6, 4]
expected_utf8_char_lengths = [5, 1]
with self.test_session() as sess:
utf8_byte_lengths = string_ops.string_length(utf8_strings, unit="BYTE")
utf8_char_lengths = string_ops.string_length(
utf8_strings, unit="UTF8_CHAR")
self.assertAllEqual(
sess.run(utf8_byte_lengths), expected_utf8_byte_lengths)
self.assertAllEqual(
sess.run(utf8_char_lengths), expected_utf8_char_lengths)
with self.assertRaisesRegexp(
ValueError, "Attr 'unit' of 'StringLength' Op passed string 'XYZ' "
'not in: "BYTE", "UTF8_CHAR"'):
string_ops.string_length(utf8_strings, unit="XYZ")
def testLegacyPositionalName(self):
# Code that predates the 'unit' parameter may have used a positional
# argument for the 'name' parameter. Check that we don't break such code.
strings = [[["1", "12"], ["123", "1234"], ["12345", "123456"]]]
lengths = string_ops.string_length(strings, "some_name")
with self.test_session():
self.assertAllEqual(lengths.eval(), [[[1, 2], [3, 4], [5, 6]]])
if __name__ == "__main__":
test.main()

View File

@ -36,10 +36,12 @@ from tensorflow.python.ops import math_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
# pylint: disable=g-bad-import-order
from tensorflow.python.ops.gen_string_ops import *
from tensorflow.python.util import compat as util_compat
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=g-bad-import-order
# pylint: enable=wildcard-import
@ -328,6 +330,17 @@ def reduce_join(inputs, axis=None,
reduce_join.__doc__ = deprecation.rewrite_argument_docstring(
gen_string_ops.reduce_join.__doc__, "reduction_indices", "axis")
# This wrapper provides backwards compatibility for code that predates the
# unit argument and that passed 'name' as a positional argument.
@tf_export("strings.length")
def string_length(input, name=None, unit="BYTE"):
return gen_string_ops.string_length(input, unit=unit, name=name)
string_length.__doc__ = gen_string_ops.string_length.__doc__
ops.NotDifferentiable("RegexReplace")
ops.NotDifferentiable("StringToHashBucket")
ops.NotDifferentiable("StringToHashBucketFast")

View File

@ -10,7 +10,7 @@ tf_module {
}
member_method {
name: "length"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], "
}
member_method {
name: "regex_full_match"

View File

@ -10,7 +10,7 @@ tf_module {
}
member_method {
name: "length"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], "
}
member_method {
name: "regex_full_match"