diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 4859759eba5..8a9a96ce363 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -45,6 +45,18 @@ namespace xla { // TODO(b/150633678): Both the ExecutionInput and ExecutionOutput need to be // revisited, with the execute APIs taking data structure which can better model // shareable buffers. +// +// ExecutionInput buffers are in one of three states: +// +// 1) Owned by the caller and immutable. +// 2) Donated by the caller but returned on error. +// 3) Donated by the caller and freed on error. +// +// Case (1) buffers are stored as MaybeOwningDeviceMemory(DeviceMemoryBase). +// Case (2) buffers are stored as MaybeOwningDeviceMemory(OwningDeviceMemory), +// with their indices present in unowned_indices_. +// Case (3) buffers are stored as MaybeOwningDeviceMemory(OwningDeviceMemory), +// with their indices absent from unowned_indices_. class ExecutionInput { public: ExecutionInput() = default; @@ -80,6 +92,10 @@ class ExecutionInput { unowned_indices_.push_back(index); } + void SetUnownedIndex(const ShapeIndex& index) { + unowned_indices_.push_back(index); + } + const ShapeTree& Buffers() const { return buffers_; } ShapeTree* MutableBuffers() { return &buffers_; } @@ -94,6 +110,8 @@ class ExecutionInput { private: ShapeTree buffers_; + // (Unordered) set of indices of buffers that should be returned to the + // caller if an error occurs when enqueuing the computation. std::vector unowned_indices_; };