[lite] Refactor MlirToFlatBufferTranslateFunction and make options struct and pass as argument instead of multiple overloads with multiple arguments.

PiperOrigin-RevId: 357763716
Change-Id: I4b999030d04b111b592d7e22a2a40aee065546db
This commit is contained in:
Karim Nosir 2021-02-16 11:21:30 -08:00 committed by TensorFlower Gardener
parent c3555ff1fc
commit e2cd9cda80
7 changed files with 73 additions and 111 deletions

View File

@ -1710,6 +1710,9 @@ Optional<std::string> Translator::Translate(
const std::unordered_set<std::string>& select_user_tf_ops, const std::unordered_set<std::string>& select_user_tf_ops,
const std::unordered_set<std::string>& tags, const std::unordered_set<std::string>& tags,
OpOrArgNameMapper* op_or_arg_name_mapper) { OpOrArgNameMapper* op_or_arg_name_mapper) {
OpOrArgLocNameMapper default_op_or_arg_name_mapper;
if (!op_or_arg_name_mapper)
op_or_arg_name_mapper = &default_op_or_arg_name_mapper;
if (!UpdateEntryFunction(module)) return llvm::None; if (!UpdateEntryFunction(module)) return llvm::None;
if (!IsValidTFLiteMlirModule(module)) return llvm::None; if (!IsValidTFLiteMlirModule(module)) return llvm::None;
Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops, Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
@ -1942,69 +1945,23 @@ BufferOffset<tflite::SparsityParameters> Translator::BuildSparsityParameters(
} // namespace } // namespace
// Translates the given MLIR module in the TFLite dialect to TFLite FlatBuffer namespace tflite {
// format. Returns false on success.
//
// TODO(hinsu): Support all valid MLIR modules in TFLite dialect by supporting // TODO(hinsu): Support all valid MLIR modules in TFLite dialect by supporting
// the following: // the following:
// //
// * Quantization // * Quantization
// * Ops with variable tensors // * Ops with variable tensors
// //
bool tflite::MlirToFlatBufferTranslateFunction( bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
ModuleOp module, std::string* serialized_flatbuffer, const FlatbufferExportOptions& options,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, std::string* serialized_flatbuffer) {
OpOrArgNameMapper* op_or_arg_name_mapper) {
return MlirToFlatBufferTranslateFunction(
module, serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{},
op_or_arg_name_mapper);
}
bool tflite::MlirToFlatBufferTranslateFunction(
ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
bool emit_custom_ops) {
OpOrArgLocNameMapper op_or_arg_name_mapper;
return MlirToFlatBufferTranslateFunction(
module, serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{},
&op_or_arg_name_mapper);
}
bool tflite::MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags) {
OpOrArgLocNameMapper op_or_arg_name_mapper;
return MlirToFlatBufferTranslateFunction(
module, serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, saved_model_tags,
&op_or_arg_name_mapper);
}
bool tflite::MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags,
OpOrArgNameMapper* op_or_arg_name_mapper) {
std::unordered_set<std::string> select_user_tf_ops;
return MlirToFlatBufferTranslateFunction(
module, serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, select_user_tf_ops, saved_model_tags,
op_or_arg_name_mapper);
}
bool tflite::MlirToFlatBufferTranslateFunction(
ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& select_user_tf_ops,
const std::unordered_set<std::string>& saved_model_tags,
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper) {
auto maybe_translated = Translator::Translate( auto maybe_translated = Translator::Translate(
module, emit_builtin_tflite_ops, emit_select_tf_ops, emit_custom_ops, module, options.emit_builtin_tflite_ops, options.emit_select_tf_ops,
select_user_tf_ops, saved_model_tags, op_or_arg_name_mapper); options.emit_custom_ops, options.select_user_tf_ops,
if (!maybe_translated) return true; options.saved_model_tags, options.op_or_arg_name_mapper);
if (!maybe_translated) return false;
*serialized_flatbuffer = std::move(*maybe_translated); *serialized_flatbuffer = std::move(*maybe_translated);
return false; return true;
} }
} // namespace tflite

View File

@ -23,43 +23,26 @@ limitations under the License.
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
namespace tflite { namespace tflite {
// Options for exporting to Flatbuffer.
struct FlatbufferExportOptions {
bool emit_builtin_tflite_ops = false;
bool emit_select_tf_ops = false;
bool emit_custom_ops = false;
// When exporting from SavedModel, this will have the requested tags.
std::unordered_set<std::string> saved_model_tags;
// TF custom op passed by the user.
std::unordered_set<std::string> select_user_tf_ops;
// OpOrArgNameMapper to convert location of the op to name in flatbuffer.
// If not set, a default mapper will be used.
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper = nullptr;
};
// Translates the given MLIR `module` into a FlatBuffer and stores the // Translates the given MLIR `module` into a FlatBuffer and stores the
// serialized flatbuffer into the string. This uses OpOrArgLocNameMapper to // serialized flatbuffer into the string.
// convert location of the op to name in flatbuffer. Returns true if translation // Returns true on successful exporting, false otherwise.
// fails, otherwise returns false.
bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module, bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
std::string* serialized_flatbuffer, const FlatbufferExportOptions& options,
bool emit_builtin_tflite_ops, std::string* serialized_flatbuffer);
bool emit_select_tf_ops,
bool emit_custom_ops);
// Same as above but takes SavedModel tags of the model.
bool MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags);
// Same as the above but with a custom op name mapper.
bool MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
// Same as above but takes SavedModel tags of the model.
bool MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags,
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
// Same as the above but with a list of allowed user's defined ops.
bool MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& select_user_tf_ops,
const std::unordered_set<std::string>& saved_model_tags,
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
} // namespace tflite } // namespace tflite
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ #endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_

