Fuse tensorflow_text.ngrams into a TFLite custom op
PiperOrigin-RevId: 323456482 Change-Id: Idfd446c371e8a4a4f82b6da730d02b0897d35a8a
This commit is contained in:
parent
c7e7f49228
commit
07e4db17ff
@ -270,6 +270,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
|
||||
"//tensorflow/core:framework",
|
||||
"@flatbuffers",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
|
||||
|
||||
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
@ -28,6 +29,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Identifier.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/OpDefinition.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
@ -43,32 +45,35 @@ namespace TFL {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kNgrams[] = "tftext:Ngrams";
|
||||
constexpr char kWhitespaceTokenizer[] = "tftext:WhitespaceTokenizer";
|
||||
constexpr char kTFImplements[] = "tf._implements";
|
||||
|
||||
using mlir::TF::FuncAttr;
|
||||
using mlir::TF::StringType;
|
||||
|
||||
inline OpaqueElementsAttr emptyCustomOption(OpBuilder* builder) {
|
||||
std::string content = "";
|
||||
inline OpaqueElementsAttr CustomOption(OpBuilder* builder,
|
||||
const std::string& content) {
|
||||
ShapedType type = RankedTensorType::get(
|
||||
{static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
|
||||
return OpaqueElementsAttr::get(
|
||||
builder->getContext()->getRegisteredDialect("tfl"), type, content);
|
||||
builder->getContext()->getRegisteredDialect("tfl"), type,
|
||||
StringRef(content.data(), content.size()));
|
||||
}
|
||||
|
||||
inline RankedTensorType getInputType(mlir::FuncOp func, int idx) {
|
||||
return func.getType()
|
||||
.getInput(idx)
|
||||
.dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
inline TensorType GetInputType(FuncOp func, int idx) {
|
||||
return func.getType().getInput(idx).dyn_cast_or_null<TensorType>();
|
||||
}
|
||||
|
||||
inline RankedTensorType getResultType(mlir::FuncOp func, int idx) {
|
||||
return func.getType()
|
||||
.getResult(idx)
|
||||
.dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
inline TensorType GetResultType(FuncOp func, int idx) {
|
||||
return func.getType().getResult(idx).dyn_cast_or_null<TensorType>();
|
||||
}
|
||||
|
||||
LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
|
||||
inline bool RankEquals(const TensorType& type, int rank) {
|
||||
return type && type.hasRank() && type.getRank() == rank;
|
||||
}
|
||||
|
||||
LogicalResult VerifyWhitespaceTokenizer(FuncOp func) {
|
||||
// In the case of input tensor with 0 rank.
|
||||
// Whitespace tokenizer generates 1 output:
|
||||
// * String tensor for tokens.
|
||||
@ -83,8 +88,8 @@ LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
|
||||
// * 1st output is the value of ragged tensor;
|
||||
// * 2nd output is the inner offset;
|
||||
// * 3rd output is the outer offset.
|
||||
auto input_type = getInputType(func, 0);
|
||||
if (!input_type || !input_type.getElementType().isa<mlir::TF::StringType>() ||
|
||||
auto input_type = GetInputType(func, 0);
|
||||
if (!input_type || !input_type.getElementType().isa<StringType>() ||
|
||||
!input_type.hasRank()) {
|
||||
return func.emitError() << "Input should be a string tensor";
|
||||
}
|
||||
@ -100,21 +105,21 @@ LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
|
||||
<< "output(s) when input has rank " << input_type.getRank();
|
||||
}
|
||||
|
||||
auto value_type = getResultType(func, 0);
|
||||
if (!value_type || !value_type.hasRank() || value_type.getRank() != 1 ||
|
||||
!value_type.getElementType().isa<mlir::TF::StringType>()) {
|
||||
auto value_type = GetResultType(func, 0);
|
||||
if (!RankEquals(value_type, 1) ||
|
||||
!value_type.getElementType().isa<StringType>()) {
|
||||
return func.emitError() << "1st output should be string tensor";
|
||||
}
|
||||
if (func.getNumResults() > 1) {
|
||||
auto offset_type = getResultType(func, 1);
|
||||
if (!offset_type || !offset_type.hasRank() || offset_type.getRank() != 1 ||
|
||||
auto offset_type = GetResultType(func, 1);
|
||||
if (!RankEquals(offset_type, 1) ||
|
||||
!offset_type.getElementType().isInteger(64)) {
|
||||
return func.emitError() << "2nd output should be int64 tensor";
|
||||
}
|
||||
}
|
||||
if (func.getNumResults() > 2) {
|
||||
auto offset_type = getResultType(func, 2);
|
||||
if (!offset_type || !offset_type.hasRank() || offset_type.getRank() != 1 ||
|
||||
auto offset_type = GetResultType(func, 2);
|
||||
if (!RankEquals(offset_type, 1) ||
|
||||
!offset_type.getElementType().isInteger(64)) {
|
||||
return func.emitError() << "3rd output should be int64 tensor";
|
||||
}
|
||||
@ -123,28 +128,159 @@ LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertWhitespaceTokenizer(mlir::FuncOp func, llvm::StringRef api,
|
||||
LogicalResult ConvertWhitespaceTokenizer(FuncOp func, llvm::StringRef api,
|
||||
FuncAttr attr) {
|
||||
func.eraseBody();
|
||||
func.addEntryBlock();
|
||||
func.setAttr(kTFImplements, attr);
|
||||
Value text = func.getArgument(0);
|
||||
OpBuilder builder(func.getBody());
|
||||
|
||||
auto op = builder.create<mlir::TFL::CustomOp>(
|
||||
func.getLoc(), func.getType().getResults(), ValueRange(text), api,
|
||||
emptyCustomOption(&builder));
|
||||
builder.create<mlir::ReturnOp>(func.getLoc(), op.getResults());
|
||||
std::string empty_option_buffer;
|
||||
auto op = builder.create<CustomOp>(
|
||||
func.getLoc(), func.getType().getResults(), func.getArguments(), api,
|
||||
CustomOption(&builder, empty_option_buffer));
|
||||
builder.create<ReturnOp>(func.getLoc(), op.getResults());
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult VerifyNgrams(FuncOp func) {
|
||||
// The inputs and outputs should be the same:
|
||||
// * A string tensor for tokens/ragged tensor values.
|
||||
// * Zero or more row_split tensors.
|
||||
constexpr int kValues = 0;
|
||||
constexpr int kRowSplits = 1;
|
||||
|
||||
if (func.getType().getInputs().size() != func.getType().getResults().size()) {
|
||||
return func.emitError() << "Mismatched number of inputs and outputs.";
|
||||
}
|
||||
|
||||
int row_splits = func.getType().getInputs().size() - kRowSplits;
|
||||
if (row_splits == 0) {
|
||||
auto input_values = GetInputType(func, kValues);
|
||||
if (!input_values || !input_values.getElementType().isa<StringType>()) {
|
||||
return func.emitError()
|
||||
<< "Input " << kValues << " should be a string tensor";
|
||||
}
|
||||
auto output_values = GetResultType(func, kValues);
|
||||
if (!output_values || !output_values.getElementType().isa<StringType>()) {
|
||||
return func.emitError()
|
||||
<< "Output " << kValues << " should be a string tensor";
|
||||
}
|
||||
|
||||
if (input_values.hasRank() && output_values.hasRank() &&
|
||||
input_values.getRank() != output_values.getRank()) {
|
||||
return func.emitError() << "Input " << kValues << " and output "
|
||||
<< kValues << " should have the same rank";
|
||||
}
|
||||
} else {
|
||||
auto input_values = GetInputType(func, kValues);
|
||||
if (!RankEquals(input_values, 1) ||
|
||||
!input_values.getElementType().isa<StringType>()) {
|
||||
return func.emitError()
|
||||
<< "Input " << kValues << " should be a 1D string tensor";
|
||||
}
|
||||
auto output_values = GetResultType(func, kValues);
|
||||
if (!RankEquals(output_values, 1) ||
|
||||
!output_values.getElementType().isa<StringType>()) {
|
||||
return func.emitError()
|
||||
<< "Output " << kValues << " should be a 1D string tensor";
|
||||
}
|
||||
|
||||
for (int i = 0; i < row_splits; ++i) {
|
||||
const int row_index = i + kRowSplits;
|
||||
auto input_row_splits = GetInputType(func, row_index);
|
||||
if (!RankEquals(input_row_splits, 1) ||
|
||||
!input_row_splits.getElementType().isInteger(64)) {
|
||||
return func.emitError()
|
||||
<< "Input " << row_index << " should be a 1D int64 tensor";
|
||||
}
|
||||
auto output_row_splits = GetResultType(func, row_index);
|
||||
if (!RankEquals(output_row_splits, 1) ||
|
||||
!output_row_splits.getElementType().isInteger(64)) {
|
||||
return func.emitError()
|
||||
<< "Output " << row_index << " should be a 1D int64 tensor";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult CreateNgramsCustomOption(FuncOp func, DictionaryAttr attrs,
|
||||
std::string& custom_option_buffer) {
|
||||
flexbuffers::Builder fbb;
|
||||
size_t start_map = fbb.StartMap();
|
||||
|
||||
auto width = attrs.get("width").dyn_cast_or_null<IntegerAttr>();
|
||||
if (!width) {
|
||||
return func.emitError() << "'width' attribute is not set or not an integer";
|
||||
}
|
||||
fbb.Int("width", width.getInt());
|
||||
|
||||
auto string_separator =
|
||||
attrs.get("string_separator").dyn_cast_or_null<StringAttr>();
|
||||
if (!string_separator) {
|
||||
return func.emitError()
|
||||
<< "'string_separator' attribute is not set or not a string";
|
||||
}
|
||||
// StringAttrs are not guaranteed to be NUL terminated, but flexbuffers
|
||||
// strings expect NUL terminated strings.
|
||||
std::string string_separator_str(string_separator.getValue().data(),
|
||||
string_separator.getValue().size());
|
||||
fbb.String("string_separator", string_separator_str);
|
||||
|
||||
auto axis = attrs.get("axis").dyn_cast_or_null<IntegerAttr>();
|
||||
if (!axis) {
|
||||
return func.emitError() << "'axis' attribute is not set or not an integer";
|
||||
}
|
||||
fbb.Int("axis", axis.getInt());
|
||||
|
||||
auto reduction_type =
|
||||
attrs.get("reduction_type").dyn_cast_or_null<StringAttr>();
|
||||
if (!reduction_type) {
|
||||
return func.emitError()
|
||||
<< "'reduction_type' attribute is not set or not a string";
|
||||
}
|
||||
// StringAttrs are not guaranteed to be NUL terminated, but flexbuffers
|
||||
// strings expect NUL terminated strings.
|
||||
std::string reduction_type_str(reduction_type.getValue().data(),
|
||||
reduction_type.getValue().size());
|
||||
fbb.String("reduction_type", reduction_type_str);
|
||||
|
||||
fbb.EndMap(start_map);
|
||||
fbb.Finish();
|
||||
custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertNgrams(FuncOp func, llvm::StringRef api, FuncAttr attr) {
|
||||
func.eraseBody();
|
||||
func.addEntryBlock();
|
||||
func.setAttr(kTFImplements, attr);
|
||||
OpBuilder builder(func.getBody());
|
||||
std::string custom_option_buffer;
|
||||
if (failed(CreateNgramsCustomOption(func, attr.GetAttrs(),
|
||||
custom_option_buffer))) {
|
||||
return failure();
|
||||
}
|
||||
auto op = builder.create<CustomOp>(
|
||||
func.getLoc(), func.getType().getResults(), func.getArguments(), api,
|
||||
CustomOption(&builder, custom_option_buffer));
|
||||
builder.create<ReturnOp>(func.getLoc(), op.getResults());
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api,
|
||||
LogicalResult ConvertTFTextAPI(FuncOp func, llvm::StringRef api,
|
||||
FuncAttr attr) {
|
||||
if (api.str() == kWhitespaceTokenizer) {
|
||||
if (succeeded(VerifyWhitespaceTokenizer(func))) {
|
||||
return ConvertWhitespaceTokenizer(func, api, attr);
|
||||
}
|
||||
} else if (api.str() == kNgrams) {
|
||||
if (succeeded(VerifyNgrams(func))) {
|
||||
return ConvertNgrams(func, api, attr);
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user