[SE,XLA] Switch to using multiple streams in xla_device_context
Instead of having one stream for compute, host-to-device and device-to-host transfers, switch to having separate streams, just like the GPU does. Add a se::Event field to XlaTensor to allow accurate inter-stream dependencies to be created. As part of this: - Fix TransferManager::TransferLiteralFrom/ToDevice to correctly make generated substreams wait on their master stream. - Fix Stream::BlockHostUntilDone() to not block on or return substreams. This behavior is completely broken and not only nondeterministically returns substreams to the pool but causes indefinite hangs with the HostStream. PiperOrigin-RevId: 203726543
This commit is contained in:
parent
caf711b6be
commit
955e356e4c
tensorflow
@ -115,6 +115,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
|||||||
const XlaDevice::Metadata* metadata = nullptr;
|
const XlaDevice::Metadata* metadata = nullptr;
|
||||||
Status s = XlaDevice::GetMetadata(ctx, &metadata);
|
Status s = XlaDevice::GetMetadata(ctx, &metadata);
|
||||||
bool allocate_xla_tensors = s.ok();
|
bool allocate_xla_tensors = s.ok();
|
||||||
|
bool use_multiple_streams = s.ok() && metadata->UseMultipleStreams();
|
||||||
|
|
||||||
// Get the platform_id_ for XLA_* devices.
|
// Get the platform_id_ for XLA_* devices.
|
||||||
if (platform_id_ == nullptr) {
|
if (platform_id_ == nullptr) {
|
||||||
@ -180,8 +181,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
|||||||
|
|
||||||
VLOG(1) << "Executing XLA Computation...";
|
VLOG(1) << "Executing XLA Computation...";
|
||||||
|
|
||||||
XlaComputationLaunchContext launch_context(client, xla_allocator,
|
XlaComputationLaunchContext launch_context(
|
||||||
allocate_xla_tensors);
|
client, xla_allocator, allocate_xla_tensors, use_multiple_streams);
|
||||||
launch_context.PopulateInputs(ctx, kernel, variables);
|
launch_context.PopulateInputs(ctx, kernel, variables);
|
||||||
|
|
||||||
// Execute the computation.
|
// Execute the computation.
|
||||||
|
@ -53,7 +53,9 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
|||||||
|
|
||||||
// Builds an XLA allocator for the device.
|
// Builds an XLA allocator for the device.
|
||||||
XlaComputationLaunchContext launch_context(
|
XlaComputationLaunchContext launch_context(
|
||||||
client, client->backend().memory_allocator(), true);
|
client, client->backend().memory_allocator(),
|
||||||
|
/*allocate_xla_tensors=*/true,
|
||||||
|
/*use_multiple_streams=*/metadata.UseMultipleStreams());
|
||||||
|
|
||||||
launch_context.PopulateInputs(ctx, result, variables);
|
launch_context.PopulateInputs(ctx, result, variables);
|
||||||
|
|
||||||
|
@ -54,6 +54,7 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
|
|||||||
DEVICE_CPU_XLA_JIT, options, name_prefix,
|
DEVICE_CPU_XLA_JIT, options, name_prefix,
|
||||||
registration,
|
registration,
|
||||||
/*transfer_as_literal=*/false,
|
/*transfer_as_literal=*/false,
|
||||||
|
/*use_multiple_streams=*/false,
|
||||||
/*shape_representation_fn=*/{},
|
/*shape_representation_fn=*/{},
|
||||||
/*padded_shape_fn=*/{}, &device));
|
/*padded_shape_fn=*/{}, &device));
|
||||||
devices->push_back(device.release());
|
devices->push_back(device.release());
|
||||||
|
@ -130,7 +130,7 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
|
|||||||
const string& jit_device_name, const SessionOptions& options,
|
const string& jit_device_name, const SessionOptions& options,
|
||||||
const string& name_prefix,
|
const string& name_prefix,
|
||||||
const XlaOpRegistry::DeviceRegistration& registration,
|
const XlaOpRegistry::DeviceRegistration& registration,
|
||||||
bool transfer_as_literal,
|
bool transfer_as_literal, bool use_multiple_streams,
|
||||||
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
||||||
const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device) {
|
const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device) {
|
||||||
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
|
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
|
||||||
@ -151,9 +151,10 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
|
|||||||
DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
|
DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
|
||||||
strings::StrCat("device: ", device_name, " device"));
|
strings::StrCat("device: ", device_name, " device"));
|
||||||
|
|
||||||
device->reset(new XlaDevice(
|
device->reset(
|
||||||
options, attrs, device_ordinal, DeviceType(jit_device_name),
|
new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name),
|
||||||
platform.ValueOrDie(), transfer_as_literal, shape_representation_fn,
|
platform.ValueOrDie(), transfer_as_literal,
|
||||||
|
use_multiple_streams, shape_representation_fn,
|
||||||
padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn));
|
padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -161,12 +162,13 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
|
|||||||
XlaDevice::Metadata::Metadata(
|
XlaDevice::Metadata::Metadata(
|
||||||
int device_ordinal, se::Platform* platform, const DeviceType& device_type,
|
int device_ordinal, se::Platform* platform, const DeviceType& device_type,
|
||||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||||
PaddedShapeFn padded_shape_fn)
|
PaddedShapeFn padded_shape_fn, bool use_multiple_streams)
|
||||||
: device_ordinal_(device_ordinal),
|
: device_ordinal_(device_ordinal),
|
||||||
device_type_(device_type),
|
device_type_(device_type),
|
||||||
platform_(platform),
|
platform_(platform),
|
||||||
shape_representation_fn_(std::move(shape_representation_fn)),
|
shape_representation_fn_(std::move(shape_representation_fn)),
|
||||||
padded_shape_fn_(std::move(padded_shape_fn)) {}
|
padded_shape_fn_(std::move(padded_shape_fn)),
|
||||||
|
use_multiple_streams_(use_multiple_streams) {}
|
||||||
|
|
||||||
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
|
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
|
||||||
|
|
||||||
@ -200,16 +202,18 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
|
|||||||
XlaDevice::XlaDevice(
|
XlaDevice::XlaDevice(
|
||||||
const SessionOptions& options, const DeviceAttributes& attrs,
|
const SessionOptions& options, const DeviceAttributes& attrs,
|
||||||
int device_ordinal, const DeviceType& jit_device_name,
|
int device_ordinal, const DeviceType& jit_device_name,
|
||||||
se::Platform* platform, bool transfer_as_literal,
|
se::Platform* platform, bool transfer_as_literal, bool use_multiple_streams,
|
||||||
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
||||||
const PaddedShapeFn& padded_shape_fn)
|
const PaddedShapeFn& padded_shape_fn)
|
||||||
: LocalDevice(options, attrs),
|
: LocalDevice(options, attrs),
|
||||||
xla_metadata_(device_ordinal, platform, jit_device_name,
|
xla_metadata_(device_ordinal, platform, jit_device_name,
|
||||||
shape_representation_fn, padded_shape_fn),
|
shape_representation_fn, padded_shape_fn,
|
||||||
|
use_multiple_streams),
|
||||||
device_ordinal_(device_ordinal),
|
device_ordinal_(device_ordinal),
|
||||||
jit_device_name_(jit_device_name),
|
jit_device_name_(jit_device_name),
|
||||||
xla_allocator_(nullptr),
|
xla_allocator_(nullptr),
|
||||||
platform_(platform),
|
platform_(platform),
|
||||||
|
use_multiple_streams_(use_multiple_streams),
|
||||||
transfer_as_literal_(transfer_as_literal),
|
transfer_as_literal_(transfer_as_literal),
|
||||||
shape_representation_fn_(shape_representation_fn) {
|
shape_representation_fn_(shape_representation_fn) {
|
||||||
VLOG(1) << "Created XLA device " << jit_device_name;
|
VLOG(1) << "Created XLA device " << jit_device_name;
|
||||||
@ -253,6 +257,30 @@ xla::StatusOr<se::Stream*> XlaDevice::GetStream() {
|
|||||||
return stream_.get();
|
return stream_.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xla::StatusOr<se::Stream*> XlaDevice::GetDeviceToHostStream() {
|
||||||
|
if (!use_multiple_streams_) {
|
||||||
|
return GetStream();
|
||||||
|
}
|
||||||
|
if (!device_to_host_stream_) {
|
||||||
|
xla::Backend* backend = client()->mutable_backend();
|
||||||
|
TF_ASSIGN_OR_RETURN(device_to_host_stream_,
|
||||||
|
backend->BorrowStream(device_ordinal_));
|
||||||
|
}
|
||||||
|
return device_to_host_stream_.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::StatusOr<se::Stream*> XlaDevice::GetHostToDeviceStream() {
|
||||||
|
if (!use_multiple_streams_) {
|
||||||
|
return GetStream();
|
||||||
|
}
|
||||||
|
if (!host_to_device_stream_) {
|
||||||
|
xla::Backend* backend = client()->mutable_backend();
|
||||||
|
TF_ASSIGN_OR_RETURN(host_to_device_stream_,
|
||||||
|
backend->BorrowStream(device_ordinal_));
|
||||||
|
}
|
||||||
|
return host_to_device_stream_.get();
|
||||||
|
}
|
||||||
|
|
||||||
Status XlaDevice::CreateAndSetGpuDeviceInfo() {
|
Status XlaDevice::CreateAndSetGpuDeviceInfo() {
|
||||||
if (gpu_device_info_ == nullptr) {
|
if (gpu_device_info_ == nullptr) {
|
||||||
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
|
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
|
||||||
@ -263,8 +291,9 @@ Status XlaDevice::CreateAndSetGpuDeviceInfo() {
|
|||||||
// gpu_device_info_->default_context.
|
// gpu_device_info_->default_context.
|
||||||
gpu_device_info_ = MakeUnique<GpuDeviceInfo>();
|
gpu_device_info_ = MakeUnique<GpuDeviceInfo>();
|
||||||
gpu_device_info_->stream = stream;
|
gpu_device_info_->stream = stream;
|
||||||
gpu_device_info_->default_context = new XlaDeviceContext(
|
gpu_device_info_->default_context =
|
||||||
stream, client(), transfer_as_literal_, shape_representation_fn_);
|
new XlaDeviceContext(stream, stream, stream, client(),
|
||||||
|
transfer_as_literal_, shape_representation_fn_);
|
||||||
set_tensorflow_gpu_device_info(gpu_device_info_.get());
|
set_tensorflow_gpu_device_info(gpu_device_info_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -276,10 +305,16 @@ Status XlaDevice::FillContextMap(const Graph* graph,
|
|||||||
VLOG(1) << "XlaDevice::FillContextMap";
|
VLOG(1) << "XlaDevice::FillContextMap";
|
||||||
device_context_map->resize(graph->num_node_ids());
|
device_context_map->resize(graph->num_node_ids());
|
||||||
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
|
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
|
||||||
|
TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream,
|
||||||
|
GetDeviceToHostStream());
|
||||||
|
TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream,
|
||||||
|
GetHostToDeviceStream());
|
||||||
|
|
||||||
// Call GetAllocator for the side-effect of ensuring the allocator is created.
|
// Call GetAllocator for the side-effect of ensuring the allocator is created.
|
||||||
GetAllocator({});
|
GetAllocator({});
|
||||||
auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_,
|
auto ctx = new XlaDeviceContext(
|
||||||
shape_representation_fn_);
|
stream, host_to_device_stream, device_to_host_stream, client(),
|
||||||
|
transfer_as_literal_, shape_representation_fn_);
|
||||||
for (Node* n : graph->nodes()) {
|
for (Node* n : graph->nodes()) {
|
||||||
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
|
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
|
||||||
ctx->Ref();
|
ctx->Ref();
|
||||||
@ -326,8 +361,13 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
|
|||||||
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
|
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
|
||||||
Notification n;
|
Notification n;
|
||||||
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
|
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
|
||||||
XlaTransferManager manager(stream, client(), transfer_as_literal_,
|
TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream,
|
||||||
shape_representation_fn_);
|
GetDeviceToHostStream());
|
||||||
|
TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream,
|
||||||
|
GetHostToDeviceStream());
|
||||||
|
XlaTransferManager manager(stream, host_to_device_stream,
|
||||||
|
device_to_host_stream, client(),
|
||||||
|
transfer_as_literal_, shape_representation_fn_);
|
||||||
manager.CopyCPUTensorToDevice(&parsed, this, ©,
|
manager.CopyCPUTensorToDevice(&parsed, this, ©,
|
||||||
[&n, &status](const Status& s) {
|
[&n, &status](const Status& s) {
|
||||||
status = s;
|
status = s;
|
||||||
|
@ -57,7 +57,7 @@ class XlaDevice : public LocalDevice {
|
|||||||
Metadata(int device_ordinal, se::Platform* platform,
|
Metadata(int device_ordinal, se::Platform* platform,
|
||||||
const DeviceType& device_type,
|
const DeviceType& device_type,
|
||||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||||
PaddedShapeFn padded_shape_fn);
|
PaddedShapeFn padded_shape_fn, bool use_multiple_streams);
|
||||||
|
|
||||||
// The index of the device on this host.
|
// The index of the device on this host.
|
||||||
int device_ordinal() const;
|
int device_ordinal() const;
|
||||||
@ -70,12 +70,15 @@ class XlaDevice : public LocalDevice {
|
|||||||
}
|
}
|
||||||
const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; }
|
const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; }
|
||||||
|
|
||||||
|
bool UseMultipleStreams() const { return use_multiple_streams_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const int device_ordinal_;
|
const int device_ordinal_;
|
||||||
const DeviceType device_type_;
|
const DeviceType device_type_;
|
||||||
se::Platform* platform_; // Not owned.
|
se::Platform* platform_; // Not owned.
|
||||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
|
||||||
PaddedShapeFn padded_shape_fn_;
|
PaddedShapeFn padded_shape_fn_;
|
||||||
|
const bool use_multiple_streams_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
|
TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
|
||||||
};
|
};
|
||||||
@ -89,6 +92,8 @@ class XlaDevice : public LocalDevice {
|
|||||||
// 'transfer_as_literal' is true if device<->host transfers must be done using
|
// 'transfer_as_literal' is true if device<->host transfers must be done using
|
||||||
// XLA's TransferLiteral{To,From}Device interface. If false, we can use
|
// XLA's TransferLiteral{To,From}Device interface. If false, we can use
|
||||||
// ThenMemcpy instead.
|
// ThenMemcpy instead.
|
||||||
|
// If 'use_multiple_streams' is true, we create separate streams for
|
||||||
|
// host-to-device and device-to-host communication.
|
||||||
// If padded_shape_fn is empty, a default implementation that returns
|
// If padded_shape_fn is empty, a default implementation that returns
|
||||||
// the on-host shape is used.
|
// the on-host shape is used.
|
||||||
static Status Create(
|
static Status Create(
|
||||||
@ -96,7 +101,7 @@ class XlaDevice : public LocalDevice {
|
|||||||
int device_ordinal, const string& jit_device_name,
|
int device_ordinal, const string& jit_device_name,
|
||||||
const SessionOptions& options, const string& name_prefix,
|
const SessionOptions& options, const string& name_prefix,
|
||||||
const XlaOpRegistry::DeviceRegistration& registration,
|
const XlaOpRegistry::DeviceRegistration& registration,
|
||||||
bool transfer_as_literal,
|
bool transfer_as_literal, bool use_multiple_streams,
|
||||||
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
||||||
const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device);
|
const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device);
|
||||||
|
|
||||||
@ -106,6 +111,7 @@ class XlaDevice : public LocalDevice {
|
|||||||
XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
|
XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
|
||||||
int device_ordinal, const DeviceType& jit_device_name,
|
int device_ordinal, const DeviceType& jit_device_name,
|
||||||
se::Platform* platform, bool transfer_as_literal,
|
se::Platform* platform, bool transfer_as_literal,
|
||||||
|
bool use_multiple_streams,
|
||||||
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
||||||
const PaddedShapeFn& padded_shape_fn);
|
const PaddedShapeFn& padded_shape_fn);
|
||||||
~XlaDevice() override;
|
~XlaDevice() override;
|
||||||
@ -126,6 +132,8 @@ class XlaDevice : public LocalDevice {
|
|||||||
xla::LocalClient* client() const;
|
xla::LocalClient* client() const;
|
||||||
const Metadata& metadata() { return xla_metadata_; }
|
const Metadata& metadata() { return xla_metadata_; }
|
||||||
xla::StatusOr<se::Stream*> GetStream();
|
xla::StatusOr<se::Stream*> GetStream();
|
||||||
|
xla::StatusOr<se::Stream*> GetHostToDeviceStream();
|
||||||
|
xla::StatusOr<se::Stream*> GetDeviceToHostStream();
|
||||||
|
|
||||||
// If not already set, create and set GpuDeviceInfo.
|
// If not already set, create and set GpuDeviceInfo.
|
||||||
// Not thread-safe
|
// Not thread-safe
|
||||||
@ -146,6 +154,16 @@ class XlaDevice : public LocalDevice {
|
|||||||
// copying back and forth between CPU and the device, and
|
// copying back and forth between CPU and the device, and
|
||||||
// computations enqueued by XLA.
|
// computations enqueued by XLA.
|
||||||
xla::Backend::StreamPtr stream_;
|
xla::Backend::StreamPtr stream_;
|
||||||
|
// If true, only stream_ is valid and all computation and transfers use
|
||||||
|
// stream_. If false, computation is performed by stream_ and transfers are
|
||||||
|
// performed by host_to_device/device_to_host_stream.
|
||||||
|
bool use_multiple_streams_;
|
||||||
|
// If use_multiple_streams_, host to device transfers are performed using this
|
||||||
|
// stream.
|
||||||
|
xla::Backend::StreamPtr host_to_device_stream_;
|
||||||
|
// If use_multiple_streams_, device to host transfers are performed using this
|
||||||
|
// stream.
|
||||||
|
xla::Backend::StreamPtr device_to_host_stream_;
|
||||||
// Must we use XLA's transfer manager for correct host<->device transfers? if
|
// Must we use XLA's transfer manager for correct host<->device transfers? if
|
||||||
// false, we can use ThenMemcpy() instead.
|
// false, we can use ThenMemcpy() instead.
|
||||||
bool transfer_as_literal_;
|
bool transfer_as_literal_;
|
||||||
|
@ -48,13 +48,20 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
|
|||||||
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
|
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
|
||||||
|
|
||||||
XlaTransferManager::XlaTransferManager(
|
XlaTransferManager::XlaTransferManager(
|
||||||
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
|
se::Stream* compute_stream, se::Stream* host_to_device_stream,
|
||||||
|
se::Stream* device_to_host_stream, xla::LocalClient* client,
|
||||||
|
bool transfer_as_literal,
|
||||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
|
||||||
: stream_(stream),
|
: stream_(compute_stream),
|
||||||
|
host_to_device_stream_(host_to_device_stream),
|
||||||
|
device_to_host_stream_(device_to_host_stream),
|
||||||
client_(client),
|
client_(client),
|
||||||
transfer_manager_(client->backend().transfer_manager()),
|
transfer_manager_(client->backend().transfer_manager()),
|
||||||
transfer_as_literal_(transfer_as_literal),
|
transfer_as_literal_(transfer_as_literal),
|
||||||
shape_representation_fn_(std::move(shape_representation_fn)) {
|
shape_representation_fn_(std::move(shape_representation_fn)) {
|
||||||
|
CHECK(host_to_device_stream_ != nullptr);
|
||||||
|
CHECK(device_to_host_stream_ != nullptr);
|
||||||
|
CHECK(stream_ != nullptr);
|
||||||
if (!shape_representation_fn_) {
|
if (!shape_representation_fn_) {
|
||||||
shape_representation_fn_ =
|
shape_representation_fn_ =
|
||||||
[](const TensorShape& shape,
|
[](const TensorShape& shape,
|
||||||
@ -70,12 +77,19 @@ Status XlaTransferManager::TransferLiteralToDevice(
|
|||||||
xla::BorrowingLiteral literal(
|
xla::BorrowingLiteral literal(
|
||||||
static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape);
|
static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape);
|
||||||
|
|
||||||
const xla::ShapedBuffer& shaped_buffer =
|
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
|
||||||
XlaTensor::FromTensor(device_tensor)->shaped_buffer();
|
const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
|
||||||
VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " "
|
VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " "
|
||||||
<< shaped_buffer.ToString();
|
<< shaped_buffer.ToString();
|
||||||
return transfer_manager_->TransferLiteralToDevice(stream_, literal,
|
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDevice(
|
||||||
shaped_buffer);
|
host_to_device_stream_, literal, shaped_buffer));
|
||||||
|
if (UseMultipleStreams()) {
|
||||||
|
se::Event event(stream_->parent());
|
||||||
|
TF_RET_CHECK(event.Init()) << "Event failed to initialize!";
|
||||||
|
host_to_device_stream_->ThenRecordEvent(&event);
|
||||||
|
xla_tensor->SetDefinedOn(host_to_device_stream_, std::move(event));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status XlaTransferManager::TransferLiteralFromDevice(
|
Status XlaTransferManager::TransferLiteralFromDevice(
|
||||||
@ -83,9 +97,9 @@ Status XlaTransferManager::TransferLiteralFromDevice(
|
|||||||
const xla::ShapedBuffer& shaped_buffer =
|
const xla::ShapedBuffer& shaped_buffer =
|
||||||
XlaTensor::FromTensor(&device_tensor)->shaped_buffer();
|
XlaTensor::FromTensor(&device_tensor)->shaped_buffer();
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
|
||||||
std::unique_ptr<xla::Literal> literal,
|
transfer_manager_->TransferLiteralFromDevice(
|
||||||
transfer_manager_->TransferLiteralFromDevice(stream_, shaped_buffer));
|
device_to_host_stream_, shaped_buffer));
|
||||||
VLOG(1) << "Transfer from device as literal: " << literal->ToString() << " "
|
VLOG(1) << "Transfer from device as literal: " << literal->ToString() << " "
|
||||||
<< shaped_buffer.ToString();
|
<< shaped_buffer.ToString();
|
||||||
Tensor tensor;
|
Tensor tensor;
|
||||||
@ -103,12 +117,16 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
|||||||
Device* device,
|
Device* device,
|
||||||
Tensor* device_tensor,
|
Tensor* device_tensor,
|
||||||
StatusCallback done) const {
|
StatusCallback done) const {
|
||||||
if (cpu_tensor->NumElements() > 0) {
|
if (cpu_tensor->NumElements() == 0) {
|
||||||
|
VLOG(2) << "CopyCPUTensorToDevice empty tensor";
|
||||||
|
done(Status::OK());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
VLOG(2) << "CopyCPUTensorToDevice "
|
VLOG(2) << "CopyCPUTensorToDevice "
|
||||||
<< reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
|
<< reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
|
||||||
<< " "
|
<< " "
|
||||||
<< reinterpret_cast<const void*>(
|
<< reinterpret_cast<const void*>(device_tensor->tensor_data().data())
|
||||||
device_tensor->tensor_data().data())
|
|
||||||
<< " " << cpu_tensor->NumElements() << " "
|
<< " " << cpu_tensor->NumElements() << " "
|
||||||
<< cpu_tensor->shape().DebugString() << " "
|
<< cpu_tensor->shape().DebugString() << " "
|
||||||
<< device_tensor->shape().DebugString();
|
<< device_tensor->shape().DebugString();
|
||||||
@ -119,16 +137,16 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
|||||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
|
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
|
||||||
CHECK(xla_tensor);
|
CHECK(xla_tensor);
|
||||||
|
|
||||||
xla::StatusOr<TensorShape> shape_or_status = shape_representation_fn_(
|
xla::StatusOr<TensorShape> shape_or_status =
|
||||||
device_tensor->shape(), device_tensor->dtype());
|
shape_representation_fn_(device_tensor->shape(), device_tensor->dtype());
|
||||||
if (!shape_or_status.ok()) {
|
if (!shape_or_status.ok()) {
|
||||||
done(shape_or_status.status());
|
done(shape_or_status.status());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
TensorShape shape = shape_or_status.ValueOrDie();
|
TensorShape shape = shape_or_status.ValueOrDie();
|
||||||
if (!xla_tensor->has_shaped_buffer()) {
|
if (!xla_tensor->has_shaped_buffer()) {
|
||||||
Status s = xla_tensor->AllocateShapedBuffer(
|
Status s =
|
||||||
device_tensor->dtype(), shape, client_,
|
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
|
||||||
stream_->parent()->device_ordinal());
|
stream_->parent()->device_ordinal());
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
done(s);
|
done(s);
|
||||||
@ -148,23 +166,18 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
|||||||
} else {
|
} else {
|
||||||
se::DeviceMemoryBase dev_dst_ptr =
|
se::DeviceMemoryBase dev_dst_ptr =
|
||||||
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
|
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
|
||||||
stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
|
host_to_device_stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
|
||||||
// TODO(hpucha): Make this asynchronous.
|
// TODO(hpucha): Make this asynchronous.
|
||||||
Status block_status = stream_->BlockHostUntilDone();
|
Status block_status = host_to_device_stream_->BlockHostUntilDone();
|
||||||
if (!block_status.ok()) {
|
if (!block_status.ok()) {
|
||||||
status = xla::InternalError(
|
status = xla::InternalError(
|
||||||
"Failed to complete data transfer on stream %p: %s", stream_,
|
"Failed to complete data transfer on stream %p: %s",
|
||||||
block_status.error_message().c_str());
|
host_to_device_stream_, block_status.error_message().c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
xla_tensor->set_host_tensor(*cpu_tensor);
|
xla_tensor->set_host_tensor(*cpu_tensor);
|
||||||
|
|
||||||
done(status);
|
done(status);
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
VLOG(2) << "CopyCPUTensorToDevice empty tensor";
|
|
||||||
done(Status::OK());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||||
@ -172,10 +185,13 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
|||||||
Device* device,
|
Device* device,
|
||||||
Tensor* cpu_tensor,
|
Tensor* cpu_tensor,
|
||||||
StatusCallback done) {
|
StatusCallback done) {
|
||||||
if (device_tensor->NumElements() > 0) {
|
if (device_tensor->NumElements() == 0) {
|
||||||
|
VLOG(2) << "CopyDeviceTensorToCPU empty tensor";
|
||||||
|
done(Status::OK());
|
||||||
|
return;
|
||||||
|
}
|
||||||
VLOG(2) << "CopyDeviceTensorToCPU "
|
VLOG(2) << "CopyDeviceTensorToCPU "
|
||||||
<< reinterpret_cast<const void*>(
|
<< reinterpret_cast<const void*>(device_tensor->tensor_data().data())
|
||||||
device_tensor->tensor_data().data())
|
|
||||||
<< " "
|
<< " "
|
||||||
<< reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
|
<< reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
|
||||||
<< " " << device_tensor->NumElements() << " "
|
<< " " << device_tensor->NumElements() << " "
|
||||||
@ -186,14 +202,21 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
|||||||
se::DeviceMemoryBase dev_src_ptr =
|
se::DeviceMemoryBase dev_src_ptr =
|
||||||
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
|
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
|
||||||
void* dst_ptr = DMAHelper::base(cpu_tensor);
|
void* dst_ptr = DMAHelper::base(cpu_tensor);
|
||||||
|
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
|
||||||
|
|
||||||
|
if (se::Event* event =
|
||||||
|
xla_tensor->GetDefinitionEvent(device_to_host_stream_)) {
|
||||||
|
device_to_host_stream_->ThenWaitFor(event);
|
||||||
|
xla_tensor->SetDefinedOn(device_to_host_stream_);
|
||||||
|
}
|
||||||
|
|
||||||
Status status;
|
Status status;
|
||||||
if (transfer_as_literal_) {
|
if (transfer_as_literal_) {
|
||||||
status = TransferLiteralFromDevice(cpu_tensor, *device_tensor);
|
status = TransferLiteralFromDevice(cpu_tensor, *device_tensor);
|
||||||
} else {
|
} else {
|
||||||
stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes);
|
device_to_host_stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes);
|
||||||
// TODO(hpucha): Make this asynchronous.
|
// TODO(hpucha): Make this asynchronous.
|
||||||
Status block_status = stream_->BlockHostUntilDone();
|
Status block_status = device_to_host_stream_->BlockHostUntilDone();
|
||||||
if (!block_status.ok()) {
|
if (!block_status.ok()) {
|
||||||
status = xla::InternalError(
|
status = xla::InternalError(
|
||||||
"Failed to complete data transfer on stream %p: %s", stream_,
|
"Failed to complete data transfer on stream %p: %s", stream_,
|
||||||
@ -202,21 +225,24 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
|||||||
}
|
}
|
||||||
|
|
||||||
done(status);
|
done(status);
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
VLOG(2) << "CopyDeviceTensorToCPU empty tensor";
|
|
||||||
done(Status::OK());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
|
void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
|
||||||
Tensor* dst_tensor,
|
Tensor* dst_tensor,
|
||||||
const StatusCallback& done) {
|
const StatusCallback& done) {
|
||||||
|
VLOG(2) << "CopyDeviceTensorToDevice "
|
||||||
|
<< reinterpret_cast<const void*>(src_tensor.tensor_data().data())
|
||||||
|
<< " "
|
||||||
|
<< reinterpret_cast<const void*>(dst_tensor->tensor_data().data());
|
||||||
// TODO(phawkins): replace this code with an asynchronous implementation.
|
// TODO(phawkins): replace this code with an asynchronous implementation.
|
||||||
auto body = [&]() {
|
auto body = [&]() {
|
||||||
if (src_tensor.NumElements() == 0) {
|
if (src_tensor.NumElements() == 0) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
// TODO(jmolloy): We co-opt the device_to_host stream for device to device
|
||||||
|
// transfers; perhaps we should have a dedicated device to device stream? or
|
||||||
|
// one per device?
|
||||||
|
auto device_to_device_stream = device_to_host_stream_;
|
||||||
XlaTensor* xla_src = XlaTensor::FromTensor(&src_tensor);
|
XlaTensor* xla_src = XlaTensor::FromTensor(&src_tensor);
|
||||||
XlaTensor* xla_dst = XlaTensor::FromTensor(dst_tensor);
|
XlaTensor* xla_dst = XlaTensor::FromTensor(dst_tensor);
|
||||||
CHECK(xla_src && xla_dst)
|
CHECK(xla_src && xla_dst)
|
||||||
@ -229,6 +255,13 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
|
|||||||
xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_,
|
xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_,
|
||||||
stream_->parent()->device_ordinal()));
|
stream_->parent()->device_ordinal()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (se::Event* event =
|
||||||
|
xla_src->GetDefinitionEvent(device_to_device_stream)) {
|
||||||
|
device_to_device_stream->ThenWaitFor(event);
|
||||||
|
xla_src->SetDefinedOn(device_to_device_stream);
|
||||||
|
TF_RETURN_IF_ERROR(device_to_device_stream->BlockHostUntilDone());
|
||||||
|
}
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
xla_dst->shaped_buffer().buffers().ForEachMutableElementWithStatus(
|
xla_dst->shaped_buffer().buffers().ForEachMutableElementWithStatus(
|
||||||
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
||||||
@ -247,9 +280,12 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
|
|||||||
}
|
}
|
||||||
|
|
||||||
XlaDeviceContext::XlaDeviceContext(
|
XlaDeviceContext::XlaDeviceContext(
|
||||||
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
|
se::Stream* compute_stream, se::Stream* host_to_device_stream,
|
||||||
|
se::Stream* device_to_host_stream, xla::LocalClient* client,
|
||||||
|
bool transfer_as_literal,
|
||||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
|
||||||
: manager_(stream, client, transfer_as_literal,
|
: manager_(compute_stream, host_to_device_stream, device_to_host_stream,
|
||||||
|
client, transfer_as_literal,
|
||||||
std::move(shape_representation_fn)) {}
|
std::move(shape_representation_fn)) {}
|
||||||
|
|
||||||
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||||
|
@ -47,7 +47,9 @@ class XlaDeviceAllocator : public Allocator {
|
|||||||
class XlaTransferManager {
|
class XlaTransferManager {
|
||||||
public:
|
public:
|
||||||
explicit XlaTransferManager(
|
explicit XlaTransferManager(
|
||||||
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
|
se::Stream* compute_stream, se::Stream* host_to_device_stream,
|
||||||
|
se::Stream* device_to_host_stream, xla::LocalClient* client,
|
||||||
|
bool transfer_as_literal,
|
||||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
|
||||||
|
|
||||||
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||||
@ -66,10 +68,17 @@ class XlaTransferManager {
|
|||||||
Tensor* device_tensor) const;
|
Tensor* device_tensor) const;
|
||||||
Status TransferLiteralFromDevice(Tensor* host_tensor,
|
Status TransferLiteralFromDevice(Tensor* host_tensor,
|
||||||
const Tensor& device_tensor) const;
|
const Tensor& device_tensor) const;
|
||||||
|
bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; }
|
||||||
|
|
||||||
// Stream obtained from a Device, used to transfer tensors between
|
// The main compute stream of the device, used to synchronize the transfer
|
||||||
// CPU and device.
|
// streams if they are set.
|
||||||
se::Stream* stream_;
|
se::Stream* stream_;
|
||||||
|
// The stream to use for transferring data from host to device. Can be
|
||||||
|
// idential to stream_, but must not be nullptr.
|
||||||
|
se::Stream* host_to_device_stream_;
|
||||||
|
// The stream to use for transferring data from device to host. Can be
|
||||||
|
// idential to stream_, but must not be nullptr.
|
||||||
|
se::Stream* device_to_host_stream_;
|
||||||
// For the underlying memory allocator and XLA's TransferManager.
|
// For the underlying memory allocator and XLA's TransferManager.
|
||||||
xla::LocalClient* client_;
|
xla::LocalClient* client_;
|
||||||
// Transfer manager, for marshalling data to and from the device.
|
// Transfer manager, for marshalling data to and from the device.
|
||||||
@ -85,7 +94,9 @@ class XlaTransferManager {
|
|||||||
class XlaDeviceContext : public DeviceContext {
|
class XlaDeviceContext : public DeviceContext {
|
||||||
public:
|
public:
|
||||||
explicit XlaDeviceContext(
|
explicit XlaDeviceContext(
|
||||||
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
|
se::Stream* compute_stream, se::Stream* host_to_device_stream,
|
||||||
|
se::Stream* device_to_host_stream, xla::LocalClient* client,
|
||||||
|
bool transfer_as_literal,
|
||||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
|
||||||
|
|
||||||
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||||
|
@ -49,6 +49,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
|
|||||||
XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
|
XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
|
||||||
name_prefix, registration,
|
name_prefix, registration,
|
||||||
/*transfer_as_literal=*/false,
|
/*transfer_as_literal=*/false,
|
||||||
|
/*use_multiple_streams=*/false,
|
||||||
/*shape_representation_fn=*/{},
|
/*shape_representation_fn=*/{},
|
||||||
/*padded_shape_fn=*/{}, &device);
|
/*padded_shape_fn=*/{}, &device);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
|
@ -52,6 +52,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
|
|||||||
DEVICE_INTERPRETER_XLA_JIT, options,
|
DEVICE_INTERPRETER_XLA_JIT, options,
|
||||||
name_prefix, registration,
|
name_prefix, registration,
|
||||||
/*transfer_as_literal=*/false,
|
/*transfer_as_literal=*/false,
|
||||||
|
/*use_multiple_streams=*/false,
|
||||||
/*shape_representation_fn=*/{},
|
/*shape_representation_fn=*/{},
|
||||||
/*padded_shape_fn=*/{}, &device));
|
/*padded_shape_fn=*/{}, &device));
|
||||||
devices->push_back(device.release());
|
devices->push_back(device.release());
|
||||||
|
@ -115,14 +115,22 @@ using internal::ExtractSubShapedBuffer;
|
|||||||
|
|
||||||
XlaComputationLaunchContext::XlaComputationLaunchContext(
|
XlaComputationLaunchContext::XlaComputationLaunchContext(
|
||||||
xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
|
xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
|
||||||
bool allocate_xla_tensors)
|
bool allocate_xla_tensors, bool use_multiple_streams)
|
||||||
: client_(client),
|
: client_(client),
|
||||||
xla_allocator_(xla_allocator),
|
xla_allocator_(xla_allocator),
|
||||||
allocate_xla_tensors_(allocate_xla_tensors) {}
|
allocate_xla_tensors_(allocate_xla_tensors),
|
||||||
|
use_multiple_streams_(use_multiple_streams) {
|
||||||
|
if (use_multiple_streams_) {
|
||||||
|
CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
|
||||||
|
"be allocating XLA tensors!";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void XlaComputationLaunchContext::PopulateInputs(
|
void XlaComputationLaunchContext::PopulateInputs(
|
||||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||||
const std::map<int, OptionalTensor>& variables) {
|
const std::map<int, OptionalTensor>& variables) {
|
||||||
|
se::Stream* stream =
|
||||||
|
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||||
// Build ShapedBuffers that point directly to the Tensor buffers.
|
// Build ShapedBuffers that point directly to the Tensor buffers.
|
||||||
arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1);
|
arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1);
|
||||||
arg_buffers_.resize(kernel->xla_input_shapes.size());
|
arg_buffers_.resize(kernel->xla_input_shapes.size());
|
||||||
@ -140,6 +148,16 @@ void XlaComputationLaunchContext::PopulateInputs(
|
|||||||
t = &(ctx->input(arg_num));
|
t = &(ctx->input(arg_num));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (use_multiple_streams_) {
|
||||||
|
CHECK(stream) << "Must have a stream available when using XLA tensors!";
|
||||||
|
XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
|
||||||
|
CHECK(xla_tensor);
|
||||||
|
if (se::Event* event = xla_tensor->GetDefinitionEvent(stream)) {
|
||||||
|
stream->ThenWaitFor(event);
|
||||||
|
xla_tensor->SetDefinedOn(stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const xla::Shape on_device_shape =
|
const xla::Shape on_device_shape =
|
||||||
client_->backend().transfer_manager()->HostShapeToDeviceShape(shape);
|
client_->backend().transfer_manager()->HostShapeToDeviceShape(shape);
|
||||||
if (xla::ShapeUtil::IsTuple(on_device_shape)) {
|
if (xla::ShapeUtil::IsTuple(on_device_shape)) {
|
||||||
@ -248,6 +266,12 @@ void XlaComputationLaunchContext::PopulateOutputs(
|
|||||||
if (xla_tensor) {
|
if (xla_tensor) {
|
||||||
xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
|
xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
|
||||||
ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
|
ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
|
||||||
|
if (use_multiple_streams_) {
|
||||||
|
se::Event event(stream->parent());
|
||||||
|
CHECK(event.Init());
|
||||||
|
stream->ThenRecordEvent(&event);
|
||||||
|
xla_tensor->SetDefinedOn(stream, std::move(event));
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// xla_tensor wasn't valid, which must mean this is a zero-element
|
// xla_tensor wasn't valid, which must mean this is a zero-element
|
||||||
// tensor.
|
// tensor.
|
||||||
@ -302,6 +326,12 @@ void XlaComputationLaunchContext::PopulateOutputs(
|
|||||||
CHECK(xla_tensor);
|
CHECK(xla_tensor);
|
||||||
xla_tensor->set_shaped_buffer(
|
xla_tensor->set_shaped_buffer(
|
||||||
ExtractSubShapedBuffer(&output, output_num, xla_allocator_));
|
ExtractSubShapedBuffer(&output, output_num, xla_allocator_));
|
||||||
|
if (use_multiple_streams_) {
|
||||||
|
se::Event event(stream->parent());
|
||||||
|
CHECK(event.Init());
|
||||||
|
stream->ThenRecordEvent(&event);
|
||||||
|
xla_tensor->SetDefinedOn(stream, std::move(event));
|
||||||
|
}
|
||||||
*variable->tensor() = output_tensor;
|
*variable->tensor() = output_tensor;
|
||||||
} else {
|
} else {
|
||||||
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
|
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
|
||||||
|
@ -76,9 +76,15 @@ class XlaComputationLaunchContext {
|
|||||||
// Create a new launch context. 'allocate_xla_tensors' is true if allocated
|
// Create a new launch context. 'allocate_xla_tensors' is true if allocated
|
||||||
// output tensors and variables are always XlaTensors. If false they are
|
// output tensors and variables are always XlaTensors. If false they are
|
||||||
// assumed to be "normal" device pointers.
|
// assumed to be "normal" device pointers.
|
||||||
|
// If 'use_multiple_streams' is true, tensors may be defined and used on
|
||||||
|
// multiple streams and so se::Events must be defined and waited for. If
|
||||||
|
// 'use_multiple_streams' is true, 'allocate_xla_tensors' must also be true
|
||||||
|
// because we track inter-stream dependencies through events inside XlaTensor
|
||||||
|
// objects.
|
||||||
XlaComputationLaunchContext(xla::LocalClient* client,
|
XlaComputationLaunchContext(xla::LocalClient* client,
|
||||||
xla::DeviceMemoryAllocator* xla_allocator,
|
xla::DeviceMemoryAllocator* xla_allocator,
|
||||||
bool allocate_xla_tensors);
|
bool allocate_xla_tensors,
|
||||||
|
bool use_multiple_streams);
|
||||||
|
|
||||||
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
|
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
|
||||||
// `variables` is a map from TensorFlow argument number to resource variable.
|
// `variables` is a map from TensorFlow argument number to resource variable.
|
||||||
@ -99,6 +105,7 @@ class XlaComputationLaunchContext {
|
|||||||
xla::LocalClient* client_;
|
xla::LocalClient* client_;
|
||||||
xla::DeviceMemoryAllocator* xla_allocator_;
|
xla::DeviceMemoryAllocator* xla_allocator_;
|
||||||
bool allocate_xla_tensors_;
|
bool allocate_xla_tensors_;
|
||||||
|
bool use_multiple_streams_;
|
||||||
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_;
|
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_;
|
||||||
std::vector<xla::ShapedBuffer*> arg_ptrs_;
|
std::vector<xla::ShapedBuffer*> arg_ptrs_;
|
||||||
};
|
};
|
||||||
|
@ -73,6 +73,33 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) {
|
||||||
|
if (!definition_event_.has_value()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The set of defined streams is expected to be very small indeed (usually
|
||||||
|
// 1-2), so a simple linear scan should be fast enough.
|
||||||
|
if (std::find(streams_defined_on_.begin(), streams_defined_on_.end(),
|
||||||
|
stream) != streams_defined_on_.end()) {
|
||||||
|
// stream is in streams_defined_on_; it doesn't need to be waited on.
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return &*definition_event_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void XlaTensor::SetDefinedOn(se::Stream* stream, se::Event event) {
|
||||||
|
CHECK(!definition_event_.has_value())
|
||||||
|
<< "SetDefinedOn must only be called once!";
|
||||||
|
definition_event_ = std::move(event);
|
||||||
|
streams_defined_on_.push_back(stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
void XlaTensor::SetDefinedOn(se::Stream* stream) {
|
||||||
|
streams_defined_on_.push_back(stream);
|
||||||
|
}
|
||||||
|
|
||||||
// The pointer tag, OR-ed into the XlaTensor's address to distinguish it from
|
// The pointer tag, OR-ed into the XlaTensor's address to distinguish it from
|
||||||
// device-side tensors, which are either CPU or GPU memory pointers. This works
|
// device-side tensors, which are either CPU or GPU memory pointers. This works
|
||||||
// because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits.
|
// because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits.
|
||||||
|
@ -85,6 +85,24 @@ class XlaTensor {
|
|||||||
host_tensor_.reset(new Tensor(tensor));
|
host_tensor_.reset(new Tensor(tensor));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the tensor's content is not yet defined on 'stream', and there exists an
|
||||||
|
// se::Event declaring when the tensor's content is defined, return it.
|
||||||
|
// Otherwise, return nullptr. If this function returns nullptr then the
|
||||||
|
// tensor's content can be read on 'stream' without additional
|
||||||
|
// synchronization.
|
||||||
|
se::Event* GetDefinitionEvent(se::Stream* stream);
|
||||||
|
|
||||||
|
// Assert that the tensor's content is defined on 'stream' by the time 'event'
|
||||||
|
// triggers.
|
||||||
|
void SetDefinedOn(se::Stream* stream, se::Event event);
|
||||||
|
|
||||||
|
// Assert that the tensor's content is defined on 'stream'. This version does
|
||||||
|
// not provide an event, and must be called *after* SetDefinedOn(Stream,
|
||||||
|
// Event). This call can be read as an assertion that the definition event has
|
||||||
|
// been waited on by 'stream', so further calls to GetDefinitionEvent(stream)
|
||||||
|
// do not need to also wait on the event.
|
||||||
|
void SetDefinedOn(se::Stream* stream);
|
||||||
|
|
||||||
// Convert from a raw pointer to an XlaTensor, removing the pointer tag.
|
// Convert from a raw pointer to an XlaTensor, removing the pointer tag.
|
||||||
static XlaTensor* FromOpaquePointer(void* ptr);
|
static XlaTensor* FromOpaquePointer(void* ptr);
|
||||||
// Convert to a raw pointer from an XlaTensor, adding the pointer tag.
|
// Convert to a raw pointer from an XlaTensor, adding the pointer tag.
|
||||||
@ -95,6 +113,13 @@ class XlaTensor {
|
|||||||
std::unique_ptr<xla::ScopedShapedBuffer> shaped_buffer_;
|
std::unique_ptr<xla::ScopedShapedBuffer> shaped_buffer_;
|
||||||
// An optional host tensor value.
|
// An optional host tensor value.
|
||||||
std::unique_ptr<Tensor> host_tensor_;
|
std::unique_ptr<Tensor> host_tensor_;
|
||||||
|
// An optional event that is triggered when the tensor's content has been
|
||||||
|
// defined. If this event is nullptr, it is assumed that the tensor's content
|
||||||
|
// is always defined.
|
||||||
|
gtl::optional<se::Event> definition_event_;
|
||||||
|
// A list of all streams for which the tensor's content is defined for any
|
||||||
|
// newly enqueued command.
|
||||||
|
gtl::InlinedVector<se::Stream*, 2> streams_defined_on_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -44,6 +44,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
|
|||||||
se::Stream* stream, const ShapedBuffer& device_buffer) {
|
se::Stream* stream, const ShapedBuffer& device_buffer) {
|
||||||
StatusOr<std::unique_ptr<Literal>> ret;
|
StatusOr<std::unique_ptr<Literal>> ret;
|
||||||
se::Stream* substream = stream->GetOrCreateSubStream();
|
se::Stream* substream = stream->GetOrCreateSubStream();
|
||||||
|
substream->ThenWaitFor(stream);
|
||||||
auto cleanup = tensorflow::gtl::MakeCleanup(
|
auto cleanup = tensorflow::gtl::MakeCleanup(
|
||||||
[&]() { stream->ReturnSubStream(substream); });
|
[&]() { stream->ReturnSubStream(substream); });
|
||||||
|
|
||||||
@ -64,6 +65,7 @@ Status TransferManager::TransferLiteralToDevice(
|
|||||||
// Use a substream so that if we are called from a HostCallback we don't
|
// Use a substream so that if we are called from a HostCallback we don't
|
||||||
// deadlock.
|
// deadlock.
|
||||||
se::Stream* substream = stream->GetOrCreateSubStream();
|
se::Stream* substream = stream->GetOrCreateSubStream();
|
||||||
|
substream->ThenWaitFor(stream);
|
||||||
auto cleanup = tensorflow::gtl::MakeCleanup(
|
auto cleanup = tensorflow::gtl::MakeCleanup(
|
||||||
[&]() { stream->ReturnSubStream(substream); });
|
[&]() { stream->ReturnSubStream(substream); });
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
|
@ -15,9 +15,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/stream_executor/event.h"
|
#include "tensorflow/stream_executor/event.h"
|
||||||
|
|
||||||
|
#include "tensorflow/stream_executor/stream.h"
|
||||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||||
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
|
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
|
||||||
#include "tensorflow/stream_executor/stream.h"
|
|
||||||
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
|
|
||||||
@ -27,10 +27,13 @@ Event::Event(StreamExecutor* stream_exec)
|
|||||||
stream_exec_->implementation()->CreateEventImplementation()) {}
|
stream_exec_->implementation()->CreateEventImplementation()) {}
|
||||||
|
|
||||||
Event::~Event() {
|
Event::~Event() {
|
||||||
|
// Deal with nullptr implementation_, as this event may have been std::moved.
|
||||||
|
if (stream_exec_ && implementation_) {
|
||||||
auto status = stream_exec_->DeallocateEvent(this);
|
auto status = stream_exec_->DeallocateEvent(this);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
LOG(ERROR) << status.error_message();
|
LOG(ERROR) << status.error_message();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Event::Init() {
|
bool Event::Init() {
|
||||||
|
@ -61,6 +61,9 @@ class Event {
|
|||||||
// Returns a pointer to the underlying platform-specific implementation.
|
// Returns a pointer to the underlying platform-specific implementation.
|
||||||
internal::EventInterface* implementation() { return implementation_.get(); }
|
internal::EventInterface* implementation() { return implementation_.get(); }
|
||||||
|
|
||||||
|
Event(Event&&) = default;
|
||||||
|
Event& operator=(Event&&) = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Stream;
|
friend class Stream;
|
||||||
|
|
||||||
|
@ -5228,24 +5228,11 @@ port::Status Stream::BlockHostUntilDone() {
|
|||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
port::Status first_error;
|
|
||||||
{
|
|
||||||
// Wait until all active sub-streams have done their tasks.
|
|
||||||
mutex_lock lock(mu_);
|
|
||||||
for (auto &stream : sub_streams_) {
|
|
||||||
if (!stream.second) {
|
|
||||||
first_error.Update(stream.first->BlockHostUntilDone());
|
|
||||||
// Set this sub-stream as available.
|
|
||||||
stream.second = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
temporary_memory_manager_.DeallocateFinalizedTemporaries();
|
temporary_memory_manager_.DeallocateFinalizedTemporaries();
|
||||||
|
|
||||||
first_error.Update(parent_->BlockHostUntilDone(this));
|
port::Status error = parent_->BlockHostUntilDone(this);
|
||||||
CheckError(first_error.ok());
|
CheckError(error.ok());
|
||||||
return first_error;
|
return error;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace stream_executor
|
} // namespace stream_executor
|
||||||
|
Loading…
Reference in New Issue
Block a user