Add convenience method to set unowned index in ExecutionInput.

PiperOrigin-RevId: 303747533
Change-Id: I49fda31f8543300982ce5e2a3b737900f4205352
This commit is contained in:
A. Unique TensorFlower 2020-03-30 08:39:28 -07:00 committed by TensorFlower Gardener
parent 37682e5e46
commit 7d0a91fc48

View File

@ -45,6 +45,18 @@ namespace xla {
// TODO(b/150633678): Both the ExecutionInput and ExecutionOutput need to be
// revisited, with the execute APIs taking data structure which can better model
// shareable buffers.
//
// ExecutionInput buffers are in one of three states:
//
// 1) Owned by the caller and immutable.
// 2) Donated by the caller but returned on error.
// 3) Donated by the caller and freed on error.
//
// Case (1) buffers are stored as MaybeOwningDeviceMemory(DeviceMemoryBase).
// Case (2) buffers are stored as MaybeOwningDeviceMemory(OwningDeviceMemory),
// with their indices present in unowned_indices_.
// Case (3) buffers are stored as MaybeOwningDeviceMemory(OwningDeviceMemory),
// with their indices absent from unowned_indices_.
class ExecutionInput {
public:
ExecutionInput() = default;
@ -80,6 +92,10 @@ class ExecutionInput {
unowned_indices_.push_back(index);
}
void SetUnownedIndex(const ShapeIndex& index) {
unowned_indices_.push_back(index);
}
const ShapeTree<MaybeOwningDeviceMemory>& Buffers() const { return buffers_; }
ShapeTree<MaybeOwningDeviceMemory>* MutableBuffers() { return &buffers_; }
@ -94,6 +110,8 @@ class ExecutionInput {
private:
ShapeTree<MaybeOwningDeviceMemory> buffers_;
// (Unordered) set of indices of buffers that should be returned to the
// caller if an error occurs when enqueuing the computation.
std::vector<ShapeIndex> unowned_indices_;
};