Merge pull request #25859 from yongtang:25857-strings.lower
PiperOrigin-RevId: 246425483
This commit is contained in:
commit
f106893959
@ -0,0 +1,3 @@
|
||||
op {
|
||||
graph_op_name: "StringLower"
|
||||
}
|
@ -0,0 +1,3 @@
|
||||
op {
|
||||
graph_op_name: "StringUpper"
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
op {
|
||||
graph_op_name: "StringLower"
|
||||
endpoint {
|
||||
name: "strings.lower"
|
||||
}
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
op {
|
||||
graph_op_name: "StringUpper"
|
||||
endpoint {
|
||||
name: "strings.upper"
|
||||
}
|
||||
}
|
@ -5066,9 +5066,11 @@ cc_library(
|
||||
":string_format_op",
|
||||
":string_join_op",
|
||||
":string_length_op",
|
||||
":string_lower_op",
|
||||
":string_split_op",
|
||||
":string_strip_op",
|
||||
":string_to_hash_bucket_op",
|
||||
":string_upper_op",
|
||||
":substr_op",
|
||||
":unicode_ops",
|
||||
":unicode_script_op",
|
||||
@ -5204,6 +5206,24 @@ tf_kernel_library(
|
||||
deps = STRING_DEPS,
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "string_lower_op",
|
||||
prefix = "string_lower_op",
|
||||
deps = STRING_DEPS + [
|
||||
"@com_google_absl//absl/strings",
|
||||
"@icu//:common",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "string_upper_op",
|
||||
prefix = "string_upper_op",
|
||||
deps = STRING_DEPS + [
|
||||
"@com_google_absl//absl/strings",
|
||||
"@icu//:common",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "substr_op",
|
||||
prefix = "substr_op",
|
||||
@ -6193,6 +6213,8 @@ filegroup(
|
||||
"batch_kernels.*",
|
||||
"regex_full_match_op.cc",
|
||||
"regex_replace_op.cc",
|
||||
"string_lower_op.cc", # Requires ICU for unicode.
|
||||
"string_upper_op.cc", # Requires ICU for unicode.
|
||||
"unicode_ops.cc",
|
||||
"unicode_script_op.cc",
|
||||
# Ops that are inherently incompatible with Android (e.g. tied to x86 platform).
|
||||
|
72
tensorflow/core/kernels/string_lower_op.cc
Normal file
72
tensorflow/core/kernels/string_lower_op.cc
Normal file
@ -0,0 +1,72 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/string_ops.cc.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "unicode/unistr.h" // TF:icu
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class StringLowerOp : public OpKernel {
|
||||
public:
|
||||
explicit StringLowerOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("encoding", &encoding_));
|
||||
OP_REQUIRES(context, encoding_.empty() || encoding_ == "utf-8",
|
||||
errors::InvalidArgument(
|
||||
"only utf-8 or '' (no encoding) is supported, received ",
|
||||
encoding_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor* input_tensor;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
|
||||
Tensor* output_tensor;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_output(0, input_tensor->shape(), &output_tensor));
|
||||
|
||||
const auto input = input_tensor->flat<string>();
|
||||
auto output = output_tensor->flat<string>();
|
||||
|
||||
if (encoding_.empty()) {
|
||||
for (int64 i = 0; i < input.size(); ++i) {
|
||||
StringPiece entry(input(i));
|
||||
output(i) = absl::AsciiStrToLower(entry);
|
||||
}
|
||||
} else {
|
||||
// The validation of utf-8 has already been done in GetAttr above.
|
||||
for (int64 i = 0; i < input.size(); ++i) {
|
||||
icu::UnicodeString us(input(i).c_str(), "UTF-8");
|
||||
us.toLower();
|
||||
us.toUTF8String(output(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
string encoding_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StringLower").Device(DEVICE_CPU), StringLowerOp);
|
||||
|
||||
} // namespace tensorflow
|
71
tensorflow/core/kernels/string_upper_op.cc
Normal file
71
tensorflow/core/kernels/string_upper_op.cc
Normal file
@ -0,0 +1,71 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/string_ops.cc.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "unicode/unistr.h" // TF:icu
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class StringUpperOp : public OpKernel {
|
||||
public:
|
||||
explicit StringUpperOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("encoding", &encoding_));
|
||||
OP_REQUIRES(context, encoding_.empty() || encoding_ == "utf-8",
|
||||
errors::InvalidArgument(
|
||||
"only utf-8 or '' (no encoding) is supported, received ",
|
||||
encoding_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor* input_tensor;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
|
||||
Tensor* output_tensor;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_output(0, input_tensor->shape(), &output_tensor));
|
||||
|
||||
const auto input = input_tensor->flat<string>();
|
||||
auto output = output_tensor->flat<string>();
|
||||
if (encoding_.empty()) {
|
||||
for (int64 i = 0; i < input.size(); ++i) {
|
||||
StringPiece entry(input(i));
|
||||
output(i) = absl::AsciiStrToUpper(entry);
|
||||
}
|
||||
} else {
|
||||
// The validation of utf-8 has already been done in GetAttr above.
|
||||
for (int64 i = 0; i < input.size(); ++i) {
|
||||
icu::UnicodeString us(input(i).c_str(), "UTF-8");
|
||||
us.toUpper();
|
||||
us.toUTF8String(output(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
string encoding_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StringUpper").Device(DEVICE_CPU), StringUpperOp);
|
||||
|
||||
} // namespace tensorflow
|
@ -206,6 +206,18 @@ REGISTER_OP("StringSplitV2")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("StringLower")
|
||||
.Input("input: string")
|
||||
.Output("output: string")
|
||||
.Attr("encoding: string =''")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("StringUpper")
|
||||
.Input("input: string")
|
||||
.Output("output: string")
|
||||
.Attr("encoding: string =''")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("StringStrip")
|
||||
.Input("input: string")
|
||||
.Output("output: string")
|
||||
|
@ -1118,6 +1118,34 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "string_lower_op_test",
|
||||
size = "small",
|
||||
srcs = ["string_lower_op_test.py"],
|
||||
additional_deps = [
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:string_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "string_upper_op_test",
|
||||
size = "small",
|
||||
srcs = ["string_upper_op_test.py"],
|
||||
additional_deps = [
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:string_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "substr_op_test",
|
||||
size = "small",
|
||||
|
56
tensorflow/python/kernel_tests/string_lower_op_test.py
Normal file
56
tensorflow/python/kernel_tests/string_lower_op_test.py
Normal file
@ -0,0 +1,56 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for string_lower_op."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class StringLowerOpTest(test.TestCase):
|
||||
"""Test cases for tf.strings.lower."""
|
||||
|
||||
def test_string_lower(self):
|
||||
strings = ["Pigs on The Wing", "aNimals"]
|
||||
|
||||
with self.cached_session():
|
||||
output = string_ops.string_lower(strings)
|
||||
output = self.evaluate(output)
|
||||
self.assertAllEqual(output, [b"pigs on the wing", b"animals"])
|
||||
|
||||
def test_string_lower_2d(self):
|
||||
strings = [["pigS on THE wIng", "aniMals"], [" hello ", "\n\tWorld! \r \n"]]
|
||||
|
||||
with self.cached_session():
|
||||
output = string_ops.string_lower(strings)
|
||||
output = self.evaluate(output)
|
||||
self.assertAllEqual(output, [[b"pigs on the wing", b"animals"],
|
||||
[b" hello ", b"\n\tworld! \r \n"]])
|
||||
|
||||
def test_string_upper_unicode(self):
|
||||
strings = [["ÓÓSSCHLOË"]]
|
||||
with self.cached_session():
|
||||
output = string_ops.string_lower(strings, encoding="utf-8")
|
||||
output = self.evaluate(output)
|
||||
# output: "óósschloë"
|
||||
self.assertAllEqual(output, [[b"\xc3\xb3\xc3\xb3sschlo\xc3\xab"]])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
56
tensorflow/python/kernel_tests/string_upper_op_test.py
Normal file
56
tensorflow/python/kernel_tests/string_upper_op_test.py
Normal file
@ -0,0 +1,56 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for string_upper_op."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class StringUpperOpTest(test.TestCase):
|
||||
"""Test cases for tf.strings.upper."""
|
||||
|
||||
def test_string_upper(self):
|
||||
strings = ["Pigs on The Wing", "aNimals"]
|
||||
|
||||
with self.cached_session():
|
||||
output = string_ops.string_upper(strings)
|
||||
output = self.evaluate(output)
|
||||
self.assertAllEqual(output, [b"PIGS ON THE WING", b"ANIMALS"])
|
||||
|
||||
def test_string_upper_2d(self):
|
||||
strings = [["pigS on THE wIng", "aniMals"], [" hello ", "\n\tWorld! \r \n"]]
|
||||
|
||||
with self.cached_session():
|
||||
output = string_ops.string_upper(strings)
|
||||
output = self.evaluate(output)
|
||||
self.assertAllEqual(output, [[b"PIGS ON THE WING", b"ANIMALS"],
|
||||
[b" HELLO ", b"\n\tWORLD! \r \n"]])
|
||||
|
||||
def test_string_upper_unicode(self):
|
||||
strings = [["óósschloë"]]
|
||||
with self.cached_session():
|
||||
output = string_ops.string_upper(strings, encoding="utf-8")
|
||||
output = self.evaluate(output)
|
||||
# output: "ÓÓSSCHLOË"
|
||||
self.assertAllEqual(output, [[b"\xc3\x93\xc3\x93SSCHLO\xc3\x8b"]])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -3872,6 +3872,10 @@ tf_module {
|
||||
name: "StringLength"
|
||||
argspec: "args=[\'input\', \'unit\', \'name\'], varargs=None, keywords=None, defaults=[\'BYTE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StringLower"
|
||||
argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StringSplit"
|
||||
argspec: "args=[\'input\', \'delimiter\', \'skip_empty\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
@ -3900,6 +3904,10 @@ tf_module {
|
||||
name: "StringToNumber"
|
||||
argspec: "args=[\'string_tensor\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StringUpper"
|
||||
argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Sub"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -20,6 +20,10 @@ tf_module {
|
||||
name: "length"
|
||||
argspec: "args=[\'input\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "lower"
|
||||
argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_join"
|
||||
argspec: "args=[\'inputs\', \'axis\', \'keep_dims\', \'separator\', \'name\', \'reduction_indices\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'\', \'None\', \'None\', \'None\'], "
|
||||
@ -88,4 +92,8 @@ tf_module {
|
||||
name: "unicode_transcode"
|
||||
argspec: "args=[\'input\', \'input_encoding\', \'output_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "upper"
|
||||
argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
@ -3872,6 +3872,10 @@ tf_module {
|
||||
name: "StringLength"
|
||||
argspec: "args=[\'input\', \'unit\', \'name\'], varargs=None, keywords=None, defaults=[\'BYTE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StringLower"
|
||||
argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StringSplit"
|
||||
argspec: "args=[\'input\', \'delimiter\', \'skip_empty\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
@ -3900,6 +3904,10 @@ tf_module {
|
||||
name: "StringToNumber"
|
||||
argspec: "args=[\'string_tensor\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StringUpper"
|
||||
argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Sub"
|
||||
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -20,6 +20,10 @@ tf_module {
|
||||
name: "length"
|
||||
argspec: "args=[\'input\', \'unit\', \'name\'], varargs=None, keywords=None, defaults=[\'BYTE\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "lower"
|
||||
argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_join"
|
||||
argspec: "args=[\'inputs\', \'axis\', \'keepdims\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'\', \'None\'], "
|
||||
@ -88,4 +92,8 @@ tf_module {
|
||||
name: "unicode_transcode"
|
||||
argspec: "args=[\'input\', \'input_encoding\', \'output_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "upper"
|
||||
argspec: "args=[\'input\', \'encoding\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user