Add StaticRegexFullMatch which can be used in place of RegexFullMatch when the regex pattern are fixed.

This allows the Op to perform the expensive regex compilation once upon creation instead of with each call to compute.

RELNOTES: Performance improvements for regex full match operations.
PiperOrigin-RevId: 211835278
This commit is contained in:
A. Unique TensorFlower 2018-09-06 11:03:14 -07:00 committed by TensorFlower Gardener
parent 025277a159
commit ca5952670d
6 changed files with 151 additions and 12 deletions

View File

@ -0,0 +1,29 @@
op {
graph_op_name: "StaticRegexFullMatch"
in_arg {
name: "input"
description: <<END
A string tensor of the text to be processed.
END
}
out_arg {
name: "output"
description: <<END
A bool tensor with the same shape as `input`.
END
}
attr {
name: "pattern"
description: "The regular expression to match the input."
}
summary: "Check if the input matches the regex pattern."
description: <<END
The input is a string tensor of any shape. The pattern is the
regular expression to be matched with every element of the input tensor.
The boolean values (True or False) of the output tensor indicate
if the input matches the regex pattern provided.
The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)
END
visibility: HIDDEN
}

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@ -56,4 +57,36 @@ class RegexFullMatchOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("RegexFullMatch").Device(DEVICE_CPU),
RegexFullMatchOp);
class StaticRegexFullMatchOp : public OpKernel {
public:
explicit StaticRegexFullMatchOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
string pattern;
OP_REQUIRES_OK(ctx, ctx->GetAttr("pattern", &pattern));
re_ = MakeUnique<RE2>(pattern);
OP_REQUIRES(ctx, re_->ok(),
errors::InvalidArgument("Invalid pattern: ", pattern,
", error: ", re_->error()));
}
void Compute(OpKernelContext* ctx) override {
const Tensor* input_tensor;
OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
const auto& input_flat = input_tensor->flat<string>();
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
&output_tensor));
auto output_flat = output_tensor->flat<bool>();
for (size_t i = 0; i < input_flat.size(); ++i) {
output_flat(i) = RE2::FullMatch(input_flat(i), *re_);
}
}
private:
std::unique_ptr<RE2> re_;
};
REGISTER_KERNEL_BUILDER(Name("StaticRegexFullMatch").Device(DEVICE_CPU),
StaticRegexFullMatchOp);
} // namespace tensorflow

View File

@ -56,6 +56,12 @@ REGISTER_OP("RegexFullMatch")
return Status::OK();
});
REGISTER_OP("StaticRegexFullMatch")
.Input("input: string")
.Attr("pattern: string")
.Output("output: bool")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("StringToHashBucketFast")
.Input("input: string")
.Output("output: int64")

View File

