Add custom op support to calibrator.

Also fix incorrectly set version while i am here.

PiperOrigin-RevId: 251583288
This commit is contained in:
Suharsh Sivakumar 2019-06-04 23:04:02 -07:00 committed by TensorFlower Gardener
parent 19665c6333
commit c8fd0ce78e
5 changed files with 105 additions and 19 deletions

View File

@ -25,16 +25,28 @@ namespace optimize {
namespace calibration {
using BuiltinOperatorKey = std::pair<BuiltinOperator, int>;
using CustomOperatorKey = std::pair<std::string, int>;
using BuiltinOpsSet = std::unordered_set<
BuiltinOperatorKey,
op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey>>;
using CustomOpsSet = std::unordered_set<
CustomOperatorKey,
op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey>>;
template <typename T>
class BuiltinOpsMap
: public std::unordered_map<
BuiltinOperatorKey, T,
op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey>> {};
template <typename T>
class CustomOpsMap
: public std::unordered_map<
CustomOperatorKey, T,
op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey>> {};
// An alias for |TfLiteRegistration.invoke|.
using KernelEvalFuncPtr = TfLiteStatus (*)(TfLiteContext*, TfLiteNode*);
@ -53,6 +65,7 @@ struct OperatorInfo {
// Outputs that need to be logged.
std::vector<int> loggable_outputs;
const TfLiteRegistration* registration;
int version;
};
} // namespace calibration

View File

@ -79,8 +79,12 @@ class Calibrator {
KernelEvalFuncPtr Calibrator::GetKernelInvoke(const TfLiteNode* node) const {
auto op_info = node_ptr_opinfo_map_.at(node);
if (op_info.is_custom_op) {
return logging_op_resolver_->GetWrappedKernelInvoke(op_info.name.c_str(),
op_info.version);
}
return logging_op_resolver_->GetWrappedKernelInvoke(op_info.builtin_op_code,
1);
op_info.version);
}
// A registry of |Calibrator| objects per |TfLiteContext|.
@ -282,7 +286,8 @@ TfLiteStatus BuildLoggingInterpreter(
auto operators = primary_subgraph->operators();
auto tensors = primary_subgraph->tensors();
std::unordered_map<int, OperatorInfo> node_to_opinfo;
BuiltinOpsSet op_and_versions;
BuiltinOpsSet builtin_op_and_versions;
CustomOpsSet custom_op_and_versions;
for (size_t i = 0; i < operators->size(); i++) {
OperatorInfo op_info;
@ -292,6 +297,7 @@ TfLiteStatus BuildLoggingInterpreter(
op_info.builtin_op_code = operator_code->builtin_code();
op_info.name = GetOpName(*operator_code);
op_info.is_custom_op = operator_code->custom_code() != nullptr;
op_info.version = operator_code->version();
auto op_inputs = op->inputs();
auto op_outputs = op->outputs();
@ -301,21 +307,25 @@ TfLiteStatus BuildLoggingInterpreter(
GetLoggableTensorIndices(op_info.inputs, tensors, tensor_buffers);
op_info.loggable_outputs =
GetLoggableTensorIndices(op_info.outputs, tensors, tensor_buffers);
if (!op_info.is_custom_op) {
op_info.registration = op_resolver.FindOp(operator_code->builtin_code(),
operator_code->version());
} else {
if (op_info.is_custom_op) {
op_info.registration =
op_resolver.FindOp(op_info.name.c_str(), operator_code->version());
custom_op_and_versions.insert(
{op_info.name.c_str(), operator_code->version()});
} else {
op_info.registration = op_resolver.FindOp(operator_code->builtin_code(),
operator_code->version());
builtin_op_and_versions.insert(
{op_info.builtin_op_code, operator_code->version()});
}
node_to_opinfo[i] = op_info;
op_and_versions.insert({op_info.builtin_op_code, operator_code->version()});
}
// Prepare the logging op resolver to use |LoggingEval| for kernel
// invocations.
auto logging_op_resolver = absl::make_unique<LoggingOpResolver>(
op_and_versions, op_resolver, LoggingEval);
builtin_op_and_versions, custom_op_and_versions, op_resolver,
LoggingEval);
tflite::InterpreterBuilder(model, *logging_op_resolver)(interpreter);
if (!(*interpreter)) {

View File

@ -20,10 +20,11 @@ namespace tflite {
namespace optimize {
namespace calibration {
LoggingOpResolver::LoggingOpResolver(const BuiltinOpsSet& ops_to_replace,
const OpResolver& base_resolver,
KernelEvalFuncPtr logging_eval_fn) {
for (const auto& op_and_version : ops_to_replace) {
LoggingOpResolver::LoggingOpResolver(
const BuiltinOpsSet& builtin_ops_to_replace,
const CustomOpsSet& custom_ops_to_replace, const OpResolver& base_resolver,
KernelEvalFuncPtr logging_eval_fn) {
for (const auto& op_and_version : builtin_ops_to_replace) {
const TfLiteRegistration* base_registration =
base_resolver.FindOp(op_and_version.first, op_and_version.second);
BuiltinOperatorKey key = op_and_version;
@ -33,6 +34,16 @@ LoggingOpResolver::LoggingOpResolver(const BuiltinOpsSet& ops_to_replace,
logging_registation->invoke = logging_eval_fn;
builtin_op_registration_map_[key] = std::move(logging_registation);
}
for (const auto& op_and_version : custom_ops_to_replace) {
const TfLiteRegistration* base_registration = base_resolver.FindOp(
op_and_version.first.c_str(), op_and_version.second);
CustomOperatorKey key = op_and_version;
custom_op_evalfn_map_[key] = base_registration->invoke;
auto logging_registation =
absl::make_unique<TfLiteRegistration>(*base_registration);
logging_registation->invoke = logging_eval_fn;
custom_op_registration_map_[key] = std::move(logging_registation);
}
}
const TfLiteRegistration* LoggingOpResolver::FindOp(BuiltinOperator op,
@ -53,9 +64,20 @@ KernelEvalFuncPtr LoggingOpResolver::GetWrappedKernelInvoke(BuiltinOperator op,
const TfLiteRegistration* LoggingOpResolver::FindOp(const char* op,
int version) const {
// TODO(b/121374947): Support custom ops as well.
CustomOperatorKey key = {op, version};
if (custom_op_registration_map_.find(key) !=
custom_op_registration_map_.end()) {
return custom_op_registration_map_.at(key).get();
}
return nullptr;
}
KernelEvalFuncPtr LoggingOpResolver::GetWrappedKernelInvoke(const char* op,
int version) const {
return custom_op_evalfn_map_.at({op, version});
}
} // namespace calibration
} // namespace optimize
} // namespace tflite

View File

@ -26,6 +26,7 @@ limitations under the License.
namespace tflite {
namespace optimize {
namespace calibration {
// A resolver that replaces the kernel invocations with a wrapper
// eval function.
class LoggingOpResolver : public OpResolver {
@ -33,23 +34,27 @@ class LoggingOpResolver : public OpResolver {
// Creates an instance of |LoggingOpResolver|.
// All |TfLiteRegistration.invoke| functions are replaced by
// |logging_eval_fn|.
// TODO(shashishekhar): This interface needs to change for custom ops and
// TODO(shashishekhar): This interface needs to change for
// BuiltinOps that need special logging implementations.
LoggingOpResolver(const BuiltinOpsSet& ops_to_replace,
LoggingOpResolver(const BuiltinOpsSet& builtin_ops_to_replace,
const CustomOpsSet& custom_ops_to_replace,
const OpResolver& base_resolver,
KernelEvalFuncPtr logging_eval_fn);
const TfLiteRegistration* FindOp(BuiltinOperator op,
int version) const override;
KernelEvalFuncPtr GetWrappedKernelInvoke(BuiltinOperator op,
int version) const;
const TfLiteRegistration* FindOp(const char* op, int version) const override;
KernelEvalFuncPtr GetWrappedKernelInvoke(const char* op, int version) const;
private:
BuiltinOpsMap<std::unique_ptr<TfLiteRegistration>>
builtin_op_registration_map_;
BuiltinOpsMap<KernelEvalFuncPtr> builtin_op_evalfn_map_;
CustomOpsMap<std::unique_ptr<TfLiteRegistration>> custom_op_registration_map_;
CustomOpsMap<KernelEvalFuncPtr> custom_op_evalfn_map_;
};
} // namespace calibration

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/tools/optimize/calibration/logging_op_resolver.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/mutable_op_resolver.h"
@ -38,6 +39,14 @@ TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus CustomPrepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus CustomEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus WrappingInvoke(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
@ -60,7 +69,8 @@ TEST(LoggingOpResolverTest, KernelInvokesAreReplaced) {
{BuiltinOperator_ADD, /*version*/ 1},
};
LoggingOpResolver resolver(ops_to_replace, base_resolver, WrappingInvoke);
LoggingOpResolver resolver(ops_to_replace, CustomOpsSet(), base_resolver,
WrappingInvoke);
auto reg = resolver.FindOp(BuiltinOperator_CONV_2D, 1);
@ -93,7 +103,8 @@ TEST(LoggingOpResolverTest, OriginalKernelInvokesAreRetained) {
{BuiltinOperator_ADD, /*version*/ 1},
};
LoggingOpResolver resolver(ops_to_replace, base_resolver, WrappingInvoke);
LoggingOpResolver resolver(ops_to_replace, CustomOpsSet(), base_resolver,
WrappingInvoke);
auto kernel_invoke =
resolver.GetWrappedKernelInvoke(BuiltinOperator_CONV_2D, 1);
EXPECT_TRUE(kernel_invoke == ConvEval);
@ -119,7 +130,8 @@ TEST(LoggingOpResolverTest, OnlyOpsInReplacementSetAreReplaces) {
{BuiltinOperator_CONV_2D, /*version*/ 1},
};
LoggingOpResolver resolver(ops_to_replace, base_resolver, WrappingInvoke);
LoggingOpResolver resolver(ops_to_replace, CustomOpsSet(), base_resolver,
WrappingInvoke);
auto reg = resolver.FindOp(BuiltinOperator_CONV_2D, 1);
EXPECT_EQ(reg->builtin_code, BuiltinOperator_CONV_2D);
EXPECT_TRUE(reg->prepare == ConvPrepare);
@ -129,6 +141,30 @@ TEST(LoggingOpResolverTest, OnlyOpsInReplacementSetAreReplaces) {
EXPECT_EQ(nullptr, reg);
}
TEST(LoggingOpResolverTest, CustomOps) {
MutableOpResolver base_resolver;
TfLiteRegistration custom_registration = {};
custom_registration.prepare = CustomPrepare;
custom_registration.invoke = CustomEval;
std::string custom_op_name = "custom";
base_resolver.AddCustom(custom_op_name.c_str(), &custom_registration);
CustomOpsSet ops_to_replace = {
{custom_op_name, /*version*/ 1},
};
LoggingOpResolver resolver(BuiltinOpsSet(), ops_to_replace, base_resolver,
WrappingInvoke);
auto reg = resolver.FindOp(custom_op_name.c_str(), 1);
EXPECT_EQ(reg->builtin_code, BuiltinOperator_CUSTOM);
EXPECT_EQ(reg->custom_name, custom_op_name.c_str());
EXPECT_TRUE(reg->prepare == CustomPrepare);
EXPECT_TRUE(reg->invoke == WrappingInvoke);
}
} // namespace
} // namespace calibration
} // namespace optimize