Support real custom ops for Toco --allow_eager_ops flow.
PiperOrigin-RevId: 217020295
This commit is contained in:
parent
109a0c1b15
commit
e4b1832849
@ -26,6 +26,7 @@ cc_library(
|
||||
"//tensorflow/contrib/lite/schema:schema_fbs",
|
||||
"//tensorflow/contrib/lite/toco:graph_transformations",
|
||||
"//tensorflow/contrib/lite/toco:model",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:ptr_util",
|
||||
"@com_google_absl//absl/memory",
|
||||
@ -42,6 +43,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":operator",
|
||||
"//tensorflow/contrib/lite/toco:tooling_util",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
"@flatbuffers",
|
||||
@ -71,6 +73,7 @@ tf_cc_test(
|
||||
tags = ["no_oss"],
|
||||
deps = [
|
||||
":types",
|
||||
"//tensorflow/core:ops",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
@ -106,6 +109,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":export",
|
||||
"//tensorflow/contrib/lite/schema:schema_fbs",
|
||||
"//tensorflow/core:ops",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
@ -141,6 +145,7 @@ tf_cc_test(
|
||||
":import",
|
||||
"//tensorflow/contrib/lite:schema_fbs_version",
|
||||
"//tensorflow/contrib/lite/schema:schema_fbs",
|
||||
"//tensorflow/core:ops",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
"@flatbuffers",
|
||||
],
|
||||
|
@ -140,7 +140,7 @@ OperatorKey GetOperatorKey(
|
||||
|
||||
// TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way
|
||||
// to populate a regular custom op. We need to find a way to fix this.
|
||||
if (allow_flex_ops) {
|
||||
if (ShouldExportAsFlexOp(allow_flex_ops, unsupported_op.tensorflow_op)) {
|
||||
key.is_flex_op = true;
|
||||
key.flex_tensorflow_op = tensorflow_op;
|
||||
key.custom_code =
|
||||
|
@ -23,6 +23,8 @@ limitations under the License.
|
||||
#include "tensorflow/contrib/lite/toco/tflite/types.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
namespace toco {
|
||||
@ -1258,6 +1260,16 @@ class TensorFlowUnsupported : public BaseOperator {
|
||||
return std::unique_ptr<flexbuffers::Builder>();
|
||||
}
|
||||
|
||||
if (ShouldExportAsFlexOp(allow_flex_ops_, node_def.op())) {
|
||||
fbb->Vector([&]() {
|
||||
fbb->String(node_def.op());
|
||||
fbb->String(op.tensorflow_node_def);
|
||||
});
|
||||
fbb->Finish();
|
||||
LOG(INFO) << "Writing flex op: " << node_def.op();
|
||||
return std::unique_ptr<flexbuffers::Builder>(fbb.release());
|
||||
}
|
||||
|
||||
bool has_valid_attr = false;
|
||||
size_t map_start = fbb->StartMap();
|
||||
for (const auto& pair : node_def.attr()) {
|
||||
@ -1588,6 +1600,21 @@ std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
|
||||
return result;
|
||||
}
|
||||
|
||||
bool ShouldExportAsFlexOp(bool allow_flex_ops,
|
||||
const string& tensorflow_op_name) {
|
||||
// If Flex ops aren't allow at all, simply return false.
|
||||
if (!allow_flex_ops) {
|
||||
return false;
|
||||
}
|
||||
// Check if we can find the `OpDef` for the TensorFlow op. If we can find
|
||||
// it, export the op as an Flex op. Otherwise, export it as a regular custom
|
||||
// op.
|
||||
const tensorflow::OpDef* op_def = nullptr;
|
||||
return tensorflow::OpRegistry::Global()
|
||||
->LookUpOpDef(tensorflow_op_name, &op_def)
|
||||
.ok();
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
} // namespace toco
|
||||
|
@ -113,6 +113,11 @@ class BaseOperator {
|
||||
OperatorType type_;
|
||||
};
|
||||
|
||||
// Helper function to determine if a unsupported TensorFlow op should be
|
||||
// exported as an Flex op or a regular custom op.
|
||||
bool ShouldExportAsFlexOp(bool allow_flex_ops,
|
||||
const string& tensorflow_op_name);
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
} // namespace toco
|
||||
|
@ -569,6 +569,12 @@ TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) {
|
||||
EXPECT_TRUE(output_node_def.attr().empty());
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, TestShouldExportAsFlexOp) {
|
||||
EXPECT_FALSE(ShouldExportAsFlexOp(false, "Conv2D"));
|
||||
EXPECT_TRUE(ShouldExportAsFlexOp(true, "Conv2D"));
|
||||
EXPECT_FALSE(ShouldExportAsFlexOp(true, "MyAwesomeCustomOp"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user