Fuse tensorflow_text.ngrams into a TFLite custom op

PiperOrigin-RevId: 323456482
Change-Id: Idfd446c371e8a4a4f82b6da730d02b0897d35a8a
This commit is contained in:
A. Unique TensorFlower 2020-07-27 15:42:20 -07:00 committed by TensorFlower Gardener
parent c7e7f49228
commit 07e4db17ff
3 changed files with 3602 additions and 3224 deletions

View File

@ -270,6 +270,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
"//tensorflow/core:framework",
"@flatbuffers",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",

File diff suppressed because it is too large Load Diff

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/SmallVector.h"
@ -28,6 +29,7 @@ limitations under the License.
#include "mlir/IR/Identifier.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
@ -43,32 +45,35 @@ namespace TFL {
namespace {
constexpr char kNgrams[] = "tftext:Ngrams";
constexpr char kWhitespaceTokenizer[] = "tftext:WhitespaceTokenizer";
constexpr char kTFImplements[] = "tf._implements";
using mlir::TF::FuncAttr;
using mlir::TF::StringType;
inline OpaqueElementsAttr emptyCustomOption(OpBuilder* builder) {
std::string content = "";
inline OpaqueElementsAttr CustomOption(OpBuilder* builder,
const std::string& content) {
ShapedType type = RankedTensorType::get(
{static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
return OpaqueElementsAttr::get(
builder->getContext()->getRegisteredDialect("tfl"), type, content);
builder->getContext()->getRegisteredDialect("tfl"), type,
StringRef(content.data(), content.size()));
}
inline RankedTensorType getInputType(mlir::FuncOp func, int idx) {
return func.getType()
.getInput(idx)
.dyn_cast_or_null<mlir::RankedTensorType>();
inline TensorType GetInputType(FuncOp func, int idx) {
return func.getType().getInput(idx).dyn_cast_or_null<TensorType>();
}
inline RankedTensorType getResultType(mlir::FuncOp func, int idx) {
return func.getType()
.getResult(idx)
.dyn_cast_or_null<mlir::RankedTensorType>();
inline TensorType GetResultType(FuncOp func, int idx) {
return func.getType().getResult(idx).dyn_cast_or_null<TensorType>();
}
LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
inline bool RankEquals(const TensorType& type, int rank) {
return type && type.hasRank() && type.getRank() == rank;
}
LogicalResult VerifyWhitespaceTokenizer(FuncOp func) {
// In the case of input tensor with 0 rank.
// Whitespace tokenizer generates 1 output:
// * String tensor for tokens.
@ -83,8 +88,8 @@ LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
// * 1st output is the value of ragged tensor;
// * 2nd output is the inner offset;
// * 3rd output is the outer offset.
auto input_type = getInputType(func, 0);
if (!input_type || !input_type.getElementType().isa<mlir::TF::StringType>() ||
auto input_type = GetInputType(func, 0);
if (!input_type || !input_type.getElementType().isa<StringType>() ||
!input_type.hasRank()) {
return func.emitError() << "Input should be a string tensor";
}
@ -100,21 +105,21 @@ LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
<< "output(s) when input has rank " << input_type.getRank();
}
auto value_type = getResultType(func, 0);
if (!value_type || !value_type.hasRank() || value_type.getRank() != 1 ||
!value_type.getElementType().isa<mlir::TF::StringType>()) {
auto value_type = GetResultType(func, 0);
if (!RankEquals(value_type, 1) ||
!value_type.getElementType().isa<StringType>()) {
return func.emitError() << "1st output should be string tensor";
}
if (func.getNumResults() > 1) {
auto offset_type = getResultType(func, 1);
if (!offset_type || !offset_type.hasRank() || offset_type.getRank() != 1 ||
auto offset_type = GetResultType(func, 1);
if (!RankEquals(offset_type, 1) ||
!offset_type.getElementType().isInteger(64)) {
return func.emitError() << "2nd output should be int64 tensor";
}
}
if (func.getNumResults() > 2) {
auto offset_type = getResultType(func, 2);
if (!offset_type || !offset_type.hasRank() || offset_type.getRank() != 1 ||
auto offset_type = GetResultType(func, 2);
if (!RankEquals(offset_type, 1) ||
!offset_type.getElementType().isInteger(64)) {
return func.emitError() << "3rd output should be int64 tensor";
}
@ -123,28 +128,159 @@ LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
return success();
}
LogicalResult ConvertWhitespaceTokenizer(mlir::FuncOp func, llvm::StringRef api,
LogicalResult ConvertWhitespaceTokenizer(FuncOp func, llvm::StringRef api,
FuncAttr attr) {
func.eraseBody();
func.addEntryBlock();
func.setAttr(kTFImplements, attr);
Value text = func.getArgument(0);
OpBuilder builder(func.getBody());
auto op = builder.create<mlir::TFL::CustomOp>(
func.getLoc(), func.getType().getResults(), ValueRange(text), api,
emptyCustomOption(&builder));
builder.create<mlir::ReturnOp>(func.getLoc(), op.getResults());
std::string empty_option_buffer;
auto op = builder.create<CustomOp>(
func.getLoc(), func.getType().getResults(), func.getArguments(), api,
CustomOption(&builder, empty_option_buffer));
builder.create<ReturnOp>(func.getLoc(), op.getResults());
return success();
}
LogicalResult VerifyNgrams(FuncOp func) {
// The inputs and outputs should be the same:
// * A string tensor for tokens/ragged tensor values.
// * Zero or more row_split tensors.
constexpr int kValues = 0;
constexpr int kRowSplits = 1;
if (func.getType().getInputs().size() != func.getType().getResults().size()) {
return func.emitError() << "Mismatched number of inputs and outputs.";
}
int row_splits = func.getType().getInputs().size() - kRowSplits;
if (row_splits == 0) {
auto input_values = GetInputType(func, kValues);
if (!input_values || !input_values.getElementType().isa<StringType>()) {
return func.emitError()
<< "Input " << kValues << " should be a string tensor";
}
auto output_values = GetResultType(func, kValues);
if (!output_values || !output_values.getElementType().isa<StringType>()) {
return func.emitError()
<< "Output " << kValues << " should be a string tensor";
}
if (input_values.hasRank() && output_values.hasRank() &&
input_values.getRank() != output_values.getRank()) {
return func.emitError() << "Input " << kValues << " and output "
<< kValues << " should have the same rank";
}
} else {
auto input_values = GetInputType(func, kValues);
if (!RankEquals(input_values, 1) ||
!input_values.getElementType().isa<StringType>()) {
return func.emitError()
<< "Input " << kValues << " should be a 1D string tensor";
}
auto output_values = GetResultType(func, kValues);
if (!RankEquals(output_values, 1) ||
!output_values.getElementType().isa<StringType>()) {
return func.emitError()
<< "Output " << kValues << " should be a 1D string tensor";
}
for (int i = 0; i < row_splits; ++i) {
const int row_index = i + kRowSplits;
auto input_row_splits = GetInputType(func, row_index);
if (!RankEquals(input_row_splits, 1) ||
!input_row_splits.getElementType().isInteger(64)) {
return func.emitError()
<< "Input " << row_index << " should be a 1D int64 tensor";
}
auto output_row_splits = GetResultType(func, row_index);
if (!RankEquals(output_row_splits, 1) ||
!output_row_splits.getElementType().isInteger(64)) {
return func.emitError()
<< "Output " << row_index << " should be a 1D int64 tensor";
}
}
}
return success();
}
LogicalResult CreateNgramsCustomOption(FuncOp func, DictionaryAttr attrs,
std::string& custom_option_buffer) {
flexbuffers::Builder fbb;
size_t start_map = fbb.StartMap();
auto width = attrs.get("width").dyn_cast_or_null<IntegerAttr>();
if (!width) {
return func.emitError() << "'width' attribute is not set or not an integer";
}
fbb.Int("width", width.getInt());
auto string_separator =
attrs.get("string_separator").dyn_cast_or_null<StringAttr>();
if (!string_separator) {
return func.emitError()
<< "'string_separator' attribute is not set or not a string";
}
// StringAttrs are not guaranteed to be NUL terminated, but flexbuffers
// strings expect NUL terminated strings.
std::string string_separator_str(string_separator.getValue().data(),
string_separator.getValue().size());
fbb.String("string_separator", string_separator_str);
auto axis = attrs.get("axis").dyn_cast_or_null<IntegerAttr>();
if (!axis) {
return func.emitError() << "'axis' attribute is not set or not an integer";
}
fbb.Int("axis", axis.getInt());
auto reduction_type =
attrs.get("reduction_type").dyn_cast_or_null<StringAttr>();
if (!reduction_type) {
return func.emitError()
<< "'reduction_type' attribute is not set or not a string";
}
// StringAttrs are not guaranteed to be NUL terminated, but flexbuffers
// strings expect NUL terminated strings.
std::string reduction_type_str(reduction_type.getValue().data(),
reduction_type.getValue().size());
fbb.String("reduction_type", reduction_type_str);
fbb.EndMap(start_map);
fbb.Finish();
custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
return success();
}
LogicalResult ConvertNgrams(FuncOp func, llvm::StringRef api, FuncAttr attr) {
func.eraseBody();
func.addEntryBlock();
func.setAttr(kTFImplements, attr);
OpBuilder builder(func.getBody());
std::string custom_option_buffer;
if (failed(CreateNgramsCustomOption(func, attr.GetAttrs(),
custom_option_buffer))) {
return failure();
}
auto op = builder.create<CustomOp>(
func.getLoc(), func.getType().getResults(), func.getArguments(), api,
CustomOption(&builder, custom_option_buffer));
builder.create<ReturnOp>(func.getLoc(), op.getResults());
return success();
}
} // namespace
LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api,
LogicalResult ConvertTFTextAPI(FuncOp func, llvm::StringRef api,
FuncAttr attr) {
if (api.str() == kWhitespaceTokenizer) {
if (succeeded(VerifyWhitespaceTokenizer(func))) {
return ConvertWhitespaceTokenizer(func, api, attr);
}
} else if (api.str() == kNgrams) {
if (succeeded(VerifyNgrams(func))) {
return ConvertNgrams(func, api, attr);
}
}
return failure();
}