STT-tensorflow/tensorflow/lite/delegates/utils/simple_delegate_test.cc
Chao Mei 51fe6f6acd Factor out the dummy test delegate from simple_delegate_test and create relevant BUILD rules to show
1. how a customized TfLiteDelegate API looks like
2. how this customized TfLiteDelegate could be plugged into TFLite benchmark and task evaluation tools.

PiperOrigin-RevId: 315821839
Change-Id: Ia1c5d13bd1711a88786f36e2ef7527497a6391c7
2020-06-10 20:24:29 -07:00

129 lines
5.3 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/utils/dummy_delegate/dummy_delegate.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
namespace tflite {
namespace {
class TestDelegate : public ::testing::Test {
protected:
void SetUp() override {
interpreter_.reset(new Interpreter);
interpreter_->AddTensors(5);
interpreter_->SetInputs({0, 1});
interpreter_->SetOutputs({3, 4});
TfLiteQuantizationParams quant;
interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
quant);
interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3},
quant);
interpreter_->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3},
quant);
interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3},
quant);
interpreter_->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", {3},
quant);
TfLiteRegistration* reg = ops::builtin::Register_ADD();
void* builtin_data_1 = malloc(sizeof(int));
void* builtin_data_2 = malloc(sizeof(int));
void* builtin_data_3 = malloc(sizeof(int));
interpreter_->AddNodeWithParameters({0, 0}, {2}, nullptr, 0, builtin_data_1,
reg);
interpreter_->AddNodeWithParameters({1, 1}, {3}, nullptr, 0, builtin_data_2,
reg);
interpreter_->AddNodeWithParameters({2, 1}, {4}, nullptr, 0, builtin_data_3,
reg);
}
void TearDown() override { interpreter_.reset(); }
protected:
std::unique_ptr<Interpreter> interpreter_;
};
TEST_F(TestDelegate, BasicDelegate) {
DummyDelegateOptions options = TfLiteDummyDelegateOptionsDefault();
options.allowed_builtin_code = kTfLiteBuiltinAdd;
auto delegate = TfLiteDummyDelegateCreateUnique(&options);
interpreter_->ModifyGraphWithDelegate(std::move(delegate));
ASSERT_EQ(interpreter_->execution_plan().size(), 1);
int node = interpreter_->execution_plan()[0];
const auto* node_and_reg = interpreter_->node_and_registration(node);
EXPECT_STREQ("DummyDelegate", node_and_reg->second.custom_name);
EXPECT_EQ(1, node_and_reg->second.version);
const TfLiteDelegateParams* params = static_cast<const TfLiteDelegateParams*>(
node_and_reg->first.builtin_data);
ASSERT_EQ(params->nodes_to_replace->size, 3);
EXPECT_EQ(params->nodes_to_replace->data[0], 0);
EXPECT_EQ(params->nodes_to_replace->data[1], 1);
EXPECT_EQ(params->nodes_to_replace->data[2], 2);
ASSERT_EQ(params->input_tensors->size, 2);
EXPECT_EQ(params->input_tensors->data[0], 0);
EXPECT_EQ(params->input_tensors->data[1], 1);
ASSERT_EQ(params->output_tensors->size, 2);
EXPECT_EQ(params->output_tensors->data[0], 3);
EXPECT_EQ(params->output_tensors->data[1], 4);
}
TEST_F(TestDelegate, NoNodesToDelegate) {
DummyDelegateOptions options = TfLiteDummyDelegateOptionsDefault();
options.allowed_builtin_code = kTfLiteBuiltinSub;
auto delegate = TfLiteDummyDelegateCreateUnique(&options);
interpreter_->ModifyGraphWithDelegate(std::move(delegate));
ASSERT_EQ(interpreter_->execution_plan().size(), 3);
}
TEST_F(TestDelegate, DelegateFailedPrepare) {
DummyDelegateOptions options = TfLiteDummyDelegateOptionsDefault();
options.allowed_builtin_code = kTfLiteBuiltinAdd;
options.error_during_prepare = true;
auto delegate = TfLiteDummyDelegateCreateUnique(&options);
ASSERT_EQ(kTfLiteDelegateError,
interpreter_->ModifyGraphWithDelegate(std::move(delegate)));
}
TEST_F(TestDelegate, DelegateFailedInvoke) {
DummyDelegateOptions options = TfLiteDummyDelegateOptionsDefault();
options.allowed_builtin_code = kTfLiteBuiltinAdd;
options.error_during_invoke = true;
auto delegate = TfLiteDummyDelegateCreateUnique(&options);
ASSERT_EQ(kTfLiteOk,
interpreter_->ModifyGraphWithDelegate(std::move(delegate)));
ASSERT_EQ(kTfLiteError, interpreter_->Invoke());
}
TEST_F(TestDelegate, DelegateFailedInit) {
DummyDelegateOptions options = TfLiteDummyDelegateOptionsDefault();
options.allowed_builtin_code = kTfLiteBuiltinAdd;
options.error_during_init = true;
auto delegate = TfLiteDummyDelegateCreateUnique(&options);
ASSERT_EQ(kTfLiteDelegateError,
interpreter_->ModifyGraphWithDelegate(std::move(delegate)));
}
} // namespace
} // namespace tflite