[lite] Refactor MlirToFlatBufferTranslateFunction and make options struct and pass as argument instead of multiple overloads with multiple arguments.

PiperOrigin-RevId: 357763716
Change-Id: I4b999030d04b111b592d7e22a2a40aee065546db
This commit is contained in:
Karim Nosir 2021-02-16 11:21:30 -08:00 committed by TensorFlower Gardener
parent c3555ff1fc
commit e2cd9cda80
7 changed files with 73 additions and 111 deletions

View File

@ -1710,6 +1710,9 @@ Optional<std::string> Translator::Translate(
const std::unordered_set<std::string>& select_user_tf_ops,
const std::unordered_set<std::string>& tags,
OpOrArgNameMapper* op_or_arg_name_mapper) {
OpOrArgLocNameMapper default_op_or_arg_name_mapper;
if (!op_or_arg_name_mapper)
op_or_arg_name_mapper = &default_op_or_arg_name_mapper;
if (!UpdateEntryFunction(module)) return llvm::None;
if (!IsValidTFLiteMlirModule(module)) return llvm::None;
Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
@ -1942,69 +1945,23 @@ BufferOffset<tflite::SparsityParameters> Translator::BuildSparsityParameters(
} // namespace
// Translates the given MLIR module in the TFLite dialect to TFLite FlatBuffer
// format. Returns false on success.
//
namespace tflite {
// TODO(hinsu): Support all valid MLIR modules in TFLite dialect by supporting
// the following:
//
// * Quantization
// * Ops with variable tensors
//
bool tflite::MlirToFlatBufferTranslateFunction(
ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
OpOrArgNameMapper* op_or_arg_name_mapper) {
return MlirToFlatBufferTranslateFunction(
module, serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{},
op_or_arg_name_mapper);
}
bool tflite::MlirToFlatBufferTranslateFunction(
ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
bool emit_custom_ops) {
OpOrArgLocNameMapper op_or_arg_name_mapper;
return MlirToFlatBufferTranslateFunction(
module, serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{},
&op_or_arg_name_mapper);
}
bool tflite::MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags) {
OpOrArgLocNameMapper op_or_arg_name_mapper;
return MlirToFlatBufferTranslateFunction(
module, serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, saved_model_tags,
&op_or_arg_name_mapper);
}
bool tflite::MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags,
OpOrArgNameMapper* op_or_arg_name_mapper) {
std::unordered_set<std::string> select_user_tf_ops;
return MlirToFlatBufferTranslateFunction(
module, serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, select_user_tf_ops, saved_model_tags,
op_or_arg_name_mapper);
}
bool tflite::MlirToFlatBufferTranslateFunction(
ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& select_user_tf_ops,
const std::unordered_set<std::string>& saved_model_tags,
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper) {
bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
const FlatbufferExportOptions& options,
std::string* serialized_flatbuffer) {
auto maybe_translated = Translator::Translate(
module, emit_builtin_tflite_ops, emit_select_tf_ops, emit_custom_ops,
select_user_tf_ops, saved_model_tags, op_or_arg_name_mapper);
if (!maybe_translated) return true;
module, options.emit_builtin_tflite_ops, options.emit_select_tf_ops,
options.emit_custom_ops, options.select_user_tf_ops,
options.saved_model_tags, options.op_or_arg_name_mapper);
if (!maybe_translated) return false;
*serialized_flatbuffer = std::move(*maybe_translated);
return false;
return true;
}
} // namespace tflite

View File

@ -23,43 +23,26 @@ limitations under the License.
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
namespace tflite {
// Options for exporting to Flatbuffer.
struct FlatbufferExportOptions {
bool emit_builtin_tflite_ops = false;
bool emit_select_tf_ops = false;
bool emit_custom_ops = false;
// When exporting from SavedModel, this will have the requested tags.
std::unordered_set<std::string> saved_model_tags;
// TF custom op passed by the user.
std::unordered_set<std::string> select_user_tf_ops;
// OpOrArgNameMapper to convert location of the op to name in flatbuffer.
// If not set, a default mapper will be used.
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper = nullptr;
};
// Translates the given MLIR `module` into a FlatBuffer and stores the
// serialized flatbuffer into the string. This uses OpOrArgLocNameMapper to
// convert location of the op to name in flatbuffer. Returns true if translation
// fails, otherwise returns false.
// serialized flatbuffer into the string.
// Returns true on successful exporting, false otherwise.
bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops,
bool emit_select_tf_ops,
bool emit_custom_ops);
// Same as above but takes SavedModel tags of the model.
bool MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags);
// Same as the above but with a custom op name mapper.
bool MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
// Same as above but takes SavedModel tags of the model.
bool MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags,
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
// Same as the above but with a list of allowed user's defined ops.
bool MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& select_user_tf_ops,
const std::unordered_set<std::string>& saved_model_tags,
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
const FlatbufferExportOptions& options,
std::string* serialized_flatbuffer);
} // namespace tflite
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_

