[lite] Refactor MlirToFlatBufferTranslateFunction and make options struct and pass as argument instead of multiple overloads with multiple arguments.
PiperOrigin-RevId: 357763716 Change-Id: I4b999030d04b111b592d7e22a2a40aee065546db
This commit is contained in:
parent
c3555ff1fc
commit
e2cd9cda80
@ -1710,6 +1710,9 @@ Optional<std::string> Translator::Translate(
|
|||||||
const std::unordered_set<std::string>& select_user_tf_ops,
|
const std::unordered_set<std::string>& select_user_tf_ops,
|
||||||
const std::unordered_set<std::string>& tags,
|
const std::unordered_set<std::string>& tags,
|
||||||
OpOrArgNameMapper* op_or_arg_name_mapper) {
|
OpOrArgNameMapper* op_or_arg_name_mapper) {
|
||||||
|
OpOrArgLocNameMapper default_op_or_arg_name_mapper;
|
||||||
|
if (!op_or_arg_name_mapper)
|
||||||
|
op_or_arg_name_mapper = &default_op_or_arg_name_mapper;
|
||||||
if (!UpdateEntryFunction(module)) return llvm::None;
|
if (!UpdateEntryFunction(module)) return llvm::None;
|
||||||
if (!IsValidTFLiteMlirModule(module)) return llvm::None;
|
if (!IsValidTFLiteMlirModule(module)) return llvm::None;
|
||||||
Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
|
Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
|
||||||
@ -1942,69 +1945,23 @@ BufferOffset<tflite::SparsityParameters> Translator::BuildSparsityParameters(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Translates the given MLIR module in the TFLite dialect to TFLite FlatBuffer
|
namespace tflite {
|
||||||
// format. Returns false on success.
|
|
||||||
//
|
|
||||||
// TODO(hinsu): Support all valid MLIR modules in TFLite dialect by supporting
|
// TODO(hinsu): Support all valid MLIR modules in TFLite dialect by supporting
|
||||||
// the following:
|
// the following:
|
||||||
//
|
//
|
||||||
// * Quantization
|
// * Quantization
|
||||||
// * Ops with variable tensors
|
// * Ops with variable tensors
|
||||||
//
|
//
|
||||||
bool tflite::MlirToFlatBufferTranslateFunction(
|
bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
|
||||||
ModuleOp module, std::string* serialized_flatbuffer,
|
const FlatbufferExportOptions& options,
|
||||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
std::string* serialized_flatbuffer) {
|
||||||
OpOrArgNameMapper* op_or_arg_name_mapper) {
|
|
||||||
return MlirToFlatBufferTranslateFunction(
|
|
||||||
module, serialized_flatbuffer, emit_builtin_tflite_ops,
|
|
||||||
emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{},
|
|
||||||
op_or_arg_name_mapper);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool tflite::MlirToFlatBufferTranslateFunction(
|
|
||||||
ModuleOp module, std::string* serialized_flatbuffer,
|
|
||||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
|
|
||||||
bool emit_custom_ops) {
|
|
||||||
OpOrArgLocNameMapper op_or_arg_name_mapper;
|
|
||||||
return MlirToFlatBufferTranslateFunction(
|
|
||||||
module, serialized_flatbuffer, emit_builtin_tflite_ops,
|
|
||||||
emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{},
|
|
||||||
&op_or_arg_name_mapper);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool tflite::MlirToFlatBufferTranslateFunction(
|
|
||||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
|
||||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
|
||||||
const std::unordered_set<std::string>& saved_model_tags) {
|
|
||||||
OpOrArgLocNameMapper op_or_arg_name_mapper;
|
|
||||||
return MlirToFlatBufferTranslateFunction(
|
|
||||||
module, serialized_flatbuffer, emit_builtin_tflite_ops,
|
|
||||||
emit_select_tf_ops, emit_custom_ops, saved_model_tags,
|
|
||||||
&op_or_arg_name_mapper);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool tflite::MlirToFlatBufferTranslateFunction(
|
|
||||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
|
||||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
|
||||||
const std::unordered_set<std::string>& saved_model_tags,
|
|
||||||
OpOrArgNameMapper* op_or_arg_name_mapper) {
|
|
||||||
std::unordered_set<std::string> select_user_tf_ops;
|
|
||||||
return MlirToFlatBufferTranslateFunction(
|
|
||||||
module, serialized_flatbuffer, emit_builtin_tflite_ops,
|
|
||||||
emit_select_tf_ops, emit_custom_ops, select_user_tf_ops, saved_model_tags,
|
|
||||||
op_or_arg_name_mapper);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool tflite::MlirToFlatBufferTranslateFunction(
|
|
||||||
ModuleOp module, std::string* serialized_flatbuffer,
|
|
||||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
|
||||||
const std::unordered_set<std::string>& select_user_tf_ops,
|
|
||||||
const std::unordered_set<std::string>& saved_model_tags,
|
|
||||||
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper) {
|
|
||||||
auto maybe_translated = Translator::Translate(
|
auto maybe_translated = Translator::Translate(
|
||||||
module, emit_builtin_tflite_ops, emit_select_tf_ops, emit_custom_ops,
|
module, options.emit_builtin_tflite_ops, options.emit_select_tf_ops,
|
||||||
select_user_tf_ops, saved_model_tags, op_or_arg_name_mapper);
|
options.emit_custom_ops, options.select_user_tf_ops,
|
||||||
if (!maybe_translated) return true;
|
options.saved_model_tags, options.op_or_arg_name_mapper);
|
||||||
|
if (!maybe_translated) return false;
|
||||||
*serialized_flatbuffer = std::move(*maybe_translated);
|
*serialized_flatbuffer = std::move(*maybe_translated);
|
||||||
return false;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
|
@ -23,43 +23,26 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
// Options for exporting to Flatbuffer.
|
||||||
|
struct FlatbufferExportOptions {
|
||||||
|
bool emit_builtin_tflite_ops = false;
|
||||||
|
bool emit_select_tf_ops = false;
|
||||||
|
bool emit_custom_ops = false;
|
||||||
|
// When exporting from SavedModel, this will have the requested tags.
|
||||||
|
std::unordered_set<std::string> saved_model_tags;
|
||||||
|
// TF custom op passed by the user.
|
||||||
|
std::unordered_set<std::string> select_user_tf_ops;
|
||||||
|
// OpOrArgNameMapper to convert location of the op to name in flatbuffer.
|
||||||
|
// If not set, a default mapper will be used.
|
||||||
|
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
// Translates the given MLIR `module` into a FlatBuffer and stores the
|
// Translates the given MLIR `module` into a FlatBuffer and stores the
|
||||||
// serialized flatbuffer into the string. This uses OpOrArgLocNameMapper to
|
// serialized flatbuffer into the string.
|
||||||
// convert location of the op to name in flatbuffer. Returns true if translation
|
// Returns true on successful exporting, false otherwise.
|
||||||
// fails, otherwise returns false.
|
|
||||||
bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
|
bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
|
||||||
std::string* serialized_flatbuffer,
|
const FlatbufferExportOptions& options,
|
||||||
bool emit_builtin_tflite_ops,
|
std::string* serialized_flatbuffer);
|
||||||
bool emit_select_tf_ops,
|
|
||||||
bool emit_custom_ops);
|
|
||||||
|
|
||||||
// Same as above but takes SavedModel tags of the model.
|
|
||||||
bool MlirToFlatBufferTranslateFunction(
|
|
||||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
|
||||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
|
||||||
const std::unordered_set<std::string>& saved_model_tags);
|
|
||||||
|
|
||||||
// Same as the above but with a custom op name mapper.
|
|
||||||
bool MlirToFlatBufferTranslateFunction(
|
|
||||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
|
||||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
|
||||||
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
|
|
||||||
|
|
||||||
// Same as above but takes SavedModel tags of the model.
|
|
||||||
bool MlirToFlatBufferTranslateFunction(
|
|
||||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
|
||||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
|
||||||
const std::unordered_set<std::string>& saved_model_tags,
|
|
||||||
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
|
|
||||||
|
|
||||||
// Same as the above but with a list of allowed user's defined ops.
|
|
||||||
bool MlirToFlatBufferTranslateFunction(
|
|
||||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
|
||||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
|
||||||
const std::unordered_set<std::string>& select_user_tf_ops,
|
|
||||||
const std::unordered_set<std::string>& saved_model_tags,
|
|
||||||
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_
|
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_
|
||||||
|
@ -158,9 +158,13 @@ static LogicalResult MlirToFlatBufferFileTranslateFunction(
|
|||||||
op_or_arg_name_mapper =
|
op_or_arg_name_mapper =
|
||||||
std::make_unique<tensorflow::OpOrArgLocNameMapper>();
|
std::make_unique<tensorflow::OpOrArgLocNameMapper>();
|
||||||
}
|
}
|
||||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
tflite::FlatbufferExportOptions options;
|
||||||
module, &serialized_flatbuffer, emit_builtin_tflite_ops,
|
options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||||
emit_select_tf_ops, emit_custom_ops, op_or_arg_name_mapper.get()))
|
options.emit_custom_ops = emit_custom_ops;
|
||||||
|
options.emit_select_tf_ops = emit_select_tf_ops;
|
||||||
|
options.op_or_arg_name_mapper = op_or_arg_name_mapper.get();
|
||||||
|
if (!tflite::MlirToFlatBufferTranslateFunction(module, options,
|
||||||
|
&serialized_flatbuffer))
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
|
||||||
output << serialized_flatbuffer;
|
output << serialized_flatbuffer;
|
||||||
|
@ -118,9 +118,12 @@ int main(int argc, char** argv) {
|
|||||||
|
|
||||||
// Convert to flatbuffer.
|
// Convert to flatbuffer.
|
||||||
std::string serialized_flatbuffer;
|
std::string serialized_flatbuffer;
|
||||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
tflite::FlatbufferExportOptions options;
|
||||||
module.get(), &serialized_flatbuffer, emit_builtin_tflite_ops,
|
options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||||
emit_select_tf_ops, emit_custom_ops))
|
options.emit_custom_ops = emit_custom_ops;
|
||||||
|
options.emit_select_tf_ops = emit_select_tf_ops;
|
||||||
|
if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options,
|
||||||
|
&serialized_flatbuffer))
|
||||||
return 1;
|
return 1;
|
||||||
|
|
||||||
// Create TFLite interpreter & invoke converted program.
|
// Create TFLite interpreter & invoke converted program.
|
||||||
|
@ -108,9 +108,12 @@ TfLiteStatus QuantizeModel(
|
|||||||
|
|
||||||
// Export the results to the builder
|
// Export the results to the builder
|
||||||
std::string result;
|
std::string result;
|
||||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
tflite::FlatbufferExportOptions options;
|
||||||
module.get(), &result, /*emit_builtin_tflite_ops=*/true,
|
options.emit_builtin_tflite_ops = true;
|
||||||
/*emit_select_tf_ops=*/true, /*emit_custom_ops=*/true)) {
|
options.emit_select_tf_ops = true;
|
||||||
|
options.emit_custom_ops = true;
|
||||||
|
if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options,
|
||||||
|
&result)) {
|
||||||
error_reporter->Report("Failed to export MLIR to flatbuffer.");
|
error_reporter->Report("Failed to export MLIR to flatbuffer.");
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
|
@ -68,9 +68,12 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
|
|||||||
|
|
||||||
// Export the results to the builder
|
// Export the results to the builder
|
||||||
std::string result;
|
std::string result;
|
||||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
tflite::FlatbufferExportOptions options;
|
||||||
module.get(), &result, /*emit_builtin_tflite_ops=*/true,
|
options.emit_builtin_tflite_ops = true;
|
||||||
/*emit_select_tf_ops=*/true, /*emit_custom_ops=*/true)) {
|
options.emit_select_tf_ops = true;
|
||||||
|
options.emit_custom_ops = true;
|
||||||
|
if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options,
|
||||||
|
&result)) {
|
||||||
error_reporter->Report("Failed to export MLIR to flatbuffer.");
|
error_reporter->Report("Failed to export MLIR to flatbuffer.");
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
|
@ -172,20 +172,29 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
|
|||||||
// Write MLIR TFLite dialect into FlatBuffer
|
// Write MLIR TFLite dialect into FlatBuffer
|
||||||
OpOrArgLocNameMapper op_or_arg_name_mapper;
|
OpOrArgLocNameMapper op_or_arg_name_mapper;
|
||||||
if (!quant_specs.RunWeightQuantization()) {
|
if (!quant_specs.RunWeightQuantization()) {
|
||||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
tflite::FlatbufferExportOptions options;
|
||||||
module, result, emit_builtin_tflite_ops, emit_select_tf_ops,
|
options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||||
emit_custom_ops, select_user_tf_ops, saved_model_tags,
|
options.emit_select_tf_ops = emit_select_tf_ops;
|
||||||
&op_or_arg_name_mapper)) {
|
options.select_user_tf_ops = select_user_tf_ops;
|
||||||
|
options.emit_custom_ops = emit_custom_ops;
|
||||||
|
options.saved_model_tags = saved_model_tags;
|
||||||
|
options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
|
||||||
|
if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result)) {
|
||||||
return statusHandler.ConsumeStatus();
|
return statusHandler.ConsumeStatus();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Post-training weight quantization path. Once MLIR has support for this,
|
// Post-training weight quantization path. Once MLIR has support for this,
|
||||||
// we can remove this else statement.
|
// we can remove this else statement.
|
||||||
std::string pre_quantized_result;
|
std::string pre_quantized_result;
|
||||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
tflite::FlatbufferExportOptions options;
|
||||||
module, &pre_quantized_result, emit_builtin_tflite_ops,
|
options.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||||
emit_select_tf_ops, emit_custom_ops, select_user_tf_ops,
|
options.emit_select_tf_ops = emit_select_tf_ops;
|
||||||
saved_model_tags, &op_or_arg_name_mapper)) {
|
options.select_user_tf_ops = select_user_tf_ops;
|
||||||
|
options.emit_custom_ops = emit_custom_ops;
|
||||||
|
options.saved_model_tags = saved_model_tags;
|
||||||
|
options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
|
||||||
|
if (!tflite::MlirToFlatBufferTranslateFunction(module, options,
|
||||||
|
&pre_quantized_result)) {
|
||||||
return statusHandler.ConsumeStatus();
|
return statusHandler.ConsumeStatus();
|
||||||
}
|
}
|
||||||
flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);
|
flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user