Update TFLite converter to export saved model signature def to the tflite schema.

PiperOrigin-RevId: 337402305
Change-Id: I81834461f54c7be1c114dbffcb1ad31f08e0d2be
This commit is contained in:
Karim Nosir 2020-10-15 16:07:22 -07:00 committed by TensorFlower Gardener
parent ed408b579e
commit a70f132c8f
43 changed files with 412 additions and 33 deletions

View File

@ -326,6 +326,21 @@ static Optional<TfLitePoolParams> GetTflitePoolParams(Operation* inst,
namespace {
// Helper struct that wraps inputs/outputs of a single SignatureDef.
struct SignatureDefData {
// Note, we are using maps here to make order deterministic
// for easily testing only.
// Inputs defined in the signature def mapped to tensor names.
std::map<std::string, std::string> inputs;
// Outputs defined in the signature def mapped to tensor names.
std::map<std::string, std::string> outputs;
// Method name exported by the signature def.
std::string method_name;
// SignatureDef key.
std::string signature_def_key;
};
// Translates an MLIR module in TFLite dialect to TFLite FlatBuffer.
class Translator {
public:
@ -334,16 +349,19 @@ class Translator {
// internal error.
static Optional<std::string> Translate(
ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper);
bool emit_custom_ops, const std::unordered_set<std::string>& tags,
OpOrArgNameMapper* op_or_arg_name_mapper);
private:
enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp };
explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops,
bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags,
OpOrArgNameMapper* op_or_arg_name_mapper)
: module_(module),
name_mapper_(*op_or_arg_name_mapper),
builder_(kInitialBufferSize) {
builder_(kInitialBufferSize),
saved_model_tags_(saved_model_tags) {
// The first buffer must be empty according to the schema definition.
empty_buffer_ = tflite::CreateBuffer(builder_);
buffers_.push_back(empty_buffer_);
@ -450,6 +468,17 @@ class Translator {
Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
CreateMetadataVector();
// Builds and returns list of tfl.SignatureDef sections in the model.
Optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>>
CreateSignatureDefs(const std::vector<SignatureDefData>& signature_defs);
// Returns list of offsets for the passed 'items' in TensorMap structure
// inside the flatbuffer.
// 'items' is a map from tensor name in signatureDef to tensor name in
// the model.
std::vector<BufferOffset<tflite::TensorMap>> GetList(
const std::map<std::string, std::string>& items);
// Uses the tf.entry_function attribute (if set) to initialize the op to name
// mapping.
void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr);
@ -472,6 +501,8 @@ class Translator {
BufferOffset<tflite::Buffer> empty_buffer_;
std::vector<BufferOffset<tflite::Buffer>> buffers_;
// Maps tensor name in the graph to the tensor index.
absl::flat_hash_map<std::string, int> tensor_index_map_;
// Maps op name to index of the corresponding OperatorCode in opcodes_ vector.
absl::flat_hash_map<std::string, uint32_t> opcode_index_map_;
@ -490,6 +521,9 @@ class Translator {
// The failed ops during legalization.
std::set<std::string> failed_flex_ops_;
std::set<std::string> failed_custom_ops_;
// Set of saved model tags, if any.
const std::unordered_set<std::string> saved_model_tags_;
};
std::string Translator::UniqueName(mlir::Value val) {
@ -1131,6 +1165,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
}
tensor_index_map.insert({value, tensors.size()});
tensor_index_map_[name] = tensors.size();
auto tensor_or = BuildTensor(value, name, buffers_.size());
if (!tensor_or) return false;
tensors.push_back(*tensor_or);
@ -1286,6 +1321,149 @@ Translator::CreateMetadataVector() {
return builder_.CreateVector(metadata);
}
// Helper method that returns list of all strings in a StringAttr identified
// by 'attr_key' and values are separated by a comma.
llvm::SmallVector<llvm::StringRef, 2> GetStringsFromAttrWithSeparator(
mlir::DictionaryAttr attr, const std::string& attr_key) {
llvm::SmallVector<llvm::StringRef, 2> result;
if (auto str = attr.get(attr_key).dyn_cast_or_null<mlir::StringAttr>()) {
str.getValue().split(result, ',', /*MaxSplit=*/-1,
/*KeepEmpty=*/false);
}
return result;
}
// Helper method that return list of string for all the StringAttr in the
// Attribute identified by 'attr_name'.
std::vector<std::string> GetStringsFromDictionaryAttr(
const llvm::SmallVector<mlir::MutableDictionaryAttr, 4>& dict_attrs,
const std::string& attr_name) {
std::vector<std::string> result;
for (const auto& arg_attr : dict_attrs) {
auto attrs = arg_attr.getAttrs();
for (const auto attr : attrs) {
if (attr.first.str() == attr_name) {
auto array_attr = attr.second.dyn_cast_or_null<mlir::ArrayAttr>();
if (!array_attr || array_attr.empty()) continue;
auto string_attr = array_attr[0].dyn_cast_or_null<mlir::StringAttr>();
if (!string_attr) continue;
result.push_back(string_attr.getValue().str());
}
}
}
return result;
}
std::vector<SignatureDefData> BuildSignaturedef(
FuncOp main_op, const std::string& saved_model_tag) {
static const char kSignatureDefIndexPath[] = "tf_saved_model.index_path";
static const char kEntryFunctionAttributes[] = "tf.entry_function";
// Fetch inputs and outputs from the signature.
llvm::SmallVector<mlir::MutableDictionaryAttr, 4> arg_attrs, res_attrs;
main_op.getAllArgAttrs(arg_attrs);
main_op.getAllResultAttrs(res_attrs);
std::vector<std::string> sig_def_inputs =
GetStringsFromDictionaryAttr(arg_attrs, kSignatureDefIndexPath);
std::vector<std::string> sig_def_outputs =
GetStringsFromDictionaryAttr(res_attrs, kSignatureDefIndexPath);
// If no defined saved model signature, then return empty list.
// This can happen when we are converting model not from SavedModel.
if (sig_def_inputs.empty() || sig_def_outputs.empty()) return {};
// Fetch function inputs and outputs tensor names.
auto dict_attr =
main_op.getAttrOfType<mlir::DictionaryAttr>(kEntryFunctionAttributes);
if (!dict_attr) return {};
// Get Input and output tensor names from attribute.
llvm::SmallVector<llvm::StringRef, 2> input_names =
GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"inputs");
llvm::SmallVector<llvm::StringRef, 2> output_names =
GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"outputs");
// Verify input size match the number of arguments.
if (input_names.size() != main_op.getNumArguments()) {
main_op.emitWarning() << "invalid entry function specification";
return {};
}
// Verify output size match the number of arguments.
auto term = main_op.back().getTerminator();
if (output_names.size() != term->getNumOperands()) {
main_op.emitWarning() << "output names (" << output_names.size()
<< ") != terminator operands ("
<< term->getNumOperands() << ")";
return {};
}
// Verify number of tensors for inputs and outputs matches size
// of the list in the signature def.
if (input_names.size() != sig_def_inputs.size() ||
output_names.size() != sig_def_outputs.size()) {
main_op.emitWarning(
"Mismatch between signature def inputs/outputs and main function "
"arguments.");
return {};
}
// Exported method name.
auto exported_name =
main_op.getAttrOfType<mlir::ArrayAttr>("tf_saved_model.exported_names");
if (exported_name.empty()) {
main_op.emitError("Empty exported names for main Function");
return {};
}
// Fill the SignatureDefData container.
// We create vector of size 1 as TFLite now supports only 1 signatureDef.
std::vector<SignatureDefData> result(1);
for (int i = 0; i < input_names.size(); ++i) {
result[0].inputs[sig_def_inputs[i]] = input_names[i].str();
}
for (int i = 0; i < output_names.size(); ++i) {
result[0].outputs[sig_def_outputs[i]] = output_names[i].str();
}
if (auto name_attr = exported_name[0].dyn_cast_or_null<StringAttr>())
result[0].method_name = name_attr.getValue().str();
result[0].signature_def_key = saved_model_tag;
return result;
}
std::vector<BufferOffset<tflite::TensorMap>> Translator::GetList(
const std::map<std::string, std::string>& items) {
std::vector<BufferOffset<tflite::TensorMap>> result;
for (const auto& item : items) {
auto name_buf = builder_.CreateString(item.first);
tflite::TensorMapBuilder tensor_map_builder(builder_);
tensor_map_builder.add_name(name_buf);
tensor_map_builder.add_tensor_index(tensor_index_map_[item.second]);
result.push_back(tensor_map_builder.Finish());
}
return result;
}
Optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>>
Translator::CreateSignatureDefs(
const std::vector<SignatureDefData>& signature_defs) {
std::vector<BufferOffset<tflite::SignatureDef>> signature_defs_buffer;
for (const auto& signature_def_data : signature_defs) {
auto inputs = GetList(signature_def_data.inputs);
auto outputs = GetList(signature_def_data.outputs);
auto inputs_buf = builder_.CreateVector(inputs);
auto outputs_buf = builder_.CreateVector(outputs);
auto method_name_buf =
builder_.CreateString(signature_def_data.method_name);
auto signature_def_key_buf =
builder_.CreateString(signature_def_data.signature_def_key);
tflite::SignatureDefBuilder sig_def_builder(builder_);
sig_def_builder.add_inputs(inputs_buf);
sig_def_builder.add_outputs(outputs_buf);
sig_def_builder.add_method_name(method_name_buf);
sig_def_builder.add_key(signature_def_key_buf);
signature_defs_buffer.push_back(sig_def_builder.Finish());
}
return builder_.CreateVector(signature_defs_buffer);
}
bool UpdateEntryFunction(ModuleOp module) {
if (module.lookupSymbol<FuncOp>("main") != nullptr) {
// We already have an entry function.
@ -1312,11 +1490,12 @@ bool UpdateEntryFunction(ModuleOp module) {
Optional<std::string> Translator::Translate(
ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper) {
bool emit_custom_ops, const std::unordered_set<std::string>& tags,
OpOrArgNameMapper* op_or_arg_name_mapper) {
if (!UpdateEntryFunction(module)) return llvm::None;
if (!IsValidTFLiteMlirModule(module)) return llvm::None;
Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
emit_custom_ops, op_or_arg_name_mapper);
emit_custom_ops, tags, op_or_arg_name_mapper);
return translator.TranslateInternal();
}
@ -1392,10 +1571,17 @@ Optional<std::string> Translator::TranslateInternal() {
auto metadata = CreateMetadataVector();
if (!metadata) return llvm::None;
auto model = tflite::CreateModel(
builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(opcodes_),
builder_.CreateVector(subgraphs), description,
builder_.CreateVector(buffers_), metadata_buffer, *metadata);
// Build SignatureDef
// We only have 1 entry point 'main' function, so build only 1 signature def.
auto main_fn_signature_def = BuildSignaturedef(
main_fn, saved_model_tags_.empty() ? "" : *saved_model_tags_.begin());
auto signature_defs = CreateSignatureDefs(main_fn_signature_def);
auto model = tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION,
builder_.CreateVector(opcodes_),
builder_.CreateVector(subgraphs),
description, builder_.CreateVector(buffers_),
metadata_buffer, *metadata, *signature_defs);
tflite::FinishModelBuffer(builder_, model);
tflite::UpdateOpVersion(builder_.GetBufferPointer());
tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer());
@ -1519,12 +1705,10 @@ bool tflite::MlirToFlatBufferTranslateFunction(
ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
OpOrArgNameMapper* op_or_arg_name_mapper) {
auto maybe_translated =
Translator::Translate(module, emit_builtin_tflite_ops, emit_select_tf_ops,
emit_custom_ops, op_or_arg_name_mapper);
if (!maybe_translated) return true;
*serialized_flatbuffer = std::move(*maybe_translated);
return false;
return MlirToFlatBufferTranslateFunction(
module, serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{},
op_or_arg_name_mapper);
}
bool tflite::MlirToFlatBufferTranslateFunction(
@ -1534,5 +1718,30 @@ bool tflite::MlirToFlatBufferTranslateFunction(
OpOrArgLocNameMapper op_or_arg_name_mapper;
return MlirToFlatBufferTranslateFunction(
module, serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, &op_or_arg_name_mapper);
emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{},
&op_or_arg_name_mapper);
}
bool tflite::MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags) {
OpOrArgLocNameMapper op_or_arg_name_mapper;
return MlirToFlatBufferTranslateFunction(
module, serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, saved_model_tags,
&op_or_arg_name_mapper);
}
bool tflite::MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags,
OpOrArgNameMapper* op_or_arg_name_mapper) {
auto maybe_translated = Translator::Translate(
module, emit_builtin_tflite_ops, emit_select_tf_ops, emit_custom_ops,
saved_model_tags, op_or_arg_name_mapper);
if (!maybe_translated) return true;
*serialized_flatbuffer = std::move(*maybe_translated);
return false;
}

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_
#include <string>
#include <unordered_set>
#include "mlir/IR/Module.h" // from @llvm-project
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
@ -33,11 +34,24 @@ bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
bool emit_select_tf_ops,
bool emit_custom_ops);
// Same as above but takes SavedModel tags of the model.
bool MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags);
// Same as the above but with a custom op name mapper.
bool MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
// Same as above but takes SavedModel tags of the model.
bool MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags,
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
} // namespace tflite
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_

