[SE,XLA] Switch to using multiple streams in xla_device_context

Instead of having one stream for compute, host-to-device and device-to-host transfers, switch to having separate streams, just like the GPU does.
Add a se::Event field to XlaTensor to allow accurate inter-stream dependencies to be created.

As part of this:
 - Fix TransferManager::TransferLiteralFrom/ToDevice to correctly make generated substreams wait on their master stream.
 - Fix Stream::BlockHostUntilDone() to not block on or return substreams. This behavior is completely broken and not only nondeterministically returns substreams to the pool but causes indefinite hangs with the HostStream.

PiperOrigin-RevId: 203726543
This commit is contained in:
A. Unique TensorFlower 2018-07-09 01:49:04 -07:00 committed by TensorFlower Gardener
parent caf711b6be
commit 955e356e4c
17 changed files with 344 additions and 149 deletions

View File

@ -115,6 +115,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
const XlaDevice::Metadata* metadata = nullptr;
Status s = XlaDevice::GetMetadata(ctx, &metadata);
bool allocate_xla_tensors = s.ok();
bool use_multiple_streams = s.ok() && metadata->UseMultipleStreams();
// Get the platform_id_ for XLA_* devices.
if (platform_id_ == nullptr) {
@ -180,8 +181,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
VLOG(1) << "Executing XLA Computation...";
XlaComputationLaunchContext launch_context(client, xla_allocator,
allocate_xla_tensors);
XlaComputationLaunchContext launch_context(
client, xla_allocator, allocate_xla_tensors, use_multiple_streams);
launch_context.PopulateInputs(ctx, kernel, variables);
// Execute the computation.

View File

@ -53,7 +53,9 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
// Builds an XLA allocator for the device.
XlaComputationLaunchContext launch_context(
client, client->backend().memory_allocator(), true);
client, client->backend().memory_allocator(),
/*allocate_xla_tensors=*/true,
/*use_multiple_streams=*/metadata.UseMultipleStreams());
launch_context.PopulateInputs(ctx, result, variables);

View File

@ -54,6 +54,7 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
DEVICE_CPU_XLA_JIT, options, name_prefix,
registration,
/*transfer_as_literal=*/false,
/*use_multiple_streams=*/false,
/*shape_representation_fn=*/{},
/*padded_shape_fn=*/{}, &device));
devices->push_back(device.release());

View File

