Add tf.strings.unicode_script, which detects the script of a unicode codepoint

based on standard ranges.

PiperOrigin-RevId: 214796357
This commit is contained in:
A. Unique TensorFlower 2018-09-27 10:31:36 -07:00 committed by TensorFlower Gardener
parent 3002b10e29
commit 334244be68
15 changed files with 290 additions and 0 deletions

View File

@ -0,0 +1,28 @@
op {
graph_op_name: "UnicodeScript"
endpoint {
name: "UnicodeScript"
}
in_arg {
name: "input"
description: <<END
A Tensor of int32 Unicode code points.
END
}
out_arg {
name: "output"
description: <<END
A Tensor of int32 script codes corresponding to each input code point.
END
}
summary: <<END
Determine the script codes of a given tensor of Unicode integer code points.
END
description: <<END
This operation converts Unicode code points to script codes corresponding to
each code point. Script codes correspond to International Components for
Unicode (ICU) UScriptCode values. See http://icu-project.org/apiref/icu4c/uscript_8h.html.
Returns -1 (USCRIPT_INVALID_CODE) for invalid codepoints. Output shape will
match input shape.
END
}

View File

@ -0,0 +1,6 @@
op {
graph_op_name: "UnicodeScript"
endpoint {
name: "strings.unicode_script"
}
}

View File

@ -4431,6 +4431,7 @@ cc_library(
":string_strip_op",
":string_to_hash_bucket_op",
":substr_op",
":unicode_script_op",
],
)
@ -5471,6 +5472,7 @@ filegroup(
"batch_kernels.*",
"regex_full_match_op.cc",
"regex_replace_op.cc",
"unicode_script_op.cc",
# Ops that are inherently incompatible with Android (e.g. tied to x86 platform).
"mkl_*",
"xsmm_*",
@ -6565,6 +6567,16 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "unicode_script_op",
srcs = ["unicode_script_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:string_ops_op_lib",
"@icu//:common",
],
)
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.

View File

@ -0,0 +1,53 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "unicode/errorcode.h" // TF:icu
#include "unicode/uscript.h" // TF:icu
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
class UnicodeScriptOp : public OpKernel {
public:
explicit UnicodeScriptOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor* input_tensor;
OP_REQUIRES_OK(context, context->input("input", &input_tensor));
const auto& input_flat = input_tensor->flat<int32>();
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("output", input_tensor->shape(),
&output_tensor));
auto output_flat = output_tensor->flat<int32>();
icu::ErrorCode status;
for (int i = 0; i < input_flat.size(); i++) {
UScriptCode script_code = uscript_getScript(input_flat(i), status);
if (status.isSuccess()) {
output_flat(i) = script_code;
} else {
output_flat(i) = -1;
status.reset();
}
}
}
};
REGISTER_KERNEL_BUILDER(Name("UnicodeScript").Device(DEVICE_CPU),
UnicodeScriptOp);
} // namespace tensorflow

View File

@ -244,4 +244,9 @@ REGISTER_OP("Substr")
return shape_inference::BroadcastBinaryOpShapeFn(c);
});
REGISTER_OP("UnicodeScript")
.Input("input: int32")
.Output("output: int32")
.SetShapeFn(shape_inference::UnchangedShape);
} // namespace tensorflow

View File