View File

@ -90,9 +90,10 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true;
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
pass_config, result,
/*session=*/llvm::None);
return internal::ConvertMLIRToTFLiteFlatBuffer(
toco_flags, std::move(module), pass_config, /*saved_model_tags=*/{},
result,
/*session=*/llvm::None);
}
} // namespace tensorflow

View File

@ -177,7 +177,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
// TODO(b/153507667): Pass the session object when importing logic is removed.
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
toco_flags, std::move(module), pass_config, result,
toco_flags, std::move(module), pass_config, tags, result,
/*session=*/llvm::None);
return status;
}

View File

@ -273,7 +273,8 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
Status ConvertMLIRToTFLiteFlatBuffer(
const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module,
const mlir::TFL::PassConfig& pass_config, string* result,
const mlir::TFL::PassConfig& pass_config,
const std::unordered_set<std::string>& saved_model_tags, string* result,
llvm::Optional<tensorflow::Session*> session) {
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
@ -297,8 +298,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs, result,
&pm);
emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs,
saved_model_tags, result, &pm);
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
// rename once we enable the new converter feature flag.

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_TF_TFL_FLATBUFFER_HELPERS_H_
#include <ostream>
#include <unordered_set>
#include <utility>
#include "llvm/ADT/Optional.h"
@ -48,7 +49,8 @@ Status PopulateQuantizationSpecs(
// This will also run relevant passes as well.
Status ConvertMLIRToTFLiteFlatBuffer(
const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module,
const mlir::TFL::PassConfig& pass_config, string* result,
const mlir::TFL::PassConfig& pass_config,
const std::unordered_set<std::string>& saved_model_tags, string* result,
llvm::Optional<tensorflow::Session*> session);
// Give a warning for any unused flags that have been specified.

View File

@ -96,5 +96,6 @@ versions {
# CHECK-NEXT: metadata: [ {
# CHECK-NEXT: name: "min_runtime_version",
# CHECK-NEXT: buffer: 4
# CHECK-NEXT: } ]
# CHECK-NEXT: } ],
# CHECK-NEXT: signature_defs: [ ]
# CHECK-NEXT: }

View File

@ -116,6 +116,7 @@ func @main(tensor<1x384xf32>, tensor<1x96xf32>, tensor<384x480xf32>, tensor<384x
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 10
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:}
^bb0(%arg0: tensor<1x384xf32>, %arg1: tensor<1x96xf32>, %arg2: tensor<384x480xf32>, %arg3: tensor<384xf32>, %arg4: tensor<1x96xf32>):

View File

@ -100,6 +100,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:}
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")

View File

@ -91,6 +91,7 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:}
%0 = "tfl.pseudo_const" () {value = dense<-1.23697901> : tensor<32xf32>} : () -> tensor<32xf32> loc("Const")

