Add custom op support to calibrator.
Also fix incorrectly set version while i am here. PiperOrigin-RevId: 251583288
This commit is contained in:
parent
19665c6333
commit
c8fd0ce78e
@ -25,16 +25,28 @@ namespace optimize {
|
||||
namespace calibration {
|
||||
using BuiltinOperatorKey = std::pair<BuiltinOperator, int>;
|
||||
|
||||
using CustomOperatorKey = std::pair<std::string, int>;
|
||||
|
||||
using BuiltinOpsSet = std::unordered_set<
|
||||
BuiltinOperatorKey,
|
||||
op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey>>;
|
||||
|
||||
using CustomOpsSet = std::unordered_set<
|
||||
CustomOperatorKey,
|
||||
op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey>>;
|
||||
|
||||
template <typename T>
|
||||
class BuiltinOpsMap
|
||||
: public std::unordered_map<
|
||||
BuiltinOperatorKey, T,
|
||||
op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey>> {};
|
||||
|
||||
template <typename T>
|
||||
class CustomOpsMap
|
||||
: public std::unordered_map<
|
||||
CustomOperatorKey, T,
|
||||
op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey>> {};
|
||||
|
||||
// An alias for |TfLiteRegistration.invoke|.
|
||||
using KernelEvalFuncPtr = TfLiteStatus (*)(TfLiteContext*, TfLiteNode*);
|
||||
|
||||
@ -53,6 +65,7 @@ struct OperatorInfo {
|
||||
// Outputs that need to be logged.
|
||||
std::vector<int> loggable_outputs;
|
||||
const TfLiteRegistration* registration;
|
||||
int version;
|
||||
};
|
||||
|
||||
} // namespace calibration
|
||||
|
||||
@ -79,8 +79,12 @@ class Calibrator {
|
||||
|
||||
KernelEvalFuncPtr Calibrator::GetKernelInvoke(const TfLiteNode* node) const {
|
||||
auto op_info = node_ptr_opinfo_map_.at(node);
|
||||
if (op_info.is_custom_op) {
|
||||
return logging_op_resolver_->GetWrappedKernelInvoke(op_info.name.c_str(),
|
||||
op_info.version);
|
||||
}
|
||||
return logging_op_resolver_->GetWrappedKernelInvoke(op_info.builtin_op_code,
|
||||
1);
|
||||
op_info.version);
|
||||
}
|
||||
|
||||
// A registry of |Calibrator| objects per |TfLiteContext|.
|
||||
@ -282,7 +286,8 @@ TfLiteStatus BuildLoggingInterpreter(
|
||||
auto operators = primary_subgraph->operators();
|
||||
auto tensors = primary_subgraph->tensors();
|
||||
std::unordered_map<int, OperatorInfo> node_to_opinfo;
|
||||
BuiltinOpsSet op_and_versions;
|
||||
BuiltinOpsSet builtin_op_and_versions;
|
||||
CustomOpsSet custom_op_and_versions;
|
||||
|
||||
for (size_t i = 0; i < operators->size(); i++) {
|
||||
OperatorInfo op_info;
|
||||
@ -292,6 +297,7 @@ TfLiteStatus BuildLoggingInterpreter(
|
||||
op_info.builtin_op_code = operator_code->builtin_code();
|
||||
op_info.name = GetOpName(*operator_code);
|
||||
op_info.is_custom_op = operator_code->custom_code() != nullptr;
|
||||
op_info.version = operator_code->version();
|
||||
|
||||
auto op_inputs = op->inputs();
|
||||
auto op_outputs = op->outputs();
|
||||
@ -301,21 +307,25 @@ TfLiteStatus BuildLoggingInterpreter(
|
||||
GetLoggableTensorIndices(op_info.inputs, tensors, tensor_buffers);
|
||||
op_info.loggable_outputs =
|
||||
GetLoggableTensorIndices(op_info.outputs, tensors, tensor_buffers);
|
||||
if (!op_info.is_custom_op) {
|
||||
op_info.registration = op_resolver.FindOp(operator_code->builtin_code(),
|
||||
operator_code->version());
|
||||
} else {
|
||||
if (op_info.is_custom_op) {
|
||||
op_info.registration =
|
||||
op_resolver.FindOp(op_info.name.c_str(), operator_code->version());
|
||||
custom_op_and_versions.insert(
|
||||
{op_info.name.c_str(), operator_code->version()});
|
||||
} else {
|
||||
op_info.registration = op_resolver.FindOp(operator_code->builtin_code(),
|
||||
operator_code->version());
|
||||
builtin_op_and_versions.insert(
|
||||
{op_info.builtin_op_code, operator_code->version()});
|
||||
}
|
||||
node_to_opinfo[i] = op_info;
|
||||
op_and_versions.insert({op_info.builtin_op_code, operator_code->version()});
|
||||
}
|
||||
|
||||
// Prepare the logging op resolver to use |LoggingEval| for kernel
|
||||
// invocations.
|
||||
auto logging_op_resolver = absl::make_unique<LoggingOpResolver>(
|
||||
op_and_versions, op_resolver, LoggingEval);
|
||||
builtin_op_and_versions, custom_op_and_versions, op_resolver,
|
||||
LoggingEval);
|
||||
tflite::InterpreterBuilder(model, *logging_op_resolver)(interpreter);
|
||||
|
||||
if (!(*interpreter)) {
|
||||
|
||||
@ -20,10 +20,11 @@ namespace tflite {
|
||||
namespace optimize {
|
||||
namespace calibration {
|
||||
|
||||
LoggingOpResolver::LoggingOpResolver(const BuiltinOpsSet& ops_to_replace,
|
||||
const OpResolver& base_resolver,
|
||||
KernelEvalFuncPtr logging_eval_fn) {
|
||||
for (const auto& op_and_version : ops_to_replace) {
|
||||
LoggingOpResolver::LoggingOpResolver(
|
||||
const BuiltinOpsSet& builtin_ops_to_replace,
|
||||
const CustomOpsSet& custom_ops_to_replace, const OpResolver& base_resolver,
|
||||
KernelEvalFuncPtr logging_eval_fn) {
|
||||
for (const auto& op_and_version : builtin_ops_to_replace) {
|
||||
const TfLiteRegistration* base_registration =
|
||||
base_resolver.FindOp(op_and_version.first, op_and_version.second);
|
||||
BuiltinOperatorKey key = op_and_version;
|
||||
@ -33,6 +34,16 @@ LoggingOpResolver::LoggingOpResolver(const BuiltinOpsSet& ops_to_replace,
|
||||
logging_registation->invoke = logging_eval_fn;
|
||||
builtin_op_registration_map_[key] = std::move(logging_registation);
|
||||
}
|
||||
for (const auto& op_and_version : custom_ops_to_replace) {
|
||||
const TfLiteRegistration* base_registration = base_resolver.FindOp(
|
||||
op_and_version.first.c_str(), op_and_version.second);
|
||||
CustomOperatorKey key = op_and_version;
|
||||
custom_op_evalfn_map_[key] = base_registration->invoke;
|
||||
auto logging_registation =
|
||||
absl::make_unique<TfLiteRegistration>(*base_registration);
|
||||
logging_registation->invoke = logging_eval_fn;
|
||||
custom_op_registration_map_[key] = std::move(logging_registation);
|
||||
}
|
||||
}
|
||||
|
||||
const TfLiteRegistration* LoggingOpResolver::FindOp(BuiltinOperator op,
|
||||
@ -53,9 +64,20 @@ KernelEvalFuncPtr LoggingOpResolver::GetWrappedKernelInvoke(BuiltinOperator op,
|
||||
|
||||
const TfLiteRegistration* LoggingOpResolver::FindOp(const char* op,
|
||||
int version) const {
|
||||
// TODO(b/121374947): Support custom ops as well.
|
||||
CustomOperatorKey key = {op, version};
|
||||
if (custom_op_registration_map_.find(key) !=
|
||||
custom_op_registration_map_.end()) {
|
||||
return custom_op_registration_map_.at(key).get();
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
KernelEvalFuncPtr LoggingOpResolver::GetWrappedKernelInvoke(const char* op,
|
||||
int version) const {
|
||||
return custom_op_evalfn_map_.at({op, version});
|
||||
}
|
||||
|
||||
} // namespace calibration
|
||||
} // namespace optimize
|
||||
} // namespace tflite
|
||||
|
||||
@ -26,6 +26,7 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace optimize {
|
||||
namespace calibration {
|
||||
|
||||
// A resolver that replaces the kernel invocations with a wrapper
|
||||
// eval function.
|
||||
class LoggingOpResolver : public OpResolver {
|
||||
@ -33,23 +34,27 @@ class LoggingOpResolver : public OpResolver {
|
||||
// Creates an instance of |LoggingOpResolver|.
|
||||
// All |TfLiteRegistration.invoke| functions are replaced by
|
||||
// |logging_eval_fn|.
|
||||
// TODO(shashishekhar): This interface needs to change for custom ops and
|
||||
// TODO(shashishekhar): This interface needs to change for
|
||||
// BuiltinOps that need special logging implementations.
|
||||
LoggingOpResolver(const BuiltinOpsSet& ops_to_replace,
|
||||
LoggingOpResolver(const BuiltinOpsSet& builtin_ops_to_replace,
|
||||
const CustomOpsSet& custom_ops_to_replace,
|
||||
const OpResolver& base_resolver,
|
||||
KernelEvalFuncPtr logging_eval_fn);
|
||||
|
||||
const TfLiteRegistration* FindOp(BuiltinOperator op,
|
||||
int version) const override;
|
||||
|
||||
KernelEvalFuncPtr GetWrappedKernelInvoke(BuiltinOperator op,
|
||||
int version) const;
|
||||
|
||||
const TfLiteRegistration* FindOp(const char* op, int version) const override;
|
||||
KernelEvalFuncPtr GetWrappedKernelInvoke(const char* op, int version) const;
|
||||
|
||||
private:
|
||||
BuiltinOpsMap<std::unique_ptr<TfLiteRegistration>>
|
||||
builtin_op_registration_map_;
|
||||
BuiltinOpsMap<KernelEvalFuncPtr> builtin_op_evalfn_map_;
|
||||
CustomOpsMap<std::unique_ptr<TfLiteRegistration>> custom_op_registration_map_;
|
||||
CustomOpsMap<KernelEvalFuncPtr> custom_op_evalfn_map_;
|
||||
};
|
||||
|
||||
} // namespace calibration
|
||||
|
||||
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/tools/optimize/calibration/logging_op_resolver.h"
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
@ -38,6 +39,14 @@ TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus CustomPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus CustomEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus WrappingInvoke(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
@ -60,7 +69,8 @@ TEST(LoggingOpResolverTest, KernelInvokesAreReplaced) {
|
||||
{BuiltinOperator_ADD, /*version*/ 1},
|
||||
};
|
||||
|
||||
LoggingOpResolver resolver(ops_to_replace, base_resolver, WrappingInvoke);
|
||||
LoggingOpResolver resolver(ops_to_replace, CustomOpsSet(), base_resolver,
|
||||
WrappingInvoke);
|
||||
|
||||
auto reg = resolver.FindOp(BuiltinOperator_CONV_2D, 1);
|
||||
|
||||
@ -93,7 +103,8 @@ TEST(LoggingOpResolverTest, OriginalKernelInvokesAreRetained) {
|
||||
{BuiltinOperator_ADD, /*version*/ 1},
|
||||
};
|
||||
|
||||
LoggingOpResolver resolver(ops_to_replace, base_resolver, WrappingInvoke);
|
||||
LoggingOpResolver resolver(ops_to_replace, CustomOpsSet(), base_resolver,
|
||||
WrappingInvoke);
|
||||
auto kernel_invoke =
|
||||
resolver.GetWrappedKernelInvoke(BuiltinOperator_CONV_2D, 1);
|
||||
EXPECT_TRUE(kernel_invoke == ConvEval);
|
||||
@ -119,7 +130,8 @@ TEST(LoggingOpResolverTest, OnlyOpsInReplacementSetAreReplaces) {
|
||||
{BuiltinOperator_CONV_2D, /*version*/ 1},
|
||||
};
|
||||
|
||||
LoggingOpResolver resolver(ops_to_replace, base_resolver, WrappingInvoke);
|
||||
LoggingOpResolver resolver(ops_to_replace, CustomOpsSet(), base_resolver,
|
||||
WrappingInvoke);
|
||||
auto reg = resolver.FindOp(BuiltinOperator_CONV_2D, 1);
|
||||
EXPECT_EQ(reg->builtin_code, BuiltinOperator_CONV_2D);
|
||||
EXPECT_TRUE(reg->prepare == ConvPrepare);
|
||||
@ -129,6 +141,30 @@ TEST(LoggingOpResolverTest, OnlyOpsInReplacementSetAreReplaces) {
|
||||
EXPECT_EQ(nullptr, reg);
|
||||
}
|
||||
|
||||
TEST(LoggingOpResolverTest, CustomOps) {
|
||||
MutableOpResolver base_resolver;
|
||||
TfLiteRegistration custom_registration = {};
|
||||
custom_registration.prepare = CustomPrepare;
|
||||
custom_registration.invoke = CustomEval;
|
||||
|
||||
std::string custom_op_name = "custom";
|
||||
base_resolver.AddCustom(custom_op_name.c_str(), &custom_registration);
|
||||
|
||||
CustomOpsSet ops_to_replace = {
|
||||
{custom_op_name, /*version*/ 1},
|
||||
};
|
||||
|
||||
LoggingOpResolver resolver(BuiltinOpsSet(), ops_to_replace, base_resolver,
|
||||
WrappingInvoke);
|
||||
|
||||
auto reg = resolver.FindOp(custom_op_name.c_str(), 1);
|
||||
|
||||
EXPECT_EQ(reg->builtin_code, BuiltinOperator_CUSTOM);
|
||||
EXPECT_EQ(reg->custom_name, custom_op_name.c_str());
|
||||
EXPECT_TRUE(reg->prepare == CustomPrepare);
|
||||
EXPECT_TRUE(reg->invoke == WrappingInvoke);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace calibration
|
||||
} // namespace optimize
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user