[XLA:Python] Remove some special cases in the XLA Python bindings that exist to work around the fact that the Host device didn't support events. Enable "multi stream" mode everywhere and delete the code for "single stream" mode.

This has the side effect of using more threading than we did previously on CPU, in particular the "host to device" and "device to host" transfers are done in separate streams (threads). It's not clear that's a bad thing, and more commonality between backends is nice.

[SE:Host] Implement the se::Event API for the host platform, as a minimal wrapper around absl::Notification.

PiperOrigin-RevId: 257594142
This commit is contained in:
Peter Hawkins 2019-07-11 06:06:56 -07:00 committed by TensorFlower Gardener
parent 557759d885
commit 78ee8186db
5 changed files with 123 additions and 112 deletions

View File

@ -30,10 +30,9 @@ limitations under the License.
// Multi-stream execution:
// -----------------------
//
// On certain platforms (e.g., TPU), we use a multistream execution design,
// where different Streams are used for host-to-device transfers,
// device-to-host transfers, and compute. This allows us to overlap transfers on
// and off the device with computation.
// We use a multistream execution design, where different Streams are used for
// host-to-device transfers, device-to-host transfers, and compute. This allows
// us to overlap transfers on and off the device with computation.
//
// Synchronization between streams occurs via BufferDefinitionEvents that
// describe when the contents of a logical buffer are known to be valid on
@ -109,32 +108,24 @@ Status RegisterCpuCustomCallTarget(const std::string& fn_name,
return Status::OK();
}
Device::Device(se::StreamExecutor* executor, bool use_multiple_streams,
bool synchronous_deallocation, bool asynchronous,
bool allow_event_reuse)
: use_multiple_streams_(use_multiple_streams),
synchronous_deallocation_(synchronous_deallocation),
Device::Device(se::StreamExecutor* executor, bool synchronous_deallocation,
bool asynchronous, bool allow_event_reuse)
: synchronous_deallocation_(synchronous_deallocation),
asynchronous_(asynchronous),
event_pool_(allow_event_reuse) {
compute_stream_ = std::make_shared<se::Stream>(executor);
compute_stream_ = absl::make_unique<se::Stream>(executor);
host_to_device_stream_ = absl::make_unique<se::Stream>(executor);
device_to_host_stream_ = absl::make_unique<se::Stream>(executor);
callback_stream_ = absl::make_unique<se::Stream>(executor);
compute_stream_->Init();
if (use_multiple_streams) {
host_to_device_stream_ = std::make_shared<se::Stream>(executor);
device_to_host_stream_ = std::make_shared<se::Stream>(executor);
callback_stream_ = std::make_shared<se::Stream>(executor);
host_to_device_stream_->Init();
device_to_host_stream_->Init();
callback_stream_->Init();
device_to_device_streams_.reserve(kNumDeviceToDeviceStreams);
for (int i = 0; i < kNumDeviceToDeviceStreams; ++i) {
auto stream = std::make_shared<se::Stream>(executor);
stream->Init();
device_to_device_streams_.push_back(std::move(stream));
}
} else {
callback_stream_ = host_to_device_stream_ = device_to_host_stream_ =
compute_stream_;
device_to_device_streams_.push_back(compute_stream_);
host_to_device_stream_->Init();
device_to_host_stream_->Init();
callback_stream_->Init();
device_to_device_streams_.reserve(kNumDeviceToDeviceStreams);
for (int i = 0; i < kNumDeviceToDeviceStreams; ++i) {
auto stream = absl::make_unique<se::Stream>(executor);
stream->Init();
device_to_device_streams_.push_back(std::move(stream));
}
worker_thread_ = absl::make_unique<WorkerThread>(tensorflow::Env::Default(),
"py_xla_execute");
@ -256,13 +247,12 @@ StatusOr<std::shared_ptr<PyLocalClient>> PyLocalClient::Get(
std::vector<std::unique_ptr<Device>> devices;
devices.reserve(client->device_count());
bool use_multiple_streams = (platform_name != "cpu");
bool synchronous_deallocation = !use_multiple_streams;
bool synchronous_deallocation = platform_name == "cpu";
for (int i = 0; i < client->device_count(); ++i) {
se::StreamExecutor* executor =
client->backend().stream_executor(i).ValueOrDie();
devices.push_back(absl::make_unique<Device>(
executor, use_multiple_streams, synchronous_deallocation, asynchronous,
executor, synchronous_deallocation, asynchronous,
/*allow_event_reuse=*/gpu_platform));
}
return std::make_shared<PyLocalClient>(platform_name, client,
@ -329,8 +319,7 @@ static StatusOr<std::unique_ptr<PyLocalBuffer>> TransferHostToDeviceAsync(
transfer_manager->HostShapeToDeviceShape(indexed_shape.shape),
client->client()->platform(), device_ordinal);
leaf.buffers().CopySubtreeFrom(buffer.buffers(), indexed_shape.index, {});
if (device->use_multiple_streams() &&
!transfer_manager->CanShapedBufferBeAccessedNow(
if (!transfer_manager->CanShapedBufferBeAccessedNow(
device->host_to_device_stream()->parent(), leaf)) {
device->host_to_device_stream()->ThenWaitFor(device->compute_stream());
}
@ -338,15 +327,14 @@ static StatusOr<std::unique_ptr<PyLocalBuffer>> TransferHostToDeviceAsync(
device->host_to_device_stream(), *it, leaf));
++it;
}
std::shared_ptr<BufferDefinitionEvent> definition_event;
if (device->use_multiple_streams()) {
definition_event = std::make_shared<BufferDefinitionEvent>();
TF_ASSIGN_OR_RETURN(EventPool::Handle event,
device->event_pool().ThenAllocateAndRecordEvent(
device->host_to_device_stream()));
definition_event->SetDefinitionEvent(std::move(event),
device->host_to_device_stream());
}
auto definition_event = std::make_shared<BufferDefinitionEvent>();
TF_ASSIGN_OR_RETURN(EventPool::Handle event,
device->event_pool().ThenAllocateAndRecordEvent(
device->host_to_device_stream()));
definition_event->SetDefinitionEvent(std::move(event),
device->host_to_device_stream());
std::shared_ptr<SharedDeviceBuffer> device_buffer =
SharedDeviceBuffer::FromScopedShapedBuffer(std::move(buffer),
definition_event);
@ -409,10 +397,8 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
Device& device = client->device(device_ordinal);
std::shared_ptr<BufferDefinitionEvent> definition_event;
if (device.use_multiple_streams()) {
definition_event = std::make_shared<BufferDefinitionEvent>();
}
auto definition_event = std::make_shared<BufferDefinitionEvent>();
TF_ASSIGN_OR_RETURN(
std::shared_ptr<SharedDeviceBuffer> tuple_buffer,
SharedDeviceBuffer::MakeTuple(device_buffers, transfer_manager, allocator,
@ -423,21 +409,19 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromPython(
// TODO(phawkins): extend TransferManager so we do not need to form a full
// ShapedBuffer just to write the root tuple index table.
TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, buffer->AsShapedBuffer());
if (device.use_multiple_streams() &&
!transfer_manager->CanShapedBufferBeAccessedNow(
if (!transfer_manager->CanShapedBufferBeAccessedNow(
device.host_to_device_stream()->parent(), shaped_buffer)) {
// Wait for the compute stream so that memory allocations are synchronized.
device.host_to_device_stream()->ThenWaitFor(device.compute_stream());
}
TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
device.host_to_device_stream(), shaped_buffer));
if (definition_event) {
TF_ASSIGN_OR_RETURN(EventPool::Handle event,
device.event_pool().ThenAllocateAndRecordEvent(
device.host_to_device_stream()));
definition_event->SetDefinitionEvent(std::move(event),
device.host_to_device_stream());
}
TF_ASSIGN_OR_RETURN(EventPool::Handle event,
device.event_pool().ThenAllocateAndRecordEvent(
device.host_to_device_stream()));
definition_event->SetDefinitionEvent(std::move(event),
device.host_to_device_stream());
if (device.synchronous_deallocation()) {
device.ThenReleaseOnWorkerThread(device.host_to_device_stream(),
@ -574,8 +558,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
ScopedShapedBuffer dst_buffer,
transfer_manager->AllocateScopedShapedBuffer(
on_host_shape_, client_->allocator(), dst_device_ordinal));
if (dst_device.use_multiple_streams() &&
!transfer_manager->CanShapedBufferBeAccessedNow(
if (!transfer_manager->CanShapedBufferBeAccessedNow(
dst_device.compute_stream()->parent(), dst_buffer)) {
src_device_to_device_stream->ThenWaitFor(dst_device.compute_stream());
}
@ -614,15 +597,12 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
dst_device.host_to_device_stream());
}
std::shared_ptr<BufferDefinitionEvent> definition_event;
if (dst_device.use_multiple_streams()) {
definition_event = std::make_shared<BufferDefinitionEvent>();
TF_ASSIGN_OR_RETURN(EventPool::Handle event,
src_device.event_pool().ThenAllocateAndRecordEvent(
src_device_to_device_stream));
definition_event->SetDefinitionEvent(std::move(event),
src_device_to_device_stream);
}
auto definition_event = std::make_shared<BufferDefinitionEvent>();
TF_ASSIGN_OR_RETURN(EventPool::Handle event,
src_device.event_pool().ThenAllocateAndRecordEvent(
src_device_to_device_stream));
definition_event->SetDefinitionEvent(std::move(event),
src_device_to_device_stream);
std::shared_ptr<SharedDeviceBuffer> dst_device_buffer =
SharedDeviceBuffer::FromScopedShapedBuffer(std::move(dst_buffer),
@ -738,15 +718,13 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
return result_buffer.status();
}
std::shared_ptr<BufferDefinitionEvent> definition_event;
if (device.use_multiple_streams()) {
definition_event = std::make_shared<BufferDefinitionEvent>();
TF_ASSIGN_OR_RETURN(EventPool::Handle event,
device.event_pool().ThenAllocateAndRecordEvent(
device.compute_stream()));
definition_event->SetDefinitionEvent(std::move(event),
device.compute_stream());
}
auto definition_event = std::make_shared<BufferDefinitionEvent>();
TF_ASSIGN_OR_RETURN(
EventPool::Handle event,
device.event_pool().ThenAllocateAndRecordEvent(device.compute_stream()));
definition_event->SetDefinitionEvent(std::move(event),
device.compute_stream());
Shape on_host_shape = result_buffer.ValueOrDie().on_host_shape();
std::shared_ptr<SharedDeviceBuffer> out_buffer =
SharedDeviceBuffer::FromScopedShapedBuffer(

View File

@ -50,11 +50,6 @@ Status RegisterCpuCustomCallTarget(const std::string& fn_name,
// can perform computation and transfers.
class Device {
public:
// If use_multiple_streams is true, we allocate separate streams for compute
// and transfers. If it is false, we share a single stream for compute and
// transfers. The CPU device does not support multiple streams, and this is
// a workaround until it does.
//
// If synchronous_deallocation is true, the host must not free buffers until
// compute/transfers that use those buffers have completed. For example, this
// typically is the case for the "platform" where compute/transfers are
@ -62,12 +57,10 @@ class Device {
//
// If asynchronous is false, the host will synchronize to the device after
// each execution or transfer. This is intended for debugging only.
Device(se::StreamExecutor* executor, bool use_multiple_streams,
bool synchronous_deallocation, bool asynchronous,
bool allow_event_reuse);
Device(se::StreamExecutor* executor, bool synchronous_deallocation,
bool asynchronous, bool allow_event_reuse);
virtual ~Device();
bool use_multiple_streams() const { return use_multiple_streams_; }
bool synchronous_deallocation() const { return synchronous_deallocation_; }
bool asynchronous() const { return asynchronous_; }
@ -142,15 +135,14 @@ class Device {
private:
Status SynchronizeAllActivity();
bool use_multiple_streams_;
bool synchronous_deallocation_;
bool asynchronous_;
EventPool event_pool_;
std::shared_ptr<se::Stream> compute_stream_;
std::shared_ptr<se::Stream> host_to_device_stream_;
std::shared_ptr<se::Stream> device_to_host_stream_;
std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_;
std::unique_ptr<se::Stream> compute_stream_;
std::unique_ptr<se::Stream> host_to_device_stream_;
std::unique_ptr<se::Stream> device_to_host_stream_;
std::vector<std::unique_ptr<se::Stream>> device_to_device_streams_;
// Number of device-to-device streams to create in the multistream case.
static constexpr int kNumDeviceToDeviceStreams = 4;
@ -161,7 +153,7 @@ class Device {
// Callback stream is used for running short host-side callbacks after device
// side events, without preventing the device-side stream from doing useful
// work.
std::shared_ptr<se::Stream> callback_stream_;
std::unique_ptr<se::Stream> callback_stream_;
std::unique_ptr<WorkerThread> worker_thread_;
};

View File

@ -111,6 +111,7 @@ cc_library(
"//tensorflow/stream_executor:stream_executor_pimpl",
"//tensorflow/stream_executor:timer",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/synchronization",
],
alwayslink = True,
)

View File

@ -19,12 +19,14 @@ limitations under the License.
#include <string.h>
#include "absl/synchronization/notification.h"
#include "tensorflow/core/platform/profile_utils/cpu_utils.h"
#include "tensorflow/stream_executor/host/host_platform_id.h"
#include "tensorflow/stream_executor/host/host_stream.h"
#include "tensorflow/stream_executor/host/host_timer.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/plugin_registry.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
namespace stream_executor {
namespace host {
@ -167,6 +169,61 @@ bool HostExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
return true;
}
class HostEvent : public internal::EventInterface {
public:
HostEvent() : notification_(std::make_shared<absl::Notification>()) {}
std::shared_ptr<absl::Notification> &notification() { return notification_; }
private:
// We use a std::shared_ptr here because the client may delete the HostEvent
// object while there are still RecordEvent and WaitForEvent callbacks pending
// on a stream.
std::shared_ptr<absl::Notification> notification_;
};
std::unique_ptr<internal::EventInterface>
HostExecutor::CreateEventImplementation() {
return std::unique_ptr<internal::EventInterface>(new HostEvent());
}
static HostEvent *AsHostEvent(Event *event) {
DCHECK(event != nullptr);
return static_cast<HostEvent *>(event->implementation());
}
port::Status HostExecutor::AllocateEvent(Event * /*event*/) {
return port::Status::OK();
}
port::Status HostExecutor::DeallocateEvent(Event * /*event*/) {
return port::Status::OK();
}
port::Status HostExecutor::RecordEvent(Stream *stream, Event *event) {
std::shared_ptr<absl::Notification> notification =
AsHostEvent(event)->notification();
AsHostStream(stream)->EnqueueTask([notification]() {
CHECK(!notification->HasBeenNotified());
notification->Notify();
});
return port::Status::OK();
}
port::Status HostExecutor::WaitForEvent(Stream *stream, Event *event) {
std::shared_ptr<absl::Notification> notification =
AsHostEvent(event)->notification();
AsHostStream(stream)->EnqueueTask(
[notification]() { notification->WaitForNotification(); });
return port::Status::OK();
}
Event::Status HostExecutor::PollForEventStatus(Event *event) {
absl::Notification &notification = *AsHostEvent(event)->notification();
return notification.HasBeenNotified() ? Event::Status::kComplete
: Event::Status::kPending;
}
bool HostExecutor::StartTimer(Stream *stream, Timer *timer) {
dynamic_cast<HostTimer *>(timer->implementation())->Start(stream);
return true;

View File

@ -106,25 +106,11 @@ class HostExecutor : public internal::StreamExecutorInterface {
bool HostCallback(Stream *stream,
std::function<port::Status()> callback) override;
port::Status AllocateEvent(Event *event) override {
return port::Status(port::error::UNIMPLEMENTED, "");
}
port::Status DeallocateEvent(Event *event) override {
return port::Status(port::error::UNIMPLEMENTED, "");
}
port::Status RecordEvent(Stream *stream, Event *event) override {
return port::Status(port::error::UNIMPLEMENTED, "");
}
port::Status WaitForEvent(Stream *stream, Event *event) override {
return port::Status(port::error::UNIMPLEMENTED, "");
}
Event::Status PollForEventStatus(Event *event) override {
return Event::Status::kError;
}
port::Status AllocateEvent(Event *event) override;
port::Status DeallocateEvent(Event *event) override;
port::Status RecordEvent(Stream *stream, Event *event) override;
port::Status WaitForEvent(Stream *stream, Event *event) override;
Event::Status PollForEventStatus(Event *event) override;
bool AllocateStream(Stream *stream) override;
void DeallocateStream(Stream *stream) override;
@ -190,10 +176,7 @@ class HostExecutor : public internal::StreamExecutorInterface {
rng::RngSupport *CreateRng() override;
std::unique_ptr<internal::EventInterface> CreateEventImplementation()
override {
LOG(WARNING) << "Events not currently supported by HostExecutor.";
return nullptr;
}
override;
std::unique_ptr<internal::KernelInterface> CreateKernelImplementation()
override {