Support flex ops in calibration optimization
This CL makes the tool generate a user-friendly error message as well. In order to use the correct logger for mobile, it uses the error_reporter. PiperOrigin-RevId: 316563081 Change-Id: Ib56f80330087750777725ed6ad3c97f54b1fa80b
This commit is contained in:
parent
1158611838
commit
7292433984
@ -89,6 +89,8 @@ cc_library(
|
||||
deps = [
|
||||
":calibration_common",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:minimal_logging",
|
||||
"//tensorflow/lite:util",
|
||||
"//tensorflow/lite/core/api",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -378,8 +378,8 @@ TfLiteStatus BuildLoggingInterpreter(
|
||||
// Prepare the logging op resolver to use |LoggingEval| for kernel
|
||||
// invocations.
|
||||
auto logging_op_resolver = absl::make_unique<LoggingOpResolver>(
|
||||
builtin_op_and_versions, custom_op_and_versions, op_resolver,
|
||||
LoggingEval);
|
||||
builtin_op_and_versions, custom_op_and_versions, op_resolver, LoggingEval,
|
||||
error_reporter);
|
||||
tflite::InterpreterBuilder(tflite_model, *logging_op_resolver,
|
||||
error_reporter)(interpreter);
|
||||
|
||||
|
@ -15,6 +15,10 @@ limitations under the License.
|
||||
#include "tensorflow/lite/tools/optimize/calibration/logging_op_resolver.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/lite/minimal_logging.h"
|
||||
#include "tensorflow/lite/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace optimize {
|
||||
@ -23,10 +27,18 @@ namespace calibration {
|
||||
LoggingOpResolver::LoggingOpResolver(
|
||||
const BuiltinOpsSet& builtin_ops_to_replace,
|
||||
const CustomOpsSet& custom_ops_to_replace, const OpResolver& base_resolver,
|
||||
KernelEvalFuncPtr logging_eval_fn) {
|
||||
KernelEvalFuncPtr logging_eval_fn, ErrorReporter* error_reporter) {
|
||||
std::vector<std::string> unresolved_builtin_ops;
|
||||
std::vector<std::string> unresolved_custom_ops;
|
||||
|
||||
for (const auto& op_and_version : builtin_ops_to_replace) {
|
||||
const TfLiteRegistration* base_registration =
|
||||
base_resolver.FindOp(op_and_version.first, op_and_version.second);
|
||||
if (!base_registration) {
|
||||
unresolved_builtin_ops.push_back(
|
||||
EnumNameBuiltinOperator(op_and_version.first));
|
||||
continue;
|
||||
}
|
||||
BuiltinOperatorKey key = op_and_version;
|
||||
builtin_op_evalfn_map_[key] = base_registration->invoke;
|
||||
auto logging_registration =
|
||||
@ -37,6 +49,11 @@ LoggingOpResolver::LoggingOpResolver(
|
||||
for (const auto& op_and_version : custom_ops_to_replace) {
|
||||
const TfLiteRegistration* base_registration = base_resolver.FindOp(
|
||||
op_and_version.first.c_str(), op_and_version.second);
|
||||
if (!base_registration) {
|
||||
if (!IsFlexOp(op_and_version.first.c_str()))
|
||||
unresolved_custom_ops.push_back(op_and_version.first.c_str());
|
||||
continue;
|
||||
}
|
||||
CustomOperatorKey key = op_and_version;
|
||||
custom_op_evalfn_map_[key] = base_registration->invoke;
|
||||
auto logging_registration =
|
||||
@ -44,6 +61,20 @@ LoggingOpResolver::LoggingOpResolver(
|
||||
logging_registration->invoke = logging_eval_fn;
|
||||
custom_op_registration_map_[key] = std::move(logging_registration);
|
||||
}
|
||||
|
||||
if (!unresolved_builtin_ops.empty() || !unresolved_custom_ops.empty()) {
|
||||
if (!error_reporter) return;
|
||||
std::string error_message =
|
||||
"Failed to initialize op resolver for calibration:";
|
||||
if (!unresolved_builtin_ops.empty())
|
||||
absl::StrAppend(&error_message, "\nThere are unresolved builtin ops: [",
|
||||
absl::StrJoin(unresolved_builtin_ops, ", "), "]");
|
||||
if (!unresolved_custom_ops.empty()) {
|
||||
absl::StrAppend(&error_message, "\nThere are unresolved custom ops: [",
|
||||
absl::StrJoin(unresolved_builtin_ops, ", "), "]");
|
||||
}
|
||||
TF_LITE_REPORT_ERROR(error_reporter, error_message.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
const TfLiteRegistration* LoggingOpResolver::FindOp(BuiltinOperator op,
|
||||
|
@ -39,7 +39,8 @@ class LoggingOpResolver : public OpResolver {
|
||||
LoggingOpResolver(const BuiltinOpsSet& builtin_ops_to_replace,
|
||||
const CustomOpsSet& custom_ops_to_replace,
|
||||
const OpResolver& base_resolver,
|
||||
KernelEvalFuncPtr logging_eval_fn);
|
||||
KernelEvalFuncPtr logging_eval_fn,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
const TfLiteRegistration* FindOp(BuiltinOperator op,
|
||||
int version) const override;
|
||||
|
@ -70,7 +70,7 @@ TEST(LoggingOpResolverTest, KernelInvokesAreReplaced) {
|
||||
};
|
||||
|
||||
LoggingOpResolver resolver(ops_to_replace, CustomOpsSet(), base_resolver,
|
||||
WrappingInvoke);
|
||||
WrappingInvoke, /*error_reporter=*/nullptr);
|
||||
|
||||
auto reg = resolver.FindOp(BuiltinOperator_CONV_2D, 1);
|
||||
|
||||
@ -104,7 +104,7 @@ TEST(LoggingOpResolverTest, OriginalKernelInvokesAreRetained) {
|
||||
};
|
||||
|
||||
LoggingOpResolver resolver(ops_to_replace, CustomOpsSet(), base_resolver,
|
||||
WrappingInvoke);
|
||||
WrappingInvoke, /*error_reporter=*/nullptr);
|
||||
auto kernel_invoke =
|
||||
resolver.GetWrappedKernelInvoke(BuiltinOperator_CONV_2D, 1);
|
||||
EXPECT_TRUE(kernel_invoke == ConvEval);
|
||||
@ -131,7 +131,7 @@ TEST(LoggingOpResolverTest, OnlyOpsInReplacementSetAreReplaces) {
|
||||
};
|
||||
|
||||
LoggingOpResolver resolver(ops_to_replace, CustomOpsSet(), base_resolver,
|
||||
WrappingInvoke);
|
||||
WrappingInvoke, /*error_reporter=*/nullptr);
|
||||
auto reg = resolver.FindOp(BuiltinOperator_CONV_2D, 1);
|
||||
EXPECT_EQ(reg->builtin_code, BuiltinOperator_CONV_2D);
|
||||
EXPECT_TRUE(reg->prepare == ConvPrepare);
|
||||
@ -155,7 +155,7 @@ TEST(LoggingOpResolverTest, CustomOps) {
|
||||
};
|
||||
|
||||
LoggingOpResolver resolver(BuiltinOpsSet(), ops_to_replace, base_resolver,
|
||||
WrappingInvoke);
|
||||
WrappingInvoke, /*error_reporter=*/nullptr);
|
||||
|
||||
auto reg = resolver.FindOp(custom_op_name.c_str(), 1);
|
||||
|
||||
@ -165,6 +165,53 @@ TEST(LoggingOpResolverTest, CustomOps) {
|
||||
EXPECT_TRUE(reg->invoke == WrappingInvoke);
|
||||
}
|
||||
|
||||
TEST(LoggingOpResolverTest, UnresolvedCustomOps) {
|
||||
// No custom op registration.
|
||||
MutableOpResolver base_resolver;
|
||||
|
||||
std::string custom_op_name = "unresolved_custom_op";
|
||||
|
||||
CustomOpsSet ops_to_replace = {
|
||||
{custom_op_name, /*version*/ 1},
|
||||
};
|
||||
|
||||
// Expect no death.
|
||||
LoggingOpResolver(BuiltinOpsSet(), ops_to_replace, base_resolver,
|
||||
WrappingInvoke, /*error_reporter=*/nullptr);
|
||||
}
|
||||
|
||||
TEST(LoggingOpResolverTest, UnresolvedBuiltinOps) {
|
||||
// No builtin op registration.
|
||||
MutableOpResolver base_resolver;
|
||||
|
||||
BuiltinOpsSet ops_to_replace = {
|
||||
{BuiltinOperator_CONV_2D, /*version*/ 1},
|
||||
{BuiltinOperator_ADD, /*version*/ 1},
|
||||
};
|
||||
|
||||
// Expect no death.
|
||||
LoggingOpResolver resolver(ops_to_replace, CustomOpsSet(), base_resolver,
|
||||
WrappingInvoke, /*error_reporter=*/nullptr);
|
||||
}
|
||||
|
||||
TEST(LoggingOpResolverTest, FlexOps) {
|
||||
// No flex op registration.
|
||||
MutableOpResolver base_resolver;
|
||||
|
||||
std::string custom_op_name = "FlexAdd";
|
||||
|
||||
CustomOpsSet ops_to_replace = {
|
||||
{custom_op_name, /*version*/ 1},
|
||||
};
|
||||
|
||||
LoggingOpResolver resolver(BuiltinOpsSet(), ops_to_replace, base_resolver,
|
||||
WrappingInvoke, /*error_reporter=*/nullptr);
|
||||
|
||||
auto reg = resolver.FindOp(custom_op_name.c_str(), 1);
|
||||
|
||||
EXPECT_TRUE(!reg);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace calibration
|
||||
} // namespace optimize
|
||||
|
Loading…
Reference in New Issue
Block a user