From bb15fff6b9c92c5926fa92229d35ca4f5f9eb822 Mon Sep 17 00:00:00 2001 From: Sachin Joglekar <srjoglekar@google.com> Date: Thu, 16 May 2019 15:37:52 -0700 Subject: [PATCH] API changes for evaluation framework, mainly to support adding arbitrary delegates PiperOrigin-RevId: 248613883 --- tensorflow/lite/tools/evaluation/stages/BUILD | 2 ++ .../stages/image_classification_stage.h | 6 ++++++ .../evaluation/stages/tflite_inference_stage.cc | 11 ++++++++++- .../evaluation/stages/tflite_inference_stage.h | 10 ++++++++-- .../stages/tflite_inference_stage_test.cc | 15 +++++++++++++++ 5 files changed, 41 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/tools/evaluation/stages/BUILD b/tensorflow/lite/tools/evaluation/stages/BUILD index 6a642550a8a..f12f9f06908 100644 --- a/tensorflow/lite/tools/evaluation/stages/BUILD +++ b/tensorflow/lite/tools/evaluation/stages/BUILD @@ -105,6 +105,7 @@ cc_library( "//tensorflow/core:stats_calculator_portable", "//tensorflow/core:tflite_portable_logging", "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_internal", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/profiling:time", "//tensorflow/lite/tools/evaluation:evaluation_stage", @@ -123,6 +124,7 @@ cc_test( deps = [ ":tflite_inference_stage", "//tensorflow/lite/c:c_api_internal", + "//tensorflow/lite/delegates/nnapi:nnapi_delegate", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/lite/tools/evaluation/stages/image_classification_stage.h b/tensorflow/lite/tools/evaluation/stages/image_classification_stage.h index 73d4c3efb7e..31c956b0c90 100644 --- a/tensorflow/lite/tools/evaluation/stages/image_classification_stage.h +++ b/tensorflow/lite/tools/evaluation/stages/image_classification_stage.h @@ -58,6 +58,12 @@ class ImageClassificationStage : public EvaluationStage { ground_truth_label_ = ground_truth_label; } + // Provides a pointer to the underlying TfLiteInferenceStage. + // Returns non-null value only if this stage has been initialized. + TfliteInferenceStage* const GetInferenceStage() { + return inference_stage_.get(); + } + private: const std::vector<std::string>* all_labels_ = nullptr; std::unique_ptr<ImagePreprocessingStage> preprocessing_stage_; diff --git a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc index bc0eab03338..0f7070d3d05 100644 --- a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc @@ -39,6 +39,15 @@ TfLiteModelInfo GetTfliteModelInfo(const Interpreter& interpreter) { } // namespace +TfLiteStatus TfliteInferenceStage::ApplyCustomDelegate( + TfLiteDelegate* delegate) { + if (!interpreter_) { + LOG(ERROR) << "Stage not initialized before calling ApplyCustomDelegate"; + return kTfLiteError; + } + return interpreter_->ModifyGraphWithDelegate(delegate); +} + TfLiteStatus TfliteInferenceStage::Init() { if (!config_.specification().has_tflite_inference_params()) { LOG(ERROR) << "TfliteInferenceParams not provided"; @@ -109,7 +118,7 @@ TfLiteStatus TfliteInferenceStage::Run() { // Copy input data. for (int i = 0; i < interpreter_->inputs().size(); ++i) { TfLiteTensor* tensor = interpreter_->tensor(interpreter_->inputs()[i]); - std::memcpy(tensor->data.raw, (*inputs_)[i], tensor->bytes); + tensor->data.raw = static_cast<char*>(inputs_->at(i)); } // Invoke. diff --git a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h index ecb9c7f7a86..fa53a49079f 100644 --- a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h +++ b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h @@ -16,9 +16,11 @@ limitations under the License. #define TENSORFLOW_LITE_TOOLS_EVALUATION_STAGES_TFLITE_INFERENCE_STAGE_H_ #include <stdint.h> + #include <vector> #include "tensorflow/core/util/stats_calculator.h" +#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" @@ -50,10 +52,14 @@ class TfliteInferenceStage : public EvaluationStage { // Call before Run(). // This class does not take ownership of raw_input_ptrs. - void SetInputs(std::vector<void*>& raw_input_ptrs) { + void SetInputs(const std::vector<void*>& raw_input_ptrs) { inputs_ = &raw_input_ptrs; } + // Applies provided delegate to the underlying TFLite Interpreter. + // NOTE: TFLiteInferenceStage does not take ownership of delegate. + TfLiteStatus ApplyCustomDelegate(TfLiteDelegate* delegate); + // Read-only view of a TfliteModelInfo. TfliteInferenceStage retains // ownership. // Only available after Init is done. @@ -70,7 +76,7 @@ class TfliteInferenceStage : public EvaluationStage { std::vector<Interpreter::TfLiteDelegatePtr> delegates_; TfLiteModelInfo model_info_; - std::vector<void*>* inputs_ = nullptr; + const std::vector<void*>* inputs_ = nullptr; std::vector<void*> outputs_; tensorflow::Stat<int64_t> latency_stats_; diff --git a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage_test.cc b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage_test.cc index f86de894713..beb9c49c338 100644 --- a/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage_test.cc +++ b/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage_test.cc @@ -15,10 +15,12 @@ limitations under the License. #include "tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h" #include <stdint.h> + #include <string> #include <gtest/gtest.h> #include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" @@ -155,6 +157,19 @@ TEST(TfliteInferenceStage, CorrectOutput) { metrics.process_metrics().tflite_inference_metrics().num_inferences(), 2); } +TEST(TfliteInferenceStage, CustomDelegate) { + // Create stage. + EvaluationStageConfig config = GetTfliteInferenceStageConfig(); + TfliteInferenceStage stage(config); + + TfLiteDelegate* test_delegate = NnApiDelegate(); + + // Delegate application should only work after initialization of stage. + EXPECT_NE(stage.ApplyCustomDelegate(test_delegate), kTfLiteOk); + EXPECT_EQ(stage.Init(), kTfLiteOk); + EXPECT_EQ(stage.ApplyCustomDelegate(test_delegate), kTfLiteOk); +} + } // namespace } // namespace evaluation } // namespace tflite