Support real custom ops for Toco --allow_eager_ops flow.
PiperOrigin-RevId: 217020295
This commit is contained in:
parent
109a0c1b15
commit
e4b1832849
@ -26,6 +26,7 @@ cc_library(
|
|||||||
"//tensorflow/contrib/lite/schema:schema_fbs",
|
"//tensorflow/contrib/lite/schema:schema_fbs",
|
||||||
"//tensorflow/contrib/lite/toco:graph_transformations",
|
"//tensorflow/contrib/lite/toco:graph_transformations",
|
||||||
"//tensorflow/contrib/lite/toco:model",
|
"//tensorflow/contrib/lite/toco:model",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:ptr_util",
|
"//tensorflow/core:ptr_util",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
@ -42,6 +43,7 @@ tf_cc_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":operator",
|
":operator",
|
||||||
"//tensorflow/contrib/lite/toco:tooling_util",
|
"//tensorflow/contrib/lite/toco:tooling_util",
|
||||||
|
"//tensorflow/core:ops",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"@com_google_googletest//:gtest_main",
|
"@com_google_googletest//:gtest_main",
|
||||||
"@flatbuffers",
|
"@flatbuffers",
|
||||||
@ -71,6 +73,7 @@ tf_cc_test(
|
|||||||
tags = ["no_oss"],
|
tags = ["no_oss"],
|
||||||
deps = [
|
deps = [
|
||||||
":types",
|
":types",
|
||||||
|
"//tensorflow/core:ops",
|
||||||
"@com_google_googletest//:gtest_main",
|
"@com_google_googletest//:gtest_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -106,6 +109,7 @@ tf_cc_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":export",
|
":export",
|
||||||
"//tensorflow/contrib/lite/schema:schema_fbs",
|
"//tensorflow/contrib/lite/schema:schema_fbs",
|
||||||
|
"//tensorflow/core:ops",
|
||||||
"@com_google_googletest//:gtest_main",
|
"@com_google_googletest//:gtest_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -141,6 +145,7 @@ tf_cc_test(
|
|||||||
":import",
|
":import",
|
||||||
"//tensorflow/contrib/lite:schema_fbs_version",
|
"//tensorflow/contrib/lite:schema_fbs_version",
|
||||||
"//tensorflow/contrib/lite/schema:schema_fbs",
|
"//tensorflow/contrib/lite/schema:schema_fbs",
|
||||||
|
"//tensorflow/core:ops",
|
||||||
"@com_google_googletest//:gtest_main",
|
"@com_google_googletest//:gtest_main",
|
||||||
"@flatbuffers",
|
"@flatbuffers",
|
||||||
],
|
],
|
||||||
|
@ -140,7 +140,7 @@ OperatorKey GetOperatorKey(
|
|||||||
|
|
||||||
// TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way
|
// TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way
|
||||||
// to populate a regular custom op. We need to find a way to fix this.
|
// to populate a regular custom op. We need to find a way to fix this.
|
||||||
if (allow_flex_ops) {
|
if (ShouldExportAsFlexOp(allow_flex_ops, unsupported_op.tensorflow_op)) {
|
||||||
key.is_flex_op = true;
|
key.is_flex_op = true;
|
||||||
key.flex_tensorflow_op = tensorflow_op;
|
key.flex_tensorflow_op = tensorflow_op;
|
||||||
key.custom_code =
|
key.custom_code =
|
||||||
|
@ -23,6 +23,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/contrib/lite/toco/tflite/types.h"
|
#include "tensorflow/contrib/lite/toco/tflite/types.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_def.pb.h"
|
||||||
#include "tensorflow/core/util/ptr_util.h"
|
#include "tensorflow/core/util/ptr_util.h"
|
||||||
|
|
||||||
namespace toco {
|
namespace toco {
|
||||||
@ -1258,6 +1260,16 @@ class TensorFlowUnsupported : public BaseOperator {
|
|||||||
return std::unique_ptr<flexbuffers::Builder>();
|
return std::unique_ptr<flexbuffers::Builder>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (ShouldExportAsFlexOp(allow_flex_ops_, node_def.op())) {
|
||||||
|
fbb->Vector([&]() {
|
||||||
|
fbb->String(node_def.op());
|
||||||
|
fbb->String(op.tensorflow_node_def);
|
||||||
|
});
|
||||||
|
fbb->Finish();
|
||||||
|
LOG(INFO) << "Writing flex op: " << node_def.op();
|
||||||
|
return std::unique_ptr<flexbuffers::Builder>(fbb.release());
|
||||||
|
}
|
||||||
|
|
||||||
bool has_valid_attr = false;
|
bool has_valid_attr = false;
|
||||||
size_t map_start = fbb->StartMap();
|
size_t map_start = fbb->StartMap();
|
||||||
for (const auto& pair : node_def.attr()) {
|
for (const auto& pair : node_def.attr()) {
|
||||||
@ -1588,6 +1600,21 @@ std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ShouldExportAsFlexOp(bool allow_flex_ops,
|
||||||
|
const string& tensorflow_op_name) {
|
||||||
|
// If Flex ops aren't allow at all, simply return false.
|
||||||
|
if (!allow_flex_ops) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Check if we can find the `OpDef` for the TensorFlow op. If we can find
|
||||||
|
// it, export the op as an Flex op. Otherwise, export it as a regular custom
|
||||||
|
// op.
|
||||||
|
const tensorflow::OpDef* op_def = nullptr;
|
||||||
|
return tensorflow::OpRegistry::Global()
|
||||||
|
->LookUpOpDef(tensorflow_op_name, &op_def)
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
} // namespace toco
|
} // namespace toco
|
||||||
|
@ -113,6 +113,11 @@ class BaseOperator {
|
|||||||
OperatorType type_;
|
OperatorType type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Helper function to determine if a unsupported TensorFlow op should be
|
||||||
|
// exported as an Flex op or a regular custom op.
|
||||||
|
bool ShouldExportAsFlexOp(bool allow_flex_ops,
|
||||||
|
const string& tensorflow_op_name);
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
} // namespace toco
|
} // namespace toco
|
||||||
|
@ -569,6 +569,12 @@ TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) {
|
|||||||
EXPECT_TRUE(output_node_def.attr().empty());
|
EXPECT_TRUE(output_node_def.attr().empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(OperatorTest, TestShouldExportAsFlexOp) {
|
||||||
|
EXPECT_FALSE(ShouldExportAsFlexOp(false, "Conv2D"));
|
||||||
|
EXPECT_TRUE(ShouldExportAsFlexOp(true, "Conv2D"));
|
||||||
|
EXPECT_FALSE(ShouldExportAsFlexOp(true, "MyAwesomeCustomOp"));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user