@ -1097,6 +1097,18 @@ tf_py_test(
],
)
tf_py_test(
name = "unicode_script_op_test",
size = "small",
srcs = ["unicode_script_op_test.py"],
additional_deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:string_ops",
],
)
cuda_py_test(
name = "topk_op_test",
size = "small",

View File

@ -0,0 +1,57 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#===============================================================================
"""Functional tests for UnicodeScript op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
class UnicodeScriptOpTest(test.TestCase):
def testValidScripts(self):
inputs = [
ord("a"),
0x0411, # CYRILLIC CAPITAL LETTER BE
0x82b8, # CJK UNIFIED IDEOGRAPH-82B8
ord(",")
]
with self.cached_session():
input_vector = constant_op.constant(inputs, dtypes.int32)
outputs = string_ops.unicode_script(input_vector).eval()
self.assertAllEqual(
outputs,
[
25, # USCRIPT_LATIN (LATN)
8, # USCRIPT_CYRILLIC (CYRL)
17, # USCRIPT_HAN (HANI)
0 # USCRIPT_COMMON (ZYYY)
])
def testInvalidScript(self):
inputs = [-100, 0xffffff]
with self.cached_session():
input_vector = constant_op.constant(inputs, dtypes.int32)
outputs = string_ops.unicode_script(input_vector).eval()
self.assertAllEqual(outputs, [-1, -1])
if __name__ == "__main__":
test.main()

View File

@ -48,4 +48,8 @@ tf_module {
name: "to_number"
argspec: "args=[\'string_tensor\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "unicode_script"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -48,4 +48,8 @@ tf_module {
name: "to_number"
argspec: "args=[\'string_tensor\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "unicode_script"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -125,6 +125,7 @@ genrule(
"@gemmlowp//:LICENSE",
"@gif_archive//:COPYING",
"@highwayhash//:LICENSE",
"@icu//:icu4c/LICENSE",
"@jpeg//:LICENSE.md",
"@llvm//:LICENSE.TXT",
"@lmdb//:LICENSE",
@ -192,6 +193,7 @@ genrule(
"@gemmlowp//:LICENSE",
"@gif_archive//:COPYING",
"@highwayhash//:LICENSE",
"@icu//:icu4j/main/shared/licenses/LICENSE",
"@jpeg//:LICENSE.md",
"@llvm//:LICENSE.TXT",
"@lmdb//:LICENSE",

View File

@ -153,6 +153,7 @@ filegroup(
"@gemmlowp//:LICENSE",
"@gif_archive//:COPYING",
"@highwayhash//:LICENSE",
"@icu//:icu4c/LICENSE",
"@jpeg//:LICENSE.md",
"@lmdb//:LICENSE",
"@local_config_sycl//sycl:LICENSE.text",

View File

@ -21,9 +21,11 @@ load(
"def_file_filter_configure",
)
load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
load("//third_party/icu:workspace.bzl", icu = "repo")
def initialize_third_party():
flatbuffers()
icu()
# Sanitize a dependency so that it works correctly from code that includes
# TensorFlow as a submodule.

1
third_party/icu/BUILD vendored Normal file
View File

@ -0,0 +1 @@
# This empty BUILD file is required to make Bazel treat this directory as a package.

88
third_party/icu/BUILD.bazel vendored Normal file
View File

@ -0,0 +1,88 @@
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
exports_files([
"icu4c/LICENSE",
"icu4j/main/shared/licenses/LICENSE",
])
cc_library(
name = "headers",
hdrs = glob(["icu4c/source/common/unicode/*.h"]),
includes = [
"icu4c/source/common",
],
deps = [
],
)
cc_library(
name = "common",
hdrs = glob(["icu4c/source/common/unicode/*.h"]),
includes = [
"icu4c/source/common",
],
deps = [
":icuuc",
],
)
cc_library(
name = "icuuc",
srcs = glob(
[
"icu4c/source/common/*.c",
"icu4c/source/common/*.cpp",
"icu4c/source/stubdata/*.cpp",
],
),
hdrs = glob([
"icu4c/source/common/*.h",
]),
copts = [
"-DU_COMMON_IMPLEMENTATION",
"-DU_HAVE_STD_ATOMICS",
] + select({
":android": [
"-fdata-sections",
"-DGOOGLE_VENDOR_SRC_BRANCH",
"-DU_HAVE_NL_LANGINFO_CODESET=0",
"-Wno-deprecated-declarations",
],
":apple": [
"-DGOOGLE_VENDOR_SRC_BRANCH",
"-Wno-shorten-64-to-32",
"-Wno-unused-variable",
],
":windows": [
"/utf-8",
"/DLOCALE_ALLOW_NEUTRAL_NAMES=0",
],
"//conditions:default": [],
}),
tags = ["requires-rtti"],
visibility = [
"//visibility:private",
],
deps = [
":headers",
],
)
config_setting(
name = "android",
values = {"crosstool_top": "//external:android/crosstool"},
)
config_setting(
name = "apple",
values = {"cpu": "darwin"},
)
config_setting(
name = "windows",
values = {"cpu": "x64_windows"},
)

15
third_party/icu/workspace.bzl vendored Normal file
View File

@ -0,0 +1,15 @@
"""Loads a lightweight subset of the ICU library for Unicode processing."""
load("//third_party:repo.bzl", "third_party_http_archive")
def repo():
third_party_http_archive(
name = "icu",
strip_prefix = "icu-release-62-1",
sha256 = "e15ffd84606323cbad5515bf9ecdf8061cc3bf80fb883b9e6aa162e485aa9761",
urls = [
"https://mirror.bazel.build/github.com/unicode-org/icu/archive/release-62-1.tar.gz",
"https://github.com/unicode-org/icu/archive/release-62-1.tar.gz",
],
build_file = "//third_party/icu:BUILD.bazel",
)