View File

@ -93,6 +93,7 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:}
%0 = "tfl.pseudo_const" () {value = dense<-1.23697901> : tensor<32xf32>} : () -> tensor<32xf32> loc("Const")

View File

@ -97,6 +97,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:}
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")

View File

@ -54,6 +54,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 3
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: }
// IMPORT: "tfl.fake_quant"(%arg0) {max = 1.400000e+00 : f32, min = 3.000000e-01 : f32, narrow_range = false, num_bits = 6 : i32}

View File

@ -47,6 +47,7 @@ func @main(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 3
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: }
%0 = "tf.AddV2"(%arg0, %arg0) : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>

View File

@ -60,6 +60,7 @@ func @main(tensor<4xcomplex<f64>>, tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 4
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:}
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4xcomplex<f64>>, tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f64>> loc("add")

View File

@ -60,6 +60,7 @@ func @main(tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 4
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:}
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> loc("add")

View File

@ -99,6 +99,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:}
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")

View File

@ -69,6 +69,7 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 5
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:}
%cst = constant unit

View File

@ -69,6 +69,7 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 5
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:}
%cst = constant unit

View File

@ -166,6 +166,7 @@
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 11
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: }
func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {

View File

@ -87,6 +87,7 @@ func @main(tensor<4xi1>) -> tensor<4xi1> {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: }
// CHECK-EMPTY:

View File

@ -258,6 +258,7 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 26
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: }
// CHECK-EMPTY:

