Initial implementation of WHILE op

PiperOrigin-RevId: 233138666
This commit is contained in:
Yu-Cheng Ling 2019-02-08 15:34:29 -08:00 committed by TensorFlower Gardener
parent 14554b2371
commit 9913382f56
8 changed files with 785 additions and 111 deletions

View File

@ -38,6 +38,7 @@ class TfLiteIntArrayView {
const_iterator begin() const { return int_array_->data; }
const_iterator end() const { return &int_array_->data[int_array_->size]; }
size_t size() const { return end() - begin(); }
int operator[](size_t pos) const { return int_array_->data[pos]; }
private:
const TfLiteIntArray* int_array_;

View File

@ -225,6 +225,7 @@ cc_library(
"unidirectional_sequence_rnn.cc",
"unique.cc",
"unpack.cc",
"while.cc",
"zeros_like.cc",
],
hdrs = [
@ -1227,6 +1228,23 @@ tf_cc_test(
],
)
tf_cc_test(
name = "while_test",
size = "small",
srcs = ["while_test.cc"],
tags = ["tflite_not_portable_ios"],
deps = [
":builtin_ops",
":kernel_util",
":subgraph_test_util",
"//tensorflow/lite:builtin_op_data",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
"@flatbuffers",
],
)
tf_cc_test(
name = "fill_test",
size = "small",

View File

@ -24,25 +24,21 @@ limitations under the License.
namespace tflite {
using subgraph_test_util::BuildAddSubgraph;
using subgraph_test_util::BuildIfSubgraph;
using subgraph_test_util::BuildMulSubgraph;
using subgraph_test_util::BuildPadSubgraph;
using subgraph_test_util::CheckIntTensor;
using subgraph_test_util::ControlFlowOpTest;
using subgraph_test_util::FillIntTensor;
namespace {
// A simple test that performs `ADD` if condition is true, and `MUL` otherwise.
// The computation is: `cond ? a + b : a * b`.
class SimpleIfTest : public ::testing::Test {
class SimpleIfTest : public ControlFlowOpTest {
protected:
void SetUp() override {
interpreter_.reset(new Interpreter);
interpreter_->AddSubgraphs(2);
BuildAddSubgraph(interpreter_->subgraph(1));
BuildMulSubgraph(interpreter_->subgraph(2));
BuildIfSubgraph(&interpreter_->primary_subgraph());
builder_->BuildAddSubgraph(interpreter_->subgraph(1));
builder_->BuildMulSubgraph(interpreter_->subgraph(2));
builder_->BuildIfSubgraph(&interpreter_->primary_subgraph());
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2});
@ -52,7 +48,6 @@ class SimpleIfTest : public ::testing::Test {
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {5, 7});
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {1, 2});
}
std::unique_ptr<Interpreter> interpreter_;
};
TEST_F(SimpleIfTest, TestIfTrue) {
@ -71,14 +66,13 @@ TEST_F(SimpleIfTest, TestIfFalse) {
// Test IF op using subgraphs with dynamically sized outputs.
// The computation is: `cond ? a + b : pad(a, b)`.
class DynamicSubgraphIfTest : public ::testing::Test {
class DynamicSubgraphIfTest : public ControlFlowOpTest {
protected:
void SetUp() override {
interpreter_.reset(new Interpreter);
interpreter_->AddSubgraphs(2);
BuildAddSubgraph(interpreter_->subgraph(1));
BuildPadSubgraph(interpreter_->subgraph(2));
BuildIfSubgraph(&interpreter_->primary_subgraph());
builder_->BuildAddSubgraph(interpreter_->subgraph(1));
builder_->BuildPadSubgraph(interpreter_->subgraph(2));
builder_->BuildIfSubgraph(&interpreter_->primary_subgraph());
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2});
@ -88,7 +82,6 @@ class DynamicSubgraphIfTest : public ::testing::Test {
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {5, 7});
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {1, 2});
}
std::unique_ptr<Interpreter> interpreter_;
};
TEST_F(DynamicSubgraphIfTest, TestIfTrue) {

View File

@ -31,92 +31,142 @@ TfLiteRegistration* Register_ADD();
TfLiteRegistration* Register_MUL();
// ADD and MUL are used to test dynamic sized subgraphs.
TfLiteRegistration* Register_PAD();
TfLiteRegistration* Register_LESS_EQUAL();
} // namespace builtin
namespace custom {
TfLiteRegistration* Register_IF();
TfLiteRegistration* Register_WHILE();
} // namespace custom
} // namespace ops
namespace subgraph_test_util {
namespace {
void SetupTensor(Subgraph* subgraph, int tensor_index, TfLiteType type) {
ASSERT_EQ(subgraph->SetTensorParametersReadWrite(tensor_index, type, "", 0,
nullptr, {}, false),
kTfLiteOk);
}
void BuildAddSubgraph(Subgraph* subgraph) {
int first_new_tensor_index;
ASSERT_EQ(subgraph->AddTensors(3, &first_new_tensor_index), kTfLiteOk);
ASSERT_EQ(first_new_tensor_index, 0);
ASSERT_EQ(subgraph->SetInputs({0, 1}), kTfLiteOk);
ASSERT_EQ(subgraph->SetOutputs({2}), kTfLiteOk);
} // namespace
SetupTensor(subgraph, 0, kTfLiteInt32);
SetupTensor(subgraph, 1, kTfLiteInt32);
SetupTensor(subgraph, 2, kTfLiteInt32);
SubgraphBuilder::~SubgraphBuilder() {
for (auto buffer : buffers_) {
free(buffer);
}
}
void SubgraphBuilder::BuildAddSubgraph(Subgraph* subgraph) {
const int kInput1 = 0;
const int kInput2 = 1;
const int kOutput = 2;
const int kTensorCount = 3;
// kInput1(0) --> +---+
// |ADD| --> kOutput(2)
// kInput2(1) --> +---+
int first_new_tensor_index;
ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index),
kTfLiteOk);
ASSERT_EQ(first_new_tensor_index, 0);
ASSERT_EQ(subgraph->SetInputs({kInput1, kInput2}), kTfLiteOk);
ASSERT_EQ(subgraph->SetOutputs({kOutput}), kTfLiteOk);
SetupTensor(subgraph, kInput1, kTfLiteInt32);
SetupTensor(subgraph, kInput2, kTfLiteInt32);
SetupTensor(subgraph, kOutput, kTfLiteInt32);
TfLiteAddParams* params =
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
params->activation = kTfLiteActNone;
int node_index;
subgraph->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, params,
::tflite::ops::builtin::Register_ADD(),
&node_index);
subgraph->AddNodeWithParameters(
{kInput1, kInput2}, {kOutput}, nullptr, 0, params,
::tflite::ops::builtin::Register_ADD(), &node_index);
}
// Build a subgraph with an mul op. Helper function for testing.
void BuildMulSubgraph(Subgraph* subgraph) {
int first_new_tensor_index;
ASSERT_EQ(subgraph->AddTensors(3, &first_new_tensor_index), kTfLiteOk);
ASSERT_EQ(first_new_tensor_index, 0);
ASSERT_EQ(subgraph->SetInputs({0, 1}), kTfLiteOk);
ASSERT_EQ(subgraph->SetOutputs({2}), kTfLiteOk);
void SubgraphBuilder::BuildMulSubgraph(Subgraph* subgraph) {
const int kInput1 = 0;
const int kInput2 = 1;
const int kOutput = 2;
const int kTensorCount = 3;
// kInput1(0) --> +---+
// |MUL| --> kOutput(2)
// kInput2(1) --> +---+
SetupTensor(subgraph, 0, kTfLiteInt32);
SetupTensor(subgraph, 1, kTfLiteInt32);
SetupTensor(subgraph, 2, kTfLiteInt32);
int first_new_tensor_index;
ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index),
kTfLiteOk);
ASSERT_EQ(first_new_tensor_index, 0);
ASSERT_EQ(subgraph->SetInputs({kInput1, kInput2}), kTfLiteOk);
ASSERT_EQ(subgraph->SetOutputs({kOutput}), kTfLiteOk);
SetupTensor(subgraph, kInput1, kTfLiteInt32);
SetupTensor(subgraph, kInput2, kTfLiteInt32);
SetupTensor(subgraph, kOutput, kTfLiteInt32);
TfLiteMulParams* params =
reinterpret_cast<TfLiteMulParams*>(malloc(sizeof(TfLiteMulParams)));
params->activation = kTfLiteActNone;
int node_index;
subgraph->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, params,
::tflite::ops::builtin::Register_MUL(),
&node_index);
subgraph->AddNodeWithParameters(
{kInput1, kInput2}, {kOutput}, nullptr, 0, params,
::tflite::ops::builtin::Register_MUL(), &node_index);
}
// Build a subgraph with a pad op. Helper function for testing.
void BuildPadSubgraph(Subgraph* subgraph) {
int first_new_tensor_index;
ASSERT_EQ(subgraph->AddTensors(3, &first_new_tensor_index), kTfLiteOk);
ASSERT_EQ(first_new_tensor_index, 0);
ASSERT_EQ(subgraph->SetInputs({0, 1}), kTfLiteOk);
ASSERT_EQ(subgraph->SetOutputs({2}), kTfLiteOk);
void SubgraphBuilder::BuildPadSubgraph(Subgraph* subgraph) {
const int kInput1 = 0;
const int kInput2 = 1;
const int kOutput = 2;
const int kTensorCount = 3;
// kInput1(0) --> +---+
// |PAD| --> kOutput(2)
// kInput2(1) --> +---+
SetupTensor(subgraph, 0, kTfLiteInt32);
SetupTensor(subgraph, 1, kTfLiteInt32);
SetupTensor(subgraph, 2, kTfLiteInt32);
int first_new_tensor_index;
ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index),
kTfLiteOk);
ASSERT_EQ(first_new_tensor_index, 0);
ASSERT_EQ(subgraph->SetInputs({kInput1, kInput2}), kTfLiteOk);
ASSERT_EQ(subgraph->SetOutputs({kOutput}), kTfLiteOk);
SetupTensor(subgraph, kInput1, kTfLiteInt32);
SetupTensor(subgraph, kInput2, kTfLiteInt32);
SetupTensor(subgraph, kOutput, kTfLiteInt32);
TfLitePadParams* params =
reinterpret_cast<TfLitePadParams*>(malloc(sizeof(TfLitePadParams)));
int node_index;
subgraph->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, params,
::tflite::ops::builtin::Register_PAD(),
&node_index);
subgraph->AddNodeWithParameters(
{kInput1, kInput2}, {kOutput}, nullptr, 0, params,
::tflite::ops::builtin::Register_PAD(), &node_index);
}
void BuildIfSubgraph(Subgraph* subgraph) {
int first_new_tensor_index;
ASSERT_EQ(subgraph->AddTensors(4, &first_new_tensor_index), kTfLiteOk);
ASSERT_EQ(first_new_tensor_index, 0);
ASSERT_EQ(subgraph->SetInputs({0, 1, 2}), kTfLiteOk);
ASSERT_EQ(subgraph->SetOutputs({3}), kTfLiteOk);
void SubgraphBuilder::BuildIfSubgraph(Subgraph* subgraph) {
const int kCondInput = 0;
const int kInput1 = 1;
const int kInput2 = 2;
const int kOutput = 3;
const int kTensorCount = 4;
SetupTensor(subgraph, 0, kTfLiteBool);
SetupTensor(subgraph, 1, kTfLiteInt32);
SetupTensor(subgraph, 2, kTfLiteInt32);
SetupTensor(subgraph, 3, kTfLiteInt32);
// kCondInput(0) --> +----+
// kInput1(1) ----> | IF | --> kOutput(3)
// kInput2(2) ----> +----+
int first_new_tensor_index;
ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index),
kTfLiteOk);
ASSERT_EQ(first_new_tensor_index, 0);
ASSERT_EQ(subgraph->SetInputs({kCondInput, kInput1, kInput2}), kTfLiteOk);
ASSERT_EQ(subgraph->SetOutputs({kOutput}), kTfLiteOk);
SetupTensor(subgraph, kCondInput, kTfLiteBool);
SetupTensor(subgraph, kInput1, kTfLiteInt32);
SetupTensor(subgraph, kInput2, kTfLiteInt32);
SetupTensor(subgraph, kOutput, kTfLiteInt32);
flexbuffers::Builder fbb;
fbb.Map([&]() {
@ -128,11 +178,197 @@ void BuildIfSubgraph(Subgraph* subgraph) {
int node_index;
subgraph->AddNodeWithParameters(
{0, 1, 2}, {3}, reinterpret_cast<const char*>(buffer.data()),
buffer.size(), nullptr, ::tflite::ops::custom::Register_IF(),
{kCondInput, kInput1, kInput2}, {kOutput},
reinterpret_cast<const char*>(buffer.data()), buffer.size(), nullptr,
::tflite::ops::custom::Register_IF(), &node_index);
}
void SubgraphBuilder::BuildLessEqualCondSubgraph(Subgraph* subgraph, int rhs) {
const int kInput1 = 0;
const int kInput2 = 1;
const int kOutput = 2;
const int kConstRhs = 3;
const int kTensorCount = 4;
// kInput1(0) ----> +------------+
// | LESS_EQUAL | --> kOutput(2)
// kConstRhs(3) --> +------------+
//
// kInput2(1) --> (unused)
int first_new_tensor_index;
ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index),
kTfLiteOk);
ASSERT_EQ(first_new_tensor_index, 0);
ASSERT_EQ(subgraph->SetInputs({kInput1, kInput2}), kTfLiteOk);
ASSERT_EQ(subgraph->SetOutputs({kOutput}), kTfLiteOk);
SetupTensor(subgraph, kInput1, kTfLiteInt32);
SetupTensor(subgraph, kInput2, kTfLiteInt32);
SetupTensor(subgraph, kOutput, kTfLiteBool);
CreateConstantInt32Tensor(subgraph, kConstRhs, {1}, {rhs});
int node_index;
subgraph->AddNodeWithParameters(
{kInput1, kConstRhs}, {kOutput}, nullptr, 0, nullptr,
::tflite::ops::builtin::Register_LESS_EQUAL(), &node_index);
}
void SubgraphBuilder::BuildAccumulateLoopBodySubgraph(Subgraph* subgraph) {
const int kInputCounter = 0;
const int kInputValue = 1;
const int kOutputCounter = 2;
const int kOutputValue = 3;
const int kConstStep = 4;
const int kTensorCount = 5;
// kInputCounter(0) --> +-----+
// | ADD | --> kOutputCounter(2)
// kConstStep(4) -----> +-----+ |
// |
// v
// +-----+
// | ADD | --> kOutputValue(3)
// kInputValue(1) ----------------------+-----+
int first_new_tensor_index;
ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index),
kTfLiteOk);
ASSERT_EQ(first_new_tensor_index, 0);
ASSERT_EQ(subgraph->SetInputs({kInputCounter, kInputValue}), kTfLiteOk);
ASSERT_EQ(subgraph->SetOutputs({kOutputCounter, kOutputValue}), kTfLiteOk);
SetupTensor(subgraph, kInputCounter, kTfLiteInt32);
SetupTensor(subgraph, kInputValue, kTfLiteInt32);
SetupTensor(subgraph, kOutputCounter, kTfLiteInt32);
SetupTensor(subgraph, kOutputValue, kTfLiteInt32);
CreateConstantInt32Tensor(subgraph, kConstStep, {1}, {1});
int node_index;
TfLiteAddParams* params =
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
params->activation = kTfLiteActNone;
subgraph->AddNodeWithParameters({0, 4}, {2}, nullptr, 0, params,
::tflite::ops::builtin::Register_ADD(),
&node_index);
params = reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
params->activation = kTfLiteActNone;
subgraph->AddNodeWithParameters({2, 1}, {3}, nullptr, 0, params,
::tflite::ops::builtin::Register_ADD(),
&node_index);
}
void SubgraphBuilder::BuildPadLoopBodySubgraph(Subgraph* subgraph,
const std::vector<int> padding) {
const int kInputCounter = 0;
const int kInputValue = 1;
const int kOutputCounter = 2;
const int kOutputValue = 3;
const int kConstStep = 4;
const int kConstPadding = 5;
const int kTensorCount = 6;
// kInputCounter(0) --> +-----+
// | ADD | --> kOutputCounter(2)
// kConstStep(4) -----> +-----+
//
// kInputValue(1) ----> +-----+
// | PAD | --> kOutputValue(3)
// kConstPadding(5) --> +-----+
int first_new_tensor_index;
ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index),
kTfLiteOk);
ASSERT_EQ(first_new_tensor_index, 0);
ASSERT_EQ(subgraph->SetInputs({kInputCounter, kInputValue}), kTfLiteOk);
ASSERT_EQ(subgraph->SetOutputs({kOutputCounter, kOutputValue}), kTfLiteOk);
SetupTensor(subgraph, kInputCounter, kTfLiteInt32);
SetupTensor(subgraph, kInputValue, kTfLiteInt32);
SetupTensor(subgraph, kOutputCounter, kTfLiteInt32);
SetupTensor(subgraph, kOutputValue, kTfLiteInt32);
CreateConstantInt32Tensor(subgraph, kConstStep, {1}, {1});
ASSERT_EQ(padding.size() % 2, 0);
int padding_dims = padding.size();
CreateConstantInt32Tensor(subgraph, kConstPadding, {1, padding_dims},
padding);
int node_index;
TfLiteAddParams* add_params =
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
add_params->activation = kTfLiteActNone;
subgraph->AddNodeWithParameters(
{kInputCounter, kConstStep}, {kOutputCounter}, nullptr, 0, add_params,
::tflite::ops::builtin::Register_ADD(), &node_index);
TfLitePadParams* pad_params =
reinterpret_cast<TfLitePadParams*>(malloc(sizeof(TfLiteAddParams)));
subgraph->AddNodeWithParameters(
{kInputValue, kConstPadding}, {kOutputValue}, nullptr, 0, pad_params,
::tflite::ops::builtin::Register_PAD(), &node_index);
}
void SubgraphBuilder::BuildWhileSubgraph(Subgraph* subgraph) {
const int kInput1 = 0;
const int kInput2 = 1;
const int kOutput1 = 2;
const int kOutput2 = 3;
const int kTensorCount = 4;
// kInput1(0) --> +-------+ --> kOutput1(2)
// | WHILE |
// kInput2(1) --> +-------+ --> kOutput2(3)
int first_new_tensor_index;
ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index),
kTfLiteOk);
ASSERT_EQ(first_new_tensor_index, 0);
ASSERT_EQ(subgraph->SetInputs({kInput1, kInput2}), kTfLiteOk);
ASSERT_EQ(subgraph->SetOutputs({kOutput1, kOutput2}), kTfLiteOk);
SetupTensor(subgraph, kInput1, kTfLiteInt32);
SetupTensor(subgraph, kInput2, kTfLiteInt32);
SetupTensor(subgraph, kOutput1, kTfLiteInt32);
SetupTensor(subgraph, kOutput2, kTfLiteInt32);
flexbuffers::Builder fbb;
fbb.Map([&]() {
fbb.Int("cond_subgraph_index", 1);
fbb.Int("body_subgraph_index", 2);
});
fbb.Finish();
const auto& buffer = fbb.GetBuffer();
int node_index;
subgraph->AddNodeWithParameters(
{0, 1}, {2, 3}, reinterpret_cast<const char*>(buffer.data()),
buffer.size(), nullptr, ::tflite::ops::custom::Register_WHILE(),
&node_index);
}
void SubgraphBuilder::CreateConstantInt32Tensor(Subgraph* subgraph,
int tensor_index,
const std::vector<int>& shape,
const std::vector<int>& data) {
ASSERT_GT(shape.size(), 0);
int num_elements = 1;
for (int dim : shape) {
num_elements *= dim;
}
ASSERT_EQ(data.size(), num_elements);
size_t size_in_bytes = sizeof(int32_t) * num_elements;
// Maybe aligned.
int32_t* buffer = reinterpret_cast<int32_t*>(malloc(size_in_bytes));
for (int i = 0; i < num_elements; ++i) {
buffer[i] = data[i];
}
buffers_.push_back(buffer);
ASSERT_EQ(subgraph->SetTensorParametersReadOnly(
tensor_index, kTfLiteInt32, "", shape, {},
reinterpret_cast<const char*>(buffer), size_in_bytes),
kTfLiteOk);
}
void FillIntTensor(TfLiteTensor* tensor, const std::vector<int32_t>& data) {
int count = NumElements(tensor);
ASSERT_EQ(count, data.size());
@ -155,5 +391,19 @@ void CheckIntTensor(const TfLiteTensor* tensor, const std::vector<int>& shape,
}
}
void CheckBoolTensor(const TfLiteTensor* tensor, const std::vector<int>& shape,
const std::vector<bool>& data) {
ASSERT_EQ(tensor->dims->size, shape.size());
for (int i = 0; i < tensor->dims->size; ++i) {
ASSERT_EQ(tensor->dims->data[i], shape[i]);
}
ASSERT_EQ(tensor->type, kTfLiteBool);
int count = NumElements(tensor);
ASSERT_EQ(count, data.size());
for (int i = 0; i < count; ++i) {
EXPECT_EQ(tensor->data.b[i], data[i]);
}
}
} // namespace subgraph_test_util
} // namespace tflite

