From 955e356e4c69d3fce4ac2bac5966671e964f9627 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 9 Jul 2018 01:49:04 -0700 Subject: [PATCH] [SE,XLA] Switch to using multiple streams in xla_device_context Instead of having one stream for compute, host-to-device and device-to-host transfers, switch to having separate streams, just like the GPU does. Add a se::Event field to XlaTensor to allow accurate inter-stream dependencies to be created. As part of this: - Fix TransferManager::TransferLiteralFrom/ToDevice to correctly make generated substreams wait on their master stream. - Fix Stream::BlockHostUntilDone() to not block on or return substreams. This behavior is completely broken and not only nondeterministically returns substreams to the pool but causes indefinite hangs with the HostStream. PiperOrigin-RevId: 203726543 --- .../compiler/jit/kernels/xla_launch_op.cc | 5 +- .../compiler/jit/xla_compile_on_demand_op.cc | 4 +- tensorflow/compiler/jit/xla_cpu_device.cc | 1 + tensorflow/compiler/jit/xla_device.cc | 70 +++-- tensorflow/compiler/jit/xla_device.h | 22 +- tensorflow/compiler/jit/xla_device_context.cc | 240 ++++++++++-------- tensorflow/compiler/jit/xla_device_context.h | 19 +- tensorflow/compiler/jit/xla_gpu_device.cc | 1 + .../compiler/jit/xla_interpreter_device.cc | 1 + tensorflow/compiler/jit/xla_launch_util.cc | 34 ++- tensorflow/compiler/jit/xla_launch_util.h | 9 +- tensorflow/compiler/jit/xla_tensor.cc | 27 ++ tensorflow/compiler/jit/xla_tensor.h | 25 ++ .../compiler/xla/service/transfer_manager.cc | 2 + tensorflow/stream_executor/event.cc | 11 +- tensorflow/stream_executor/event.h | 3 + tensorflow/stream_executor/stream.cc | 19 +- 17 files changed, 344 insertions(+), 149 deletions(-) diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 251a07304ea..338fb5a6f06 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -115,6 +115,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { const XlaDevice::Metadata* metadata = nullptr; Status s = XlaDevice::GetMetadata(ctx, &metadata); bool allocate_xla_tensors = s.ok(); + bool use_multiple_streams = s.ok() && metadata->UseMultipleStreams(); // Get the platform_id_ for XLA_* devices. if (platform_id_ == nullptr) { @@ -180,8 +181,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { VLOG(1) << "Executing XLA Computation..."; - XlaComputationLaunchContext launch_context(client, xla_allocator, - allocate_xla_tensors); + XlaComputationLaunchContext launch_context( + client, xla_allocator, allocate_xla_tensors, use_multiple_streams); launch_context.PopulateInputs(ctx, kernel, variables); // Execute the computation. diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index baccea2d6a7..d288d37bc75 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -53,7 +53,9 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, // Builds an XLA allocator for the device. XlaComputationLaunchContext launch_context( - client, client->backend().memory_allocator(), true); + client, client->backend().memory_allocator(), + /*allocate_xla_tensors=*/true, + /*use_multiple_streams=*/metadata.UseMultipleStreams()); launch_context.PopulateInputs(ctx, result, variables); diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 43648402f65..7e159e31711 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -54,6 +54,7 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options, DEVICE_CPU_XLA_JIT, options, name_prefix, registration, /*transfer_as_literal=*/false, + /*use_multiple_streams=*/false, /*shape_representation_fn=*/{}, /*padded_shape_fn=*/{}, &device)); devices->push_back(device.release()); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index ed007d603ea..c55eba2f79d 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -130,7 +130,7 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { const string& jit_device_name, const SessionOptions& options, const string& name_prefix, const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, + bool transfer_as_literal, bool use_multiple_streams, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, const PaddedShapeFn& padded_shape_fn, std::unique_ptr* device) { VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" @@ -151,22 +151,24 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), strings::StrCat("device: ", device_name, " device")); - device->reset(new XlaDevice( - options, attrs, device_ordinal, DeviceType(jit_device_name), - platform.ValueOrDie(), transfer_as_literal, shape_representation_fn, - padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn)); + device->reset( + new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name), + platform.ValueOrDie(), transfer_as_literal, + use_multiple_streams, shape_representation_fn, + padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn)); return Status::OK(); } XlaDevice::Metadata::Metadata( int device_ordinal, se::Platform* platform, const DeviceType& device_type, XlaCompiler::ShapeRepresentationFn shape_representation_fn, - PaddedShapeFn padded_shape_fn) + PaddedShapeFn padded_shape_fn, bool use_multiple_streams) : device_ordinal_(device_ordinal), device_type_(device_type), platform_(platform), shape_representation_fn_(std::move(shape_representation_fn)), - padded_shape_fn_(std::move(padded_shape_fn)) {} + padded_shape_fn_(std::move(padded_shape_fn)), + use_multiple_streams_(use_multiple_streams) {} int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; } @@ -200,16 +202,18 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { XlaDevice::XlaDevice( const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, - se::Platform* platform, bool transfer_as_literal, + se::Platform* platform, bool transfer_as_literal, bool use_multiple_streams, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, const PaddedShapeFn& padded_shape_fn) : LocalDevice(options, attrs), xla_metadata_(device_ordinal, platform, jit_device_name, - shape_representation_fn, padded_shape_fn), + shape_representation_fn, padded_shape_fn, + use_multiple_streams), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(nullptr), platform_(platform), + use_multiple_streams_(use_multiple_streams), transfer_as_literal_(transfer_as_literal), shape_representation_fn_(shape_representation_fn) { VLOG(1) << "Created XLA device " << jit_device_name; @@ -253,6 +257,30 @@ xla::StatusOr XlaDevice::GetStream() { return stream_.get(); } +xla::StatusOr XlaDevice::GetDeviceToHostStream() { + if (!use_multiple_streams_) { + return GetStream(); + } + if (!device_to_host_stream_) { + xla::Backend* backend = client()->mutable_backend(); + TF_ASSIGN_OR_RETURN(device_to_host_stream_, + backend->BorrowStream(device_ordinal_)); + } + return device_to_host_stream_.get(); +} + +xla::StatusOr XlaDevice::GetHostToDeviceStream() { + if (!use_multiple_streams_) { + return GetStream(); + } + if (!host_to_device_stream_) { + xla::Backend* backend = client()->mutable_backend(); + TF_ASSIGN_OR_RETURN(host_to_device_stream_, + backend->BorrowStream(device_ordinal_)); + } + return host_to_device_stream_.get(); +} + Status XlaDevice::CreateAndSetGpuDeviceInfo() { if (gpu_device_info_ == nullptr) { TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); @@ -263,8 +291,9 @@ Status XlaDevice::CreateAndSetGpuDeviceInfo() { // gpu_device_info_->default_context. gpu_device_info_ = MakeUnique(); gpu_device_info_->stream = stream; - gpu_device_info_->default_context = new XlaDeviceContext( - stream, client(), transfer_as_literal_, shape_representation_fn_); + gpu_device_info_->default_context = + new XlaDeviceContext(stream, stream, stream, client(), + transfer_as_literal_, shape_representation_fn_); set_tensorflow_gpu_device_info(gpu_device_info_.get()); } @@ -276,10 +305,16 @@ Status XlaDevice::FillContextMap(const Graph* graph, VLOG(1) << "XlaDevice::FillContextMap"; device_context_map->resize(graph->num_node_ids()); TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); + TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream, + GetDeviceToHostStream()); + TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream, + GetHostToDeviceStream()); + // Call GetAllocator for the side-effect of ensuring the allocator is created. GetAllocator({}); - auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_, - shape_representation_fn_); + auto ctx = new XlaDeviceContext( + stream, host_to_device_stream, device_to_host_stream, client(), + transfer_as_literal_, shape_representation_fn_); for (Node* n : graph->nodes()) { VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name(); ctx->Ref(); @@ -326,8 +361,13 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); Notification n; TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - XlaTransferManager manager(stream, client(), transfer_as_literal_, - shape_representation_fn_); + TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream, + GetDeviceToHostStream()); + TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream, + GetHostToDeviceStream()); + XlaTransferManager manager(stream, host_to_device_stream, + device_to_host_stream, client(), + transfer_as_literal_, shape_representation_fn_); manager.CopyCPUTensorToDevice(&parsed, this, ©, [&n, &status](const Status& s) { status = s; diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 02e88ee6793..fccdb143680 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -57,7 +57,7 @@ class XlaDevice : public LocalDevice { Metadata(int device_ordinal, se::Platform* platform, const DeviceType& device_type, XlaCompiler::ShapeRepresentationFn shape_representation_fn, - PaddedShapeFn padded_shape_fn); + PaddedShapeFn padded_shape_fn, bool use_multiple_streams); // The index of the device on this host. int device_ordinal() const; @@ -70,12 +70,15 @@ class XlaDevice : public LocalDevice { } const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; } + bool UseMultipleStreams() const { return use_multiple_streams_; } + private: const int device_ordinal_; const DeviceType device_type_; se::Platform* platform_; // Not owned. XlaCompiler::ShapeRepresentationFn shape_representation_fn_; PaddedShapeFn padded_shape_fn_; + const bool use_multiple_streams_; TF_DISALLOW_COPY_AND_ASSIGN(Metadata); }; @@ -89,6 +92,8 @@ class XlaDevice : public LocalDevice { // 'transfer_as_literal' is true if device<->host transfers must be done using // XLA's TransferLiteral{To,From}Device interface. If false, we can use // ThenMemcpy instead. + // If 'use_multiple_streams' is true, we create separate streams for + // host-to-device and device-to-host communication. // If padded_shape_fn is empty, a default implementation that returns // the on-host shape is used. static Status Create( @@ -96,7 +101,7 @@ class XlaDevice : public LocalDevice { int device_ordinal, const string& jit_device_name, const SessionOptions& options, const string& name_prefix, const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, + bool transfer_as_literal, bool use_multiple_streams, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, const PaddedShapeFn& padded_shape_fn, std::unique_ptr* device); @@ -106,6 +111,7 @@ class XlaDevice : public LocalDevice { XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, se::Platform* platform, bool transfer_as_literal, + bool use_multiple_streams, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, const PaddedShapeFn& padded_shape_fn); ~XlaDevice() override; @@ -126,6 +132,8 @@ class XlaDevice : public LocalDevice { xla::LocalClient* client() const; const Metadata& metadata() { return xla_metadata_; } xla::StatusOr GetStream(); + xla::StatusOr GetHostToDeviceStream(); + xla::StatusOr GetDeviceToHostStream(); // If not already set, create and set GpuDeviceInfo. // Not thread-safe @@ -146,6 +154,16 @@ class XlaDevice : public LocalDevice { // copying back and forth between CPU and the device, and // computations enqueued by XLA. xla::Backend::StreamPtr stream_; + // If true, only stream_ is valid and all computation and transfers use + // stream_. If false, computation is performed by stream_ and transfers are + // performed by host_to_device/device_to_host_stream. + bool use_multiple_streams_; + // If use_multiple_streams_, host to device transfers are performed using this + // stream. + xla::Backend::StreamPtr host_to_device_stream_; + // If use_multiple_streams_, device to host transfers are performed using this + // stream. + xla::Backend::StreamPtr device_to_host_stream_; // Must we use XLA's transfer manager for correct host<->device transfers? if // false, we can use ThenMemcpy() instead. bool transfer_as_literal_; diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 0188faaf512..04778c00904 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -48,13 +48,20 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } XlaTransferManager::XlaTransferManager( - se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + se::Stream* compute_stream, se::Stream* host_to_device_stream, + se::Stream* device_to_host_stream, xla::LocalClient* client, + bool transfer_as_literal, XlaCompiler::ShapeRepresentationFn shape_representation_fn) - : stream_(stream), + : stream_(compute_stream), + host_to_device_stream_(host_to_device_stream), + device_to_host_stream_(device_to_host_stream), client_(client), transfer_manager_(client->backend().transfer_manager()), transfer_as_literal_(transfer_as_literal), shape_representation_fn_(std::move(shape_representation_fn)) { + CHECK(host_to_device_stream_ != nullptr); + CHECK(device_to_host_stream_ != nullptr); + CHECK(stream_ != nullptr); if (!shape_representation_fn_) { shape_representation_fn_ = [](const TensorShape& shape, @@ -70,12 +77,19 @@ Status XlaTransferManager::TransferLiteralToDevice( xla::BorrowingLiteral literal( static_cast(DMAHelper::base(&host_tensor)), xla_shape); - const xla::ShapedBuffer& shaped_buffer = - XlaTensor::FromTensor(device_tensor)->shaped_buffer(); + XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); + const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer(); VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " " << shaped_buffer.ToString(); - return transfer_manager_->TransferLiteralToDevice(stream_, literal, - shaped_buffer); + TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDevice( + host_to_device_stream_, literal, shaped_buffer)); + if (UseMultipleStreams()) { + se::Event event(stream_->parent()); + TF_RET_CHECK(event.Init()) << "Event failed to initialize!"; + host_to_device_stream_->ThenRecordEvent(&event); + xla_tensor->SetDefinedOn(host_to_device_stream_, std::move(event)); + } + return Status::OK(); } Status XlaTransferManager::TransferLiteralFromDevice( @@ -83,9 +97,9 @@ Status XlaTransferManager::TransferLiteralFromDevice( const xla::ShapedBuffer& shaped_buffer = XlaTensor::FromTensor(&device_tensor)->shaped_buffer(); - TF_ASSIGN_OR_RETURN( - std::unique_ptr literal, - transfer_manager_->TransferLiteralFromDevice(stream_, shaped_buffer)); + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + transfer_manager_->TransferLiteralFromDevice( + device_to_host_stream_, shaped_buffer)); VLOG(1) << "Transfer from device as literal: " << literal->ToString() << " " << shaped_buffer.ToString(); Tensor tensor; @@ -103,68 +117,67 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const { - if (cpu_tensor->NumElements() > 0) { - VLOG(2) << "CopyCPUTensorToDevice " - << reinterpret_cast(cpu_tensor->tensor_data().data()) - << " " - << reinterpret_cast( - device_tensor->tensor_data().data()) - << " " << cpu_tensor->NumElements() << " " - << cpu_tensor->shape().DebugString() << " " - << device_tensor->shape().DebugString(); - - void* src_ptr = const_cast(DMAHelper::base(cpu_tensor)); - const int64 total_bytes = cpu_tensor->TotalBytes(); - - XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); - CHECK(xla_tensor); - - xla::StatusOr shape_or_status = shape_representation_fn_( - device_tensor->shape(), device_tensor->dtype()); - if (!shape_or_status.ok()) { - done(shape_or_status.status()); - return; - } - TensorShape shape = shape_or_status.ValueOrDie(); - if (!xla_tensor->has_shaped_buffer()) { - Status s = xla_tensor->AllocateShapedBuffer( - device_tensor->dtype(), shape, client_, - stream_->parent()->device_ordinal()); - if (!s.ok()) { - done(s); - return; - } - } - - Status status; - if (transfer_as_literal_) { - Tensor reshaped_cpu_tensor; - if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) { - done(errors::Internal( - "Tensor::CopyFrom failed when copying from CPU to XLA device")); - return; - } - status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); - } else { - se::DeviceMemoryBase dev_dst_ptr = - XlaTensor::DeviceMemoryFromTensor(*device_tensor); - stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); - // TODO(hpucha): Make this asynchronous. - Status block_status = stream_->BlockHostUntilDone(); - if (!block_status.ok()) { - status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_, - block_status.error_message().c_str()); - } - } - xla_tensor->set_host_tensor(*cpu_tensor); - - done(status); + if (cpu_tensor->NumElements() == 0) { + VLOG(2) << "CopyCPUTensorToDevice empty tensor"; + done(Status::OK()); return; } - VLOG(2) << "CopyCPUTensorToDevice empty tensor"; - done(Status::OK()); + VLOG(2) << "CopyCPUTensorToDevice " + << reinterpret_cast(cpu_tensor->tensor_data().data()) + << " " + << reinterpret_cast(device_tensor->tensor_data().data()) + << " " << cpu_tensor->NumElements() << " " + << cpu_tensor->shape().DebugString() << " " + << device_tensor->shape().DebugString(); + + void* src_ptr = const_cast(DMAHelper::base(cpu_tensor)); + const int64 total_bytes = cpu_tensor->TotalBytes(); + + XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); + CHECK(xla_tensor); + + xla::StatusOr shape_or_status = + shape_representation_fn_(device_tensor->shape(), device_tensor->dtype()); + if (!shape_or_status.ok()) { + done(shape_or_status.status()); + return; + } + TensorShape shape = shape_or_status.ValueOrDie(); + if (!xla_tensor->has_shaped_buffer()) { + Status s = + xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_, + stream_->parent()->device_ordinal()); + if (!s.ok()) { + done(s); + return; + } + } + + Status status; + if (transfer_as_literal_) { + Tensor reshaped_cpu_tensor; + if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) { + done(errors::Internal( + "Tensor::CopyFrom failed when copying from CPU to XLA device")); + return; + } + status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); + } else { + se::DeviceMemoryBase dev_dst_ptr = + XlaTensor::DeviceMemoryFromTensor(*device_tensor); + host_to_device_stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); + // TODO(hpucha): Make this asynchronous. + Status block_status = host_to_device_stream_->BlockHostUntilDone(); + if (!block_status.ok()) { + status = xla::InternalError( + "Failed to complete data transfer on stream %p: %s", + host_to_device_stream_, block_status.error_message().c_str()); + } + } + xla_tensor->set_host_tensor(*cpu_tensor); + + done(status); } void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, @@ -172,51 +185,64 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, Device* device, Tensor* cpu_tensor, StatusCallback done) { - if (device_tensor->NumElements() > 0) { - VLOG(2) << "CopyDeviceTensorToCPU " - << reinterpret_cast( - device_tensor->tensor_data().data()) - << " " - << reinterpret_cast(cpu_tensor->tensor_data().data()) - << " " << device_tensor->NumElements() << " " - << cpu_tensor->shape().DebugString() << " " - << device_tensor->shape().DebugString(); - - const int64 total_bytes = cpu_tensor->TotalBytes(); - se::DeviceMemoryBase dev_src_ptr = - XlaTensor::DeviceMemoryFromTensor(*device_tensor); - void* dst_ptr = DMAHelper::base(cpu_tensor); - - Status status; - if (transfer_as_literal_) { - status = TransferLiteralFromDevice(cpu_tensor, *device_tensor); - } else { - stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); - // TODO(hpucha): Make this asynchronous. - Status block_status = stream_->BlockHostUntilDone(); - if (!block_status.ok()) { - status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_, - block_status.error_message().c_str()); - } - } - - done(status); + if (device_tensor->NumElements() == 0) { + VLOG(2) << "CopyDeviceTensorToCPU empty tensor"; + done(Status::OK()); return; } + VLOG(2) << "CopyDeviceTensorToCPU " + << reinterpret_cast(device_tensor->tensor_data().data()) + << " " + << reinterpret_cast(cpu_tensor->tensor_data().data()) + << " " << device_tensor->NumElements() << " " + << cpu_tensor->shape().DebugString() << " " + << device_tensor->shape().DebugString(); - VLOG(2) << "CopyDeviceTensorToCPU empty tensor"; - done(Status::OK()); + const int64 total_bytes = cpu_tensor->TotalBytes(); + se::DeviceMemoryBase dev_src_ptr = + XlaTensor::DeviceMemoryFromTensor(*device_tensor); + void* dst_ptr = DMAHelper::base(cpu_tensor); + XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); + + if (se::Event* event = + xla_tensor->GetDefinitionEvent(device_to_host_stream_)) { + device_to_host_stream_->ThenWaitFor(event); + xla_tensor->SetDefinedOn(device_to_host_stream_); + } + + Status status; + if (transfer_as_literal_) { + status = TransferLiteralFromDevice(cpu_tensor, *device_tensor); + } else { + device_to_host_stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); + // TODO(hpucha): Make this asynchronous. + Status block_status = device_to_host_stream_->BlockHostUntilDone(); + if (!block_status.ok()) { + status = xla::InternalError( + "Failed to complete data transfer on stream %p: %s", stream_, + block_status.error_message().c_str()); + } + } + + done(status); } void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done) { + VLOG(2) << "CopyDeviceTensorToDevice " + << reinterpret_cast(src_tensor.tensor_data().data()) + << " " + << reinterpret_cast(dst_tensor->tensor_data().data()); // TODO(phawkins): replace this code with an asynchronous implementation. auto body = [&]() { if (src_tensor.NumElements() == 0) { return Status::OK(); } + // TODO(jmolloy): We co-opt the device_to_host stream for device to device + // transfers; perhaps we should have a dedicated device to device stream? or + // one per device? + auto device_to_device_stream = device_to_host_stream_; XlaTensor* xla_src = XlaTensor::FromTensor(&src_tensor); XlaTensor* xla_dst = XlaTensor::FromTensor(dst_tensor); CHECK(xla_src && xla_dst) @@ -229,6 +255,13 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_, stream_->parent()->device_ordinal())); } + + if (se::Event* event = + xla_src->GetDefinitionEvent(device_to_device_stream)) { + device_to_device_stream->ThenWaitFor(event); + xla_src->SetDefinedOn(device_to_device_stream); + TF_RETURN_IF_ERROR(device_to_device_stream->BlockHostUntilDone()); + } TF_RETURN_IF_ERROR( xla_dst->shaped_buffer().buffers().ForEachMutableElementWithStatus( [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { @@ -247,9 +280,12 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, } XlaDeviceContext::XlaDeviceContext( - se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + se::Stream* compute_stream, se::Stream* host_to_device_stream, + se::Stream* device_to_host_stream, xla::LocalClient* client, + bool transfer_as_literal, XlaCompiler::ShapeRepresentationFn shape_representation_fn) - : manager_(stream, client, transfer_as_literal, + : manager_(compute_stream, host_to_device_stream, device_to_host_stream, + client, transfer_as_literal, std::move(shape_representation_fn)) {} void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index ee346e5653b..c726495f968 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -47,7 +47,9 @@ class XlaDeviceAllocator : public Allocator { class XlaTransferManager { public: explicit XlaTransferManager( - se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + se::Stream* compute_stream, se::Stream* host_to_device_stream, + se::Stream* device_to_host_stream, xla::LocalClient* client, + bool transfer_as_literal, XlaCompiler::ShapeRepresentationFn shape_representation_fn); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, @@ -66,10 +68,17 @@ class XlaTransferManager { Tensor* device_tensor) const; Status TransferLiteralFromDevice(Tensor* host_tensor, const Tensor& device_tensor) const; + bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; } - // Stream obtained from a Device, used to transfer tensors between - // CPU and device. + // The main compute stream of the device, used to synchronize the transfer + // streams if they are set. se::Stream* stream_; + // The stream to use for transferring data from host to device. Can be + // idential to stream_, but must not be nullptr. + se::Stream* host_to_device_stream_; + // The stream to use for transferring data from device to host. Can be + // idential to stream_, but must not be nullptr. + se::Stream* device_to_host_stream_; // For the underlying memory allocator and XLA's TransferManager. xla::LocalClient* client_; // Transfer manager, for marshalling data to and from the device. @@ -85,7 +94,9 @@ class XlaTransferManager { class XlaDeviceContext : public DeviceContext { public: explicit XlaDeviceContext( - se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + se::Stream* compute_stream, se::Stream* host_to_device_stream, + se::Stream* device_to_host_stream, xla::LocalClient* client, + bool transfer_as_literal, XlaCompiler::ShapeRepresentationFn shape_representation_fn); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index c0d86a28c76..851b118b0c1 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -49,6 +49,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, name_prefix, registration, /*transfer_as_literal=*/false, + /*use_multiple_streams=*/false, /*shape_representation_fn=*/{}, /*padded_shape_fn=*/{}, &device); if (!status.ok()) { diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 661187f4a87..45745596749 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -52,6 +52,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices( DEVICE_INTERPRETER_XLA_JIT, options, name_prefix, registration, /*transfer_as_literal=*/false, + /*use_multiple_streams=*/false, /*shape_representation_fn=*/{}, /*padded_shape_fn=*/{}, &device)); devices->push_back(device.release()); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 5ceccc769fa..616c3ed2a26 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -115,14 +115,22 @@ using internal::ExtractSubShapedBuffer; XlaComputationLaunchContext::XlaComputationLaunchContext( xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, - bool allocate_xla_tensors) + bool allocate_xla_tensors, bool use_multiple_streams) : client_(client), xla_allocator_(xla_allocator), - allocate_xla_tensors_(allocate_xla_tensors) {} + allocate_xla_tensors_(allocate_xla_tensors), + use_multiple_streams_(use_multiple_streams) { + if (use_multiple_streams_) { + CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must " + "be allocating XLA tensors!"; + } +} void XlaComputationLaunchContext::PopulateInputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, const std::map& variables) { + se::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; // Build ShapedBuffers that point directly to the Tensor buffers. arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1); arg_buffers_.resize(kernel->xla_input_shapes.size()); @@ -140,6 +148,16 @@ void XlaComputationLaunchContext::PopulateInputs( t = &(ctx->input(arg_num)); } + if (use_multiple_streams_) { + CHECK(stream) << "Must have a stream available when using XLA tensors!"; + XlaTensor* xla_tensor = XlaTensor::FromTensor(t); + CHECK(xla_tensor); + if (se::Event* event = xla_tensor->GetDefinitionEvent(stream)) { + stream->ThenWaitFor(event); + xla_tensor->SetDefinedOn(stream); + } + } + const xla::Shape on_device_shape = client_->backend().transfer_manager()->HostShapeToDeviceShape(shape); if (xla::ShapeUtil::IsTuple(on_device_shape)) { @@ -248,6 +266,12 @@ void XlaComputationLaunchContext::PopulateOutputs( if (xla_tensor) { xla_tensor->set_shaped_buffer(ScopedShapedBuffer( ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); + if (use_multiple_streams_) { + se::Event event(stream->parent()); + CHECK(event.Init()); + stream->ThenRecordEvent(&event); + xla_tensor->SetDefinedOn(stream, std::move(event)); + } } else { // xla_tensor wasn't valid, which must mean this is a zero-element // tensor. @@ -302,6 +326,12 @@ void XlaComputationLaunchContext::PopulateOutputs( CHECK(xla_tensor); xla_tensor->set_shaped_buffer( ExtractSubShapedBuffer(&output, output_num, xla_allocator_)); + if (use_multiple_streams_) { + se::Event event(stream->parent()); + CHECK(event.Init()); + stream->ThenRecordEvent(&event); + xla_tensor->SetDefinedOn(stream, std::move(event)); + } *variable->tensor() = output_tensor; } else { Tensor output_tensor = XlaTensorBuffer::MakeTensor( diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 4390701ccbd..90531174ff1 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -76,9 +76,15 @@ class XlaComputationLaunchContext { // Create a new launch context. 'allocate_xla_tensors' is true if allocated // output tensors and variables are always XlaTensors. If false they are // assumed to be "normal" device pointers. + // If 'use_multiple_streams' is true, tensors may be defined and used on + // multiple streams and so se::Events must be defined and waited for. If + // 'use_multiple_streams' is true, 'allocate_xla_tensors' must also be true + // because we track inter-stream dependencies through events inside XlaTensor + // objects. XlaComputationLaunchContext(xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, - bool allocate_xla_tensors); + bool allocate_xla_tensors, + bool use_multiple_streams); // Add all inputs within `ctx` as XLA arguments (returned by arguments()). // `variables` is a map from TensorFlow argument number to resource variable. @@ -99,6 +105,7 @@ class XlaComputationLaunchContext { xla::LocalClient* client_; xla::DeviceMemoryAllocator* xla_allocator_; bool allocate_xla_tensors_; + bool use_multiple_streams_; std::vector> arg_buffers_; std::vector arg_ptrs_; }; diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index 3c44c4ae6df..91a6f3da3f8 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -73,6 +73,33 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, return Status::OK(); } +se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) { + if (!definition_event_.has_value()) { + return nullptr; + } + + // The set of defined streams is expected to be very small indeed (usually + // 1-2), so a simple linear scan should be fast enough. + if (std::find(streams_defined_on_.begin(), streams_defined_on_.end(), + stream) != streams_defined_on_.end()) { + // stream is in streams_defined_on_; it doesn't need to be waited on. + return nullptr; + } + + return &*definition_event_; +} + +void XlaTensor::SetDefinedOn(se::Stream* stream, se::Event event) { + CHECK(!definition_event_.has_value()) + << "SetDefinedOn must only be called once!"; + definition_event_ = std::move(event); + streams_defined_on_.push_back(stream); +} + +void XlaTensor::SetDefinedOn(se::Stream* stream) { + streams_defined_on_.push_back(stream); +} + // The pointer tag, OR-ed into the XlaTensor's address to distinguish it from // device-side tensors, which are either CPU or GPU memory pointers. This works // because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits. diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index c54001a9999..c420fe40e37 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -85,6 +85,24 @@ class XlaTensor { host_tensor_.reset(new Tensor(tensor)); } + // If the tensor's content is not yet defined on 'stream', and there exists an + // se::Event declaring when the tensor's content is defined, return it. + // Otherwise, return nullptr. If this function returns nullptr then the + // tensor's content can be read on 'stream' without additional + // synchronization. + se::Event* GetDefinitionEvent(se::Stream* stream); + + // Assert that the tensor's content is defined on 'stream' by the time 'event' + // triggers. + void SetDefinedOn(se::Stream* stream, se::Event event); + + // Assert that the tensor's content is defined on 'stream'. This version does + // not provide an event, and must be called *after* SetDefinedOn(Stream, + // Event). This call can be read as an assertion that the definition event has + // been waited on by 'stream', so further calls to GetDefinitionEvent(stream) + // do not need to also wait on the event. + void SetDefinedOn(se::Stream* stream); + // Convert from a raw pointer to an XlaTensor, removing the pointer tag. static XlaTensor* FromOpaquePointer(void* ptr); // Convert to a raw pointer from an XlaTensor, adding the pointer tag. @@ -95,6 +113,13 @@ class XlaTensor { std::unique_ptr shaped_buffer_; // An optional host tensor value. std::unique_ptr host_tensor_; + // An optional event that is triggered when the tensor's content has been + // defined. If this event is nullptr, it is assumed that the tensor's content + // is always defined. + gtl::optional definition_event_; + // A list of all streams for which the tensor's content is defined for any + // newly enqueued command. + gtl::InlinedVector streams_defined_on_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 4c5038a009b..7232c658b3f 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -44,6 +44,7 @@ StatusOr> TransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer) { StatusOr> ret; se::Stream* substream = stream->GetOrCreateSubStream(); + substream->ThenWaitFor(stream); auto cleanup = tensorflow::gtl::MakeCleanup( [&]() { stream->ReturnSubStream(substream); }); @@ -64,6 +65,7 @@ Status TransferManager::TransferLiteralToDevice( // Use a substream so that if we are called from a HostCallback we don't // deadlock. se::Stream* substream = stream->GetOrCreateSubStream(); + substream->ThenWaitFor(stream); auto cleanup = tensorflow::gtl::MakeCleanup( [&]() { stream->ReturnSubStream(substream); }); TF_RETURN_IF_ERROR( diff --git a/tensorflow/stream_executor/event.cc b/tensorflow/stream_executor/event.cc index 50a6edd80bd..52efe771bc3 100644 --- a/tensorflow/stream_executor/event.cc +++ b/tensorflow/stream_executor/event.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/stream_executor/event.h" +#include "tensorflow/stream_executor/stream.h" #include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/stream_executor_pimpl.h" -#include "tensorflow/stream_executor/stream.h" namespace stream_executor { @@ -27,9 +27,12 @@ Event::Event(StreamExecutor* stream_exec) stream_exec_->implementation()->CreateEventImplementation()) {} Event::~Event() { - auto status = stream_exec_->DeallocateEvent(this); - if (!status.ok()) { - LOG(ERROR) << status.error_message(); + // Deal with nullptr implementation_, as this event may have been std::moved. + if (stream_exec_ && implementation_) { + auto status = stream_exec_->DeallocateEvent(this); + if (!status.ok()) { + LOG(ERROR) << status.error_message(); + } } } diff --git a/tensorflow/stream_executor/event.h b/tensorflow/stream_executor/event.h index 1f37262c78d..9cc87a7c129 100644 --- a/tensorflow/stream_executor/event.h +++ b/tensorflow/stream_executor/event.h @@ -61,6 +61,9 @@ class Event { // Returns a pointer to the underlying platform-specific implementation. internal::EventInterface* implementation() { return implementation_.get(); } + Event(Event&&) = default; + Event& operator=(Event&&) = default; + private: friend class Stream; diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 0cd0790a72b..93691831337 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -5228,24 +5228,11 @@ port::Status Stream::BlockHostUntilDone() { return status; } - port::Status first_error; - { - // Wait until all active sub-streams have done their tasks. - mutex_lock lock(mu_); - for (auto &stream : sub_streams_) { - if (!stream.second) { - first_error.Update(stream.first->BlockHostUntilDone()); - // Set this sub-stream as available. - stream.second = true; - } - } - } - temporary_memory_manager_.DeallocateFinalizedTemporaries(); - first_error.Update(parent_->BlockHostUntilDone(this)); - CheckError(first_error.ok()); - return first_error; + port::Status error = parent_->BlockHostUntilDone(this); + CheckError(error.ok()); + return error; } } // namespace stream_executor