Reuse the input buffers when possible in TensorList kernels.
This lets the tf runtime reuse the memory for the input tensor to the output tensor, preventing allocations and copying, and making the normal runtime of successive push-backs O(n) instead of O(n^2). PiperOrigin-RevId: 227056083
This commit is contained in:
parent
fbf52681e2
commit
331bf7a11a
@ -259,15 +259,22 @@ class TensorListPushBack : public OpKernel {
|
||||
" max_num_elements: ", l->max_num_elements));
|
||||
}
|
||||
|
||||
AllocatorAttributes attr;
|
||||
attr.set_on_host(true);
|
||||
std::unique_ptr<Tensor> maybe_result = c->forward_input(
|
||||
0, 0, DT_VARIANT, TensorShape{}, c->input_memory_type(0), attr);
|
||||
if (maybe_result != nullptr) {
|
||||
maybe_result->scalar<Variant>()().get<TensorList>()->tensors.push_back(
|
||||
input);
|
||||
} else {
|
||||
Tensor* result;
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
|
||||
TensorList output;
|
||||
output = *l;
|
||||
output.tensors.push_back(input);
|
||||
Tensor* result;
|
||||
AllocatorAttributes attr;
|
||||
attr.set_on_host(true);
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
|
||||
result->scalar<Variant>()() = std::move(output);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
DataType element_dtype_;
|
||||
@ -384,15 +391,21 @@ class TensorListPopBack : public OpKernel {
|
||||
errors::InvalidArgument("Trying to pop from an empty list."));
|
||||
|
||||
c->set_output(1, l->tensors.back());
|
||||
AllocatorAttributes attr;
|
||||
attr.set_on_host(true);
|
||||
std::unique_ptr<Tensor> maybe_result = c->forward_input(
|
||||
0, 0, DT_VARIANT, TensorShape{}, c->input_memory_type(0), attr);
|
||||
if (maybe_result != nullptr) {
|
||||
maybe_result->scalar<Variant>()().get<TensorList>()->tensors.pop_back();
|
||||
} else {
|
||||
TensorList output;
|
||||
output = *l;
|
||||
output.tensors.pop_back();
|
||||
Tensor* result;
|
||||
AllocatorAttributes attr;
|
||||
attr.set_on_host(true);
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
|
||||
result->scalar<Variant>()() = std::move(output);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
DataType element_dtype_;
|
||||
@ -529,15 +542,22 @@ class TensorListSetItem : public OpKernel {
|
||||
"list index. Item element shape: ",
|
||||
value.shape().DebugString(),
|
||||
" list shape: ", l->element_shape.DebugString()));
|
||||
AllocatorAttributes attr;
|
||||
attr.set_on_host(true);
|
||||
std::unique_ptr<Tensor> maybe_result = c->forward_input(
|
||||
0, 0, DT_VARIANT, TensorShape{}, c->input_memory_type(0), attr);
|
||||
if (maybe_result != nullptr) {
|
||||
maybe_result->scalar<Variant>()().get<TensorList>()->tensors[index] =
|
||||
value;
|
||||
} else {
|
||||
TensorList output;
|
||||
output = *l;
|
||||
output.tensors[index] = value;
|
||||
Tensor* result;
|
||||
AllocatorAttributes attr;
|
||||
attr.set_on_host(true);
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
|
||||
result->scalar<Variant>()() = std::move(output);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
DataType element_dtype_;
|
||||
|
Loading…
Reference in New Issue
Block a user