View File

@ -158,9 +158,13 @@ static LogicalResult MlirToFlatBufferFileTranslateFunction(
op_or_arg_name_mapper =
std::make_unique<tensorflow::OpOrArgLocNameMapper>();
}
if (tflite::MlirToFlatBufferTranslateFunction(
module, &serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, op_or_arg_name_mapper.get()))
tflite::FlatbufferExportOptions options;
options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
options.emit_custom_ops = emit_custom_ops;
options.emit_select_tf_ops = emit_select_tf_ops;
options.op_or_arg_name_mapper = op_or_arg_name_mapper.get();
if (!tflite::MlirToFlatBufferTranslateFunction(module, options,
&serialized_flatbuffer))
return mlir::failure();
output << serialized_flatbuffer;

View File

@ -118,9 +118,12 @@ int main(int argc, char** argv) {
// Convert to flatbuffer.
std::string serialized_flatbuffer;
if (tflite::MlirToFlatBufferTranslateFunction(
module.get(), &serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops))
tflite::FlatbufferExportOptions options;
options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
options.emit_custom_ops = emit_custom_ops;
options.emit_select_tf_ops = emit_select_tf_ops;
if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options,
&serialized_flatbuffer))
return 1;
// Create TFLite interpreter & invoke converted program.

View File

@ -108,9 +108,12 @@ TfLiteStatus QuantizeModel(
// Export the results to the builder
std::string result;
if (tflite::MlirToFlatBufferTranslateFunction(
module.get(), &result, /*emit_builtin_tflite_ops=*/true,
/*emit_select_tf_ops=*/true, /*emit_custom_ops=*/true)) {
tflite::FlatbufferExportOptions options;
options.emit_builtin_tflite_ops = true;
options.emit_select_tf_ops = true;
options.emit_custom_ops = true;
if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options,
&result)) {
error_reporter->Report("Failed to export MLIR to flatbuffer.");
return kTfLiteError;
}

View File

@ -68,9 +68,12 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
// Export the results to the builder
std::string result;
if (tflite::MlirToFlatBufferTranslateFunction(
module.get(), &result, /*emit_builtin_tflite_ops=*/true,
/*emit_select_tf_ops=*/true, /*emit_custom_ops=*/true)) {
tflite::FlatbufferExportOptions options;
options.emit_builtin_tflite_ops = true;
options.emit_select_tf_ops = true;
options.emit_custom_ops = true;
if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options,
&result)) {
error_reporter->Report("Failed to export MLIR to flatbuffer.");
return kTfLiteError;
}

View File

@ -172,20 +172,29 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
// Write MLIR TFLite dialect into FlatBuffer
OpOrArgLocNameMapper op_or_arg_name_mapper;
if (!quant_specs.RunWeightQuantization()) {
if (tflite::MlirToFlatBufferTranslateFunction(
module, result, emit_builtin_tflite_ops, emit_select_tf_ops,
emit_custom_ops, select_user_tf_ops, saved_model_tags,
&op_or_arg_name_mapper)) {
tflite::FlatbufferExportOptions options;
options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
options.emit_select_tf_ops = emit_select_tf_ops;
options.select_user_tf_ops = select_user_tf_ops;
options.emit_custom_ops = emit_custom_ops;
options.saved_model_tags = saved_model_tags;
options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result)) {
return statusHandler.ConsumeStatus();
}
} else {
// Post-training weight quantization path. Once MLIR has support for this,
// we can remove this else statement.
std::string pre_quantized_result;
if (tflite::MlirToFlatBufferTranslateFunction(
module, &pre_quantized_result, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, select_user_tf_ops,
saved_model_tags, &op_or_arg_name_mapper)) {
tflite::FlatbufferExportOptions options;
options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
options.emit_select_tf_ops = emit_select_tf_ops;
options.select_user_tf_ops = select_user_tf_ops;
options.emit_custom_ops = emit_custom_ops;
options.saved_model_tags = saved_model_tags;
options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
if (!tflite::MlirToFlatBufferTranslateFunction(module, options,
&pre_quantized_result)) {
return statusHandler.ConsumeStatus();
}
flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);