Add buffer aliasing/donation support to PyLocalClient. Works only for

TPU backend since XLA does not support aliasing on other backends.

PiperOrigin-RevId: 306267232
Change-Id: Ib2bea9ba79087afdb81caa0e3c80dea62d9b7d7e
This commit is contained in:
A. Unique TensorFlower 2020-04-13 11:05:00 -07:00 committed by TensorFlower Gardener
parent 1d939d63fb
commit ab4462cc09
5 changed files with 54 additions and 311 deletions

View File

@ -197,7 +197,6 @@ cc_library(
"//tensorflow/compiler/xla/python/distributed:protocol_proto_cc",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
@ -208,7 +207,6 @@ cc_library(
"//tensorflow/stream_executor:stream",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",

View File

@ -70,7 +70,6 @@ limitations under the License.
#include <vector>
#include "absl/base/casts.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_format.h"
@ -87,7 +86,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/python/local_device_state.h"
#include "tensorflow/compiler/xla/python/shared_device_buffer.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/shape_util.h"
@ -197,21 +195,6 @@ StatusOr<DeviceAssignment> PyLocalClient::GetDefaultDeviceAssignment(
num_partitions);
}
StatusOr<absl::flat_hash_set<int>>
PyLocalClient::GetParametersThatMustBeDonated(const LocalExecutable& executable,
bool tuple_inputs) const {
// TODO(b/149489114) support buffer donation on CPU/GPU when XLA supports it.
const HloInputOutputAliasConfig& config =
executable.executable()->module().input_output_alias_config();
TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
[](const ShapeIndex& output_index,
const HloInputOutputAliasConfig::Alias& alias) {
return InvalidArgument(
"Buffer aliasing is not supported by XLA for non-TPU backends.");
}));
return absl::flat_hash_set<int>();
}
namespace {
// Ensures that it is safe to deallocate any buffers that have been enqueued in
@ -457,25 +440,14 @@ void PyLocalBuffer::ScopedHold::ConvertUsageHold(
SetError(InvalidArgument("Buffer has been converted"));
}
void PyLocalBuffer::ScopedHold::ConfirmDonation() {
CHECK(ok());
CHECK(type_ == kDonation);
parent_->ConfirmDonation(buffer().get());
SetError(InvalidArgument("Buffer has been donated"));
}
void PyLocalBuffer::ScopedHold::AddToInput(
ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
ExecutionInput* execution_input,
se::DeviceMemoryAllocator* allocator) const {
CHECK(ok());
if (type_ == kDonation) {
buffer()->AddToInputAsDonated(iterator, end, execution_input, allocator);
} else {
CHECK(type_ == kUsage);
buffer()->AddToInputAsImmutable(iterator, end);
}
CHECK(type_ == kUsage);
buffer()->AddToInputAsImmutable(iterator, end);
}
/* static */
@ -681,8 +653,7 @@ PyLocalBuffer::PyLocalBuffer(Shape on_host_shape, Shape on_device_shape,
on_host_shape_(std::move(on_host_shape)),
on_device_shape_(std::move(on_device_shape)),
device_(device),
device_buffer_(std::move(device_buffer)),
donation_semaphore_(/*capacity=*/1) {
device_buffer_(std::move(device_buffer)) {
for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) {
holds_[i] = 0;
}
@ -695,41 +666,24 @@ PyLocalBuffer::~PyLocalBuffer() {
}
}
void PyLocalBuffer::WaitForOutstandingUsageHolds() {
auto not_in_usage_hold = [&]() {
StatusOr<std::shared_ptr<SharedDeviceBuffer>> PyLocalBuffer::Release(
bool wait_for_operations_to_complete) {
auto no_usage_holds = [&]() {
mu_.AssertHeld();
return holds_[ScopedHold::kUsage] == 0;
};
mu_.Await(absl::Condition(&not_in_usage_hold));
}
void PyLocalBuffer::WaitForOutstandingDonationHold() {
auto not_in_donation_hold = [&]() {
mu_.AssertHeld();
return holds_[ScopedHold::kDonation] == 0;
};
mu_.Await(absl::Condition(&not_in_donation_hold));
}
StatusOr<std::shared_ptr<SharedDeviceBuffer>> PyLocalBuffer::Release(
bool wait_for_operations_to_complete) {
std::shared_ptr<SharedDeviceBuffer> device_buffer;
SharedDeviceBuffer::StreamAndEventContainer events;
{
absl::MutexLock lock(&mu_);
// We first wait for a donation hold to complete if there is one in
// progress. If the donation succeeds via ConfirmDonation() then it will
// set device_buffer_ to nullptr before returning to this thread.
WaitForOutstandingDonationHold();
if (device_buffer_ == nullptr) {
return std::shared_ptr<SharedDeviceBuffer>();
}
// Set host_value_ and device_buffer_ to null now so that no other thread
// can add a hold while we are in WaitForOutstandingUsageHolds()
// below.
// can add a hold while we are in Await below.
host_value_ = nullptr;
std::swap(device_buffer_, device_buffer);
WaitForOutstandingUsageHolds();
mu_.Await(absl::Condition(&no_usage_holds));
// Now that all holds have completed and no more can be added, we can get
// the final set of usage events.
events = device_buffer->LockUseAndTransferUsageEvents();
@ -799,35 +753,10 @@ bool PyLocalBuffer::IsDeleted() {
StatusOr<std::shared_ptr<SharedDeviceBuffer>>
PyLocalBuffer::GetBufferForHoldLocked(ScopedHold::Type type) {
if (type == ScopedHold::kDonation) {
if (device_buffer_ == nullptr) {
return InvalidArgument("Donation requested for invalid buffer");
}
if (holds_[ScopedHold::kExternalReference] > 0) {
return InvalidArgument(
"Donation requested for buffer with external reference");
}
// donation_semaphore_ was acquired in GetBufferWithHold so that only one
// thread at a time can attempt to get a donation hold.
CHECK_EQ(holds_[type], 0);
// First add the donation hold.
++holds_[type];
// Then wait for any usage holds to be dropped or converted. No new usage
// holds can be added until we drop the donation hold so this wait will
// complete eventually.
WaitForOutstandingUsageHolds();
// Because we added a donation hold, nobody could release the buffer while
// we were waiting.
CHECK(device_buffer_ != nullptr);
if (device_buffer_ == nullptr) {
return InvalidArgument("Hold requested on invalid buffer");
} else {
// If there is a donation hold in progress we have to wait before
// acquiring any other kind of hold.
WaitForOutstandingDonationHold();
if (device_buffer_ == nullptr) {
return InvalidArgument("Hold requested on invalid buffer");
} else {
++holds_[type];
}
++holds_[type];
}
return device_buffer_;
}
@ -846,40 +775,12 @@ void PyLocalBuffer::ConvertUsageHold(
--holds_[ScopedHold::kUsage];
}
void PyLocalBuffer::ConfirmDonation(SharedDeviceBuffer* device_buffer) {
{
absl::MutexLock lock(&mu_);
CHECK_EQ(holds_[ScopedHold::kUsage], 0);
CHECK_EQ(holds_[ScopedHold::kExternalReference], 0);
CHECK_EQ(holds_[ScopedHold::kDonation], 1);
holds_[ScopedHold::kDonation] = 0;
CHECK(device_buffer_.get() == device_buffer);
// As a sanity check ensure no more usage events can be added to the buffer.
device_buffer->LockUseAndTransferUsageEvents();
// Give up ownership of the device memory so we don't free it when the last
// reference to device_buffer_ goes away.
device_buffer->ReleaseDeviceMemory();
// Make *this invalid so it can't be used again. Any threads blocking in
// Release or GetBufferWithHold will see an invalid buffer and return.
host_value_ = nullptr;
device_buffer_.reset();
}
// Unblock another thread, if any, trying to get a donation hold.
donation_semaphore_.Release(1);
}
void PyLocalBuffer::DropHold(ScopedHold::Type type,
SharedDeviceBuffer* buffer) {
absl::MutexLock lock(&mu_);
CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr);
CHECK_GT(holds_[type], 0);
--holds_[type];
if (type == ScopedHold::kDonation) {
CHECK_EQ(holds_[ScopedHold::kDonation], 0);
CHECK_EQ(holds_[ScopedHold::kUsage], 0);
CHECK_EQ(holds_[ScopedHold::kExternalReference], 0);
donation_semaphore_.Release(1);
}
}
Status PyLocalBuffer::CopyToHostAsync() {
@ -892,8 +793,6 @@ Status PyLocalBuffer::CopyToHostAsync() {
se::Stream* stream = local_device->GetDeviceToHostStream();
{
absl::MutexLock lock(&mu_);
// We can't perform any other action while a donation hold is in progress.
WaitForOutstandingDonationHold();
if (device_buffer_ == nullptr) {
return InvalidArgument("CopyToHostAsync() called on invalid buffer.");
}
@ -968,16 +867,9 @@ StatusOr<ShapedBuffer> PyLocalBuffer::AsShapedBuffer() const {
PyLocalBuffer::ScopedHold PyLocalBuffer::GetBufferWithHold(
ScopedHold::Type type) {
if (type == ScopedHold::kDonation) {
// Ensure that at most one donation hold can be in progress at a time.
donation_semaphore_.Acquire(1);
}
absl::MutexLock lock(&mu_);
ScopedHold hold(this, type);
AcquireHoldLocked(&hold);
if (type == ScopedHold::kDonation && !hold.status().ok()) {
donation_semaphore_.Release(1);
}
return hold;
}
@ -1064,8 +956,6 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
ScopedHold src_device_buffer(this, ScopedHold::kUsage);
{
absl::MutexLock lock(&mu_);
// We can't perform any other action while a donation hold is in progress.
WaitForOutstandingDonationHold();
if (device_buffer_ == nullptr) {
return InvalidArgument("CopyToDevice called on invalid buffer");
}
@ -1281,19 +1171,6 @@ PyLocalExecutable::PyLocalExecutable(
<< "Inconsistent local device count.";
}
Status PyLocalExecutable::SetUpDonation(PyLocalClient* client,
bool tuple_inputs) {
parameters_that_must_be_donated_.reserve(executables_.size());
for (auto& executable : executables_) {
TF_ASSIGN_OR_RETURN(
absl::flat_hash_set<int> parameters_to_donate,
client->GetParametersThatMustBeDonated(*executable, tuple_inputs));
parameters_that_must_be_donated_.emplace_back(
std::move(parameters_to_donate));
}
return Status::OK();
}
const std::string& PyLocalExecutable::name() const {
Executable* executable = executables_[0]->executable();
if (executable->has_module()) {
@ -1325,8 +1202,6 @@ StatusOr<ScopedShapedBuffer> PyLocalExecutable::EnqueueExecution(
std::vector<const Shape*> argument_host_shapes;
std::vector<ExecutionInput> execution_inputs;
device_buffers->reserve(argument_handles.size());
const absl::flat_hash_set<int>& parameters_that_must_be_donated =
parameters_that_must_be_donated_[executable_idx];
for (int i = 0; i < argument_handles.size(); ++i) {
PyLocalBuffer* handle = argument_handles[i];
if (handle->device() != device) {
@ -1335,11 +1210,7 @@ StatusOr<ScopedShapedBuffer> PyLocalExecutable::EnqueueExecution(
"device %s, but replica is assigned to device %s.",
i, replica, handle->device()->DebugString(), device->DebugString());
}
bool must_donate = parameters_that_must_be_donated.find(i) !=
parameters_that_must_be_donated.end();
device_buffers->emplace_back(handle->GetBufferWithHold(
must_donate ? PyLocalBuffer::ScopedHold::kDonation
: PyLocalBuffer::ScopedHold::kUsage));
device_buffers->emplace_back(handle->GetBufferWithUsageHold());
PyLocalBuffer::ScopedHold& device_buffer = device_buffers->back();
if (!device_buffer.ok()) {
return InvalidArgument(
@ -1347,13 +1218,7 @@ StatusOr<ScopedShapedBuffer> PyLocalExecutable::EnqueueExecution(
"%s",
i, replica, device_buffer.status().ToString());
}
// If we are trying to donate the buffer wait on the usage events as well
// as the definition events to ensure that all reads have been completed
// before the buffer is mutated. Usage holds are excluded during a donation
// hold so we know that the set of usage events won't be modified while we
// are enqueueing.
GetDeviceBufferEvents(*device_buffer, /*get_usage_events=*/must_donate,
&events);
GetDeviceBufferDefinitionEvents(*device_buffer, &events);
}
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
@ -1422,13 +1287,9 @@ StatusOr<ScopedShapedBuffer> PyLocalExecutable::EnqueueExecution(
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
ExecutionOutput& execution_output = result_buffer_or_status.ValueOrDie();
// If we used a transient tuple for the arguments we donated its root table
// buffer. In that case, and/or if we donated any input buffers that were
// not aliased, the donated buffers are going to be passed back to us via
// the execution output. We need to ensure they aren't freed until after
// execution completes. (Currently XLA does not support aliasing tuple
// tables, so if any donated parameter is a tuple there will be donated but
// unaliased buffers.)
// If we used a transient tuple for the arguments its root table is going to
// be passed back to us via the execution output. We need to ensure it isn't
// freed until after execution completes.
std::vector<se::OwningDeviceMemory> donated_memory =
execution_output.ConsumeToBeReleased();
absl::InlinedVector<se::DeviceMemoryBase, 3> donated_ptrs;
@ -1495,14 +1356,6 @@ PyLocalExecutable::ExecuteHelper(
device_state->event_pool().ThenAllocateAndRecordEvent(stream);
if (!event_or.ok()) {
StallStreamOnError(device_state, stream);
for (PyLocalBuffer::ScopedHold& b : device_buffers) {
if (b.type() == PyLocalBuffer::ScopedHold::kDonation) {
// Even though there was an error we need to call ConfirmDonation, which
// renders b invalid, since the computation has been enqueued and b has
// been donated.
b.ConfirmDonation();
}
}
return event_or.status();
}
auto definition_event = std::make_shared<BufferDefinitionEvent>();
@ -1538,14 +1391,8 @@ PyLocalExecutable::ExecuteHelper(
// ComputeSynchronized allocation model we don't need to retain a reference
// to the device_buffer during execution because by definition the compute
// stream is synchronized past the execution.
if (b.type() == PyLocalBuffer::ScopedHold::kUsage) {
RecordUsage(std::move(b), device_state, device_state, definition_event,
stream,
/*prefer_to_retain_reference=*/false);
} else {
CHECK(b.type() == PyLocalBuffer::ScopedHold::kDonation);
b.ConfirmDonation();
}
RecordUsage(std::move(b), device_state, device_state, definition_event,
stream, /*prefer_to_retain_reference=*/false);
}
return outputs;
@ -1763,12 +1610,9 @@ PyLocalExecutable::Compile(const XlaComputation& computation,
client->client()->Compile(computation, argument_layout_pointers,
build_options));
auto py_executable = absl::make_unique<PyLocalExecutable>(
return absl::make_unique<PyLocalExecutable>(
std::move(local_executables), options.tuple_arguments,
build_options.device_assignment(), client);
TF_RETURN_IF_ERROR(
py_executable->SetUpDonation(client, options.tuple_arguments));
return py_executable;
}
} // namespace xla

View File

@ -20,7 +20,6 @@ limitations under the License.
#include <string>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
@ -156,13 +155,6 @@ class PyLocalClient : public std::enable_shared_from_this<PyLocalClient> {
// function specifies which one the platform expects.
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
// Some platforms allow executables to donate buffers so that they can be
// aliased from inputs to outputs. This function returns the list of
// parameters that must be donated when executable is run. tuple_inputs
// reflects the option that executable was compiled with.
virtual StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
const LocalExecutable& executable, bool tuple_inputs) const;
protected:
friend class PyLocalBuffer;
virtual void EnqueueCrossHostReceive(
@ -207,17 +199,16 @@ StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
// Holds a reference from Python to a tuple of device buffers. A PyLocalBuffer
// can be either valid or invalid. An invalid buffer is one that has never been
// initialized, or a buffer that has been deleted (e.g., by calling Delete, or
// by donating it to a computation that aliases an input parameter to an
// output). We allow PyLocalBuffer objects to outlive the underlying device
// buffers so we can decouple buffer lifetimes from the corresponding Python
// references if needed. Thread-safe.
// initialized, or a buffer that has been deleted (e.g., by calling Delete). We
// allow PyLocalBuffer objects to outlive the underlying device buffers so we
// can decouple buffer lifetimes from the corresponding Python references if
// needed. Thread-safe.
class PyLocalBuffer {
public:
// Helper class to retain a "hold" on a PyLocalBuffer. A ScopedHold may not
// outlive its parent PyLocalBuffer.
//
// There are three types of hold, as follows:
// There are two types of hold, as follows:
//
// 1) Usage hold: a transient hold while an operation using the buffer is
// being enqueued onto a stream.
@ -239,29 +230,12 @@ class PyLocalBuffer {
// confident via its own synchronization that modifications do not race with
// reads from the PyLocalBuffer.
//
// 3) Donation hold: a transient hold while an execution that donates the
// buffer is being enqueued onto the compute stream.
// A client acquires a donation hold by calling
// PyLocalBuffer::GetBufferWithHold(kDonation). If the enqueue completes
// successfully the hold should be released using a call to ConfirmDonation
// after which the buffer is invalid. If the ScopedHold is deleted without
// ConfirmDonation being called, e.g., on error, the hold is dropped and the
// buffer remains valid. If the buffer is successfully enqueued the client
// *must* call ConfirmDonation.
//
// Donation holds behave like exclusive write locks: when a donation hold
// has been acquired, any attempt to acquire another hold of any type will
// block until the donation hold is dropped or confirmed. Acquiring a donation
// hold will fail with an error if there is any outstanding external hold, and
// will block if there are any outstanding usage holds until those holds are
// dropped or converted.
//
// Calls to PyLocalBuffer::Release (and transitively to
// PyLocalBuffer::Delete() and ~PyLocalBuffer()) will block until all usage
// and donation holds are either deleted or converted/confirmed.
// holds are either deleted or converted.
class ScopedHold {
public:
enum Type { kUsage = 0, kExternalReference, kDonation, kMaxValue };
enum Type { kUsage = 0, KExternalReference, kMaxValue };
~ScopedHold();
ScopedHold(ScopedHold&& other);
@ -295,18 +269,12 @@ class PyLocalBuffer {
std::shared_ptr<BufferDefinitionEvent> event,
bool reference_held);
// Confirms that the buffer was successfully donated to an execution.
// Only valid for holds of type kDonation. Causes the buffer to become
// invalid.
void ConfirmDonation();
// Adds the held device buffers in order to 'iterator'. Used to add the
// buffers to an ExecutionInput. We require but do not verify that
// 'iterator' when passed in is pointing to a sub-tuple of the
// ExecutionInput whose on_device_shape matches that of the
// SharedDeviceBuffer. 'end' is used to check that 'iterator' doesn't run
// out of bounds. Donates the device buffers if the hold type is kDonation,
// otherwise retains ownership of the device buffers.
// out of bounds.
void AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
ExecutionInput* execution_input,
@ -451,7 +419,7 @@ class PyLocalBuffer {
return GetBufferWithHold(ScopedHold::kUsage);
}
ScopedHold GetBufferWithExternalReference() {
return GetBufferWithHold(ScopedHold::kExternalReference);
return GetBufferWithHold(ScopedHold::KExternalReference);
}
// Copies the buffer to device `dst_device`. Returns an error if the buffer is
@ -486,21 +454,14 @@ class PyLocalBuffer {
std::shared_ptr<Literal> value;
};
// Blocks in mu_.Await until there are no more usage holds.
void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Blocks in mu_.Await until there is no donation hold.
void WaitForOutstandingDonationHold() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Adds a hold of 'type' and returns device_buffer_. Returns an error if
// device_buffer_ is null, or if a donation hold was requested when there is
// an outstanding external hold.
// If device_buffer_ is non-null, adds a hold of 'type' and returns
// device_buffer_. Otherwise returns an error.
StatusOr<std::shared_ptr<SharedDeviceBuffer>> GetBufferForHoldLocked(
ScopedHold::Type type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Adds a hold of hold->type() and initializes `hold` with device_buffer_.
// Initializes hold with an error if device_buffer_ is null, or if a donation
// hold was requested when there is an outstanding external hold.
// If device_buffer_ is non-null, adds a hold of hold->type() and initializes
// `hold` with device_buffer_. Otherwise initializes `hold` with an error
// status.
void AcquireHoldLocked(ScopedHold* hold) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Drops a usage hold and calls device_buffer_->AddUsageEvent. Does a sanity
@ -510,11 +471,6 @@ class PyLocalBuffer {
std::shared_ptr<BufferDefinitionEvent> event,
bool reference_held);
// Drops a donation hold and makes *this invalid for further use. Does a
// sanity check that buffer==device_buffer_. Called after device_buffer_ was
// successfully donated to an execution.
void ConfirmDonation(SharedDeviceBuffer* device_buffer);
// Drops a hold without taking any other action. Does a sanity check that
// buffer==device_buffer_ or device_buffer_==nullptr.
void DropHold(ScopedHold::Type type, SharedDeviceBuffer* buffer);
@ -536,8 +492,6 @@ class PyLocalBuffer {
std::shared_ptr<HostValue> host_value_ TF_GUARDED_BY(mu_);
// Count of holds on the buffer.
std::array<int, ScopedHold::Type::kMaxValue> holds_ TF_GUARDED_BY(mu_);
// Semaphore used to ensure there is only one outstanding donation hold.
Semaphore donation_semaphore_;
};
struct CompileOptions {
@ -560,9 +514,7 @@ struct ExecuteOptions {
// Represents a compiled computation that can be executed given handles to
// device-allocated literals. Wraps one or more XLA LocalExecutables (one per
// partition, as specified by the build options). If any input/output alias
// has been specified in the computation, the parameter containing the input
// buffer will be donated when passed to the execution.
// partition, as specified by the build options).
class PyLocalExecutable {
public:
static StatusOr<std::unique_ptr<PyLocalExecutable>> Compile(
@ -627,9 +579,6 @@ class PyLocalExecutable {
const string& name() const;
private:
// Initializes information about which arguments to which executables must be
// donated due to aliases that were specified by the computation.
Status SetUpDonation(PyLocalClient* client, bool tuple_inputs);
StatusOr<ScopedShapedBuffer> EnqueueExecution(
absl::Span<PyLocalBuffer* const> argument_handles, int replica,
int partition, int executable_idx, const RunId& run_id,
@ -645,9 +594,6 @@ class PyLocalExecutable {
PyLocalClient* const client_;
// One executable per partition.
std::vector<std::shared_ptr<LocalExecutable>> executables_;
// Per-executable set of parameters that have any aliased buffers and thus
// must be donated when executing the computation.
std::vector<absl::flat_hash_set<int>> parameters_that_must_be_donated_;
std::shared_ptr<DeviceAssignment> device_assignment_;
// True if the executables were compiled expecting arguments in a single

View File

@ -147,21 +147,6 @@ void SharedDeviceBuffer::AddToInputAsImmutable(
}
}
void SharedDeviceBuffer::AddToInputAsDonated(
ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
ExecutionInput* execution_input,
se::DeviceMemoryAllocator* allocator) const {
for (const se::DeviceMemoryBase& buf : device_memory_) {
CHECK(*iterator != end);
// Set buffers to be case (2) in the comment on ExecutionInput.
(*iterator)->second = MaybeOwningDeviceMemory(
se::OwningDeviceMemory(buf, device_ordinal_, allocator));
execution_input->SetUnownedIndex((*iterator)->first);
++(*iterator);
}
}
namespace {
using MoveIterator =
@ -221,24 +206,18 @@ SharedDeviceBuffer::LockUseAndTransferUsageEvents() {
return std::move(usage_events_);
}
void GetDeviceBufferEvents(
const SharedDeviceBuffer& buffer, bool get_usage_events,
void GetDeviceBufferDefinitionEvents(
const SharedDeviceBuffer& buffer,
absl::flat_hash_set<BufferDefinitionEvent*>* events) {
if (get_usage_events) {
for (const auto& e : buffer.usage_events()) {
events->insert(e.event.get());
}
} else {
for (const auto& e : buffer.definition_events()) {
events->insert(e.get());
}
for (const auto& e : buffer.definition_events()) {
events->insert(e.get());
}
}
void WaitForBufferDefinitionEventsOnStream(const SharedDeviceBuffer& buffer,
se::Stream* stream) {
absl::flat_hash_set<BufferDefinitionEvent*> events;
GetDeviceBufferEvents(buffer, /*get_usage_events=*/false, &events);
GetDeviceBufferDefinitionEvents(buffer, &events);
for (BufferDefinitionEvent* event : events) {
event->WaitForEventOnStream(stream);
}

View File

@ -117,18 +117,6 @@ class BufferDefinitionEvent {
// of memory under all of the allocation model semantics.
class SharedDeviceBuffer {
public:
// Helper object to keep track of usage of the buffer on streams.
struct StreamAndEvent {
// A stream the buffer has been used on.
se::Stream* stream;
// An event that is later than the most recent usage of the buffer on
// stream.
std::shared_ptr<BufferDefinitionEvent> event;
// True if and only if a reference to the buffer is kept live until after
// the host knows that event is complete.
bool reference_held;
};
// Converts a ScopedShapedBuffer into a SharedDeviceBuffer. Takes ownership of
// the buffers of the shaped_buffer.
static std::shared_ptr<SharedDeviceBuffer> FromScopedShapedBuffer(
@ -152,20 +140,6 @@ class SharedDeviceBuffer {
ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
const ShapeTree<MaybeOwningDeviceMemory>::iterator& end) const;
// Adds the owned device buffers in order to 'iterator', marking them as
// available to be donated. If donation succeeds, i.e., execution_input is
// subsequently successfully enqueued to a computation,
// this->ReleaseDeviceMemory() must be called to avoid freeing the device
// memory twice. We require but do not verify that 'iterator' when passed in
// is pointing to a sub-tuple of execution_input whose on_device_shape matches
// that of the SharedDeviceBuffer. 'end' is used to check that 'iterator'
// doesn't run out of bounds.
void AddToInputAsDonated(
ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
ExecutionInput* execution_input,
se::DeviceMemoryAllocator* allocator) const;
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
int device_ordinal() const { return device_ordinal_; }
absl::InlinedVector<se::DeviceMemoryBase, 1>& device_memory() {
@ -178,13 +152,6 @@ class SharedDeviceBuffer {
const {
return definition_events_;
}
absl::Span<const StreamAndEvent> usage_events() const {
return usage_events_;
}
// Relinquishes ownership of the buffer's device memory, e.g., after the
// buffer is passed to a computation that aliases its inputs to outputs.
void ReleaseDeviceMemory() { device_memory_.clear(); }
// Indicates that the buffer has been used on a stream.
//
@ -199,6 +166,17 @@ class SharedDeviceBuffer {
std::shared_ptr<BufferDefinitionEvent> event,
bool reference_held);
// Helper object to keep track of usage of the buffer on streams.
struct StreamAndEvent {
// A stream the buffer has been used on.
se::Stream* stream;
// An event that is later than the most recent usage of the buffer on
// stream.
std::shared_ptr<BufferDefinitionEvent> event;
// True if and only if a reference to the buffer is kept live until after
// the host knows that event is complete.
bool reference_held;
};
using StreamAndEventContainer = absl::InlinedVector<StreamAndEvent, 3>;
// Returns the set of streams that the buffer was used on, and for each stream
// an event later than the last use of the buffer. After
@ -243,12 +221,10 @@ class SharedDeviceBuffer {
std::function<void()> on_delete_callback_;
};
// Populates 'events' with the set of buffer events for buffer. If
// get_usage_events=true populates with the latest usage events, otherwise
// populates with the definition events.
void GetDeviceBufferEvents(const SharedDeviceBuffer& buffer,
bool get_usage_events,
absl::flat_hash_set<BufferDefinitionEvent*>* events);
// Populates 'events' with the set of buffer definition events for buffer.
void GetDeviceBufferDefinitionEvents(
const SharedDeviceBuffer& buffer,
absl::flat_hash_set<BufferDefinitionEvent*>* events);
// Waits for all of the definition events in a buffer on 'stream'.
void WaitForBufferDefinitionEventsOnStream(const SharedDeviceBuffer& buffer,