From c8fd0ce78e236355f435dec1e9fc6cf11cdeed36 Mon Sep 17 00:00:00 2001 From: Suharsh Sivakumar Date: Tue, 4 Jun 2019 23:04:02 -0700 Subject: [PATCH] Add custom op support to calibrator. Also fix incorrectly set version while i am here. PiperOrigin-RevId: 251583288 --- .../optimize/calibration/calibration_common.h | 13 ++++++ .../tools/optimize/calibration/calibrator.cc | 26 ++++++++---- .../calibration/logging_op_resolver.cc | 32 +++++++++++--- .../calibration/logging_op_resolver.h | 11 +++-- .../calibration/logging_op_resolver_test.cc | 42 +++++++++++++++++-- 5 files changed, 105 insertions(+), 19 deletions(-) diff --git a/tensorflow/lite/tools/optimize/calibration/calibration_common.h b/tensorflow/lite/tools/optimize/calibration/calibration_common.h index 1ff2d3f18a6..52498edcba9 100644 --- a/tensorflow/lite/tools/optimize/calibration/calibration_common.h +++ b/tensorflow/lite/tools/optimize/calibration/calibration_common.h @@ -25,16 +25,28 @@ namespace optimize { namespace calibration { using BuiltinOperatorKey = std::pair; +using CustomOperatorKey = std::pair; + using BuiltinOpsSet = std::unordered_set< BuiltinOperatorKey, op_resolver_hasher::OperatorKeyHasher>; +using CustomOpsSet = std::unordered_set< + CustomOperatorKey, + op_resolver_hasher::OperatorKeyHasher>; + template class BuiltinOpsMap : public std::unordered_map< BuiltinOperatorKey, T, op_resolver_hasher::OperatorKeyHasher> {}; +template +class CustomOpsMap + : public std::unordered_map< + CustomOperatorKey, T, + op_resolver_hasher::OperatorKeyHasher> {}; + // An alias for |TfLiteRegistration.invoke|. using KernelEvalFuncPtr = TfLiteStatus (*)(TfLiteContext*, TfLiteNode*); @@ -53,6 +65,7 @@ struct OperatorInfo { // Outputs that need to be logged. std::vector loggable_outputs; const TfLiteRegistration* registration; + int version; }; } // namespace calibration diff --git a/tensorflow/lite/tools/optimize/calibration/calibrator.cc b/tensorflow/lite/tools/optimize/calibration/calibrator.cc index 7a9a4943704..f1fbf22437f 100644 --- a/tensorflow/lite/tools/optimize/calibration/calibrator.cc +++ b/tensorflow/lite/tools/optimize/calibration/calibrator.cc @@ -79,8 +79,12 @@ class Calibrator { KernelEvalFuncPtr Calibrator::GetKernelInvoke(const TfLiteNode* node) const { auto op_info = node_ptr_opinfo_map_.at(node); + if (op_info.is_custom_op) { + return logging_op_resolver_->GetWrappedKernelInvoke(op_info.name.c_str(), + op_info.version); + } return logging_op_resolver_->GetWrappedKernelInvoke(op_info.builtin_op_code, - 1); + op_info.version); } // A registry of |Calibrator| objects per |TfLiteContext|. @@ -282,7 +286,8 @@ TfLiteStatus BuildLoggingInterpreter( auto operators = primary_subgraph->operators(); auto tensors = primary_subgraph->tensors(); std::unordered_map node_to_opinfo; - BuiltinOpsSet op_and_versions; + BuiltinOpsSet builtin_op_and_versions; + CustomOpsSet custom_op_and_versions; for (size_t i = 0; i < operators->size(); i++) { OperatorInfo op_info; @@ -292,6 +297,7 @@ TfLiteStatus BuildLoggingInterpreter( op_info.builtin_op_code = operator_code->builtin_code(); op_info.name = GetOpName(*operator_code); op_info.is_custom_op = operator_code->custom_code() != nullptr; + op_info.version = operator_code->version(); auto op_inputs = op->inputs(); auto op_outputs = op->outputs(); @@ -301,21 +307,25 @@ TfLiteStatus BuildLoggingInterpreter( GetLoggableTensorIndices(op_info.inputs, tensors, tensor_buffers); op_info.loggable_outputs = GetLoggableTensorIndices(op_info.outputs, tensors, tensor_buffers); - if (!op_info.is_custom_op) { - op_info.registration = op_resolver.FindOp(operator_code->builtin_code(), - operator_code->version()); - } else { + if (op_info.is_custom_op) { op_info.registration = op_resolver.FindOp(op_info.name.c_str(), operator_code->version()); + custom_op_and_versions.insert( + {op_info.name.c_str(), operator_code->version()}); + } else { + op_info.registration = op_resolver.FindOp(operator_code->builtin_code(), + operator_code->version()); + builtin_op_and_versions.insert( + {op_info.builtin_op_code, operator_code->version()}); } node_to_opinfo[i] = op_info; - op_and_versions.insert({op_info.builtin_op_code, operator_code->version()}); } // Prepare the logging op resolver to use |LoggingEval| for kernel // invocations. auto logging_op_resolver = absl::make_unique( - op_and_versions, op_resolver, LoggingEval); + builtin_op_and_versions, custom_op_and_versions, op_resolver, + LoggingEval); tflite::InterpreterBuilder(model, *logging_op_resolver)(interpreter); if (!(*interpreter)) { diff --git a/tensorflow/lite/tools/optimize/calibration/logging_op_resolver.cc b/tensorflow/lite/tools/optimize/calibration/logging_op_resolver.cc index d2a09e898ae..199318c5db2 100644 --- a/tensorflow/lite/tools/optimize/calibration/logging_op_resolver.cc +++ b/tensorflow/lite/tools/optimize/calibration/logging_op_resolver.cc @@ -20,10 +20,11 @@ namespace tflite { namespace optimize { namespace calibration { -LoggingOpResolver::LoggingOpResolver(const BuiltinOpsSet& ops_to_replace, - const OpResolver& base_resolver, - KernelEvalFuncPtr logging_eval_fn) { - for (const auto& op_and_version : ops_to_replace) { +LoggingOpResolver::LoggingOpResolver( + const BuiltinOpsSet& builtin_ops_to_replace, + const CustomOpsSet& custom_ops_to_replace, const OpResolver& base_resolver, + KernelEvalFuncPtr logging_eval_fn) { + for (const auto& op_and_version : builtin_ops_to_replace) { const TfLiteRegistration* base_registration = base_resolver.FindOp(op_and_version.first, op_and_version.second); BuiltinOperatorKey key = op_and_version; @@ -33,6 +34,16 @@ LoggingOpResolver::LoggingOpResolver(const BuiltinOpsSet& ops_to_replace, logging_registation->invoke = logging_eval_fn; builtin_op_registration_map_[key] = std::move(logging_registation); } + for (const auto& op_and_version : custom_ops_to_replace) { + const TfLiteRegistration* base_registration = base_resolver.FindOp( + op_and_version.first.c_str(), op_and_version.second); + CustomOperatorKey key = op_and_version; + custom_op_evalfn_map_[key] = base_registration->invoke; + auto logging_registation = + absl::make_unique(*base_registration); + logging_registation->invoke = logging_eval_fn; + custom_op_registration_map_[key] = std::move(logging_registation); + } } const TfLiteRegistration* LoggingOpResolver::FindOp(BuiltinOperator op, @@ -53,9 +64,20 @@ KernelEvalFuncPtr LoggingOpResolver::GetWrappedKernelInvoke(BuiltinOperator op, const TfLiteRegistration* LoggingOpResolver::FindOp(const char* op, int version) const { - // TODO(b/121374947): Support custom ops as well. + CustomOperatorKey key = {op, version}; + if (custom_op_registration_map_.find(key) != + custom_op_registration_map_.end()) { + return custom_op_registration_map_.at(key).get(); + } + return nullptr; } + +KernelEvalFuncPtr LoggingOpResolver::GetWrappedKernelInvoke(const char* op, + int version) const { + return custom_op_evalfn_map_.at({op, version}); +} + } // namespace calibration } // namespace optimize } // namespace tflite diff --git a/tensorflow/lite/tools/optimize/calibration/logging_op_resolver.h b/tensorflow/lite/tools/optimize/calibration/logging_op_resolver.h index af4127e42f7..bbdfef60d92 100644 --- a/tensorflow/lite/tools/optimize/calibration/logging_op_resolver.h +++ b/tensorflow/lite/tools/optimize/calibration/logging_op_resolver.h @@ -26,6 +26,7 @@ limitations under the License. namespace tflite { namespace optimize { namespace calibration { + // A resolver that replaces the kernel invocations with a wrapper // eval function. class LoggingOpResolver : public OpResolver { @@ -33,23 +34,27 @@ class LoggingOpResolver : public OpResolver { // Creates an instance of |LoggingOpResolver|. // All |TfLiteRegistration.invoke| functions are replaced by // |logging_eval_fn|. - // TODO(shashishekhar): This interface needs to change for custom ops and + // TODO(shashishekhar): This interface needs to change for // BuiltinOps that need special logging implementations. - LoggingOpResolver(const BuiltinOpsSet& ops_to_replace, + LoggingOpResolver(const BuiltinOpsSet& builtin_ops_to_replace, + const CustomOpsSet& custom_ops_to_replace, const OpResolver& base_resolver, KernelEvalFuncPtr logging_eval_fn); const TfLiteRegistration* FindOp(BuiltinOperator op, int version) const override; - KernelEvalFuncPtr GetWrappedKernelInvoke(BuiltinOperator op, int version) const; + const TfLiteRegistration* FindOp(const char* op, int version) const override; + KernelEvalFuncPtr GetWrappedKernelInvoke(const char* op, int version) const; private: BuiltinOpsMap> builtin_op_registration_map_; BuiltinOpsMap builtin_op_evalfn_map_; + CustomOpsMap> custom_op_registration_map_; + CustomOpsMap custom_op_evalfn_map_; }; } // namespace calibration diff --git a/tensorflow/lite/tools/optimize/calibration/logging_op_resolver_test.cc b/tensorflow/lite/tools/optimize/calibration/logging_op_resolver_test.cc index d8d29ad8eff..25507daaa91 100644 --- a/tensorflow/lite/tools/optimize/calibration/logging_op_resolver_test.cc +++ b/tensorflow/lite/tools/optimize/calibration/logging_op_resolver_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/lite/tools/optimize/calibration/logging_op_resolver.h" + #include #include #include "tensorflow/lite/mutable_op_resolver.h" @@ -38,6 +39,14 @@ TfLiteStatus AddEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus CustomPrepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus CustomEval(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + TfLiteStatus WrappingInvoke(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } @@ -60,7 +69,8 @@ TEST(LoggingOpResolverTest, KernelInvokesAreReplaced) { {BuiltinOperator_ADD, /*version*/ 1}, }; - LoggingOpResolver resolver(ops_to_replace, base_resolver, WrappingInvoke); + LoggingOpResolver resolver(ops_to_replace, CustomOpsSet(), base_resolver, + WrappingInvoke); auto reg = resolver.FindOp(BuiltinOperator_CONV_2D, 1); @@ -93,7 +103,8 @@ TEST(LoggingOpResolverTest, OriginalKernelInvokesAreRetained) { {BuiltinOperator_ADD, /*version*/ 1}, }; - LoggingOpResolver resolver(ops_to_replace, base_resolver, WrappingInvoke); + LoggingOpResolver resolver(ops_to_replace, CustomOpsSet(), base_resolver, + WrappingInvoke); auto kernel_invoke = resolver.GetWrappedKernelInvoke(BuiltinOperator_CONV_2D, 1); EXPECT_TRUE(kernel_invoke == ConvEval); @@ -119,7 +130,8 @@ TEST(LoggingOpResolverTest, OnlyOpsInReplacementSetAreReplaces) { {BuiltinOperator_CONV_2D, /*version*/ 1}, }; - LoggingOpResolver resolver(ops_to_replace, base_resolver, WrappingInvoke); + LoggingOpResolver resolver(ops_to_replace, CustomOpsSet(), base_resolver, + WrappingInvoke); auto reg = resolver.FindOp(BuiltinOperator_CONV_2D, 1); EXPECT_EQ(reg->builtin_code, BuiltinOperator_CONV_2D); EXPECT_TRUE(reg->prepare == ConvPrepare); @@ -129,6 +141,30 @@ TEST(LoggingOpResolverTest, OnlyOpsInReplacementSetAreReplaces) { EXPECT_EQ(nullptr, reg); } +TEST(LoggingOpResolverTest, CustomOps) { + MutableOpResolver base_resolver; + TfLiteRegistration custom_registration = {}; + custom_registration.prepare = CustomPrepare; + custom_registration.invoke = CustomEval; + + std::string custom_op_name = "custom"; + base_resolver.AddCustom(custom_op_name.c_str(), &custom_registration); + + CustomOpsSet ops_to_replace = { + {custom_op_name, /*version*/ 1}, + }; + + LoggingOpResolver resolver(BuiltinOpsSet(), ops_to_replace, base_resolver, + WrappingInvoke); + + auto reg = resolver.FindOp(custom_op_name.c_str(), 1); + + EXPECT_EQ(reg->builtin_code, BuiltinOperator_CUSTOM); + EXPECT_EQ(reg->custom_name, custom_op_name.c_str()); + EXPECT_TRUE(reg->prepare == CustomPrepare); + EXPECT_TRUE(reg->invoke == WrappingInvoke); +} + } // namespace } // namespace calibration } // namespace optimize