Fuse tensorflow_text.ngrams into a TFLite custom op
PiperOrigin-RevId: 323456482 Change-Id: Idfd446c371e8a4a4f82b6da730d02b0897d35a8a
This commit is contained in:
		
							parent
							
								
									c7e7f49228
								
							
						
					
					
						commit
						07e4db17ff
					
				| @ -270,6 +270,7 @@ cc_library( | ||||
|         "//tensorflow/compiler/mlir/tensorflow", | ||||
|         "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", | ||||
|         "//tensorflow/core:framework", | ||||
|         "@flatbuffers", | ||||
|         "@llvm-project//llvm:Support", | ||||
|         "@llvm-project//mlir:IR", | ||||
|         "@llvm-project//mlir:StandardOps", | ||||
|  | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -15,6 +15,7 @@ limitations under the License. | ||||
| 
 | ||||
| #include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h" | ||||
| 
 | ||||
| #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
 | ||||
| #include "llvm/ADT/ArrayRef.h" | ||||
| #include "llvm/ADT/None.h" | ||||
| #include "llvm/ADT/SmallVector.h" | ||||
| @ -28,6 +29,7 @@ limitations under the License. | ||||
| #include "mlir/IR/Identifier.h"  // from @llvm-project
 | ||||
| #include "mlir/IR/Location.h"  // from @llvm-project
 | ||||
| #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 | ||||
| #include "mlir/IR/Matchers.h"  // from @llvm-project
 | ||||
| #include "mlir/IR/OpDefinition.h"  // from @llvm-project
 | ||||
| #include "mlir/IR/Operation.h"  // from @llvm-project
 | ||||
| #include "mlir/IR/StandardTypes.h"  // from @llvm-project
 | ||||
| @ -43,32 +45,35 @@ namespace TFL { | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| constexpr char kNgrams[] = "tftext:Ngrams"; | ||||
| constexpr char kWhitespaceTokenizer[] = "tftext:WhitespaceTokenizer"; | ||||
| constexpr char kTFImplements[] = "tf._implements"; | ||||
| 
 | ||||
| using mlir::TF::FuncAttr; | ||||
| using mlir::TF::StringType; | ||||
| 
 | ||||
| inline OpaqueElementsAttr emptyCustomOption(OpBuilder* builder) { | ||||
|   std::string content = ""; | ||||
| inline OpaqueElementsAttr CustomOption(OpBuilder* builder, | ||||
|                                        const std::string& content) { | ||||
|   ShapedType type = RankedTensorType::get( | ||||
|       {static_cast<int64_t>(content.size())}, builder->getIntegerType(8)); | ||||
|   return OpaqueElementsAttr::get( | ||||
|       builder->getContext()->getRegisteredDialect("tfl"), type, content); | ||||
|       builder->getContext()->getRegisteredDialect("tfl"), type, | ||||
|       StringRef(content.data(), content.size())); | ||||
| } | ||||
| 
 | ||||
| inline RankedTensorType getInputType(mlir::FuncOp func, int idx) { | ||||
|   return func.getType() | ||||
|       .getInput(idx) | ||||
|       .dyn_cast_or_null<mlir::RankedTensorType>(); | ||||
| inline TensorType GetInputType(FuncOp func, int idx) { | ||||
|   return func.getType().getInput(idx).dyn_cast_or_null<TensorType>(); | ||||
| } | ||||
| 
 | ||||
| inline RankedTensorType getResultType(mlir::FuncOp func, int idx) { | ||||
|   return func.getType() | ||||
|       .getResult(idx) | ||||
|       .dyn_cast_or_null<mlir::RankedTensorType>(); | ||||
| inline TensorType GetResultType(FuncOp func, int idx) { | ||||
|   return func.getType().getResult(idx).dyn_cast_or_null<TensorType>(); | ||||
| } | ||||
| 
 | ||||
| LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) { | ||||
| inline bool RankEquals(const TensorType& type, int rank) { | ||||
|   return type && type.hasRank() && type.getRank() == rank; | ||||
| } | ||||
| 
 | ||||
| LogicalResult VerifyWhitespaceTokenizer(FuncOp func) { | ||||
|   // In the case of input tensor with 0 rank.
 | ||||
|   // Whitespace tokenizer generates 1 output:
 | ||||
|   // * String tensor for tokens.
 | ||||
| @ -83,8 +88,8 @@ LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) { | ||||
|   // * 1st output is the value of ragged tensor;
 | ||||
|   // * 2nd output is the inner offset;
 | ||||
|   // * 3rd output is the outer offset.
 | ||||
|   auto input_type = getInputType(func, 0); | ||||
|   if (!input_type || !input_type.getElementType().isa<mlir::TF::StringType>() || | ||||
|   auto input_type = GetInputType(func, 0); | ||||
|   if (!input_type || !input_type.getElementType().isa<StringType>() || | ||||
|       !input_type.hasRank()) { | ||||
|     return func.emitError() << "Input should be a string tensor"; | ||||
|   } | ||||
| @ -100,21 +105,21 @@ LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) { | ||||
|            << "output(s) when input has rank " << input_type.getRank(); | ||||
|   } | ||||
| 
 | ||||
|   auto value_type = getResultType(func, 0); | ||||
|   if (!value_type || !value_type.hasRank() || value_type.getRank() != 1 || | ||||
|       !value_type.getElementType().isa<mlir::TF::StringType>()) { | ||||
|   auto value_type = GetResultType(func, 0); | ||||
|   if (!RankEquals(value_type, 1) || | ||||
|       !value_type.getElementType().isa<StringType>()) { | ||||
|     return func.emitError() << "1st output should be string tensor"; | ||||
|   } | ||||
|   if (func.getNumResults() > 1) { | ||||
|     auto offset_type = getResultType(func, 1); | ||||
|     if (!offset_type || !offset_type.hasRank() || offset_type.getRank() != 1 || | ||||
|     auto offset_type = GetResultType(func, 1); | ||||
|     if (!RankEquals(offset_type, 1) || | ||||
|         !offset_type.getElementType().isInteger(64)) { | ||||
|       return func.emitError() << "2nd output should be int64 tensor"; | ||||
|     } | ||||
|   } | ||||
|   if (func.getNumResults() > 2) { | ||||
|     auto offset_type = getResultType(func, 2); | ||||
|     if (!offset_type || !offset_type.hasRank() || offset_type.getRank() != 1 || | ||||
|     auto offset_type = GetResultType(func, 2); | ||||
|     if (!RankEquals(offset_type, 1) || | ||||
|         !offset_type.getElementType().isInteger(64)) { | ||||
|       return func.emitError() << "3rd output should be int64 tensor"; | ||||
|     } | ||||
| @ -123,28 +128,159 @@ LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) { | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| LogicalResult ConvertWhitespaceTokenizer(mlir::FuncOp func, llvm::StringRef api, | ||||
| LogicalResult ConvertWhitespaceTokenizer(FuncOp func, llvm::StringRef api, | ||||
|                                          FuncAttr attr) { | ||||
|   func.eraseBody(); | ||||
|   func.addEntryBlock(); | ||||
|   func.setAttr(kTFImplements, attr); | ||||
|   Value text = func.getArgument(0); | ||||
|   OpBuilder builder(func.getBody()); | ||||
| 
 | ||||
|   auto op = builder.create<mlir::TFL::CustomOp>( | ||||
|       func.getLoc(), func.getType().getResults(), ValueRange(text), api, | ||||
|       emptyCustomOption(&builder)); | ||||
|   builder.create<mlir::ReturnOp>(func.getLoc(), op.getResults()); | ||||
|   std::string empty_option_buffer; | ||||
|   auto op = builder.create<CustomOp>( | ||||
|       func.getLoc(), func.getType().getResults(), func.getArguments(), api, | ||||
|       CustomOption(&builder, empty_option_buffer)); | ||||
|   builder.create<ReturnOp>(func.getLoc(), op.getResults()); | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| LogicalResult VerifyNgrams(FuncOp func) { | ||||
|   // The inputs and outputs should be the same:
 | ||||
|   // * A string tensor for tokens/ragged tensor values.
 | ||||
|   // * Zero or more row_split tensors.
 | ||||
|   constexpr int kValues = 0; | ||||
|   constexpr int kRowSplits = 1; | ||||
| 
 | ||||
|   if (func.getType().getInputs().size() != func.getType().getResults().size()) { | ||||
|     return func.emitError() << "Mismatched number of inputs and outputs."; | ||||
|   } | ||||
| 
 | ||||
|   int row_splits = func.getType().getInputs().size() - kRowSplits; | ||||
|   if (row_splits == 0) { | ||||
|     auto input_values = GetInputType(func, kValues); | ||||
|     if (!input_values || !input_values.getElementType().isa<StringType>()) { | ||||
|       return func.emitError() | ||||
|              << "Input " << kValues << " should be a string tensor"; | ||||
|     } | ||||
|     auto output_values = GetResultType(func, kValues); | ||||
|     if (!output_values || !output_values.getElementType().isa<StringType>()) { | ||||
|       return func.emitError() | ||||
|              << "Output " << kValues << " should be a string tensor"; | ||||
|     } | ||||
| 
 | ||||
|     if (input_values.hasRank() && output_values.hasRank() && | ||||
|         input_values.getRank() != output_values.getRank()) { | ||||
|       return func.emitError() << "Input " << kValues << " and output " | ||||
|                               << kValues << " should have the same rank"; | ||||
|     } | ||||
|   } else { | ||||
|     auto input_values = GetInputType(func, kValues); | ||||
|     if (!RankEquals(input_values, 1) || | ||||
|         !input_values.getElementType().isa<StringType>()) { | ||||
|       return func.emitError() | ||||
|              << "Input " << kValues << " should be a 1D string tensor"; | ||||
|     } | ||||
|     auto output_values = GetResultType(func, kValues); | ||||
|     if (!RankEquals(output_values, 1) || | ||||
|         !output_values.getElementType().isa<StringType>()) { | ||||
|       return func.emitError() | ||||
|              << "Output " << kValues << " should be a 1D string tensor"; | ||||
|     } | ||||
| 
 | ||||
|     for (int i = 0; i < row_splits; ++i) { | ||||
|       const int row_index = i + kRowSplits; | ||||
|       auto input_row_splits = GetInputType(func, row_index); | ||||
|       if (!RankEquals(input_row_splits, 1) || | ||||
|           !input_row_splits.getElementType().isInteger(64)) { | ||||
|         return func.emitError() | ||||
|                << "Input " << row_index << " should be a 1D int64 tensor"; | ||||
|       } | ||||
|       auto output_row_splits = GetResultType(func, row_index); | ||||
|       if (!RankEquals(output_row_splits, 1) || | ||||
|           !output_row_splits.getElementType().isInteger(64)) { | ||||
|         return func.emitError() | ||||
|                << "Output " << row_index << " should be a 1D int64 tensor"; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| LogicalResult CreateNgramsCustomOption(FuncOp func, DictionaryAttr attrs, | ||||
|                                        std::string& custom_option_buffer) { | ||||
|   flexbuffers::Builder fbb; | ||||
|   size_t start_map = fbb.StartMap(); | ||||
| 
 | ||||
|   auto width = attrs.get("width").dyn_cast_or_null<IntegerAttr>(); | ||||
|   if (!width) { | ||||
|     return func.emitError() << "'width' attribute is not set or not an integer"; | ||||
|   } | ||||
|   fbb.Int("width", width.getInt()); | ||||
| 
 | ||||
|   auto string_separator = | ||||
|       attrs.get("string_separator").dyn_cast_or_null<StringAttr>(); | ||||
|   if (!string_separator) { | ||||
|     return func.emitError() | ||||
|            << "'string_separator' attribute is not set or not a string"; | ||||
|   } | ||||
|   // StringAttrs are not guaranteed to be NUL terminated, but flexbuffers
 | ||||
|   // strings expect NUL terminated strings.
 | ||||
|   std::string string_separator_str(string_separator.getValue().data(), | ||||
|                                    string_separator.getValue().size()); | ||||
|   fbb.String("string_separator", string_separator_str); | ||||
| 
 | ||||
|   auto axis = attrs.get("axis").dyn_cast_or_null<IntegerAttr>(); | ||||
|   if (!axis) { | ||||
|     return func.emitError() << "'axis' attribute is not set or not an integer"; | ||||
|   } | ||||
|   fbb.Int("axis", axis.getInt()); | ||||
| 
 | ||||
|   auto reduction_type = | ||||
|       attrs.get("reduction_type").dyn_cast_or_null<StringAttr>(); | ||||
|   if (!reduction_type) { | ||||
|     return func.emitError() | ||||
|            << "'reduction_type' attribute is not set or not a string"; | ||||
|   } | ||||
|   // StringAttrs are not guaranteed to be NUL terminated, but flexbuffers
 | ||||
|   // strings expect NUL terminated strings.
 | ||||
|   std::string reduction_type_str(reduction_type.getValue().data(), | ||||
|                                  reduction_type.getValue().size()); | ||||
|   fbb.String("reduction_type", reduction_type_str); | ||||
| 
 | ||||
|   fbb.EndMap(start_map); | ||||
|   fbb.Finish(); | ||||
|   custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end()); | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| LogicalResult ConvertNgrams(FuncOp func, llvm::StringRef api, FuncAttr attr) { | ||||
|   func.eraseBody(); | ||||
|   func.addEntryBlock(); | ||||
|   func.setAttr(kTFImplements, attr); | ||||
|   OpBuilder builder(func.getBody()); | ||||
|   std::string custom_option_buffer; | ||||
|   if (failed(CreateNgramsCustomOption(func, attr.GetAttrs(), | ||||
|                                       custom_option_buffer))) { | ||||
|     return failure(); | ||||
|   } | ||||
|   auto op = builder.create<CustomOp>( | ||||
|       func.getLoc(), func.getType().getResults(), func.getArguments(), api, | ||||
|       CustomOption(&builder, custom_option_buffer)); | ||||
|   builder.create<ReturnOp>(func.getLoc(), op.getResults()); | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api, | ||||
| LogicalResult ConvertTFTextAPI(FuncOp func, llvm::StringRef api, | ||||
|                                FuncAttr attr) { | ||||
|   if (api.str() == kWhitespaceTokenizer) { | ||||
|     if (succeeded(VerifyWhitespaceTokenizer(func))) { | ||||
|       return ConvertWhitespaceTokenizer(func, api, attr); | ||||
|     } | ||||
|   } else if (api.str() == kNgrams) { | ||||
|     if (succeeded(VerifyNgrams(func))) { | ||||
|       return ConvertNgrams(func, api, attr); | ||||
|     } | ||||
|   } | ||||
|   return failure(); | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user