Update TFLite converter to export saved model signature def to the tflite schema.
PiperOrigin-RevId: 337402305 Change-Id: I81834461f54c7be1c114dbffcb1ad31f08e0d2be
This commit is contained in:
parent
ed408b579e
commit
a70f132c8f
@ -326,6 +326,21 @@ static Optional<TfLitePoolParams> GetTflitePoolParams(Operation* inst,
|
||||
|
||||
namespace {
|
||||
|
||||
// Helper struct that wraps inputs/outputs of a single SignatureDef.
|
||||
struct SignatureDefData {
|
||||
// Note, we are using maps here to make order deterministic
|
||||
// for easily testing only.
|
||||
|
||||
// Inputs defined in the signature def mapped to tensor names.
|
||||
std::map<std::string, std::string> inputs;
|
||||
// Outputs defined in the signature def mapped to tensor names.
|
||||
std::map<std::string, std::string> outputs;
|
||||
// Method name exported by the signature def.
|
||||
std::string method_name;
|
||||
// SignatureDef key.
|
||||
std::string signature_def_key;
|
||||
};
|
||||
|
||||
// Translates an MLIR module in TFLite dialect to TFLite FlatBuffer.
|
||||
class Translator {
|
||||
public:
|
||||
@ -334,16 +349,19 @@ class Translator {
|
||||
// internal error.
|
||||
static Optional<std::string> Translate(
|
||||
ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
|
||||
bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper);
|
||||
bool emit_custom_ops, const std::unordered_set<std::string>& tags,
|
||||
OpOrArgNameMapper* op_or_arg_name_mapper);
|
||||
|
||||
private:
|
||||
enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp };
|
||||
explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops,
|
||||
bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
OpOrArgNameMapper* op_or_arg_name_mapper)
|
||||
: module_(module),
|
||||
name_mapper_(*op_or_arg_name_mapper),
|
||||
builder_(kInitialBufferSize) {
|
||||
builder_(kInitialBufferSize),
|
||||
saved_model_tags_(saved_model_tags) {
|
||||
// The first buffer must be empty according to the schema definition.
|
||||
empty_buffer_ = tflite::CreateBuffer(builder_);
|
||||
buffers_.push_back(empty_buffer_);
|
||||
@ -450,6 +468,17 @@ class Translator {
|
||||
Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
|
||||
CreateMetadataVector();
|
||||
|
||||
// Builds and returns list of tfl.SignatureDef sections in the model.
|
||||
Optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>>
|
||||
CreateSignatureDefs(const std::vector<SignatureDefData>& signature_defs);
|
||||
|
||||
// Returns list of offsets for the passed 'items' in TensorMap structure
|
||||
// inside the flatbuffer.
|
||||
// 'items' is a map from tensor name in signatureDef to tensor name in
|
||||
// the model.
|
||||
std::vector<BufferOffset<tflite::TensorMap>> GetList(
|
||||
const std::map<std::string, std::string>& items);
|
||||
|
||||
// Uses the tf.entry_function attribute (if set) to initialize the op to name
|
||||
// mapping.
|
||||
void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr);
|
||||
@ -472,6 +501,8 @@ class Translator {
|
||||
BufferOffset<tflite::Buffer> empty_buffer_;
|
||||
|
||||
std::vector<BufferOffset<tflite::Buffer>> buffers_;
|
||||
// Maps tensor name in the graph to the tensor index.
|
||||
absl::flat_hash_map<std::string, int> tensor_index_map_;
|
||||
|
||||
// Maps op name to index of the corresponding OperatorCode in opcodes_ vector.
|
||||
absl::flat_hash_map<std::string, uint32_t> opcode_index_map_;
|
||||
@ -490,6 +521,9 @@ class Translator {
|
||||
// The failed ops during legalization.
|
||||
std::set<std::string> failed_flex_ops_;
|
||||
std::set<std::string> failed_custom_ops_;
|
||||
|
||||
// Set of saved model tags, if any.
|
||||
const std::unordered_set<std::string> saved_model_tags_;
|
||||
};
|
||||
|
||||
std::string Translator::UniqueName(mlir::Value val) {
|
||||
@ -1131,6 +1165,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
|
||||
}
|
||||
|
||||
tensor_index_map.insert({value, tensors.size()});
|
||||
tensor_index_map_[name] = tensors.size();
|
||||
auto tensor_or = BuildTensor(value, name, buffers_.size());
|
||||
if (!tensor_or) return false;
|
||||
tensors.push_back(*tensor_or);
|
||||
@ -1286,6 +1321,149 @@ Translator::CreateMetadataVector() {
|
||||
return builder_.CreateVector(metadata);
|
||||
}
|
||||
|
||||
// Helper method that returns list of all strings in a StringAttr identified
|
||||
// by 'attr_key' and values are separated by a comma.
|
||||
llvm::SmallVector<llvm::StringRef, 2> GetStringsFromAttrWithSeparator(
|
||||
mlir::DictionaryAttr attr, const std::string& attr_key) {
|
||||
llvm::SmallVector<llvm::StringRef, 2> result;
|
||||
if (auto str = attr.get(attr_key).dyn_cast_or_null<mlir::StringAttr>()) {
|
||||
str.getValue().split(result, ',', /*MaxSplit=*/-1,
|
||||
/*KeepEmpty=*/false);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Helper method that return list of string for all the StringAttr in the
|
||||
// Attribute identified by 'attr_name'.
|
||||
std::vector<std::string> GetStringsFromDictionaryAttr(
|
||||
const llvm::SmallVector<mlir::MutableDictionaryAttr, 4>& dict_attrs,
|
||||
const std::string& attr_name) {
|
||||
std::vector<std::string> result;
|
||||
for (const auto& arg_attr : dict_attrs) {
|
||||
auto attrs = arg_attr.getAttrs();
|
||||
for (const auto attr : attrs) {
|
||||
if (attr.first.str() == attr_name) {
|
||||
auto array_attr = attr.second.dyn_cast_or_null<mlir::ArrayAttr>();
|
||||
if (!array_attr || array_attr.empty()) continue;
|
||||
auto string_attr = array_attr[0].dyn_cast_or_null<mlir::StringAttr>();
|
||||
if (!string_attr) continue;
|
||||
result.push_back(string_attr.getValue().str());
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<SignatureDefData> BuildSignaturedef(
|
||||
FuncOp main_op, const std::string& saved_model_tag) {
|
||||
static const char kSignatureDefIndexPath[] = "tf_saved_model.index_path";
|
||||
static const char kEntryFunctionAttributes[] = "tf.entry_function";
|
||||
|
||||
// Fetch inputs and outputs from the signature.
|
||||
llvm::SmallVector<mlir::MutableDictionaryAttr, 4> arg_attrs, res_attrs;
|
||||
main_op.getAllArgAttrs(arg_attrs);
|
||||
main_op.getAllResultAttrs(res_attrs);
|
||||
std::vector<std::string> sig_def_inputs =
|
||||
GetStringsFromDictionaryAttr(arg_attrs, kSignatureDefIndexPath);
|
||||
std::vector<std::string> sig_def_outputs =
|
||||
GetStringsFromDictionaryAttr(res_attrs, kSignatureDefIndexPath);
|
||||
|
||||
// If no defined saved model signature, then return empty list.
|
||||
// This can happen when we are converting model not from SavedModel.
|
||||
if (sig_def_inputs.empty() || sig_def_outputs.empty()) return {};
|
||||
|
||||
// Fetch function inputs and outputs tensor names.
|
||||
auto dict_attr =
|
||||
main_op.getAttrOfType<mlir::DictionaryAttr>(kEntryFunctionAttributes);
|
||||
if (!dict_attr) return {};
|
||||
|
||||
// Get Input and output tensor names from attribute.
|
||||
llvm::SmallVector<llvm::StringRef, 2> input_names =
|
||||
GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"inputs");
|
||||
llvm::SmallVector<llvm::StringRef, 2> output_names =
|
||||
GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"outputs");
|
||||
|
||||
// Verify input size match the number of arguments.
|
||||
if (input_names.size() != main_op.getNumArguments()) {
|
||||
main_op.emitWarning() << "invalid entry function specification";
|
||||
return {};
|
||||
}
|
||||
// Verify output size match the number of arguments.
|
||||
auto term = main_op.back().getTerminator();
|
||||
if (output_names.size() != term->getNumOperands()) {
|
||||
main_op.emitWarning() << "output names (" << output_names.size()
|
||||
<< ") != terminator operands ("
|
||||
<< term->getNumOperands() << ")";
|
||||
return {};
|
||||
}
|
||||
// Verify number of tensors for inputs and outputs matches size
|
||||
// of the list in the signature def.
|
||||
if (input_names.size() != sig_def_inputs.size() ||
|
||||
output_names.size() != sig_def_outputs.size()) {
|
||||
main_op.emitWarning(
|
||||
"Mismatch between signature def inputs/outputs and main function "
|
||||
"arguments.");
|
||||
return {};
|
||||
}
|
||||
// Exported method name.
|
||||
auto exported_name =
|
||||
main_op.getAttrOfType<mlir::ArrayAttr>("tf_saved_model.exported_names");
|
||||
if (exported_name.empty()) {
|
||||
main_op.emitError("Empty exported names for main Function");
|
||||
return {};
|
||||
}
|
||||
// Fill the SignatureDefData container.
|
||||
// We create vector of size 1 as TFLite now supports only 1 signatureDef.
|
||||
std::vector<SignatureDefData> result(1);
|
||||
for (int i = 0; i < input_names.size(); ++i) {
|
||||
result[0].inputs[sig_def_inputs[i]] = input_names[i].str();
|
||||
}
|
||||
for (int i = 0; i < output_names.size(); ++i) {
|
||||
result[0].outputs[sig_def_outputs[i]] = output_names[i].str();
|
||||
}
|
||||
if (auto name_attr = exported_name[0].dyn_cast_or_null<StringAttr>())
|
||||
result[0].method_name = name_attr.getValue().str();
|
||||
result[0].signature_def_key = saved_model_tag;
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<BufferOffset<tflite::TensorMap>> Translator::GetList(
|
||||
const std::map<std::string, std::string>& items) {
|
||||
std::vector<BufferOffset<tflite::TensorMap>> result;
|
||||
for (const auto& item : items) {
|
||||
auto name_buf = builder_.CreateString(item.first);
|
||||
tflite::TensorMapBuilder tensor_map_builder(builder_);
|
||||
tensor_map_builder.add_name(name_buf);
|
||||
tensor_map_builder.add_tensor_index(tensor_index_map_[item.second]);
|
||||
result.push_back(tensor_map_builder.Finish());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>>
|
||||
Translator::CreateSignatureDefs(
|
||||
const std::vector<SignatureDefData>& signature_defs) {
|
||||
std::vector<BufferOffset<tflite::SignatureDef>> signature_defs_buffer;
|
||||
for (const auto& signature_def_data : signature_defs) {
|
||||
auto inputs = GetList(signature_def_data.inputs);
|
||||
auto outputs = GetList(signature_def_data.outputs);
|
||||
auto inputs_buf = builder_.CreateVector(inputs);
|
||||
auto outputs_buf = builder_.CreateVector(outputs);
|
||||
auto method_name_buf =
|
||||
builder_.CreateString(signature_def_data.method_name);
|
||||
auto signature_def_key_buf =
|
||||
builder_.CreateString(signature_def_data.signature_def_key);
|
||||
tflite::SignatureDefBuilder sig_def_builder(builder_);
|
||||
sig_def_builder.add_inputs(inputs_buf);
|
||||
sig_def_builder.add_outputs(outputs_buf);
|
||||
sig_def_builder.add_method_name(method_name_buf);
|
||||
sig_def_builder.add_key(signature_def_key_buf);
|
||||
signature_defs_buffer.push_back(sig_def_builder.Finish());
|
||||
}
|
||||
|
||||
return builder_.CreateVector(signature_defs_buffer);
|
||||
}
|
||||
|
||||
bool UpdateEntryFunction(ModuleOp module) {
|
||||
if (module.lookupSymbol<FuncOp>("main") != nullptr) {
|
||||
// We already have an entry function.
|
||||
@ -1312,11 +1490,12 @@ bool UpdateEntryFunction(ModuleOp module) {
|
||||
|
||||
Optional<std::string> Translator::Translate(
|
||||
ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
|
||||
bool emit_custom_ops, OpOrArgNameMapper* op_or_arg_name_mapper) {
|
||||
bool emit_custom_ops, const std::unordered_set<std::string>& tags,
|
||||
OpOrArgNameMapper* op_or_arg_name_mapper) {
|
||||
if (!UpdateEntryFunction(module)) return llvm::None;
|
||||
if (!IsValidTFLiteMlirModule(module)) return llvm::None;
|
||||
Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
|
||||
emit_custom_ops, op_or_arg_name_mapper);
|
||||
emit_custom_ops, tags, op_or_arg_name_mapper);
|
||||
return translator.TranslateInternal();
|
||||
}
|
||||
|
||||
@ -1392,10 +1571,17 @@ Optional<std::string> Translator::TranslateInternal() {
|
||||
auto metadata = CreateMetadataVector();
|
||||
if (!metadata) return llvm::None;
|
||||
|
||||
auto model = tflite::CreateModel(
|
||||
builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(opcodes_),
|
||||
builder_.CreateVector(subgraphs), description,
|
||||
builder_.CreateVector(buffers_), metadata_buffer, *metadata);
|
||||
// Build SignatureDef
|
||||
// We only have 1 entry point 'main' function, so build only 1 signature def.
|
||||
auto main_fn_signature_def = BuildSignaturedef(
|
||||
main_fn, saved_model_tags_.empty() ? "" : *saved_model_tags_.begin());
|
||||
auto signature_defs = CreateSignatureDefs(main_fn_signature_def);
|
||||
|
||||
auto model = tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION,
|
||||
builder_.CreateVector(opcodes_),
|
||||
builder_.CreateVector(subgraphs),
|
||||
description, builder_.CreateVector(buffers_),
|
||||
metadata_buffer, *metadata, *signature_defs);
|
||||
tflite::FinishModelBuffer(builder_, model);
|
||||
tflite::UpdateOpVersion(builder_.GetBufferPointer());
|
||||
tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer());
|
||||
@ -1519,12 +1705,10 @@ bool tflite::MlirToFlatBufferTranslateFunction(
|
||||
ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
OpOrArgNameMapper* op_or_arg_name_mapper) {
|
||||
auto maybe_translated =
|
||||
Translator::Translate(module, emit_builtin_tflite_ops, emit_select_tf_ops,
|
||||
emit_custom_ops, op_or_arg_name_mapper);
|
||||
if (!maybe_translated) return true;
|
||||
*serialized_flatbuffer = std::move(*maybe_translated);
|
||||
return false;
|
||||
return MlirToFlatBufferTranslateFunction(
|
||||
module, serialized_flatbuffer, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{},
|
||||
op_or_arg_name_mapper);
|
||||
}
|
||||
|
||||
bool tflite::MlirToFlatBufferTranslateFunction(
|
||||
@ -1534,5 +1718,30 @@ bool tflite::MlirToFlatBufferTranslateFunction(
|
||||
OpOrArgLocNameMapper op_or_arg_name_mapper;
|
||||
return MlirToFlatBufferTranslateFunction(
|
||||
module, serialized_flatbuffer, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, &op_or_arg_name_mapper);
|
||||
emit_select_tf_ops, emit_custom_ops, /*saved_model_tags=*/{},
|
||||
&op_or_arg_name_mapper);
|
||||
}
|
||||
|
||||
bool tflite::MlirToFlatBufferTranslateFunction(
|
||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags) {
|
||||
OpOrArgLocNameMapper op_or_arg_name_mapper;
|
||||
return MlirToFlatBufferTranslateFunction(
|
||||
module, serialized_flatbuffer, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, saved_model_tags,
|
||||
&op_or_arg_name_mapper);
|
||||
}
|
||||
|
||||
bool tflite::MlirToFlatBufferTranslateFunction(
|
||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
OpOrArgNameMapper* op_or_arg_name_mapper) {
|
||||
auto maybe_translated = Translator::Translate(
|
||||
module, emit_builtin_tflite_ops, emit_select_tf_ops, emit_custom_ops,
|
||||
saved_model_tags, op_or_arg_name_mapper);
|
||||
if (!maybe_translated) return true;
|
||||
*serialized_flatbuffer = std::move(*maybe_translated);
|
||||
return false;
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||
@ -33,11 +34,24 @@ bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
|
||||
bool emit_select_tf_ops,
|
||||
bool emit_custom_ops);
|
||||
|
||||
// Same as above but takes SavedModel tags of the model.
|
||||
bool MlirToFlatBufferTranslateFunction(
|
||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags);
|
||||
|
||||
// Same as the above but with a custom op name mapper.
|
||||
bool MlirToFlatBufferTranslateFunction(
|
||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
|
||||
|
||||
// Same as above but takes SavedModel tags of the model.
|
||||
bool MlirToFlatBufferTranslateFunction(
|
||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_
|
||||
|
@ -90,9 +90,10 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
pass_config.lower_tensor_list_ops = true;
|
||||
|
||||
return internal::ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
|
||||
pass_config, result,
|
||||
/*session=*/llvm::None);
|
||||
return internal::ConvertMLIRToTFLiteFlatBuffer(
|
||||
toco_flags, std::move(module), pass_config, /*saved_model_tags=*/{},
|
||||
result,
|
||||
/*session=*/llvm::None);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -177,7 +177,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
|
||||
// TODO(b/153507667): Pass the session object when importing logic is removed.
|
||||
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
|
||||
toco_flags, std::move(module), pass_config, result,
|
||||
toco_flags, std::move(module), pass_config, tags, result,
|
||||
/*session=*/llvm::None);
|
||||
return status;
|
||||
}
|
||||
|
@ -273,7 +273,8 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
|
||||
|
||||
Status ConvertMLIRToTFLiteFlatBuffer(
|
||||
const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module,
|
||||
const mlir::TFL::PassConfig& pass_config, string* result,
|
||||
const mlir::TFL::PassConfig& pass_config,
|
||||
const std::unordered_set<std::string>& saved_model_tags, string* result,
|
||||
llvm::Optional<tensorflow::Session*> session) {
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
|
||||
@ -297,8 +298,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(
|
||||
|
||||
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs, result,
|
||||
&pm);
|
||||
emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs,
|
||||
saved_model_tags, result, &pm);
|
||||
if (toco_flags.has_dump_graphviz_dir()) {
|
||||
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
|
||||
// rename once we enable the new converter feature flag.
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_TF_TFL_FLATBUFFER_HELPERS_H_
|
||||
|
||||
#include <ostream>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/ADT/Optional.h"
|
||||
@ -48,7 +49,8 @@ Status PopulateQuantizationSpecs(
|
||||
// This will also run relevant passes as well.
|
||||
Status ConvertMLIRToTFLiteFlatBuffer(
|
||||
const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module,
|
||||
const mlir::TFL::PassConfig& pass_config, string* result,
|
||||
const mlir::TFL::PassConfig& pass_config,
|
||||
const std::unordered_set<std::string>& saved_model_tags, string* result,
|
||||
llvm::Optional<tensorflow::Session*> session);
|
||||
|
||||
// Give a warning for any unused flags that have been specified.
|
||||
|
@ -96,5 +96,6 @@ versions {
|
||||
# CHECK-NEXT: metadata: [ {
|
||||
# CHECK-NEXT: name: "min_runtime_version",
|
||||
# CHECK-NEXT: buffer: 4
|
||||
# CHECK-NEXT: } ]
|
||||
# CHECK-NEXT: } ],
|
||||
# CHECK-NEXT: signature_defs: [ ]
|
||||
# CHECK-NEXT: }
|
||||
|
@ -116,6 +116,7 @@ func @main(tensor<1x384xf32>, tensor<1x96xf32>, tensor<384x480xf32>, tensor<384x
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 10
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
^bb0(%arg0: tensor<1x384xf32>, %arg1: tensor<1x96xf32>, %arg2: tensor<384x480xf32>, %arg3: tensor<384xf32>, %arg4: tensor<1x96xf32>):
|
||||
|
@ -100,6 +100,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 6
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
|
@ -91,6 +91,7 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 6
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
%0 = "tfl.pseudo_const" () {value = dense<-1.23697901> : tensor<32xf32>} : () -> tensor<32xf32> loc("Const")
|
||||
|
@ -93,6 +93,7 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 6
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
%0 = "tfl.pseudo_const" () {value = dense<-1.23697901> : tensor<32xf32>} : () -> tensor<32xf32> loc("Const")
|
||||
|
@ -97,6 +97,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 6
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
|
@ -54,6 +54,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 3
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// IMPORT: "tfl.fake_quant"(%arg0) {max = 1.400000e+00 : f32, min = 3.000000e-01 : f32, narrow_range = false, num_bits = 6 : i32}
|
||||
|
@ -47,6 +47,7 @@ func @main(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 3
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
%0 = "tf.AddV2"(%arg0, %arg0) : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>
|
||||
|
@ -60,6 +60,7 @@ func @main(tensor<4xcomplex<f64>>, tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 4
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4xcomplex<f64>>, tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f64>> loc("add")
|
||||
|
@ -60,6 +60,7 @@ func @main(tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 4
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> loc("add")
|
||||
|
@ -99,6 +99,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 6
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
|
@ -69,6 +69,7 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 5
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
%cst = constant unit
|
||||
|
@ -69,6 +69,7 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 5
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
%cst = constant unit
|
||||
|
@ -166,6 +166,7 @@
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 11
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
|
@ -87,6 +87,7 @@ func @main(tensor<4xi1>) -> tensor<4xi1> {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 6
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-EMPTY:
|
||||
|
||||
|
@ -258,6 +258,7 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 26
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-EMPTY:
|
||||
|
||||
|
@ -320,5 +320,6 @@ func @main(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 23
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT: }
|
||||
}
|
||||
|
@ -140,6 +140,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 8
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
|
@ -33,4 +33,5 @@ module attributes {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 6
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT: }
|
||||
|
@ -66,6 +66,7 @@ func @main(tensor<3x!quant.uniform<i8:f32, 0.1>>) -> tensor<3x!quant.uniform<i8:
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 4
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
%0 = "tfl.pseudo_qconst"() { qtype = tensor<3x!quant.uniform<i8:f32, 0.1>>, value = dense<2> : tensor<3xi8>} : () -> tensor<3x!quant.uniform<i8:f32, 0.1>>
|
||||
|
@ -66,6 +66,7 @@ func @main(tensor<3x!quant.uniform<i8:f32, 1.0>>) -> tensor<3x!quant.uniform<i8:
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 4
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
%0 = "tfl.pseudo_qconst"() { qtype = tensor<3x!quant.uniform<i8:f32, 1.0>>, value = dense<2> : tensor<3xi8>} : () -> tensor<3x!quant.uniform<i8:f32, 1.0>>
|
||||
|
@ -55,6 +55,7 @@ func @main(tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 3
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
%0 = "tfl.average_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> loc("avgpool")
|
||||
|
@ -48,6 +48,7 @@
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 3
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
func @main(%arg0: tensor<4xf32>, %arg1: tensor<4x!quant.uniform<u8:f32, 0.1>>) -> tensor<4xf32> {
|
||||
|
@ -165,6 +165,7 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 10
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
%0 = "tfl.pseudo_const" () {value = dense<[1, 1001]> : tensor<2xi32>} : () -> tensor<2xi32> loc("Const")
|
||||
|
@ -59,6 +59,7 @@ func @main(tensor<3x2xi32>) -> tensor<6xi32> {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 4
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
%0 = "tfl.pseudo_const" () {value = dense<[6]> : tensor<1xi32>} : () -> tensor<1xi32> loc("Const")
|
||||
|
@ -0,0 +1,117 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: deprecated_builtin_code: 9,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: FULLY_CONNECTED
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 1, 384 ],
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "serving_default_input2:0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: },
|
||||
// CHECK-NEXT: shape_signature: [ -1, 384 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 384 ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "serving_default_input1:0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: },
|
||||
// CHECK-NEXT: shape_signature: [ -1, 384 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 5 ],
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "std.constant",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 5, 384 ],
|
||||
// CHECK-NEXT: buffer: 4,
|
||||
// CHECK-NEXT: name: "std.constant1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 5, 384 ],
|
||||
// CHECK-NEXT: buffer: 5,
|
||||
// CHECK-NEXT: name: "std.constant2",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 5 ],
|
||||
// CHECK-NEXT: buffer: 6,
|
||||
// CHECK-NEXT: name: "StatefulPartitionedCall:0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: },
|
||||
// CHECK-NEXT: shape_signature: [ -1, 5 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 5 ],
|
||||
// CHECK-NEXT: buffer: 7,
|
||||
// CHECK-NEXT: name: "StatefulPartitionedCall:1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: },
|
||||
// CHECK-NEXT: shape_signature: [ -1, 5 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 6, 5 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 3, 2 ],
|
||||
// CHECK-NEXT: outputs: [ 5 ],
|
||||
// CHECK-NEXT: builtin_options_type: FullyConnectedOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: inputs: [ 0, 4, 2 ],
|
||||
// CHECK-NEXT: outputs: [ 6 ],
|
||||
// CHECK-NEXT: builtin_options_type: FullyConnectedOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
|
||||
// CHECK: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 8
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: signature_defs: [ {
|
||||
// CHECK-NEXT: inputs: [ {
|
||||
// CHECK-NEXT: name: "input1",
|
||||
// CHECK-NEXT: tensor_index: 1
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: name: "input2"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: outputs: [ {
|
||||
// CHECK-NEXT: name: "end_logits",
|
||||
// CHECK-NEXT: tensor_index: 5
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: name: "start_logits",
|
||||
// CHECK-NEXT: tensor_index: 6
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: method_name: "serving_default",
|
||||
// CHECK-NEXT: key: ""
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 554 : i32}, tf_saved_model.semantics} {
|
||||
func @main(%arg0: tensor<?x384xf32> {tf_saved_model.index_path = ["input2"]}, %arg1: tensor<?x384xf32> {tf_saved_model.index_path = ["input1"]}) -> (tensor<?x5xf32> {tf_saved_model.index_path = ["start_logits"]}, tensor<?x5xf32> {tf_saved_model.index_path = ["end_logits"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input2:0,serving_default_input1:0", outputs = "StatefulPartitionedCall:1,StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
|
||||
%cst = constant dense<0.000000e+00> : tensor<5xf32>
|
||||
%cst_0 = constant dense<1.0> : tensor<5x384xf32>
|
||||
%cst_1 = constant dense<1.0> : tensor<5x384xf32>
|
||||
%0 = "tfl.fully_connected"(%arg0, %cst_0, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<?x384xf32>, tensor<5x384xf32>, tensor<5xf32>) -> tensor<?x5xf32>
|
||||
%1 = "tfl.fully_connected"(%arg0, %cst_1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<?x384xf32>, tensor<5x384xf32>, tensor<5xf32>) -> tensor<?x5xf32>
|
||||
return %1, %0 : tensor<?x5xf32>, tensor<?x5xf32>
|
||||
}
|
||||
}
|
@ -105,6 +105,7 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 6
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
%0 = "tfl.pseudo_const" () {value = dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> loc("Const")
|
||||
|
@ -87,6 +87,7 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 7
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-EMPTY:
|
||||
|
||||
|
@ -88,6 +88,7 @@ func @main(tensor<4 x f32>, tensor<4 x i8>, tensor<4 x f32>, tensor<4 x f32>) ->
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 7
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-EMPTY:
|
||||
|
||||
|
@ -199,6 +199,7 @@
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 14
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func @WhileOp_cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
|
||||
|
@ -70,6 +70,7 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 5
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
%cst = constant unit
|
||||
|
@ -257,6 +257,7 @@ func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 26
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-EMPTY:
|
||||
|
||||
|
@ -87,6 +87,7 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 7
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-EMPTY:
|
||||
|
||||
|
@ -199,6 +199,7 @@
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
// CHECK-NEXT: buffer: 14
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: signature_defs: [ ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
|
@ -143,6 +143,7 @@ int main(int argc, char **argv) {
|
||||
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> module;
|
||||
std::unordered_set<std::string> tags;
|
||||
|
||||
tensorflow::GraphImportConfig specs;
|
||||
specs.upgrade_legacy = upgrade_legacy;
|
||||
@ -161,8 +162,7 @@ int main(int argc, char **argv) {
|
||||
module = tensorflow::errors::InvalidArgument(
|
||||
"Importing saved model should not have input_mlir set");
|
||||
|
||||
std::unordered_set<std::string> tags =
|
||||
absl::StrSplit(saved_model_tags, ',');
|
||||
tags = absl::StrSplit(saved_model_tags, ',');
|
||||
std::vector<std::string> exported_names_vector =
|
||||
absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty());
|
||||
absl::Span<std::string> exported_names(exported_names_vector);
|
||||
@ -241,7 +241,7 @@ int main(int argc, char **argv) {
|
||||
std::string result;
|
||||
auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
module.ValueOrDie().get(), output_mlir, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, quant_specs, &result, &pm);
|
||||
emit_select_tf_ops, emit_custom_ops, quant_specs, tags, &result, &pm);
|
||||
if (!status.ok()) return kTrFailure;
|
||||
|
||||
std::string error_msg;
|
||||
|
@ -137,8 +137,9 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
|
||||
Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
|
||||
bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const mlir::TFL::QuantizationSpecs& quant_specs, std::string* result,
|
||||
mlir::PassManager* pass_manager) {
|
||||
const mlir::TFL::QuantizationSpecs& quant_specs,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
std::string* result, mlir::PassManager* pass_manager) {
|
||||
// Explicitly disable dumping Op details on failures.
|
||||
module.getContext()->printOpOnDiagnostic(false);
|
||||
|
||||
@ -171,7 +172,7 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
if (!quant_specs.RunWeightQuantization()) {
|
||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
||||
module, result, emit_builtin_tflite_ops, emit_select_tf_ops,
|
||||
emit_custom_ops)) {
|
||||
emit_custom_ops, saved_model_tags)) {
|
||||
return statusHandler.ConsumeStatus();
|
||||
}
|
||||
} else {
|
||||
@ -180,7 +181,7 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
std::string pre_quantized_result;
|
||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
||||
module, &pre_quantized_result, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops)) {
|
||||
emit_select_tf_ops, emit_custom_ops, saved_model_tags)) {
|
||||
return statusHandler.ConsumeStatus();
|
||||
}
|
||||
flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);
|
||||
|
@ -63,8 +63,9 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
|
||||
bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const mlir::TFL::QuantizationSpecs& quant_specs, std::string* result,
|
||||
mlir::PassManager* pass_manager);
|
||||
const mlir::TFL::QuantizationSpecs& quant_specs,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
std::string* result, mlir::PassManager* pass_manager);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_
|
||||
|
Loading…
Reference in New Issue
Block a user