Reuse the input buffers when possible in TensorList kernels.

This lets the tf runtime reuse the memory for the input tensor to the output
tensor, preventing allocations and copying, and making the normal runtime
of successive push-backs O(n) instead of O(n^2).

PiperOrigin-RevId: 227056083
This commit is contained in:
Alexandre Passos 2018-12-27 13:07:56 -08:00 committed by TensorFlower Gardener
parent fbf52681e2
commit 331bf7a11a

View File

@ -259,15 +259,22 @@ class TensorListPushBack : public OpKernel {
" max_num_elements: ", l->max_num_elements));
}
AllocatorAttributes attr;
attr.set_on_host(true);
std::unique_ptr<Tensor> maybe_result = c->forward_input(
0, 0, DT_VARIANT, TensorShape{}, c->input_memory_type(0), attr);
if (maybe_result != nullptr) {
maybe_result->scalar<Variant>()().get<TensorList>()->tensors.push_back(
input);
} else {
Tensor* result;
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
TensorList output;
output = *l;
output.tensors.push_back(input);
Tensor* result;
AllocatorAttributes attr;
attr.set_on_host(true);
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
result->scalar<Variant>()() = std::move(output);
}
}
private:
DataType element_dtype_;
@ -384,15 +391,21 @@ class TensorListPopBack : public OpKernel {
errors::InvalidArgument("Trying to pop from an empty list."));
c->set_output(1, l->tensors.back());
AllocatorAttributes attr;
attr.set_on_host(true);
std::unique_ptr<Tensor> maybe_result = c->forward_input(
0, 0, DT_VARIANT, TensorShape{}, c->input_memory_type(0), attr);
if (maybe_result != nullptr) {
maybe_result->scalar<Variant>()().get<TensorList>()->tensors.pop_back();
} else {
TensorList output;
output = *l;
output.tensors.pop_back();
Tensor* result;
AllocatorAttributes attr;
attr.set_on_host(true);
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
result->scalar<Variant>()() = std::move(output);
}
}
private:
DataType element_dtype_;
@ -529,15 +542,22 @@ class TensorListSetItem : public OpKernel {
"list index. Item element shape: ",
value.shape().DebugString(),
" list shape: ", l->element_shape.DebugString()));
AllocatorAttributes attr;
attr.set_on_host(true);
std::unique_ptr<Tensor> maybe_result = c->forward_input(
0, 0, DT_VARIANT, TensorShape{}, c->input_memory_type(0), attr);
if (maybe_result != nullptr) {
maybe_result->scalar<Variant>()().get<TensorList>()->tensors[index] =
value;
} else {
TensorList output;
output = *l;
output.tensors[index] = value;
Tensor* result;
AllocatorAttributes attr;
attr.set_on_host(true);
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
result->scalar<Variant>()() = std::move(output);
}
}
private:
DataType element_dtype_;