@ -130,7 +130,7 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
const string& jit_device_name, const SessionOptions& options,
const string& name_prefix,
const XlaOpRegistry::DeviceRegistration& registration,
bool transfer_as_literal,
bool transfer_as_literal, bool use_multiple_streams,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device) {
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
@ -151,22 +151,24 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
strings::StrCat("device: ", device_name, " device"));
device->reset(new XlaDevice(
options, attrs, device_ordinal, DeviceType(jit_device_name),
platform.ValueOrDie(), transfer_as_literal, shape_representation_fn,
padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn));
device->reset(
new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name),
platform.ValueOrDie(), transfer_as_literal,
use_multiple_streams, shape_representation_fn,
padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn));
return Status::OK();
}
XlaDevice::Metadata::Metadata(
int device_ordinal, se::Platform* platform, const DeviceType& device_type,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
PaddedShapeFn padded_shape_fn)
PaddedShapeFn padded_shape_fn, bool use_multiple_streams)
: device_ordinal_(device_ordinal),
device_type_(device_type),
platform_(platform),
shape_representation_fn_(std::move(shape_representation_fn)),
padded_shape_fn_(std::move(padded_shape_fn)) {}
padded_shape_fn_(std::move(padded_shape_fn)),
use_multiple_streams_(use_multiple_streams) {}
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
@ -200,16 +202,18 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
XlaDevice::XlaDevice(
const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
se::Platform* platform, bool transfer_as_literal,
se::Platform* platform, bool transfer_as_literal, bool use_multiple_streams,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
const PaddedShapeFn& padded_shape_fn)
: LocalDevice(options, attrs),
xla_metadata_(device_ordinal, platform, jit_device_name,
shape_representation_fn, padded_shape_fn),
shape_representation_fn, padded_shape_fn,
use_multiple_streams),
device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name),
xla_allocator_(nullptr),
platform_(platform),
use_multiple_streams_(use_multiple_streams),
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(shape_representation_fn) {
VLOG(1) << "Created XLA device " << jit_device_name;
@ -253,6 +257,30 @@ xla::StatusOr<se::Stream*> XlaDevice::GetStream() {
return stream_.get();
}
xla::StatusOr<se::Stream*> XlaDevice::GetDeviceToHostStream() {
if (!use_multiple_streams_) {
return GetStream();
}
if (!device_to_host_stream_) {
xla::Backend* backend = client()->mutable_backend();
TF_ASSIGN_OR_RETURN(device_to_host_stream_,
backend->BorrowStream(device_ordinal_));
}
return device_to_host_stream_.get();
}
xla::StatusOr<se::Stream*> XlaDevice::GetHostToDeviceStream() {
if (!use_multiple_streams_) {
return GetStream();
}
if (!host_to_device_stream_) {
xla::Backend* backend = client()->mutable_backend();
TF_ASSIGN_OR_RETURN(host_to_device_stream_,
backend->BorrowStream(device_ordinal_));
}
return host_to_device_stream_.get();
}
Status XlaDevice::CreateAndSetGpuDeviceInfo() {
if (gpu_device_info_ == nullptr) {
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
@ -263,8 +291,9 @@ Status XlaDevice::CreateAndSetGpuDeviceInfo() {
// gpu_device_info_->default_context.
gpu_device_info_ = MakeUnique<GpuDeviceInfo>();
gpu_device_info_->stream = stream;
gpu_device_info_->default_context = new XlaDeviceContext(
stream, client(), transfer_as_literal_, shape_representation_fn_);
gpu_device_info_->default_context =
new XlaDeviceContext(stream, stream, stream, client(),
transfer_as_literal_, shape_representation_fn_);
set_tensorflow_gpu_device_info(gpu_device_info_.get());
}
@ -276,10 +305,16 @@ Status XlaDevice::FillContextMap(const Graph* graph,
VLOG(1) << "XlaDevice::FillContextMap";
device_context_map->resize(graph->num_node_ids());
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream,
GetDeviceToHostStream());
TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream,
GetHostToDeviceStream());
// Call GetAllocator for the side-effect of ensuring the allocator is created.
GetAllocator({});
auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_,
shape_representation_fn_);
auto ctx = new XlaDeviceContext(
stream, host_to_device_stream, device_to_host_stream, client(),
transfer_as_literal_, shape_representation_fn_);
for (Node* n : graph->nodes()) {
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
ctx->Ref();
@ -326,8 +361,13 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
Notification n;
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
XlaTransferManager manager(stream, client(), transfer_as_literal_,
shape_representation_fn_);
TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream,
GetDeviceToHostStream());
TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream,
GetHostToDeviceStream());
XlaTransferManager manager(stream, host_to_device_stream,
device_to_host_stream, client(),
transfer_as_literal_, shape_representation_fn_);
manager.CopyCPUTensorToDevice(&parsed, this, &copy,
[&n, &status](const Status& s) {
status = s;

View File

@ -57,7 +57,7 @@ class XlaDevice : public LocalDevice {
Metadata(int device_ordinal, se::Platform* platform,
const DeviceType& device_type,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
PaddedShapeFn padded_shape_fn);
PaddedShapeFn padded_shape_fn, bool use_multiple_streams);
// The index of the device on this host.
int device_ordinal() const;
@ -70,12 +70,15 @@ class XlaDevice : public LocalDevice {
}
const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; }
bool UseMultipleStreams() const { return use_multiple_streams_; }
private:
const int device_ordinal_;
const DeviceType device_type_;
se::Platform* platform_; // Not owned.
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
PaddedShapeFn padded_shape_fn_;
const bool use_multiple_streams_;
TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
};
@ -89,6 +92,8 @@ class XlaDevice : public LocalDevice {
// 'transfer_as_literal' is true if device<->host transfers must be done using
// XLA's TransferLiteral{To,From}Device interface. If false, we can use
// ThenMemcpy instead.
// If 'use_multiple_streams' is true, we create separate streams for
// host-to-device and device-to-host communication.
// If padded_shape_fn is empty, a default implementation that returns
// the on-host shape is used.
static Status Create(
@ -96,7 +101,7 @@ class XlaDevice : public LocalDevice {
int device_ordinal, const string& jit_device_name,
const SessionOptions& options, const string& name_prefix,
const XlaOpRegistry::DeviceRegistration& registration,
bool transfer_as_literal,
bool transfer_as_literal, bool use_multiple_streams,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device);
@ -106,6 +111,7 @@ class XlaDevice : public LocalDevice {
XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
se::Platform* platform, bool transfer_as_literal,
bool use_multiple_streams,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
const PaddedShapeFn& padded_shape_fn);
~XlaDevice() override;
@ -126,6 +132,8 @@ class XlaDevice : public LocalDevice {
xla::LocalClient* client() const;
const Metadata& metadata() { return xla_metadata_; }
xla::StatusOr<se::Stream*> GetStream();
xla::StatusOr<se::Stream*> GetHostToDeviceStream();
xla::StatusOr<se::Stream*> GetDeviceToHostStream();
// If not already set, create and set GpuDeviceInfo.
// Not thread-safe
@ -146,6 +154,16 @@ class XlaDevice : public LocalDevice {
// copying back and forth between CPU and the device, and
// computations enqueued by XLA.
xla::Backend::StreamPtr stream_;
// If true, only stream_ is valid and all computation and transfers use
// stream_. If false, computation is performed by stream_ and transfers are
// performed by host_to_device/device_to_host_stream.
bool use_multiple_streams_;
// If use_multiple_streams_, host to device transfers are performed using this
// stream.
xla::Backend::StreamPtr host_to_device_stream_;
// If use_multiple_streams_, device to host transfers are performed using this
// stream.
xla::Backend::StreamPtr device_to_host_stream_;
// Must we use XLA's transfer manager for correct host<->device transfers? if
// false, we can use ThenMemcpy() instead.
bool transfer_as_literal_;

View File

@ -48,13 +48,20 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
XlaTransferManager::XlaTransferManager(
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
se::Stream* compute_stream, se::Stream* host_to_device_stream,
se::Stream* device_to_host_stream, xla::LocalClient* client,
bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
: stream_(stream),
: stream_(compute_stream),
host_to_device_stream_(host_to_device_stream),
device_to_host_stream_(device_to_host_stream),
client_(client),
transfer_manager_(client->backend().transfer_manager()),
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(std::move(shape_representation_fn)) {
CHECK(host_to_device_stream_ != nullptr);
CHECK(device_to_host_stream_ != nullptr);
CHECK(stream_ != nullptr);
if (!shape_representation_fn_) {
shape_representation_fn_ =
[](const TensorShape& shape,
@ -70,12 +77,19 @@ Status XlaTransferManager::TransferLiteralToDevice(
xla::BorrowingLiteral literal(
static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape);
const xla::ShapedBuffer& shaped_buffer =
XlaTensor::FromTensor(device_tensor)->shaped_buffer();
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " "
<< shaped_buffer.ToString();
return transfer_manager_->TransferLiteralToDevice(stream_, literal,
shaped_buffer);
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDevice(
host_to_device_stream_, literal, shaped_buffer));
if (UseMultipleStreams()) {
se::Event event(stream_->parent());
TF_RET_CHECK(event.Init()) << "Event failed to initialize!";
host_to_device_stream_->ThenRecordEvent(&event);
xla_tensor->SetDefinedOn(host_to_device_stream_, std::move(event));
}
return Status::OK();
}
Status XlaTransferManager::TransferLiteralFromDevice(
@ -83,9 +97,9 @@ Status XlaTransferManager::TransferLiteralFromDevice(
const xla::ShapedBuffer& shaped_buffer =
XlaTensor::FromTensor(&device_tensor)->shaped_buffer();
TF_ASSIGN_OR_RETURN(
std::unique_ptr<xla::Literal> literal,
transfer_manager_->TransferLiteralFromDevice(stream_, shaped_buffer));
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
transfer_manager_->TransferLiteralFromDevice(
device_to_host_stream_, shaped_buffer));
VLOG(1) << "Transfer from device as literal: " << literal->ToString() << " "
<< shaped_buffer.ToString();
Tensor tensor;
@ -103,68 +117,67 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,
Tensor* device_tensor,
StatusCallback done) const {
if (cpu_tensor->NumElements() > 0) {
VLOG(2) << "CopyCPUTensorToDevice "
<< reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
<< " "
<< reinterpret_cast<const void*>(
device_tensor->tensor_data().data())
<< " " << cpu_tensor->NumElements() << " "
<< cpu_tensor->shape().DebugString() << " "
<< device_tensor->shape().DebugString();
void* src_ptr = const_cast<void*>(DMAHelper::base(cpu_tensor));
const int64 total_bytes = cpu_tensor->TotalBytes();
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
CHECK(xla_tensor);
xla::StatusOr<TensorShape> shape_or_status = shape_representation_fn_(
device_tensor->shape(), device_tensor->dtype());
if (!shape_or_status.ok()) {
done(shape_or_status.status());
return;
}
TensorShape shape = shape_or_status.ValueOrDie();
if (!xla_tensor->has_shaped_buffer()) {
Status s = xla_tensor->AllocateShapedBuffer(
device_tensor->dtype(), shape, client_,
stream_->parent()->device_ordinal());
if (!s.ok()) {
done(s);
return;
}
}
Status status;
if (transfer_as_literal_) {
Tensor reshaped_cpu_tensor;
if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) {
done(errors::Internal(
"Tensor::CopyFrom failed when copying from CPU to XLA device"));
return;
}
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
} else {
se::DeviceMemoryBase dev_dst_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
// TODO(hpucha): Make this asynchronous.
Status block_status = stream_->BlockHostUntilDone();
if (!block_status.ok()) {
status = xla::InternalError(
"Failed to complete data transfer on stream %p: %s", stream_,
block_status.error_message().c_str());
}
}
xla_tensor->set_host_tensor(*cpu_tensor);
done(status);
if (cpu_tensor->NumElements() == 0) {
VLOG(2) << "CopyCPUTensorToDevice empty tensor";
done(Status::OK());
return;
}
VLOG(2) << "CopyCPUTensorToDevice empty tensor";
done(Status::OK());
VLOG(2) << "CopyCPUTensorToDevice "
<< reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
<< " "
<< reinterpret_cast<const void*>(device_tensor->tensor_data().data())
<< " " << cpu_tensor->NumElements() << " "
<< cpu_tensor->shape().DebugString() << " "
<< device_tensor->shape().DebugString();
void* src_ptr = const_cast<void*>(DMAHelper::base(cpu_tensor));
const int64 total_bytes = cpu_tensor->TotalBytes();
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
CHECK(xla_tensor);
xla::StatusOr<TensorShape> shape_or_status =
shape_representation_fn_(device_tensor->shape(), device_tensor->dtype());
if (!shape_or_status.ok()) {
done(shape_or_status.status());
return;
}
TensorShape shape = shape_or_status.ValueOrDie();
if (!xla_tensor->has_shaped_buffer()) {
Status s =
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
stream_->parent()->device_ordinal());
if (!s.ok()) {
done(s);
return;
}
}
Status status;
if (transfer_as_literal_) {
Tensor reshaped_cpu_tensor;
if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) {
done(errors::Internal(
"Tensor::CopyFrom failed when copying from CPU to XLA device"));
return;
}
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
} else {
se::DeviceMemoryBase dev_dst_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
host_to_device_stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
// TODO(hpucha): Make this asynchronous.
Status block_status = host_to_device_stream_->BlockHostUntilDone();
if (!block_status.ok()) {
status = xla::InternalError(
"Failed to complete data transfer on stream %p: %s",
host_to_device_stream_, block_status.error_message().c_str());
}
}
xla_tensor->set_host_tensor(*cpu_tensor);
done(status);
}
void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
@ -172,51 +185,64 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
Device* device,
Tensor* cpu_tensor,
StatusCallback done) {
if (device_tensor->NumElements() > 0) {
VLOG(2) << "CopyDeviceTensorToCPU "
<< reinterpret_cast<const void*>(
device_tensor->tensor_data().data())
<< " "
<< reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
<< " " << device_tensor->NumElements() << " "
<< cpu_tensor->shape().DebugString() << " "
<< device_tensor->shape().DebugString();
const int64 total_bytes = cpu_tensor->TotalBytes();
se::DeviceMemoryBase dev_src_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
void* dst_ptr = DMAHelper::base(cpu_tensor);
Status status;
if (transfer_as_literal_) {
status = TransferLiteralFromDevice(cpu_tensor, *device_tensor);
} else {
stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes);
// TODO(hpucha): Make this asynchronous.
Status block_status = stream_->BlockHostUntilDone();
if (!block_status.ok()) {
status = xla::InternalError(
"Failed to complete data transfer on stream %p: %s", stream_,
block_status.error_message().c_str());
}
}
done(status);
if (device_tensor->NumElements() == 0) {
VLOG(2) << "CopyDeviceTensorToCPU empty tensor";
done(Status::OK());
return;
}
VLOG(2) << "CopyDeviceTensorToCPU "
<< reinterpret_cast<const void*>(device_tensor->tensor_data().data())
<< " "
<< reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
<< " " << device_tensor->NumElements() << " "
<< cpu_tensor->shape().DebugString() << " "
<< device_tensor->shape().DebugString();
VLOG(2) << "CopyDeviceTensorToCPU empty tensor";
done(Status::OK());
const int64 total_bytes = cpu_tensor->TotalBytes();
se::DeviceMemoryBase dev_src_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
void* dst_ptr = DMAHelper::base(cpu_tensor);
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
if (se::Event* event =
xla_tensor->GetDefinitionEvent(device_to_host_stream_)) {
device_to_host_stream_->ThenWaitFor(event);
xla_tensor->SetDefinedOn(device_to_host_stream_);
}
Status status;
if (transfer_as_literal_) {
status = TransferLiteralFromDevice(cpu_tensor, *device_tensor);
} else {
device_to_host_stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes);
// TODO(hpucha): Make this asynchronous.
Status block_status = device_to_host_stream_->BlockHostUntilDone();
if (!block_status.ok()) {
status = xla::InternalError(
"Failed to complete data transfer on stream %p: %s", stream_,
block_status.error_message().c_str());
}
}
done(status);
}
void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
Tensor* dst_tensor,
const StatusCallback& done) {
VLOG(2) << "CopyDeviceTensorToDevice "
<< reinterpret_cast<const void*>(src_tensor.tensor_data().data())
<< " "
<< reinterpret_cast<const void*>(dst_tensor->tensor_data().data());
// TODO(phawkins): replace this code with an asynchronous implementation.
auto body = [&]() {
if (src_tensor.NumElements() == 0) {
return Status::OK();
}
// TODO(jmolloy): We co-opt the device_to_host stream for device to device
// transfers; perhaps we should have a dedicated device to device stream? or
// one per device?
auto device_to_device_stream = device_to_host_stream_;
XlaTensor* xla_src = XlaTensor::FromTensor(&src_tensor);
XlaTensor* xla_dst = XlaTensor::FromTensor(dst_tensor);
CHECK(xla_src && xla_dst)
@ -229,6 +255,13 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_,
stream_->parent()->device_ordinal()));
}
if (se::Event* event =
xla_src->GetDefinitionEvent(device_to_device_stream)) {
device_to_device_stream->ThenWaitFor(event);
xla_src->SetDefinedOn(device_to_device_stream);
TF_RETURN_IF_ERROR(device_to_device_stream->BlockHostUntilDone());
}
TF_RETURN_IF_ERROR(
xla_dst->shaped_buffer().buffers().ForEachMutableElementWithStatus(
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
@ -247,9 +280,12 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
}
XlaDeviceContext::XlaDeviceContext(
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
se::Stream* compute_stream, se::Stream* host_to_device_stream,
se::Stream* device_to_host_stream, xla::LocalClient* client,
bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
: manager_(stream, client, transfer_as_literal,
: manager_(compute_stream, host_to_device_stream, device_to_host_stream,
client, transfer_as_literal,
std::move(shape_representation_fn)) {}
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,

View File

@ -47,7 +47,9 @@ class XlaDeviceAllocator : public Allocator {
class XlaTransferManager {
public:
explicit XlaTransferManager(
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
se::Stream* compute_stream, se::Stream* host_to_device_stream,
se::Stream* device_to_host_stream, xla::LocalClient* client,
bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
@ -66,10 +68,17 @@ class XlaTransferManager {
Tensor* device_tensor) const;
Status TransferLiteralFromDevice(Tensor* host_tensor,
const Tensor& device_tensor) const;
bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; }
// Stream obtained from a Device, used to transfer tensors between
// CPU and device.
// The main compute stream of the device, used to synchronize the transfer
// streams if they are set.
se::Stream* stream_;
// The stream to use for transferring data from host to device. Can be
// idential to stream_, but must not be nullptr.
se::Stream* host_to_device_stream_;
// The stream to use for transferring data from device to host. Can be
// idential to stream_, but must not be nullptr.
se::Stream* device_to_host_stream_;
// For the underlying memory allocator and XLA's TransferManager.
xla::LocalClient* client_;
// Transfer manager, for marshalling data to and from the device.
@ -85,7 +94,9 @@ class XlaTransferManager {
class XlaDeviceContext : public DeviceContext {
public:
explicit XlaDeviceContext(
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
se::Stream* compute_stream, se::Stream* host_to_device_stream,
se::Stream* device_to_host_stream, xla::LocalClient* client,
bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,

View File

@ -49,6 +49,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
name_prefix, registration,
/*transfer_as_literal=*/false,
/*use_multiple_streams=*/false,
/*shape_representation_fn=*/{},
/*padded_shape_fn=*/{}, &device);
if (!status.ok()) {

View File

@ -52,6 +52,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
DEVICE_INTERPRETER_XLA_JIT, options,
name_prefix, registration,
/*transfer_as_literal=*/false,
/*use_multiple_streams=*/false,
/*shape_representation_fn=*/{},
/*padded_shape_fn=*/{}, &device));
devices->push_back(device.release());

View File

@ -115,14 +115,22 @@ using internal::ExtractSubShapedBuffer;
XlaComputationLaunchContext::XlaComputationLaunchContext(
xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors)
bool allocate_xla_tensors, bool use_multiple_streams)
: client_(client),
xla_allocator_(xla_allocator),
allocate_xla_tensors_(allocate_xla_tensors) {}
allocate_xla_tensors_(allocate_xla_tensors),
use_multiple_streams_(use_multiple_streams) {
if (use_multiple_streams_) {
CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
"be allocating XLA tensors!";
}
}
void XlaComputationLaunchContext::PopulateInputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
const std::map<int, OptionalTensor>& variables) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
// Build ShapedBuffers that point directly to the Tensor buffers.
arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1);
arg_buffers_.resize(kernel->xla_input_shapes.size());
@ -140,6 +148,16 @@ void XlaComputationLaunchContext::PopulateInputs(
t = &(ctx->input(arg_num));
}
if (use_multiple_streams_) {
CHECK(stream) << "Must have a stream available when using XLA tensors!";
XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
CHECK(xla_tensor);
if (se::Event* event = xla_tensor->GetDefinitionEvent(stream)) {
stream->ThenWaitFor(event);
xla_tensor->SetDefinedOn(stream);
}
}
const xla::Shape on_device_shape =
client_->backend().transfer_manager()->HostShapeToDeviceShape(shape);
if (xla::ShapeUtil::IsTuple(on_device_shape)) {
@ -248,6 +266,12 @@ void XlaComputationLaunchContext::PopulateOutputs(
if (xla_tensor) {
xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
if (use_multiple_streams_) {
se::Event event(stream->parent());
CHECK(event.Init());
stream->ThenRecordEvent(&event);
xla_tensor->SetDefinedOn(stream, std::move(event));
}
} else {
// xla_tensor wasn't valid, which must mean this is a zero-element
// tensor.
@ -302,6 +326,12 @@ void XlaComputationLaunchContext::PopulateOutputs(
CHECK(xla_tensor);
xla_tensor->set_shaped_buffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_));
if (use_multiple_streams_) {
se::Event event(stream->parent());
CHECK(event.Init());
stream->ThenRecordEvent(&event);
xla_tensor->SetDefinedOn(stream, std::move(event));
}
*variable->tensor() = output_tensor;
} else {
Tensor output_tensor = XlaTensorBuffer::MakeTensor(

View File

@ -76,9 +76,15 @@ class XlaComputationLaunchContext {
// Create a new launch context. 'allocate_xla_tensors' is true if allocated
// output tensors and variables are always XlaTensors. If false they are
// assumed to be "normal" device pointers.
// If 'use_multiple_streams' is true, tensors may be defined and used on
// multiple streams and so se::Events must be defined and waited for. If
// 'use_multiple_streams' is true, 'allocate_xla_tensors' must also be true
// because we track inter-stream dependencies through events inside XlaTensor
// objects.
XlaComputationLaunchContext(xla::LocalClient* client,
xla::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors);
bool allocate_xla_tensors,
bool use_multiple_streams);
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
// `variables` is a map from TensorFlow argument number to resource variable.
@ -99,6 +105,7 @@ class XlaComputationLaunchContext {
xla::LocalClient* client_;
xla::DeviceMemoryAllocator* xla_allocator_;
bool allocate_xla_tensors_;
bool use_multiple_streams_;
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_;
std::vector<xla::ShapedBuffer*> arg_ptrs_;
};

View File

@ -73,6 +73,33 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape,
return Status::OK();
}
se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) {
if (!definition_event_.has_value()) {
return nullptr;
}
// The set of defined streams is expected to be very small indeed (usually
// 1-2), so a simple linear scan should be fast enough.
if (std::find(streams_defined_on_.begin(), streams_defined_on_.end(),
stream) != streams_defined_on_.end()) {
// stream is in streams_defined_on_; it doesn't need to be waited on.
return nullptr;
}
return &*definition_event_;
}
void XlaTensor::SetDefinedOn(se::Stream* stream, se::Event event) {
CHECK(!definition_event_.has_value())
<< "SetDefinedOn must only be called once!";
definition_event_ = std::move(event);
streams_defined_on_.push_back(stream);
}
void XlaTensor::SetDefinedOn(se::Stream* stream) {
streams_defined_on_.push_back(stream);
}
// The pointer tag, OR-ed into the XlaTensor's address to distinguish it from
// device-side tensors, which are either CPU or GPU memory pointers. This works
// because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits.

View File

@ -85,6 +85,24 @@ class XlaTensor {
host_tensor_.reset(new Tensor(tensor));
}
// If the tensor's content is not yet defined on 'stream', and there exists an
// se::Event declaring when the tensor's content is defined, return it.
// Otherwise, return nullptr. If this function returns nullptr then the
// tensor's content can be read on 'stream' without additional
// synchronization.
se::Event* GetDefinitionEvent(se::Stream* stream);
// Assert that the tensor's content is defined on 'stream' by the time 'event'
// triggers.
void SetDefinedOn(se::Stream* stream, se::Event event);
// Assert that the tensor's content is defined on 'stream'. This version does
// not provide an event, and must be called *after* SetDefinedOn(Stream,
// Event). This call can be read as an assertion that the definition event has
// been waited on by 'stream', so further calls to GetDefinitionEvent(stream)
// do not need to also wait on the event.
void SetDefinedOn(se::Stream* stream);
// Convert from a raw pointer to an XlaTensor, removing the pointer tag.
static XlaTensor* FromOpaquePointer(void* ptr);
// Convert to a raw pointer from an XlaTensor, adding the pointer tag.
@ -95,6 +113,13 @@ class XlaTensor {
std::unique_ptr<xla::ScopedShapedBuffer> shaped_buffer_;
// An optional host tensor value.
std::unique_ptr<Tensor> host_tensor_;
// An optional event that is triggered when the tensor's content has been
// defined. If this event is nullptr, it is assumed that the tensor's content
// is always defined.
gtl::optional<se::Event> definition_event_;
// A list of all streams for which the tensor's content is defined for any
// newly enqueued command.
gtl::InlinedVector<se::Stream*, 2> streams_defined_on_;
};
} // namespace tensorflow

View File

@ -44,6 +44,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
se::Stream* stream, const ShapedBuffer& device_buffer) {
StatusOr<std::unique_ptr<Literal>> ret;
se::Stream* substream = stream->GetOrCreateSubStream();
substream->ThenWaitFor(stream);
auto cleanup = tensorflow::gtl::MakeCleanup(
[&]() { stream->ReturnSubStream(substream); });
@ -64,6 +65,7 @@ Status TransferManager::TransferLiteralToDevice(
// Use a substream so that if we are called from a HostCallback we don't
// deadlock.
se::Stream* substream = stream->GetOrCreateSubStream();
substream->ThenWaitFor(stream);
auto cleanup = tensorflow::gtl::MakeCleanup(
[&]() { stream->ReturnSubStream(substream); });
TF_RETURN_IF_ERROR(

View File

@ -15,9 +15,9 @@ limitations under the License.
#include "tensorflow/stream_executor/event.h"
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
#include "tensorflow/stream_executor/stream.h"
namespace stream_executor {
@ -27,9 +27,12 @@ Event::Event(StreamExecutor* stream_exec)
stream_exec_->implementation()->CreateEventImplementation()) {}
Event::~Event() {
auto status = stream_exec_->DeallocateEvent(this);
if (!status.ok()) {
LOG(ERROR) << status.error_message();
// Deal with nullptr implementation_, as this event may have been std::moved.
if (stream_exec_ && implementation_) {
auto status = stream_exec_->DeallocateEvent(this);
if (!status.ok()) {
LOG(ERROR) << status.error_message();
}
}
}

View File

@ -61,6 +61,9 @@ class Event {
// Returns a pointer to the underlying platform-specific implementation.
internal::EventInterface* implementation() { return implementation_.get(); }
Event(Event&&) = default;
Event& operator=(Event&&) = default;
private:
friend class Stream;

View File

@ -5228,24 +5228,11 @@ port::Status Stream::BlockHostUntilDone() {
return status;
}
port::Status first_error;
{
// Wait until all active sub-streams have done their tasks.
mutex_lock lock(mu_);
for (auto &stream : sub_streams_) {
if (!stream.second) {
first_error.Update(stream.first->BlockHostUntilDone());
// Set this sub-stream as available.
stream.second = true;
}
}
}
temporary_memory_manager_.DeallocateFinalizedTemporaries();
first_error.Update(parent_->BlockHostUntilDone(this));
CheckError(first_error.ok());
return first_error;
port::Status error = parent_->BlockHostUntilDone(this);
CheckError(error.ok());
return error;
}
} // namespace stream_executor