Fuse tf.text WhitespaceTokenizer to tflite custom op

PiperOrigin-RevId: 312612112
Change-Id: Ia7142d64948a4e41f795ee1f64ecd004bcbf9be0
This commit is contained in:
A. Unique TensorFlower 2020-05-20 21:16:15 -07:00 committed by TensorFlower Gardener
parent 1df42d1bf3
commit 41224dad54
5 changed files with 227 additions and 0 deletions

View File

@ -260,6 +260,41 @@ cc_library(
],
)
cc_library(
name = "tftext_utils",
srcs = [
"utils/tftext_utils.cc",
],
hdrs = [
"utils/tftext_utils.h",
],
copts = ["-std=c++14"],
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
tf_cc_test(
name = "tftext_utils_test",
size = "small",
srcs = ["utils/lstm_utils_test.cc"],
deps = [
":lstm_utils",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
cc_library(
name = "stateful_ops_utils",
srcs = [
@ -320,6 +355,7 @@ cc_library(
":lstm_utils",
":stateful_ops_utils",
":tensorflow_lite",
":tftext_utils",
":validators",
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",

View File

@ -0,0 +1,14 @@
// RUN: tf-opt -tfl-prepare-composite-funcs-tf -tfl-fuse-tftext=true %s -split-input-file | FileCheck %s --dump-input-on-failure
module {
func @_whitespace_func(%arg0: tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._GrapplerSpecializedFunc = true, tf._input_shapes = [#tf.shape<1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} {
%0 = "tf.op1"(%arg0) : (tensor<1x!tf.string>) -> (tensor<?x!tf.string>)
%1 = "tf.Const"() {value = dense<-1> : tensor<i64>} : () -> tensor<?xi64>
%2:2 = "tf.op2"(%arg0, %1) : (tensor<1x!tf.string>, tensor<?xi64>) -> (tensor<?x!tf.string>, tensor<?xi64>)
return %2#0, %2#1 : tensor<?x!tf.string>, tensor<?xi64>
}
// CHECK: func @_whitespace_func(%arg0: tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._GrapplerSpecializedFunc = true, tf._input_shapes = [#tf.shape<1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} {
// CHECK: "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>)
// CHECK: return %0#0, %0#1 : tensor<?x!tf.string>, tensor<?xi64>
}

View File

@ -41,15 +41,22 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
// The cmd line flag to turn on/off Tf.Text API fusion.
// NOLINTNEXTLINE
static llvm::cl::opt<bool> fuse_tftext(
"tfl-fuse-tftext", llvm::cl::value_desc("bool"),
llvm::cl::desc("Fuse TF.Text API ops when it's true"),
llvm::cl::init(false));
namespace mlir {
namespace TFL {
namespace {
constexpr char kTFAPIImplements[] = "tf.api_implements";
constexpr char kTfTextAPIPRefix[] = "tftext:";
// Abstracts the conversion of the embedded lookup composite function.
class ConvertEmbeddedLookupFunc {
@ -187,6 +194,10 @@ void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func,
OpBuilder builder(func.getBody());
if (failed(ConvertKerasLSTMLayer(func, &builder)))
return signalPassFailure();
} else if (fuse_tftext && attr.getValue().startswith(kTfTextAPIPRefix)) {
if (failed(ConvertTFTextAPI(func, attr.getValue()))) {
return signalPassFailure();
}
}
}

View File

@ -0,0 +1,127 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Identifier.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TFL {
namespace {
constexpr char kWhitespaceTokenizer[] = "tftext:WhitespaceTokenizer";
constexpr char kTFAPIImplements[] = "tf.api_implements";
inline OpaqueElementsAttr emptyCustomOption(OpBuilder* builder) {
std::string content = "";
ShapedType type = RankedTensorType::get(
{static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
return OpaqueElementsAttr::get(
builder->getContext()->getRegisteredDialect("tfl"), type, content);
}
inline RankedTensorType getInputType(mlir::FuncOp func, int idx) {
return func.getType()
.getInput(idx)
.dyn_cast_or_null<mlir::RankedTensorType>();
}
inline RankedTensorType getResultType(mlir::FuncOp func, int idx) {
return func.getType()
.getResult(idx)
.dyn_cast_or_null<mlir::RankedTensorType>();
}
LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
if (func.getNumResults() != 2) {
return failure();
}
if (func.getNumArguments() != 1) {
return failure();
}
auto input_type = getInputType(func, 0);
if (!input_type || input_type.getRank() != 1 ||
!input_type.getElementType().isa<mlir::TF::StringType>()) {
return failure();
}
auto value_type = getResultType(func, 0);
if (!value_type || value_type.getRank() != 1 ||
!value_type.getElementType().isa<mlir::TF::StringType>()) {
return failure();
}
auto offset_type = getResultType(func, 1);
if (offset_type.getRank() != 1 ||
!offset_type.getElementType().isInteger(64)) {
return failure();
}
return success();
}
LogicalResult ConvertWhitespaceTokenizer(mlir::FuncOp func,
llvm::StringRef api) {
func.eraseBody();
func.addEntryBlock();
func.setAttr(kTFAPIImplements, StringAttr::get(api, func.getContext()));
Value text = func.getArgument(0);
auto output_type = func.getType().getResult(0);
auto offset_type = func.getType().getResult(1);
SmallVector<Type, 2> shape = {output_type, offset_type};
ArrayRef<Type> output_types(shape);
OpBuilder builder(func.getBody());
auto op = builder.create<mlir::TFL::CustomOp>(func.getLoc(), output_types,
ValueRange(text), api,
emptyCustomOption(&builder));
builder.create<mlir::ReturnOp>(func.getLoc(), op.getResults());
return success();
}
} // namespace
LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api) {
if (api.str() == kWhitespaceTokenizer) {
if (succeeded(VerifyWhitespaceTokenizer(func))) {
return ConvertWhitespaceTokenizer(func, api);
}
}
return failure();
}
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This header file defines common utils used by TFLite transformation
// passes to work with op attributes.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
namespace mlir {
namespace TFL {
LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api);
} // end namespace TFL
} // end namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_