Add StaticRegexFullMatch which can be used in place of RegexFullMatch when the regex pattern are fixed.
This allows the Op to perform the expensive regex compilation once upon creation instead of with each call to compute. RELNOTES: Performance improvements for regex full match operations. PiperOrigin-RevId: 211835278
This commit is contained in:
parent
025277a159
commit
ca5952670d
@ -0,0 +1,29 @@
|
||||
op {
|
||||
graph_op_name: "StaticRegexFullMatch"
|
||||
in_arg {
|
||||
name: "input"
|
||||
description: <<END
|
||||
A string tensor of the text to be processed.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
A bool tensor with the same shape as `input`.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "pattern"
|
||||
description: "The regular expression to match the input."
|
||||
}
|
||||
summary: "Check if the input matches the regex pattern."
|
||||
description: <<END
|
||||
The input is a string tensor of any shape. The pattern is the
|
||||
regular expression to be matched with every element of the input tensor.
|
||||
The boolean values (True or False) of the output tensor indicate
|
||||
if the input matches the regex pattern provided.
|
||||
|
||||
The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)
|
||||
END
|
||||
visibility: HIDDEN
|
||||
}
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -56,4 +57,36 @@ class RegexFullMatchOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("RegexFullMatch").Device(DEVICE_CPU),
|
||||
RegexFullMatchOp);
|
||||
|
||||
class StaticRegexFullMatchOp : public OpKernel {
|
||||
public:
|
||||
explicit StaticRegexFullMatchOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
string pattern;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("pattern", &pattern));
|
||||
re_ = MakeUnique<RE2>(pattern);
|
||||
OP_REQUIRES(ctx, re_->ok(),
|
||||
errors::InvalidArgument("Invalid pattern: ", pattern,
|
||||
", error: ", re_->error()));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor* input_tensor;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
|
||||
const auto& input_flat = input_tensor->flat<string>();
|
||||
|
||||
Tensor* output_tensor = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
|
||||
&output_tensor));
|
||||
auto output_flat = output_tensor->flat<bool>();
|
||||
for (size_t i = 0; i < input_flat.size(); ++i) {
|
||||
output_flat(i) = RE2::FullMatch(input_flat(i), *re_);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<RE2> re_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StaticRegexFullMatch").Device(DEVICE_CPU),
|
||||
StaticRegexFullMatchOp);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -56,6 +56,12 @@ REGISTER_OP("RegexFullMatch")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("StaticRegexFullMatch")
|
||||
.Input("input: string")
|
||||
.Attr("pattern: string")
|
||||
.Output("output: bool")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("StringToHashBucketFast")
|
||||
.Input("input: string")
|
||||
.Output("output: int64")
|
||||
|
@ -779,6 +779,7 @@ tf_py_test(
|
||||
size = "small",
|
||||
srcs = ["regex_full_match_op_test.py"],
|
||||
additional_deps = [
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
|
@ -18,37 +18,77 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import gen_string_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class RegexFullMatchOpTest(test.TestCase):
|
||||
@parameterized.parameters(
|
||||
(gen_string_ops.regex_full_match),
|
||||
(gen_string_ops.static_regex_full_match))
|
||||
class RegexFullMatchOpVariantsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testRegexFullMatch(self):
|
||||
def testRegexFullMatch(self, op):
|
||||
values = ["abaaba", "abcdabcde"]
|
||||
with self.test_session():
|
||||
input_vector = constant_op.constant(values, dtypes.string)
|
||||
matched = string_ops.regex_full_match(input_vector, "a.*a").eval()
|
||||
input_tensor = constant_op.constant(values, dtypes.string)
|
||||
matched = op(input_tensor, "a.*a").eval()
|
||||
self.assertAllEqual([True, False], matched)
|
||||
|
||||
def testEmptyMatch(self):
|
||||
def testRegexFullMatchTwoDims(self, op):
|
||||
values = [["abaaba", "abcdabcde"], ["acdcba", "ebcda"]]
|
||||
with self.test_session():
|
||||
input_tensor = constant_op.constant(values, dtypes.string)
|
||||
matched = op(input_tensor, "a.*a").eval()
|
||||
self.assertAllEqual([[True, False], [True, False]], matched)
|
||||
|
||||
def testEmptyMatch(self, op):
|
||||
values = ["abc", "1"]
|
||||
with self.test_session():
|
||||
input_vector = constant_op.constant(values, dtypes.string)
|
||||
matched = string_ops.regex_full_match(input_vector, "").eval()
|
||||
input_tensor = constant_op.constant(values, dtypes.string)
|
||||
matched = op(input_tensor, "").eval()
|
||||
self.assertAllEqual([False, False], matched)
|
||||
|
||||
def testInvalidPattern(self):
|
||||
def testInvalidPattern(self, op):
|
||||
values = ["abc", "1"]
|
||||
with self.test_session():
|
||||
input_vector = constant_op.constant(values, dtypes.string)
|
||||
input_tensor = constant_op.constant(values, dtypes.string)
|
||||
invalid_pattern = "A["
|
||||
matched = string_ops.regex_full_match(input_vector, invalid_pattern)
|
||||
matched = op(input_tensor, invalid_pattern)
|
||||
with self.assertRaisesOpError("Invalid pattern"):
|
||||
matched.eval()
|
||||
|
||||
|
||||
class RegexFullMatchOpTest(test.TestCase):
|
||||
|
||||
def testRegexFullMatchDelegation(self):
|
||||
with compat.forward_compatibility_horizon(2018, 11, 1):
|
||||
with self.test_session():
|
||||
input_tensor = constant_op.constant("foo", dtypes.string)
|
||||
pattern = "[a-z]"
|
||||
op = string_ops.regex_full_match(input_tensor, pattern)
|
||||
self.assertTrue(op.name.startswith("RegexFullMatch"), op.name)
|
||||
|
||||
pattern_tensor = constant_op.constant("[a-z]*", dtypes.string)
|
||||
op_tensor = string_ops.regex_full_match(input_tensor, pattern_tensor)
|
||||
self.assertTrue(op_tensor.name.startswith("RegexFullMatch"), op.name)
|
||||
|
||||
def testStaticRegexFullMatchDelegation(self):
|
||||
with compat.forward_compatibility_horizon(2018, 11, 20):
|
||||
with self.test_session():
|
||||
input_tensor = constant_op.constant("foo", dtypes.string)
|
||||
pattern = "[a-z]*"
|
||||
op = string_ops.regex_full_match(input_tensor, pattern)
|
||||
self.assertTrue(op.name.startswith("StaticRegexFullMatch"), op.name)
|
||||
|
||||
pattern_tensor = constant_op.constant("[a-z]*", dtypes.string)
|
||||
op_vec = string_ops.regex_full_match(input_tensor, pattern_tensor)
|
||||
self.assertTrue(op_vec.name.startswith("RegexFullMatch"), op.name)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -41,12 +41,41 @@ from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
# pylint: enable=wildcard-import
|
||||
|
||||
|
||||
# pylint: disable=redefined-builtin
|
||||
def regex_full_match(input, pattern, name=None):
|
||||
r"""Match elements of `input` with regex `pattern`.
|
||||
|
||||
Args:
|
||||
input: string `Tensor`, the source strings to process.
|
||||
pattern: string or scalar string `Tensor`, regular expression to use,
|
||||
see more details at https://github.com/google/re2/wiki/Syntax
|
||||
name: Name of the op.
|
||||
|
||||
Returns:
|
||||
bool `Tensor` of the same shape as `input` with match results.
|
||||
"""
|
||||
# TODO(b/112455102): Remove compat.forward_compatible once past the horizon.
|
||||
if not compat.forward_compatible(2018, 11, 10):
|
||||
return gen_string_ops.regex_full_match(
|
||||
input=input, pattern=pattern, name=name)
|
||||
if isinstance(pattern, util_compat.bytes_or_text_types):
|
||||
# When `pattern` is static through the life of the op we can
|
||||
# use a version which performs the expensive regex compilation once at
|
||||
# creation time.
|
||||
return gen_string_ops.static_regex_full_match(
|
||||
input=input, pattern=pattern, name=name)
|
||||
return gen_string_ops.regex_full_match(
|
||||
input=input, pattern=pattern, name=name)
|
||||
|
||||
regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__
|
||||
|
||||
# Expose regex_full_match in strings namespace
|
||||
tf_export("strings.regex_full_match")(regex_full_match)
|
||||
|
||||
|
||||
def regex_replace(source, pattern, rewrite, replace_global=True):
|
||||
r"""Replace elements of `source` matching regex `pattern with `rewrite`.
|
||||
r"""Replace elements of `source` matching regex `pattern` with `rewrite`.
|
||||
|
||||
Args:
|
||||
source: string `Tensor`, the source strings to process.
|
||||
@ -128,6 +157,7 @@ def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=inv
|
||||
shape.set_shape([2])
|
||||
return sparse_tensor.SparseTensor(indices, values, shape)
|
||||
|
||||
|
||||
@tf_export("strings.split")
|
||||
def string_split_v2(source, sep=None, maxsplit=-1):
|
||||
"""Split elements of `source` based on `sep` into a `SparseTensor`.
|
||||
@ -170,7 +200,7 @@ def string_split_v2(source, sep=None, maxsplit=-1):
|
||||
second column corresponds to the index of the split component in this row.
|
||||
"""
|
||||
if sep is None:
|
||||
sep = ''
|
||||
sep = ""
|
||||
sep = ops.convert_to_tensor(sep, dtype=dtypes.string)
|
||||
source = ops.convert_to_tensor(source, dtype=dtypes.string)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user