Fix invoking while multiple times when body is dynamic.
PiperOrigin-RevId: 257459721
This commit is contained in:
parent
36e969ee0b
commit
17521bcffd
@ -28,21 +28,36 @@ namespace {
|
|||||||
|
|
||||||
// Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph`
|
// Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph`
|
||||||
// to `dst_tensor_indices` in `dst_subgraph`.
|
// to `dst_tensor_indices` in `dst_subgraph`.
|
||||||
|
//
|
||||||
|
// When `resize_subgraph_inputs` is true, the function calls subgraphs's
|
||||||
|
// `ResizeInputTensor` function, and it may trigger the memory planner to
|
||||||
|
// reallocate memory.
|
||||||
|
// When `resize_subgraph_inputs` is false, it implies `context` belongs to
|
||||||
|
// `dst_subgraph`. The function calls `context->ResizeTensor`. This happens
|
||||||
|
// when resizing `While` op's outputs.
|
||||||
template <typename SrcVector, typename DstVector>
|
template <typename SrcVector, typename DstVector>
|
||||||
TfLiteStatus CopyTensorsShapeAndType(TfLiteContext* context,
|
TfLiteStatus CopyTensorsShapeAndType(TfLiteContext* context,
|
||||||
Subgraph* src_subgraph,
|
Subgraph* src_subgraph,
|
||||||
const SrcVector& src_tensor_indices,
|
const SrcVector& src_tensor_indices,
|
||||||
Subgraph* dst_subgraph,
|
Subgraph* dst_subgraph,
|
||||||
const DstVector& dst_tensor_indices) {
|
const DstVector& dst_tensor_indices,
|
||||||
|
bool resize_subgraph_inputs) {
|
||||||
TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(),
|
TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(),
|
||||||
dst_tensor_indices.size());
|
dst_tensor_indices.size());
|
||||||
for (int i = 0; i < src_tensor_indices.size(); ++i) {
|
for (int i = 0; i < src_tensor_indices.size(); ++i) {
|
||||||
const TfLiteTensor* src_tensor =
|
const TfLiteTensor* src_tensor =
|
||||||
src_subgraph->tensor(src_tensor_indices[i]);
|
src_subgraph->tensor(src_tensor_indices[i]);
|
||||||
std::vector<int> dims(src_tensor->dims->data,
|
|
||||||
src_tensor->dims->data + src_tensor->dims->size);
|
|
||||||
dst_subgraph->ResizeInputTensor(dst_tensor_indices[i], dims);
|
|
||||||
TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]);
|
TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]);
|
||||||
|
if (resize_subgraph_inputs) {
|
||||||
|
std::vector<int> dims(src_tensor->dims->data,
|
||||||
|
src_tensor->dims->data + src_tensor->dims->size);
|
||||||
|
dst_subgraph->ResizeInputTensor(dst_tensor_indices[i], dims);
|
||||||
|
} else {
|
||||||
|
TF_LITE_ENSURE_OK(
|
||||||
|
context, context->ResizeTensor(context, dst_tensor,
|
||||||
|
TfLiteIntArrayCopy(src_tensor->dims)));
|
||||||
|
}
|
||||||
dst_tensor->type = src_tensor->type;
|
dst_tensor->type = src_tensor->type;
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
@ -130,9 +145,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
// Prepare and check the condition subgraph.
|
// Prepare and check the condition subgraph.
|
||||||
TF_LITE_ENSURE_OK(
|
TF_LITE_ENSURE_OK(
|
||||||
context, CopyTensorsShapeAndType(context, this_subgraph,
|
context, CopyTensorsShapeAndType(
|
||||||
TfLiteIntArrayView(node->inputs),
|
context, this_subgraph, TfLiteIntArrayView(node->inputs),
|
||||||
cond_subgraph, cond_subgraph->inputs()));
|
cond_subgraph, cond_subgraph->inputs(), true));
|
||||||
TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
|
TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
|
||||||
TfLiteTensor* cond_output =
|
TfLiteTensor* cond_output =
|
||||||
cond_subgraph->tensor(cond_subgraph->outputs()[0]);
|
cond_subgraph->tensor(cond_subgraph->outputs()[0]);
|
||||||
@ -148,9 +163,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
// Prepare and check the body subgraph.
|
// Prepare and check the body subgraph.
|
||||||
TF_LITE_ENSURE_OK(
|
TF_LITE_ENSURE_OK(
|
||||||
context, CopyTensorsShapeAndType(context, this_subgraph,
|
context, CopyTensorsShapeAndType(
|
||||||
TfLiteIntArrayView(node->inputs),
|
context, this_subgraph, TfLiteIntArrayView(node->inputs),
|
||||||
body_subgraph, body_subgraph->inputs()));
|
body_subgraph, body_subgraph->inputs(), true));
|
||||||
TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors());
|
TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors());
|
||||||
if (body_subgraph->HasDynamicTensors()) {
|
if (body_subgraph->HasDynamicTensors()) {
|
||||||
op_data->body_has_dynamic_output_tensors = true;
|
op_data->body_has_dynamic_output_tensors = true;
|
||||||
@ -232,6 +247,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
// boundary. Currently we copy the input / output between the subgraphs. This
|
// boundary. Currently we copy the input / output between the subgraphs. This
|
||||||
// isn't optimized yet and a lot of redundant copies are made.
|
// isn't optimized yet and a lot of redundant copies are made.
|
||||||
// TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
|
// TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
|
||||||
|
|
||||||
|
if (op_data->body_has_dynamic_output_tensors) {
|
||||||
|
// If body subgraph has dynamic outputs, the input of condition subgraph may
|
||||||
|
// be changed in the last invocation and may need resizing.
|
||||||
|
TF_LITE_ENSURE_OK(
|
||||||
|
context, CopyTensorsShapeAndType(
|
||||||
|
context, this_subgraph, TfLiteIntArrayView(node->inputs),
|
||||||
|
cond_subgraph, cond_subgraph->inputs(), true));
|
||||||
|
TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
|
||||||
|
}
|
||||||
TF_LITE_ENSURE_OK(
|
TF_LITE_ENSURE_OK(
|
||||||
context,
|
context,
|
||||||
CopyTensorsData(context, this_subgraph, TfLiteIntArrayView(node->inputs),
|
CopyTensorsData(context, this_subgraph, TfLiteIntArrayView(node->inputs),
|
||||||
@ -254,7 +279,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(context,
|
TF_LITE_ENSURE_OK(context,
|
||||||
CopyTensorsShapeAndType(
|
CopyTensorsShapeAndType(
|
||||||
context, cond_subgraph, cond_subgraph->inputs(),
|
context, cond_subgraph, cond_subgraph->inputs(),
|
||||||
body_subgraph, body_subgraph->inputs()));
|
body_subgraph, body_subgraph->inputs(), true));
|
||||||
TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors());
|
TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -273,7 +298,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_OK(context,
|
TF_LITE_ENSURE_OK(context,
|
||||||
CopyTensorsShapeAndType(
|
CopyTensorsShapeAndType(
|
||||||
context, body_subgraph, body_subgraph->outputs(),
|
context, body_subgraph, body_subgraph->outputs(),
|
||||||
cond_subgraph, cond_subgraph->inputs()));
|
cond_subgraph, cond_subgraph->inputs(), true));
|
||||||
TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
|
TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -287,9 +312,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
// TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
|
// TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
|
||||||
if (op_data->body_has_dynamic_output_tensors) {
|
if (op_data->body_has_dynamic_output_tensors) {
|
||||||
TF_LITE_ENSURE_OK(
|
TF_LITE_ENSURE_OK(
|
||||||
context, CopyTensorsShapeAndType(context, cond_subgraph,
|
context, CopyTensorsShapeAndType(
|
||||||
cond_subgraph->inputs(), this_subgraph,
|
context, cond_subgraph, cond_subgraph->inputs(),
|
||||||
TfLiteIntArrayView(node->outputs)));
|
this_subgraph, TfLiteIntArrayView(node->outputs), false));
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_LITE_ENSURE_OK(
|
TF_LITE_ENSURE_OK(
|
||||||
|
@ -59,8 +59,6 @@ TEST_F(WhileTest, TestTriangularNumberSequence) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This requires dynamic sized subgraphs and it's not supported right now.
|
|
||||||
// TODO(ycling): Support dynamic sized subgraphs.
|
|
||||||
TEST_F(WhileTest, TestPadLoop) {
|
TEST_F(WhileTest, TestPadLoop) {
|
||||||
interpreter_.reset(new Interpreter);
|
interpreter_.reset(new Interpreter);
|
||||||
interpreter_->AddSubgraphs(2);
|
interpreter_->AddSubgraphs(2);
|
||||||
@ -70,8 +68,6 @@ TEST_F(WhileTest, TestPadLoop) {
|
|||||||
|
|
||||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
|
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
|
||||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2});
|
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2});
|
||||||
// This is not supported yet. The test ensures thatit doesn't crash and raises
|
|
||||||
// an error properly.
|
|
||||||
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||||
|
|
||||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1});
|
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1});
|
||||||
@ -82,6 +78,11 @@ TEST_F(WhileTest, TestPadLoop) {
|
|||||||
CheckIntTensor(output1, {1}, {4});
|
CheckIntTensor(output1, {1}, {4});
|
||||||
TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]);
|
TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]);
|
||||||
CheckIntTensor(output2, {11}, {0, 0, 0, 5, 7, 0, 0, 0, 0, 0, 0});
|
CheckIntTensor(output2, {11}, {0, 0, 0, 5, 7, 0, 0, 0, 0, 0, 0});
|
||||||
|
|
||||||
|
// The extra invocation serves as a regiression test: There was a bug that
|
||||||
|
// invoking a while loop with dynamic shaped body makes the interpreter
|
||||||
|
// state uninvokable.
|
||||||
|
ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
Loading…
Reference in New Issue
Block a user