API changes for evaluation framework, mainly to support adding arbitrary delegates
PiperOrigin-RevId: 248613883
This commit is contained in:
parent
5b0521bed3
commit
bb15fff6b9
tensorflow/lite/tools/evaluation/stages
@ -105,6 +105,7 @@ cc_library(
|
||||
"//tensorflow/core:stats_calculator_portable",
|
||||
"//tensorflow/core:tflite_portable_logging",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/profiling:time",
|
||||
"//tensorflow/lite/tools/evaluation:evaluation_stage",
|
||||
@ -123,6 +124,7 @@ cc_test(
|
||||
deps = [
|
||||
":tflite_inference_stage",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/delegates/nnapi:nnapi_delegate",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
|
@ -58,6 +58,12 @@ class ImageClassificationStage : public EvaluationStage {
|
||||
ground_truth_label_ = ground_truth_label;
|
||||
}
|
||||
|
||||
// Provides a pointer to the underlying TfLiteInferenceStage.
|
||||
// Returns non-null value only if this stage has been initialized.
|
||||
TfliteInferenceStage* const GetInferenceStage() {
|
||||
return inference_stage_.get();
|
||||
}
|
||||
|
||||
private:
|
||||
const std::vector<std::string>* all_labels_ = nullptr;
|
||||
std::unique_ptr<ImagePreprocessingStage> preprocessing_stage_;
|
||||
|
@ -39,6 +39,15 @@ TfLiteModelInfo GetTfliteModelInfo(const Interpreter& interpreter) {
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus TfliteInferenceStage::ApplyCustomDelegate(
|
||||
TfLiteDelegate* delegate) {
|
||||
if (!interpreter_) {
|
||||
LOG(ERROR) << "Stage not initialized before calling ApplyCustomDelegate";
|
||||
return kTfLiteError;
|
||||
}
|
||||
return interpreter_->ModifyGraphWithDelegate(delegate);
|
||||
}
|
||||
|
||||
TfLiteStatus TfliteInferenceStage::Init() {
|
||||
if (!config_.specification().has_tflite_inference_params()) {
|
||||
LOG(ERROR) << "TfliteInferenceParams not provided";
|
||||
@ -109,7 +118,7 @@ TfLiteStatus TfliteInferenceStage::Run() {
|
||||
// Copy input data.
|
||||
for (int i = 0; i < interpreter_->inputs().size(); ++i) {
|
||||
TfLiteTensor* tensor = interpreter_->tensor(interpreter_->inputs()[i]);
|
||||
std::memcpy(tensor->data.raw, (*inputs_)[i], tensor->bytes);
|
||||
tensor->data.raw = static_cast<char*>(inputs_->at(i));
|
||||
}
|
||||
|
||||
// Invoke.
|
||||
|
@ -16,9 +16,11 @@ limitations under the License.
|
||||
#define TENSORFLOW_LITE_TOOLS_EVALUATION_STAGES_TFLITE_INFERENCE_STAGE_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/util/stats_calculator.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
@ -50,10 +52,14 @@ class TfliteInferenceStage : public EvaluationStage {
|
||||
|
||||
// Call before Run().
|
||||
// This class does not take ownership of raw_input_ptrs.
|
||||
void SetInputs(std::vector<void*>& raw_input_ptrs) {
|
||||
void SetInputs(const std::vector<void*>& raw_input_ptrs) {
|
||||
inputs_ = &raw_input_ptrs;
|
||||
}
|
||||
|
||||
// Applies provided delegate to the underlying TFLite Interpreter.
|
||||
// NOTE: TFLiteInferenceStage does not take ownership of delegate.
|
||||
TfLiteStatus ApplyCustomDelegate(TfLiteDelegate* delegate);
|
||||
|
||||
// Read-only view of a TfliteModelInfo. TfliteInferenceStage retains
|
||||
// ownership.
|
||||
// Only available after Init is done.
|
||||
@ -70,7 +76,7 @@ class TfliteInferenceStage : public EvaluationStage {
|
||||
std::vector<Interpreter::TfLiteDelegatePtr> delegates_;
|
||||
|
||||
TfLiteModelInfo model_info_;
|
||||
std::vector<void*>* inputs_ = nullptr;
|
||||
const std::vector<void*>* inputs_ = nullptr;
|
||||
std::vector<void*> outputs_;
|
||||
|
||||
tensorflow::Stat<int64_t> latency_stats_;
|
||||
|
@ -15,10 +15,12 @@ limitations under the License.
|
||||
#include "tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h"
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||
|
||||
@ -155,6 +157,19 @@ TEST(TfliteInferenceStage, CorrectOutput) {
|
||||
metrics.process_metrics().tflite_inference_metrics().num_inferences(), 2);
|
||||
}
|
||||
|
||||
TEST(TfliteInferenceStage, CustomDelegate) {
|
||||
// Create stage.
|
||||
EvaluationStageConfig config = GetTfliteInferenceStageConfig();
|
||||
TfliteInferenceStage stage(config);
|
||||
|
||||
TfLiteDelegate* test_delegate = NnApiDelegate();
|
||||
|
||||
// Delegate application should only work after initialization of stage.
|
||||
EXPECT_NE(stage.ApplyCustomDelegate(test_delegate), kTfLiteOk);
|
||||
EXPECT_EQ(stage.Init(), kTfLiteOk);
|
||||
EXPECT_EQ(stage.ApplyCustomDelegate(test_delegate), kTfLiteOk);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace evaluation
|
||||
} // namespace tflite
|
||||
|
Loading…
Reference in New Issue
Block a user