From 5a6c606736fc7445c438dc4c89f364d50657d358 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 13 Jun 2020 18:18:59 -0700 Subject: [PATCH] Support flex ops in calibration optimization This CL makes the tool generate a user-friendly error message as well. PiperOrigin-RevId: 316297488 Change-Id: Icc66ae3273c01e9f909f000c9f94b02647477fc3 --- .../lite/tools/optimize/calibration/BUILD | 2 - .../calibration/logging_op_resolver.cc | 30 ------------ .../calibration/logging_op_resolver_test.cc | 47 ------------------- 3 files changed, 79 deletions(-) diff --git a/tensorflow/lite/tools/optimize/calibration/BUILD b/tensorflow/lite/tools/optimize/calibration/BUILD index 743c4be828f..a394156786f 100644 --- a/tensorflow/lite/tools/optimize/calibration/BUILD +++ b/tensorflow/lite/tools/optimize/calibration/BUILD @@ -88,9 +88,7 @@ cc_library( copts = tflite_copts(), deps = [ ":calibration_common", - "//tensorflow/core/platform:logging", "//tensorflow/lite:framework", - "//tensorflow/lite:util", "//tensorflow/lite/core/api", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/tensorflow/lite/tools/optimize/calibration/logging_op_resolver.cc b/tensorflow/lite/tools/optimize/calibration/logging_op_resolver.cc index 171aecad98d..634b2a76a3a 100644 --- a/tensorflow/lite/tools/optimize/calibration/logging_op_resolver.cc +++ b/tensorflow/lite/tools/optimize/calibration/logging_op_resolver.cc @@ -15,10 +15,6 @@ limitations under the License. #include "tensorflow/lite/tools/optimize/calibration/logging_op_resolver.h" #include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/lite/util.h" namespace tflite { namespace optimize { @@ -28,17 +24,9 @@ LoggingOpResolver::LoggingOpResolver( const BuiltinOpsSet& builtin_ops_to_replace, const CustomOpsSet& custom_ops_to_replace, const OpResolver& base_resolver, KernelEvalFuncPtr logging_eval_fn) { - std::vector unresolved_builtin_ops; - std::vector unresolved_custom_ops; - for (const auto& op_and_version : builtin_ops_to_replace) { const TfLiteRegistration* base_registration = base_resolver.FindOp(op_and_version.first, op_and_version.second); - if (!base_registration) { - unresolved_builtin_ops.push_back( - EnumNameBuiltinOperator(op_and_version.first)); - continue; - } BuiltinOperatorKey key = op_and_version; builtin_op_evalfn_map_[key] = base_registration->invoke; auto logging_registration = @@ -49,11 +37,6 @@ LoggingOpResolver::LoggingOpResolver( for (const auto& op_and_version : custom_ops_to_replace) { const TfLiteRegistration* base_registration = base_resolver.FindOp( op_and_version.first.c_str(), op_and_version.second); - if (!base_registration) { - if (!IsFlexOp(op_and_version.first.c_str())) - unresolved_custom_ops.push_back(op_and_version.first.c_str()); - continue; - } CustomOperatorKey key = op_and_version; custom_op_evalfn_map_[key] = base_registration->invoke; auto logging_registration = @@ -61,19 +44,6 @@ LoggingOpResolver::LoggingOpResolver( logging_registration->invoke = logging_eval_fn; custom_op_registration_map_[key] = std::move(logging_registration); } - - if (!unresolved_builtin_ops.empty() || !unresolved_custom_ops.empty()) { - std::string error_message = - "Failed to initialize op resolver for calibration:"; - if (!unresolved_builtin_ops.empty()) - absl::StrAppend(&error_message, "\nThere are unresolved builtin ops: [", - absl::StrJoin(unresolved_builtin_ops, ", "), "]"); - if (!unresolved_custom_ops.empty()) { - absl::StrAppend(&error_message, "\nThere are unresolved custom ops: [", - absl::StrJoin(unresolved_builtin_ops, ", "), "]"); - } - LOG(ERROR) << error_message; - } } const TfLiteRegistration* LoggingOpResolver::FindOp(BuiltinOperator op, diff --git a/tensorflow/lite/tools/optimize/calibration/logging_op_resolver_test.cc b/tensorflow/lite/tools/optimize/calibration/logging_op_resolver_test.cc index bdc7dd1802f..511e4d0288d 100644 --- a/tensorflow/lite/tools/optimize/calibration/logging_op_resolver_test.cc +++ b/tensorflow/lite/tools/optimize/calibration/logging_op_resolver_test.cc @@ -165,53 +165,6 @@ TEST(LoggingOpResolverTest, CustomOps) { EXPECT_TRUE(reg->invoke == WrappingInvoke); } -TEST(LoggingOpResolverTest, UnresolvedCustomOps) { - // No custom op registration. - MutableOpResolver base_resolver; - - std::string custom_op_name = "unresolved_custom_op"; - - CustomOpsSet ops_to_replace = { - {custom_op_name, /*version*/ 1}, - }; - - // Expect no death. - LoggingOpResolver(BuiltinOpsSet(), ops_to_replace, base_resolver, - WrappingInvoke); -} - -TEST(LoggingOpResolverTest, UnresolvedBuiltinOps) { - // No builtin op registration. - MutableOpResolver base_resolver; - - BuiltinOpsSet ops_to_replace = { - {BuiltinOperator_CONV_2D, /*version*/ 1}, - {BuiltinOperator_ADD, /*version*/ 1}, - }; - - // Expect no death. - LoggingOpResolver resolver(ops_to_replace, CustomOpsSet(), base_resolver, - WrappingInvoke); -} - -TEST(LoggingOpResolverTest, FlexOps) { - // No flex op registration. - MutableOpResolver base_resolver; - - std::string custom_op_name = "FlexAdd"; - - CustomOpsSet ops_to_replace = { - {custom_op_name, /*version*/ 1}, - }; - - LoggingOpResolver resolver(BuiltinOpsSet(), ops_to_replace, base_resolver, - WrappingInvoke); - - auto reg = resolver.FindOp(custom_op_name.c_str(), 1); - - EXPECT_TRUE(!reg); -} - } // namespace } // namespace calibration } // namespace optimize