Support flex ops in calibration optimization

This CL makes the tool generate a user-friendly error message as well.

PiperOrigin-RevId: 316297488
Change-Id: Icc66ae3273c01e9f909f000c9f94b02647477fc3
This commit is contained in:
A. Unique TensorFlower 2020-06-13 18:18:59 -07:00 committed by TensorFlower Gardener
parent 5d49dc5526
commit 5a6c606736
3 changed files with 0 additions and 79 deletions

View File

@ -88,9 +88,7 @@ cc_library(
copts = tflite_copts(),
deps = [
":calibration_common",
"//tensorflow/core/platform:logging",
"//tensorflow/lite:framework",
"//tensorflow/lite:util",
"//tensorflow/lite/core/api",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",

View File

@ -15,10 +15,6 @@ limitations under the License.
#include "tensorflow/lite/tools/optimize/calibration/logging_op_resolver.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/util.h"
namespace tflite {
namespace optimize {
@ -28,17 +24,9 @@ LoggingOpResolver::LoggingOpResolver(
const BuiltinOpsSet& builtin_ops_to_replace,
const CustomOpsSet& custom_ops_to_replace, const OpResolver& base_resolver,
KernelEvalFuncPtr logging_eval_fn) {
std::vector<std::string> unresolved_builtin_ops;
std::vector<std::string> unresolved_custom_ops;
for (const auto& op_and_version : builtin_ops_to_replace) {
const TfLiteRegistration* base_registration =
base_resolver.FindOp(op_and_version.first, op_and_version.second);
if (!base_registration) {
unresolved_builtin_ops.push_back(
EnumNameBuiltinOperator(op_and_version.first));
continue;
}
BuiltinOperatorKey key = op_and_version;
builtin_op_evalfn_map_[key] = base_registration->invoke;
auto logging_registration =
@ -49,11 +37,6 @@ LoggingOpResolver::LoggingOpResolver(
for (const auto& op_and_version : custom_ops_to_replace) {
const TfLiteRegistration* base_registration = base_resolver.FindOp(
op_and_version.first.c_str(), op_and_version.second);
if (!base_registration) {
if (!IsFlexOp(op_and_version.first.c_str()))
unresolved_custom_ops.push_back(op_and_version.first.c_str());
continue;
}
CustomOperatorKey key = op_and_version;
custom_op_evalfn_map_[key] = base_registration->invoke;
auto logging_registration =
@ -61,19 +44,6 @@ LoggingOpResolver::LoggingOpResolver(
logging_registration->invoke = logging_eval_fn;
custom_op_registration_map_[key] = std::move(logging_registration);
}
if (!unresolved_builtin_ops.empty() || !unresolved_custom_ops.empty()) {
std::string error_message =
"Failed to initialize op resolver for calibration:";
if (!unresolved_builtin_ops.empty())
absl::StrAppend(&error_message, "\nThere are unresolved builtin ops: [",
absl::StrJoin(unresolved_builtin_ops, ", "), "]");
if (!unresolved_custom_ops.empty()) {
absl::StrAppend(&error_message, "\nThere are unresolved custom ops: [",
absl::StrJoin(unresolved_builtin_ops, ", "), "]");
}
LOG(ERROR) << error_message;
}
}
const TfLiteRegistration* LoggingOpResolver::FindOp(BuiltinOperator op,

View File

@ -165,53 +165,6 @@ TEST(LoggingOpResolverTest, CustomOps) {
EXPECT_TRUE(reg->invoke == WrappingInvoke);
}
TEST(LoggingOpResolverTest, UnresolvedCustomOps) {
// No custom op registration.
MutableOpResolver base_resolver;
std::string custom_op_name = "unresolved_custom_op";
CustomOpsSet ops_to_replace = {
{custom_op_name, /*version*/ 1},
};
// Expect no death.
LoggingOpResolver(BuiltinOpsSet(), ops_to_replace, base_resolver,
WrappingInvoke);
}
TEST(LoggingOpResolverTest, UnresolvedBuiltinOps) {
// No builtin op registration.
MutableOpResolver base_resolver;
BuiltinOpsSet ops_to_replace = {
{BuiltinOperator_CONV_2D, /*version*/ 1},
{BuiltinOperator_ADD, /*version*/ 1},
};
// Expect no death.
LoggingOpResolver resolver(ops_to_replace, CustomOpsSet(), base_resolver,
WrappingInvoke);
}
TEST(LoggingOpResolverTest, FlexOps) {
// No flex op registration.
MutableOpResolver base_resolver;
std::string custom_op_name = "FlexAdd";
CustomOpsSet ops_to_replace = {
{custom_op_name, /*version*/ 1},
};
LoggingOpResolver resolver(BuiltinOpsSet(), ops_to_replace, base_resolver,
WrappingInvoke);
auto reg = resolver.FindOp(custom_op_name.c_str(), 1);
EXPECT_TRUE(!reg);
}
} // namespace
} // namespace calibration
} // namespace optimize