Make RegexFullMatchOp and RegexReplaceOp cache RE2 objects.

PiperOrigin-RevId: 326038690
Change-Id: Ia3b0177e38a40c1514c08d8821b992aca60cd0ac
This commit is contained in:
Paul Wankadia 2020-08-11 09:39:30 -07:00 committed by TensorFlower Gardener
parent ee2c2d1781
commit b4297ce0a8
2 changed files with 67 additions and 13 deletions

View File

@ -20,6 +20,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@ -28,6 +30,8 @@ class RegexFullMatchOp : public OpKernel {
public:
explicit RegexFullMatchOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
~RegexFullMatchOp() override {}
void Compute(OpKernelContext* ctx) override {
const Tensor* input_tensor;
OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
@ -39,19 +43,43 @@ class RegexFullMatchOp : public OpKernel {
errors::InvalidArgument("Pattern must be scalar, but received ",
pattern_tensor->shape().DebugString()));
const string pattern = pattern_tensor->flat<tstring>()(0);
const RE2 match(pattern);
OP_REQUIRES(ctx, match.ok(),
std::shared_ptr<RE2> regex = CachedRE2(pattern);
OP_REQUIRES(ctx, regex->ok(),
errors::InvalidArgument("Invalid pattern: ", pattern,
", error: ", match.error()));
", error: ", regex->error()));
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
&output_tensor));
auto output_flat = output_tensor->flat<bool>();
for (size_t i = 0; i < input_flat.size(); ++i) {
output_flat(i) = RE2::FullMatch(input_flat(i), match);
output_flat(i) = RE2::FullMatch(input_flat(i), *regex);
}
}
private:
std::shared_ptr<RE2> CachedRE2(const string& pattern) {
{
tf_shared_lock l(mu_);
if (regex_ != nullptr && regex_->pattern() == pattern) {
return regex_;
}
}
// Construct the new RE2 object before acquiring the lock.
auto regex = std::make_shared<RE2>(pattern);
{
mutex_lock l(mu_);
// Swap instead of assigning so that we destruct the old
// RE2 object (when necessary) after releasing the lock.
regex_.swap(regex);
return regex_;
}
}
mutex mu_;
std::shared_ptr<RE2> regex_ TF_GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(RegexFullMatchOp);
};
REGISTER_KERNEL_BUILDER(Name("RegexFullMatch").Device(DEVICE_CPU),

View File

@ -20,6 +20,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@ -29,7 +31,7 @@ namespace {
// Context requirements:
// - "input" string Tensor at input_index=0
// - "output" string Tensor at output_index=0
Status InternalCompute(const RE2& match, const string& rewrite,
Status InternalCompute(const RE2& regex, const string& rewrite,
const bool replace_global, OpKernelContext* ctx) {
const Tensor* input_tensor;
TF_RETURN_IF_ERROR(ctx->input("input", &input_tensor));
@ -52,9 +54,9 @@ Status InternalCompute(const RE2& match, const string& rewrite,
// accept std::string.
string buf = output_flat(i);
if (replace_global) {
RE2::GlobalReplace(&buf, match, rewrite);
RE2::GlobalReplace(&buf, regex, rewrite);
} else {
RE2::Replace(&buf, match, rewrite);
RE2::Replace(&buf, regex, rewrite);
}
output_flat(i) = std::move(buf);
}
@ -68,6 +70,8 @@ class RegexReplaceOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("replace_global", &replace_global_));
}
~RegexReplaceOp() override {}
void Compute(OpKernelContext* ctx) override {
const Tensor* pattern_tensor;
OP_REQUIRES_OK(ctx, ctx->input("pattern", &pattern_tensor));
@ -75,10 +79,10 @@ class RegexReplaceOp : public OpKernel {
errors::InvalidArgument("Pattern must be scalar, but received ",
pattern_tensor->shape().DebugString()));
const string& pattern = pattern_tensor->scalar<tstring>()();
const RE2 match(pattern);
OP_REQUIRES(ctx, match.ok(),
std::shared_ptr<RE2> regex = CachedRE2(pattern);
OP_REQUIRES(ctx, regex->ok(),
errors::InvalidArgument("Invalid pattern: ", pattern,
", error: ", match.error()));
", error: ", regex->error()));
const Tensor* rewrite_tensor;
OP_REQUIRES_OK(ctx, ctx->input("rewrite", &rewrite_tensor));
@ -86,11 +90,33 @@ class RegexReplaceOp : public OpKernel {
errors::InvalidArgument("Rewrite must be scalar, but received ",
rewrite_tensor->shape().DebugString()));
const string& rewrite = rewrite_tensor->scalar<tstring>()();
OP_REQUIRES_OK(ctx, InternalCompute(match, rewrite, replace_global_, ctx));
OP_REQUIRES_OK(ctx, InternalCompute(*regex, rewrite, replace_global_, ctx));
}
private:
std::shared_ptr<RE2> CachedRE2(const string& pattern) {
{
tf_shared_lock l(mu_);
if (regex_ != nullptr && regex_->pattern() == pattern) {
return regex_;
}
}
// Construct the new RE2 object before acquiring the lock.
auto regex = std::make_shared<RE2>(pattern);
{
mutex_lock l(mu_);
// Swap instead of assigning so that we destruct the old
// RE2 object (when necessary) after releasing the lock.
regex_.swap(regex);
return regex_;
}
}
bool replace_global_;
mutex mu_;
std::shared_ptr<RE2> regex_ TF_GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(RegexReplaceOp);
};
REGISTER_KERNEL_BUILDER(Name("RegexReplace").Device(DEVICE_CPU),
@ -101,11 +127,11 @@ class StaticRegexReplaceOp : public OpKernel {
explicit StaticRegexReplaceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
string pattern;
OP_REQUIRES_OK(ctx, ctx->GetAttr("pattern", &pattern));
OP_REQUIRES_OK(ctx, ctx->GetAttr("rewrite", &rewrite_str_));
re_ = MakeUnique<RE2>(pattern);
OP_REQUIRES(ctx, re_->ok(),
errors::InvalidArgument("Invalid pattern: ", pattern,
", error: ", re_->error()));
OP_REQUIRES_OK(ctx, ctx->GetAttr("rewrite", &rewrite_str_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("replace_global", &replace_global_));
}
@ -115,8 +141,8 @@ class StaticRegexReplaceOp : public OpKernel {
}
private:
string rewrite_str_;
std::unique_ptr<RE2> re_;
string rewrite_str_;
bool replace_global_;
};