Refactoring: Move subgraph testing utility function to subgraph_test_util module.
This is a no-op change to improve code quality. PiperOrigin-RevId: 232758236
This commit is contained in:
parent
3d704b754f
commit
6103a3d46a
@ -1218,9 +1218,10 @@ tf_cc_test(
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":kernel_util",
|
||||
":subgraph_test_util",
|
||||
":test_util",
|
||||
"//tensorflow/lite:builtin_op_data",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@flatbuffers",
|
||||
],
|
||||
@ -1285,3 +1286,31 @@ tf_cc_test(
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "subgraph_test_util",
|
||||
testonly = 1,
|
||||
srcs = ["subgraph_test_util.cc"],
|
||||
hdrs = ["subgraph_test_util.h"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":kernel_util",
|
||||
":test_util",
|
||||
"//tensorflow/lite:builtin_op_data",
|
||||
"//tensorflow/lite:framework",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "subgraph_test_util_test",
|
||||
size = "small",
|
||||
srcs = ["subgraph_test_util_test.cc"],
|
||||
deps = [
|
||||
":subgraph_test_util",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -12,192 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/subgraph_test_util.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
// ADD and MUL are used to test simple branch.
|
||||
TfLiteRegistration* Register_ADD();
|
||||
TfLiteRegistration* Register_MUL();
|
||||
// Pad is used to test dynamic sized subgraphs.
|
||||
TfLiteRegistration* Register_PAD();
|
||||
} // namespace builtin
|
||||
namespace custom {
|
||||
TfLiteRegistration* Register_IF();
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
|
||||
using subgraph_test_util::BuildAddSubgraph;
|
||||
using subgraph_test_util::BuildIfSubgraph;
|
||||
using subgraph_test_util::BuildMulSubgraph;
|
||||
using subgraph_test_util::BuildPadSubgraph;
|
||||
using subgraph_test_util::CheckIntTensor;
|
||||
using subgraph_test_util::FillIntTensor;
|
||||
|
||||
namespace {
|
||||
|
||||
void SetupTensor(Subgraph* subgraph, int tensor_index,
|
||||
const std::vector<int>& shape, TfLiteType type) {
|
||||
ASSERT_EQ(subgraph->SetTensorParametersReadWrite(
|
||||
tensor_index, type, "", shape.size(), shape.data(), {}, false),
|
||||
kTfLiteOk);
|
||||
}
|
||||
|
||||
// TODO(ycling): Consider to move all the test helper functions to another
|
||||
// build target (e.g. subgraph_test_util).
|
||||
// Build a subgraph with an add op. Helper function for testing.
|
||||
void BuildAddSubgraph(Subgraph* subgraph) {
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(subgraph->AddTensors(3, &first_new_tensor_index), kTfLiteOk);
|
||||
ASSERT_EQ(first_new_tensor_index, 0);
|
||||
ASSERT_EQ(subgraph->SetInputs({0, 1}), kTfLiteOk);
|
||||
ASSERT_EQ(subgraph->SetOutputs({2}), kTfLiteOk);
|
||||
|
||||
SetupTensor(subgraph, 0, {2}, kTfLiteInt32);
|
||||
SetupTensor(subgraph, 1, {1, 2}, kTfLiteInt32);
|
||||
// Intentionally set the wrong output size for testing. This should be
|
||||
// overridden by Prepare function.
|
||||
SetupTensor(subgraph, 2, {100}, kTfLiteInt32);
|
||||
|
||||
TfLiteAddParams* params =
|
||||
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
|
||||
params->activation = kTfLiteActNone;
|
||||
int node_index;
|
||||
subgraph->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, params,
|
||||
::tflite::ops::builtin::Register_ADD(),
|
||||
&node_index);
|
||||
}
|
||||
|
||||
// Build a subgraph with an mul op. Helper function for testing.
|
||||
void BuildMulSubgraph(Subgraph* subgraph) {
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(subgraph->AddTensors(3, &first_new_tensor_index), kTfLiteOk);
|
||||
ASSERT_EQ(first_new_tensor_index, 0);
|
||||
ASSERT_EQ(subgraph->SetInputs({0, 1}), kTfLiteOk);
|
||||
ASSERT_EQ(subgraph->SetOutputs({2}), kTfLiteOk);
|
||||
|
||||
SetupTensor(subgraph, 0, {2}, kTfLiteInt32);
|
||||
SetupTensor(subgraph, 1, {1, 2}, kTfLiteInt32);
|
||||
// Intentionally set the wrong output size for testing. This should be
|
||||
// overridden by Prepare function.
|
||||
SetupTensor(subgraph, 2, {100}, kTfLiteInt32);
|
||||
|
||||
TfLiteMulParams* params =
|
||||
reinterpret_cast<TfLiteMulParams*>(malloc(sizeof(TfLiteMulParams)));
|
||||
params->activation = kTfLiteActNone;
|
||||
int node_index;
|
||||
subgraph->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, params,
|
||||
::tflite::ops::builtin::Register_MUL(),
|
||||
&node_index);
|
||||
}
|
||||
|
||||
// Build a subgraph with a pad op. Helper function for testing.
|
||||
void BuildPadSubgraph(Subgraph* subgraph) {
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(subgraph->AddTensors(3, &first_new_tensor_index), kTfLiteOk);
|
||||
ASSERT_EQ(first_new_tensor_index, 0);
|
||||
ASSERT_EQ(subgraph->SetInputs({0, 1}), kTfLiteOk);
|
||||
ASSERT_EQ(subgraph->SetOutputs({2}), kTfLiteOk);
|
||||
|
||||
SetupTensor(subgraph, 0, {2}, kTfLiteInt32);
|
||||
SetupTensor(subgraph, 1, {1, 2}, kTfLiteInt32);
|
||||
// Intentionally set the wrong output size for testing. This should be
|
||||
// overridden by Prepare function.
|
||||
SetupTensor(subgraph, 2, {100}, kTfLiteInt32);
|
||||
|
||||
TfLitePadParams* params =
|
||||
reinterpret_cast<TfLitePadParams*>(malloc(sizeof(TfLitePadParams)));
|
||||
int node_index;
|
||||
subgraph->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, params,
|
||||
::tflite::ops::builtin::Register_PAD(),
|
||||
&node_index);
|
||||
}
|
||||
|
||||
void BuildIfSubgraph(Subgraph* subgraph) {
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(subgraph->AddTensors(4, &first_new_tensor_index), kTfLiteOk);
|
||||
ASSERT_EQ(first_new_tensor_index, 0);
|
||||
ASSERT_EQ(subgraph->SetInputs({0, 1, 2}), kTfLiteOk);
|
||||
ASSERT_EQ(subgraph->SetOutputs({3}), kTfLiteOk);
|
||||
|
||||
SetupTensor(subgraph, 0, {1}, kTfLiteBool);
|
||||
SetupTensor(subgraph, 1, {2}, kTfLiteInt32);
|
||||
SetupTensor(subgraph, 2, {1, 2}, kTfLiteInt32);
|
||||
// Intentionally set the wrong output size for testing. This should be
|
||||
// overridden by Prepare function.
|
||||
SetupTensor(subgraph, 3, {100}, kTfLiteInt32);
|
||||
|
||||
flexbuffers::Builder fbb;
|
||||
fbb.Map([&]() {
|
||||
fbb.Int("then_subgraph_index", 1);
|
||||
fbb.Int("else_subgraph_index", 2);
|
||||
});
|
||||
fbb.Finish();
|
||||
const auto& buffer = fbb.GetBuffer();
|
||||
|
||||
int node_index;
|
||||
subgraph->AddNodeWithParameters(
|
||||
{0, 1, 2}, {3}, reinterpret_cast<const char*>(buffer.data()),
|
||||
buffer.size(), nullptr, ::tflite::ops::custom::Register_IF(),
|
||||
&node_index);
|
||||
}
|
||||
|
||||
void FillIntTensor(TfLiteTensor* tensor, const std::vector<int32_t>& data) {
|
||||
int count = NumElements(tensor);
|
||||
ASSERT_EQ(count, data.size());
|
||||
for (int i = 0; i < count; ++i) {
|
||||
tensor->data.i32[i] = data[i];
|
||||
}
|
||||
}
|
||||
|
||||
void CheckIntTensor(const TfLiteTensor* tensor, const std::vector<int>& shape,
|
||||
const std::vector<int32_t>& data) {
|
||||
ASSERT_EQ(tensor->dims->size, shape.size());
|
||||
for (int i = 0; i < tensor->dims->size; ++i) {
|
||||
ASSERT_EQ(tensor->dims->data[i], shape[i]);
|
||||
}
|
||||
ASSERT_EQ(tensor->type, kTfLiteInt32);
|
||||
int count = NumElements(tensor);
|
||||
ASSERT_EQ(count, data.size());
|
||||
for (int i = 0; i < count; ++i) {
|
||||
EXPECT_EQ(tensor->data.i32[i], data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// TestHelperfunctionTest tests the helper functions defined in this file.
|
||||
TEST(TestHelperfunctionTest, TestBuildAddSubgraph) {
|
||||
std::unique_ptr<Interpreter> interpreter(new Interpreter);
|
||||
BuildAddSubgraph(&interpreter->primary_subgraph());
|
||||
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||
FillIntTensor(interpreter->tensor(interpreter->inputs()[0]), {5, 7});
|
||||
FillIntTensor(interpreter->tensor(interpreter->inputs()[1]), {1, 2});
|
||||
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
|
||||
TfLiteTensor* output = interpreter->tensor(interpreter->outputs()[0]);
|
||||
CheckIntTensor(output, {1, 2}, {6, 9});
|
||||
}
|
||||
|
||||
TEST(TestHelperfunctionTest, TestBuildMulSubgraph) {
|
||||
std::unique_ptr<Interpreter> interpreter(new Interpreter);
|
||||
BuildMulSubgraph(&interpreter->primary_subgraph());
|
||||
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||
FillIntTensor(interpreter->tensor(interpreter->inputs()[0]), {5, 7});
|
||||
FillIntTensor(interpreter->tensor(interpreter->inputs()[1]), {1, 2});
|
||||
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
|
||||
TfLiteTensor* output = interpreter->tensor(interpreter->outputs()[0]);
|
||||
CheckIntTensor(output, {1, 2}, {5, 14});
|
||||
}
|
||||
|
||||
TEST(TestHelperfunctionTest, TestBuildPadSubgraph) {
|
||||
std::unique_ptr<Interpreter> interpreter(new Interpreter);
|
||||
BuildPadSubgraph(&interpreter->primary_subgraph());
|
||||
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||
FillIntTensor(interpreter->tensor(interpreter->inputs()[0]), {5, 7});
|
||||
FillIntTensor(interpreter->tensor(interpreter->inputs()[1]), {1, 2});
|
||||
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
|
||||
TfLiteTensor* output = interpreter->tensor(interpreter->outputs()[0]);
|
||||
CheckIntTensor(output, {5}, {0, 5, 7, 0, 0});
|
||||
}
|
||||
|
||||
// A simple test that performs `ADD` if condition is true, and `MUL` otherwise.
|
||||
// The computation is: `cond ? a + b : a * b`.
|
||||
class SimpleIfTest : public ::testing::Test {
|
||||
@ -208,7 +43,12 @@ class SimpleIfTest : public ::testing::Test {
|
||||
BuildAddSubgraph(interpreter_->subgraph(1));
|
||||
BuildMulSubgraph(interpreter_->subgraph(2));
|
||||
BuildIfSubgraph(&interpreter_->primary_subgraph());
|
||||
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2});
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1, 2});
|
||||
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {5, 7});
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {1, 2});
|
||||
}
|
||||
@ -239,7 +79,12 @@ class DynamicSubgraphIfTest : public ::testing::Test {
|
||||
BuildAddSubgraph(interpreter_->subgraph(1));
|
||||
BuildPadSubgraph(interpreter_->subgraph(2));
|
||||
BuildIfSubgraph(&interpreter_->primary_subgraph());
|
||||
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2});
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1, 2});
|
||||
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {5, 7});
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {1, 2});
|
||||
}
|
||||
|
159
tensorflow/lite/kernels/subgraph_test_util.cc
Normal file
159
tensorflow/lite/kernels/subgraph_test_util.cc
Normal file
@ -0,0 +1,159 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/kernels/subgraph_test_util.h"
|
||||
|
||||
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
|
||||
#include "tensorflow/lite/core/subgraph.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
// ADD and MUL are used to test simple branch.
|
||||
TfLiteRegistration* Register_ADD();
|
||||
TfLiteRegistration* Register_MUL();
|
||||
// ADD and MUL are used to test dynamic sized subgraphs.
|
||||
TfLiteRegistration* Register_PAD();
|
||||
} // namespace builtin
|
||||
namespace custom {
|
||||
TfLiteRegistration* Register_IF();
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
|
||||
namespace subgraph_test_util {
|
||||
|
||||
void SetupTensor(Subgraph* subgraph, int tensor_index, TfLiteType type) {
|
||||
ASSERT_EQ(subgraph->SetTensorParametersReadWrite(tensor_index, type, "", 0,
|
||||
nullptr, {}, false),
|
||||
kTfLiteOk);
|
||||
}
|
||||
|
||||
void BuildAddSubgraph(Subgraph* subgraph) {
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(subgraph->AddTensors(3, &first_new_tensor_index), kTfLiteOk);
|
||||
ASSERT_EQ(first_new_tensor_index, 0);
|
||||
ASSERT_EQ(subgraph->SetInputs({0, 1}), kTfLiteOk);
|
||||
ASSERT_EQ(subgraph->SetOutputs({2}), kTfLiteOk);
|
||||
|
||||
SetupTensor(subgraph, 0, kTfLiteInt32);
|
||||
SetupTensor(subgraph, 1, kTfLiteInt32);
|
||||
SetupTensor(subgraph, 2, kTfLiteInt32);
|
||||
|
||||
TfLiteAddParams* params =
|
||||
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
|
||||
params->activation = kTfLiteActNone;
|
||||
int node_index;
|
||||
subgraph->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, params,
|
||||
::tflite::ops::builtin::Register_ADD(),
|
||||
&node_index);
|
||||
}
|
||||
|
||||
// Build a subgraph with an mul op. Helper function for testing.
|
||||
void BuildMulSubgraph(Subgraph* subgraph) {
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(subgraph->AddTensors(3, &first_new_tensor_index), kTfLiteOk);
|
||||
ASSERT_EQ(first_new_tensor_index, 0);
|
||||
ASSERT_EQ(subgraph->SetInputs({0, 1}), kTfLiteOk);
|
||||
ASSERT_EQ(subgraph->SetOutputs({2}), kTfLiteOk);
|
||||
|
||||
SetupTensor(subgraph, 0, kTfLiteInt32);
|
||||
SetupTensor(subgraph, 1, kTfLiteInt32);
|
||||
SetupTensor(subgraph, 2, kTfLiteInt32);
|
||||
|
||||
TfLiteMulParams* params =
|
||||
reinterpret_cast<TfLiteMulParams*>(malloc(sizeof(TfLiteMulParams)));
|
||||
params->activation = kTfLiteActNone;
|
||||
int node_index;
|
||||
subgraph->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, params,
|
||||
::tflite::ops::builtin::Register_MUL(),
|
||||
&node_index);
|
||||
}
|
||||
|
||||
// Build a subgraph with a pad op. Helper function for testing.
|
||||
void BuildPadSubgraph(Subgraph* subgraph) {
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(subgraph->AddTensors(3, &first_new_tensor_index), kTfLiteOk);
|
||||
ASSERT_EQ(first_new_tensor_index, 0);
|
||||
ASSERT_EQ(subgraph->SetInputs({0, 1}), kTfLiteOk);
|
||||
ASSERT_EQ(subgraph->SetOutputs({2}), kTfLiteOk);
|
||||
|
||||
SetupTensor(subgraph, 0, kTfLiteInt32);
|
||||
SetupTensor(subgraph, 1, kTfLiteInt32);
|
||||
SetupTensor(subgraph, 2, kTfLiteInt32);
|
||||
|
||||
TfLitePadParams* params =
|
||||
reinterpret_cast<TfLitePadParams*>(malloc(sizeof(TfLitePadParams)));
|
||||
int node_index;
|
||||
subgraph->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, params,
|
||||
::tflite::ops::builtin::Register_PAD(),
|
||||
&node_index);
|
||||
}
|
||||
|
||||
void BuildIfSubgraph(Subgraph* subgraph) {
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(subgraph->AddTensors(4, &first_new_tensor_index), kTfLiteOk);
|
||||
ASSERT_EQ(first_new_tensor_index, 0);
|
||||
ASSERT_EQ(subgraph->SetInputs({0, 1, 2}), kTfLiteOk);
|
||||
ASSERT_EQ(subgraph->SetOutputs({3}), kTfLiteOk);
|
||||
|
||||
SetupTensor(subgraph, 0, kTfLiteBool);
|
||||
SetupTensor(subgraph, 1, kTfLiteInt32);
|
||||
SetupTensor(subgraph, 2, kTfLiteInt32);
|
||||
SetupTensor(subgraph, 3, kTfLiteInt32);
|
||||
|
||||
flexbuffers::Builder fbb;
|
||||
fbb.Map([&]() {
|
||||
fbb.Int("then_subgraph_index", 1);
|
||||
fbb.Int("else_subgraph_index", 2);
|
||||
});
|
||||
fbb.Finish();
|
||||
const auto& buffer = fbb.GetBuffer();
|
||||
|
||||
int node_index;
|
||||
subgraph->AddNodeWithParameters(
|
||||
{0, 1, 2}, {3}, reinterpret_cast<const char*>(buffer.data()),
|
||||
buffer.size(), nullptr, ::tflite::ops::custom::Register_IF(),
|
||||
&node_index);
|
||||
}
|
||||
|
||||
void FillIntTensor(TfLiteTensor* tensor, const std::vector<int32_t>& data) {
|
||||
int count = NumElements(tensor);
|
||||
ASSERT_EQ(count, data.size());
|
||||
for (int i = 0; i < count; ++i) {
|
||||
tensor->data.i32[i] = data[i];
|
||||
}
|
||||
}
|
||||
|
||||
void CheckIntTensor(const TfLiteTensor* tensor, const std::vector<int>& shape,
|
||||
const std::vector<int32_t>& data) {
|
||||
ASSERT_EQ(tensor->dims->size, shape.size());
|
||||
for (int i = 0; i < tensor->dims->size; ++i) {
|
||||
ASSERT_EQ(tensor->dims->data[i], shape[i]);
|
||||
}
|
||||
ASSERT_EQ(tensor->type, kTfLiteInt32);
|
||||
int count = NumElements(tensor);
|
||||
ASSERT_EQ(count, data.size());
|
||||
for (int i = 0; i < count; ++i) {
|
||||
EXPECT_EQ(tensor->data.i32[i], data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace subgraph_test_util
|
||||
} // namespace tflite
|
62
tensorflow/lite/kernels/subgraph_test_util.h
Normal file
62
tensorflow/lite/kernels/subgraph_test_util.h
Normal file
@ -0,0 +1,62 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This module provides helper functions for testing the interaction between
|
||||
// control flow ops and subgraphs.
|
||||
// For convenience, we mostly only use `kTfLiteInt32` in this module.
|
||||
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_SUBGRAPH_TEST_UTIL_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_SUBGRAPH_TEST_UTIL_H_
|
||||
|
||||
#include "tensorflow/lite/core/subgraph.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace subgraph_test_util {
|
||||
|
||||
// Build a subgraph with a single Add op.
|
||||
// 2 inputs. 1 output.
|
||||
void BuildAddSubgraph(Subgraph* subgraph);
|
||||
|
||||
// Build a subgraph with a single Mul op.
|
||||
// 2 inputs. 1 output.
|
||||
void BuildMulSubgraph(Subgraph* subgraph);
|
||||
|
||||
// Build a subgraph with a single Pad op.
|
||||
// 2 inputs. 1 output.
|
||||
void BuildPadSubgraph(Subgraph* subgraph);
|
||||
|
||||
// Build a subgraph with a single If op.
|
||||
// 3 inputs:
|
||||
// The 1st input is condition with boolean type.
|
||||
// The 2nd and 3rd inputs are feed input the branch subgraphs.
|
||||
// 1 output.
|
||||
void BuildIfSubgraph(Subgraph* subgraph);
|
||||
|
||||
// Fill a `TfLiteTensor` with a 32-bits integer vector.
|
||||
// Preconditions:
|
||||
// * The tensor must have `kTfLiteInt32` type.
|
||||
// * The tensor must be allocated.
|
||||
// * The element count of the tensor must be equal to the length or
|
||||
// the vector.
|
||||
void FillIntTensor(TfLiteTensor* tensor, const std::vector<int32_t>& data);
|
||||
|
||||
// Check if the shape and data of a tensor is as expected.
|
||||
void CheckIntTensor(const TfLiteTensor* tensor, const std::vector<int>& shape,
|
||||
const std::vector<int32_t>& data);
|
||||
|
||||
} // namespace subgraph_test_util
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_SUBGRAPH_TEST_UTIL_H_
|
84
tensorflow/lite/kernels/subgraph_test_util_test.cc
Normal file
84
tensorflow/lite/kernels/subgraph_test_util_test.cc
Normal file
@ -0,0 +1,84 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/kernels/subgraph_test_util.h"
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace subgraph_test_util {
|
||||
|
||||
namespace {
|
||||
|
||||
// SubGraphTestUtilTest tests the helper functions defined in this file.
|
||||
TEST(SubGraphTestUtilTest, TestBuildAddSubgraph) {
|
||||
std::unique_ptr<Interpreter> interpreter(new Interpreter);
|
||||
BuildAddSubgraph(&interpreter->primary_subgraph());
|
||||
|
||||
interpreter->ResizeInputTensor(interpreter->inputs()[0], {2});
|
||||
interpreter->ResizeInputTensor(interpreter->inputs()[1], {1, 2});
|
||||
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||
|
||||
FillIntTensor(interpreter->tensor(interpreter->inputs()[0]), {5, 7});
|
||||
FillIntTensor(interpreter->tensor(interpreter->inputs()[1]), {1, 2});
|
||||
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
|
||||
|
||||
TfLiteTensor* output = interpreter->tensor(interpreter->outputs()[0]);
|
||||
CheckIntTensor(output, {1, 2}, {6, 9});
|
||||
}
|
||||
|
||||
TEST(SubGraphTestUtilTest, TestBuildMulSubgraph) {
|
||||
std::unique_ptr<Interpreter> interpreter(new Interpreter);
|
||||
BuildMulSubgraph(&interpreter->primary_subgraph());
|
||||
|
||||
interpreter->ResizeInputTensor(interpreter->inputs()[0], {2});
|
||||
interpreter->ResizeInputTensor(interpreter->inputs()[1], {1, 2});
|
||||
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||
|
||||
FillIntTensor(interpreter->tensor(interpreter->inputs()[0]), {5, 7});
|
||||
FillIntTensor(interpreter->tensor(interpreter->inputs()[1]), {1, 2});
|
||||
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
|
||||
|
||||
TfLiteTensor* output = interpreter->tensor(interpreter->outputs()[0]);
|
||||
CheckIntTensor(output, {1, 2}, {5, 14});
|
||||
}
|
||||
|
||||
TEST(SubGraphTestUtilTest, TestBuildPadSubgraph) {
|
||||
std::unique_ptr<Interpreter> interpreter(new Interpreter);
|
||||
BuildPadSubgraph(&interpreter->primary_subgraph());
|
||||
|
||||
interpreter->ResizeInputTensor(interpreter->inputs()[0], {2});
|
||||
interpreter->ResizeInputTensor(interpreter->inputs()[1], {1, 2});
|
||||
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||
|
||||
FillIntTensor(interpreter->tensor(interpreter->inputs()[0]), {5, 7});
|
||||
FillIntTensor(interpreter->tensor(interpreter->inputs()[1]), {1, 2});
|
||||
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
|
||||
|
||||
TfLiteTensor* output = interpreter->tensor(interpreter->outputs()[0]);
|
||||
CheckIntTensor(output, {5}, {0, 5, 7, 0, 0});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace subgraph_test_util
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
Loading…
Reference in New Issue
Block a user