From 7c224307e5ac16f48203b53b2f13ad9d502e4810 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 1 Jul 2020 16:33:24 -0700 Subject: [PATCH] Add pass config to turn on/off fusing tftext PiperOrigin-RevId: 319318564 Change-Id: I9374dfd3efb3da27ca87048d176421145ed96b3a --- tensorflow/compiler/mlir/lite/BUILD | 13 +++++ .../prepare_composite_functions_tf.cc | 11 ++-- .../compiler/mlir/lite/utils/tftext_utils.cc | 12 +++++ .../compiler/mlir/lite/utils/tftext_utils.h | 3 ++ .../mlir/lite/utils/tftext_utils_test.cc | 53 +++++++++++++++++++ 5 files changed, 88 insertions(+), 4 deletions(-) create mode 100644 tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 719293617aa..b1f44889964 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -273,6 +273,7 @@ cc_library( deps = [ ":tensorflow_lite", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/core:framework", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:StandardOps", @@ -280,6 +281,18 @@ cc_library( ], ) +tf_cc_test( + name = "tftext_utils_test", + size = "small", + srcs = ["utils/tftext_utils_test.cc"], + deps = [ + ":tftext_utils", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "stateful_ops_utils", srcs = [ diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index 56af68f6bbe..221e8c70cd7 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -46,7 +46,7 @@ limitations under the License. // The cmd line flag to turn on/off Tf.Text API fusion. // NOLINTNEXTLINE -static llvm::cl::opt fuse_tftext( +static llvm::cl::opt fuse_tftext_flag( "tfl-fuse-tftext", llvm::cl::value_desc("bool"), llvm::cl::desc("Fuse TF.Text API ops when it's true"), llvm::cl::init(false)); @@ -194,9 +194,12 @@ void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func, OpBuilder builder(func.getBody()); if (failed(ConvertKerasLSTMLayer(func, &builder))) return signalPassFailure(); - } else if (fuse_tftext && attr.getValue().startswith(kTfTextAPIPRefix)) { - if (failed(ConvertTFTextAPI(func, attr.getValue()))) { - return signalPassFailure(); + } else if (fuse_tftext_flag || + IsTfTextRegistered(tensorflow::OpRegistry::Global())) { + if (attr.getValue().startswith(kTfTextAPIPRefix)) { + if (failed(ConvertTFTextAPI(func, attr.getValue()))) { + return signalPassFailure(); + } } } } diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc index cb671c7cd70..2ed0891dc59 100644 --- a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc @@ -146,5 +146,17 @@ LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api) { return failure(); } +bool IsTfTextRegistered(const tensorflow::OpRegistry* op_registery) { + const std::vector kTfTextOps = { + "WhitespaceTokenizeWithOffsets", + }; + for (const auto& iter : kTfTextOps) { + if (op_registery->LookUp(iter)) { + return true; + } + } + return false; +} + } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils.h b/tensorflow/compiler/mlir/lite/utils/tftext_utils.h index 283e57c179a..c52ee019d8d 100644 --- a/tensorflow/compiler/mlir/lite/utils/tftext_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils.h @@ -27,12 +27,15 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/core/framework/op.h" namespace mlir { namespace TFL { LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api); +bool IsTfTextRegistered(const tensorflow::OpRegistry* op_registery); + } // end namespace TFL } // end namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc new file mode 100644 index 00000000000..7d29264aaae --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc @@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h" + +#include + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace mlir { +namespace TFL { + +using tensorflow::OpRegistrationData; +using tensorflow::OpRegistry; +using tensorflow::Status; + +namespace { + +void Register(const std::string& op_name, OpRegistry* registry) { + registry->Register([op_name](OpRegistrationData* op_reg_data) -> Status { + op_reg_data->op_def.set_name(op_name); + return Status::OK(); + }); +} + +} // namespace + +TEST(TfTextUtilsTest, TestTfTextRegistered) { + std::unique_ptr registry(new OpRegistry); + Register("WhitespaceTokenizeWithOffsets", registry.get()); + EXPECT_TRUE(IsTfTextRegistered(registry.get())); +} + +TEST(TfTextUtilsTest, TestTfTextNotRegistered) { + std::unique_ptr registry(new OpRegistry); + Register("Test", registry.get()); + EXPECT_FALSE(IsTfTextRegistered(registry.get())); +} +} // namespace TFL +} // namespace mlir