[JAX] Add support for asynchronous execution, but leave it disabled by default for now.
[XLA:Python] Add support for asynchronous execution in the Python client. Python isn't famous for being the world's fastest language, so for high performance eager-style dispatch it is helpful to be able to hide Python latency behind device computations by having the Python code dispatch device operations asynchronously. The design here closely follows the design of asynchronous execution in TensorFlow and the TensorFlow/XLA client. We use three main streams: * a compute stream, for running XLA computations, * a host-to-device stream, for transferring data onto the device * a device-to-host stream, for transferring data off the device. Both host-to-device transfers and compute are asynchronous, that is, they return control to Python as soon as any necessary error checking is complete, but before the operation completes. This allows the Python code to enqueue any subsequent operations while the previously enqueued operations complete. Device-to-host transfers are still blocking, in the sense that they stall the host until the host-side data is ready. [XLA] Add LocalExecutable::RunAsync() to obtain async execution on a stream. There is currently no way to achieve this via the LocalClient API, only by using internal XLA APIs. [XLA:GPU] Implement ExecuteAsyncOnStream. It turns out that ExecuteOnStream is already more or less async anyway. PiperOrigin-RevId: 246650968
This commit is contained in:
parent
f9cdb743b4
commit
c7b255ae35
@ -140,7 +140,8 @@ Status LocalExecutable::ValidateExecutionOptions(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
|
||||
StatusOr<std::pair<ServiceExecutableRunOptions, StreamPool::Ptr>>
|
||||
LocalExecutable::RunHelper(
|
||||
const absl::Span<const ShapedBuffer* const> arguments,
|
||||
ExecutableRunOptions run_options) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -149,7 +150,7 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
|
||||
StreamPool::Ptr stream;
|
||||
if (run_options.stream() == nullptr) {
|
||||
// NB! The lifetime of `stream` needs to match the lifetime of
|
||||
// `actual_options` (otherwise we will end up using a returned stream in
|
||||
// `service_options` (otherwise we will end up using a returned stream in
|
||||
// ExecuteOnStreamWrapper), which is why it isn't declared in the inner "if"
|
||||
// scope.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -167,12 +168,29 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
|
||||
// backend_->eigen_intra_op_thread_pool().
|
||||
ServiceExecutableRunOptions service_options(run_options,
|
||||
backend_->StreamBorrower());
|
||||
return std::make_pair(service_options, std::move(stream));
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
|
||||
const absl::Span<const ShapedBuffer* const> arguments,
|
||||
ExecutableRunOptions run_options) {
|
||||
TF_ASSIGN_OR_RETURN(auto options_and_stream,
|
||||
RunHelper(arguments, run_options));
|
||||
|
||||
if (executable_->dumping_snapshot()) {
|
||||
return ExecuteAndDump(&service_options, arguments);
|
||||
return ExecuteAndDump(&options_and_stream.first, arguments);
|
||||
}
|
||||
return executable_->ExecuteOnStreamWrapper(
|
||||
&service_options, run_options.execution_profile(), arguments);
|
||||
&options_and_stream.first, run_options.execution_profile(), arguments);
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> LocalExecutable::RunAsync(
|
||||
const absl::Span<const ShapedBuffer* const> arguments,
|
||||
ExecutableRunOptions run_options) {
|
||||
TF_ASSIGN_OR_RETURN(auto options_and_stream,
|
||||
RunHelper(arguments, run_options));
|
||||
return executable_->ExecuteAsyncOnStream(&options_and_stream.first,
|
||||
arguments);
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> LocalExecutable::ExecuteAndDump(
|
||||
|
@ -43,6 +43,12 @@ class LocalExecutable {
|
||||
const absl::Span<const ShapedBuffer* const> arguments,
|
||||
ExecutableRunOptions run_options);
|
||||
|
||||
// Similar to Run(), but need not block the host waiting for the computation
|
||||
// to complete before returning.
|
||||
StatusOr<ScopedShapedBuffer> RunAsync(
|
||||
const absl::Span<const ShapedBuffer* const> arguments,
|
||||
ExecutableRunOptions run_options);
|
||||
|
||||
// Return the options used to build the executable.
|
||||
const ExecutableBuildOptions& build_options() const { return build_options_; }
|
||||
|
||||
@ -86,6 +92,10 @@ class LocalExecutable {
|
||||
// Returns a literal containing the contents of the given ShapedBuffer.
|
||||
StatusOr<Literal> LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer);
|
||||
|
||||
StatusOr<std::pair<ServiceExecutableRunOptions, StreamPool::Ptr>> RunHelper(
|
||||
const absl::Span<const ShapedBuffer* const> arguments,
|
||||
ExecutableRunOptions run_options);
|
||||
|
||||
// The ordinal of the device which this executable was compiled for. The
|
||||
// executable can run on all equivalent devices (as determined by
|
||||
// Backend::devices_equivalent).
|
||||
|
@ -112,6 +112,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -13,6 +13,51 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Implementation notes:
|
||||
//
|
||||
// Asynchronous execution:
|
||||
// -----------------------
|
||||
//
|
||||
// If 'asynchronous' is set when constructing the client, computations and
|
||||
// host-to-device transfers do not block the host waiting for the operation to
|
||||
// complete but instead return control to the host immediately. This allows
|
||||
// Python logic to overlap with device-side computation.
|
||||
//
|
||||
// For a good user experience, we must be careful only to enqueue operations
|
||||
// that are unlikely to fail; as a rule error checking must be done eagerly
|
||||
// before returning control to the client.
|
||||
//
|
||||
// Multi-stream execution:
|
||||
// -----------------------
|
||||
//
|
||||
// On certain platforms (e.g., TPU), we use a multistream execution design,
|
||||
// where different Streams are used for host-to-device transfers,
|
||||
// device-to-host transfers, and compute. This allows us to overlap transfers on
|
||||
// and off the device with computation.
|
||||
//
|
||||
// Synchronization between streams occurs via BufferDefinitionEvents that
|
||||
// describe when the contents of a logical buffer are known to be valid on
|
||||
// a particular stream.
|
||||
//
|
||||
// Synchronous vs asynchronous deallocation:
|
||||
// -----------------------------------------
|
||||
//
|
||||
// In asynchronous deallocation mode (currently only enabled on TPU), the client
|
||||
// need only keep buffers alive from its perspective until all operations that
|
||||
// touch those buffers have been enqueued.
|
||||
// The allocator and lower-level runtime is responsible for keeping buffers
|
||||
// alive (if that is needed) from the perspective of the device until any
|
||||
// device-side work actually completes. The client's use of the device allocator
|
||||
// thereby corresponds to a view of the tail of the compute stream instead of
|
||||
// its head.
|
||||
//
|
||||
// In synchronous deallocation mode the client is responsible for keeping
|
||||
// buffers alive until all device-side activity that consumes those buffers has
|
||||
// ceased. This is the case for CPU since HostExecutor performs allocation
|
||||
// and deallocation eagerly. In this mode, the client's use of the device
|
||||
// allocator is logically synchronized to the head of the compute stream, not
|
||||
// the tail.
|
||||
|
||||
#include "tensorflow/compiler/xla/python/local_client.h"
|
||||
|
||||
#include <memory>
|
||||
@ -23,6 +68,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/synchronization/blocking_counter.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/synchronization/notification.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
@ -60,40 +106,100 @@ Status RegisterCpuCustomCallTarget(const std::string& fn_name,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<py::object> PythonRefManager::ManageReference(
|
||||
const py::object& object) {
|
||||
auto deleter = [this](py::object* x) {
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
python_garbage_.push_back(std::move(*x));
|
||||
}
|
||||
delete x;
|
||||
};
|
||||
return std::shared_ptr<py::object>(new py::object(object), deleter);
|
||||
}
|
||||
|
||||
void PythonRefManager::CollectGarbage() {
|
||||
// TODO(phawkins): ideally we would assert that the GIL is held, but there is
|
||||
// no API to do this across all Python versions.
|
||||
absl::MutexLock lock(&mu_);
|
||||
python_garbage_.clear();
|
||||
}
|
||||
|
||||
Device::Device(se::StreamExecutor* executor, bool use_multiple_streams,
|
||||
bool synchronous_deallocation, bool asynchronous)
|
||||
: use_multiple_streams_(use_multiple_streams),
|
||||
synchronous_deallocation_(synchronous_deallocation),
|
||||
asynchronous_(asynchronous) {
|
||||
compute_stream_ = std::make_shared<se::Stream>(executor);
|
||||
compute_stream_->Init();
|
||||
if (use_multiple_streams) {
|
||||
host_to_device_stream_ = std::make_shared<se::Stream>(executor);
|
||||
device_to_host_stream_ = std::make_shared<se::Stream>(executor);
|
||||
callback_stream_ = std::make_shared<se::Stream>(executor);
|
||||
host_to_device_stream_->Init();
|
||||
device_to_host_stream_->Init();
|
||||
callback_stream_->Init();
|
||||
} else {
|
||||
callback_stream_ = host_to_device_stream_ = device_to_host_stream_ =
|
||||
compute_stream_;
|
||||
}
|
||||
worker_thread_ = absl::make_unique<WorkerThread>(tensorflow::Env::Default(),
|
||||
"py_xla_execute");
|
||||
}
|
||||
|
||||
Device::~Device() { compute_stream_->parent()->SynchronizeAllActivity(); }
|
||||
|
||||
void Device::ThenExecuteOnWorkerThread(se::Stream* stream,
|
||||
std::function<void()> callback) const {
|
||||
stream->ThenDoHostCallback(
|
||||
[this, callback]() { worker_thread_->Schedule(std::move(callback)); });
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PyLocalClient>> PyLocalClient::Get(
|
||||
const std::string& platform_name) {
|
||||
const std::string& platform_name, const std::string& xla_platform_name,
|
||||
bool asynchronous) {
|
||||
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
||||
PlatformUtil::GetPlatform(platform_name));
|
||||
PlatformUtil::GetPlatform(xla_platform_name));
|
||||
if (platform->VisibleDeviceCount() <= 0) {
|
||||
return InvalidArgument("Platform %s has no visible devices.",
|
||||
platform_name);
|
||||
return InvalidArgument("Platform %s (%s) has no visible devices.",
|
||||
platform_name, xla_platform_name);
|
||||
}
|
||||
LocalClientOptions options;
|
||||
options.set_platform(platform);
|
||||
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
||||
ClientLibrary::GetOrCreateLocalClient(options));
|
||||
return absl::make_unique<PyLocalClient>(client);
|
||||
return absl::make_unique<PyLocalClient>(platform_name, client, asynchronous);
|
||||
}
|
||||
|
||||
PyLocalClient::PyLocalClient(LocalClient* client)
|
||||
: client_(client),
|
||||
PyLocalClient::PyLocalClient(std::string platform_name, LocalClient* client,
|
||||
bool asynchronous)
|
||||
: platform_name_(std::move(platform_name)),
|
||||
client_(client),
|
||||
h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
|
||||
client->device_count()) {
|
||||
execute_threads_.reserve(client->device_count());
|
||||
devices_.reserve(client->device_count());
|
||||
// TODO(phawkins): enable multistream mode on GPU too.
|
||||
bool use_multiple_streams = (platform_name == "tpu");
|
||||
bool synchronous_deallocation = !use_multiple_streams;
|
||||
for (int i = 0; i < client->device_count(); ++i) {
|
||||
execute_threads_.push_back(absl::make_unique<WorkerThread>(
|
||||
tensorflow::Env::Default(), "py_xla_execute"));
|
||||
se::StreamExecutor* executor =
|
||||
client_->backend().stream_executor(i).ValueOrDie();
|
||||
devices_.push_back(absl::make_unique<Device>(executor, use_multiple_streams,
|
||||
synchronous_deallocation,
|
||||
asynchronous));
|
||||
}
|
||||
}
|
||||
|
||||
Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal,
|
||||
int device_ordinal) {
|
||||
py_ref_manager().CollectGarbage();
|
||||
py::gil_scoped_release gil_release;
|
||||
return client_->TransferToInfeedLocal(literal, device_ordinal);
|
||||
}
|
||||
|
||||
StatusOr<pybind11::object> PyLocalClient::TransferFromOutfeed(
|
||||
const Shape& shape, int device_ordinal) {
|
||||
py_ref_manager().CollectGarbage();
|
||||
Literal literal;
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
@ -105,7 +211,7 @@ StatusOr<pybind11::object> PyLocalClient::TransferFromOutfeed(
|
||||
|
||||
static StatusOr<PyLocalBuffer> TransferHostToDeviceAsync(
|
||||
const PythonBufferTree& tree, int device_ordinal, PyLocalClient* client,
|
||||
se::Stream* stream) {
|
||||
const Device& device) {
|
||||
DeviceMemoryAllocator* allocator =
|
||||
client->client()->backend().memory_allocator();
|
||||
TransferManager* transfer_manager =
|
||||
@ -115,8 +221,8 @@ static StatusOr<PyLocalBuffer> TransferHostToDeviceAsync(
|
||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer,
|
||||
transfer_manager->AllocateScopedShapedBuffer(
|
||||
shape, allocator, device_ordinal));
|
||||
TF_RETURN_IF_ERROR(
|
||||
transfer_manager->WriteTupleIndexTablesAsync(stream, buffer));
|
||||
TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync(
|
||||
device.host_to_device_stream(), buffer));
|
||||
|
||||
auto it = tree.leaves.begin();
|
||||
for (const ShapeUtil::IndexedShape& indexed_shape :
|
||||
@ -127,13 +233,29 @@ static StatusOr<PyLocalBuffer> TransferHostToDeviceAsync(
|
||||
transfer_manager->HostShapeToDeviceShape(indexed_shape.shape),
|
||||
client->client()->platform(), device_ordinal);
|
||||
leaf.buffers().CopySubtreeFrom(buffer.buffers(), indexed_shape.index, {});
|
||||
TF_RETURN_IF_ERROR(
|
||||
transfer_manager->TransferLiteralToDeviceAsync(stream, *it, leaf));
|
||||
if (device.use_multiple_streams() &&
|
||||
!transfer_manager->CanShapedBufferBeAccessedNow(
|
||||
device.host_to_device_stream()->parent(), leaf)) {
|
||||
device.host_to_device_stream()->ThenWaitFor(device.compute_stream());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDeviceAsync(
|
||||
device.host_to_device_stream(), *it, leaf));
|
||||
++it;
|
||||
}
|
||||
return PyLocalBuffer(
|
||||
shape, PySharedDeviceBuffer::FromScopedShapedBuffer(std::move(buffer)),
|
||||
client);
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event;
|
||||
if (device.use_multiple_streams()) {
|
||||
definition_event = std::make_shared<BufferDefinitionEvent>(
|
||||
device.host_to_device_stream()->parent());
|
||||
definition_event->RecordOnStream(device.host_to_device_stream());
|
||||
}
|
||||
std::shared_ptr<PySharedDeviceBuffer> device_buffer =
|
||||
PySharedDeviceBuffer::FromScopedShapedBuffer(std::move(buffer),
|
||||
definition_event);
|
||||
if (device.synchronous_deallocation()) {
|
||||
device.ThenReleaseOnWorkerThread(device.host_to_device_stream(),
|
||||
device_buffer);
|
||||
}
|
||||
return PyLocalBuffer(shape, std::move(device_buffer), client);
|
||||
}
|
||||
|
||||
/* static */
|
||||
@ -143,18 +265,26 @@ StatusOr<PyLocalBuffer> PyLocalBuffer::FromPython(const py::object& argument,
|
||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromPython");
|
||||
TF_ASSIGN_OR_RETURN(PythonBufferTree tree, GetPythonBufferTree(argument));
|
||||
|
||||
client->py_ref_manager().CollectGarbage();
|
||||
|
||||
// Take a reference to the buffer to ensure that the inputs in host memory
|
||||
// remain live until the transfer is complete.
|
||||
auto py_buffer_ref = client->py_ref_manager().ManageReference(argument);
|
||||
|
||||
// We are done manipulating Python objects; release the GIL.
|
||||
py::gil_scoped_release gil_release;
|
||||
VLOG(1) << "PyLocalBuffer::FromPython: shape: " << tree.shape.ToString()
|
||||
<< " device ordinal: " << device_ordinal;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
StreamPool::Ptr stream,
|
||||
client->client()->mutable_backend()->BorrowStream(device_ordinal));
|
||||
const Device& device = client->device(device_ordinal);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
PyLocalBuffer buffer,
|
||||
TransferHostToDeviceAsync(tree, device_ordinal, client, stream.get()));
|
||||
stream->BlockHostUntilDone();
|
||||
TransferHostToDeviceAsync(tree, device_ordinal, client, device));
|
||||
|
||||
device.ThenRelease(device.host_to_device_stream(), std::move(py_buffer_ref));
|
||||
if (!device.asynchronous()) {
|
||||
device.host_to_device_stream()->BlockHostUntilDone();
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
@ -171,29 +301,25 @@ PyLocalBuffer::FromPythonValues(
|
||||
|
||||
struct H2DTransfer {
|
||||
PythonBufferTree tree;
|
||||
StreamPool::Ptr stream;
|
||||
StatusOr<PyLocalBuffer> buffer;
|
||||
std::shared_ptr<py::object> py_buffer_ref;
|
||||
};
|
||||
|
||||
std::vector<H2DTransfer> transfers(num_arguments);
|
||||
for (int i = 0; i < num_arguments; ++i) {
|
||||
TF_ASSIGN_OR_RETURN(transfers[i].tree,
|
||||
GetPythonBufferTree(arguments[i].first));
|
||||
transfers[i].py_buffer_ref =
|
||||
client->py_ref_manager().ManageReference(arguments[i].first);
|
||||
}
|
||||
client->py_ref_manager().CollectGarbage();
|
||||
// We are done manipulating Python objects; release the GIL.
|
||||
py::gil_scoped_release gil_release;
|
||||
|
||||
for (int i = 0; i < num_arguments; ++i) {
|
||||
int device_ordinal = arguments[i].second;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
transfers[i].stream,
|
||||
client->client()->mutable_backend()->BorrowStream(device_ordinal));
|
||||
}
|
||||
|
||||
auto transfer_h2d = [&](int i) -> StatusOr<PyLocalBuffer> {
|
||||
int device_ordinal = arguments[i].second;
|
||||
return TransferHostToDeviceAsync(transfers[i].tree, device_ordinal, client,
|
||||
transfers[i].stream.get());
|
||||
client->device(device_ordinal));
|
||||
};
|
||||
|
||||
// We perform the transfers on a thread pool in case XLA needs to do any
|
||||
@ -201,26 +327,27 @@ PyLocalBuffer::FromPythonValues(
|
||||
if (num_arguments == 1) {
|
||||
transfers[0].buffer = transfer_h2d(0);
|
||||
} else {
|
||||
absl::BlockingCounter counter(num_arguments - 1);
|
||||
for (int i = 1; i < num_arguments; ++i) {
|
||||
absl::BlockingCounter counter(num_arguments);
|
||||
for (int i = 0; i < num_arguments; ++i) {
|
||||
client->h2d_transfer_pool()->Schedule([&, i]() {
|
||||
transfers[i].buffer = transfer_h2d(i);
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
// Perform the first transfer on the main thread.
|
||||
transfers[0].buffer = transfer_h2d(0);
|
||||
counter.Wait();
|
||||
}
|
||||
|
||||
// First, wait for all transfers to complete. We wait for all to complete
|
||||
// since currently we maintain the invariant that the device's view of the
|
||||
// state matches the host's view of the state. Returning early would mean that
|
||||
// we might deallocate device-side memory before a transfer completes, which
|
||||
// violates that invariant.
|
||||
// Release our references once the transfers have completed.
|
||||
for (int i = 0; i < num_arguments; ++i) {
|
||||
transfers[i].stream->BlockHostUntilDone();
|
||||
int device_ordinal = arguments[i].second;
|
||||
const Device& device = client->device(device_ordinal);
|
||||
device.ThenRelease(device.host_to_device_stream(),
|
||||
std::move(transfers[i].py_buffer_ref));
|
||||
if (!device.asynchronous()) {
|
||||
device.host_to_device_stream()->BlockHostUntilDone();
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_arguments; ++i) {
|
||||
TF_ASSIGN_OR_RETURN(outputs[i], std::move(transfers[i].buffer));
|
||||
}
|
||||
@ -244,21 +371,42 @@ PyLocalBuffer::FromPythonValues(
|
||||
client->client()->backend().memory_allocator();
|
||||
TransferManager* transfer_manager =
|
||||
client->client()->backend().transfer_manager();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::shared_ptr<PySharedDeviceBuffer> tuple_buffer,
|
||||
PySharedDeviceBuffer::MakeTuple(device_buffers, transfer_manager,
|
||||
allocator, device_ordinal));
|
||||
const Device& device = client->device(device_ordinal);
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event;
|
||||
if (device.use_multiple_streams()) {
|
||||
definition_event = std::make_shared<BufferDefinitionEvent>(
|
||||
device.host_to_device_stream()->parent());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<PySharedDeviceBuffer> tuple_buffer,
|
||||
PySharedDeviceBuffer::MakeTuple(
|
||||
device_buffers, transfer_manager, allocator,
|
||||
device_ordinal, definition_event));
|
||||
PyLocalBuffer buffer(ShapeUtil::MakeTupleShape(host_shapes), tuple_buffer,
|
||||
client);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
StreamPool::Ptr stream,
|
||||
client->client()->mutable_backend()->BorrowStream(device_ordinal));
|
||||
// TODO(phawkins): extend TransferManager so we do not need to form a full
|
||||
// ShapedBuffer just to write the root tuple index table.
|
||||
transfer_manager->WriteRootTupleIndexTable(stream.get(),
|
||||
buffer.AsShapedBuffer());
|
||||
stream->BlockHostUntilDone();
|
||||
ShapedBuffer shaped_buffer = buffer.AsShapedBuffer();
|
||||
if (device.use_multiple_streams() &&
|
||||
!transfer_manager->CanShapedBufferBeAccessedNow(
|
||||
device.host_to_device_stream()->parent(), shaped_buffer)) {
|
||||
// Wait for the compute stream so that memory allocations are synchronized.
|
||||
device.host_to_device_stream()->ThenWaitFor(device.compute_stream());
|
||||
}
|
||||
transfer_manager->WriteRootTupleIndexTable(device.host_to_device_stream(),
|
||||
shaped_buffer);
|
||||
if (definition_event) {
|
||||
definition_event->RecordOnStream(device.host_to_device_stream());
|
||||
}
|
||||
|
||||
if (device.synchronous_deallocation()) {
|
||||
device.ThenReleaseOnWorkerThread(device.host_to_device_stream(),
|
||||
std::move(tuple_buffer));
|
||||
}
|
||||
if (!device.asynchronous()) {
|
||||
device.host_to_device_stream()->BlockHostUntilDone();
|
||||
}
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
@ -271,11 +419,21 @@ PyLocalBuffer::PyLocalBuffer(
|
||||
|
||||
StatusOr<py::object> PyLocalBuffer::ToPython() const {
|
||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::ToPython");
|
||||
auto literal = absl::make_unique<Literal>();
|
||||
auto literal = absl::make_unique<Literal>(on_host_shape());
|
||||
client_->py_ref_manager().CollectGarbage();
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*literal, client_->client()->ShapedBufferToLiteral(AsShapedBuffer()));
|
||||
se::Stream* stream = client_->device(device_buffer_->device_ordinal())
|
||||
.device_to_host_stream();
|
||||
WaitForBufferDefinitionEventsOnStream(*device_buffer_, stream);
|
||||
absl::Notification done;
|
||||
Status status;
|
||||
client_->client()->backend().transfer_manager()->TransferLiteralFromDevice(
|
||||
stream, AsShapedBuffer(), *literal, [&](Status done_status) {
|
||||
status = done_status;
|
||||
done.Notify();
|
||||
});
|
||||
done.WaitForNotification();
|
||||
}
|
||||
return LiteralToPython(std::move(literal));
|
||||
}
|
||||
@ -303,7 +461,7 @@ StatusOr<std::vector<PyLocalBuffer>> PyLocalBuffer::DestructureTuple() {
|
||||
}
|
||||
|
||||
PyLocalExecutable::PyLocalExecutable(
|
||||
std::unique_ptr<LocalExecutable> executable,
|
||||
std::shared_ptr<LocalExecutable> executable,
|
||||
DeviceAssignment device_assignment, PyLocalClient* client)
|
||||
: executable_(std::move(executable)),
|
||||
device_assignment_(std::move(device_assignment)),
|
||||
@ -319,18 +477,14 @@ std::vector<int> PyLocalExecutable::DeviceOrdinals() const {
|
||||
return device_ordinals;
|
||||
}
|
||||
|
||||
StatusOr<PyLocalBuffer> PyLocalExecutable::Execute(
|
||||
absl::Span<PyLocalBuffer* const> argument_handles) {
|
||||
StatusOr<PyLocalBuffer> PyLocalExecutable::ExecuteHelper(
|
||||
absl::Span<PyLocalBuffer* const> argument_handles, int replica) {
|
||||
const int device_ordinal = device_assignment_(replica, 0);
|
||||
tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
|
||||
if (num_replicas() != 1) {
|
||||
return InvalidArgument(
|
||||
"Attempted to execute computation with %d replicas using Execute()",
|
||||
num_replicas());
|
||||
}
|
||||
const int device_ordinal = device_assignment_(0, 0);
|
||||
VLOG(3) << "Replica 0 mapped to device ordinal for execution: "
|
||||
<< device_ordinal;
|
||||
VLOG(3) << "Replica " << replica
|
||||
<< " mapped to device ordinal for execution: " << device_ordinal;
|
||||
|
||||
absl::flat_hash_set<BufferDefinitionEvent*> events;
|
||||
std::vector<ShapedBuffer> argument_buffers;
|
||||
std::vector<const ShapedBuffer*> argument_buffer_ptrs;
|
||||
argument_buffers.reserve(argument_handles.size());
|
||||
@ -338,23 +492,70 @@ StatusOr<PyLocalBuffer> PyLocalExecutable::Execute(
|
||||
for (auto& handle : argument_handles) {
|
||||
argument_buffers.push_back(handle->AsShapedBuffer());
|
||||
argument_buffer_ptrs.push_back(&argument_buffers.back());
|
||||
GetDeviceBufferDefinitionEvents(*handle->device_buffer(), &events);
|
||||
VLOG(4) << "Argument " << argument_buffers.size() - 1
|
||||
<< " buffer: " << argument_buffers.back().ToString();
|
||||
}
|
||||
|
||||
const Device& device = client_->device(device_ordinal);
|
||||
for (BufferDefinitionEvent* event : events) {
|
||||
event->WaitForEventOnStream(device.compute_stream());
|
||||
}
|
||||
|
||||
ExecutableRunOptions options;
|
||||
options.set_device_ordinal(device_ordinal);
|
||||
options.set_stream(device.compute_stream());
|
||||
options.set_host_to_device_stream(device.host_to_device_stream());
|
||||
options.set_allocator(client_->client()->backend().memory_allocator());
|
||||
options.set_intra_op_thread_pool(
|
||||
client_->client()->backend().eigen_intra_op_thread_pool_device());
|
||||
options.set_device_assignment(&device_assignment_);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result_buffer,
|
||||
executable_->Run(argument_buffer_ptrs, options));
|
||||
StatusOr<ScopedShapedBuffer> result_buffer =
|
||||
executable_->RunAsync(argument_buffer_ptrs, options);
|
||||
|
||||
Shape on_host_shape = result_buffer.on_host_shape();
|
||||
return PyLocalBuffer(
|
||||
on_host_shape,
|
||||
PySharedDeviceBuffer::FromScopedShapedBuffer(std::move(result_buffer)),
|
||||
client_);
|
||||
VLOG(1) << "Replica " << replica << " completed; ok=" << result_buffer.ok();
|
||||
if (!result_buffer.ok()) {
|
||||
LOG(ERROR) << "Execution of replica " << replica
|
||||
<< " failed: " << result_buffer.status();
|
||||
return result_buffer.status();
|
||||
}
|
||||
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event;
|
||||
if (device.use_multiple_streams()) {
|
||||
definition_event = std::make_shared<BufferDefinitionEvent>(
|
||||
device.compute_stream()->parent());
|
||||
definition_event->RecordOnStream(device.compute_stream());
|
||||
}
|
||||
Shape on_host_shape = result_buffer.ValueOrDie().on_host_shape();
|
||||
std::shared_ptr<PySharedDeviceBuffer> out_buffer =
|
||||
PySharedDeviceBuffer::FromScopedShapedBuffer(
|
||||
std::move(result_buffer.ValueOrDie()), definition_event);
|
||||
|
||||
if (device.synchronous_deallocation()) {
|
||||
std::vector<std::shared_ptr<PySharedDeviceBuffer>> buffers;
|
||||
buffers.reserve(argument_handles.size() + 1);
|
||||
for (auto& handle : argument_handles) {
|
||||
buffers.push_back(handle->device_buffer());
|
||||
}
|
||||
buffers.push_back(out_buffer);
|
||||
device.ThenReleaseOnWorkerThread(device.compute_stream(),
|
||||
std::move(buffers));
|
||||
device.ThenReleaseOnWorkerThread(device.compute_stream(), executable_);
|
||||
}
|
||||
if (!device.asynchronous()) {
|
||||
device.compute_stream()->BlockHostUntilDone();
|
||||
}
|
||||
return PyLocalBuffer(on_host_shape, std::move(out_buffer), client_);
|
||||
}
|
||||
|
||||
StatusOr<PyLocalBuffer> PyLocalExecutable::Execute(
|
||||
absl::Span<PyLocalBuffer* const> argument_handles) {
|
||||
if (num_replicas() != 1) {
|
||||
return InvalidArgument(
|
||||
"Attempted to execute computation with %d replicas using Execute()",
|
||||
num_replicas());
|
||||
}
|
||||
return ExecuteHelper(argument_handles, /*replica=*/0);
|
||||
}
|
||||
|
||||
StatusOr<std::vector<PyLocalBuffer>> PyLocalExecutable::ExecutePerReplica(
|
||||
@ -373,53 +574,13 @@ StatusOr<std::vector<PyLocalBuffer>> PyLocalExecutable::ExecutePerReplica(
|
||||
argument_handles.size(), num_devices);
|
||||
}
|
||||
|
||||
VLOG(1) << "Executing with " << num_replicas() << " replicas.";
|
||||
|
||||
auto execute = [this,
|
||||
&argument_handles](int replica) -> StatusOr<PyLocalBuffer> {
|
||||
const int device_ordinal = device_assignment_(replica, 0);
|
||||
VLOG(3) << "Replica " << replica
|
||||
<< " mapped to device ordinal for execution: " << device_ordinal;
|
||||
|
||||
std::vector<ShapedBuffer> argument_buffers;
|
||||
std::vector<const ShapedBuffer*> argument_buffer_ptrs;
|
||||
argument_buffers.reserve(argument_handles[replica].size());
|
||||
argument_buffer_ptrs.reserve(argument_handles[replica].size());
|
||||
for (auto& handle : argument_handles[replica]) {
|
||||
argument_buffers.push_back(handle->AsShapedBuffer());
|
||||
argument_buffer_ptrs.push_back(&argument_buffers.back());
|
||||
}
|
||||
|
||||
ExecutableRunOptions options;
|
||||
options.set_device_ordinal(device_ordinal);
|
||||
options.set_allocator(client_->client()->backend().memory_allocator());
|
||||
options.set_intra_op_thread_pool(
|
||||
client_->client()->backend().eigen_intra_op_thread_pool_device());
|
||||
options.set_device_assignment(&device_assignment_);
|
||||
StatusOr<ScopedShapedBuffer> result_buffer =
|
||||
executable_->Run(argument_buffer_ptrs, options);
|
||||
|
||||
VLOG(1) << "Replica " << replica << " completed; ok=" << result_buffer.ok();
|
||||
if (!result_buffer.ok()) {
|
||||
LOG(ERROR) << "Execution of replica " << replica
|
||||
<< " failed: " << result_buffer.status();
|
||||
return result_buffer.status();
|
||||
}
|
||||
Shape on_host_shape = result_buffer.ValueOrDie().on_host_shape();
|
||||
|
||||
return PyLocalBuffer(on_host_shape,
|
||||
PySharedDeviceBuffer::FromScopedShapedBuffer(
|
||||
std::move(result_buffer.ValueOrDie())),
|
||||
client_);
|
||||
};
|
||||
|
||||
VLOG(1) << "Executing replicated computation; num_replicas="
|
||||
<< num_replicas();
|
||||
std::vector<StatusOr<PyLocalBuffer>> results(num_replicas());
|
||||
if (num_replicas() == 1) {
|
||||
// Fast-path if there is only one replica — run the computation on the
|
||||
// current thread.
|
||||
results[0] = execute(0);
|
||||
results[0] = ExecuteHelper(argument_handles[0], /*replica=*/0);
|
||||
} else {
|
||||
absl::Mutex mu;
|
||||
int running GUARDED_BY(mu) = num_replicas();
|
||||
@ -427,8 +588,10 @@ StatusOr<std::vector<PyLocalBuffer>> PyLocalExecutable::ExecutePerReplica(
|
||||
Status first_failure_status GUARDED_BY(mu);
|
||||
|
||||
for (int replica = 0; replica < num_replicas(); ++replica) {
|
||||
client_->execute_threads().at(replica)->Schedule([&, replica] {
|
||||
results[replica] = execute(replica);
|
||||
const int device_ordinal = device_assignment_(replica, 0);
|
||||
const Device& device = client_->device(device_ordinal);
|
||||
device.worker_thread()->Schedule([&, replica] {
|
||||
results[replica] = ExecuteHelper(argument_handles[replica], replica);
|
||||
|
||||
absl::MutexLock lock(&mu);
|
||||
--running;
|
||||
@ -541,7 +704,8 @@ PyLocalExecutable::Compile(const XlaComputation& computation,
|
||||
options.num_replicas(), /*computation_count=*/1));
|
||||
|
||||
return absl::make_unique<PyLocalExecutable>(
|
||||
std::move(local_executable), std::move(device_assignment), client);
|
||||
std::shared_ptr<LocalExecutable>(std::move(local_executable)),
|
||||
std::move(device_assignment), client);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_
|
||||
|
||||
#include <deque>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -40,35 +41,165 @@ namespace xla {
|
||||
Status RegisterCpuCustomCallTarget(const std::string& fn_name,
|
||||
pybind11::capsule capsule);
|
||||
|
||||
// Class that manages destruction of Python objects.
|
||||
//
|
||||
// We must not destroy Python objects without holding the GIL. However, we
|
||||
// frequently want to hold references to Python objects for the duration of
|
||||
// an asynchronous transfer on a Stream, and release our reference when the
|
||||
// transfer completes.
|
||||
//
|
||||
// This class holds references to Python objects outside a GIL scope, that can
|
||||
// be collected later when the GIL is held by calling CollectGarbage().
|
||||
class PythonRefManager {
|
||||
public:
|
||||
PythonRefManager() = default;
|
||||
|
||||
// Creates a managed std::shared_ptr to an object. When the shared_ptr is
|
||||
// destroyed, the reference to 'object' will be added to python_garbage_,
|
||||
// and collected next time CollectGarbage() is called.
|
||||
std::shared_ptr<pybind11::object> ManageReference(
|
||||
const pybind11::object& object);
|
||||
|
||||
// Releases the contents of python_garbage_. Requires that the GIL is held.
|
||||
// The client calls this method during API entry points where the GIL is held
|
||||
// to free any garbage that has accumulated.
|
||||
void CollectGarbage();
|
||||
|
||||
private:
|
||||
absl::Mutex mu_;
|
||||
std::deque<pybind11::object> python_garbage_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
// Class that encapsulates state relating to a device (e.g., a GPU) on which we
|
||||
// can perform computation and transfers.
|
||||
class Device {
|
||||
public:
|
||||
// If use_multiple_streams is true, we allocate separate streams for compute
|
||||
// and transfers. If it is false, we share a single stream for compute and
|
||||
// transfers. The CPU device does not support multiple streams, and this is
|
||||
// a workaround until it does.
|
||||
//
|
||||
// If synchronous_deallocation is true, the host must not free buffers until
|
||||
// compute/transfers that use those buffers have completed. For example, this
|
||||
// typically is the case for the "platform" where compute/transfers are
|
||||
// operations that take place on another thread.
|
||||
//
|
||||
// If asynchronous is false, the host will synchronize to the device after
|
||||
// each execution or transfer. This is intended for debugging only.
|
||||
Device(se::StreamExecutor* executor, bool use_multiple_streams,
|
||||
bool synchronous_deallocation, bool asynchronous);
|
||||
~Device();
|
||||
|
||||
bool use_multiple_streams() const { return use_multiple_streams_; }
|
||||
bool synchronous_deallocation() const { return synchronous_deallocation_; }
|
||||
bool asynchronous() const { return asynchronous_; }
|
||||
se::Stream* compute_stream() const { return compute_stream_.get(); }
|
||||
se::Stream* host_to_device_stream() const {
|
||||
return host_to_device_stream_.get();
|
||||
}
|
||||
se::Stream* device_to_host_stream() const {
|
||||
return device_to_host_stream_.get();
|
||||
}
|
||||
|
||||
// A worker thread, used for replicated computation launches and callbacks.
|
||||
WorkerThread* worker_thread() const { return worker_thread_.get(); }
|
||||
|
||||
// Enqueues a host callback on 'stream', to be executed by worker_thread_.
|
||||
// ThenDoHostCallback is often constrained in what it can do, in particular,
|
||||
// on GPU the callback runs on a thread belonging to the GPU runtime and
|
||||
// cannot perform GPU operations itself.
|
||||
void ThenExecuteOnWorkerThread(se::Stream* stream,
|
||||
std::function<void()> callback) const;
|
||||
|
||||
// Helper for releasing values from a callback at the tail of a stream.
|
||||
// This is only permitted if object's destructor will not free any device
|
||||
// objects, since the callback may be called from a device thread pool on
|
||||
// GPU.
|
||||
template <typename T>
|
||||
void ThenRelease(se::Stream* stream, std::shared_ptr<T> object) const {
|
||||
if (callback_stream_.get() != stream) {
|
||||
callback_stream_->ThenWaitFor(stream);
|
||||
}
|
||||
callback_stream_->ThenDoHostCallback([object]() { /* releases object */ });
|
||||
}
|
||||
|
||||
// Helpers for releasing values on a worker thread at the tail of a stream on
|
||||
// a worker thread.
|
||||
template <typename T>
|
||||
void ThenReleaseOnWorkerThread(se::Stream* stream,
|
||||
std::shared_ptr<T> object) const {
|
||||
// We use a non-smart pointer here because we want to ensure that the worker
|
||||
// thread is the only callee of the shared_ptr destructor, and if we passed
|
||||
// object by lambda capture we have a race where the worker thread might
|
||||
// run and release its reference first.
|
||||
auto* ref = new std::shared_ptr<T>(std::move(object));
|
||||
if (callback_stream_.get() != stream) {
|
||||
callback_stream_->ThenWaitFor(stream);
|
||||
}
|
||||
ThenExecuteOnWorkerThread(callback_stream_.get(), [ref]() { delete ref; });
|
||||
}
|
||||
template <typename T>
|
||||
void ThenReleaseOnWorkerThread(se::Stream* stream,
|
||||
std::vector<std::shared_ptr<T>> object) const {
|
||||
auto* ref = new std::vector<std::shared_ptr<T>>(std::move(object));
|
||||
if (callback_stream_.get() != stream) {
|
||||
callback_stream_->ThenWaitFor(stream);
|
||||
}
|
||||
ThenExecuteOnWorkerThread(callback_stream_.get(), [ref]() { delete ref; });
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_multiple_streams_;
|
||||
bool synchronous_deallocation_;
|
||||
bool asynchronous_;
|
||||
std::shared_ptr<se::Stream> compute_stream_;
|
||||
std::shared_ptr<se::Stream> host_to_device_stream_;
|
||||
std::shared_ptr<se::Stream> device_to_host_stream_;
|
||||
|
||||
// Callback stream is used for running short host-side callbacks after device
|
||||
// side events, without preventing the device-side stream from doing useful
|
||||
// work.
|
||||
std::shared_ptr<se::Stream> callback_stream_;
|
||||
|
||||
std::unique_ptr<WorkerThread> worker_thread_;
|
||||
};
|
||||
|
||||
// Encapsulates the state of Python session with XLA.
|
||||
class PyLocalClient {
|
||||
public:
|
||||
// Initializes a local XLA client for `platform_name`. Returns an error if no
|
||||
// such platform exists, or if the platform has no visible devices.
|
||||
static StatusOr<std::unique_ptr<PyLocalClient>> Get(
|
||||
const std::string& platform_name);
|
||||
const std::string& platform_name, const std::string& xla_platform_id,
|
||||
bool asynchronous);
|
||||
|
||||
explicit PyLocalClient(LocalClient* client);
|
||||
explicit PyLocalClient(std::string platform_name, LocalClient* client,
|
||||
bool asynchronous);
|
||||
|
||||
Status TransferToInfeed(const LiteralSlice& literal, int device_ordinal);
|
||||
StatusOr<pybind11::object> TransferFromOutfeed(const Shape& shape,
|
||||
int device_ordinal);
|
||||
|
||||
int device_count() const { return client_->device_count(); }
|
||||
const Device& device(int device_ordinal) const {
|
||||
return *devices_.at(device_ordinal);
|
||||
}
|
||||
LocalClient* client() const { return client_; }
|
||||
|
||||
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
|
||||
return &h2d_transfer_pool_;
|
||||
}
|
||||
const std::vector<std::unique_ptr<WorkerThread>>& execute_threads() {
|
||||
return execute_threads_;
|
||||
}
|
||||
|
||||
PythonRefManager& py_ref_manager() { return py_ref_manager_; }
|
||||
|
||||
private:
|
||||
std::string platform_name_;
|
||||
LocalClient* client_;
|
||||
std::vector<std::unique_ptr<Device>> devices_;
|
||||
|
||||
tensorflow::thread::ThreadPool h2d_transfer_pool_;
|
||||
// We use a single worker thread per device, both for simplicity and because
|
||||
// it avoids a deadlock in tensorflow::thread::ThreadPool (b/130761212).
|
||||
std::vector<std::unique_ptr<WorkerThread>> execute_threads_;
|
||||
|
||||
PythonRefManager py_ref_manager_;
|
||||
};
|
||||
|
||||
// Holds a reference from Python to one or more device buffers.
|
||||
@ -125,7 +256,7 @@ class PyLocalExecutable {
|
||||
const XlaComputation& computation, std::vector<Shape> argument_layouts,
|
||||
const ExecutableBuildOptions* build_options, PyLocalClient* client);
|
||||
|
||||
PyLocalExecutable(std::unique_ptr<LocalExecutable> executable,
|
||||
PyLocalExecutable(std::shared_ptr<LocalExecutable> executable,
|
||||
DeviceAssignment device_assignment, PyLocalClient* client);
|
||||
|
||||
int num_replicas() const {
|
||||
@ -151,7 +282,10 @@ class PyLocalExecutable {
|
||||
void Delete() { executable_ = nullptr; }
|
||||
|
||||
private:
|
||||
std::unique_ptr<LocalExecutable> executable_;
|
||||
StatusOr<PyLocalBuffer> ExecuteHelper(
|
||||
absl::Span<PyLocalBuffer* const> argument_handles, int replica);
|
||||
|
||||
std::shared_ptr<LocalExecutable> executable_;
|
||||
const DeviceAssignment device_assignment_;
|
||||
PyLocalClient* const client_;
|
||||
};
|
||||
|
@ -19,12 +19,38 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
BufferDefinitionEvent::BufferDefinitionEvent(se::StreamExecutor* executor)
|
||||
: event_(executor) {}
|
||||
|
||||
void BufferDefinitionEvent::RecordOnStream(se::Stream* stream) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
CHECK(streams_defined_on_.empty());
|
||||
stream->ThenRecordEvent(&event_);
|
||||
streams_defined_on_.push_back(stream);
|
||||
}
|
||||
|
||||
void BufferDefinitionEvent::WaitForEventOnStream(se::Stream* stream) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
||||
// The set of defined streams is expected to be very small indeed (usually
|
||||
// 1-2), so a simple linear scan should be fast enough.
|
||||
if (std::find(streams_defined_on_.begin(), streams_defined_on_.end(),
|
||||
stream) != streams_defined_on_.end()) {
|
||||
// stream is in streams_defined_on_; it doesn't need to be waited on.
|
||||
return;
|
||||
}
|
||||
|
||||
stream->ThenWaitFor(&event_);
|
||||
streams_defined_on_.push_back(stream);
|
||||
}
|
||||
|
||||
static std::shared_ptr<PySharedDeviceBuffer>
|
||||
BufferFromScopedShapedBufferIterator(
|
||||
const Shape& on_device_shape, int device_ordinal,
|
||||
DeviceMemoryAllocator* allocator,
|
||||
ShapeTree<se::DeviceMemoryBase>::iterator* iterator,
|
||||
const ShapeTree<se::DeviceMemoryBase>::iterator& end) {
|
||||
const ShapeTree<se::DeviceMemoryBase>::iterator& end,
|
||||
const std::shared_ptr<BufferDefinitionEvent>& definition_event) {
|
||||
CHECK(*iterator != end);
|
||||
|
||||
OwningDeviceMemory device_memory((*iterator)->second, device_ordinal,
|
||||
@ -39,22 +65,24 @@ BufferFromScopedShapedBufferIterator(
|
||||
for (int i = 0; i < num_children; ++i) {
|
||||
children.push_back(BufferFromScopedShapedBufferIterator(
|
||||
on_device_shape.tuple_shapes(i), device_ordinal, allocator, iterator,
|
||||
end));
|
||||
end, definition_event));
|
||||
}
|
||||
}
|
||||
return std::make_shared<PySharedDeviceBuffer>(
|
||||
on_device_shape, std::move(device_memory), children);
|
||||
on_device_shape, std::move(device_memory), children, definition_event);
|
||||
}
|
||||
|
||||
/* static */ std::shared_ptr<PySharedDeviceBuffer>
|
||||
PySharedDeviceBuffer::FromScopedShapedBuffer(ScopedShapedBuffer shaped_buffer) {
|
||||
PySharedDeviceBuffer::FromScopedShapedBuffer(
|
||||
ScopedShapedBuffer shaped_buffer,
|
||||
const std::shared_ptr<BufferDefinitionEvent>& definition_event) {
|
||||
ShapeTree<se::DeviceMemoryBase>::iterator iterator =
|
||||
shaped_buffer.buffers().begin();
|
||||
std::shared_ptr<PySharedDeviceBuffer> output =
|
||||
BufferFromScopedShapedBufferIterator(
|
||||
shaped_buffer.on_device_shape(), shaped_buffer.device_ordinal(),
|
||||
shaped_buffer.memory_allocator(), &iterator,
|
||||
shaped_buffer.buffers().end());
|
||||
shaped_buffer.buffers().end(), definition_event);
|
||||
CHECK(iterator == shaped_buffer.buffers().end());
|
||||
return output;
|
||||
}
|
||||
@ -63,7 +91,8 @@ PySharedDeviceBuffer::FromScopedShapedBuffer(ScopedShapedBuffer shaped_buffer) {
|
||||
PySharedDeviceBuffer::MakeTuple(
|
||||
std::vector<std::shared_ptr<PySharedDeviceBuffer>> children,
|
||||
TransferManager* transfer_manager, DeviceMemoryAllocator* allocator,
|
||||
int device_ordinal) {
|
||||
int device_ordinal,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event) {
|
||||
std::vector<Shape> child_shapes;
|
||||
child_shapes.reserve(children.size());
|
||||
for (const auto& child : children) {
|
||||
@ -77,14 +106,15 @@ PySharedDeviceBuffer::MakeTuple(
|
||||
allocator->Allocate(device_ordinal,
|
||||
transfer_manager->GetByteSizeRequirement(shape)));
|
||||
return std::make_shared<PySharedDeviceBuffer>(
|
||||
std::move(shape), std::move(device_memory), std::move(children));
|
||||
std::move(shape), std::move(device_memory), std::move(children),
|
||||
std::move(definition_event));
|
||||
}
|
||||
|
||||
/* static */ StatusOr<std::shared_ptr<PySharedDeviceBuffer>>
|
||||
PySharedDeviceBuffer::MakeArray(Shape on_device_shape,
|
||||
TransferManager* transfer_manager,
|
||||
DeviceMemoryAllocator* allocator,
|
||||
int device_ordinal) {
|
||||
PySharedDeviceBuffer::MakeArray(
|
||||
Shape on_device_shape, TransferManager* transfer_manager,
|
||||
DeviceMemoryAllocator* allocator, int device_ordinal,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
OwningDeviceMemory device_memory,
|
||||
allocator->Allocate(
|
||||
@ -92,7 +122,8 @@ PySharedDeviceBuffer::MakeArray(Shape on_device_shape,
|
||||
transfer_manager->GetByteSizeRequirement(on_device_shape)));
|
||||
return std::make_shared<PySharedDeviceBuffer>(
|
||||
std::move(on_device_shape), std::move(device_memory),
|
||||
/*children=*/std::vector<std::shared_ptr<PySharedDeviceBuffer>>{});
|
||||
/*children=*/std::vector<std::shared_ptr<PySharedDeviceBuffer>>{},
|
||||
std::move(definition_event));
|
||||
}
|
||||
|
||||
// Populates a buffer tree from a ShapeTree iterator.
|
||||
@ -123,9 +154,31 @@ ShapedBuffer PySharedDeviceBuffer::AsShapedBuffer(
|
||||
|
||||
PySharedDeviceBuffer::PySharedDeviceBuffer(
|
||||
Shape on_device_shape, OwningDeviceMemory device_memory,
|
||||
std::vector<std::shared_ptr<PySharedDeviceBuffer>> children)
|
||||
std::vector<std::shared_ptr<PySharedDeviceBuffer>> children,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event)
|
||||
: on_device_shape_(std::move(on_device_shape)),
|
||||
device_memory_(std::move(device_memory)),
|
||||
children_(std::move(children)) {}
|
||||
children_(std::move(children)),
|
||||
definition_event_(std::move(definition_event)) {}
|
||||
|
||||
void GetDeviceBufferDefinitionEvents(
|
||||
const PySharedDeviceBuffer& buffer,
|
||||
absl::flat_hash_set<BufferDefinitionEvent*>* events) {
|
||||
if (buffer.definition_event()) {
|
||||
events->insert(buffer.definition_event().get());
|
||||
}
|
||||
for (const auto& child : buffer.children()) {
|
||||
GetDeviceBufferDefinitionEvents(*child, events);
|
||||
}
|
||||
}
|
||||
|
||||
void WaitForBufferDefinitionEventsOnStream(const PySharedDeviceBuffer& buffer,
|
||||
se::Stream* stream) {
|
||||
absl::flat_hash_set<BufferDefinitionEvent*> events;
|
||||
GetDeviceBufferDefinitionEvents(buffer, &events);
|
||||
for (BufferDefinitionEvent* event : events) {
|
||||
event->WaitForEventOnStream(stream);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_SHARED_DEVICE_BUFFER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_SHARED_DEVICE_BUFFER_H_
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/owning_device_memory.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
@ -24,6 +25,56 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// A BufferDefinitionEvent describes whether a buffer is valid from the
|
||||
// viewpoint of each of stream that may access it.
|
||||
//
|
||||
// Each logical buffer in an XLA computation may be defined (i.e., written to)
|
||||
// at most once, although the same physical piece of memory may be reused for
|
||||
// multiple logical buffers. We call the operation that writes the buffer's
|
||||
// value on some stream (e.g., a transfer or compute kernel) the buffer's
|
||||
// definition event.
|
||||
//
|
||||
// After the operation that populates the value of a buffer has been enqueued on
|
||||
// 'stream', RecordOnStream(stream) should also be called to trigger the
|
||||
// definition event after the operation has completed.
|
||||
//
|
||||
// Since different streams are not necessarily synchronized with one another,
|
||||
// if we wish to consume the value of the buffer on a different stream, we
|
||||
// should first call WaitForEventOnStream(stream), which add a cross-stream
|
||||
// from 'stream' to the buffer's definition event, causing 'stream' to pause
|
||||
// until the definition event has been triggered, if needed. Operations on
|
||||
// 'stream' may then assume that the buffer is valid and its contents correspond
|
||||
// to the desired buffer.
|
||||
//
|
||||
// The dependency logic caches the set of streams at the tail of which the
|
||||
// definition event is known to have occurred; waiting for the same event on the
|
||||
// same stream causes no additional waiting.
|
||||
class BufferDefinitionEvent {
|
||||
public:
|
||||
// Creates a new definition event whose event has not yet been triggered.
|
||||
explicit BufferDefinitionEvent(se::StreamExecutor* executor);
|
||||
|
||||
// Records the definition event on the tail of 'stream'.
|
||||
void RecordOnStream(se::Stream* stream);
|
||||
|
||||
// Adds synchronization events to 'stream' that wait for this event to be
|
||||
// defined on 'stream'. Does nothing if the event is already known to have
|
||||
// occurred by the tail of 'stream'.
|
||||
void WaitForEventOnStream(se::Stream* stream);
|
||||
|
||||
private:
|
||||
// An event that is triggered when the content of one or more buffers is
|
||||
// ready. If this event is nullptr, it is assumed that the buffer's content is
|
||||
// always defined.
|
||||
se::Event event_;
|
||||
|
||||
absl::Mutex mu_;
|
||||
|
||||
// A list of all streams for which the buffer's content is known to be defined
|
||||
// at the tail of the queue, i.e., for any newly enqueued command.
|
||||
absl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
// Class that represents a node in a reference-counted DAG of device buffers.
|
||||
// Unlike a ShapedBuffer, which owns none of its buffers, and
|
||||
// ScopedShapedBuffer, which owns an entire buffer tree, the reference counting
|
||||
@ -36,18 +87,21 @@ class PySharedDeviceBuffer {
|
||||
// Converts a ScopedShapedBuffer into a Buffer tree. Takes ownership of the
|
||||
// contents of the shaped_buffer.
|
||||
static std::shared_ptr<PySharedDeviceBuffer> FromScopedShapedBuffer(
|
||||
ScopedShapedBuffer shaped_buffer);
|
||||
ScopedShapedBuffer shaped_buffer,
|
||||
const std::shared_ptr<BufferDefinitionEvent>& definition_event);
|
||||
|
||||
// Makes a tuple buffer. Does not initialize the tuple table.
|
||||
static StatusOr<std::shared_ptr<PySharedDeviceBuffer>> MakeTuple(
|
||||
std::vector<std::shared_ptr<PySharedDeviceBuffer>> children,
|
||||
TransferManager* transfer_manager, DeviceMemoryAllocator* allocator,
|
||||
int device_ordinal);
|
||||
int device_ordinal,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event);
|
||||
|
||||
// Makes an uninitialized array buffer.
|
||||
static StatusOr<std::shared_ptr<PySharedDeviceBuffer>> MakeArray(
|
||||
Shape on_device_shape, TransferManager* transfer_manager,
|
||||
DeviceMemoryAllocator* allocator, int device_ordinal);
|
||||
DeviceMemoryAllocator* allocator, int device_ordinal,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event);
|
||||
|
||||
// Builds a ShapedBuffer view onto the buffers of 'tree'. Since
|
||||
// PySharedDeviceBuffer does not maintain the on-host shape, the caller must
|
||||
@ -60,11 +114,16 @@ class PySharedDeviceBuffer {
|
||||
return children_;
|
||||
}
|
||||
const OwningDeviceMemory& device_memory() const { return device_memory_; }
|
||||
int device_ordinal() const { return device_memory_.device_ordinal(); }
|
||||
const std::shared_ptr<BufferDefinitionEvent> definition_event() const {
|
||||
return definition_event_;
|
||||
}
|
||||
|
||||
PySharedDeviceBuffer() = default;
|
||||
PySharedDeviceBuffer(
|
||||
Shape on_device_shape, OwningDeviceMemory device_memory,
|
||||
std::vector<std::shared_ptr<PySharedDeviceBuffer>> children);
|
||||
std::vector<std::shared_ptr<PySharedDeviceBuffer>> children,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event);
|
||||
|
||||
private:
|
||||
// We only represent the on-device shape. The on-host shape may not be
|
||||
@ -73,8 +132,24 @@ class PySharedDeviceBuffer {
|
||||
Shape on_device_shape_;
|
||||
OwningDeviceMemory device_memory_;
|
||||
std::vector<std::shared_ptr<PySharedDeviceBuffer>> children_;
|
||||
|
||||
// An event that is triggered when the content of one or more buffers is
|
||||
// ready during multistream execution. May be nullptr, which is used in the
|
||||
// single-stream execution case where events are not necessary for buffer
|
||||
// event sequencing.
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event_;
|
||||
};
|
||||
|
||||
// Populates 'events' with the set of buffer definition events for all buffers
|
||||
// in the buffer DAG rooted at 'buffer'.
|
||||
void GetDeviceBufferDefinitionEvents(
|
||||
const PySharedDeviceBuffer& buffer,
|
||||
absl::flat_hash_set<BufferDefinitionEvent*>* events);
|
||||
|
||||
// Waits for all of the buffer definition events in a buffer DAG on 'stream'.
|
||||
void WaitForBufferDefinitionEventsOnStream(const PySharedDeviceBuffer& buffer,
|
||||
se::Stream* stream);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_SHARED_DEVICE_BUFFER_H_
|
||||
|
@ -28,10 +28,10 @@ TEST(PySharedDeviceBufferTest, MakeArray) {
|
||||
LocalClient* client = ClientLibrary::LocalClientOrDie();
|
||||
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {3, 101, 4});
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto buffer,
|
||||
PySharedDeviceBuffer::MakeArray(
|
||||
shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto buffer, PySharedDeviceBuffer::MakeArray(
|
||||
shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, nullptr));
|
||||
EXPECT_EQ(
|
||||
buffer->on_device_shape(),
|
||||
client->backend().transfer_manager()->HostShapeToDeviceShape(shape));
|
||||
@ -48,19 +48,19 @@ TEST(PySharedDeviceBufferTest, MakeTuple) {
|
||||
Shape a_shape = ShapeUtil::MakeShape(F32, {3, 101, 4});
|
||||
Shape b_shape = ShapeUtil::MakeShape(S8, {77});
|
||||
Shape tuple_shape = ShapeUtil::MakeTupleShape({a_shape, b_shape});
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto a_buffer,
|
||||
PySharedDeviceBuffer::MakeArray(
|
||||
a_shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto b_buffer,
|
||||
PySharedDeviceBuffer::MakeArray(
|
||||
b_shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto a_buffer, PySharedDeviceBuffer::MakeArray(
|
||||
a_shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, nullptr));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto b_buffer, PySharedDeviceBuffer::MakeArray(
|
||||
b_shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, nullptr));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto tuple_buffer,
|
||||
PySharedDeviceBuffer::MakeTuple({a_buffer, b_buffer},
|
||||
client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0));
|
||||
PySharedDeviceBuffer::MakeTuple(
|
||||
{a_buffer, b_buffer}, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, nullptr));
|
||||
EXPECT_EQ(tuple_buffer->on_device_shape(),
|
||||
client->backend().transfer_manager()->HostShapeToDeviceShape(
|
||||
tuple_shape));
|
||||
@ -81,28 +81,28 @@ TEST(PySharedDeviceBufferTest, AsShapedBuffer) {
|
||||
Shape ab_tuple_shape = ShapeUtil::MakeTupleShape({a_shape, b_shape});
|
||||
Shape c_shape = ShapeUtil::MakeShape(S64, {});
|
||||
Shape abc_tuple_shape = ShapeUtil::MakeTupleShape({c_shape, ab_tuple_shape});
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto a_buffer,
|
||||
PySharedDeviceBuffer::MakeArray(
|
||||
a_shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto b_buffer,
|
||||
PySharedDeviceBuffer::MakeArray(
|
||||
b_shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto a_buffer, PySharedDeviceBuffer::MakeArray(
|
||||
a_shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, nullptr));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto b_buffer, PySharedDeviceBuffer::MakeArray(
|
||||
b_shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, nullptr));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto ab_tuple_buffer,
|
||||
PySharedDeviceBuffer::MakeTuple({a_buffer, b_buffer},
|
||||
client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto c_buffer,
|
||||
PySharedDeviceBuffer::MakeArray(
|
||||
c_shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0));
|
||||
PySharedDeviceBuffer::MakeTuple(
|
||||
{a_buffer, b_buffer}, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, nullptr));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto c_buffer, PySharedDeviceBuffer::MakeArray(
|
||||
c_shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, nullptr));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto abc_tuple_buffer,
|
||||
PySharedDeviceBuffer::MakeTuple({c_buffer, ab_tuple_buffer},
|
||||
client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0));
|
||||
PySharedDeviceBuffer::MakeTuple(
|
||||
{c_buffer, ab_tuple_buffer}, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, nullptr));
|
||||
EXPECT_EQ(abc_tuple_buffer->on_device_shape(),
|
||||
client->backend().transfer_manager()->HostShapeToDeviceShape(
|
||||
abc_tuple_shape));
|
||||
@ -142,7 +142,8 @@ TEST(PySharedDeviceBufferTest, FromScopedShapedBuffer) {
|
||||
ScopedShapedBuffer shaped_buffer,
|
||||
client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0));
|
||||
std::shared_ptr<PySharedDeviceBuffer> device_buffer =
|
||||
PySharedDeviceBuffer::FromScopedShapedBuffer(std::move(shaped_buffer));
|
||||
PySharedDeviceBuffer::FromScopedShapedBuffer(std::move(shaped_buffer),
|
||||
nullptr);
|
||||
|
||||
EXPECT_EQ(device_buffer->on_device_shape(),
|
||||
client->backend().transfer_manager()->HostShapeToDeviceShape(
|
||||
|
@ -211,7 +211,9 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
// CPU custom-call targets.
|
||||
m.def("RegisterCpuCustomCallTarget", &RegisterCpuCustomCallTarget);
|
||||
|
||||
py::class_<PyLocalClient>(m, "LocalClient")
|
||||
// The LocalClient object allows dynamic attributes to allow external backends
|
||||
// (e.g., TPU) to stash private data in the client.
|
||||
py::class_<PyLocalClient>(m, "LocalClient", py::dynamic_attr())
|
||||
.def_static("Get", &PyLocalClient::Get)
|
||||
.def("DeviceCount", &PyLocalClient::device_count)
|
||||
.def("TransferToInfeed", &PyLocalClient::TransferToInfeed)
|
||||
|
@ -103,15 +103,17 @@ class Backend(object):
|
||||
class LocalBackend(Backend):
|
||||
"""XLA backend implemented using the in-process xla::LocalClient API."""
|
||||
|
||||
def __init__(self, platform=None, xla_platform_id=None):
|
||||
def __init__(self, platform=None, xla_platform_id=None, asynchronous=False):
|
||||
"""Creates a new LocalBackend.
|
||||
|
||||
Args:
|
||||
platform: A string; the user-visible platform name, e.g. 'gpu'.
|
||||
xla_platform_id: A string; XLA's name for the platform, e.g., 'CUDA'.
|
||||
asynchronous: A boolean; should we enable asynchronous execution?
|
||||
(Experimental.)
|
||||
"""
|
||||
super(LocalBackend, self).__init__(platform)
|
||||
self.client = _xla.LocalClient.Get(xla_platform_id)
|
||||
self.client = _xla.LocalClient.Get(platform, xla_platform_id, asynchronous)
|
||||
|
||||
def device_count(self):
|
||||
return self.client.DeviceCount()
|
||||
|
@ -226,10 +226,10 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) {
|
||||
return &module_globals_.emplace(executor, std::move(globals)).first->second;
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
|
||||
StatusOr<ScopedShapedBuffer> GpuExecutable::Execute(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
HloExecutionProfile* hlo_execution_profile, bool block_host_until_done) {
|
||||
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
|
||||
|
||||
if (GetRootPointsToSet().IsAmbiguous()) {
|
||||
@ -272,8 +272,6 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
|
||||
buffer_allocations_builder.Build(
|
||||
assignment_.get(), executor->device_ordinal(), memory_allocator));
|
||||
|
||||
bool block_host_until_done =
|
||||
!memory_allocator->AllowsAsynchronousDeallocation();
|
||||
TF_RETURN_IF_ERROR(ExecuteThunks(run_options, *buffer_allocations,
|
||||
block_host_until_done,
|
||||
hlo_execution_profile));
|
||||
@ -339,12 +337,22 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
|
||||
return std::move(shaped_buffer);
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
return Execute(run_options, arguments, hlo_execution_profile,
|
||||
/*block_host_until_done=*/true);
|
||||
}
|
||||
|
||||
StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments) {
|
||||
// TODO(b/30671675): Implement asynchronous execution mode.
|
||||
return Unimplemented(
|
||||
"Asynchronous execution on stream is not yet supported on GPU.");
|
||||
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
|
||||
// Force synchronous execution if the allocator requires it.
|
||||
bool block_host_until_done =
|
||||
!memory_allocator->AllowsAsynchronousDeallocation();
|
||||
return Execute(run_options, arguments, nullptr, block_host_until_done);
|
||||
}
|
||||
|
||||
const PointsToSet& GpuExecutable::GetRootPointsToSet() const {
|
||||
|
@ -86,6 +86,11 @@ class GpuExecutable : public Executable {
|
||||
absl::Span<const ShapedBuffer* const> arguments) override;
|
||||
|
||||
private:
|
||||
StatusOr<ScopedShapedBuffer> Execute(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile, bool block_host_until_done);
|
||||
|
||||
// If `block_host_until_done` is false, execution will not block the host
|
||||
// until the kernels have completed. This is used as an optimization for
|
||||
// clients, such as Tensorflow, that use a single stream of execution for
|
||||
|
Loading…
Reference in New Issue
Block a user