Initial implementation of WHILE op
PiperOrigin-RevId: 233138666
This commit is contained in:
parent
14554b2371
commit
9913382f56
@ -38,6 +38,7 @@ class TfLiteIntArrayView {
|
||||
const_iterator begin() const { return int_array_->data; }
|
||||
const_iterator end() const { return &int_array_->data[int_array_->size]; }
|
||||
size_t size() const { return end() - begin(); }
|
||||
int operator[](size_t pos) const { return int_array_->data[pos]; }
|
||||
|
||||
private:
|
||||
const TfLiteIntArray* int_array_;
|
||||
|
@ -225,6 +225,7 @@ cc_library(
|
||||
"unidirectional_sequence_rnn.cc",
|
||||
"unique.cc",
|
||||
"unpack.cc",
|
||||
"while.cc",
|
||||
"zeros_like.cc",
|
||||
],
|
||||
hdrs = [
|
||||
@ -1227,6 +1228,23 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "while_test",
|
||||
size = "small",
|
||||
srcs = ["while_test.cc"],
|
||||
tags = ["tflite_not_portable_ios"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":kernel_util",
|
||||
":subgraph_test_util",
|
||||
"//tensorflow/lite:builtin_op_data",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "fill_test",
|
||||
size = "small",
|
||||
|
@ -24,25 +24,21 @@ limitations under the License.
|
||||
|
||||
namespace tflite {
|
||||
|
||||
using subgraph_test_util::BuildAddSubgraph;
|
||||
using subgraph_test_util::BuildIfSubgraph;
|
||||
using subgraph_test_util::BuildMulSubgraph;
|
||||
using subgraph_test_util::BuildPadSubgraph;
|
||||
using subgraph_test_util::CheckIntTensor;
|
||||
using subgraph_test_util::ControlFlowOpTest;
|
||||
using subgraph_test_util::FillIntTensor;
|
||||
|
||||
namespace {
|
||||
|
||||
// A simple test that performs `ADD` if condition is true, and `MUL` otherwise.
|
||||
// The computation is: `cond ? a + b : a * b`.
|
||||
class SimpleIfTest : public ::testing::Test {
|
||||
class SimpleIfTest : public ControlFlowOpTest {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
interpreter_.reset(new Interpreter);
|
||||
interpreter_->AddSubgraphs(2);
|
||||
BuildAddSubgraph(interpreter_->subgraph(1));
|
||||
BuildMulSubgraph(interpreter_->subgraph(2));
|
||||
BuildIfSubgraph(&interpreter_->primary_subgraph());
|
||||
builder_->BuildAddSubgraph(interpreter_->subgraph(1));
|
||||
builder_->BuildMulSubgraph(interpreter_->subgraph(2));
|
||||
builder_->BuildIfSubgraph(&interpreter_->primary_subgraph());
|
||||
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2});
|
||||
@ -52,7 +48,6 @@ class SimpleIfTest : public ::testing::Test {
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {5, 7});
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {1, 2});
|
||||
}
|
||||
std::unique_ptr<Interpreter> interpreter_;
|
||||
};
|
||||
|
||||
TEST_F(SimpleIfTest, TestIfTrue) {
|
||||
@ -71,14 +66,13 @@ TEST_F(SimpleIfTest, TestIfFalse) {
|
||||
|
||||
// Test IF op using subgraphs with dynamically sized outputs.
|
||||
// The computation is: `cond ? a + b : pad(a, b)`.
|
||||
class DynamicSubgraphIfTest : public ::testing::Test {
|
||||
class DynamicSubgraphIfTest : public ControlFlowOpTest {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
interpreter_.reset(new Interpreter);
|
||||
interpreter_->AddSubgraphs(2);
|
||||
BuildAddSubgraph(interpreter_->subgraph(1));
|
||||
BuildPadSubgraph(interpreter_->subgraph(2));
|
||||
BuildIfSubgraph(&interpreter_->primary_subgraph());
|
||||
builder_->BuildAddSubgraph(interpreter_->subgraph(1));
|
||||
builder_->BuildPadSubgraph(interpreter_->subgraph(2));
|
||||
builder_->BuildIfSubgraph(&interpreter_->primary_subgraph());
|
||||
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2});
|
||||
@ -88,7 +82,6 @@ class DynamicSubgraphIfTest : public ::testing::Test {
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {5, 7});
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {1, 2});
|
||||
}
|
||||
std::unique_ptr<Interpreter> interpreter_;
|
||||
};
|
||||
|
||||
TEST_F(DynamicSubgraphIfTest, TestIfTrue) {
|
||||
|
@ -31,92 +31,142 @@ TfLiteRegistration* Register_ADD();
|
||||
TfLiteRegistration* Register_MUL();
|
||||
// ADD and MUL are used to test dynamic sized subgraphs.
|
||||
TfLiteRegistration* Register_PAD();
|
||||
TfLiteRegistration* Register_LESS_EQUAL();
|
||||
} // namespace builtin
|
||||
namespace custom {
|
||||
TfLiteRegistration* Register_IF();
|
||||
TfLiteRegistration* Register_WHILE();
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
|
||||
namespace subgraph_test_util {
|
||||
|
||||
namespace {
|
||||
|
||||
void SetupTensor(Subgraph* subgraph, int tensor_index, TfLiteType type) {
|
||||
ASSERT_EQ(subgraph->SetTensorParametersReadWrite(tensor_index, type, "", 0,
|
||||
nullptr, {}, false),
|
||||
kTfLiteOk);
|
||||
}
|
||||
|
||||
void BuildAddSubgraph(Subgraph* subgraph) {
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(subgraph->AddTensors(3, &first_new_tensor_index), kTfLiteOk);
|
||||
ASSERT_EQ(first_new_tensor_index, 0);
|
||||
ASSERT_EQ(subgraph->SetInputs({0, 1}), kTfLiteOk);
|
||||
ASSERT_EQ(subgraph->SetOutputs({2}), kTfLiteOk);
|
||||
} // namespace
|
||||
|
||||
SetupTensor(subgraph, 0, kTfLiteInt32);
|
||||
SetupTensor(subgraph, 1, kTfLiteInt32);
|
||||
SetupTensor(subgraph, 2, kTfLiteInt32);
|
||||
SubgraphBuilder::~SubgraphBuilder() {
|
||||
for (auto buffer : buffers_) {
|
||||
free(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
void SubgraphBuilder::BuildAddSubgraph(Subgraph* subgraph) {
|
||||
const int kInput1 = 0;
|
||||
const int kInput2 = 1;
|
||||
const int kOutput = 2;
|
||||
const int kTensorCount = 3;
|
||||
// kInput1(0) --> +---+
|
||||
// |ADD| --> kOutput(2)
|
||||
// kInput2(1) --> +---+
|
||||
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index),
|
||||
kTfLiteOk);
|
||||
ASSERT_EQ(first_new_tensor_index, 0);
|
||||
ASSERT_EQ(subgraph->SetInputs({kInput1, kInput2}), kTfLiteOk);
|
||||
ASSERT_EQ(subgraph->SetOutputs({kOutput}), kTfLiteOk);
|
||||
|
||||
SetupTensor(subgraph, kInput1, kTfLiteInt32);
|
||||
SetupTensor(subgraph, kInput2, kTfLiteInt32);
|
||||
SetupTensor(subgraph, kOutput, kTfLiteInt32);
|
||||
|
||||
TfLiteAddParams* params =
|
||||
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
|
||||
params->activation = kTfLiteActNone;
|
||||
int node_index;
|
||||
subgraph->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, params,
|
||||
::tflite::ops::builtin::Register_ADD(),
|
||||
&node_index);
|
||||
subgraph->AddNodeWithParameters(
|
||||
{kInput1, kInput2}, {kOutput}, nullptr, 0, params,
|
||||
::tflite::ops::builtin::Register_ADD(), &node_index);
|
||||
}
|
||||
|
||||
// Build a subgraph with an mul op. Helper function for testing.
|
||||
void BuildMulSubgraph(Subgraph* subgraph) {
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(subgraph->AddTensors(3, &first_new_tensor_index), kTfLiteOk);
|
||||
ASSERT_EQ(first_new_tensor_index, 0);
|
||||
ASSERT_EQ(subgraph->SetInputs({0, 1}), kTfLiteOk);
|
||||
ASSERT_EQ(subgraph->SetOutputs({2}), kTfLiteOk);
|
||||
void SubgraphBuilder::BuildMulSubgraph(Subgraph* subgraph) {
|
||||
const int kInput1 = 0;
|
||||
const int kInput2 = 1;
|
||||
const int kOutput = 2;
|
||||
const int kTensorCount = 3;
|
||||
// kInput1(0) --> +---+
|
||||
// |MUL| --> kOutput(2)
|
||||
// kInput2(1) --> +---+
|
||||
|
||||
SetupTensor(subgraph, 0, kTfLiteInt32);
|
||||
SetupTensor(subgraph, 1, kTfLiteInt32);
|
||||
SetupTensor(subgraph, 2, kTfLiteInt32);
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index),
|
||||
kTfLiteOk);
|
||||
ASSERT_EQ(first_new_tensor_index, 0);
|
||||
ASSERT_EQ(subgraph->SetInputs({kInput1, kInput2}), kTfLiteOk);
|
||||
ASSERT_EQ(subgraph->SetOutputs({kOutput}), kTfLiteOk);
|
||||
|
||||
SetupTensor(subgraph, kInput1, kTfLiteInt32);
|
||||
SetupTensor(subgraph, kInput2, kTfLiteInt32);
|
||||
SetupTensor(subgraph, kOutput, kTfLiteInt32);
|
||||
|
||||
TfLiteMulParams* params =
|
||||
reinterpret_cast<TfLiteMulParams*>(malloc(sizeof(TfLiteMulParams)));
|
||||
params->activation = kTfLiteActNone;
|
||||
int node_index;
|
||||
subgraph->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, params,
|
||||
::tflite::ops::builtin::Register_MUL(),
|
||||
&node_index);
|
||||
subgraph->AddNodeWithParameters(
|
||||
{kInput1, kInput2}, {kOutput}, nullptr, 0, params,
|
||||
::tflite::ops::builtin::Register_MUL(), &node_index);
|
||||
}
|
||||
|
||||
// Build a subgraph with a pad op. Helper function for testing.
|
||||
void BuildPadSubgraph(Subgraph* subgraph) {
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(subgraph->AddTensors(3, &first_new_tensor_index), kTfLiteOk);
|
||||
ASSERT_EQ(first_new_tensor_index, 0);
|
||||
ASSERT_EQ(subgraph->SetInputs({0, 1}), kTfLiteOk);
|
||||
ASSERT_EQ(subgraph->SetOutputs({2}), kTfLiteOk);
|
||||
void SubgraphBuilder::BuildPadSubgraph(Subgraph* subgraph) {
|
||||
const int kInput1 = 0;
|
||||
const int kInput2 = 1;
|
||||
const int kOutput = 2;
|
||||
const int kTensorCount = 3;
|
||||
// kInput1(0) --> +---+
|
||||
// |PAD| --> kOutput(2)
|
||||
// kInput2(1) --> +---+
|
||||
|
||||
SetupTensor(subgraph, 0, kTfLiteInt32);
|
||||
SetupTensor(subgraph, 1, kTfLiteInt32);
|
||||
SetupTensor(subgraph, 2, kTfLiteInt32);
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index),
|
||||
kTfLiteOk);
|
||||
ASSERT_EQ(first_new_tensor_index, 0);
|
||||
ASSERT_EQ(subgraph->SetInputs({kInput1, kInput2}), kTfLiteOk);
|
||||
ASSERT_EQ(subgraph->SetOutputs({kOutput}), kTfLiteOk);
|
||||
|
||||
SetupTensor(subgraph, kInput1, kTfLiteInt32);
|
||||
SetupTensor(subgraph, kInput2, kTfLiteInt32);
|
||||
SetupTensor(subgraph, kOutput, kTfLiteInt32);
|
||||
|
||||
TfLitePadParams* params =
|
||||
reinterpret_cast<TfLitePadParams*>(malloc(sizeof(TfLitePadParams)));
|
||||
int node_index;
|
||||
subgraph->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, params,
|
||||
::tflite::ops::builtin::Register_PAD(),
|
||||
&node_index);
|
||||
subgraph->AddNodeWithParameters(
|
||||
{kInput1, kInput2}, {kOutput}, nullptr, 0, params,
|
||||
::tflite::ops::builtin::Register_PAD(), &node_index);
|
||||
}
|
||||
|
||||
void BuildIfSubgraph(Subgraph* subgraph) {
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(subgraph->AddTensors(4, &first_new_tensor_index), kTfLiteOk);
|
||||
ASSERT_EQ(first_new_tensor_index, 0);
|
||||
ASSERT_EQ(subgraph->SetInputs({0, 1, 2}), kTfLiteOk);
|
||||
ASSERT_EQ(subgraph->SetOutputs({3}), kTfLiteOk);
|
||||
void SubgraphBuilder::BuildIfSubgraph(Subgraph* subgraph) {
|
||||
const int kCondInput = 0;
|
||||
const int kInput1 = 1;
|
||||
const int kInput2 = 2;
|
||||
const int kOutput = 3;
|
||||
const int kTensorCount = 4;
|
||||
|
||||
SetupTensor(subgraph, 0, kTfLiteBool);
|
||||
SetupTensor(subgraph, 1, kTfLiteInt32);
|
||||
SetupTensor(subgraph, 2, kTfLiteInt32);
|
||||
SetupTensor(subgraph, 3, kTfLiteInt32);
|
||||
// kCondInput(0) --> +----+
|
||||
// kInput1(1) ----> | IF | --> kOutput(3)
|
||||
// kInput2(2) ----> +----+
|
||||
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index),
|
||||
kTfLiteOk);
|
||||
ASSERT_EQ(first_new_tensor_index, 0);
|
||||
ASSERT_EQ(subgraph->SetInputs({kCondInput, kInput1, kInput2}), kTfLiteOk);
|
||||
ASSERT_EQ(subgraph->SetOutputs({kOutput}), kTfLiteOk);
|
||||
|
||||
SetupTensor(subgraph, kCondInput, kTfLiteBool);
|
||||
SetupTensor(subgraph, kInput1, kTfLiteInt32);
|
||||
SetupTensor(subgraph, kInput2, kTfLiteInt32);
|
||||
SetupTensor(subgraph, kOutput, kTfLiteInt32);
|
||||
|
||||
flexbuffers::Builder fbb;
|
||||
fbb.Map([&]() {
|
||||
@ -128,9 +178,195 @@ void BuildIfSubgraph(Subgraph* subgraph) {
|
||||
|
||||
int node_index;
|
||||
subgraph->AddNodeWithParameters(
|
||||
{0, 1, 2}, {3}, reinterpret_cast<const char*>(buffer.data()),
|
||||
buffer.size(), nullptr, ::tflite::ops::custom::Register_IF(),
|
||||
{kCondInput, kInput1, kInput2}, {kOutput},
|
||||
reinterpret_cast<const char*>(buffer.data()), buffer.size(), nullptr,
|
||||
::tflite::ops::custom::Register_IF(), &node_index);
|
||||
}
|
||||
|
||||
void SubgraphBuilder::BuildLessEqualCondSubgraph(Subgraph* subgraph, int rhs) {
|
||||
const int kInput1 = 0;
|
||||
const int kInput2 = 1;
|
||||
const int kOutput = 2;
|
||||
const int kConstRhs = 3;
|
||||
const int kTensorCount = 4;
|
||||
|
||||
// kInput1(0) ----> +------------+
|
||||
// | LESS_EQUAL | --> kOutput(2)
|
||||
// kConstRhs(3) --> +------------+
|
||||
//
|
||||
// kInput2(1) --> (unused)
|
||||
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index),
|
||||
kTfLiteOk);
|
||||
ASSERT_EQ(first_new_tensor_index, 0);
|
||||
ASSERT_EQ(subgraph->SetInputs({kInput1, kInput2}), kTfLiteOk);
|
||||
ASSERT_EQ(subgraph->SetOutputs({kOutput}), kTfLiteOk);
|
||||
|
||||
SetupTensor(subgraph, kInput1, kTfLiteInt32);
|
||||
SetupTensor(subgraph, kInput2, kTfLiteInt32);
|
||||
SetupTensor(subgraph, kOutput, kTfLiteBool);
|
||||
|
||||
CreateConstantInt32Tensor(subgraph, kConstRhs, {1}, {rhs});
|
||||
int node_index;
|
||||
subgraph->AddNodeWithParameters(
|
||||
{kInput1, kConstRhs}, {kOutput}, nullptr, 0, nullptr,
|
||||
::tflite::ops::builtin::Register_LESS_EQUAL(), &node_index);
|
||||
}
|
||||
|
||||
void SubgraphBuilder::BuildAccumulateLoopBodySubgraph(Subgraph* subgraph) {
|
||||
const int kInputCounter = 0;
|
||||
const int kInputValue = 1;
|
||||
const int kOutputCounter = 2;
|
||||
const int kOutputValue = 3;
|
||||
const int kConstStep = 4;
|
||||
const int kTensorCount = 5;
|
||||
|
||||
// kInputCounter(0) --> +-----+
|
||||
// | ADD | --> kOutputCounter(2)
|
||||
// kConstStep(4) -----> +-----+ |
|
||||
// |
|
||||
// v
|
||||
// +-----+
|
||||
// | ADD | --> kOutputValue(3)
|
||||
// kInputValue(1) ----------------------+-----+
|
||||
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index),
|
||||
kTfLiteOk);
|
||||
ASSERT_EQ(first_new_tensor_index, 0);
|
||||
ASSERT_EQ(subgraph->SetInputs({kInputCounter, kInputValue}), kTfLiteOk);
|
||||
ASSERT_EQ(subgraph->SetOutputs({kOutputCounter, kOutputValue}), kTfLiteOk);
|
||||
|
||||
SetupTensor(subgraph, kInputCounter, kTfLiteInt32);
|
||||
SetupTensor(subgraph, kInputValue, kTfLiteInt32);
|
||||
SetupTensor(subgraph, kOutputCounter, kTfLiteInt32);
|
||||
SetupTensor(subgraph, kOutputValue, kTfLiteInt32);
|
||||
CreateConstantInt32Tensor(subgraph, kConstStep, {1}, {1});
|
||||
|
||||
int node_index;
|
||||
TfLiteAddParams* params =
|
||||
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
|
||||
params->activation = kTfLiteActNone;
|
||||
subgraph->AddNodeWithParameters({0, 4}, {2}, nullptr, 0, params,
|
||||
::tflite::ops::builtin::Register_ADD(),
|
||||
&node_index);
|
||||
params = reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
|
||||
params->activation = kTfLiteActNone;
|
||||
subgraph->AddNodeWithParameters({2, 1}, {3}, nullptr, 0, params,
|
||||
::tflite::ops::builtin::Register_ADD(),
|
||||
&node_index);
|
||||
}
|
||||
|
||||
void SubgraphBuilder::BuildPadLoopBodySubgraph(Subgraph* subgraph,
|
||||
const std::vector<int> padding) {
|
||||
const int kInputCounter = 0;
|
||||
const int kInputValue = 1;
|
||||
const int kOutputCounter = 2;
|
||||
const int kOutputValue = 3;
|
||||
const int kConstStep = 4;
|
||||
const int kConstPadding = 5;
|
||||
const int kTensorCount = 6;
|
||||
|
||||
// kInputCounter(0) --> +-----+
|
||||
// | ADD | --> kOutputCounter(2)
|
||||
// kConstStep(4) -----> +-----+
|
||||
//
|
||||
// kInputValue(1) ----> +-----+
|
||||
// | PAD | --> kOutputValue(3)
|
||||
// kConstPadding(5) --> +-----+
|
||||
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index),
|
||||
kTfLiteOk);
|
||||
ASSERT_EQ(first_new_tensor_index, 0);
|
||||
ASSERT_EQ(subgraph->SetInputs({kInputCounter, kInputValue}), kTfLiteOk);
|
||||
ASSERT_EQ(subgraph->SetOutputs({kOutputCounter, kOutputValue}), kTfLiteOk);
|
||||
|
||||
SetupTensor(subgraph, kInputCounter, kTfLiteInt32);
|
||||
SetupTensor(subgraph, kInputValue, kTfLiteInt32);
|
||||
SetupTensor(subgraph, kOutputCounter, kTfLiteInt32);
|
||||
SetupTensor(subgraph, kOutputValue, kTfLiteInt32);
|
||||
|
||||
CreateConstantInt32Tensor(subgraph, kConstStep, {1}, {1});
|
||||
ASSERT_EQ(padding.size() % 2, 0);
|
||||
int padding_dims = padding.size();
|
||||
CreateConstantInt32Tensor(subgraph, kConstPadding, {1, padding_dims},
|
||||
padding);
|
||||
|
||||
int node_index;
|
||||
TfLiteAddParams* add_params =
|
||||
reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
|
||||
add_params->activation = kTfLiteActNone;
|
||||
subgraph->AddNodeWithParameters(
|
||||
{kInputCounter, kConstStep}, {kOutputCounter}, nullptr, 0, add_params,
|
||||
::tflite::ops::builtin::Register_ADD(), &node_index);
|
||||
TfLitePadParams* pad_params =
|
||||
reinterpret_cast<TfLitePadParams*>(malloc(sizeof(TfLiteAddParams)));
|
||||
subgraph->AddNodeWithParameters(
|
||||
{kInputValue, kConstPadding}, {kOutputValue}, nullptr, 0, pad_params,
|
||||
::tflite::ops::builtin::Register_PAD(), &node_index);
|
||||
}
|
||||
|
||||
void SubgraphBuilder::BuildWhileSubgraph(Subgraph* subgraph) {
|
||||
const int kInput1 = 0;
|
||||
const int kInput2 = 1;
|
||||
const int kOutput1 = 2;
|
||||
const int kOutput2 = 3;
|
||||
const int kTensorCount = 4;
|
||||
|
||||
// kInput1(0) --> +-------+ --> kOutput1(2)
|
||||
// | WHILE |
|
||||
// kInput2(1) --> +-------+ --> kOutput2(3)
|
||||
|
||||
int first_new_tensor_index;
|
||||
ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index),
|
||||
kTfLiteOk);
|
||||
ASSERT_EQ(first_new_tensor_index, 0);
|
||||
ASSERT_EQ(subgraph->SetInputs({kInput1, kInput2}), kTfLiteOk);
|
||||
ASSERT_EQ(subgraph->SetOutputs({kOutput1, kOutput2}), kTfLiteOk);
|
||||
|
||||
SetupTensor(subgraph, kInput1, kTfLiteInt32);
|
||||
SetupTensor(subgraph, kInput2, kTfLiteInt32);
|
||||
SetupTensor(subgraph, kOutput1, kTfLiteInt32);
|
||||
SetupTensor(subgraph, kOutput2, kTfLiteInt32);
|
||||
|
||||
flexbuffers::Builder fbb;
|
||||
fbb.Map([&]() {
|
||||
fbb.Int("cond_subgraph_index", 1);
|
||||
fbb.Int("body_subgraph_index", 2);
|
||||
});
|
||||
fbb.Finish();
|
||||
const auto& buffer = fbb.GetBuffer();
|
||||
|
||||
int node_index;
|
||||
subgraph->AddNodeWithParameters(
|
||||
{0, 1}, {2, 3}, reinterpret_cast<const char*>(buffer.data()),
|
||||
buffer.size(), nullptr, ::tflite::ops::custom::Register_WHILE(),
|
||||
&node_index);
|
||||
}
|
||||
|
||||
void SubgraphBuilder::CreateConstantInt32Tensor(Subgraph* subgraph,
|
||||
int tensor_index,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<int>& data) {
|
||||
ASSERT_GT(shape.size(), 0);
|
||||
int num_elements = 1;
|
||||
for (int dim : shape) {
|
||||
num_elements *= dim;
|
||||
}
|
||||
ASSERT_EQ(data.size(), num_elements);
|
||||
size_t size_in_bytes = sizeof(int32_t) * num_elements;
|
||||
// Maybe aligned.
|
||||
int32_t* buffer = reinterpret_cast<int32_t*>(malloc(size_in_bytes));
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
buffer[i] = data[i];
|
||||
}
|
||||
buffers_.push_back(buffer);
|
||||
ASSERT_EQ(subgraph->SetTensorParametersReadOnly(
|
||||
tensor_index, kTfLiteInt32, "", shape, {},
|
||||
reinterpret_cast<const char*>(buffer), size_in_bytes),
|
||||
kTfLiteOk);
|
||||
}
|
||||
|
||||
void FillIntTensor(TfLiteTensor* tensor, const std::vector<int32_t>& data) {
|
||||
@ -155,5 +391,19 @@ void CheckIntTensor(const TfLiteTensor* tensor, const std::vector<int>& shape,
|
||||
}
|
||||
}
|
||||
|
||||
void CheckBoolTensor(const TfLiteTensor* tensor, const std::vector<int>& shape,
|
||||
const std::vector<bool>& data) {
|
||||
ASSERT_EQ(tensor->dims->size, shape.size());
|
||||
for (int i = 0; i < tensor->dims->size; ++i) {
|
||||
ASSERT_EQ(tensor->dims->data[i], shape[i]);
|
||||
}
|
||||
ASSERT_EQ(tensor->type, kTfLiteBool);
|
||||
int count = NumElements(tensor);
|
||||
ASSERT_EQ(count, data.size());
|
||||
for (int i = 0; i < count; ++i) {
|
||||
EXPECT_EQ(tensor->data.b[i], data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace subgraph_test_util
|
||||
} // namespace tflite
|
||||
|
@ -20,11 +20,20 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_SUBGRAPH_TEST_UTIL_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_SUBGRAPH_TEST_UTIL_H_
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/core/subgraph.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace subgraph_test_util {
|
||||
|
||||
// TODO(ycling): This file should be renamed as
|
||||
// `control_flow_test_util` to avoid confusion. I'll do it immediately
|
||||
// in a separated change.
|
||||
class SubgraphBuilder {
|
||||
public:
|
||||
~SubgraphBuilder();
|
||||
|
||||
// Build a subgraph with a single Add op.
|
||||
// 2 inputs. 1 output.
|
||||
void BuildAddSubgraph(Subgraph* subgraph);
|
||||
@ -44,6 +53,55 @@ void BuildPadSubgraph(Subgraph* subgraph);
|
||||
// 1 output.
|
||||
void BuildIfSubgraph(Subgraph* subgraph);
|
||||
|
||||
// Build a subgraph with a single Less op.
|
||||
// The subgraph is used as the condition subgraph for testing `While` op.
|
||||
// 2 inputs:
|
||||
// The 1st input is a counter with `kTfLiteInt32` type.
|
||||
// The 2nd input is ignored in this subgraph.
|
||||
// 1 output with `kTfLiteBool` type.
|
||||
// Equivalent to (input < rhs).
|
||||
void BuildLessEqualCondSubgraph(Subgraph* subgraph, int rhs);
|
||||
|
||||
// An accumulate loop body subgraph. Used to produce triangle number
|
||||
// seqeuence. 2 inputs and 2 outpus
|
||||
// Equivalent to (counter, value) -> (counter + 1, counter + 1 + value)
|
||||
void BuildAccumulateLoopBodySubgraph(Subgraph* subgraph);
|
||||
|
||||
// A pad loop body subgraph. When used in a loop it will repeatively enlarge
|
||||
// the
|
||||
// tensor.
|
||||
// 2 inputs and 2 outputs.
|
||||
// Equivalent to (counter, value) -> (counter + 1, tf.pad(value, padding))
|
||||
// Note the padding is created as a constant tensor.
|
||||
void BuildPadLoopBodySubgraph(Subgraph* subgraph,
|
||||
const std::vector<int> padding);
|
||||
|
||||
// Build a subgraph with a single While op.
|
||||
// 2 inputs, 2 outputs.
|
||||
void BuildWhileSubgraph(Subgraph* subgraph);
|
||||
|
||||
private:
|
||||
void CreateConstantInt32Tensor(Subgraph* subgraph, int tensor_index,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<int>& data);
|
||||
std::vector<void*> buffers_;
|
||||
};
|
||||
|
||||
class ControlFlowOpTest : public ::testing::Test {
|
||||
public:
|
||||
ControlFlowOpTest()
|
||||
: interpreter_(new Interpreter), builder_(new SubgraphBuilder) {}
|
||||
|
||||
~ControlFlowOpTest() override {
|
||||
interpreter_.reset();
|
||||
builder_.reset();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<Interpreter> interpreter_;
|
||||
std::unique_ptr<SubgraphBuilder> builder_;
|
||||
};
|
||||
|
||||
// Fill a `TfLiteTensor` with a 32-bits integer vector.
|
||||
// Preconditions:
|
||||
// * The tensor must have `kTfLiteInt32` type.
|
||||
@ -52,9 +110,12 @@ void BuildIfSubgraph(Subgraph* subgraph);
|
||||
// the vector.
|
||||
void FillIntTensor(TfLiteTensor* tensor, const std::vector<int32_t>& data);
|
||||
|
||||
// Check if the shape and data of a tensor is as expected.
|
||||
// Check if the shape and int32 data of a tensor is as expected.
|
||||
void CheckIntTensor(const TfLiteTensor* tensor, const std::vector<int>& shape,
|
||||
const std::vector<int32_t>& data);
|
||||
// Check if the shape and bool data of a tensor is as expected.
|
||||
void CheckBoolTensor(const TfLiteTensor* tensor, const std::vector<int>& shape,
|
||||
const std::vector<bool>& data);
|
||||
|
||||
} // namespace subgraph_test_util
|
||||
} // namespace tflite
|
||||
|
@ -24,55 +24,128 @@ namespace subgraph_test_util {
|
||||
|
||||
namespace {
|
||||
|
||||
// SubGraphTestUtilTest tests the helper functions defined in this file.
|
||||
TEST(SubGraphTestUtilTest, TestBuildAddSubgraph) {
|
||||
std::unique_ptr<Interpreter> interpreter(new Interpreter);
|
||||
BuildAddSubgraph(&interpreter->primary_subgraph());
|
||||
class SubgraphBuilderTest : public ::testing::Test {
|
||||
public:
|
||||
SubgraphBuilderTest()
|
||||
: interpreter_(new Interpreter), builder_(new SubgraphBuilder) {}
|
||||
|
||||
interpreter->ResizeInputTensor(interpreter->inputs()[0], {2});
|
||||
interpreter->ResizeInputTensor(interpreter->inputs()[1], {1, 2});
|
||||
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||
~SubgraphBuilderTest() override {
|
||||
interpreter_.reset();
|
||||
builder_.reset();
|
||||
}
|
||||
|
||||
FillIntTensor(interpreter->tensor(interpreter->inputs()[0]), {5, 7});
|
||||
FillIntTensor(interpreter->tensor(interpreter->inputs()[1]), {1, 2});
|
||||
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
|
||||
protected:
|
||||
void TestAccumelateLoopBody(int input1, int input2, int output1,
|
||||
int output2) {
|
||||
interpreter_.reset(new Interpreter);
|
||||
builder_->BuildAccumulateLoopBodySubgraph(
|
||||
&interpreter_->primary_subgraph());
|
||||
|
||||
TfLiteTensor* output = interpreter->tensor(interpreter->outputs()[0]);
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1});
|
||||
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {input1});
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {input2});
|
||||
|
||||
ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
TfLiteTensor* output_tensor1 =
|
||||
interpreter_->tensor(interpreter_->outputs()[0]);
|
||||
CheckIntTensor(output_tensor1, {1}, {output1});
|
||||
TfLiteTensor* output_tensor2 =
|
||||
interpreter_->tensor(interpreter_->outputs()[1]);
|
||||
CheckIntTensor(output_tensor2, {1}, {output2});
|
||||
}
|
||||
|
||||
std::unique_ptr<Interpreter> interpreter_;
|
||||
std::unique_ptr<SubgraphBuilder> builder_;
|
||||
};
|
||||
|
||||
TEST_F(SubgraphBuilderTest, TestBuildAddSubgraph) {
|
||||
builder_->BuildAddSubgraph(&interpreter_->primary_subgraph());
|
||||
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2});
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2});
|
||||
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {5, 7});
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1, 2});
|
||||
ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
|
||||
TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
|
||||
CheckIntTensor(output, {1, 2}, {6, 9});
|
||||
}
|
||||
|
||||
TEST(SubGraphTestUtilTest, TestBuildMulSubgraph) {
|
||||
std::unique_ptr<Interpreter> interpreter(new Interpreter);
|
||||
BuildMulSubgraph(&interpreter->primary_subgraph());
|
||||
TEST_F(SubgraphBuilderTest, TestBuildMulSubgraph) {
|
||||
builder_->BuildMulSubgraph(&interpreter_->primary_subgraph());
|
||||
|
||||
interpreter->ResizeInputTensor(interpreter->inputs()[0], {2});
|
||||
interpreter->ResizeInputTensor(interpreter->inputs()[1], {1, 2});
|
||||
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2});
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2});
|
||||
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||
|
||||
FillIntTensor(interpreter->tensor(interpreter->inputs()[0]), {5, 7});
|
||||
FillIntTensor(interpreter->tensor(interpreter->inputs()[1]), {1, 2});
|
||||
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {5, 7});
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1, 2});
|
||||
ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
|
||||
TfLiteTensor* output = interpreter->tensor(interpreter->outputs()[0]);
|
||||
TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
|
||||
CheckIntTensor(output, {1, 2}, {5, 14});
|
||||
}
|
||||
|
||||
TEST(SubGraphTestUtilTest, TestBuildPadSubgraph) {
|
||||
std::unique_ptr<Interpreter> interpreter(new Interpreter);
|
||||
BuildPadSubgraph(&interpreter->primary_subgraph());
|
||||
TEST_F(SubgraphBuilderTest, TestBuildPadSubgraph) {
|
||||
builder_->BuildPadSubgraph(&interpreter_->primary_subgraph());
|
||||
|
||||
interpreter->ResizeInputTensor(interpreter->inputs()[0], {2});
|
||||
interpreter->ResizeInputTensor(interpreter->inputs()[1], {1, 2});
|
||||
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2});
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2});
|
||||
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||
|
||||
FillIntTensor(interpreter->tensor(interpreter->inputs()[0]), {5, 7});
|
||||
FillIntTensor(interpreter->tensor(interpreter->inputs()[1]), {1, 2});
|
||||
ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {5, 7});
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1, 2});
|
||||
ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
|
||||
TfLiteTensor* output = interpreter->tensor(interpreter->outputs()[0]);
|
||||
TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
|
||||
CheckIntTensor(output, {5}, {0, 5, 7, 0, 0});
|
||||
}
|
||||
|
||||
TEST_F(SubgraphBuilderTest, TestBuildLessEqualCondSubgraph) {
|
||||
builder_->BuildLessEqualCondSubgraph(&interpreter_->primary_subgraph(), 3);
|
||||
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {5});
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {10, 10});
|
||||
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||
|
||||
// Test [1, 2, 3, 4, 5] <= 3 == [true, true, true, false, false]
|
||||
// (with broadcasting).
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]),
|
||||
{1, 2, 3, 4, 5});
|
||||
ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
|
||||
CheckBoolTensor(output, {5}, {true, true, true, false, false});
|
||||
}
|
||||
|
||||
TEST_F(SubgraphBuilderTest, TestBuildAccumulateLoopBodySubgraph) {
|
||||
TestAccumelateLoopBody(1, 1, 2, 3);
|
||||
TestAccumelateLoopBody(2, 3, 3, 6);
|
||||
TestAccumelateLoopBody(3, 6, 4, 10);
|
||||
}
|
||||
|
||||
TEST_F(SubgraphBuilderTest, TestBuildPadLoopBodySubgraph) {
|
||||
builder_->BuildPadLoopBodySubgraph(&interpreter_->primary_subgraph(), {1, 2});
|
||||
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {5});
|
||||
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1});
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]),
|
||||
{0, 5, 7, 0, 0});
|
||||
|
||||
ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]);
|
||||
CheckIntTensor(output1, {1}, {2});
|
||||
TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]);
|
||||
CheckIntTensor(output2, {8}, {0, 0, 5, 7, 0, 0, 0, 0});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace subgraph_test_util
|
||||
} // namespace tflite
|
||||
|
193
tensorflow/lite/kernels/while.cc
Normal file
193
tensorflow/lite/kernels/while.cc
Normal file
@ -0,0 +1,193 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/context_util.h"
|
||||
#include "tensorflow/lite/core/subgraph.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace custom {
|
||||
namespace while_kernel {
|
||||
|
||||
namespace {
|
||||
|
||||
TfLiteStatus ResizeSubgraphInputs(TfLiteContext* context, TfLiteNode* node,
|
||||
Subgraph* subgraph) {
|
||||
int num_inputs = node->inputs->size;
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
const TfLiteTensor* input = GetInput(context, node, i);
|
||||
std::vector<int> dims(input->dims->data,
|
||||
input->dims->data + input->dims->size);
|
||||
subgraph->ResizeInputTensor(i, dims);
|
||||
TfLiteTensor* subgraph_input = subgraph->tensor(subgraph->inputs()[i]);
|
||||
TF_LITE_ENSURE_EQ(context, input->type, subgraph_input->type);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <typename SrcVector, typename DstVector>
|
||||
TfLiteStatus CopyTensors(TfLiteContext* context, Subgraph* src_subgraph,
|
||||
const SrcVector& src_tensor_indices,
|
||||
Subgraph* dst_subgraph,
|
||||
const DstVector& dst_tensor_indices) {
|
||||
TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(),
|
||||
dst_tensor_indices.size());
|
||||
for (int i = 0; i < src_tensor_indices.size(); ++i) {
|
||||
const TfLiteTensor* src_tensor =
|
||||
src_subgraph->tensor(src_tensor_indices[i]);
|
||||
TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]);
|
||||
TF_LITE_ENSURE_EQ(context, src_tensor->bytes, dst_tensor->bytes);
|
||||
memcpy(dst_tensor->data.raw, src_tensor->data.raw, src_tensor->bytes);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
struct OpData {
|
||||
int cond_subgraph_index;
|
||||
int body_subgraph_index;
|
||||
};
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
auto* op_data = new OpData;
|
||||
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
|
||||
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
|
||||
op_data->cond_subgraph_index = m["cond_subgraph_index"].AsInt32();
|
||||
op_data->body_subgraph_index = m["body_subgraph_index"].AsInt32();
|
||||
return op_data;
|
||||
}
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) {
|
||||
delete reinterpret_cast<OpData*>(buffer);
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
int num_inputs = node->inputs->size;
|
||||
// The number of outputs should be the same as number of inputs.
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, num_inputs);
|
||||
|
||||
// Check subgraph indices and get subgraphs.
|
||||
Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
|
||||
auto* subgraphs = this_subgraph->GetSubgraphs();
|
||||
TF_LITE_ENSURE(context, op_data->cond_subgraph_index < subgraphs->size());
|
||||
TF_LITE_ENSURE(context, op_data->body_subgraph_index < subgraphs->size());
|
||||
|
||||
Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get();
|
||||
Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get();
|
||||
|
||||
// Check input & output count of the condition subgraph.
|
||||
TF_LITE_ENSURE_EQ(context, cond_subgraph->inputs().size(), num_inputs);
|
||||
TF_LITE_ENSURE_EQ(context, cond_subgraph->outputs().size(), 1);
|
||||
|
||||
// Check input & output count of the body subgraph.
|
||||
TF_LITE_ENSURE_EQ(context, body_subgraph->inputs().size(), num_inputs);
|
||||
TF_LITE_ENSURE_EQ(context, body_subgraph->outputs().size(), num_inputs);
|
||||
|
||||
// Prepare and check the condition subgraph.
|
||||
ResizeSubgraphInputs(context, node, cond_subgraph);
|
||||
TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
|
||||
TfLiteTensor* cond_output =
|
||||
cond_subgraph->tensor(cond_subgraph->outputs()[0]);
|
||||
// The condition output must be a single boolean value.
|
||||
TF_LITE_ENSURE_EQ(context, cond_output->type, kTfLiteBool);
|
||||
TF_LITE_ENSURE_EQ(context, cond_output->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, cond_output->dims->data[0], 1);
|
||||
|
||||
// TODO(ycling): Handle the case where condition graph has dynamic
|
||||
// sized tensors.
|
||||
|
||||
// Prepare and check the body subgraph.
|
||||
ResizeSubgraphInputs(context, node, body_subgraph);
|
||||
TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors());
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
TfLiteTensor* body_input =
|
||||
body_subgraph->tensor(body_subgraph->inputs()[i]);
|
||||
TfLiteTensor* body_output =
|
||||
body_subgraph->tensor(body_subgraph->outputs()[i]);
|
||||
TF_LITE_ENSURE_EQ(context, body_input->type, body_output->type);
|
||||
|
||||
// TODO(ycling): Support dynamic sized body subgraph.
|
||||
TF_LITE_ENSURE(context, !IsDynamicTensor(body_output));
|
||||
TF_LITE_ENSURE(context,
|
||||
TfLiteIntArrayEqual(body_input->dims, body_output->dims));
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, i);
|
||||
TfLiteIntArray* output_size = TfLiteIntArrayCopy(body_output->dims);
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
context->ResizeTensor(context, output, output_size));
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
|
||||
auto* subgraphs = this_subgraph->GetSubgraphs();
|
||||
Subgraph* cond_subgraph = (*subgraphs)[op_data->cond_subgraph_index].get();
|
||||
Subgraph* body_subgraph = (*subgraphs)[op_data->body_subgraph_index].get();
|
||||
|
||||
// Currently we copy the input / output between the subgraphs. This isn't
|
||||
// optimized yet.
|
||||
// TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
CopyTensors(context, this_subgraph, TfLiteIntArrayView(node->inputs),
|
||||
cond_subgraph, cond_subgraph->inputs()));
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
CopyTensors(context, this_subgraph, TfLiteIntArrayView(node->inputs),
|
||||
body_subgraph, body_subgraph->inputs()));
|
||||
|
||||
while (true) {
|
||||
TF_LITE_ENSURE_OK(context, cond_subgraph->Invoke());
|
||||
TfLiteTensor* cond_output =
|
||||
cond_subgraph->tensor(cond_subgraph->outputs()[0]);
|
||||
if (!cond_output->data.b[0]) {
|
||||
break;
|
||||
}
|
||||
TF_LITE_ENSURE_OK(context, body_subgraph->Invoke());
|
||||
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, CopyTensors(context, body_subgraph, body_subgraph->outputs(),
|
||||
body_subgraph, body_subgraph->inputs()));
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, CopyTensors(context, body_subgraph, body_subgraph->outputs(),
|
||||
cond_subgraph, cond_subgraph->inputs()));
|
||||
}
|
||||
|
||||
// Note that copying from body's output will fail if body is never invoked.
|
||||
// TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, CopyTensors(context, body_subgraph, body_subgraph->inputs(),
|
||||
this_subgraph, TfLiteIntArrayView(node->outputs)));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace while_kernel
|
||||
|
||||
TfLiteRegistration* Register_WHILE() {
|
||||
static TfLiteRegistration r = {while_kernel::Init, while_kernel::Free,
|
||||
while_kernel::Prepare, while_kernel::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
85
tensorflow/lite/kernels/while_test.cc
Normal file
85
tensorflow/lite/kernels/while_test.cc
Normal file
@ -0,0 +1,85 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <gtest/gtest.h>
|
||||
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/subgraph_test_util.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
using subgraph_test_util::CheckIntTensor;
|
||||
using subgraph_test_util::ControlFlowOpTest;
|
||||
using subgraph_test_util::FillIntTensor;
|
||||
|
||||
namespace {
|
||||
|
||||
class WhileTest : public ControlFlowOpTest {};
|
||||
|
||||
// The test builds a model that produces the i-th number of
|
||||
// triangular number sequence.
|
||||
//
|
||||
// TODO(ycling): Consider to improve this test case by adding a
|
||||
// concat into the body subgraph.
|
||||
TEST_F(WhileTest, TestTriangularNumberSequence) {
|
||||
const std::vector<int> expected = {1, 3, 6, 10, 15, 21, 28};
|
||||
for (int i = 0; i < expected.size(); ++i) {
|
||||
interpreter_.reset(new Interpreter);
|
||||
interpreter_->AddSubgraphs(2);
|
||||
builder_->BuildLessEqualCondSubgraph(interpreter_->subgraph(1), i);
|
||||
builder_->BuildAccumulateLoopBodySubgraph(interpreter_->subgraph(2));
|
||||
builder_->BuildWhileSubgraph(&interpreter_->primary_subgraph());
|
||||
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1});
|
||||
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1});
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1});
|
||||
|
||||
ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]);
|
||||
CheckIntTensor(output1, {1}, {i + 1});
|
||||
TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]);
|
||||
CheckIntTensor(output2, {1}, {expected[i]});
|
||||
}
|
||||
}
|
||||
|
||||
// This requires dynamic sized subgraphs and it's not supported right now.
|
||||
// TODO(ycling): Support dynamic sized subgraphs.
|
||||
TEST_F(WhileTest, TestPadLoop) {
|
||||
interpreter_.reset(new Interpreter);
|
||||
interpreter_->AddSubgraphs(2);
|
||||
builder_->BuildLessEqualCondSubgraph(interpreter_->subgraph(1), 3);
|
||||
builder_->BuildPadLoopBodySubgraph(interpreter_->subgraph(2), {1, 2});
|
||||
builder_->BuildWhileSubgraph(&interpreter_->primary_subgraph());
|
||||
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2});
|
||||
// This is not supported yet. The test ensures thatit doesn't crash and raises
|
||||
// an error properly.
|
||||
ASSERT_NE(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user