Make RegexFullMatchOp and RegexReplaceOp cache RE2 objects.
PiperOrigin-RevId: 326038690 Change-Id: Ia3b0177e38a40c1514c08d8821b992aca60cd0ac
This commit is contained in:
parent
ee2c2d1781
commit
b4297ce0a8
@ -20,6 +20,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -28,6 +30,8 @@ class RegexFullMatchOp : public OpKernel {
|
||||
public:
|
||||
explicit RegexFullMatchOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
~RegexFullMatchOp() override {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor* input_tensor;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
|
||||
@ -39,19 +43,43 @@ class RegexFullMatchOp : public OpKernel {
|
||||
errors::InvalidArgument("Pattern must be scalar, but received ",
|
||||
pattern_tensor->shape().DebugString()));
|
||||
const string pattern = pattern_tensor->flat<tstring>()(0);
|
||||
const RE2 match(pattern);
|
||||
OP_REQUIRES(ctx, match.ok(),
|
||||
std::shared_ptr<RE2> regex = CachedRE2(pattern);
|
||||
OP_REQUIRES(ctx, regex->ok(),
|
||||
errors::InvalidArgument("Invalid pattern: ", pattern,
|
||||
", error: ", match.error()));
|
||||
", error: ", regex->error()));
|
||||
|
||||
Tensor* output_tensor = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
|
||||
&output_tensor));
|
||||
auto output_flat = output_tensor->flat<bool>();
|
||||
for (size_t i = 0; i < input_flat.size(); ++i) {
|
||||
output_flat(i) = RE2::FullMatch(input_flat(i), match);
|
||||
output_flat(i) = RE2::FullMatch(input_flat(i), *regex);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<RE2> CachedRE2(const string& pattern) {
|
||||
{
|
||||
tf_shared_lock l(mu_);
|
||||
if (regex_ != nullptr && regex_->pattern() == pattern) {
|
||||
return regex_;
|
||||
}
|
||||
}
|
||||
// Construct the new RE2 object before acquiring the lock.
|
||||
auto regex = std::make_shared<RE2>(pattern);
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
// Swap instead of assigning so that we destruct the old
|
||||
// RE2 object (when necessary) after releasing the lock.
|
||||
regex_.swap(regex);
|
||||
return regex_;
|
||||
}
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
std::shared_ptr<RE2> regex_ TF_GUARDED_BY(mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RegexFullMatchOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RegexFullMatch").Device(DEVICE_CPU),
|
||||
|
@ -20,6 +20,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -29,7 +31,7 @@ namespace {
|
||||
// Context requirements:
|
||||
// - "input" string Tensor at input_index=0
|
||||
// - "output" string Tensor at output_index=0
|
||||
Status InternalCompute(const RE2& match, const string& rewrite,
|
||||
Status InternalCompute(const RE2& regex, const string& rewrite,
|
||||
const bool replace_global, OpKernelContext* ctx) {
|
||||
const Tensor* input_tensor;
|
||||
TF_RETURN_IF_ERROR(ctx->input("input", &input_tensor));
|
||||
@ -52,9 +54,9 @@ Status InternalCompute(const RE2& match, const string& rewrite,
|
||||
// accept std::string.
|
||||
string buf = output_flat(i);
|
||||
if (replace_global) {
|
||||
RE2::GlobalReplace(&buf, match, rewrite);
|
||||
RE2::GlobalReplace(&buf, regex, rewrite);
|
||||
} else {
|
||||
RE2::Replace(&buf, match, rewrite);
|
||||
RE2::Replace(&buf, regex, rewrite);
|
||||
}
|
||||
output_flat(i) = std::move(buf);
|
||||
}
|
||||
@ -68,6 +70,8 @@ class RegexReplaceOp : public OpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("replace_global", &replace_global_));
|
||||
}
|
||||
|
||||
~RegexReplaceOp() override {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor* pattern_tensor;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("pattern", &pattern_tensor));
|
||||
@ -75,10 +79,10 @@ class RegexReplaceOp : public OpKernel {
|
||||
errors::InvalidArgument("Pattern must be scalar, but received ",
|
||||
pattern_tensor->shape().DebugString()));
|
||||
const string& pattern = pattern_tensor->scalar<tstring>()();
|
||||
const RE2 match(pattern);
|
||||
OP_REQUIRES(ctx, match.ok(),
|
||||
std::shared_ptr<RE2> regex = CachedRE2(pattern);
|
||||
OP_REQUIRES(ctx, regex->ok(),
|
||||
errors::InvalidArgument("Invalid pattern: ", pattern,
|
||||
", error: ", match.error()));
|
||||
", error: ", regex->error()));
|
||||
|
||||
const Tensor* rewrite_tensor;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("rewrite", &rewrite_tensor));
|
||||
@ -86,11 +90,33 @@ class RegexReplaceOp : public OpKernel {
|
||||
errors::InvalidArgument("Rewrite must be scalar, but received ",
|
||||
rewrite_tensor->shape().DebugString()));
|
||||
const string& rewrite = rewrite_tensor->scalar<tstring>()();
|
||||
OP_REQUIRES_OK(ctx, InternalCompute(match, rewrite, replace_global_, ctx));
|
||||
OP_REQUIRES_OK(ctx, InternalCompute(*regex, rewrite, replace_global_, ctx));
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<RE2> CachedRE2(const string& pattern) {
|
||||
{
|
||||
tf_shared_lock l(mu_);
|
||||
if (regex_ != nullptr && regex_->pattern() == pattern) {
|
||||
return regex_;
|
||||
}
|
||||
}
|
||||
// Construct the new RE2 object before acquiring the lock.
|
||||
auto regex = std::make_shared<RE2>(pattern);
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
// Swap instead of assigning so that we destruct the old
|
||||
// RE2 object (when necessary) after releasing the lock.
|
||||
regex_.swap(regex);
|
||||
return regex_;
|
||||
}
|
||||
}
|
||||
|
||||
bool replace_global_;
|
||||
mutex mu_;
|
||||
std::shared_ptr<RE2> regex_ TF_GUARDED_BY(mu_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RegexReplaceOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RegexReplace").Device(DEVICE_CPU),
|
||||
@ -101,11 +127,11 @@ class StaticRegexReplaceOp : public OpKernel {
|
||||
explicit StaticRegexReplaceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
string pattern;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("pattern", &pattern));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("rewrite", &rewrite_str_));
|
||||
re_ = MakeUnique<RE2>(pattern);
|
||||
OP_REQUIRES(ctx, re_->ok(),
|
||||
errors::InvalidArgument("Invalid pattern: ", pattern,
|
||||
", error: ", re_->error()));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("rewrite", &rewrite_str_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("replace_global", &replace_global_));
|
||||
}
|
||||
|
||||
@ -115,8 +141,8 @@ class StaticRegexReplaceOp : public OpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
string rewrite_str_;
|
||||
std::unique_ptr<RE2> re_;
|
||||
string rewrite_str_;
|
||||
bool replace_global_;
|
||||
};
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user