Support variant of whitespace tokenizer

PiperOrigin-RevId: 315522310
Change-Id: I44e3fa1563b5f46445602eb6495941c00d7ce4b8
This commit is contained in:
A. Unique TensorFlower 2020-06-09 11:23:45 -07:00 committed by TensorFlower Gardener
parent 4ff7e65477
commit eb8f61f5f4
2 changed files with 3237 additions and 32 deletions
tensorflow/compiler/mlir/lite

File diff suppressed because it is too large Load Diff

View File

@ -67,27 +67,57 @@ inline RankedTensorType getResultType(mlir::FuncOp func, int idx) {
}
LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
if (func.getNumResults() != 2) {
return failure();
}
if (func.getNumArguments() != 1) {
return failure();
}
// In the case of input tensor with 0 rank.
// Whitespace tokenizer generates 1 output:
// * String tensor for tokens.
//
// In the case of 1-D input tensor,
// Whitespace tokenizer generates 2 outputs to make up a ragged tensor:
// * 1st output is the value of ragged tensor;
// * 2nd output is the offset.
//
// In the case of batched input tesnor,
// Whitespace tokenizer has 3 outputs to make up a nested ragged tensor:
// * 1st output is the value of ragged tensor;
// * 2nd output is the inner offset;
// * 3rd output is the outer offset.
auto input_type = getInputType(func, 0);
if (!input_type || input_type.getRank() != 1 ||
!input_type.getElementType().isa<mlir::TF::StringType>()) {
return failure();
if (!input_type || !input_type.getElementType().isa<mlir::TF::StringType>() ||
!input_type.hasRank()) {
return func.emitError() << "Input should be a string tensor";
}
const std::vector<int> kValidNumOfOutput = {1, 2, 3};
if (input_type.getRank() >= kValidNumOfOutput.size()) {
return func.emitError()
<< "Unrecognized input rank: " << input_type.getRank();
}
if (func.getNumResults() != kValidNumOfOutput[input_type.getRank()]) {
return func.emitError()
<< "Expect " << kValidNumOfOutput[input_type.getRank()]
<< "output(s) when input has rank " << input_type.getRank();
}
auto value_type = getResultType(func, 0);
if (!value_type || value_type.getRank() != 1 ||
if (!value_type || !value_type.hasRank() || value_type.getRank() != 1 ||
!value_type.getElementType().isa<mlir::TF::StringType>()) {
return failure();
return func.emitError() << "1st output should be string tensor";
}
auto offset_type = getResultType(func, 1);
if (offset_type.getRank() != 1 ||
!offset_type.getElementType().isInteger(64)) {
return failure();
if (func.getNumResults() > 1) {
auto offset_type = getResultType(func, 1);
if (!offset_type || !offset_type.hasRank() || offset_type.getRank() != 1 ||
!offset_type.getElementType().isInteger(64)) {
return func.emitError() << "2nd output should be int64 tensor";
}
}
if (func.getNumResults() > 2) {
auto offset_type = getResultType(func, 2);
if (!offset_type || !offset_type.hasRank() || offset_type.getRank() != 1 ||
!offset_type.getElementType().isInteger(64)) {
return func.emitError() << "3rd output should be int64 tensor";
}
}
return success();
}
@ -96,19 +126,12 @@ LogicalResult ConvertWhitespaceTokenizer(mlir::FuncOp func,
func.eraseBody();
func.addEntryBlock();
func.setAttr(kTFAPIImplements, StringAttr::get(api, func.getContext()));
Value text = func.getArgument(0);
auto output_type = func.getType().getResult(0);
auto offset_type = func.getType().getResult(1);
SmallVector<Type, 2> shape = {output_type, offset_type};
ArrayRef<Type> output_types(shape);
OpBuilder builder(func.getBody());
auto op = builder.create<mlir::TFL::CustomOp>(func.getLoc(), output_types,
ValueRange(text), api,
emptyCustomOption(&builder));
auto op = builder.create<mlir::TFL::CustomOp>(
func.getLoc(), func.getType().getResults(), ValueRange(text), api,
emptyCustomOption(&builder));
builder.create<mlir::ReturnOp>(func.getLoc(), op.getResults());
return success();
}