API changes for evaluation framework, mainly to support adding arbitrary delegates

PiperOrigin-RevId: 248613883
This commit is contained in:
Sachin Joglekar 2019-05-16 15:37:52 -07:00 committed by TensorFlower Gardener
parent 5b0521bed3
commit bb15fff6b9
5 changed files with 41 additions and 3 deletions

View File

@ -105,6 +105,7 @@ cc_library(
"//tensorflow/core:stats_calculator_portable",
"//tensorflow/core:tflite_portable_logging",
"//tensorflow/lite:framework",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/profiling:time",
"//tensorflow/lite/tools/evaluation:evaluation_stage",
@ -123,6 +124,7 @@ cc_test(
deps = [
":tflite_inference_stage",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
"@com_google_googletest//:gtest_main",

View File

@ -58,6 +58,12 @@ class ImageClassificationStage : public EvaluationStage {
ground_truth_label_ = ground_truth_label;
}
// Provides a pointer to the underlying TfLiteInferenceStage.
// Returns non-null value only if this stage has been initialized.
TfliteInferenceStage* const GetInferenceStage() {
return inference_stage_.get();
}
private:
const std::vector<std::string>* all_labels_ = nullptr;
std::unique_ptr<ImagePreprocessingStage> preprocessing_stage_;

View File

@ -39,6 +39,15 @@ TfLiteModelInfo GetTfliteModelInfo(const Interpreter& interpreter) {
} // namespace
TfLiteStatus TfliteInferenceStage::ApplyCustomDelegate(
TfLiteDelegate* delegate) {
if (!interpreter_) {
LOG(ERROR) << "Stage not initialized before calling ApplyCustomDelegate";
return kTfLiteError;
}
return interpreter_->ModifyGraphWithDelegate(delegate);
}
TfLiteStatus TfliteInferenceStage::Init() {
if (!config_.specification().has_tflite_inference_params()) {
LOG(ERROR) << "TfliteInferenceParams not provided";
@ -109,7 +118,7 @@ TfLiteStatus TfliteInferenceStage::Run() {
// Copy input data.
for (int i = 0; i < interpreter_->inputs().size(); ++i) {
TfLiteTensor* tensor = interpreter_->tensor(interpreter_->inputs()[i]);
std::memcpy(tensor->data.raw, (*inputs_)[i], tensor->bytes);
tensor->data.raw = static_cast<char*>(inputs_->at(i));
}
// Invoke.

View File

@ -16,9 +16,11 @@ limitations under the License.
#define TENSORFLOW_LITE_TOOLS_EVALUATION_STAGES_TFLITE_INFERENCE_STAGE_H_
#include <stdint.h>
#include <vector>
#include "tensorflow/core/util/stats_calculator.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
@ -50,10 +52,14 @@ class TfliteInferenceStage : public EvaluationStage {
// Call before Run().
// This class does not take ownership of raw_input_ptrs.
void SetInputs(std::vector<void*>& raw_input_ptrs) {
void SetInputs(const std::vector<void*>& raw_input_ptrs) {
inputs_ = &raw_input_ptrs;
}
// Applies provided delegate to the underlying TFLite Interpreter.
// NOTE: TFLiteInferenceStage does not take ownership of delegate.
TfLiteStatus ApplyCustomDelegate(TfLiteDelegate* delegate);
// Read-only view of a TfliteModelInfo. TfliteInferenceStage retains
// ownership.
// Only available after Init is done.
@ -70,7 +76,7 @@ class TfliteInferenceStage : public EvaluationStage {
std::vector<Interpreter::TfLiteDelegatePtr> delegates_;
TfLiteModelInfo model_info_;
std::vector<void*>* inputs_ = nullptr;
const std::vector<void*>* inputs_ = nullptr;
std::vector<void*> outputs_;
tensorflow::Stat<int64_t> latency_stats_;

View File

@ -15,10 +15,12 @@ limitations under the License.
#include "tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h"
#include <stdint.h>
#include <string>
#include <gtest/gtest.h>
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
@ -155,6 +157,19 @@ TEST(TfliteInferenceStage, CorrectOutput) {
metrics.process_metrics().tflite_inference_metrics().num_inferences(), 2);
}
TEST(TfliteInferenceStage, CustomDelegate) {
// Create stage.
EvaluationStageConfig config = GetTfliteInferenceStageConfig();
TfliteInferenceStage stage(config);
TfLiteDelegate* test_delegate = NnApiDelegate();
// Delegate application should only work after initialization of stage.
EXPECT_NE(stage.ApplyCustomDelegate(test_delegate), kTfLiteOk);
EXPECT_EQ(stage.Init(), kTfLiteOk);
EXPECT_EQ(stage.ApplyCustomDelegate(test_delegate), kTfLiteOk);
}
} // namespace
} // namespace evaluation
} // namespace tflite