[lite] Refactor MlirToFlatBufferTranslateFunction and make options struct and pass as argument instead of multiple overloads with multiple arguments.
PiperOrigin-RevId: 357763716 Change-Id: I4b999030d04b111b592d7e22a2a40aee065546db
This commit is contained in:
parent
c3555ff1fc
commit
e2cd9cda80
@ -1710,6 +1710,9 @@ Optional<std::string> Translator::Translate(
|
||||
const std::unordered_set<std::string>& select_user_tf_ops,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
OpOrArgNameMapper* op_or_arg_name_mapper) {
|
||||
OpOrArgLocNameMapper default_op_or_arg_name_mapper;
|
||||
if (!op_or_arg_name_mapper)
|
||||
op_or_arg_name_mapper = &default_op_or_arg_name_mapper;
|
||||
if (!UpdateEntryFunction(module)) return llvm::None;
|
||||
if (!IsValidTFLiteMlirModule(module)) return llvm::None;
|
||||
Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
|
||||
@ -1942,69 +1945,23 @@ BufferOffset<tflite::SparsityParameters> Translator::BuildSparsityParameters(
|
||||
|
||||
} // namespace
|
||||
|
||||
// Translates the given MLIR module in the TFLite dialect to TFLite FlatBuffer
|
||||
// format. Returns false on success.
|
||||
//
|
||||
namespace tflite {
|
||||
// TODO(hinsu): Support all valid MLIR modules in TFLite dialect by supporting
|
||||
// the following:
|
||||
//
|
||||
// * Quantization
|
||||
// * Ops with variable tensors
|
||||
//
|
||||
bool tflite::MlirToFlatBufferTranslateFunction(
|
||||
ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
OpOrArgNameMapper* op_or_arg_name_mapper) {
|
||||
return MlirToFlatBufferTranslateFunction(
|
||||
module, serialized_flatbuffer, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{},
|
||||
op_or_arg_name_mapper);
|
||||
}
|
||||
|
||||
bool tflite::MlirToFlatBufferTranslateFunction(
|
||||
ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
|
||||
bool emit_custom_ops) {
|
||||
OpOrArgLocNameMapper op_or_arg_name_mapper;
|
||||
return MlirToFlatBufferTranslateFunction(
|
||||
module, serialized_flatbuffer, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{},
|
||||
&op_or_arg_name_mapper);
|
||||
}
|
||||
|
||||
bool tflite::MlirToFlatBufferTranslateFunction(
|
||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags) {
|
||||
OpOrArgLocNameMapper op_or_arg_name_mapper;
|
||||
return MlirToFlatBufferTranslateFunction(
|
||||
module, serialized_flatbuffer, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, saved_model_tags,
|
||||
&op_or_arg_name_mapper);
|
||||
}
|
||||
|
||||
bool tflite::MlirToFlatBufferTranslateFunction(
|
||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
OpOrArgNameMapper* op_or_arg_name_mapper) {
|
||||
std::unordered_set<std::string> select_user_tf_ops;
|
||||
return MlirToFlatBufferTranslateFunction(
|
||||
module, serialized_flatbuffer, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, select_user_tf_ops, saved_model_tags,
|
||||
op_or_arg_name_mapper);
|
||||
}
|
||||
|
||||
bool tflite::MlirToFlatBufferTranslateFunction(
|
||||
ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& select_user_tf_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper) {
|
||||
bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
|
||||
const FlatbufferExportOptions& options,
|
||||
std::string* serialized_flatbuffer) {
|
||||
auto maybe_translated = Translator::Translate(
|
||||
module, emit_builtin_tflite_ops, emit_select_tf_ops, emit_custom_ops,
|
||||
select_user_tf_ops, saved_model_tags, op_or_arg_name_mapper);
|
||||
if (!maybe_translated) return true;
|
||||
module, options.emit_builtin_tflite_ops, options.emit_select_tf_ops,
|
||||
options.emit_custom_ops, options.select_user_tf_ops,
|
||||
options.saved_model_tags, options.op_or_arg_name_mapper);
|
||||
if (!maybe_translated) return false;
|
||||
*serialized_flatbuffer = std::move(*maybe_translated);
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -23,43 +23,26 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||
|
||||
namespace tflite {
|
||||
// Options for exporting to Flatbuffer.
|
||||
struct FlatbufferExportOptions {
|
||||
bool emit_builtin_tflite_ops = false;
|
||||
bool emit_select_tf_ops = false;
|
||||
bool emit_custom_ops = false;
|
||||
// When exporting from SavedModel, this will have the requested tags.
|
||||
std::unordered_set<std::string> saved_model_tags;
|
||||
// TF custom op passed by the user.
|
||||
std::unordered_set<std::string> select_user_tf_ops;
|
||||
// OpOrArgNameMapper to convert location of the op to name in flatbuffer.
|
||||
// If not set, a default mapper will be used.
|
||||
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper = nullptr;
|
||||
};
|
||||
|
||||
// Translates the given MLIR `module` into a FlatBuffer and stores the
|
||||
// serialized flatbuffer into the string. This uses OpOrArgLocNameMapper to
|
||||
// convert location of the op to name in flatbuffer. Returns true if translation
|
||||
// fails, otherwise returns false.
|
||||
// serialized flatbuffer into the string.
|
||||
// Returns true on successful exporting, false otherwise.
|
||||
bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
|
||||
std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops,
|
||||
bool emit_select_tf_ops,
|
||||
bool emit_custom_ops);
|
||||
|
||||
// Same as above but takes SavedModel tags of the model.
|
||||
bool MlirToFlatBufferTranslateFunction(
|
||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags);
|
||||
|
||||
// Same as the above but with a custom op name mapper.
|
||||
bool MlirToFlatBufferTranslateFunction(
|
||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
|
||||
|
||||
// Same as above but takes SavedModel tags of the model.
|
||||
bool MlirToFlatBufferTranslateFunction(
|
||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
|
||||
|
||||
// Same as the above but with a list of allowed user's defined ops.
|
||||
bool MlirToFlatBufferTranslateFunction(
|
||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& select_user_tf_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
|
||||
const FlatbufferExportOptions& options,
|
||||
std::string* serialized_flatbuffer);
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_
|
||||
|
@ -158,9 +158,13 @@ static LogicalResult MlirToFlatBufferFileTranslateFunction(
|
||||
op_or_arg_name_mapper =
|
||||
std::make_unique<tensorflow::OpOrArgLocNameMapper>();
|
||||
}
|
||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
||||
module, &serialized_flatbuffer, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, op_or_arg_name_mapper.get()))
|
||||
tflite::FlatbufferExportOptions options;
|
||||
options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
options.emit_custom_ops = emit_custom_ops;
|
||||
options.emit_select_tf_ops = emit_select_tf_ops;
|
||||
options.op_or_arg_name_mapper = op_or_arg_name_mapper.get();
|
||||
if (!tflite::MlirToFlatBufferTranslateFunction(module, options,
|
||||
&serialized_flatbuffer))
|
||||
return mlir::failure();
|
||||
|
||||
output << serialized_flatbuffer;
|
||||
|
@ -118,9 +118,12 @@ int main(int argc, char** argv) {
|
||||
|
||||
// Convert to flatbuffer.
|
||||
std::string serialized_flatbuffer;
|
||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
||||
module.get(), &serialized_flatbuffer, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops))
|
||||
tflite::FlatbufferExportOptions options;
|
||||
options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
options.emit_custom_ops = emit_custom_ops;
|
||||
options.emit_select_tf_ops = emit_select_tf_ops;
|
||||
if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options,
|
||||
&serialized_flatbuffer))
|
||||
return 1;
|
||||
|
||||
// Create TFLite interpreter & invoke converted program.
|
||||
|
@ -108,9 +108,12 @@ TfLiteStatus QuantizeModel(
|
||||
|
||||
// Export the results to the builder
|
||||
std::string result;
|
||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
||||
module.get(), &result, /*emit_builtin_tflite_ops=*/true,
|
||||
/*emit_select_tf_ops=*/true, /*emit_custom_ops=*/true)) {
|
||||
tflite::FlatbufferExportOptions options;
|
||||
options.emit_builtin_tflite_ops = true;
|
||||
options.emit_select_tf_ops = true;
|
||||
options.emit_custom_ops = true;
|
||||
if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options,
|
||||
&result)) {
|
||||
error_reporter->Report("Failed to export MLIR to flatbuffer.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
@ -68,9 +68,12 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
|
||||
|
||||
// Export the results to the builder
|
||||
std::string result;
|
||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
||||
module.get(), &result, /*emit_builtin_tflite_ops=*/true,
|
||||
/*emit_select_tf_ops=*/true, /*emit_custom_ops=*/true)) {
|
||||
tflite::FlatbufferExportOptions options;
|
||||
options.emit_builtin_tflite_ops = true;
|
||||
options.emit_select_tf_ops = true;
|
||||
options.emit_custom_ops = true;
|
||||
if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options,
|
||||
&result)) {
|
||||
error_reporter->Report("Failed to export MLIR to flatbuffer.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
@ -172,20 +172,29 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
// Write MLIR TFLite dialect into FlatBuffer
|
||||
OpOrArgLocNameMapper op_or_arg_name_mapper;
|
||||
if (!quant_specs.RunWeightQuantization()) {
|
||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
||||
module, result, emit_builtin_tflite_ops, emit_select_tf_ops,
|
||||
emit_custom_ops, select_user_tf_ops, saved_model_tags,
|
||||
&op_or_arg_name_mapper)) {
|
||||
tflite::FlatbufferExportOptions options;
|
||||
options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
options.emit_select_tf_ops = emit_select_tf_ops;
|
||||
options.select_user_tf_ops = select_user_tf_ops;
|
||||
options.emit_custom_ops = emit_custom_ops;
|
||||
options.saved_model_tags = saved_model_tags;
|
||||
options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
|
||||
if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result)) {
|
||||
return statusHandler.ConsumeStatus();
|
||||
}
|
||||
} else {
|
||||
// Post-training weight quantization path. Once MLIR has support for this,
|
||||
// we can remove this else statement.
|
||||
std::string pre_quantized_result;
|
||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
||||
module, &pre_quantized_result, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, select_user_tf_ops,
|
||||
saved_model_tags, &op_or_arg_name_mapper)) {
|
||||
tflite::FlatbufferExportOptions options;
|
||||
options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
options.emit_select_tf_ops = emit_select_tf_ops;
|
||||
options.select_user_tf_ops = select_user_tf_ops;
|
||||
options.emit_custom_ops = emit_custom_ops;
|
||||
options.saved_model_tags = saved_model_tags;
|
||||
options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
|
||||
if (!tflite::MlirToFlatBufferTranslateFunction(module, options,
|
||||
&pre_quantized_result)) {
|
||||
return statusHandler.ConsumeStatus();
|
||||
}
|
||||
flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);
|
||||
|
Loading…
x
Reference in New Issue
Block a user