From bb15fff6b9c92c5926fa92229d35ca4f5f9eb822 Mon Sep 17 00:00:00 2001
From: Sachin Joglekar <srjoglekar@google.com>
Date: Thu, 16 May 2019 15:37:52 -0700
Subject: [PATCH] API changes for evaluation framework, mainly to support
 adding arbitrary delegates

PiperOrigin-RevId: 248613883
---
 tensorflow/lite/tools/evaluation/stages/BUILD     |  2 ++
 .../stages/image_classification_stage.h           |  6 ++++++
 .../evaluation/stages/tflite_inference_stage.cc   | 11 ++++++++++-
 .../evaluation/stages/tflite_inference_stage.h    | 10 ++++++++--
 .../stages/tflite_inference_stage_test.cc         | 15 +++++++++++++++
 5 files changed, 41 insertions(+), 3 deletions(-)

diff --git a/tensorflow/lite/tools/evaluation/stages/BUILD b/tensorflow/lite/tools/evaluation/stages/BUILD
index 6a642550a8a..f12f9f06908 100644
--- a/tensorflow/lite/tools/evaluation/stages/BUILD
+++ b/tensorflow/lite/tools/evaluation/stages/BUILD
@@ -105,6 +105,7 @@ cc_library(
         "//tensorflow/core:stats_calculator_portable",
         "//tensorflow/core:tflite_portable_logging",
         "//tensorflow/lite:framework",
+        "//tensorflow/lite/c:c_api_internal",
         "//tensorflow/lite/kernels:builtin_ops",
         "//tensorflow/lite/profiling:time",
         "//tensorflow/lite/tools/evaluation:evaluation_stage",
@@ -123,6 +124,7 @@ cc_test(
     deps = [
         ":tflite_inference_stage",
         "//tensorflow/lite/c:c_api_internal",
+        "//tensorflow/lite/delegates/nnapi:nnapi_delegate",
         "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
         "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
         "@com_google_googletest//:gtest_main",
diff --git a/tensorflow/lite/tools/evaluation/stages/image_classification_stage.h b/tensorflow/lite/tools/evaluation/stages/image_classification_stage.h
index 73d4c3efb7e..31c956b0c90 100644
--- a/tensorflow/lite/tools/evaluation/stages/image_classification_stage.h
+++ b/tensorflow/lite/tools/evaluation/stages/image_classification_stage.h
@@ -58,6 +58,12 @@ class ImageClassificationStage : public EvaluationStage {
     ground_truth_label_ = ground_truth_label;
   }
 
+  // Provides a pointer to the underlying TfLiteInferenceStage.
+  // Returns non-null value only if this stage has been initialized.
+  TfliteInferenceStage* const GetInferenceStage() {
+    return inference_stage_.get();
+  }
+
  private:
   const std::vector<std::string>* all_labels_ = nullptr;
   std::unique_ptr<ImagePreprocessingStage> preprocessing_stage_;
diff --git a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc
index bc0eab03338..0f7070d3d05 100644
--- a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc
+++ b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc
@@ -39,6 +39,15 @@ TfLiteModelInfo GetTfliteModelInfo(const Interpreter& interpreter) {
 
 }  // namespace
 
+TfLiteStatus TfliteInferenceStage::ApplyCustomDelegate(
+    TfLiteDelegate* delegate) {
+  if (!interpreter_) {
+    LOG(ERROR) << "Stage not initialized before calling ApplyCustomDelegate";
+    return kTfLiteError;
+  }
+  return interpreter_->ModifyGraphWithDelegate(delegate);
+}
+
 TfLiteStatus TfliteInferenceStage::Init() {
   if (!config_.specification().has_tflite_inference_params()) {
     LOG(ERROR) << "TfliteInferenceParams not provided";
@@ -109,7 +118,7 @@ TfLiteStatus TfliteInferenceStage::Run() {
   // Copy input data.
   for (int i = 0; i < interpreter_->inputs().size(); ++i) {
     TfLiteTensor* tensor = interpreter_->tensor(interpreter_->inputs()[i]);
-    std::memcpy(tensor->data.raw, (*inputs_)[i], tensor->bytes);
+    tensor->data.raw = static_cast<char*>(inputs_->at(i));
   }
 
   // Invoke.
diff --git a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h
index ecb9c7f7a86..fa53a49079f 100644
--- a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h
+++ b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h
@@ -16,9 +16,11 @@ limitations under the License.
 #define TENSORFLOW_LITE_TOOLS_EVALUATION_STAGES_TFLITE_INFERENCE_STAGE_H_
 
 #include <stdint.h>
+
 #include <vector>
 
 #include "tensorflow/core/util/stats_calculator.h"
+#include "tensorflow/lite/c/c_api_internal.h"
 #include "tensorflow/lite/interpreter.h"
 #include "tensorflow/lite/kernels/register.h"
 #include "tensorflow/lite/model.h"
@@ -50,10 +52,14 @@ class TfliteInferenceStage : public EvaluationStage {
 
   // Call before Run().
   // This class does not take ownership of raw_input_ptrs.
-  void SetInputs(std::vector<void*>& raw_input_ptrs) {
+  void SetInputs(const std::vector<void*>& raw_input_ptrs) {
     inputs_ = &raw_input_ptrs;
   }
 
+  // Applies provided delegate to the underlying TFLite Interpreter.
+  // NOTE: TFLiteInferenceStage does not take ownership of delegate.
+  TfLiteStatus ApplyCustomDelegate(TfLiteDelegate* delegate);
+
   // Read-only view of a TfliteModelInfo. TfliteInferenceStage retains
   // ownership.
   // Only available after Init is done.
@@ -70,7 +76,7 @@ class TfliteInferenceStage : public EvaluationStage {
   std::vector<Interpreter::TfLiteDelegatePtr> delegates_;
 
   TfLiteModelInfo model_info_;
-  std::vector<void*>* inputs_ = nullptr;
+  const std::vector<void*>* inputs_ = nullptr;
   std::vector<void*> outputs_;
 
   tensorflow::Stat<int64_t> latency_stats_;
diff --git a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage_test.cc b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage_test.cc
index f86de894713..beb9c49c338 100644
--- a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage_test.cc
+++ b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage_test.cc
@@ -15,10 +15,12 @@ limitations under the License.
 #include "tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h"
 
 #include <stdint.h>
+
 #include <string>
 
 #include <gtest/gtest.h>
 #include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
 #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
 
@@ -155,6 +157,19 @@ TEST(TfliteInferenceStage, CorrectOutput) {
       metrics.process_metrics().tflite_inference_metrics().num_inferences(), 2);
 }
 
+TEST(TfliteInferenceStage, CustomDelegate) {
+  // Create stage.
+  EvaluationStageConfig config = GetTfliteInferenceStageConfig();
+  TfliteInferenceStage stage(config);
+
+  TfLiteDelegate* test_delegate = NnApiDelegate();
+
+  // Delegate application should only work after initialization of stage.
+  EXPECT_NE(stage.ApplyCustomDelegate(test_delegate), kTfLiteOk);
+  EXPECT_EQ(stage.Init(), kTfLiteOk);
+  EXPECT_EQ(stage.ApplyCustomDelegate(test_delegate), kTfLiteOk);
+}
+
 }  // namespace
 }  // namespace evaluation
 }  // namespace tflite