Fix invoking while multiple times when body is dynamic.

PiperOrigin-RevId: 257459721
This commit is contained in:
Yu-Cheng Ling 2019-07-10 12:37:54 -07:00 committed by TensorFlower Gardener
parent 36e969ee0b
commit 17521bcffd
2 changed files with 45 additions and 19 deletions

View File

@ -28,21 +28,36 @@ namespace {
// Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph`
// to `dst_tensor_indices` in `dst_subgraph`.
//
// When `resize_subgraph_inputs` is true, the function calls subgraphs's
// `ResizeInputTensor` function, and it may trigger the memory planner to
// reallocate memory.
// When `resize_subgraph_inputs` is false, it implies `context` belongs to
// `dst_subgraph`. The function calls `context->ResizeTensor`. This happens
// when resizing `While` op's outputs.
template <typename SrcVector, typename DstVector>
TfLiteStatus CopyTensorsShapeAndType(TfLiteContext* context,
Subgraph* src_subgraph,
const SrcVector& src_tensor_indices,
Subgraph* dst_subgraph,
const DstVector& dst_tensor_indices) {
const DstVector& dst_tensor_indices,
bool resize_subgraph_inputs) {
TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(),
dst_tensor_indices.size());
for (int i = 0; i < src_tensor_indices.size(); ++i) {
const TfLiteTensor* src_tensor =
src_subgraph->tensor(src_tensor_indices[i]);
TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]);
if (resize_subgraph_inputs) {
std::vector<int> dims(src_tensor->dims->data,
src_tensor->dims->data + src_tensor->dims->size);
dst_subgraph->ResizeInputTensor(dst_tensor_indices[i], dims);
TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]);
} else {
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, dst_tensor,
TfLiteIntArrayCopy(src_tensor->dims)));
}
dst_tensor->type = src_tensor->type;
}
return kTfLiteOk;
@ -130,9 +145,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Prepare and check the condition subgraph.
TF_LITE_ENSURE_OK(
context, CopyTensorsShapeAndType(context, this_subgraph,
TfLiteIntArrayView(node->inputs),
cond_subgraph, cond_subgraph->inputs()));
context, CopyTensorsShapeAndType(
context, this_subgraph, TfLiteIntArrayView(node->inputs),
cond_subgraph, cond_subgraph->inputs(), true));
TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
TfLiteTensor* cond_output =
cond_subgraph->tensor(cond_subgraph->outputs()[0]);
@ -148,9 +163,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Prepare and check the body subgraph.
TF_LITE_ENSURE_OK(
context, CopyTensorsShapeAndType(context, this_subgraph,
TfLiteIntArrayView(node->inputs),
body_subgraph, body_subgraph->inputs()));
context, CopyTensorsShapeAndType(
context, this_subgraph, TfLiteIntArrayView(node->inputs),
body_subgraph, body_subgraph->inputs(), true));
TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors());
if (body_subgraph->HasDynamicTensors()) {
op_data->body_has_dynamic_output_tensors = true;
@ -232,6 +247,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// boundary. Currently we copy the input / output between the subgraphs. This
// isn't optimized yet and a lot of redundant copies are made.
// TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
if (op_data->body_has_dynamic_output_tensors) {
// If body subgraph has dynamic outputs, the input of condition subgraph may
// be changed in the last invocation and may need resizing.
TF_LITE_ENSURE_OK(
context, CopyTensorsShapeAndType(
context, this_subgraph, TfLiteIntArrayView(node->inputs),
cond_subgraph, cond_subgraph->inputs(), true));
TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
}
TF_LITE_ENSURE_OK(
context,
CopyTensorsData(context, this_subgraph, TfLiteIntArrayView(node->inputs),
@ -254,7 +279,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
CopyTensorsShapeAndType(
context, cond_subgraph, cond_subgraph->inputs(),
body_subgraph, body_subgraph->inputs()));
body_subgraph, body_subgraph->inputs(), true));
TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors());
}
@ -273,7 +298,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
CopyTensorsShapeAndType(
context, body_subgraph, body_subgraph->outputs(),
cond_subgraph, cond_subgraph->inputs()));
cond_subgraph, cond_subgraph->inputs(), true));
TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
}
@ -287,9 +312,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
if (op_data->body_has_dynamic_output_tensors) {
TF_LITE_ENSURE_OK(
context, CopyTensorsShapeAndType(context, cond_subgraph,
cond_subgraph->inputs(), this_subgraph,
TfLiteIntArrayView(node->outputs)));
context, CopyTensorsShapeAndType(
context, cond_subgraph, cond_subgraph->inputs(),
this_subgraph, TfLiteIntArrayView(node->outputs), false));
}
TF_LITE_ENSURE_OK(

View File

@ -59,8 +59,6 @@ TEST_F(WhileTest, TestTriangularNumberSequence) {
}
}
// This requires dynamic sized subgraphs and it's not supported right now.
// TODO(ycling): Support dynamic sized subgraphs.
TEST_F(WhileTest, TestPadLoop) {
interpreter_.reset(new Interpreter);
interpreter_->AddSubgraphs(2);
@ -70,8 +68,6 @@ TEST_F(WhileTest, TestPadLoop) {
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2});
// This is not supported yet. The test ensures thatit doesn't crash and raises
// an error properly.
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1});
@ -82,6 +78,11 @@ TEST_F(WhileTest, TestPadLoop) {
CheckIntTensor(output1, {1}, {4});
TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]);
CheckIntTensor(output2, {11}, {0, 0, 0, 5, 7, 0, 0, 0, 0, 0, 0});
// The extra invocation serves as a regiression test: There was a bug that
// invoking a while loop with dynamic shaped body makes the interpreter
// state uninvokable.
ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
}
} // namespace