Fix invoking while multiple times when body is dynamic.

PiperOrigin-RevId: 257459721
This commit is contained in:
Yu-Cheng Ling 2019-07-10 12:37:54 -07:00 committed by TensorFlower Gardener
parent 36e969ee0b
commit 17521bcffd
2 changed files with 45 additions and 19 deletions

View File

@ -28,21 +28,36 @@ namespace {
// Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph` // Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph`
// to `dst_tensor_indices` in `dst_subgraph`. // to `dst_tensor_indices` in `dst_subgraph`.
//
// When `resize_subgraph_inputs` is true, the function calls subgraphs's
// `ResizeInputTensor` function, and it may trigger the memory planner to
// reallocate memory.
// When `resize_subgraph_inputs` is false, it implies `context` belongs to
// `dst_subgraph`. The function calls `context->ResizeTensor`. This happens
// when resizing `While` op's outputs.
template <typename SrcVector, typename DstVector> template <typename SrcVector, typename DstVector>
TfLiteStatus CopyTensorsShapeAndType(TfLiteContext* context, TfLiteStatus CopyTensorsShapeAndType(TfLiteContext* context,
Subgraph* src_subgraph, Subgraph* src_subgraph,
const SrcVector& src_tensor_indices, const SrcVector& src_tensor_indices,
Subgraph* dst_subgraph, Subgraph* dst_subgraph,
const DstVector& dst_tensor_indices) { const DstVector& dst_tensor_indices,
bool resize_subgraph_inputs) {
TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(), TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(),
dst_tensor_indices.size()); dst_tensor_indices.size());
for (int i = 0; i < src_tensor_indices.size(); ++i) { for (int i = 0; i < src_tensor_indices.size(); ++i) {
const TfLiteTensor* src_tensor = const TfLiteTensor* src_tensor =
src_subgraph->tensor(src_tensor_indices[i]); src_subgraph->tensor(src_tensor_indices[i]);
TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]);
if (resize_subgraph_inputs) {
std::vector<int> dims(src_tensor->dims->data, std::vector<int> dims(src_tensor->dims->data,
src_tensor->dims->data + src_tensor->dims->size); src_tensor->dims->data + src_tensor->dims->size);
dst_subgraph->ResizeInputTensor(dst_tensor_indices[i], dims); dst_subgraph->ResizeInputTensor(dst_tensor_indices[i], dims);
TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]); } else {
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, dst_tensor,
TfLiteIntArrayCopy(src_tensor->dims)));
}
dst_tensor->type = src_tensor->type; dst_tensor->type = src_tensor->type;
} }
return kTfLiteOk; return kTfLiteOk;
@ -130,9 +145,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Prepare and check the condition subgraph. // Prepare and check the condition subgraph.
TF_LITE_ENSURE_OK( TF_LITE_ENSURE_OK(
context, CopyTensorsShapeAndType(context, this_subgraph, context, CopyTensorsShapeAndType(
TfLiteIntArrayView(node->inputs), context, this_subgraph, TfLiteIntArrayView(node->inputs),
cond_subgraph, cond_subgraph->inputs())); cond_subgraph, cond_subgraph->inputs(), true));
TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors()); TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
TfLiteTensor* cond_output = TfLiteTensor* cond_output =
cond_subgraph->tensor(cond_subgraph->outputs()[0]); cond_subgraph->tensor(cond_subgraph->outputs()[0]);
@ -148,9 +163,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Prepare and check the body subgraph. // Prepare and check the body subgraph.
TF_LITE_ENSURE_OK( TF_LITE_ENSURE_OK(
context, CopyTensorsShapeAndType(context, this_subgraph, context, CopyTensorsShapeAndType(
TfLiteIntArrayView(node->inputs), context, this_subgraph, TfLiteIntArrayView(node->inputs),
body_subgraph, body_subgraph->inputs())); body_subgraph, body_subgraph->inputs(), true));
TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors()); TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors());
if (body_subgraph->HasDynamicTensors()) { if (body_subgraph->HasDynamicTensors()) {
op_data->body_has_dynamic_output_tensors = true; op_data->body_has_dynamic_output_tensors = true;
@ -232,6 +247,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// boundary. Currently we copy the input / output between the subgraphs. This // boundary. Currently we copy the input / output between the subgraphs. This
// isn't optimized yet and a lot of redundant copies are made. // isn't optimized yet and a lot of redundant copies are made.
// TODO(b/120234921): Optimize and avoid copying tensors between subgraphs. // TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
if (op_data->body_has_dynamic_output_tensors) {
// If body subgraph has dynamic outputs, the input of condition subgraph may
// be changed in the last invocation and may need resizing.
TF_LITE_ENSURE_OK(
context, CopyTensorsShapeAndType(
context, this_subgraph, TfLiteIntArrayView(node->inputs),
cond_subgraph, cond_subgraph->inputs(), true));
TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
}
TF_LITE_ENSURE_OK( TF_LITE_ENSURE_OK(
context, context,
CopyTensorsData(context, this_subgraph, TfLiteIntArrayView(node->inputs), CopyTensorsData(context, this_subgraph, TfLiteIntArrayView(node->inputs),
@ -254,7 +279,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, TF_LITE_ENSURE_OK(context,
CopyTensorsShapeAndType( CopyTensorsShapeAndType(
context, cond_subgraph, cond_subgraph->inputs(), context, cond_subgraph, cond_subgraph->inputs(),
body_subgraph, body_subgraph->inputs())); body_subgraph, body_subgraph->inputs(), true));
TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors()); TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors());
} }
@ -273,7 +298,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, TF_LITE_ENSURE_OK(context,
CopyTensorsShapeAndType( CopyTensorsShapeAndType(
context, body_subgraph, body_subgraph->outputs(), context, body_subgraph, body_subgraph->outputs(),
cond_subgraph, cond_subgraph->inputs())); cond_subgraph, cond_subgraph->inputs(), true));
TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors()); TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
} }
@ -287,9 +312,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(b/120234921): Optimize and avoid copying tensors between subgraphs. // TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
if (op_data->body_has_dynamic_output_tensors) { if (op_data->body_has_dynamic_output_tensors) {
TF_LITE_ENSURE_OK( TF_LITE_ENSURE_OK(
context, CopyTensorsShapeAndType(context, cond_subgraph, context, CopyTensorsShapeAndType(
cond_subgraph->inputs(), this_subgraph, context, cond_subgraph, cond_subgraph->inputs(),
TfLiteIntArrayView(node->outputs))); this_subgraph, TfLiteIntArrayView(node->outputs), false));
} }
TF_LITE_ENSURE_OK( TF_LITE_ENSURE_OK(

View File

@ -59,8 +59,6 @@ TEST_F(WhileTest, TestTriangularNumberSequence) {
} }
} }
// This requires dynamic sized subgraphs and it's not supported right now.
// TODO(ycling): Support dynamic sized subgraphs.
TEST_F(WhileTest, TestPadLoop) { TEST_F(WhileTest, TestPadLoop) {
interpreter_.reset(new Interpreter); interpreter_.reset(new Interpreter);
interpreter_->AddSubgraphs(2); interpreter_->AddSubgraphs(2);
@ -70,8 +68,6 @@ TEST_F(WhileTest, TestPadLoop) {
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}); interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2}); interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2});
// This is not supported yet. The test ensures thatit doesn't crash and raises
// an error properly.
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1}); FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1});
@ -82,6 +78,11 @@ TEST_F(WhileTest, TestPadLoop) {
CheckIntTensor(output1, {1}, {4}); CheckIntTensor(output1, {1}, {4});
TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]); TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]);
CheckIntTensor(output2, {11}, {0, 0, 0, 5, 7, 0, 0, 0, 0, 0, 0}); CheckIntTensor(output2, {11}, {0, 0, 0, 5, 7, 0, 0, 0, 0, 0, 0});
// The extra invocation serves as a regiression test: There was a bug that
// invoking a while loop with dynamic shaped body makes the interpreter
// state uninvokable.
ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
} }
} // namespace } // namespace