Reuse the input buffers when possible in TensorList kernels.
This lets the tf runtime reuse the memory for the input tensor to the output tensor, preventing allocations and copying, and making the normal runtime of successive push-backs O(n) instead of O(n^2). PiperOrigin-RevId: 227056083
This commit is contained in:
parent
fbf52681e2
commit
331bf7a11a
@ -259,14 +259,21 @@ class TensorListPushBack : public OpKernel {
|
||||
" max_num_elements: ", l->max_num_elements));
|
||||
}
|
||||
|
||||
TensorList output;
|
||||
output = *l;
|
||||
output.tensors.push_back(input);
|
||||
Tensor* result;
|
||||
AllocatorAttributes attr;
|
||||
attr.set_on_host(true);
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
|
||||
result->scalar<Variant>()() = std::move(output);
|
||||
std::unique_ptr<Tensor> maybe_result = c->forward_input(
|
||||
0, 0, DT_VARIANT, TensorShape{}, c->input_memory_type(0), attr);
|
||||
if (maybe_result != nullptr) {
|
||||
maybe_result->scalar<Variant>()().get<TensorList>()->tensors.push_back(
|
||||
input);
|
||||
} else {
|
||||
Tensor* result;
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
|
||||
TensorList output;
|
||||
output = *l;
|
||||
output.tensors.push_back(input);
|
||||
result->scalar<Variant>()() = std::move(output);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
@ -384,14 +391,20 @@ class TensorListPopBack : public OpKernel {
|
||||
errors::InvalidArgument("Trying to pop from an empty list."));
|
||||
|
||||
c->set_output(1, l->tensors.back());
|
||||
TensorList output;
|
||||
output = *l;
|
||||
output.tensors.pop_back();
|
||||
Tensor* result;
|
||||
AllocatorAttributes attr;
|
||||
attr.set_on_host(true);
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
|
||||
result->scalar<Variant>()() = std::move(output);
|
||||
std::unique_ptr<Tensor> maybe_result = c->forward_input(
|
||||
0, 0, DT_VARIANT, TensorShape{}, c->input_memory_type(0), attr);
|
||||
if (maybe_result != nullptr) {
|
||||
maybe_result->scalar<Variant>()().get<TensorList>()->tensors.pop_back();
|
||||
} else {
|
||||
TensorList output;
|
||||
output = *l;
|
||||
output.tensors.pop_back();
|
||||
Tensor* result;
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
|
||||
result->scalar<Variant>()() = std::move(output);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
@ -529,14 +542,21 @@ class TensorListSetItem : public OpKernel {
|
||||
"list index. Item element shape: ",
|
||||
value.shape().DebugString(),
|
||||
" list shape: ", l->element_shape.DebugString()));
|
||||
TensorList output;
|
||||
output = *l;
|
||||
output.tensors[index] = value;
|
||||
Tensor* result;
|
||||
AllocatorAttributes attr;
|
||||
attr.set_on_host(true);
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
|
||||
result->scalar<Variant>()() = std::move(output);
|
||||
std::unique_ptr<Tensor> maybe_result = c->forward_input(
|
||||
0, 0, DT_VARIANT, TensorShape{}, c->input_memory_type(0), attr);
|
||||
if (maybe_result != nullptr) {
|
||||
maybe_result->scalar<Variant>()().get<TensorList>()->tensors[index] =
|
||||
value;
|
||||
} else {
|
||||
TensorList output;
|
||||
output = *l;
|
||||
output.tensors[index] = value;
|
||||
Tensor* result;
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
|
||||
result->scalar<Variant>()() = std::move(output);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
Loading…
Reference in New Issue
Block a user