Reuse the input buffers when possible in TensorList kernels.

This lets the tf runtime reuse the memory for the input tensor to the output
tensor, preventing allocations and copying, and making the normal runtime
of successive push-backs O(n) instead of O(n^2).

PiperOrigin-RevId: 227056083
This commit is contained in:
Alexandre Passos 2018-12-27 13:07:56 -08:00 committed by TensorFlower Gardener
parent fbf52681e2
commit 331bf7a11a

View File

@ -259,14 +259,21 @@ class TensorListPushBack : public OpKernel {
" max_num_elements: ", l->max_num_elements));
}
TensorList output;
output = *l;
output.tensors.push_back(input);
Tensor* result;
AllocatorAttributes attr;
attr.set_on_host(true);
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
result->scalar<Variant>()() = std::move(output);
std::unique_ptr<Tensor> maybe_result = c->forward_input(
0, 0, DT_VARIANT, TensorShape{}, c->input_memory_type(0), attr);
if (maybe_result != nullptr) {
maybe_result->scalar<Variant>()().get<TensorList>()->tensors.push_back(
input);
} else {
Tensor* result;
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
TensorList output;
output = *l;
output.tensors.push_back(input);
result->scalar<Variant>()() = std::move(output);
}
}
private:
@ -384,14 +391,20 @@ class TensorListPopBack : public OpKernel {
errors::InvalidArgument("Trying to pop from an empty list."));
c->set_output(1, l->tensors.back());
TensorList output;
output = *l;
output.tensors.pop_back();
Tensor* result;
AllocatorAttributes attr;
attr.set_on_host(true);
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
result->scalar<Variant>()() = std::move(output);
std::unique_ptr<Tensor> maybe_result = c->forward_input(
0, 0, DT_VARIANT, TensorShape{}, c->input_memory_type(0), attr);
if (maybe_result != nullptr) {
maybe_result->scalar<Variant>()().get<TensorList>()->tensors.pop_back();
} else {
TensorList output;
output = *l;
output.tensors.pop_back();
Tensor* result;
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
result->scalar<Variant>()() = std::move(output);
}
}
private:
@ -529,14 +542,21 @@ class TensorListSetItem : public OpKernel {
"list index. Item element shape: ",
value.shape().DebugString(),
" list shape: ", l->element_shape.DebugString()));
TensorList output;
output = *l;
output.tensors[index] = value;
Tensor* result;
AllocatorAttributes attr;
attr.set_on_host(true);
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
result->scalar<Variant>()() = std::move(output);
std::unique_ptr<Tensor> maybe_result = c->forward_input(
0, 0, DT_VARIANT, TensorShape{}, c->input_memory_type(0), attr);
if (maybe_result != nullptr) {
maybe_result->scalar<Variant>()().get<TensorList>()->tensors[index] =
value;
} else {
TensorList output;
output = *l;
output.tensors[index] = value;
Tensor* result;
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
result->scalar<Variant>()() = std::move(output);
}
}
private: