Add buffer aliasing/donation support to PyLocalClient. Works only for
TPU backend since XLA does not support aliasing on other backends. PiperOrigin-RevId: 306267232 Change-Id: Ib2bea9ba79087afdb81caa0e3c80dea62d9b7d7e
This commit is contained in:
parent
1d939d63fb
commit
ab4462cc09
@ -197,7 +197,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/python/distributed:protocol_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
|
||||
@ -208,7 +207,6 @@ cc_library(
|
||||
"//tensorflow/stream_executor:stream",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -70,7 +70,6 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/casts.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
@ -87,7 +86,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/python/local_device_state.h"
|
||||
#include "tensorflow/compiler/xla/python/shared_device_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -197,21 +195,6 @@ StatusOr<DeviceAssignment> PyLocalClient::GetDefaultDeviceAssignment(
|
||||
num_partitions);
|
||||
}
|
||||
|
||||
StatusOr<absl::flat_hash_set<int>>
|
||||
PyLocalClient::GetParametersThatMustBeDonated(const LocalExecutable& executable,
|
||||
bool tuple_inputs) const {
|
||||
// TODO(b/149489114) support buffer donation on CPU/GPU when XLA supports it.
|
||||
const HloInputOutputAliasConfig& config =
|
||||
executable.executable()->module().input_output_alias_config();
|
||||
TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
|
||||
[](const ShapeIndex& output_index,
|
||||
const HloInputOutputAliasConfig::Alias& alias) {
|
||||
return InvalidArgument(
|
||||
"Buffer aliasing is not supported by XLA for non-TPU backends.");
|
||||
}));
|
||||
return absl::flat_hash_set<int>();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Ensures that it is safe to deallocate any buffers that have been enqueued in
|
||||
@ -457,25 +440,14 @@ void PyLocalBuffer::ScopedHold::ConvertUsageHold(
|
||||
SetError(InvalidArgument("Buffer has been converted"));
|
||||
}
|
||||
|
||||
void PyLocalBuffer::ScopedHold::ConfirmDonation() {
|
||||
CHECK(ok());
|
||||
CHECK(type_ == kDonation);
|
||||
parent_->ConfirmDonation(buffer().get());
|
||||
SetError(InvalidArgument("Buffer has been donated"));
|
||||
}
|
||||
|
||||
void PyLocalBuffer::ScopedHold::AddToInput(
|
||||
ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
|
||||
const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
|
||||
ExecutionInput* execution_input,
|
||||
se::DeviceMemoryAllocator* allocator) const {
|
||||
CHECK(ok());
|
||||
if (type_ == kDonation) {
|
||||
buffer()->AddToInputAsDonated(iterator, end, execution_input, allocator);
|
||||
} else {
|
||||
CHECK(type_ == kUsage);
|
||||
buffer()->AddToInputAsImmutable(iterator, end);
|
||||
}
|
||||
CHECK(type_ == kUsage);
|
||||
buffer()->AddToInputAsImmutable(iterator, end);
|
||||
}
|
||||
|
||||
/* static */
|
||||
@ -681,8 +653,7 @@ PyLocalBuffer::PyLocalBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||
on_host_shape_(std::move(on_host_shape)),
|
||||
on_device_shape_(std::move(on_device_shape)),
|
||||
device_(device),
|
||||
device_buffer_(std::move(device_buffer)),
|
||||
donation_semaphore_(/*capacity=*/1) {
|
||||
device_buffer_(std::move(device_buffer)) {
|
||||
for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) {
|
||||
holds_[i] = 0;
|
||||
}
|
||||
@ -695,41 +666,24 @@ PyLocalBuffer::~PyLocalBuffer() {
|
||||
}
|
||||
}
|
||||
|
||||
void PyLocalBuffer::WaitForOutstandingUsageHolds() {
|
||||
auto not_in_usage_hold = [&]() {
|
||||
StatusOr<std::shared_ptr<SharedDeviceBuffer>> PyLocalBuffer::Release(
|
||||
bool wait_for_operations_to_complete) {
|
||||
auto no_usage_holds = [&]() {
|
||||
mu_.AssertHeld();
|
||||
return holds_[ScopedHold::kUsage] == 0;
|
||||
};
|
||||
mu_.Await(absl::Condition(¬_in_usage_hold));
|
||||
}
|
||||
|
||||
void PyLocalBuffer::WaitForOutstandingDonationHold() {
|
||||
auto not_in_donation_hold = [&]() {
|
||||
mu_.AssertHeld();
|
||||
return holds_[ScopedHold::kDonation] == 0;
|
||||
};
|
||||
mu_.Await(absl::Condition(¬_in_donation_hold));
|
||||
}
|
||||
|
||||
StatusOr<std::shared_ptr<SharedDeviceBuffer>> PyLocalBuffer::Release(
|
||||
bool wait_for_operations_to_complete) {
|
||||
std::shared_ptr<SharedDeviceBuffer> device_buffer;
|
||||
SharedDeviceBuffer::StreamAndEventContainer events;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
// We first wait for a donation hold to complete if there is one in
|
||||
// progress. If the donation succeeds via ConfirmDonation() then it will
|
||||
// set device_buffer_ to nullptr before returning to this thread.
|
||||
WaitForOutstandingDonationHold();
|
||||
if (device_buffer_ == nullptr) {
|
||||
return std::shared_ptr<SharedDeviceBuffer>();
|
||||
}
|
||||
// Set host_value_ and device_buffer_ to null now so that no other thread
|
||||
// can add a hold while we are in WaitForOutstandingUsageHolds()
|
||||
// below.
|
||||
// can add a hold while we are in Await below.
|
||||
host_value_ = nullptr;
|
||||
std::swap(device_buffer_, device_buffer);
|
||||
WaitForOutstandingUsageHolds();
|
||||
mu_.Await(absl::Condition(&no_usage_holds));
|
||||
// Now that all holds have completed and no more can be added, we can get
|
||||
// the final set of usage events.
|
||||
events = device_buffer->LockUseAndTransferUsageEvents();
|
||||
@ -799,35 +753,10 @@ bool PyLocalBuffer::IsDeleted() {
|
||||
|
||||
StatusOr<std::shared_ptr<SharedDeviceBuffer>>
|
||||
PyLocalBuffer::GetBufferForHoldLocked(ScopedHold::Type type) {
|
||||
if (type == ScopedHold::kDonation) {
|
||||
if (device_buffer_ == nullptr) {
|
||||
return InvalidArgument("Donation requested for invalid buffer");
|
||||
}
|
||||
if (holds_[ScopedHold::kExternalReference] > 0) {
|
||||
return InvalidArgument(
|
||||
"Donation requested for buffer with external reference");
|
||||
}
|
||||
// donation_semaphore_ was acquired in GetBufferWithHold so that only one
|
||||
// thread at a time can attempt to get a donation hold.
|
||||
CHECK_EQ(holds_[type], 0);
|
||||
// First add the donation hold.
|
||||
++holds_[type];
|
||||
// Then wait for any usage holds to be dropped or converted. No new usage
|
||||
// holds can be added until we drop the donation hold so this wait will
|
||||
// complete eventually.
|
||||
WaitForOutstandingUsageHolds();
|
||||
// Because we added a donation hold, nobody could release the buffer while
|
||||
// we were waiting.
|
||||
CHECK(device_buffer_ != nullptr);
|
||||
if (device_buffer_ == nullptr) {
|
||||
return InvalidArgument("Hold requested on invalid buffer");
|
||||
} else {
|
||||
// If there is a donation hold in progress we have to wait before
|
||||
// acquiring any other kind of hold.
|
||||
WaitForOutstandingDonationHold();
|
||||
if (device_buffer_ == nullptr) {
|
||||
return InvalidArgument("Hold requested on invalid buffer");
|
||||
} else {
|
||||
++holds_[type];
|
||||
}
|
||||
++holds_[type];
|
||||
}
|
||||
return device_buffer_;
|
||||
}
|
||||
@ -846,40 +775,12 @@ void PyLocalBuffer::ConvertUsageHold(
|
||||
--holds_[ScopedHold::kUsage];
|
||||
}
|
||||
|
||||
void PyLocalBuffer::ConfirmDonation(SharedDeviceBuffer* device_buffer) {
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
CHECK_EQ(holds_[ScopedHold::kUsage], 0);
|
||||
CHECK_EQ(holds_[ScopedHold::kExternalReference], 0);
|
||||
CHECK_EQ(holds_[ScopedHold::kDonation], 1);
|
||||
holds_[ScopedHold::kDonation] = 0;
|
||||
CHECK(device_buffer_.get() == device_buffer);
|
||||
// As a sanity check ensure no more usage events can be added to the buffer.
|
||||
device_buffer->LockUseAndTransferUsageEvents();
|
||||
// Give up ownership of the device memory so we don't free it when the last
|
||||
// reference to device_buffer_ goes away.
|
||||
device_buffer->ReleaseDeviceMemory();
|
||||
// Make *this invalid so it can't be used again. Any threads blocking in
|
||||
// Release or GetBufferWithHold will see an invalid buffer and return.
|
||||
host_value_ = nullptr;
|
||||
device_buffer_.reset();
|
||||
}
|
||||
// Unblock another thread, if any, trying to get a donation hold.
|
||||
donation_semaphore_.Release(1);
|
||||
}
|
||||
|
||||
void PyLocalBuffer::DropHold(ScopedHold::Type type,
|
||||
SharedDeviceBuffer* buffer) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr);
|
||||
CHECK_GT(holds_[type], 0);
|
||||
--holds_[type];
|
||||
if (type == ScopedHold::kDonation) {
|
||||
CHECK_EQ(holds_[ScopedHold::kDonation], 0);
|
||||
CHECK_EQ(holds_[ScopedHold::kUsage], 0);
|
||||
CHECK_EQ(holds_[ScopedHold::kExternalReference], 0);
|
||||
donation_semaphore_.Release(1);
|
||||
}
|
||||
}
|
||||
|
||||
Status PyLocalBuffer::CopyToHostAsync() {
|
||||
@ -892,8 +793,6 @@ Status PyLocalBuffer::CopyToHostAsync() {
|
||||
se::Stream* stream = local_device->GetDeviceToHostStream();
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
// We can't perform any other action while a donation hold is in progress.
|
||||
WaitForOutstandingDonationHold();
|
||||
if (device_buffer_ == nullptr) {
|
||||
return InvalidArgument("CopyToHostAsync() called on invalid buffer.");
|
||||
}
|
||||
@ -968,16 +867,9 @@ StatusOr<ShapedBuffer> PyLocalBuffer::AsShapedBuffer() const {
|
||||
|
||||
PyLocalBuffer::ScopedHold PyLocalBuffer::GetBufferWithHold(
|
||||
ScopedHold::Type type) {
|
||||
if (type == ScopedHold::kDonation) {
|
||||
// Ensure that at most one donation hold can be in progress at a time.
|
||||
donation_semaphore_.Acquire(1);
|
||||
}
|
||||
absl::MutexLock lock(&mu_);
|
||||
ScopedHold hold(this, type);
|
||||
AcquireHoldLocked(&hold);
|
||||
if (type == ScopedHold::kDonation && !hold.status().ok()) {
|
||||
donation_semaphore_.Release(1);
|
||||
}
|
||||
return hold;
|
||||
}
|
||||
|
||||
@ -1064,8 +956,6 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
|
||||
ScopedHold src_device_buffer(this, ScopedHold::kUsage);
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
// We can't perform any other action while a donation hold is in progress.
|
||||
WaitForOutstandingDonationHold();
|
||||
if (device_buffer_ == nullptr) {
|
||||
return InvalidArgument("CopyToDevice called on invalid buffer");
|
||||
}
|
||||
@ -1281,19 +1171,6 @@ PyLocalExecutable::PyLocalExecutable(
|
||||
<< "Inconsistent local device count.";
|
||||
}
|
||||
|
||||
Status PyLocalExecutable::SetUpDonation(PyLocalClient* client,
|
||||
bool tuple_inputs) {
|
||||
parameters_that_must_be_donated_.reserve(executables_.size());
|
||||
for (auto& executable : executables_) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
absl::flat_hash_set<int> parameters_to_donate,
|
||||
client->GetParametersThatMustBeDonated(*executable, tuple_inputs));
|
||||
parameters_that_must_be_donated_.emplace_back(
|
||||
std::move(parameters_to_donate));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const std::string& PyLocalExecutable::name() const {
|
||||
Executable* executable = executables_[0]->executable();
|
||||
if (executable->has_module()) {
|
||||
@ -1325,8 +1202,6 @@ StatusOr<ScopedShapedBuffer> PyLocalExecutable::EnqueueExecution(
|
||||
std::vector<const Shape*> argument_host_shapes;
|
||||
std::vector<ExecutionInput> execution_inputs;
|
||||
device_buffers->reserve(argument_handles.size());
|
||||
const absl::flat_hash_set<int>& parameters_that_must_be_donated =
|
||||
parameters_that_must_be_donated_[executable_idx];
|
||||
for (int i = 0; i < argument_handles.size(); ++i) {
|
||||
PyLocalBuffer* handle = argument_handles[i];
|
||||
if (handle->device() != device) {
|
||||
@ -1335,11 +1210,7 @@ StatusOr<ScopedShapedBuffer> PyLocalExecutable::EnqueueExecution(
|
||||
"device %s, but replica is assigned to device %s.",
|
||||
i, replica, handle->device()->DebugString(), device->DebugString());
|
||||
}
|
||||
bool must_donate = parameters_that_must_be_donated.find(i) !=
|
||||
parameters_that_must_be_donated.end();
|
||||
device_buffers->emplace_back(handle->GetBufferWithHold(
|
||||
must_donate ? PyLocalBuffer::ScopedHold::kDonation
|
||||
: PyLocalBuffer::ScopedHold::kUsage));
|
||||
device_buffers->emplace_back(handle->GetBufferWithUsageHold());
|
||||
PyLocalBuffer::ScopedHold& device_buffer = device_buffers->back();
|
||||
if (!device_buffer.ok()) {
|
||||
return InvalidArgument(
|
||||
@ -1347,13 +1218,7 @@ StatusOr<ScopedShapedBuffer> PyLocalExecutable::EnqueueExecution(
|
||||
"%s",
|
||||
i, replica, device_buffer.status().ToString());
|
||||
}
|
||||
// If we are trying to donate the buffer wait on the usage events as well
|
||||
// as the definition events to ensure that all reads have been completed
|
||||
// before the buffer is mutated. Usage holds are excluded during a donation
|
||||
// hold so we know that the set of usage events won't be modified while we
|
||||
// are enqueueing.
|
||||
GetDeviceBufferEvents(*device_buffer, /*get_usage_events=*/must_donate,
|
||||
&events);
|
||||
GetDeviceBufferDefinitionEvents(*device_buffer, &events);
|
||||
}
|
||||
|
||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
||||
@ -1422,13 +1287,9 @@ StatusOr<ScopedShapedBuffer> PyLocalExecutable::EnqueueExecution(
|
||||
|
||||
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
|
||||
ExecutionOutput& execution_output = result_buffer_or_status.ValueOrDie();
|
||||
// If we used a transient tuple for the arguments we donated its root table
|
||||
// buffer. In that case, and/or if we donated any input buffers that were
|
||||
// not aliased, the donated buffers are going to be passed back to us via
|
||||
// the execution output. We need to ensure they aren't freed until after
|
||||
// execution completes. (Currently XLA does not support aliasing tuple
|
||||
// tables, so if any donated parameter is a tuple there will be donated but
|
||||
// unaliased buffers.)
|
||||
// If we used a transient tuple for the arguments its root table is going to
|
||||
// be passed back to us via the execution output. We need to ensure it isn't
|
||||
// freed until after execution completes.
|
||||
std::vector<se::OwningDeviceMemory> donated_memory =
|
||||
execution_output.ConsumeToBeReleased();
|
||||
absl::InlinedVector<se::DeviceMemoryBase, 3> donated_ptrs;
|
||||
@ -1495,14 +1356,6 @@ PyLocalExecutable::ExecuteHelper(
|
||||
device_state->event_pool().ThenAllocateAndRecordEvent(stream);
|
||||
if (!event_or.ok()) {
|
||||
StallStreamOnError(device_state, stream);
|
||||
for (PyLocalBuffer::ScopedHold& b : device_buffers) {
|
||||
if (b.type() == PyLocalBuffer::ScopedHold::kDonation) {
|
||||
// Even though there was an error we need to call ConfirmDonation, which
|
||||
// renders b invalid, since the computation has been enqueued and b has
|
||||
// been donated.
|
||||
b.ConfirmDonation();
|
||||
}
|
||||
}
|
||||
return event_or.status();
|
||||
}
|
||||
auto definition_event = std::make_shared<BufferDefinitionEvent>();
|
||||
@ -1538,14 +1391,8 @@ PyLocalExecutable::ExecuteHelper(
|
||||
// ComputeSynchronized allocation model we don't need to retain a reference
|
||||
// to the device_buffer during execution because by definition the compute
|
||||
// stream is synchronized past the execution.
|
||||
if (b.type() == PyLocalBuffer::ScopedHold::kUsage) {
|
||||
RecordUsage(std::move(b), device_state, device_state, definition_event,
|
||||
stream,
|
||||
/*prefer_to_retain_reference=*/false);
|
||||
} else {
|
||||
CHECK(b.type() == PyLocalBuffer::ScopedHold::kDonation);
|
||||
b.ConfirmDonation();
|
||||
}
|
||||
RecordUsage(std::move(b), device_state, device_state, definition_event,
|
||||
stream, /*prefer_to_retain_reference=*/false);
|
||||
}
|
||||
|
||||
return outputs;
|
||||
@ -1763,12 +1610,9 @@ PyLocalExecutable::Compile(const XlaComputation& computation,
|
||||
client->client()->Compile(computation, argument_layout_pointers,
|
||||
build_options));
|
||||
|
||||
auto py_executable = absl::make_unique<PyLocalExecutable>(
|
||||
return absl::make_unique<PyLocalExecutable>(
|
||||
std::move(local_executables), options.tuple_arguments,
|
||||
build_options.device_assignment(), client);
|
||||
TF_RETURN_IF_ERROR(
|
||||
py_executable->SetUpDonation(client, options.tuple_arguments));
|
||||
return py_executable;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
@ -156,13 +155,6 @@ class PyLocalClient : public std::enable_shared_from_this<PyLocalClient> {
|
||||
// function specifies which one the platform expects.
|
||||
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
|
||||
|
||||
// Some platforms allow executables to donate buffers so that they can be
|
||||
// aliased from inputs to outputs. This function returns the list of
|
||||
// parameters that must be donated when executable is run. tuple_inputs
|
||||
// reflects the option that executable was compiled with.
|
||||
virtual StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
|
||||
const LocalExecutable& executable, bool tuple_inputs) const;
|
||||
|
||||
protected:
|
||||
friend class PyLocalBuffer;
|
||||
virtual void EnqueueCrossHostReceive(
|
||||
@ -207,17 +199,16 @@ StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
|
||||
|
||||
// Holds a reference from Python to a tuple of device buffers. A PyLocalBuffer
|
||||
// can be either valid or invalid. An invalid buffer is one that has never been
|
||||
// initialized, or a buffer that has been deleted (e.g., by calling Delete, or
|
||||
// by donating it to a computation that aliases an input parameter to an
|
||||
// output). We allow PyLocalBuffer objects to outlive the underlying device
|
||||
// buffers so we can decouple buffer lifetimes from the corresponding Python
|
||||
// references if needed. Thread-safe.
|
||||
// initialized, or a buffer that has been deleted (e.g., by calling Delete). We
|
||||
// allow PyLocalBuffer objects to outlive the underlying device buffers so we
|
||||
// can decouple buffer lifetimes from the corresponding Python references if
|
||||
// needed. Thread-safe.
|
||||
class PyLocalBuffer {
|
||||
public:
|
||||
// Helper class to retain a "hold" on a PyLocalBuffer. A ScopedHold may not
|
||||
// outlive its parent PyLocalBuffer.
|
||||
//
|
||||
// There are three types of hold, as follows:
|
||||
// There are two types of hold, as follows:
|
||||
//
|
||||
// 1) Usage hold: a transient hold while an operation using the buffer is
|
||||
// being enqueued onto a stream.
|
||||
@ -239,29 +230,12 @@ class PyLocalBuffer {
|
||||
// confident via its own synchronization that modifications do not race with
|
||||
// reads from the PyLocalBuffer.
|
||||
//
|
||||
// 3) Donation hold: a transient hold while an execution that donates the
|
||||
// buffer is being enqueued onto the compute stream.
|
||||
// A client acquires a donation hold by calling
|
||||
// PyLocalBuffer::GetBufferWithHold(kDonation). If the enqueue completes
|
||||
// successfully the hold should be released using a call to ConfirmDonation
|
||||
// after which the buffer is invalid. If the ScopedHold is deleted without
|
||||
// ConfirmDonation being called, e.g., on error, the hold is dropped and the
|
||||
// buffer remains valid. If the buffer is successfully enqueued the client
|
||||
// *must* call ConfirmDonation.
|
||||
//
|
||||
// Donation holds behave like exclusive write locks: when a donation hold
|
||||
// has been acquired, any attempt to acquire another hold of any type will
|
||||
// block until the donation hold is dropped or confirmed. Acquiring a donation
|
||||
// hold will fail with an error if there is any outstanding external hold, and
|
||||
// will block if there are any outstanding usage holds until those holds are
|
||||
// dropped or converted.
|
||||
//
|
||||
// Calls to PyLocalBuffer::Release (and transitively to
|
||||
// PyLocalBuffer::Delete() and ~PyLocalBuffer()) will block until all usage
|
||||
// and donation holds are either deleted or converted/confirmed.
|
||||
// holds are either deleted or converted.
|
||||
class ScopedHold {
|
||||
public:
|
||||
enum Type { kUsage = 0, kExternalReference, kDonation, kMaxValue };
|
||||
enum Type { kUsage = 0, KExternalReference, kMaxValue };
|
||||
|
||||
~ScopedHold();
|
||||
ScopedHold(ScopedHold&& other);
|
||||
@ -295,18 +269,12 @@ class PyLocalBuffer {
|
||||
std::shared_ptr<BufferDefinitionEvent> event,
|
||||
bool reference_held);
|
||||
|
||||
// Confirms that the buffer was successfully donated to an execution.
|
||||
// Only valid for holds of type kDonation. Causes the buffer to become
|
||||
// invalid.
|
||||
void ConfirmDonation();
|
||||
|
||||
// Adds the held device buffers in order to 'iterator'. Used to add the
|
||||
// buffers to an ExecutionInput. We require but do not verify that
|
||||
// 'iterator' when passed in is pointing to a sub-tuple of the
|
||||
// ExecutionInput whose on_device_shape matches that of the
|
||||
// SharedDeviceBuffer. 'end' is used to check that 'iterator' doesn't run
|
||||
// out of bounds. Donates the device buffers if the hold type is kDonation,
|
||||
// otherwise retains ownership of the device buffers.
|
||||
// out of bounds.
|
||||
void AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
|
||||
const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
|
||||
ExecutionInput* execution_input,
|
||||
@ -451,7 +419,7 @@ class PyLocalBuffer {
|
||||
return GetBufferWithHold(ScopedHold::kUsage);
|
||||
}
|
||||
ScopedHold GetBufferWithExternalReference() {
|
||||
return GetBufferWithHold(ScopedHold::kExternalReference);
|
||||
return GetBufferWithHold(ScopedHold::KExternalReference);
|
||||
}
|
||||
|
||||
// Copies the buffer to device `dst_device`. Returns an error if the buffer is
|
||||
@ -486,21 +454,14 @@ class PyLocalBuffer {
|
||||
std::shared_ptr<Literal> value;
|
||||
};
|
||||
|
||||
// Blocks in mu_.Await until there are no more usage holds.
|
||||
void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Blocks in mu_.Await until there is no donation hold.
|
||||
void WaitForOutstandingDonationHold() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Adds a hold of 'type' and returns device_buffer_. Returns an error if
|
||||
// device_buffer_ is null, or if a donation hold was requested when there is
|
||||
// an outstanding external hold.
|
||||
// If device_buffer_ is non-null, adds a hold of 'type' and returns
|
||||
// device_buffer_. Otherwise returns an error.
|
||||
StatusOr<std::shared_ptr<SharedDeviceBuffer>> GetBufferForHoldLocked(
|
||||
ScopedHold::Type type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Adds a hold of hold->type() and initializes `hold` with device_buffer_.
|
||||
// Initializes hold with an error if device_buffer_ is null, or if a donation
|
||||
// hold was requested when there is an outstanding external hold.
|
||||
// If device_buffer_ is non-null, adds a hold of hold->type() and initializes
|
||||
// `hold` with device_buffer_. Otherwise initializes `hold` with an error
|
||||
// status.
|
||||
void AcquireHoldLocked(ScopedHold* hold) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Drops a usage hold and calls device_buffer_->AddUsageEvent. Does a sanity
|
||||
@ -510,11 +471,6 @@ class PyLocalBuffer {
|
||||
std::shared_ptr<BufferDefinitionEvent> event,
|
||||
bool reference_held);
|
||||
|
||||
// Drops a donation hold and makes *this invalid for further use. Does a
|
||||
// sanity check that buffer==device_buffer_. Called after device_buffer_ was
|
||||
// successfully donated to an execution.
|
||||
void ConfirmDonation(SharedDeviceBuffer* device_buffer);
|
||||
|
||||
// Drops a hold without taking any other action. Does a sanity check that
|
||||
// buffer==device_buffer_ or device_buffer_==nullptr.
|
||||
void DropHold(ScopedHold::Type type, SharedDeviceBuffer* buffer);
|
||||
@ -536,8 +492,6 @@ class PyLocalBuffer {
|
||||
std::shared_ptr<HostValue> host_value_ TF_GUARDED_BY(mu_);
|
||||
// Count of holds on the buffer.
|
||||
std::array<int, ScopedHold::Type::kMaxValue> holds_ TF_GUARDED_BY(mu_);
|
||||
// Semaphore used to ensure there is only one outstanding donation hold.
|
||||
Semaphore donation_semaphore_;
|
||||
};
|
||||
|
||||
struct CompileOptions {
|
||||
@ -560,9 +514,7 @@ struct ExecuteOptions {
|
||||
|
||||
// Represents a compiled computation that can be executed given handles to
|
||||
// device-allocated literals. Wraps one or more XLA LocalExecutables (one per
|
||||
// partition, as specified by the build options). If any input/output alias
|
||||
// has been specified in the computation, the parameter containing the input
|
||||
// buffer will be donated when passed to the execution.
|
||||
// partition, as specified by the build options).
|
||||
class PyLocalExecutable {
|
||||
public:
|
||||
static StatusOr<std::unique_ptr<PyLocalExecutable>> Compile(
|
||||
@ -627,9 +579,6 @@ class PyLocalExecutable {
|
||||
const string& name() const;
|
||||
|
||||
private:
|
||||
// Initializes information about which arguments to which executables must be
|
||||
// donated due to aliases that were specified by the computation.
|
||||
Status SetUpDonation(PyLocalClient* client, bool tuple_inputs);
|
||||
StatusOr<ScopedShapedBuffer> EnqueueExecution(
|
||||
absl::Span<PyLocalBuffer* const> argument_handles, int replica,
|
||||
int partition, int executable_idx, const RunId& run_id,
|
||||
@ -645,9 +594,6 @@ class PyLocalExecutable {
|
||||
PyLocalClient* const client_;
|
||||
// One executable per partition.
|
||||
std::vector<std::shared_ptr<LocalExecutable>> executables_;
|
||||
// Per-executable set of parameters that have any aliased buffers and thus
|
||||
// must be donated when executing the computation.
|
||||
std::vector<absl::flat_hash_set<int>> parameters_that_must_be_donated_;
|
||||
std::shared_ptr<DeviceAssignment> device_assignment_;
|
||||
|
||||
// True if the executables were compiled expecting arguments in a single
|
||||
|
@ -147,21 +147,6 @@ void SharedDeviceBuffer::AddToInputAsImmutable(
|
||||
}
|
||||
}
|
||||
|
||||
void SharedDeviceBuffer::AddToInputAsDonated(
|
||||
ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
|
||||
const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
|
||||
ExecutionInput* execution_input,
|
||||
se::DeviceMemoryAllocator* allocator) const {
|
||||
for (const se::DeviceMemoryBase& buf : device_memory_) {
|
||||
CHECK(*iterator != end);
|
||||
// Set buffers to be case (2) in the comment on ExecutionInput.
|
||||
(*iterator)->second = MaybeOwningDeviceMemory(
|
||||
se::OwningDeviceMemory(buf, device_ordinal_, allocator));
|
||||
execution_input->SetUnownedIndex((*iterator)->first);
|
||||
++(*iterator);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
using MoveIterator =
|
||||
@ -221,24 +206,18 @@ SharedDeviceBuffer::LockUseAndTransferUsageEvents() {
|
||||
return std::move(usage_events_);
|
||||
}
|
||||
|
||||
void GetDeviceBufferEvents(
|
||||
const SharedDeviceBuffer& buffer, bool get_usage_events,
|
||||
void GetDeviceBufferDefinitionEvents(
|
||||
const SharedDeviceBuffer& buffer,
|
||||
absl::flat_hash_set<BufferDefinitionEvent*>* events) {
|
||||
if (get_usage_events) {
|
||||
for (const auto& e : buffer.usage_events()) {
|
||||
events->insert(e.event.get());
|
||||
}
|
||||
} else {
|
||||
for (const auto& e : buffer.definition_events()) {
|
||||
events->insert(e.get());
|
||||
}
|
||||
for (const auto& e : buffer.definition_events()) {
|
||||
events->insert(e.get());
|
||||
}
|
||||
}
|
||||
|
||||
void WaitForBufferDefinitionEventsOnStream(const SharedDeviceBuffer& buffer,
|
||||
se::Stream* stream) {
|
||||
absl::flat_hash_set<BufferDefinitionEvent*> events;
|
||||
GetDeviceBufferEvents(buffer, /*get_usage_events=*/false, &events);
|
||||
GetDeviceBufferDefinitionEvents(buffer, &events);
|
||||
for (BufferDefinitionEvent* event : events) {
|
||||
event->WaitForEventOnStream(stream);
|
||||
}
|
||||
|
@ -117,18 +117,6 @@ class BufferDefinitionEvent {
|
||||
// of memory under all of the allocation model semantics.
|
||||
class SharedDeviceBuffer {
|
||||
public:
|
||||
// Helper object to keep track of usage of the buffer on streams.
|
||||
struct StreamAndEvent {
|
||||
// A stream the buffer has been used on.
|
||||
se::Stream* stream;
|
||||
// An event that is later than the most recent usage of the buffer on
|
||||
// stream.
|
||||
std::shared_ptr<BufferDefinitionEvent> event;
|
||||
// True if and only if a reference to the buffer is kept live until after
|
||||
// the host knows that event is complete.
|
||||
bool reference_held;
|
||||
};
|
||||
|
||||
// Converts a ScopedShapedBuffer into a SharedDeviceBuffer. Takes ownership of
|
||||
// the buffers of the shaped_buffer.
|
||||
static std::shared_ptr<SharedDeviceBuffer> FromScopedShapedBuffer(
|
||||
@ -152,20 +140,6 @@ class SharedDeviceBuffer {
|
||||
ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
|
||||
const ShapeTree<MaybeOwningDeviceMemory>::iterator& end) const;
|
||||
|
||||
// Adds the owned device buffers in order to 'iterator', marking them as
|
||||
// available to be donated. If donation succeeds, i.e., execution_input is
|
||||
// subsequently successfully enqueued to a computation,
|
||||
// this->ReleaseDeviceMemory() must be called to avoid freeing the device
|
||||
// memory twice. We require but do not verify that 'iterator' when passed in
|
||||
// is pointing to a sub-tuple of execution_input whose on_device_shape matches
|
||||
// that of the SharedDeviceBuffer. 'end' is used to check that 'iterator'
|
||||
// doesn't run out of bounds.
|
||||
void AddToInputAsDonated(
|
||||
ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
|
||||
const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
|
||||
ExecutionInput* execution_input,
|
||||
se::DeviceMemoryAllocator* allocator) const;
|
||||
|
||||
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
|
||||
int device_ordinal() const { return device_ordinal_; }
|
||||
absl::InlinedVector<se::DeviceMemoryBase, 1>& device_memory() {
|
||||
@ -178,13 +152,6 @@ class SharedDeviceBuffer {
|
||||
const {
|
||||
return definition_events_;
|
||||
}
|
||||
absl::Span<const StreamAndEvent> usage_events() const {
|
||||
return usage_events_;
|
||||
}
|
||||
|
||||
// Relinquishes ownership of the buffer's device memory, e.g., after the
|
||||
// buffer is passed to a computation that aliases its inputs to outputs.
|
||||
void ReleaseDeviceMemory() { device_memory_.clear(); }
|
||||
|
||||
// Indicates that the buffer has been used on a stream.
|
||||
//
|
||||
@ -199,6 +166,17 @@ class SharedDeviceBuffer {
|
||||
std::shared_ptr<BufferDefinitionEvent> event,
|
||||
bool reference_held);
|
||||
|
||||
// Helper object to keep track of usage of the buffer on streams.
|
||||
struct StreamAndEvent {
|
||||
// A stream the buffer has been used on.
|
||||
se::Stream* stream;
|
||||
// An event that is later than the most recent usage of the buffer on
|
||||
// stream.
|
||||
std::shared_ptr<BufferDefinitionEvent> event;
|
||||
// True if and only if a reference to the buffer is kept live until after
|
||||
// the host knows that event is complete.
|
||||
bool reference_held;
|
||||
};
|
||||
using StreamAndEventContainer = absl::InlinedVector<StreamAndEvent, 3>;
|
||||
// Returns the set of streams that the buffer was used on, and for each stream
|
||||
// an event later than the last use of the buffer. After
|
||||
@ -243,12 +221,10 @@ class SharedDeviceBuffer {
|
||||
std::function<void()> on_delete_callback_;
|
||||
};
|
||||
|
||||
// Populates 'events' with the set of buffer events for buffer. If
|
||||
// get_usage_events=true populates with the latest usage events, otherwise
|
||||
// populates with the definition events.
|
||||
void GetDeviceBufferEvents(const SharedDeviceBuffer& buffer,
|
||||
bool get_usage_events,
|
||||
absl::flat_hash_set<BufferDefinitionEvent*>* events);
|
||||
// Populates 'events' with the set of buffer definition events for buffer.
|
||||
void GetDeviceBufferDefinitionEvents(
|
||||
const SharedDeviceBuffer& buffer,
|
||||
absl::flat_hash_set<BufferDefinitionEvent*>* events);
|
||||
|
||||
// Waits for all of the definition events in a buffer on 'stream'.
|
||||
void WaitForBufferDefinitionEventsOnStream(const SharedDeviceBuffer& buffer,
|
||||
|
Loading…
x
Reference in New Issue
Block a user