View File

@ -20,29 +20,87 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_SUBGRAPH_TEST_UTIL_H_
#define TENSORFLOW_LITE_KERNELS_SUBGRAPH_TEST_UTIL_H_
#include <gtest/gtest.h>
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/interpreter.h"
namespace tflite {
namespace subgraph_test_util {
// Build a subgraph with a single Add op.
// 2 inputs. 1 output.
void BuildAddSubgraph(Subgraph* subgraph);
// TODO(ycling): This file should be renamed as
// `control_flow_test_util` to avoid confusion. I'll do it immediately
// in a separated change.
class SubgraphBuilder {
public:
~SubgraphBuilder();
// Build a subgraph with a single Mul op.
// 2 inputs. 1 output.
void BuildMulSubgraph(Subgraph* subgraph);
// Build a subgraph with a single Add op.
// 2 inputs. 1 output.
void BuildAddSubgraph(Subgraph* subgraph);
// Build a subgraph with a single Pad op.
// 2 inputs. 1 output.
void BuildPadSubgraph(Subgraph* subgraph);
// Build a subgraph with a single Mul op.
// 2 inputs. 1 output.
void BuildMulSubgraph(Subgraph* subgraph);
// Build a subgraph with a single If op.
// 3 inputs:
// The 1st input is condition with boolean type.
// The 2nd and 3rd inputs are feed input the branch subgraphs.
// 1 output.
void BuildIfSubgraph(Subgraph* subgraph);
// Build a subgraph with a single Pad op.
// 2 inputs. 1 output.
void BuildPadSubgraph(Subgraph* subgraph);
// Build a subgraph with a single If op.
// 3 inputs:
// The 1st input is condition with boolean type.
// The 2nd and 3rd inputs are feed input the branch subgraphs.
// 1 output.
void BuildIfSubgraph(Subgraph* subgraph);
// Build a subgraph with a single Less op.
// The subgraph is used as the condition subgraph for testing `While` op.
// 2 inputs:
// The 1st input is a counter with `kTfLiteInt32` type.
// The 2nd input is ignored in this subgraph.
// 1 output with `kTfLiteBool` type.
// Equivalent to (input < rhs).
void BuildLessEqualCondSubgraph(Subgraph* subgraph, int rhs);
// An accumulate loop body subgraph. Used to produce triangle number
// seqeuence. 2 inputs and 2 outpus
// Equivalent to (counter, value) -> (counter + 1, counter + 1 + value)
void BuildAccumulateLoopBodySubgraph(Subgraph* subgraph);
// A pad loop body subgraph. When used in a loop it will repeatively enlarge
// the
// tensor.
// 2 inputs and 2 outputs.
// Equivalent to (counter, value) -> (counter + 1, tf.pad(value, padding))
// Note the padding is created as a constant tensor.
void BuildPadLoopBodySubgraph(Subgraph* subgraph,
const std::vector<int> padding);
// Build a subgraph with a single While op.
// 2 inputs, 2 outputs.
void BuildWhileSubgraph(Subgraph* subgraph);
private:
void CreateConstantInt32Tensor(Subgraph* subgraph, int tensor_index,
const std::vector<int>& shape,
const std::vector<int>& data);
std::vector<void*> buffers_;
};
class ControlFlowOpTest : public ::testing::Test {
public:
ControlFlowOpTest()
: interpreter_(new Interpreter), builder_(new SubgraphBuilder) {}
~ControlFlowOpTest() override {
interpreter_.reset();
builder_.reset();
}
protected:
std::unique_ptr<Interpreter> interpreter_;
std::unique_ptr<SubgraphBuilder> builder_;
};
// Fill a `TfLiteTensor` with a 32-bits integer vector.
// Preconditions:
@ -52,9 +110,12 @@ void BuildIfSubgraph(Subgraph* subgraph);
// the vector.
void FillIntTensor(TfLiteTensor* tensor, const std::vector<int32_t>& data);
// Check if the shape and data of a tensor is as expected.
// Check if the shape and int32 data of a tensor is as expected.
void CheckIntTensor(const TfLiteTensor* tensor, const std::vector<int>& shape,
const std::vector<int32_t>& data);
// Check if the shape and bool data of a tensor is as expected.
void CheckBoolTensor(const TfLiteTensor* tensor, const std::vector<int>& shape,
const std::vector<bool>& data);
} // namespace subgraph_test_util
} // namespace tflite