@ -779,6 +779,7 @@ tf_py_test(
size = "small",
srcs = ["regex_full_match_op_test.py"],
additional_deps = [
"@absl_py//absl/testing:parameterized",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",

View File

@ -18,37 +18,77 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
class RegexFullMatchOpTest(test.TestCase):
@parameterized.parameters(
(gen_string_ops.regex_full_match),
(gen_string_ops.static_regex_full_match))
class RegexFullMatchOpVariantsTest(test.TestCase, parameterized.TestCase):
def testRegexFullMatch(self):
def testRegexFullMatch(self, op):
values = ["abaaba", "abcdabcde"]
with self.test_session():
input_vector = constant_op.constant(values, dtypes.string)
matched = string_ops.regex_full_match(input_vector, "a.*a").eval()
input_tensor = constant_op.constant(values, dtypes.string)
matched = op(input_tensor, "a.*a").eval()
self.assertAllEqual([True, False], matched)
def testEmptyMatch(self):
def testRegexFullMatchTwoDims(self, op):
values = [["abaaba", "abcdabcde"], ["acdcba", "ebcda"]]
with self.test_session():
input_tensor = constant_op.constant(values, dtypes.string)
matched = op(input_tensor, "a.*a").eval()
self.assertAllEqual([[True, False], [True, False]], matched)
def testEmptyMatch(self, op):
values = ["abc", "1"]
with self.test_session():
input_vector = constant_op.constant(values, dtypes.string)
matched = string_ops.regex_full_match(input_vector, "").eval()
input_tensor = constant_op.constant(values, dtypes.string)
matched = op(input_tensor, "").eval()
self.assertAllEqual([False, False], matched)
def testInvalidPattern(self):
def testInvalidPattern(self, op):
values = ["abc", "1"]
with self.test_session():
input_vector = constant_op.constant(values, dtypes.string)
input_tensor = constant_op.constant(values, dtypes.string)
invalid_pattern = "A["
matched = string_ops.regex_full_match(input_vector, invalid_pattern)
matched = op(input_tensor, invalid_pattern)
with self.assertRaisesOpError("Invalid pattern"):
matched.eval()
class RegexFullMatchOpTest(test.TestCase):
def testRegexFullMatchDelegation(self):
with compat.forward_compatibility_horizon(2018, 11, 1):
with self.test_session():
input_tensor = constant_op.constant("foo", dtypes.string)
pattern = "[a-z]"
op = string_ops.regex_full_match(input_tensor, pattern)
self.assertTrue(op.name.startswith("RegexFullMatch"), op.name)
pattern_tensor = constant_op.constant("[a-z]*", dtypes.string)
op_tensor = string_ops.regex_full_match(input_tensor, pattern_tensor)
self.assertTrue(op_tensor.name.startswith("RegexFullMatch"), op.name)
def testStaticRegexFullMatchDelegation(self):
with compat.forward_compatibility_horizon(2018, 11, 20):
with self.test_session():
input_tensor = constant_op.constant("foo", dtypes.string)
pattern = "[a-z]*"
op = string_ops.regex_full_match(input_tensor, pattern)
self.assertTrue(op.name.startswith("StaticRegexFullMatch"), op.name)
pattern_tensor = constant_op.constant("[a-z]*", dtypes.string)
op_vec = string_ops.regex_full_match(input_tensor, pattern_tensor)
self.assertTrue(op_vec.name.startswith("RegexFullMatch"), op.name)
if __name__ == "__main__":
test.main()

View File

@ -41,12 +41,41 @@ from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
# pylint: disable=redefined-builtin
def regex_full_match(input, pattern, name=None):
r"""Match elements of `input` with regex `pattern`.
Args:
input: string `Tensor`, the source strings to process.
pattern: string or scalar string `Tensor`, regular expression to use,
see more details at https://github.com/google/re2/wiki/Syntax
name: Name of the op.
Returns:
bool `Tensor` of the same shape as `input` with match results.
"""
# TODO(b/112455102): Remove compat.forward_compatible once past the horizon.
if not compat.forward_compatible(2018, 11, 10):
return gen_string_ops.regex_full_match(
input=input, pattern=pattern, name=name)
if isinstance(pattern, util_compat.bytes_or_text_types):
# When `pattern` is static through the life of the op we can
# use a version which performs the expensive regex compilation once at
# creation time.
return gen_string_ops.static_regex_full_match(
input=input, pattern=pattern, name=name)
return gen_string_ops.regex_full_match(
input=input, pattern=pattern, name=name)
regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__
# Expose regex_full_match in strings namespace
tf_export("strings.regex_full_match")(regex_full_match)
def regex_replace(source, pattern, rewrite, replace_global=True):
r"""Replace elements of `source` matching regex `pattern with `rewrite`.
r"""Replace elements of `source` matching regex `pattern` with `rewrite`.
Args:
source: string `Tensor`, the source strings to process.
@ -128,6 +157,7 @@ def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=inv
shape.set_shape([2])
return sparse_tensor.SparseTensor(indices, values, shape)
@tf_export("strings.split")
def string_split_v2(source, sep=None, maxsplit=-1):
"""Split elements of `source` based on `sep` into a `SparseTensor`.
@ -170,7 +200,7 @@ def string_split_v2(source, sep=None, maxsplit=-1):
second column corresponds to the index of the split component in this row.
"""
if sep is None:
sep = ''
sep = ""
sep = ops.convert_to_tensor(sep, dtype=dtypes.string)
source = ops.convert_to_tensor(source, dtype=dtypes.string)