View File

@ -320,5 +320,6 @@ func @main(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 23
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: }
}

View File

@ -140,6 +140,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 8
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: }
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")

View File

@ -33,4 +33,5 @@ module attributes {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: }

View File

@ -66,6 +66,7 @@ func @main(tensor<3x!quant.uniform<i8:f32, 0.1>>) -> tensor<3x!quant.uniform<i8:
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 4
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:}
%0 = "tfl.pseudo_qconst"() { qtype = tensor<3x!quant.uniform<i8:f32, 0.1>>, value = dense<2> : tensor<3xi8>} : () -> tensor<3x!quant.uniform<i8:f32, 0.1>>

View File

@ -66,6 +66,7 @@ func @main(tensor<3x!quant.uniform<i8:f32, 1.0>>) -> tensor<3x!quant.uniform<i8:
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 4
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:}
%0 = "tfl.pseudo_qconst"() { qtype = tensor<3x!quant.uniform<i8:f32, 1.0>>, value = dense<2> : tensor<3xi8>} : () -> tensor<3x!quant.uniform<i8:f32, 1.0>>

View File

@ -55,6 +55,7 @@ func @main(tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 3
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: }
%0 = "tfl.average_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> loc("avgpool")

View File

@ -48,6 +48,7 @@
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 3
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:}
func @main(%arg0: tensor<4xf32>, %arg1: tensor<4x!quant.uniform<u8:f32, 0.1>>) -> tensor<4xf32> {

View File

@ -165,6 +165,7 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 10
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:}
%0 = "tfl.pseudo_const" () {value = dense<[1, 1001]> : tensor<2xi32>} : () -> tensor<2xi32> loc("Const")