View File

@ -24,55 +24,128 @@ namespace subgraph_test_util {
namespace {
// SubGraphTestUtilTest tests the helper functions defined in this file.
TEST(SubGraphTestUtilTest, TestBuildAddSubgraph) {
std::unique_ptr<Interpreter> interpreter(new Interpreter);
BuildAddSubgraph(&interpreter->primary_subgraph());
class SubgraphBuilderTest : public ::testing::Test {
public:
SubgraphBuilderTest()
: interpreter_(new Interpreter), builder_(new SubgraphBuilder) {}
interpreter->ResizeInputTensor(interpreter->inputs()[0], {2});
interpreter->ResizeInputTensor(interpreter->inputs()[1], {1, 2});
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
~SubgraphBuilderTest() override {
interpreter_.reset();
builder_.reset();
}
FillIntTensor(interpreter->tensor(interpreter->inputs()[0]), {5, 7});
FillIntTensor(interpreter->tensor(interpreter->inputs()[1]), {1, 2});
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
protected:
void TestAccumelateLoopBody(int input1, int input2, int output1,
int output2) {
interpreter_.reset(new Interpreter);
builder_->BuildAccumulateLoopBodySubgraph(
&interpreter_->primary_subgraph());
TfLiteTensor* output = interpreter->tensor(interpreter->outputs()[0]);
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1});
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {input1});
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {input2});
ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
TfLiteTensor* output_tensor1 =
interpreter_->tensor(interpreter_->outputs()[0]);
CheckIntTensor(output_tensor1, {1}, {output1});
TfLiteTensor* output_tensor2 =
interpreter_->tensor(interpreter_->outputs()[1]);
CheckIntTensor(output_tensor2, {1}, {output2});
}
std::unique_ptr<Interpreter> interpreter_;
std::unique_ptr<SubgraphBuilder> builder_;
};
TEST_F(SubgraphBuilderTest, TestBuildAddSubgraph) {
builder_->BuildAddSubgraph(&interpreter_->primary_subgraph());
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2});
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2});
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {5, 7});
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1, 2});
ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
CheckIntTensor(output, {1, 2}, {6, 9});
}
TEST(SubGraphTestUtilTest, TestBuildMulSubgraph) {
std::unique_ptr<Interpreter> interpreter(new Interpreter);
BuildMulSubgraph(&interpreter->primary_subgraph());
TEST_F(SubgraphBuilderTest, TestBuildMulSubgraph) {
builder_->BuildMulSubgraph(&interpreter_->primary_subgraph());
interpreter->ResizeInputTensor(interpreter->inputs()[0], {2});
interpreter->ResizeInputTensor(interpreter->inputs()[1], {1, 2});
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2});
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2});
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
FillIntTensor(interpreter->tensor(interpreter->inputs()[0]), {5, 7});
FillIntTensor(interpreter->tensor(interpreter->inputs()[1]), {1, 2});
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {5, 7});
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1, 2});
ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
TfLiteTensor* output = interpreter->tensor(interpreter->outputs()[0]);
TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
CheckIntTensor(output, {1, 2}, {5, 14});
}
TEST(SubGraphTestUtilTest, TestBuildPadSubgraph) {
std::unique_ptr<Interpreter> interpreter(new Interpreter);
BuildPadSubgraph(&interpreter->primary_subgraph());
TEST_F(SubgraphBuilderTest, TestBuildPadSubgraph) {
builder_->BuildPadSubgraph(&interpreter_->primary_subgraph());
interpreter->ResizeInputTensor(interpreter->inputs()[0], {2});
interpreter->ResizeInputTensor(interpreter->inputs()[1], {1, 2});
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2});
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2});
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
FillIntTensor(interpreter->tensor(interpreter->inputs()[0]), {5, 7});
FillIntTensor(interpreter->tensor(interpreter->inputs()[1]), {1, 2});
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {5, 7});
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1, 2});
ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
TfLiteTensor* output = interpreter->tensor(interpreter->outputs()[0]);
TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
CheckIntTensor(output, {5}, {0, 5, 7, 0, 0});
}
TEST_F(SubgraphBuilderTest, TestBuildLessEqualCondSubgraph) {
builder_->BuildLessEqualCondSubgraph(&interpreter_->primary_subgraph(), 3);
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {5});
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {10, 10});
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
// Test [1, 2, 3, 4, 5] <= 3 == [true, true, true, false, false]
// (with broadcasting).
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]),
{1, 2, 3, 4, 5});
ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
CheckBoolTensor(output, {5}, {true, true, true, false, false});
}
TEST_F(SubgraphBuilderTest, TestBuildAccumulateLoopBodySubgraph) {
TestAccumelateLoopBody(1, 1, 2, 3);
TestAccumelateLoopBody(2, 3, 3, 6);
TestAccumelateLoopBody(3, 6, 4, 10);
}
TEST_F(SubgraphBuilderTest, TestBuildPadLoopBodySubgraph) {
builder_->BuildPadLoopBodySubgraph(&interpreter_->primary_subgraph(), {1, 2});
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {5});
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1});
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]),
{0, 5, 7, 0, 0});
ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]);
CheckIntTensor(output1, {1}, {2});
TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]);
CheckIntTensor(output2, {8}, {0, 0, 5, 7, 0, 0, 0, 0});
}
} // namespace
} // namespace subgraph_test_util
} // namespace tflite

