diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 10ba3bbeb31..23d5d9612f7 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -1710,6 +1710,9 @@ Optional Translator::Translate( const std::unordered_set& select_user_tf_ops, const std::unordered_set& tags, OpOrArgNameMapper* op_or_arg_name_mapper) { + OpOrArgLocNameMapper default_op_or_arg_name_mapper; + if (!op_or_arg_name_mapper) + op_or_arg_name_mapper = &default_op_or_arg_name_mapper; if (!UpdateEntryFunction(module)) return llvm::None; if (!IsValidTFLiteMlirModule(module)) return llvm::None; Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops, @@ -1942,69 +1945,23 @@ BufferOffset Translator::BuildSparsityParameters( } // namespace -// Translates the given MLIR module in the TFLite dialect to TFLite FlatBuffer -// format. Returns false on success. -// +namespace tflite { // TODO(hinsu): Support all valid MLIR modules in TFLite dialect by supporting // the following: // // * Quantization // * Ops with variable tensors // -bool tflite::MlirToFlatBufferTranslateFunction( - ModuleOp module, std::string* serialized_flatbuffer, - bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, - OpOrArgNameMapper* op_or_arg_name_mapper) { - return MlirToFlatBufferTranslateFunction( - module, serialized_flatbuffer, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{}, - op_or_arg_name_mapper); -} - -bool tflite::MlirToFlatBufferTranslateFunction( - ModuleOp module, std::string* serialized_flatbuffer, - bool emit_builtin_tflite_ops, bool emit_select_tf_ops, - bool emit_custom_ops) { - OpOrArgLocNameMapper op_or_arg_name_mapper; - return MlirToFlatBufferTranslateFunction( - module, serialized_flatbuffer, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{}, - &op_or_arg_name_mapper); -} - -bool tflite::MlirToFlatBufferTranslateFunction( - mlir::ModuleOp module, std::string* serialized_flatbuffer, - bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, - const std::unordered_set& saved_model_tags) { - OpOrArgLocNameMapper op_or_arg_name_mapper; - return MlirToFlatBufferTranslateFunction( - module, serialized_flatbuffer, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops, saved_model_tags, - &op_or_arg_name_mapper); -} - -bool tflite::MlirToFlatBufferTranslateFunction( - mlir::ModuleOp module, std::string* serialized_flatbuffer, - bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, - const std::unordered_set& saved_model_tags, - OpOrArgNameMapper* op_or_arg_name_mapper) { - std::unordered_set select_user_tf_ops; - return MlirToFlatBufferTranslateFunction( - module, serialized_flatbuffer, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops, select_user_tf_ops, saved_model_tags, - op_or_arg_name_mapper); -} - -bool tflite::MlirToFlatBufferTranslateFunction( - ModuleOp module, std::string* serialized_flatbuffer, - bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, - const std::unordered_set& select_user_tf_ops, - const std::unordered_set& saved_model_tags, - tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper) { +bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module, + const FlatbufferExportOptions& options, + std::string* serialized_flatbuffer) { auto maybe_translated = Translator::Translate( - module, emit_builtin_tflite_ops, emit_select_tf_ops, emit_custom_ops, - select_user_tf_ops, saved_model_tags, op_or_arg_name_mapper); - if (!maybe_translated) return true; + module, options.emit_builtin_tflite_ops, options.emit_select_tf_ops, + options.emit_custom_ops, options.select_user_tf_ops, + options.saved_model_tags, options.op_or_arg_name_mapper); + if (!maybe_translated) return false; *serialized_flatbuffer = std::move(*maybe_translated); - return false; + return true; } + +} // namespace tflite diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.h b/tensorflow/compiler/mlir/lite/flatbuffer_export.h index c47bffbf6bd..73b71668564 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.h @@ -23,43 +23,26 @@ limitations under the License. #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" namespace tflite { +// Options for exporting to Flatbuffer. +struct FlatbufferExportOptions { + bool emit_builtin_tflite_ops = false; + bool emit_select_tf_ops = false; + bool emit_custom_ops = false; + // When exporting from SavedModel, this will have the requested tags. + std::unordered_set saved_model_tags; + // TF custom op passed by the user. + std::unordered_set select_user_tf_ops; + // OpOrArgNameMapper to convert location of the op to name in flatbuffer. + // If not set, a default mapper will be used. + tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper = nullptr; +}; // Translates the given MLIR `module` into a FlatBuffer and stores the -// serialized flatbuffer into the string. This uses OpOrArgLocNameMapper to -// convert location of the op to name in flatbuffer. Returns true if translation -// fails, otherwise returns false. +// serialized flatbuffer into the string. +// Returns true on successful exporting, false otherwise. bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module, - std::string* serialized_flatbuffer, - bool emit_builtin_tflite_ops, - bool emit_select_tf_ops, - bool emit_custom_ops); - -// Same as above but takes SavedModel tags of the model. -bool MlirToFlatBufferTranslateFunction( - mlir::ModuleOp module, std::string* serialized_flatbuffer, - bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, - const std::unordered_set& saved_model_tags); - -// Same as the above but with a custom op name mapper. -bool MlirToFlatBufferTranslateFunction( - mlir::ModuleOp module, std::string* serialized_flatbuffer, - bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, - tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper); - -// Same as above but takes SavedModel tags of the model. -bool MlirToFlatBufferTranslateFunction( - mlir::ModuleOp module, std::string* serialized_flatbuffer, - bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, - const std::unordered_set& saved_model_tags, - tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper); - -// Same as the above but with a list of allowed user's defined ops. -bool MlirToFlatBufferTranslateFunction( - mlir::ModuleOp module, std::string* serialized_flatbuffer, - bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, - const std::unordered_set& select_user_tf_ops, - const std::unordered_set& saved_model_tags, - tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper); + const FlatbufferExportOptions& options, + std::string* serialized_flatbuffer); } // namespace tflite #endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index 901199e4bee..1c3788a0b5e 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -158,9 +158,13 @@ static LogicalResult MlirToFlatBufferFileTranslateFunction( op_or_arg_name_mapper = std::make_unique(); } - if (tflite::MlirToFlatBufferTranslateFunction( - module, &serialized_flatbuffer, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops, op_or_arg_name_mapper.get())) + tflite::FlatbufferExportOptions options; + options.emit_builtin_tflite_ops = emit_builtin_tflite_ops; + options.emit_custom_ops = emit_custom_ops; + options.emit_select_tf_ops = emit_select_tf_ops; + options.op_or_arg_name_mapper = op_or_arg_name_mapper.get(); + if (!tflite::MlirToFlatBufferTranslateFunction(module, options, + &serialized_flatbuffer)) return mlir::failure(); output << serialized_flatbuffer; diff --git a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc index 93d0c07fab8..6e5af0889c5 100644 --- a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc +++ b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc @@ -118,9 +118,12 @@ int main(int argc, char** argv) { // Convert to flatbuffer. std::string serialized_flatbuffer; - if (tflite::MlirToFlatBufferTranslateFunction( - module.get(), &serialized_flatbuffer, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops)) + tflite::FlatbufferExportOptions options; + options.emit_builtin_tflite_ops = emit_builtin_tflite_ops; + options.emit_custom_ops = emit_custom_ops; + options.emit_select_tf_ops = emit_select_tf_ops; + if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options, + &serialized_flatbuffer)) return 1; // Create TFLite interpreter & invoke converted program. diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 6b267e726de..f8870cbfb57 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -108,9 +108,12 @@ TfLiteStatus QuantizeModel( // Export the results to the builder std::string result; - if (tflite::MlirToFlatBufferTranslateFunction( - module.get(), &result, /*emit_builtin_tflite_ops=*/true, - /*emit_select_tf_ops=*/true, /*emit_custom_ops=*/true)) { + tflite::FlatbufferExportOptions options; + options.emit_builtin_tflite_ops = true; + options.emit_select_tf_ops = true; + options.emit_custom_ops = true; + if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options, + &result)) { error_reporter->Report("Failed to export MLIR to flatbuffer."); return kTfLiteError; } diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc index eed9529a969..d3482f706c7 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc +++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc @@ -68,9 +68,12 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model, // Export the results to the builder std::string result; - if (tflite::MlirToFlatBufferTranslateFunction( - module.get(), &result, /*emit_builtin_tflite_ops=*/true, - /*emit_select_tf_ops=*/true, /*emit_custom_ops=*/true)) { + tflite::FlatbufferExportOptions options; + options.emit_builtin_tflite_ops = true; + options.emit_select_tf_ops = true; + options.emit_custom_ops = true; + if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options, + &result)) { error_reporter->Report("Failed to export MLIR to flatbuffer."); return kTfLiteError; } diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index e1312664f03..a9a192b8552 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -172,20 +172,29 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( // Write MLIR TFLite dialect into FlatBuffer OpOrArgLocNameMapper op_or_arg_name_mapper; if (!quant_specs.RunWeightQuantization()) { - if (tflite::MlirToFlatBufferTranslateFunction( - module, result, emit_builtin_tflite_ops, emit_select_tf_ops, - emit_custom_ops, select_user_tf_ops, saved_model_tags, - &op_or_arg_name_mapper)) { + tflite::FlatbufferExportOptions options; + options.emit_builtin_tflite_ops = emit_builtin_tflite_ops; + options.emit_select_tf_ops = emit_select_tf_ops; + options.select_user_tf_ops = select_user_tf_ops; + options.emit_custom_ops = emit_custom_ops; + options.saved_model_tags = saved_model_tags; + options.op_or_arg_name_mapper = &op_or_arg_name_mapper; + if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result)) { return statusHandler.ConsumeStatus(); } } else { // Post-training weight quantization path. Once MLIR has support for this, // we can remove this else statement. std::string pre_quantized_result; - if (tflite::MlirToFlatBufferTranslateFunction( - module, &pre_quantized_result, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops, select_user_tf_ops, - saved_model_tags, &op_or_arg_name_mapper)) { + tflite::FlatbufferExportOptions options; + options.emit_builtin_tflite_ops = emit_builtin_tflite_ops; + options.emit_select_tf_ops = emit_select_tf_ops; + options.select_user_tf_ops = select_user_tf_ops; + options.emit_custom_ops = emit_custom_ops; + options.saved_model_tags = saved_model_tags; + options.op_or_arg_name_mapper = &op_or_arg_name_mapper; + if (!tflite::MlirToFlatBufferTranslateFunction(module, options, + &pre_quantized_result)) { return statusHandler.ConsumeStatus(); } flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);