Support flex ops in calibration optimization

This CL makes the tool generate a user-friendly error message as well.

In order to use the correct logger for mobile, it uses the error_reporter.

PiperOrigin-RevId: 316563081
Change-Id: Ib56f80330087750777725ed6ad3c97f54b1fa80b
This commit is contained in:
Jaesung Chung 2020-06-15 15:57:32 -07:00 committed by TensorFlower Gardener
parent 1158611838
commit 7292433984
5 changed files with 89 additions and 8 deletions

View File

@ -89,6 +89,8 @@ cc_library(
deps = [
":calibration_common",
"//tensorflow/lite:framework",
"//tensorflow/lite:minimal_logging",
"//tensorflow/lite:util",
"//tensorflow/lite/core/api",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",

View File

@ -378,8 +378,8 @@ TfLiteStatus BuildLoggingInterpreter(
// Prepare the logging op resolver to use |LoggingEval| for kernel
// invocations.
auto logging_op_resolver = absl::make_unique<LoggingOpResolver>(
builtin_op_and_versions, custom_op_and_versions, op_resolver,
LoggingEval);
builtin_op_and_versions, custom_op_and_versions, op_resolver, LoggingEval,
error_reporter);
tflite::InterpreterBuilder(tflite_model, *logging_op_resolver,
error_reporter)(interpreter);

View File

@ -15,6 +15,10 @@ limitations under the License.
#include "tensorflow/lite/tools/optimize/calibration/logging_op_resolver.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/lite/minimal_logging.h"
#include "tensorflow/lite/util.h"
namespace tflite {
namespace optimize {
@ -23,10 +27,18 @@ namespace calibration {
LoggingOpResolver::LoggingOpResolver(
const BuiltinOpsSet& builtin_ops_to_replace,
const CustomOpsSet& custom_ops_to_replace, const OpResolver& base_resolver,
KernelEvalFuncPtr logging_eval_fn) {
KernelEvalFuncPtr logging_eval_fn, ErrorReporter* error_reporter) {
std::vector<std::string> unresolved_builtin_ops;
std::vector<std::string> unresolved_custom_ops;
for (const auto& op_and_version : builtin_ops_to_replace) {
const TfLiteRegistration* base_registration =
base_resolver.FindOp(op_and_version.first, op_and_version.second);
if (!base_registration) {
unresolved_builtin_ops.push_back(
EnumNameBuiltinOperator(op_and_version.first));
continue;
}
BuiltinOperatorKey key = op_and_version;
builtin_op_evalfn_map_[key] = base_registration->invoke;
auto logging_registration =
@ -37,6 +49,11 @@ LoggingOpResolver::LoggingOpResolver(
for (const auto& op_and_version : custom_ops_to_replace) {
const TfLiteRegistration* base_registration = base_resolver.FindOp(
op_and_version.first.c_str(), op_and_version.second);
if (!base_registration) {
if (!IsFlexOp(op_and_version.first.c_str()))
unresolved_custom_ops.push_back(op_and_version.first.c_str());
continue;
}
CustomOperatorKey key = op_and_version;
custom_op_evalfn_map_[key] = base_registration->invoke;
auto logging_registration =
@ -44,6 +61,20 @@ LoggingOpResolver::LoggingOpResolver(
logging_registration->invoke = logging_eval_fn;
custom_op_registration_map_[key] = std::move(logging_registration);
}
if (!unresolved_builtin_ops.empty() || !unresolved_custom_ops.empty()) {
if (!error_reporter) return;
std::string error_message =
"Failed to initialize op resolver for calibration:";
if (!unresolved_builtin_ops.empty())
absl::StrAppend(&error_message, "\nThere are unresolved builtin ops: [",
absl::StrJoin(unresolved_builtin_ops, ", "), "]");
if (!unresolved_custom_ops.empty()) {
absl::StrAppend(&error_message, "\nThere are unresolved custom ops: [",
absl::StrJoin(unresolved_builtin_ops, ", "), "]");
}
TF_LITE_REPORT_ERROR(error_reporter, error_message.c_str());
}
}
const TfLiteRegistration* LoggingOpResolver::FindOp(BuiltinOperator op,

View File

@ -39,7 +39,8 @@ class LoggingOpResolver : public OpResolver {
LoggingOpResolver(const BuiltinOpsSet& builtin_ops_to_replace,
const CustomOpsSet& custom_ops_to_replace,
const OpResolver& base_resolver,
KernelEvalFuncPtr logging_eval_fn);
KernelEvalFuncPtr logging_eval_fn,
ErrorReporter* error_reporter);
const TfLiteRegistration* FindOp(BuiltinOperator op,
int version) const override;

View File

@ -70,7 +70,7 @@ TEST(LoggingOpResolverTest, KernelInvokesAreReplaced) {
};
LoggingOpResolver resolver(ops_to_replace, CustomOpsSet(), base_resolver,
WrappingInvoke);
WrappingInvoke, /*error_reporter=*/nullptr);
auto reg = resolver.FindOp(BuiltinOperator_CONV_2D, 1);
@ -104,7 +104,7 @@ TEST(LoggingOpResolverTest, OriginalKernelInvokesAreRetained) {
};
LoggingOpResolver resolver(ops_to_replace, CustomOpsSet(), base_resolver,
WrappingInvoke);
WrappingInvoke, /*error_reporter=*/nullptr);
auto kernel_invoke =
resolver.GetWrappedKernelInvoke(BuiltinOperator_CONV_2D, 1);
EXPECT_TRUE(kernel_invoke == ConvEval);
@ -131,7 +131,7 @@ TEST(LoggingOpResolverTest, OnlyOpsInReplacementSetAreReplaces) {
};
LoggingOpResolver resolver(ops_to_replace, CustomOpsSet(), base_resolver,
WrappingInvoke);
WrappingInvoke, /*error_reporter=*/nullptr);
auto reg = resolver.FindOp(BuiltinOperator_CONV_2D, 1);
EXPECT_EQ(reg->builtin_code, BuiltinOperator_CONV_2D);
EXPECT_TRUE(reg->prepare == ConvPrepare);
@ -155,7 +155,7 @@ TEST(LoggingOpResolverTest, CustomOps) {
};
LoggingOpResolver resolver(BuiltinOpsSet(), ops_to_replace, base_resolver,
WrappingInvoke);
WrappingInvoke, /*error_reporter=*/nullptr);
auto reg = resolver.FindOp(custom_op_name.c_str(), 1);
@ -165,6 +165,53 @@ TEST(LoggingOpResolverTest, CustomOps) {
EXPECT_TRUE(reg->invoke == WrappingInvoke);
}
TEST(LoggingOpResolverTest, UnresolvedCustomOps) {
// No custom op registration.
MutableOpResolver base_resolver;
std::string custom_op_name = "unresolved_custom_op";
CustomOpsSet ops_to_replace = {
{custom_op_name, /*version*/ 1},
};
// Expect no death.
LoggingOpResolver(BuiltinOpsSet(), ops_to_replace, base_resolver,
WrappingInvoke, /*error_reporter=*/nullptr);
}
TEST(LoggingOpResolverTest, UnresolvedBuiltinOps) {
// No builtin op registration.
MutableOpResolver base_resolver;
BuiltinOpsSet ops_to_replace = {
{BuiltinOperator_CONV_2D, /*version*/ 1},
{BuiltinOperator_ADD, /*version*/ 1},
};
// Expect no death.
LoggingOpResolver resolver(ops_to_replace, CustomOpsSet(), base_resolver,
WrappingInvoke, /*error_reporter=*/nullptr);
}
TEST(LoggingOpResolverTest, FlexOps) {
// No flex op registration.
MutableOpResolver base_resolver;
std::string custom_op_name = "FlexAdd";
CustomOpsSet ops_to_replace = {
{custom_op_name, /*version*/ 1},
};
LoggingOpResolver resolver(BuiltinOpsSet(), ops_to_replace, base_resolver,
WrappingInvoke, /*error_reporter=*/nullptr);
auto reg = resolver.FindOp(custom_op_name.c_str(), 1);
EXPECT_TRUE(!reg);
}
} // namespace
} // namespace calibration
} // namespace optimize