View File

@ -59,6 +59,7 @@ func @main(tensor<3x2xi32>) -> tensor<6xi32> {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 4
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: }
%0 = "tfl.pseudo_const" () {value = dense<[6]> : tensor<1xi32>} : () -> tensor<1xi32> loc("Const")

View File

@ -0,0 +1,117 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: deprecated_builtin_code: 9,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: FULLY_CONNECTED
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 1, 384 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "serving_default_input2:0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: },
// CHECK-NEXT: shape_signature: [ -1, 384 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 384 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "serving_default_input1:0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: },
// CHECK-NEXT: shape_signature: [ -1, 384 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 5 ],
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "std.constant",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 5, 384 ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "std.constant1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 5, 384 ],
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "std.constant2",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 5 ],
// CHECK-NEXT: buffer: 6,
// CHECK-NEXT: name: "StatefulPartitionedCall:0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: },
// CHECK-NEXT: shape_signature: [ -1, 5 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 5 ],
// CHECK-NEXT: buffer: 7,
// CHECK-NEXT: name: "StatefulPartitionedCall:1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: },
// CHECK-NEXT: shape_signature: [ -1, 5 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 6, 5 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 3, 2 ],
// CHECK-NEXT: outputs: [ 5 ],
// CHECK-NEXT: builtin_options_type: FullyConnectedOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: inputs: [ 0, 4, 2 ],
// CHECK-NEXT: outputs: [ 6 ],
// CHECK-NEXT: builtin_options_type: FullyConnectedOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 8
// CHECK-NEXT: } ],
// CHECK-NEXT: signature_defs: [ {
// CHECK-NEXT: inputs: [ {
// CHECK-NEXT: name: "input1",
// CHECK-NEXT: tensor_index: 1
// CHECK-NEXT: }, {
// CHECK-NEXT: name: "input2"
// CHECK-NEXT: } ],
// CHECK-NEXT: outputs: [ {
// CHECK-NEXT: name: "end_logits",
// CHECK-NEXT: tensor_index: 5
// CHECK-NEXT: }, {
// CHECK-NEXT: name: "start_logits",
// CHECK-NEXT: tensor_index: 6
// CHECK-NEXT: } ],
// CHECK-NEXT: method_name: "serving_default",
// CHECK-NEXT: key: ""
// CHECK-NEXT: } ]
// CHECK-NEXT:}
module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 554 : i32}, tf_saved_model.semantics} {
func @main(%arg0: tensor<?x384xf32> {tf_saved_model.index_path = ["input2"]}, %arg1: tensor<?x384xf32> {tf_saved_model.index_path = ["input1"]}) -> (tensor<?x5xf32> {tf_saved_model.index_path = ["start_logits"]}, tensor<?x5xf32> {tf_saved_model.index_path = ["end_logits"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input2:0,serving_default_input1:0", outputs = "StatefulPartitionedCall:1,StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
%cst = constant dense<0.000000e+00> : tensor<5xf32>
%cst_0 = constant dense<1.0> : tensor<5x384xf32>
%cst_1 = constant dense<1.0> : tensor<5x384xf32>
%0 = "tfl.fully_connected"(%arg0, %cst_0, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<?x384xf32>, tensor<5x384xf32>, tensor<5xf32>) -> tensor<?x5xf32>
%1 = "tfl.fully_connected"(%arg0, %cst_1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<?x384xf32>, tensor<5x384xf32>, tensor<5xf32>) -> tensor<?x5xf32>
return %1, %0 : tensor<?x5xf32>, tensor<?x5xf32>
}
}

View File

@ -105,6 +105,7 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32>
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 6
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: }
%0 = "tfl.pseudo_const" () {value = dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> loc("Const")

View File

@ -87,6 +87,7 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 7
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: }
// CHECK-EMPTY:

View File

@ -88,6 +88,7 @@ func @main(tensor<4 x f32>, tensor<4 x i8>, tensor<4 x f32>, tensor<4 x f32>) ->
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 7
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: }
// CHECK-EMPTY:

