diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD index 71cdb7703e9..fcb628fec8f 100644 --- a/tensorflow/contrib/lite/toco/tflite/BUILD +++ b/tensorflow/contrib/lite/toco/tflite/BUILD @@ -26,6 +26,7 @@ cc_library( "//tensorflow/contrib/lite/schema:schema_fbs", "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", + "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core:ptr_util", "@com_google_absl//absl/memory", @@ -42,6 +43,7 @@ tf_cc_test( deps = [ ":operator", "//tensorflow/contrib/lite/toco:tooling_util", + "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", "@com_google_googletest//:gtest_main", "@flatbuffers", @@ -71,6 +73,7 @@ tf_cc_test( tags = ["no_oss"], deps = [ ":types", + "//tensorflow/core:ops", "@com_google_googletest//:gtest_main", ], ) @@ -106,6 +109,7 @@ tf_cc_test( deps = [ ":export", "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/core:ops", "@com_google_googletest//:gtest_main", ], ) @@ -141,6 +145,7 @@ tf_cc_test( ":import", "//tensorflow/contrib/lite:schema_fbs_version", "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/core:ops", "@com_google_googletest//:gtest_main", "@flatbuffers", ], diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index c23043789c0..8dcb7957384 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -140,7 +140,7 @@ OperatorKey GetOperatorKey( // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way // to populate a regular custom op. We need to find a way to fix this. - if (allow_flex_ops) { + if (ShouldExportAsFlexOp(allow_flex_ops, unsupported_op.tensorflow_op)) { key.is_flex_op = true; key.flex_tensorflow_op = tensorflow_op; key.custom_code = diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index e08a61d357d..1ee71d4341c 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -23,6 +23,8 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/tflite/types.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/util/ptr_util.h" namespace toco { @@ -1258,6 +1260,16 @@ class TensorFlowUnsupported : public BaseOperator { return std::unique_ptr(); } + if (ShouldExportAsFlexOp(allow_flex_ops_, node_def.op())) { + fbb->Vector([&]() { + fbb->String(node_def.op()); + fbb->String(op.tensorflow_node_def); + }); + fbb->Finish(); + LOG(INFO) << "Writing flex op: " << node_def.op(); + return std::unique_ptr(fbb.release()); + } + bool has_valid_attr = false; size_t map_start = fbb->StartMap(); for (const auto& pair : node_def.attr()) { @@ -1588,6 +1600,21 @@ std::map> BuildOperatorByNameMap( return result; } +bool ShouldExportAsFlexOp(bool allow_flex_ops, + const string& tensorflow_op_name) { + // If Flex ops aren't allow at all, simply return false. + if (!allow_flex_ops) { + return false; + } + // Check if we can find the `OpDef` for the TensorFlow op. If we can find + // it, export the op as an Flex op. Otherwise, export it as a regular custom + // op. + const tensorflow::OpDef* op_def = nullptr; + return tensorflow::OpRegistry::Global() + ->LookUpOpDef(tensorflow_op_name, &op_def) + .ok(); +} + } // namespace tflite } // namespace toco diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h index 6e4e0a16d18..6e2a41bf53a 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.h +++ b/tensorflow/contrib/lite/toco/tflite/operator.h @@ -113,6 +113,11 @@ class BaseOperator { OperatorType type_; }; +// Helper function to determine if a unsupported TensorFlow op should be +// exported as an Flex op or a regular custom op. +bool ShouldExportAsFlexOp(bool allow_flex_ops, + const string& tensorflow_op_name); + } // namespace tflite } // namespace toco diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 0bc591e6471..66896a49c09 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -569,6 +569,12 @@ TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) { EXPECT_TRUE(output_node_def.attr().empty()); } +TEST_F(OperatorTest, TestShouldExportAsFlexOp) { + EXPECT_FALSE(ShouldExportAsFlexOp(false, "Conv2D")); + EXPECT_TRUE(ShouldExportAsFlexOp(true, "Conv2D")); + EXPECT_FALSE(ShouldExportAsFlexOp(true, "MyAwesomeCustomOp")); +} + } // namespace } // namespace tflite