diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index cdbe69d617e..c01f906fe85 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -140,9 +140,9 @@ tf_cc_test( ) cc_library( - name = "device_state", - srcs = ["device_state.cc"], - hdrs = ["device_state.h"], + name = "local_device_state", + srcs = ["local_device_state.cc"], + hdrs = ["local_device_state.h"], deps = [ ":event_pool", ":semaphore", @@ -161,7 +161,7 @@ cc_library( srcs = ["local_client.cc"], hdrs = ["local_client.h"], deps = [ - ":device_state", + ":local_device_state", ":shared_device_buffer", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:literal", diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/python/local_client.cc index ef8ff4275a6..237f10c39ae 100644 --- a/tensorflow/compiler/xla/python/local_client.cc +++ b/tensorflow/compiler/xla/python/local_client.cc @@ -105,6 +105,13 @@ limitations under the License. namespace xla { +StatusOr<LocalDeviceState*> Device::GetLocalDeviceState() const { + if (local_device_state_) { + return local_device_state_.get(); + } + return InvalidArgument("Device %s is not a local device.", DebugString()); +} + std::string CpuDevice::DebugString() const { return absl::StrCat("CPU_", id()); } @@ -115,7 +122,7 @@ std::string GpuDevice::DebugString() const { static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator( se::Platform* platform, - absl::Span<const std::unique_ptr<DeviceState>> device_states, + absl::Span<const std::shared_ptr<Device>> local_devices, LocalClient* client, double memory_fraction, bool preallocate) { CHECK_GT(client->backend().device_count(), 0); std::vector<se::MultiDeviceAdapter::AllocatorWithStream> allocators; @@ -148,19 +155,24 @@ static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator( /*allow_growth=*/!preallocate, absl::StrCat("GPU_", device_ordinal, "_bfc")); allocators.emplace_back(std::move(gpu_bfc_allocator), - device_states.at(device_ordinal)->compute_stream()); + local_devices.at(device_ordinal) + ->local_device_state() + ->compute_stream()); } return absl::make_unique<se::MultiDeviceAdapter>(platform, std::move(allocators)); } -static std::shared_ptr<Device> MakeDevice(const std::string& platform_name, - int id, int local_device_ordinal) { +static std::shared_ptr<Device> MakeDevice( + const std::string& platform_name, int id, + std::unique_ptr<LocalDeviceState> local_device_state) { if (platform_name == "cpu") { - return std::make_shared<CpuDevice>(id, local_device_ordinal, platform_name); + return std::make_shared<CpuDevice>(id, std::move(local_device_state), + platform_name); } else { CHECK_EQ(platform_name, "gpu"); - return std::make_shared<GpuDevice>(id, local_device_ordinal, platform_name); + return std::make_shared<GpuDevice>(id, std::move(local_device_state), + platform_name); } } @@ -179,16 +191,15 @@ StatusOr<std::shared_ptr<PyLocalClient>> PyLocalClient::Get( ClientLibrary::GetOrCreateLocalClient(options)); bool gpu_platform = platform_name == "gpu"; - std::vector<std::unique_ptr<DeviceState>> device_states; std::vector<std::shared_ptr<Device>> devices; bool synchronous_deallocation = platform_name == "cpu"; for (int i = 0; i < client->device_count(); ++i) { se::StreamExecutor* executor = client->backend().stream_executor(i).ValueOrDie(); - device_states.push_back(absl::make_unique<DeviceState>( + auto device_state = absl::make_unique<LocalDeviceState>( executor, synchronous_deallocation, asynchronous, - /*allow_event_reuse=*/gpu_platform)); - devices.push_back(MakeDevice(platform_name, i, i)); + /*allow_event_reuse=*/gpu_platform); + devices.push_back(MakeDevice(platform_name, i, std::move(device_state))); } std::unique_ptr<se::DeviceMemoryAllocator> allocator; @@ -196,7 +207,7 @@ StatusOr<std::shared_ptr<PyLocalClient>> PyLocalClient::Get( if (gpu_platform) { if (allocator_config.kind != AllocatorConfig::Kind::kPlatform) { TF_ASSIGN_OR_RETURN(allocator, - CreateBFCAllocator(platform, device_states, client, + CreateBFCAllocator(platform, devices, client, allocator_config.memory_fraction, allocator_config.preallocate)); } @@ -217,21 +228,18 @@ StatusOr<std::shared_ptr<PyLocalClient>> PyLocalClient::Get( return std::make_shared<PyLocalClient>( platform_name, client, std::move(devices), /*host_id=*/0, - std::move(device_states), std::move(allocator), - std::move(host_memory_allocator)); + std::move(allocator), std::move(host_memory_allocator)); } PyLocalClient::PyLocalClient( std::string platform_name, LocalClient* client, std::vector<std::shared_ptr<Device>> devices, int host_id, - std::vector<std::unique_ptr<DeviceState>> device_states, std::unique_ptr<se::DeviceMemoryAllocator> allocator, std::unique_ptr<tensorflow::Allocator> host_memory_allocator) : platform_name_(std::move(platform_name)), client_(client), devices_(std::move(devices)), host_id_(host_id), - device_states_(std::move(device_states)), owned_allocator_(std::move(allocator)), host_memory_allocator_(std::move(host_memory_allocator)), h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer", @@ -242,15 +250,16 @@ PyLocalClient::PyLocalClient( allocator_ = client_->backend().memory_allocator(); } - local_devices_.resize(device_states_.size()); for (const std::shared_ptr<Device>& device : devices_) { CHECK(id_to_device_.insert({device->id(), device}).second) << "Duplicate device id: " << device->id(); - if (device->local_device_ordinal() != -1) { - int idx = device->local_device_ordinal(); + if (device->local_device_state()) { + int idx = device->local_device_state()->device_ordinal(); + if (idx >= local_devices_.size()) { + local_devices_.resize(idx + 1); + } CHECK(local_devices_[idx] == nullptr) << idx; - CHECK_LT(idx, local_devices_.size()); local_devices_[idx] = device; } } @@ -274,17 +283,19 @@ PyLocalClient::DeserializeExecutable( } Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal, - int device_ordinal) { - TF_RETURN_IF_ERROR( - CheckDeviceOrdinal(device_ordinal, "PyLocalClient::TransferToInfeed")); - return client_->TransferToInfeedLocal(literal, device_ordinal); + std::shared_ptr<Device> device) { + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, + device->GetLocalDeviceState()); + return client_->TransferToInfeedLocal(literal, + local_device->device_ordinal()); } -StatusOr<Literal> PyLocalClient::TransferFromOutfeed(const Shape& shape, - int device_ordinal) { - TF_RETURN_IF_ERROR( - CheckDeviceOrdinal(device_ordinal, "PyLocalClient::TransferFromOutfeed")); - return client_->TransferFromOutfeedLocal(shape, device_ordinal); +StatusOr<Literal> PyLocalClient::TransferFromOutfeed( + const Shape& shape, std::shared_ptr<Device> device) { + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, + device->GetLocalDeviceState()); + return client_->TransferFromOutfeedLocal(shape, + local_device->device_ordinal()); } StatusOr<DeviceAssignment> PyLocalClient::GetDefaultDeviceAssignment( @@ -293,36 +304,26 @@ StatusOr<DeviceAssignment> PyLocalClient::GetDefaultDeviceAssignment( num_replicas, /*computation_count=*/1); } -Status PyLocalClient::CheckDeviceOrdinal(int device_ordinal, - absl::string_view caller_name) { - if (device_ordinal < 0 || device_ordinal >= local_device_count()) { - return InvalidArgument( - "%s got bad device_ordinal: %d (num_local_devices=%d)", caller_name, - device_ordinal, local_device_count()); - } - return Status::OK(); -} - /* static */ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals( std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape, std::shared_ptr<void> leaves_reference, - std::shared_ptr<PyLocalClient> client, int device_ordinal) { + std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device) { tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromLiterals"); VLOG(1) << "PyLocalBuffer::FromLiterals: shape: " << tuple_shape.ToString() - << " device ordinal: " << device_ordinal; - TF_RETURN_IF_ERROR(client->CheckDeviceOrdinal(device_ordinal, - "PyLocalBuffer::FromLiterals")); - DeviceState* device = &client->device_state(device_ordinal); + << " device: " << device->DebugString(); + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, + device->GetLocalDeviceState()); TransferManager* transfer_manager = client->client()->backend().transfer_manager(); se::DeviceMemoryAllocator* allocator = client->allocator(); TF_ASSIGN_OR_RETURN( Shape compact_shape, transfer_manager->ChooseCompactLayoutForShape(tuple_shape)); - TF_ASSIGN_OR_RETURN(ScopedShapedBuffer scoped_buffer, - transfer_manager->AllocateScopedShapedBuffer( - compact_shape, allocator, device_ordinal)); + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer scoped_buffer, + transfer_manager->AllocateScopedShapedBuffer( + compact_shape, allocator, local_device->device_ordinal())); // Make the host to device stream wait for the newly allocated buffer to be // available on the compute stream. We schedule this wait synchronously; while @@ -331,8 +332,9 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals( // computations that depend on this transfer being enqueued on the compute // stream. if (!transfer_manager->CanShapedBufferBeAccessedNow( - device->host_to_device_stream()->parent(), scoped_buffer)) { - device->host_to_device_stream()->ThenWaitFor(device->compute_stream()); + local_device->host_to_device_stream()->parent(), scoped_buffer)) { + local_device->host_to_device_stream()->ThenWaitFor( + local_device->compute_stream()); } std::shared_ptr<BufferDefinitionEvent> definition_event = @@ -344,16 +346,15 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals( // TODO(makro): Use move capture once C++ 14 features are available. auto leaves = std::make_shared<std::vector<BorrowingLiteral>>( std::move(leaves_literals)); - auto transfer_h2d = [client, transfer_manager, device, device_ordinal, - device_buffer, compact_shape, leaves, - leaves_reference]() { + auto transfer_h2d = [client, transfer_manager, local_device, device_buffer, + compact_shape, leaves, leaves_reference]() { // This function uses TF_CHECK_OK and ValueOrDie() since we have no way to // report failures from a callback. However, the operations here are // unlikely to fail and not recoverable even if we were to fail: DMAs to // memory that has already been allocated, and a possible Event allocation. ShapedBuffer buffer = device_buffer->AsShapedBuffer(compact_shape); TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync( - device->host_to_device_stream(), buffer)); + local_device->host_to_device_stream(), buffer)); std::vector<std::shared_ptr<void>> staging_buffers; staging_buffers.reserve(leaves->size()); auto it = leaves->begin(); @@ -363,7 +364,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals( ShapedBuffer leaf( indexed_shape.shape, transfer_manager->HostShapeToDeviceShape(indexed_shape.shape), - client->client()->platform(), device_ordinal); + client->client()->platform(), local_device->device_ordinal()); leaf.buffers().CopySubtreeFrom(buffer.buffers(), indexed_shape.index, {}); // If applicable on the backend, stage the transfer via host memory @@ -379,51 +380,53 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals( BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()), it->shape()); TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( - device->host_to_device_stream(), literal, leaf)); + local_device->host_to_device_stream(), literal, leaf)); staging_buffers.push_back(std::move(staging_buffer)); } else { // Otherwise, just transfer the literal. TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( - device->host_to_device_stream(), *it, leaf)); + local_device->host_to_device_stream(), *it, leaf)); } ++it; } EventPool::Handle event = - device->event_pool() - .ThenAllocateAndRecordEvent(device->host_to_device_stream()) + local_device->event_pool() + .ThenAllocateAndRecordEvent(local_device->host_to_device_stream()) .ValueOrDie(); // Sets the buffer definition event. Note: this has the side effect of // unblocking any host threads that may have been waiting to consume the // buffer. device_buffer->definition_event()->SetDefinitionEvent( - std::move(event), device->host_to_device_stream()); + std::move(event), local_device->host_to_device_stream()); - if (device->synchronous_deallocation()) { - device->ThenRelease(device->host_to_device_stream(), device_buffer); + if (local_device->synchronous_deallocation()) { + local_device->ThenRelease(local_device->host_to_device_stream(), + device_buffer); } - device->ThenRelease( - device->host_to_device_stream(), + local_device->ThenRelease( + local_device->host_to_device_stream(), std::make_pair(leaves_reference, std::move(staging_buffers))); }; client->h2d_transfer_pool()->Schedule(transfer_h2d); - return absl::make_unique<PyLocalBuffer>( - compact_shape, std::move(device_buffer), std::move(client)); + return absl::make_unique<PyLocalBuffer>(compact_shape, + std::move(device_buffer), + std::move(client), std::move(device)); } /* static */ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::MakeTuple( const std::vector<PyLocalBuffer*> buffers, - std::shared_ptr<PyLocalClient> client, int device_ordinal) { - TF_RETURN_IF_ERROR( - client->CheckDeviceOrdinal(device_ordinal, "PyLocalBuffer::MakeTuple")); + std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device) { + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, + device->GetLocalDeviceState()); std::vector<Shape> host_shapes; std::vector<std::shared_ptr<SharedDeviceBuffer>> device_buffers; host_shapes.reserve(buffers.size()); device_buffers.reserve(buffers.size()); for (const PyLocalBuffer* buffer : buffers) { - TF_RET_CHECK(buffer->device_ordinal() == device_ordinal); + TF_RET_CHECK(buffer->device().get() == device.get()); std::shared_ptr<SharedDeviceBuffer> device_buffer = buffer->DeviceBuffer(); if (!device_buffer) { return InvalidArgument( @@ -436,45 +439,48 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals( se::DeviceMemoryAllocator* allocator = client->allocator(); TransferManager* transfer_manager = client->client()->backend().transfer_manager(); - DeviceState& device = client->device_state(device_ordinal); auto definition_event = std::make_shared<BufferDefinitionEvent>(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr<SharedDeviceBuffer> tuple_buffer, - SharedDeviceBuffer::MakeTuple(device_buffers, transfer_manager, allocator, - device_ordinal, definition_event)); + TF_ASSIGN_OR_RETURN(std::shared_ptr<SharedDeviceBuffer> tuple_buffer, + SharedDeviceBuffer::MakeTuple( + device_buffers, transfer_manager, allocator, + local_device->device_ordinal(), definition_event)); auto buffer = absl::make_unique<PyLocalBuffer>( - ShapeUtil::MakeTupleShape(host_shapes), tuple_buffer, std::move(client)); + ShapeUtil::MakeTupleShape(host_shapes), tuple_buffer, std::move(client), + std::move(device)); // TODO(phawkins): extend TransferManager so we do not need to form a full // ShapedBuffer just to write the root tuple index table. TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, buffer->AsShapedBuffer()); if (!transfer_manager->CanShapedBufferBeAccessedNow( - device.host_to_device_stream()->parent(), shaped_buffer)) { + local_device->host_to_device_stream()->parent(), shaped_buffer)) { // Wait for the compute stream so that memory allocations are synchronized. - device.host_to_device_stream()->ThenWaitFor(device.compute_stream()); + local_device->host_to_device_stream()->ThenWaitFor( + local_device->compute_stream()); } TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable( - device.host_to_device_stream(), shaped_buffer)); + local_device->host_to_device_stream(), shaped_buffer)); TF_ASSIGN_OR_RETURN(EventPool::Handle event, - device.event_pool().ThenAllocateAndRecordEvent( - device.host_to_device_stream())); + local_device->event_pool().ThenAllocateAndRecordEvent( + local_device->host_to_device_stream())); definition_event->SetDefinitionEvent(std::move(event), - device.host_to_device_stream()); + local_device->host_to_device_stream()); - if (device.synchronous_deallocation()) { - device.ThenRelease(device.host_to_device_stream(), std::move(tuple_buffer)); + if (local_device->synchronous_deallocation()) { + local_device->ThenRelease(local_device->host_to_device_stream(), + std::move(tuple_buffer)); } return buffer; } PyLocalBuffer::PyLocalBuffer(Shape on_host_shape, std::shared_ptr<SharedDeviceBuffer> device_buffer, - std::shared_ptr<PyLocalClient> client) + std::shared_ptr<PyLocalClient> client, + std::shared_ptr<Device> device) : client_(std::move(client)), on_host_shape_(std::move(on_host_shape)), - device_ordinal_(device_buffer->device_ordinal()), + device_(std::move(device)), device_buffer_(std::move(device_buffer)) {} void PyLocalBuffer::Delete() { @@ -499,8 +505,7 @@ Status PyLocalBuffer::CopyToHostAsync() { } host_value = host_value_ = std::make_shared<HostValue>(); } - se::Stream* stream = - client_->device_state(device_ordinal_).device_to_host_stream(); + se::Stream* stream = device_->local_device_state()->device_to_host_stream(); WaitForBufferDefinitionEventsOnStream(*device_buffer, stream); host_value->value = std::make_shared<Literal>(on_host_shape_); TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, AsShapedBuffer()); @@ -564,36 +569,38 @@ PyLocalBuffer::DestructureTuple() { for (int64 i = 0; i < num_children; ++i) { results.push_back(absl::make_unique<PyLocalBuffer>( on_host_shape_.tuple_shapes(i), device_buffer_->children().at(i), - client_)); + client_, device_)); } return results; } StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice( - int dst_device_ordinal) { + std::shared_ptr<Device> dst_device) { tensorflow::profiler::TraceMe traceme("PyLocalBuffer::CopyToDevice"); std::shared_ptr<SharedDeviceBuffer> src_device_buffer = DeviceBuffer(); - if (dst_device_ordinal == device_ordinal_) { - return absl::make_unique<PyLocalBuffer>(on_host_shape_, src_device_buffer, - client_); - } - int transfer_device_ordinal = client_->EnqueueD2DTransfersOnSrcStream() - ? device_ordinal_ - : dst_device_ordinal; - DeviceState& transfer_device = client_->device_state(transfer_device_ordinal); - const DeviceState& dst_device = client_->device_state(dst_device_ordinal); + TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device, + dst_device->GetLocalDeviceState()); - se::Stream* transfer_stream = transfer_device.GetDeviceToDeviceStream(); + if (dst_device.get() == device_.get()) { + return absl::make_unique<PyLocalBuffer>(on_host_shape_, src_device_buffer, + client_, device_); + } + LocalDeviceState* transfer_local_device = + client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state() + : dst_local_device; + + se::Stream* transfer_stream = + transfer_local_device->GetDeviceToDeviceStream(); TransferManager* transfer_manager = client_->client()->backend().transfer_manager(); - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer dst_buffer, - transfer_manager->AllocateScopedShapedBuffer( - on_host_shape_, client_->allocator(), dst_device_ordinal)); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer, + transfer_manager->AllocateScopedShapedBuffer( + on_host_shape_, client_->allocator(), + dst_local_device->device_ordinal())); if (!transfer_manager->CanShapedBufferBeAccessedNow( - dst_device.compute_stream()->parent(), dst_buffer)) { - transfer_stream->ThenWaitFor(dst_device.compute_stream()); + dst_local_device->compute_stream()->parent(), dst_buffer)) { + transfer_stream->ThenWaitFor(dst_local_device->compute_stream()); } TF_ASSIGN_OR_RETURN(ShapedBuffer src_buffer, AsShapedBuffer()); @@ -607,37 +614,39 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice( TF_RET_CHECK(input_buffer.size() == output_buffer.size()) << "input: " << input_buffer.size() << " output: " << output_buffer.size(); - TF_RETURN_IF_ERROR(transfer_device.ThenMemcpyDeviceToDevice( - transfer_stream, dst_device.compute_stream(), input_buffer, + TF_RETURN_IF_ERROR(transfer_local_device->ThenMemcpyDeviceToDevice( + transfer_stream, dst_local_device->compute_stream(), input_buffer, output_buffer)); } // We hold on to the `src_device_buffer` until the transfer is finished. - transfer_device.ThenRelease(transfer_stream, std::move(src_device_buffer)); + transfer_local_device->ThenRelease(transfer_stream, + std::move(src_device_buffer)); // Write new tuple buffers. The destination buffers have different addresses, // so we must construct tuple buffers from scratch instead of copying them. if (dst_buffer.on_device_shape().IsTuple()) { TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync( - dst_device.host_to_device_stream(), dst_buffer)); + dst_local_device->host_to_device_stream(), dst_buffer)); // We need a single definition event, so make the device to device stream // wait for the stream that wrote the tuple index tables on the destination // device. - transfer_stream->ThenWaitFor(dst_device.host_to_device_stream()); + transfer_stream->ThenWaitFor(dst_local_device->host_to_device_stream()); } auto definition_event = std::make_shared<BufferDefinitionEvent>(); TF_ASSIGN_OR_RETURN( EventPool::Handle event, - transfer_device.event_pool().ThenAllocateAndRecordEvent(transfer_stream)); + transfer_local_device->event_pool().ThenAllocateAndRecordEvent( + transfer_stream)); definition_event->SetDefinitionEvent(std::move(event), transfer_stream); std::shared_ptr<SharedDeviceBuffer> dst_device_buffer = SharedDeviceBuffer::FromScopedShapedBuffer(std::move(dst_buffer), definition_event); return absl::make_unique<PyLocalBuffer>( - on_host_shape_, std::move(dst_device_buffer), client_); + on_host_shape_, std::move(dst_device_buffer), client_, dst_device); } Status PyLocalBuffer::BlockHostUntilReady() { @@ -694,7 +703,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper( const int device_id = (*device_assignment_)(replica, 0); std::shared_ptr<Device> device = LookupDevice(*client_, device_id); CHECK_EQ(device->host_id(), client_->host_id()); - int device_ordinal = device->local_device_ordinal(); + int device_ordinal = device->local_device_state()->device_ordinal(); tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute"); VLOG(3) << "Replica " << replica << " mapped to device ordinal for execution: " << device_ordinal; @@ -729,7 +738,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper( << " buffer: " << argument_buffers.back().ToString(); } - DeviceState* device_state = &client_->device_state(device_ordinal); + LocalDeviceState* device_state = &client_->device_state(device_ordinal); // The choice of where we wait is arbitrary; the reason for the wait is pacing // to avoid problems such as memory fragmentation and running ahead too far, // not for correctness. Placing it before the executable launch allows the @@ -782,7 +791,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper( device_state->compute_stream(), std::make_tuple(executable_, compute_reservation, device_assignment_)); return absl::make_unique<PyLocalBuffer>(on_host_shape, std::move(out_buffer), - client_); + client_, device); } StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::Execute( @@ -833,8 +842,7 @@ PyLocalExecutable::ExecutePerReplica( for (int i = 0; i < num_local_replicas; ++i) { const int replica = local_replicas_[i]; std::shared_ptr<Device> device = local_devices_[i]; - const DeviceState& device_state = - client_->device_state(device->local_device_ordinal()); + const LocalDeviceState& device_state = *device->local_device_state(); device_state.execute_thread()->Schedule([&, replica, i] { results[i] = ExecuteHelper(argument_handles[i], replica, run_id); diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/python/local_client.h index 3f13f62241f..e0a21ad6f1e 100644 --- a/tensorflow/compiler/xla/python/local_client.h +++ b/tensorflow/compiler/xla/python/local_client.h @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/python/device_state.h" +#include "tensorflow/compiler/xla/python/local_device_state.h" #include "tensorflow/compiler/xla/python/shared_device_buffer.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -43,10 +43,10 @@ class PyLocalExecutable; class Device { public: - explicit Device(int id, int local_device_ordinal, + explicit Device(int id, std::unique_ptr<LocalDeviceState> local_device_state, absl::string_view platform_name, int host_id = 0) : id_(id), - local_device_ordinal_(local_device_ordinal), + local_device_state_(std::move(local_device_state)), host_id_(host_id), platform_name_(platform_name) {} virtual ~Device() {} @@ -56,13 +56,17 @@ class Device { // hosts' devices. This is the ID that should be used in a DeviceAssignment. int id() const { return id_; } - // If this is a device local to this host, the local index of this device as - // according to the underlying backend. Unlike id(), this will always be in - // the range [0, num_local_devices), and can be used with the xla::LocalClient - // and xla::Backend APIs. - // - // -1 if this device is not local to this host. - int local_device_ordinal() const { return local_device_ordinal_; } + // If this is a device local to this host, returns a LocalDeviceState object + // that can be used to manipulate the device. Returns nullptr if the device is + // not local to this host. + LocalDeviceState* local_device_state() const { + return local_device_state_.get(); + } + + // If this is a device local to this host, returns a LocalDeviceState object + // that can be used to manipulate the device. Returns an error if the device + // is not local to this host. + StatusOr<LocalDeviceState*> GetLocalDeviceState() const; // The ID of this device's host. This is always 0 on single-host platforms. int host_id() const { return host_id_; } @@ -73,7 +77,7 @@ class Device { private: const int id_; - const int local_device_ordinal_; + const std::unique_ptr<LocalDeviceState> local_device_state_; const int host_id_; const std::string platform_name_; }; @@ -123,13 +127,14 @@ class PyLocalClient { explicit PyLocalClient( std::string platform_name, LocalClient* client, std::vector<std::shared_ptr<Device>> devices, int host_id, - std::vector<std::unique_ptr<DeviceState>> device_states, std::unique_ptr<se::DeviceMemoryAllocator> allocator, std::unique_ptr<tensorflow::Allocator> host_memory_allocator); virtual ~PyLocalClient() = default; - Status TransferToInfeed(const LiteralSlice& literal, int device_ordinal); - StatusOr<Literal> TransferFromOutfeed(const Shape& shape, int device_ordinal); + Status TransferToInfeed(const LiteralSlice& literal, + std::shared_ptr<Device> device); + StatusOr<Literal> TransferFromOutfeed(const Shape& shape, + std::shared_ptr<Device> device); virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment( int num_replicas) const; @@ -146,8 +151,8 @@ class PyLocalClient { int host_id() const { return host_id_; } const std::string& platform_name() const { return platform_name_; } - DeviceState& device_state(int device_ordinal) const { - return *device_states_.at(device_ordinal); + LocalDeviceState& device_state(int device_ordinal) const { + return *local_devices_.at(device_ordinal)->local_device_state(); } LocalClient* client() const { return client_; } @@ -178,10 +183,6 @@ class PyLocalClient { const std::string& serialized, std::shared_ptr<PyLocalClient> this_shared) const; - // Returns a bad status containing `caller_name` if `device_ordinal` doesn't - // correspond to a local device. - Status CheckDeviceOrdinal(int device_ordinal, absl::string_view caller_name); - protected: std::string platform_name_; LocalClient* client_; @@ -194,8 +195,6 @@ class PyLocalClient { std::vector<std::shared_ptr<Device>> local_devices_; int host_id_; - // Device states local to this host. Indexed by local device ordinal. - std::vector<std::unique_ptr<DeviceState>> device_states_; se::DeviceMemoryAllocator* allocator_; std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_; @@ -219,16 +218,16 @@ class PyLocalBuffer { static StatusOr<std::unique_ptr<PyLocalBuffer>> FromLiterals( std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape, std::shared_ptr<void> leaves_reference, - std::shared_ptr<PyLocalClient> client, int device_ordinal); + std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device); static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple( const std::vector<PyLocalBuffer*> buffers, - std::shared_ptr<PyLocalClient> client, int device_ordinal); + std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device); - PyLocalBuffer() = default; PyLocalBuffer(Shape on_host_shape, std::shared_ptr<SharedDeviceBuffer> device_buffer, - std::shared_ptr<PyLocalClient> client); + std::shared_ptr<PyLocalClient> client, + std::shared_ptr<Device> device); PyLocalBuffer(const PyLocalBuffer&) = delete; PyLocalBuffer(PyLocalBuffer&&) = delete; @@ -236,7 +235,7 @@ class PyLocalBuffer { PyLocalBuffer& operator=(PyLocalBuffer&&) = delete; const Shape& on_host_shape() const { return on_host_shape_; } - int device_ordinal() const { return device_ordinal_; } + std::shared_ptr<Device> device() const { return device_; } const std::string& platform_name() const { return client_->platform_name(); } std::shared_ptr<PyLocalClient> client() const { return client_; } @@ -266,8 +265,9 @@ class PyLocalBuffer { // Destructures a tuple-valued PyLocalBuffer into its constituent elements. StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> DestructureTuple(); - // Copies the buffer to device `dst_device_ordinal`. - StatusOr<std::unique_ptr<PyLocalBuffer>> CopyToDevice(int dst_device_ordinal); + // Copies the buffer to device `dst_device`. + StatusOr<std::unique_ptr<PyLocalBuffer>> CopyToDevice( + std::shared_ptr<Device> dst_device); // Blocks the host until the buffer's value has been computed and is ready for // immediate use on the device. Useful in particular for timing benchmarks. @@ -276,7 +276,7 @@ class PyLocalBuffer { private: const std::shared_ptr<PyLocalClient> client_; const Shape on_host_shape_; - const int device_ordinal_; + const std::shared_ptr<Device> device_; mutable absl::Mutex mu_; std::shared_ptr<SharedDeviceBuffer> device_buffer_ GUARDED_BY(mu_); diff --git a/tensorflow/compiler/xla/python/device_state.cc b/tensorflow/compiler/xla/python/local_device_state.cc similarity index 81% rename from tensorflow/compiler/xla/python/device_state.cc rename to tensorflow/compiler/xla/python/local_device_state.cc index 3403d882e92..6b8d09d4ffa 100644 --- a/tensorflow/compiler/xla/python/device_state.cc +++ b/tensorflow/compiler/xla/python/local_device_state.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/device_state.h" +#include "tensorflow/compiler/xla/python/local_device_state.h" #include <memory> #include <vector> @@ -24,12 +24,13 @@ limitations under the License. namespace xla { -DeviceState::DeviceState(se::StreamExecutor* executor, - bool synchronous_deallocation, bool asynchronous, - bool allow_event_reuse) +LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor, + bool synchronous_deallocation, + bool asynchronous, bool allow_event_reuse) : synchronous_deallocation_(synchronous_deallocation), event_pool_(allow_event_reuse), - compute_semaphore_(/*capacity=*/asynchronous ? 32 : 1) { + compute_semaphore_(/*capacity=*/asynchronous ? 32 : 1), + executor_(executor) { compute_stream_ = absl::make_unique<se::Stream>(executor); host_to_device_stream_ = absl::make_unique<se::Stream>(executor); device_to_host_stream_ = absl::make_unique<se::Stream>(executor); @@ -50,14 +51,14 @@ DeviceState::DeviceState(se::StreamExecutor* executor, "py_xla_callback"); } -DeviceState::~DeviceState() { +LocalDeviceState::~LocalDeviceState() { Status status = SynchronizeAllActivity(); if (!status.ok()) { LOG(ERROR) << "Error when closing device: " << status; } } -Status DeviceState::SynchronizeAllActivity() { +Status LocalDeviceState::SynchronizeAllActivity() { Status status; // TODO(phawkins): in theory the call to SynchronizeAllActivity below should // suffice. However on the Host platform SynchronizeAllActivity is a dummy @@ -73,10 +74,9 @@ Status DeviceState::SynchronizeAllActivity() { return status; } -Status DeviceState::ThenMemcpyDeviceToDevice(se::Stream* transfer_stream, - se::Stream* dst_stream, - se::DeviceMemoryBase src_buffer, - se::DeviceMemoryBase dst_buffer) { +Status LocalDeviceState::ThenMemcpyDeviceToDevice( + se::Stream* transfer_stream, se::Stream* dst_stream, + se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) { // The default implementation simply calls ThenMemcpyD2D, and assumes that // the buffer addresses identify the devices. This does not work // on all platforms; this method is virtual so it can be overridden. @@ -84,14 +84,14 @@ Status DeviceState::ThenMemcpyDeviceToDevice(se::Stream* transfer_stream, return Status::OK(); } -void DeviceState::ThenExecuteOnCallbackThread( +void LocalDeviceState::ThenExecuteOnCallbackThread( se::Stream* stream, std::function<void()> callback) const { stream->ThenDoHostCallback([this, callback]() mutable { callback_thread_->Schedule(std::move(callback)); }); } -se::Stream* DeviceState::GetDeviceToDeviceStream() { +se::Stream* LocalDeviceState::GetDeviceToDeviceStream() { absl::MutexLock lock(&mu_); int i = next_device_to_device_stream_; next_device_to_device_stream_ = diff --git a/tensorflow/compiler/xla/python/device_state.h b/tensorflow/compiler/xla/python/local_device_state.h similarity index 88% rename from tensorflow/compiler/xla/python/device_state.h rename to tensorflow/compiler/xla/python/local_device_state.h index 3772c03fc59..fe9b9bd61b3 100644 --- a/tensorflow/compiler/xla/python/device_state.h +++ b/tensorflow/compiler/xla/python/local_device_state.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_STATE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_STATE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_ #include <memory> #include <vector> @@ -29,9 +29,9 @@ limitations under the License. namespace xla { // Class that encapsulates state relating to a device (e.g., a GPU) on which we -// can perform computation and transfers. DeviceState objects only exist for -// devices local to this host. -class DeviceState { +// can perform computation and transfers. LocalDeviceState objects only exist +// for devices local to this host. +class LocalDeviceState { public: // If synchronous_deallocation is true, the host must not free buffers until // compute/transfers that use those buffers have completed. For example, this @@ -40,9 +40,12 @@ class DeviceState { // // If asynchronous is false, the host will synchronize to the device after // each execution or transfer. This is intended for debugging only. - DeviceState(se::StreamExecutor* executor, bool synchronous_deallocation, - bool asynchronous, bool allow_event_reuse); - virtual ~DeviceState(); + LocalDeviceState(se::StreamExecutor* executor, bool synchronous_deallocation, + bool asynchronous, bool allow_event_reuse); + virtual ~LocalDeviceState(); + + // StreamExecutor (local) device ordinal. + int device_ordinal() const { return executor_->device_ordinal(); } bool synchronous_deallocation() const { return synchronous_deallocation_; } @@ -104,6 +107,7 @@ class DeviceState { // stream by the host ahead of the device. Semaphore compute_semaphore_; + se::StreamExecutor* executor_; std::unique_ptr<se::Stream> compute_stream_; std::unique_ptr<se::Stream> host_to_device_stream_; std::unique_ptr<se::Stream> device_to_host_stream_; @@ -132,4 +136,4 @@ class DeviceState { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_STATE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_ diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD index d5d492de054..13e0d147e86 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD @@ -19,7 +19,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:executable_build_options", - "//tensorflow/compiler/xla/python:device_state", "//tensorflow/compiler/xla/python:local_client", "//tensorflow/compiler/xla/python:semaphore", "//tensorflow/compiler/xla/python/tpu_driver", diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index b9ca2a7e1a7..f0c93772ffe 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -39,10 +39,9 @@ std::string TpuDevice::DebugString() const { } static std::shared_ptr<Device> MakeDevice(const std::string& platform_name, - int id, int local_device_ordinal) { + int id) { CHECK_EQ(platform_name, "tpu"); - CHECK_EQ(id, local_device_ordinal); // Every device must be local for now. - return std::make_shared<TpuDevice>(id, local_device_ordinal, "tpu"); + return std::make_shared<TpuDevice>(id, /*local_device_state=*/nullptr, "tpu"); } StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get( @@ -67,7 +66,7 @@ StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get( LOG(INFO) << "Creating " << num_cores << " TPU device(s)."; devices.reserve(num_cores); for (int i = 0; i < num_cores; ++i) { - devices.push_back(MakeDevice("tpu", i, i)); + devices.push_back(MakeDevice("tpu", i)); } return std::make_shared<PyTpuClient>("tpu", std::move(client), @@ -87,8 +86,8 @@ PyTpuClient::PyTpuClient(std::string platform_name, CHECK(id_to_device_.insert({device->id(), device}).second) << "Duplicate device id: " << device->id(); - if (device->local_device_ordinal() != -1) { - int idx = device->local_device_ordinal(); + if (device->id() != -1) { + int idx = device->id(); CHECK(local_devices_[idx] == nullptr) << idx; CHECK_LT(idx, local_devices_.size()); local_devices_[idx] = device; @@ -509,7 +508,7 @@ PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper( const int device_id = device_assignment_(replica, 0); std::shared_ptr<Device> device = LookupDevice(*client_, device_id); CHECK_EQ(device->host_id(), client_->host_id()); - int device_ordinal = device->local_device_ordinal(); + int device_ordinal = device->id(); tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Execute"); VLOG(3) << "Replica " << replica << " mapped to device ordinal for execution: " << device_ordinal; @@ -742,7 +741,7 @@ PyTpuExecutable::ExecutePerReplica( const int device_id = (*device_assignment)(replica, 0); std::shared_ptr<Device> device = LookupDevice(*client, device_id); CHECK_EQ(device->host_id(), client->host_id()); - int device_ordinal = device->local_device_ordinal(); + int device_ordinal = device->id(); loaded_programs[replica] = client->driver()->LoadProgram( device_ordinal, compiled_program.get(), {}); } diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h index 7624a14943f..49d4182b719 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h @@ -24,7 +24,6 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" -#include "tensorflow/compiler/xla/python/device_state.h" #include "tensorflow/compiler/xla/python/local_client.h" #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h" #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h" diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc index 60886416a62..2b7082d40c9 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc @@ -96,9 +96,9 @@ PYBIND11_MODULE(tpu_client_extension, m) { std::make_move_iterator(tree.leaves.end())); py::gil_scoped_release gil_release; - return PyTpuBuffer::FromLiterals( - std::move(leaves), tree.shape, std::move(py_buffer_ref), - std::move(client), device->local_device_ordinal()); + return PyTpuBuffer::FromLiterals(std::move(leaves), tree.shape, + std::move(py_buffer_ref), + std::move(client), device->id()); }) .def_static( "from_python", @@ -135,8 +135,8 @@ PYBIND11_MODULE(tpu_client_extension, m) { "Cannot make tuple on device '%s' with '%s' backend", device->DebugString(), client->platform_name()); } - return PyTpuBuffer::MakeTuple( - buffers, client, device->local_device_ordinal()); + return PyTpuBuffer::MakeTuple(buffers, client, + device->id()); }) .def_static("make_tuple", &PyTpuBuffer::MakeTuple) .def("copy_to_device", @@ -144,7 +144,7 @@ PYBIND11_MODULE(tpu_client_extension, m) { CHECK(dst_device != nullptr); GlobalPyRefManager()->CollectGarbage(); py::gil_scoped_release gil_release; - return buffer->CopyToDevice(dst_device->local_device_ordinal()); + return buffer->CopyToDevice(dst_device->id()); }) .def("copy_to_device", [](PyTpuBuffer* buffer, int dst_device_ordinal) { @@ -193,7 +193,7 @@ PYBIND11_MODULE(tpu_client_extension, m) { [](const PyTpuExecutable& executable) { std::vector<int> device_ordinals; for (std::shared_ptr<Device> device : executable.local_devices()) { - device_ordinals.push_back(device->local_device_ordinal()); + device_ordinals.push_back(device->id()); } return device_ordinals; }) diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index f1776763796..b5eb6fa47da 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -142,6 +142,16 @@ Status PyRegisterCustomCallTarget(const std::string& fn_name, return Status::OK(); } +StatusOr<std::shared_ptr<Device>> LookupDeviceOrdinal( + PyLocalClient* client, int device_ordinal, absl::string_view caller_name) { + if (device_ordinal < 0 || device_ordinal >= client->local_device_count()) { + return InvalidArgument( + "%s got bad device_ordinal: %d (num_local_devices=%d)", caller_name, + device_ordinal, client->local_device_count()); + } + return client->local_devices()[device_ordinal]; +} + } // namespace PYBIND11_MODULE(xla_extension, m) { @@ -381,13 +391,27 @@ PYBIND11_MODULE(xla_extension, m) { } return result; }) + // TODO(phawkins): delete overload that accepts a device_ordinal after + // all callers have been updated to pass a Device. .def("TransferToInfeed", [](PyLocalClient* client, const LiteralSlice& literal, int device_ordinal) { GlobalPyRefManager()->CollectGarbage(); py::gil_scoped_release gil_release; - return client->TransferToInfeed(literal, device_ordinal); + TF_ASSIGN_OR_RETURN(std::shared_ptr<Device> device, + LookupDeviceOrdinal(client, device_ordinal, + "TransferToInfeed")); + return client->TransferToInfeed(literal, device); }) + .def("TransferToInfeed", + [](PyLocalClient* client, const LiteralSlice& literal, + std::shared_ptr<Device> device) { + GlobalPyRefManager()->CollectGarbage(); + py::gil_scoped_release gil_release; + return client->TransferToInfeed(literal, device); + }) + // TODO(phawkins): delete overload that accepts a device_ordinal after + // all callers have been updated to pass a Device. .def("TransferFromOutfeed", [](PyLocalClient* client, const Shape& shape, int device_ordinal) -> StatusOr<py::object> { @@ -395,8 +419,24 @@ PYBIND11_MODULE(xla_extension, m) { std::shared_ptr<Literal> literal_shared; { py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN(Literal literal, client->TransferFromOutfeed( - shape, device_ordinal)); + TF_ASSIGN_OR_RETURN(std::shared_ptr<Device> device, + LookupDeviceOrdinal(client, device_ordinal, + "TransferFromOutfeed")); + TF_ASSIGN_OR_RETURN(Literal literal, + client->TransferFromOutfeed(shape, device)); + literal_shared = std::make_shared<Literal>(std::move(literal)); + } + return LiteralToPython(std::move(literal_shared)); + }) + .def("TransferFromOutfeed", + [](PyLocalClient* client, const Shape& shape, + std::shared_ptr<Device> device) -> StatusOr<py::object> { + GlobalPyRefManager()->CollectGarbage(); + std::shared_ptr<Literal> literal_shared; + { + py::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(Literal literal, + client->TransferFromOutfeed(shape, device)); literal_shared = std::make_shared<Literal>(std::move(literal)); } return LiteralToPython(std::move(literal_shared)); @@ -440,7 +480,7 @@ PYBIND11_MODULE(xla_extension, m) { py::gil_scoped_release gil_release; return PyLocalBuffer::FromLiterals( std::move(leaves), tree.shape, std::move(py_buffer_ref), - std::move(client), device->local_device_ordinal()); + std::move(client), std::move(device)); }) .def_static("make_tuple", [](const std::vector<PyLocalBuffer*> buffers, @@ -454,15 +494,15 @@ PYBIND11_MODULE(xla_extension, m) { "Cannot make tuple on device '%s' with '%s' backend", device->DebugString(), client->platform_name()); } - return PyLocalBuffer::MakeTuple( - buffers, client, device->local_device_ordinal()); + return PyLocalBuffer::MakeTuple(buffers, std::move(client), + std::move(device)); }) .def("copy_to_device", [](PyLocalBuffer* buffer, std::shared_ptr<Device> dst_device) { CHECK(dst_device != nullptr); GlobalPyRefManager()->CollectGarbage(); py::gil_scoped_release gil_release; - return buffer->CopyToDevice(dst_device->local_device_ordinal()); + return buffer->CopyToDevice(std::move(dst_device)); }) .def("delete", &PyLocalBuffer::Delete) .def("destructure", &PyLocalBuffer::DestructureTuple) @@ -485,10 +525,7 @@ PYBIND11_MODULE(xla_extension, m) { return LiteralToPython(std::move(literal)); }) .def("shape", &PyLocalBuffer::on_host_shape) - .def("device", - [](PyLocalBuffer* buffer) -> std::shared_ptr<Device> { - return buffer->client()->local_devices()[buffer->device_ordinal()]; - }) + .def("device", &PyLocalBuffer::device) .def("platform", &PyLocalBuffer::platform_name) .def("is_deleted", [](const PyLocalBuffer& buffer) { diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index c7f36a56912..82cab92443c 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -444,7 +444,7 @@ def shape_from_pyval(pyval): return convert(pyval) -def transfer_to_infeed(value, device_ordinal=0): +def transfer_to_infeed(value, device=None): """Transfers the given value into the XLA infeed queue. XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with @@ -454,29 +454,31 @@ def transfer_to_infeed(value, device_ordinal=0): Args: value: the value that the caller would like to enqueue into the XLA infeed queue - device_ordinal: the device to infeed the value to. Each device has a + device: the device to infeed the value to. Each device has a distinct infeed queue. """ # TODO(phawkins): support non-default backends. backend = get_local_backend() - backend.client.TransferToInfeed(value, device_ordinal) + device = device or backend.local_devices()[0] + backend.client.TransferToInfeed(value, device) -def transfer_from_outfeed(shape, device_ordinal=0): - """Transfers a literal of the given shape from `device_ordinal`'s outfeed. +def transfer_from_outfeed(shape, device=None): + """Transfers a literal of the given shape from `device`'s outfeed. Args: shape: The shape of the value to transfer from outfeed. - device_ordinal: The device ordinal to transfer the outfeed value from. Each - device has a distinct outfeed queue.. + device: The device from which to transfer the outfeed value. Each device has + a distinct outfeed queue.. Returns: The literal value that is produced from the outfeed queue. """ # TODO(phawkins): support non-default backends. backend = get_local_backend() + device = device or backend.local_devices()[0] return backend.client.TransferFromOutfeed( - shape.with_major_to_minor_layout_if_absent(), device_ordinal) + shape.with_major_to_minor_layout_if_absent(), device) DeviceAssignment = _xla.DeviceAssignment