Print all unsupported ops before failing in flatbuffer_translate.cc. This helps the user know what are all the unsupported ops in the current model.
PiperOrigin-RevId: 276569129 Change-Id: Idbc99c54bd0a1bb101c841b5b14fea0964f26b82
This commit is contained in:
parent
c3c312afde
commit
a084b0b4c9
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "flatbuffers/flatbuffers.h" // TF:flatbuffers
|
||||
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
|
||||
@ -466,6 +467,10 @@ class Translator {
|
||||
// dialect is not registered.
|
||||
const Dialect* tf_dialect_;
|
||||
const Dialect* tfl_dialect_;
|
||||
|
||||
// The failed ops during legalization.
|
||||
std::vector<std::string> failed_flex_ops_;
|
||||
std::vector<std::string> failed_custom_ops_;
|
||||
};
|
||||
|
||||
std::string Translator::UniqueName(mlir::Operation* op) {
|
||||
@ -800,6 +805,12 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
return llvm::None;
|
||||
}
|
||||
} else {
|
||||
// Insert failed op to `flex_ops` or `custom_ops`.
|
||||
if (IsWhitelistedFlexOp(node_def->op())) {
|
||||
failed_flex_ops_.push_back(node_def->op());
|
||||
} else {
|
||||
failed_custom_ops_.push_back(node_def->op());
|
||||
}
|
||||
return inst->emitOpError("is neither a custom op nor a flex op"),
|
||||
llvm::None;
|
||||
}
|
||||
@ -925,6 +936,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
}
|
||||
}
|
||||
|
||||
bool failed_once = false;
|
||||
for (auto& inst : bb) {
|
||||
if (inst.isKnownTerminator()) break;
|
||||
|
||||
@ -961,9 +973,11 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
if (auto tfl_operator = BuildOperator(&inst, operands, results))
|
||||
operators.push_back(*tfl_operator);
|
||||
else
|
||||
return llvm::None;
|
||||
failed_once = true;
|
||||
}
|
||||
|
||||
if (failed_once) return llvm::None;
|
||||
|
||||
// Get input and output tensor indices for the subgraph.
|
||||
std::vector<int32_t> inputs, outputs;
|
||||
for (auto* arg : bb.getArguments()) {
|
||||
@ -1048,13 +1062,39 @@ Optional<std::string> Translator::TranslateInternal() {
|
||||
// Build subgraph for each of the functions.
|
||||
std::vector<BufferOffset<tflite::SubGraph>> subgraphs;
|
||||
subgraphs.reserve(functions.size());
|
||||
for (auto fn : functions) {
|
||||
auto subgraph_or = BuildSubGraph(fn);
|
||||
if (!subgraph_or)
|
||||
return fn.emitError("failed while converting: '") << fn.getName() << '\'',
|
||||
llvm::None;
|
||||
int first_failed_func = -1;
|
||||
for (int i = 0; i < functions.size(); ++i) {
|
||||
auto subgraph_or = BuildSubGraph(functions[i]);
|
||||
if (!subgraph_or) {
|
||||
if (first_failed_func == -1)
|
||||
// Record the index of the first function that cannot be converted.
|
||||
// Keep looping through all subgraphs in the module to make sure that
|
||||
// we collect the list of missing ops from the entire module.
|
||||
first_failed_func = i;
|
||||
} else {
|
||||
subgraphs.push_back(*subgraph_or);
|
||||
}
|
||||
}
|
||||
|
||||
subgraphs.push_back(*subgraph_or);
|
||||
if (first_failed_func != -1) {
|
||||
std::string failed_flex_ops_list = absl::StrJoin(failed_flex_ops_, ",");
|
||||
std::string failed_custom_ops_list = absl::StrJoin(failed_custom_ops_, ",");
|
||||
std::string err;
|
||||
if (!failed_flex_ops_list.empty())
|
||||
err +=
|
||||
"Ops that can be supported by the flex runtime (enabled via setting "
|
||||
"the -emit-select-tf-ops flag): " +
|
||||
failed_flex_ops_list + ".";
|
||||
if (!failed_custom_ops_list.empty())
|
||||
err +=
|
||||
"Ops that need custom implementation (enabled via setting the "
|
||||
"-emit-custom-ops flag): " +
|
||||
failed_custom_ops_list;
|
||||
|
||||
return functions[first_failed_func].emitError("failed while converting: '")
|
||||
<< functions[first_failed_func].getName() << "\'\n"
|
||||
<< err,
|
||||
llvm::None;
|
||||
}
|
||||
|
||||
std::string model_description;
|
||||
|
@ -1,6 +1,7 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
|
||||
# CHECK: loc("disable_flex.mlir":96:8): error: 'tf.div' op is a Flex op but Flex ops are not enabled for emission
|
||||
# CHECK-NEXT: Verification failed.
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
|
||||
// CHECK: error: 'tf.Div' op is neither a custom op nor a flex op
|
||||
// CHECK: error: failed while converting: 'main'
|
||||
// CHECK: Ops that can be supported by the flex runtime (enabled via setting the -emit-select-tf-ops flag): Div.
|
||||
|
||||
func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
^bb0(%arg0: tensor<4xf32>):
|
||||
|
Loading…
Reference in New Issue
Block a user