Support real custom ops for Toco --allow_eager_ops flow.

PiperOrigin-RevId: 217020295
This commit is contained in:
Yu-Cheng Ling 2018-10-13 20:26:25 -07:00 committed by TensorFlower Gardener
parent 109a0c1b15
commit e4b1832849
5 changed files with 44 additions and 1 deletions

View File

@ -26,6 +26,7 @@ cc_library(
"//tensorflow/contrib/lite/schema:schema_fbs",
"//tensorflow/contrib/lite/toco:graph_transformations",
"//tensorflow/contrib/lite/toco:model",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:ptr_util",
"@com_google_absl//absl/memory",
@ -42,6 +43,7 @@ tf_cc_test(
deps = [
":operator",
"//tensorflow/contrib/lite/toco:tooling_util",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"@com_google_googletest//:gtest_main",
"@flatbuffers",
@ -71,6 +73,7 @@ tf_cc_test(
tags = ["no_oss"],
deps = [
":types",
"//tensorflow/core:ops",
"@com_google_googletest//:gtest_main",
],
)
@ -106,6 +109,7 @@ tf_cc_test(
deps = [
":export",
"//tensorflow/contrib/lite/schema:schema_fbs",
"//tensorflow/core:ops",
"@com_google_googletest//:gtest_main",
],
)
@ -141,6 +145,7 @@ tf_cc_test(
":import",
"//tensorflow/contrib/lite:schema_fbs_version",
"//tensorflow/contrib/lite/schema:schema_fbs",
"//tensorflow/core:ops",
"@com_google_googletest//:gtest_main",
"@flatbuffers",
],

View File

@ -140,7 +140,7 @@ OperatorKey GetOperatorKey(
// TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way
// to populate a regular custom op. We need to find a way to fix this.
if (allow_flex_ops) {
if (ShouldExportAsFlexOp(allow_flex_ops, unsupported_op.tensorflow_op)) {
key.is_flex_op = true;
key.flex_tensorflow_op = tensorflow_op;
key.custom_code =

View File

@ -23,6 +23,8 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/tflite/types.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/util/ptr_util.h"
namespace toco {
@ -1258,6 +1260,16 @@ class TensorFlowUnsupported : public BaseOperator {
return std::unique_ptr<flexbuffers::Builder>();
}
if (ShouldExportAsFlexOp(allow_flex_ops_, node_def.op())) {
fbb->Vector([&]() {
fbb->String(node_def.op());
fbb->String(op.tensorflow_node_def);
});
fbb->Finish();
LOG(INFO) << "Writing flex op: " << node_def.op();
return std::unique_ptr<flexbuffers::Builder>(fbb.release());
}
bool has_valid_attr = false;
size_t map_start = fbb->StartMap();
for (const auto& pair : node_def.attr()) {
@ -1588,6 +1600,21 @@ std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
return result;
}
bool ShouldExportAsFlexOp(bool allow_flex_ops,
const string& tensorflow_op_name) {
// If Flex ops aren't allow at all, simply return false.
if (!allow_flex_ops) {
return false;
}
// Check if we can find the `OpDef` for the TensorFlow op. If we can find
// it, export the op as an Flex op. Otherwise, export it as a regular custom
// op.
const tensorflow::OpDef* op_def = nullptr;
return tensorflow::OpRegistry::Global()
->LookUpOpDef(tensorflow_op_name, &op_def)
.ok();
}
} // namespace tflite
} // namespace toco

View File

@ -113,6 +113,11 @@ class BaseOperator {
OperatorType type_;
};
// Helper function to determine if a unsupported TensorFlow op should be
// exported as an Flex op or a regular custom op.
bool ShouldExportAsFlexOp(bool allow_flex_ops,
const string& tensorflow_op_name);
} // namespace tflite
} // namespace toco

View File

@ -569,6 +569,12 @@ TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) {
EXPECT_TRUE(output_node_def.attr().empty());
}
TEST_F(OperatorTest, TestShouldExportAsFlexOp) {
EXPECT_FALSE(ShouldExportAsFlexOp(false, "Conv2D"));
EXPECT_TRUE(ShouldExportAsFlexOp(true, "Conv2D"));
EXPECT_FALSE(ShouldExportAsFlexOp(true, "MyAwesomeCustomOp"));
}
} // namespace
} // namespace tflite