Add convenience method to set unowned index in ExecutionInput.
PiperOrigin-RevId: 303747533 Change-Id: I49fda31f8543300982ce5e2a3b737900f4205352
This commit is contained in:
parent
37682e5e46
commit
7d0a91fc48
@ -45,6 +45,18 @@ namespace xla {
|
|||||||
// TODO(b/150633678): Both the ExecutionInput and ExecutionOutput need to be
|
// TODO(b/150633678): Both the ExecutionInput and ExecutionOutput need to be
|
||||||
// revisited, with the execute APIs taking data structure which can better model
|
// revisited, with the execute APIs taking data structure which can better model
|
||||||
// shareable buffers.
|
// shareable buffers.
|
||||||
|
//
|
||||||
|
// ExecutionInput buffers are in one of three states:
|
||||||
|
//
|
||||||
|
// 1) Owned by the caller and immutable.
|
||||||
|
// 2) Donated by the caller but returned on error.
|
||||||
|
// 3) Donated by the caller and freed on error.
|
||||||
|
//
|
||||||
|
// Case (1) buffers are stored as MaybeOwningDeviceMemory(DeviceMemoryBase).
|
||||||
|
// Case (2) buffers are stored as MaybeOwningDeviceMemory(OwningDeviceMemory),
|
||||||
|
// with their indices present in unowned_indices_.
|
||||||
|
// Case (3) buffers are stored as MaybeOwningDeviceMemory(OwningDeviceMemory),
|
||||||
|
// with their indices absent from unowned_indices_.
|
||||||
class ExecutionInput {
|
class ExecutionInput {
|
||||||
public:
|
public:
|
||||||
ExecutionInput() = default;
|
ExecutionInput() = default;
|
||||||
@ -80,6 +92,10 @@ class ExecutionInput {
|
|||||||
unowned_indices_.push_back(index);
|
unowned_indices_.push_back(index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SetUnownedIndex(const ShapeIndex& index) {
|
||||||
|
unowned_indices_.push_back(index);
|
||||||
|
}
|
||||||
|
|
||||||
const ShapeTree<MaybeOwningDeviceMemory>& Buffers() const { return buffers_; }
|
const ShapeTree<MaybeOwningDeviceMemory>& Buffers() const { return buffers_; }
|
||||||
|
|
||||||
ShapeTree<MaybeOwningDeviceMemory>* MutableBuffers() { return &buffers_; }
|
ShapeTree<MaybeOwningDeviceMemory>* MutableBuffers() { return &buffers_; }
|
||||||
@ -94,6 +110,8 @@ class ExecutionInput {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
ShapeTree<MaybeOwningDeviceMemory> buffers_;
|
ShapeTree<MaybeOwningDeviceMemory> buffers_;
|
||||||
|
// (Unordered) set of indices of buffers that should be returned to the
|
||||||
|
// caller if an error occurs when enqueuing the computation.
|
||||||
std::vector<ShapeIndex> unowned_indices_;
|
std::vector<ShapeIndex> unowned_indices_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user