View File

@ -158,9 +158,13 @@ static LogicalResult MlirToFlatBufferFileTranslateFunction(
op_or_arg_name_mapper = op_or_arg_name_mapper =
std::make_unique<tensorflow::OpOrArgLocNameMapper>(); std::make_unique<tensorflow::OpOrArgLocNameMapper>();
} }
if (tflite::MlirToFlatBufferTranslateFunction( tflite::FlatbufferExportOptions options;
module, &serialized_flatbuffer, emit_builtin_tflite_ops, options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
emit_select_tf_ops, emit_custom_ops, op_or_arg_name_mapper.get())) options.emit_custom_ops = emit_custom_ops;
options.emit_select_tf_ops = emit_select_tf_ops;
options.op_or_arg_name_mapper = op_or_arg_name_mapper.get();
if (!tflite::MlirToFlatBufferTranslateFunction(module, options,
&serialized_flatbuffer))
return mlir::failure(); return mlir::failure();
output << serialized_flatbuffer; output << serialized_flatbuffer;

View File

@ -118,9 +118,12 @@ int main(int argc, char** argv) {
// Convert to flatbuffer. // Convert to flatbuffer.
std::string serialized_flatbuffer; std::string serialized_flatbuffer;
if (tflite::MlirToFlatBufferTranslateFunction( tflite::FlatbufferExportOptions options;
module.get(), &serialized_flatbuffer, emit_builtin_tflite_ops, options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
emit_select_tf_ops, emit_custom_ops)) options.emit_custom_ops = emit_custom_ops;
options.emit_select_tf_ops = emit_select_tf_ops;
if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options,
&serialized_flatbuffer))
return 1; return 1;
// Create TFLite interpreter & invoke converted program. // Create TFLite interpreter & invoke converted program.

View File

@ -108,9 +108,12 @@ TfLiteStatus QuantizeModel(
// Export the results to the builder // Export the results to the builder
std::string result; std::string result;
if (tflite::MlirToFlatBufferTranslateFunction( tflite::FlatbufferExportOptions options;
module.get(), &result, /*emit_builtin_tflite_ops=*/true, options.emit_builtin_tflite_ops = true;
/*emit_select_tf_ops=*/true, /*emit_custom_ops=*/true)) { options.emit_select_tf_ops = true;
options.emit_custom_ops = true;
if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options,
&result)) {
error_reporter->Report("Failed to export MLIR to flatbuffer."); error_reporter->Report("Failed to export MLIR to flatbuffer.");
return kTfLiteError; return kTfLiteError;
} }

View File

@ -68,9 +68,12 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
// Export the results to the builder // Export the results to the builder
std::string result; std::string result;
if (tflite::MlirToFlatBufferTranslateFunction( tflite::FlatbufferExportOptions options;
module.get(), &result, /*emit_builtin_tflite_ops=*/true, options.emit_builtin_tflite_ops = true;
/*emit_select_tf_ops=*/true, /*emit_custom_ops=*/true)) { options.emit_select_tf_ops = true;
options.emit_custom_ops = true;
if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options,
&result)) {
error_reporter->Report("Failed to export MLIR to flatbuffer."); error_reporter->Report("Failed to export MLIR to flatbuffer.");
return kTfLiteError; return kTfLiteError;
} }

View File

@ -172,20 +172,29 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
// Write MLIR TFLite dialect into FlatBuffer // Write MLIR TFLite dialect into FlatBuffer
OpOrArgLocNameMapper op_or_arg_name_mapper; OpOrArgLocNameMapper op_or_arg_name_mapper;
if (!quant_specs.RunWeightQuantization()) { if (!quant_specs.RunWeightQuantization()) {
if (tflite::MlirToFlatBufferTranslateFunction( tflite::FlatbufferExportOptions options;
module, result, emit_builtin_tflite_ops, emit_select_tf_ops, options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
emit_custom_ops, select_user_tf_ops, saved_model_tags, options.emit_select_tf_ops = emit_select_tf_ops;
&op_or_arg_name_mapper)) { options.select_user_tf_ops = select_user_tf_ops;
options.emit_custom_ops = emit_custom_ops;
options.saved_model_tags = saved_model_tags;
options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result)) {
return statusHandler.ConsumeStatus(); return statusHandler.ConsumeStatus();
} }
} else { } else {
// Post-training weight quantization path. Once MLIR has support for this, // Post-training weight quantization path. Once MLIR has support for this,
// we can remove this else statement. // we can remove this else statement.
std::string pre_quantized_result; std::string pre_quantized_result;
if (tflite::MlirToFlatBufferTranslateFunction( tflite::FlatbufferExportOptions options;
module, &pre_quantized_result, emit_builtin_tflite_ops, options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
emit_select_tf_ops, emit_custom_ops, select_user_tf_ops, options.emit_select_tf_ops = emit_select_tf_ops;
saved_model_tags, &op_or_arg_name_mapper)) { options.select_user_tf_ops = select_user_tf_ops;
options.emit_custom_ops = emit_custom_ops;
options.saved_model_tags = saved_model_tags;
options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
if (!tflite::MlirToFlatBufferTranslateFunction(module, options,
&pre_quantized_result)) {
return statusHandler.ConsumeStatus(); return statusHandler.ConsumeStatus();
} }
flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240); flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);