Add pass config to turn on/off fusing tftext

PiperOrigin-RevId: 319318564
Change-Id: I9374dfd3efb3da27ca87048d176421145ed96b3a
This commit is contained in:
A. Unique TensorFlower 2020-07-01 16:33:24 -07:00 committed by TensorFlower Gardener
parent f0f90ef9c0
commit 7c224307e5
5 changed files with 88 additions and 4 deletions

View File

@ -273,6 +273,7 @@ cc_library(
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:framework",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
@ -280,6 +281,18 @@ cc_library(
],
)
tf_cc_test(
name = "tftext_utils_test",
size = "small",
srcs = ["utils/tftext_utils_test.cc"],
deps = [
":tftext_utils",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "stateful_ops_utils",
srcs = [

View File

@ -46,7 +46,7 @@ limitations under the License.
// The cmd line flag to turn on/off Tf.Text API fusion.
// NOLINTNEXTLINE
static llvm::cl::opt<bool> fuse_tftext(
static llvm::cl::opt<bool> fuse_tftext_flag(
"tfl-fuse-tftext", llvm::cl::value_desc("bool"),
llvm::cl::desc("Fuse TF.Text API ops when it's true"),
llvm::cl::init(false));
@ -194,9 +194,12 @@ void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func,
OpBuilder builder(func.getBody());
if (failed(ConvertKerasLSTMLayer(func, &builder)))
return signalPassFailure();
} else if (fuse_tftext && attr.getValue().startswith(kTfTextAPIPRefix)) {
if (failed(ConvertTFTextAPI(func, attr.getValue()))) {
return signalPassFailure();
} else if (fuse_tftext_flag ||
IsTfTextRegistered(tensorflow::OpRegistry::Global())) {
if (attr.getValue().startswith(kTfTextAPIPRefix)) {
if (failed(ConvertTFTextAPI(func, attr.getValue()))) {
return signalPassFailure();
}
}
}
}

View File

@ -146,5 +146,17 @@ LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api) {
return failure();
}
bool IsTfTextRegistered(const tensorflow::OpRegistry* op_registery) {
const std::vector<std::string> kTfTextOps = {
"WhitespaceTokenizeWithOffsets",
};
for (const auto& iter : kTfTextOps) {
if (op_registery->LookUp(iter)) {
return true;
}
}
return false;
}
} // namespace TFL
} // namespace mlir

View File

@ -27,12 +27,15 @@ limitations under the License.
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/core/framework/op.h"
namespace mlir {
namespace TFL {
LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api);
bool IsTfTextRegistered(const tensorflow::OpRegistry* op_registery);
} // end namespace TFL
} // end namespace mlir

View File

@ -0,0 +1,53 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
#include <memory>
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace mlir {
namespace TFL {
using tensorflow::OpRegistrationData;
using tensorflow::OpRegistry;
using tensorflow::Status;
namespace {
void Register(const std::string& op_name, OpRegistry* registry) {
registry->Register([op_name](OpRegistrationData* op_reg_data) -> Status {
op_reg_data->op_def.set_name(op_name);
return Status::OK();
});
}
} // namespace
TEST(TfTextUtilsTest, TestTfTextRegistered) {
std::unique_ptr<OpRegistry> registry(new OpRegistry);
Register("WhitespaceTokenizeWithOffsets", registry.get());
EXPECT_TRUE(IsTfTextRegistered(registry.get()));
}
TEST(TfTextUtilsTest, TestTfTextNotRegistered) {
std::unique_ptr<OpRegistry> registry(new OpRegistry);
Register("Test", registry.get());
EXPECT_FALSE(IsTfTextRegistered(registry.get()));
}
} // namespace TFL
} // namespace mlir