View File

@ -0,0 +1,193 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/context_util.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace custom {
namespace while_kernel {
namespace {
TfLiteStatus ResizeSubgraphInputs(TfLiteContext* context, TfLiteNode* node,
Subgraph* subgraph) {
int num_inputs = node->inputs->size;
for (int i = 0; i < num_inputs; ++i) {
const TfLiteTensor* input = GetInput(context, node, i);
std::vector<int> dims(input->dims->data,
input->dims->data + input->dims->size);
subgraph->ResizeInputTensor(i, dims);
TfLiteTensor* subgraph_input = subgraph->tensor(subgraph->inputs()[i]);
TF_LITE_ENSURE_EQ(context, input->type, subgraph_input->type);
}
return kTfLiteOk;
}
template <typename SrcVector, typename DstVector>
TfLiteStatus CopyTensors(TfLiteContext* context, Subgraph* src_subgraph,
const SrcVector& src_tensor_indices,
Subgraph* dst_subgraph,
const DstVector& dst_tensor_indices) {
TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(),
dst_tensor_indices.size());
for (int i = 0; i < src_tensor_indices.size(); ++i) {
const TfLiteTensor* src_tensor =
src_subgraph->tensor(src_tensor_indices[i]);
TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]);
TF_LITE_ENSURE_EQ(context, src_tensor->bytes, dst_tensor->bytes);
memcpy(dst_tensor->data.raw, src_tensor->data.raw, src_tensor->bytes);
}
return kTfLiteOk;
}
} // namespace
struct OpData {
int cond_subgraph_index;
int body_subgraph_index;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* op_data = new OpData;
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
op_data->cond_subgraph_index = m["cond_subgraph_index"].AsInt32();
op_data->body_subgraph_index = m["body_subgraph_index"].AsInt32();
return op_data;
}
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<OpData*>(buffer);
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
int num_inputs = node->inputs->size;
// The number of outputs should be the same as number of inputs.
TF_LITE_ENSURE_EQ(context, node->outputs->size, num_inputs);
// Check subgraph indices and get subgraphs.
Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
auto* subgraphs = this_subgraph->GetSubgraphs();
TF_LITE_ENSURE(context, op_data->cond_subgraph_index < subgraphs->size());
TF_LITE_ENSURE(context, op_data->body_subgraph_index < subgraphs->size());
Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get();
Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get();
// Check input & output count of the condition subgraph.
TF_LITE_ENSURE_EQ(context, cond_subgraph->inputs().size(), num_inputs);
TF_LITE_ENSURE_EQ(context, cond_subgraph->outputs().size(), 1);
// Check input & output count of the body subgraph.
TF_LITE_ENSURE_EQ(context, body_subgraph->inputs().size(), num_inputs);
TF_LITE_ENSURE_EQ(context, body_subgraph->outputs().size(), num_inputs);
// Prepare and check the condition subgraph.
ResizeSubgraphInputs(context, node, cond_subgraph);
TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
TfLiteTensor* cond_output =
cond_subgraph->tensor(cond_subgraph->outputs()[0]);
// The condition output must be a single boolean value.
TF_LITE_ENSURE_EQ(context, cond_output->type, kTfLiteBool);
TF_LITE_ENSURE_EQ(context, cond_output->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cond_output->dims->data[0], 1);
// TODO(ycling): Handle the case where condition graph has dynamic
// sized tensors.
// Prepare and check the body subgraph.
ResizeSubgraphInputs(context, node, body_subgraph);
TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors());
for (int i = 0; i < num_inputs; ++i) {
TfLiteTensor* body_input =
body_subgraph->tensor(body_subgraph->inputs()[i]);
TfLiteTensor* body_output =
body_subgraph->tensor(body_subgraph->outputs()[i]);
TF_LITE_ENSURE_EQ(context, body_input->type, body_output->type);
// TODO(ycling): Support dynamic sized body subgraph.
TF_LITE_ENSURE(context, !IsDynamicTensor(body_output));
TF_LITE_ENSURE(context,
TfLiteIntArrayEqual(body_input->dims, body_output->dims));
TfLiteTensor* output = GetOutput(context, node, i);
TfLiteIntArray* output_size = TfLiteIntArrayCopy(body_output->dims);
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size));
}
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
auto* subgraphs = this_subgraph->GetSubgraphs();
Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get();
Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get();
// Currently we copy the input / output between the subgraphs. This isn't
// optimized yet.
// TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
TF_LITE_ENSURE_OK(
context,
CopyTensors(context, this_subgraph, TfLiteIntArrayView(node->inputs),
cond_subgraph, cond_subgraph->inputs()));
TF_LITE_ENSURE_OK(
context,
CopyTensors(context, this_subgraph, TfLiteIntArrayView(node->inputs),
body_subgraph, body_subgraph->inputs()));
while (true) {
TF_LITE_ENSURE_OK(context, cond_subgraph->Invoke());
TfLiteTensor* cond_output =
cond_subgraph->tensor(cond_subgraph->outputs()[0]);
if (!cond_output->data.b[0]) {
break;
}
TF_LITE_ENSURE_OK(context, body_subgraph->Invoke());
TF_LITE_ENSURE_OK(
context, CopyTensors(context, body_subgraph, body_subgraph->outputs(),
body_subgraph, body_subgraph->inputs()));
TF_LITE_ENSURE_OK(
context, CopyTensors(context, body_subgraph, body_subgraph->outputs(),
cond_subgraph, cond_subgraph->inputs()));
}
// Note that copying from body's output will fail if body is never invoked.
// TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
TF_LITE_ENSURE_OK(
context, CopyTensors(context, body_subgraph, body_subgraph->inputs(),
this_subgraph, TfLiteIntArrayView(node->outputs)));
return kTfLiteOk;
}
} // namespace while_kernel
TfLiteRegistration* Register_WHILE() {
static TfLiteRegistration r = {while_kernel::Init, while_kernel::Free,
while_kernel::Prepare, while_kernel::Eval};
return &r;
}
} // namespace custom
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,85 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/subgraph_test_util.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
namespace tflite {
using subgraph_test_util::CheckIntTensor;
using subgraph_test_util::ControlFlowOpTest;
using subgraph_test_util::FillIntTensor;
namespace {
class WhileTest : public ControlFlowOpTest {};
// The test builds a model that produces the i-th number of
// triangular number sequence.
//
// TODO(ycling): Consider to improve this test case by adding a
// concat into the body subgraph.
TEST_F(WhileTest, TestTriangularNumberSequence) {
const std::vector<int> expected = {1, 3, 6, 10, 15, 21, 28};
for (int i = 0; i < expected.size(); ++i) {
interpreter_.reset(new Interpreter);
interpreter_->AddSubgraphs(2);
builder_->BuildLessEqualCondSubgraph(interpreter_->subgraph(1), i);
builder_->BuildAccumulateLoopBodySubgraph(interpreter_->subgraph(2));
builder_->BuildWhileSubgraph(&interpreter_->primary_subgraph());
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1});
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1});
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1});
ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]);
CheckIntTensor(output1, {1}, {i + 1});
TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]);
CheckIntTensor(output2, {1}, {expected[i]});
}
}
// This requires dynamic sized subgraphs and it's not supported right now.
// TODO(ycling): Support dynamic sized subgraphs.
TEST_F(WhileTest, TestPadLoop) {
interpreter_.reset(new Interpreter);
interpreter_->AddSubgraphs(2);
builder_->BuildLessEqualCondSubgraph(interpreter_->subgraph(1), 3);
builder_->BuildPadLoopBodySubgraph(interpreter_->subgraph(2), {1, 2});
builder_->BuildWhileSubgraph(&interpreter_->primary_subgraph());
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2});
// This is not supported yet. The test ensures thatit doesn't crash and raises
// an error properly.
ASSERT_NE(interpreter_->AllocateTensors(), kTfLiteOk);
}
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}