Support variant of whitespace tokenizer
PiperOrigin-RevId: 315522310 Change-Id: I44e3fa1563b5f46445602eb6495941c00d7ce4b8
This commit is contained in:
parent
4ff7e65477
commit
eb8f61f5f4
tensorflow/compiler/mlir/lite
File diff suppressed because it is too large
Load Diff
@ -67,27 +67,57 @@ inline RankedTensorType getResultType(mlir::FuncOp func, int idx) {
|
||||
}
|
||||
|
||||
LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
|
||||
if (func.getNumResults() != 2) {
|
||||
return failure();
|
||||
}
|
||||
if (func.getNumArguments() != 1) {
|
||||
return failure();
|
||||
}
|
||||
// In the case of input tensor with 0 rank.
|
||||
// Whitespace tokenizer generates 1 output:
|
||||
// * String tensor for tokens.
|
||||
//
|
||||
// In the case of 1-D input tensor,
|
||||
// Whitespace tokenizer generates 2 outputs to make up a ragged tensor:
|
||||
// * 1st output is the value of ragged tensor;
|
||||
// * 2nd output is the offset.
|
||||
//
|
||||
// In the case of batched input tesnor,
|
||||
// Whitespace tokenizer has 3 outputs to make up a nested ragged tensor:
|
||||
// * 1st output is the value of ragged tensor;
|
||||
// * 2nd output is the inner offset;
|
||||
// * 3rd output is the outer offset.
|
||||
auto input_type = getInputType(func, 0);
|
||||
if (!input_type || input_type.getRank() != 1 ||
|
||||
!input_type.getElementType().isa<mlir::TF::StringType>()) {
|
||||
return failure();
|
||||
if (!input_type || !input_type.getElementType().isa<mlir::TF::StringType>() ||
|
||||
!input_type.hasRank()) {
|
||||
return func.emitError() << "Input should be a string tensor";
|
||||
}
|
||||
|
||||
const std::vector<int> kValidNumOfOutput = {1, 2, 3};
|
||||
if (input_type.getRank() >= kValidNumOfOutput.size()) {
|
||||
return func.emitError()
|
||||
<< "Unrecognized input rank: " << input_type.getRank();
|
||||
}
|
||||
if (func.getNumResults() != kValidNumOfOutput[input_type.getRank()]) {
|
||||
return func.emitError()
|
||||
<< "Expect " << kValidNumOfOutput[input_type.getRank()]
|
||||
<< "output(s) when input has rank " << input_type.getRank();
|
||||
}
|
||||
|
||||
auto value_type = getResultType(func, 0);
|
||||
if (!value_type || value_type.getRank() != 1 ||
|
||||
if (!value_type || !value_type.hasRank() || value_type.getRank() != 1 ||
|
||||
!value_type.getElementType().isa<mlir::TF::StringType>()) {
|
||||
return failure();
|
||||
return func.emitError() << "1st output should be string tensor";
|
||||
}
|
||||
auto offset_type = getResultType(func, 1);
|
||||
if (offset_type.getRank() != 1 ||
|
||||
!offset_type.getElementType().isInteger(64)) {
|
||||
return failure();
|
||||
if (func.getNumResults() > 1) {
|
||||
auto offset_type = getResultType(func, 1);
|
||||
if (!offset_type || !offset_type.hasRank() || offset_type.getRank() != 1 ||
|
||||
!offset_type.getElementType().isInteger(64)) {
|
||||
return func.emitError() << "2nd output should be int64 tensor";
|
||||
}
|
||||
}
|
||||
if (func.getNumResults() > 2) {
|
||||
auto offset_type = getResultType(func, 2);
|
||||
if (!offset_type || !offset_type.hasRank() || offset_type.getRank() != 1 ||
|
||||
!offset_type.getElementType().isInteger(64)) {
|
||||
return func.emitError() << "3rd output should be int64 tensor";
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -96,19 +126,12 @@ LogicalResult ConvertWhitespaceTokenizer(mlir::FuncOp func,
|
||||
func.eraseBody();
|
||||
func.addEntryBlock();
|
||||
func.setAttr(kTFAPIImplements, StringAttr::get(api, func.getContext()));
|
||||
|
||||
Value text = func.getArgument(0);
|
||||
auto output_type = func.getType().getResult(0);
|
||||
auto offset_type = func.getType().getResult(1);
|
||||
SmallVector<Type, 2> shape = {output_type, offset_type};
|
||||
ArrayRef<Type> output_types(shape);
|
||||
|
||||
OpBuilder builder(func.getBody());
|
||||
|
||||
auto op = builder.create<mlir::TFL::CustomOp>(func.getLoc(), output_types,
|
||||
ValueRange(text), api,
|
||||
emptyCustomOption(&builder));
|
||||
|
||||
auto op = builder.create<mlir::TFL::CustomOp>(
|
||||
func.getLoc(), func.getType().getResults(), ValueRange(text), api,
|
||||
emptyCustomOption(&builder));
|
||||
builder.create<mlir::ReturnOp>(func.getLoc(), op.getResults());
|
||||
return success();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user