View File

@ -199,6 +199,7 @@
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 14
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: }
func @WhileOp_cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {

View File

@ -70,6 +70,7 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 5
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT:}
%cst = constant unit

View File

@ -257,6 +257,7 @@ func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 26
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: }
// CHECK-EMPTY:

View File

@ -87,6 +87,7 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 7
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: }
// CHECK-EMPTY:

View File

@ -199,6 +199,7 @@
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 14
// CHECK-NEXT: } ]
// CHECK-NEXT: signature_defs: [ ]
// CHECK-NEXT: }
func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {

View File

@ -143,6 +143,7 @@ int main(int argc, char **argv) {
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
StatusOr<mlir::OwningModuleRef> module;
std::unordered_set<std::string> tags;
tensorflow::GraphImportConfig specs;
specs.upgrade_legacy = upgrade_legacy;
@ -161,8 +162,7 @@ int main(int argc, char **argv) {
module = tensorflow::errors::InvalidArgument(
"Importing saved model should not have input_mlir set");
std::unordered_set<std::string> tags =
absl::StrSplit(saved_model_tags, ',');
tags = absl::StrSplit(saved_model_tags, ',');
std::vector<std::string> exported_names_vector =
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
absl::Span<std::string> exported_names(exported_names_vector);
@ -241,7 +241,7 @@ int main(int argc, char **argv) {
std::string result;
auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
module.ValueOrDie().get(), output_mlir, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, quant_specs, &result, &pm);
emit_select_tf_ops, emit_custom_ops, quant_specs, tags, &result, &pm);
if (!status.ok()) return kTrFailure;
std::string error_msg;

View File

@ -137,8 +137,9 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
Status ConvertTFExecutorToTFLOrFlatbuffer(
mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
bool emit_select_tf_ops, bool emit_custom_ops,
const mlir::TFL::QuantizationSpecs& quant_specs, std::string* result,
mlir::PassManager* pass_manager) {
const mlir::TFL::QuantizationSpecs& quant_specs,
const std::unordered_set<std::string>& saved_model_tags,
std::string* result, mlir::PassManager* pass_manager) {
// Explicitly disable dumping Op details on failures.
module.getContext()->printOpOnDiagnostic(false);
@ -171,7 +172,7 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
if (!quant_specs.RunWeightQuantization()) {
if (tflite::MlirToFlatBufferTranslateFunction(
module, result, emit_builtin_tflite_ops, emit_select_tf_ops,
emit_custom_ops)) {
emit_custom_ops, saved_model_tags)) {
return statusHandler.ConsumeStatus();
}
} else {
@ -180,7 +181,7 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
std::string pre_quantized_result;
if (tflite::MlirToFlatBufferTranslateFunction(
module, &pre_quantized_result, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops)) {
emit_select_tf_ops, emit_custom_ops, saved_model_tags)) {
return statusHandler.ConsumeStatus();
}
flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);

View File

@ -63,8 +63,9 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
Status ConvertTFExecutorToTFLOrFlatbuffer(
mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
bool emit_select_tf_ops, bool emit_custom_ops,
const mlir::TFL::QuantizationSpecs& quant_specs, std::string* result,
mlir::PassManager* pass_manager);
const mlir::TFL::QuantizationSpecs& quant_specs,
const std::unordered_set<std::string>& saved_model_tags,
std::string* result, mlir::PassManager* pass_manager);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_