Fix invoking while multiple times when body is dynamic.
PiperOrigin-RevId: 257459721
This commit is contained in:
parent
36e969ee0b
commit
17521bcffd
@ -28,21 +28,36 @@ namespace {
|
||||
|
||||
// Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph`
|
||||
// to `dst_tensor_indices` in `dst_subgraph`.
|
||||
//
|
||||
// When `resize_subgraph_inputs` is true, the function calls subgraphs's
|
||||
// `ResizeInputTensor` function, and it may trigger the memory planner to
|
||||
// reallocate memory.
|
||||
// When `resize_subgraph_inputs` is false, it implies `context` belongs to
|
||||
// `dst_subgraph`. The function calls `context->ResizeTensor`. This happens
|
||||
// when resizing `While` op's outputs.
|
||||
template <typename SrcVector, typename DstVector>
|
||||
TfLiteStatus CopyTensorsShapeAndType(TfLiteContext* context,
|
||||
Subgraph* src_subgraph,
|
||||
const SrcVector& src_tensor_indices,
|
||||
Subgraph* dst_subgraph,
|
||||
const DstVector& dst_tensor_indices) {
|
||||
const DstVector& dst_tensor_indices,
|
||||
bool resize_subgraph_inputs) {
|
||||
TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(),
|
||||
dst_tensor_indices.size());
|
||||
for (int i = 0; i < src_tensor_indices.size(); ++i) {
|
||||
const TfLiteTensor* src_tensor =
|
||||
src_subgraph->tensor(src_tensor_indices[i]);
|
||||
std::vector<int> dims(src_tensor->dims->data,
|
||||
src_tensor->dims->data + src_tensor->dims->size);
|
||||
dst_subgraph->ResizeInputTensor(dst_tensor_indices[i], dims);
|
||||
|
||||
TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]);
|
||||
if (resize_subgraph_inputs) {
|
||||
std::vector<int> dims(src_tensor->dims->data,
|
||||
src_tensor->dims->data + src_tensor->dims->size);
|
||||
dst_subgraph->ResizeInputTensor(dst_tensor_indices[i], dims);
|
||||
} else {
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, dst_tensor,
|
||||
TfLiteIntArrayCopy(src_tensor->dims)));
|
||||
}
|
||||
dst_tensor->type = src_tensor->type;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
@ -130,9 +145,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
// Prepare and check the condition subgraph.
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, CopyTensorsShapeAndType(context, this_subgraph,
|
||||
TfLiteIntArrayView(node->inputs),
|
||||
cond_subgraph, cond_subgraph->inputs()));
|
||||
context, CopyTensorsShapeAndType(
|
||||
context, this_subgraph, TfLiteIntArrayView(node->inputs),
|
||||
cond_subgraph, cond_subgraph->inputs(), true));
|
||||
TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
|
||||
TfLiteTensor* cond_output =
|
||||
cond_subgraph->tensor(cond_subgraph->outputs()[0]);
|
||||
@ -148,9 +163,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
// Prepare and check the body subgraph.
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, CopyTensorsShapeAndType(context, this_subgraph,
|
||||
TfLiteIntArrayView(node->inputs),
|
||||
body_subgraph, body_subgraph->inputs()));
|
||||
context, CopyTensorsShapeAndType(
|
||||
context, this_subgraph, TfLiteIntArrayView(node->inputs),
|
||||
body_subgraph, body_subgraph->inputs(), true));
|
||||
TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors());
|
||||
if (body_subgraph->HasDynamicTensors()) {
|
||||
op_data->body_has_dynamic_output_tensors = true;
|
||||
@ -232,6 +247,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// boundary. Currently we copy the input / output between the subgraphs. This
|
||||
// isn't optimized yet and a lot of redundant copies are made.
|
||||
// TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
|
||||
|
||||
if (op_data->body_has_dynamic_output_tensors) {
|
||||
// If body subgraph has dynamic outputs, the input of condition subgraph may
|
||||
// be changed in the last invocation and may need resizing.
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, CopyTensorsShapeAndType(
|
||||
context, this_subgraph, TfLiteIntArrayView(node->inputs),
|
||||
cond_subgraph, cond_subgraph->inputs(), true));
|
||||
TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
|
||||
}
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
CopyTensorsData(context, this_subgraph, TfLiteIntArrayView(node->inputs),
|
||||
@ -254,7 +279,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
CopyTensorsShapeAndType(
|
||||
context, cond_subgraph, cond_subgraph->inputs(),
|
||||
body_subgraph, body_subgraph->inputs()));
|
||||
body_subgraph, body_subgraph->inputs(), true));
|
||||
TF_LITE_ENSURE_OK(context, body_subgraph->AllocateTensors());
|
||||
}
|
||||
|
||||
@ -273,7 +298,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
CopyTensorsShapeAndType(
|
||||
context, body_subgraph, body_subgraph->outputs(),
|
||||
cond_subgraph, cond_subgraph->inputs()));
|
||||
cond_subgraph, cond_subgraph->inputs(), true));
|
||||
TF_LITE_ENSURE_OK(context, cond_subgraph->AllocateTensors());
|
||||
}
|
||||
|
||||
@ -287,9 +312,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
|
||||
if (op_data->body_has_dynamic_output_tensors) {
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, CopyTensorsShapeAndType(context, cond_subgraph,
|
||||
cond_subgraph->inputs(), this_subgraph,
|
||||
TfLiteIntArrayView(node->outputs)));
|
||||
context, CopyTensorsShapeAndType(
|
||||
context, cond_subgraph, cond_subgraph->inputs(),
|
||||
this_subgraph, TfLiteIntArrayView(node->outputs), false));
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE_OK(
|
||||
|
@ -59,8 +59,6 @@ TEST_F(WhileTest, TestTriangularNumberSequence) {
|
||||
}
|
||||
}
|
||||
|
||||
// This requires dynamic sized subgraphs and it's not supported right now.
|
||||
// TODO(ycling): Support dynamic sized subgraphs.
|
||||
TEST_F(WhileTest, TestPadLoop) {
|
||||
interpreter_.reset(new Interpreter);
|
||||
interpreter_->AddSubgraphs(2);
|
||||
@ -70,8 +68,6 @@ TEST_F(WhileTest, TestPadLoop) {
|
||||
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
|
||||
interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2});
|
||||
// This is not supported yet. The test ensures thatit doesn't crash and raises
|
||||
// an error properly.
|
||||
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||
|
||||
FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1});
|
||||
@ -82,6 +78,11 @@ TEST_F(WhileTest, TestPadLoop) {
|
||||
CheckIntTensor(output1, {1}, {4});
|
||||
TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]);
|
||||
CheckIntTensor(output2, {11}, {0, 0, 0, 5, 7, 0, 0, 0, 0, 0, 0});
|
||||
|
||||
// The extra invocation serves as a regiression test: There was a bug that
|
||||
// invoking a while loop with dynamic shaped body makes the interpreter
|
||||
// state uninvokable.
|
||||
ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user