[XLA:Python] Refactor Device and DeviceState class.

* Remove all device_ordinals from the local_client.h API. Instead, change APIs to expect a Device object.
* Rename DeviceState to LocalDeviceState, which perhaps better clarifies its role: it hold objects pertaining to a locally-attached device.
* Make the LocalDeviceState owned by Device.

PiperOrigin-RevId: 285394262
Change-Id: I8a4f14bb03be31c4667c85e7d39a4ebcee4c6f40
This commit is contained in:
Peter Hawkins 2019-12-13 07:16:10 -08:00 committed by TensorFlower Gardener
parent 035050412d
commit 9a4295cb3d
11 changed files with 256 additions and 208 deletions

View File

@ -140,9 +140,9 @@ tf_cc_test(
)
cc_library(
name = "device_state",
srcs = ["device_state.cc"],
hdrs = ["device_state.h"],
name = "local_device_state",
srcs = ["local_device_state.cc"],
hdrs = ["local_device_state.h"],
deps = [
":event_pool",
":semaphore",
@ -161,7 +161,7 @@ cc_library(
srcs = ["local_client.cc"],
hdrs = ["local_client.h"],
deps = [
":device_state",
":local_device_state",
":shared_device_buffer",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:literal",

View File

@ -105,6 +105,13 @@ limitations under the License.
namespace xla {
StatusOr<LocalDeviceState*> Device::GetLocalDeviceState() const {
if (local_device_state_) {
return local_device_state_.get();
}
return InvalidArgument("Device %s is not a local device.", DebugString());
}
std::string CpuDevice::DebugString() const {
return absl::StrCat("CPU_", id());
}
@ -115,7 +122,7 @@ std::string GpuDevice::DebugString() const {
static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
se::Platform* platform,
absl::Span<const std::unique_ptr<DeviceState>> device_states,
absl::Span<const std::shared_ptr<Device>> local_devices,
LocalClient* client, double memory_fraction, bool preallocate) {
CHECK_GT(client->backend().device_count(), 0);
std::vector<se::MultiDeviceAdapter::AllocatorWithStream> allocators;
@ -148,19 +155,24 @@ static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
/*allow_growth=*/!preallocate,
absl::StrCat("GPU_", device_ordinal, "_bfc"));
allocators.emplace_back(std::move(gpu_bfc_allocator),
device_states.at(device_ordinal)->compute_stream());
local_devices.at(device_ordinal)
->local_device_state()
->compute_stream());
}
return absl::make_unique<se::MultiDeviceAdapter>(platform,
std::move(allocators));
}
static std::shared_ptr<Device> MakeDevice(const std::string& platform_name,
int id, int local_device_ordinal) {
static std::shared_ptr<Device> MakeDevice(
const std::string& platform_name, int id,
std::unique_ptr<LocalDeviceState> local_device_state) {
if (platform_name == "cpu") {
return std::make_shared<CpuDevice>(id, local_device_ordinal, platform_name);
return std::make_shared<CpuDevice>(id, std::move(local_device_state),
platform_name);
} else {
CHECK_EQ(platform_name, "gpu");
return std::make_shared<GpuDevice>(id, local_device_ordinal, platform_name);
return std::make_shared<GpuDevice>(id, std::move(local_device_state),
platform_name);
}
}
@ -179,16 +191,15 @@ StatusOr<std::shared_ptr<PyLocalClient>> PyLocalClient::Get(
ClientLibrary::GetOrCreateLocalClient(options));
bool gpu_platform = platform_name == "gpu";
std::vector<std::unique_ptr<DeviceState>> device_states;
std::vector<std::shared_ptr<Device>> devices;
bool synchronous_deallocation = platform_name == "cpu";
for (int i = 0; i < client->device_count(); ++i) {
se::StreamExecutor* executor =
client->backend().stream_executor(i).ValueOrDie();
device_states.push_back(absl::make_unique<DeviceState>(
auto device_state = absl::make_unique<LocalDeviceState>(
executor, synchronous_deallocation, asynchronous,
/*allow_event_reuse=*/gpu_platform));
devices.push_back(MakeDevice(platform_name, i, i));
/*allow_event_reuse=*/gpu_platform);
devices.push_back(MakeDevice(platform_name, i, std::move(device_state)));
}
std::unique_ptr<se::DeviceMemoryAllocator> allocator;
@ -196,7 +207,7 @@ StatusOr<std::shared_ptr<PyLocalClient>> PyLocalClient::Get(
if (gpu_platform) {
if (allocator_config.kind != AllocatorConfig::Kind::kPlatform) {
TF_ASSIGN_OR_RETURN(allocator,
CreateBFCAllocator(platform, device_states, client,
CreateBFCAllocator(platform, devices, client,
allocator_config.memory_fraction,
allocator_config.preallocate));
}
@ -217,21 +228,18 @@ StatusOr<std::shared_ptr<PyLocalClient>> PyLocalClient::Get(
return std::make_shared<PyLocalClient>(
platform_name, client, std::move(devices), /*host_id=*/0,
std::move(device_states), std::move(allocator),
std::move(host_memory_allocator));
std::move(allocator), std::move(host_memory_allocator));
}
PyLocalClient::PyLocalClient(
std::string platform_name, LocalClient* client,
std::vector<std::shared_ptr<Device>> devices, int host_id,
std::vector<std::unique_ptr<DeviceState>> device_states,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator)
: platform_name_(std::move(platform_name)),
client_(client),
devices_(std::move(devices)),
host_id_(host_id),
device_states_(std::move(device_states)),
owned_allocator_(std::move(allocator)),
host_memory_allocator_(std::move(host_memory_allocator)),
h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
@ -242,15 +250,16 @@ PyLocalClient::PyLocalClient(
allocator_ = client_->backend().memory_allocator();
}
local_devices_.resize(device_states_.size());
for (const std::shared_ptr<Device>& device : devices_) {
CHECK(id_to_device_.insert({device->id(), device}).second)
<< "Duplicate device id: " << device->id();
if (device->local_device_ordinal() != -1) {
int idx = device->local_device_ordinal();
if (device->local_device_state()) {
int idx = device->local_device_state()->device_ordinal();
if (idx >= local_devices_.size()) {
local_devices_.resize(idx + 1);
}
CHECK(local_devices_[idx] == nullptr) << idx;
CHECK_LT(idx, local_devices_.size());
local_devices_[idx] = device;
}
}
@ -274,17 +283,19 @@ PyLocalClient::DeserializeExecutable(
}
Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal,
int device_ordinal) {
TF_RETURN_IF_ERROR(
CheckDeviceOrdinal(device_ordinal, "PyLocalClient::TransferToInfeed"));
return client_->TransferToInfeedLocal(literal, device_ordinal);
std::shared_ptr<Device> device) {
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState());
return client_->TransferToInfeedLocal(literal,
local_device->device_ordinal());
}
StatusOr<Literal> PyLocalClient::TransferFromOutfeed(const Shape& shape,
int device_ordinal) {
TF_RETURN_IF_ERROR(
CheckDeviceOrdinal(device_ordinal, "PyLocalClient::TransferFromOutfeed"));
return client_->TransferFromOutfeedLocal(shape, device_ordinal);
StatusOr<Literal> PyLocalClient::TransferFromOutfeed(
const Shape& shape, std::shared_ptr<Device> device) {
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState());
return client_->TransferFromOutfeedLocal(shape,
local_device->device_ordinal());
}
StatusOr<DeviceAssignment> PyLocalClient::GetDefaultDeviceAssignment(
@ -293,36 +304,26 @@ StatusOr<DeviceAssignment> PyLocalClient::GetDefaultDeviceAssignment(
num_replicas, /*computation_count=*/1);
}
Status PyLocalClient::CheckDeviceOrdinal(int device_ordinal,
absl::string_view caller_name) {
if (device_ordinal < 0 || device_ordinal >= local_device_count()) {
return InvalidArgument(
"%s got bad device_ordinal: %d (num_local_devices=%d)", caller_name,
device_ordinal, local_device_count());
}
return Status::OK();
}
/* static */
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
std::shared_ptr<void> leaves_reference,
std::shared_ptr<PyLocalClient> client, int device_ordinal) {
std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device) {
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromLiterals");
VLOG(1) << "PyLocalBuffer::FromLiterals: shape: " << tuple_shape.ToString()
<< " device ordinal: " << device_ordinal;
TF_RETURN_IF_ERROR(client->CheckDeviceOrdinal(device_ordinal,
"PyLocalBuffer::FromLiterals"));
DeviceState* device = &client->device_state(device_ordinal);
<< " device: " << device->DebugString();
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState());
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
se::DeviceMemoryAllocator* allocator = client->allocator();
TF_ASSIGN_OR_RETURN(
Shape compact_shape,
transfer_manager->ChooseCompactLayoutForShape(tuple_shape));
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer scoped_buffer,
transfer_manager->AllocateScopedShapedBuffer(
compact_shape, allocator, device_ordinal));
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer scoped_buffer,
transfer_manager->AllocateScopedShapedBuffer(
compact_shape, allocator, local_device->device_ordinal()));
// Make the host to device stream wait for the newly allocated buffer to be
// available on the compute stream. We schedule this wait synchronously; while
@ -331,8 +332,9 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
// computations that depend on this transfer being enqueued on the compute
// stream.
if (!transfer_manager->CanShapedBufferBeAccessedNow(
device->host_to_device_stream()->parent(), scoped_buffer)) {
device->host_to_device_stream()->ThenWaitFor(device->compute_stream());
local_device->host_to_device_stream()->parent(), scoped_buffer)) {
local_device->host_to_device_stream()->ThenWaitFor(
local_device->compute_stream());
}
std::shared_ptr<BufferDefinitionEvent> definition_event =
@ -344,16 +346,15 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
// TODO(makro): Use move capture once C++ 14 features are available.
auto leaves = std::make_shared<std::vector<BorrowingLiteral>>(
std::move(leaves_literals));
auto transfer_h2d = [client, transfer_manager, device, device_ordinal,
device_buffer, compact_shape, leaves,
leaves_reference]() {
auto transfer_h2d = [client, transfer_manager, local_device, device_buffer,
compact_shape, leaves, leaves_reference]() {
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way to
// report failures from a callback. However, the operations here are
// unlikely to fail and not recoverable even if we were to fail: DMAs to
// memory that has already been allocated, and a possible Event allocation.
ShapedBuffer buffer = device_buffer->AsShapedBuffer(compact_shape);
TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync(
device->host_to_device_stream(), buffer));
local_device->host_to_device_stream(), buffer));
std::vector<std::shared_ptr<void>> staging_buffers;
staging_buffers.reserve(leaves->size());
auto it = leaves->begin();
@ -363,7 +364,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
ShapedBuffer leaf(
indexed_shape.shape,
transfer_manager->HostShapeToDeviceShape(indexed_shape.shape),
client->client()->platform(), device_ordinal);
client->client()->platform(), local_device->device_ordinal());
leaf.buffers().CopySubtreeFrom(buffer.buffers(), indexed_shape.index, {});
// If applicable on the backend, stage the transfer via host memory
@ -379,51 +380,53 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
it->shape());
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
device->host_to_device_stream(), literal, leaf));
local_device->host_to_device_stream(), literal, leaf));
staging_buffers.push_back(std::move(staging_buffer));
} else {
// Otherwise, just transfer the literal.
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
device->host_to_device_stream(), *it, leaf));
local_device->host_to_device_stream(), *it, leaf));
}
++it;
}
EventPool::Handle event =
device->event_pool()
.ThenAllocateAndRecordEvent(device->host_to_device_stream())
local_device->event_pool()
.ThenAllocateAndRecordEvent(local_device->host_to_device_stream())
.ValueOrDie();
// Sets the buffer definition event. Note: this has the side effect of
// unblocking any host threads that may have been waiting to consume the
// buffer.
device_buffer->definition_event()->SetDefinitionEvent(
std::move(event), device->host_to_device_stream());
std::move(event), local_device->host_to_device_stream());
if (device->synchronous_deallocation()) {
device->ThenRelease(device->host_to_device_stream(), device_buffer);
if (local_device->synchronous_deallocation()) {
local_device->ThenRelease(local_device->host_to_device_stream(),
device_buffer);
}
device->ThenRelease(
device->host_to_device_stream(),
local_device->ThenRelease(
local_device->host_to_device_stream(),
std::make_pair(leaves_reference, std::move(staging_buffers)));
};
client->h2d_transfer_pool()->Schedule(transfer_h2d);
return absl::make_unique<PyLocalBuffer>(
compact_shape, std::move(device_buffer), std::move(client));
return absl::make_unique<PyLocalBuffer>(compact_shape,
std::move(device_buffer),
std::move(client), std::move(device));
}
/* static */ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::MakeTuple(
const std::vector<PyLocalBuffer*> buffers,
std::shared_ptr<PyLocalClient> client, int device_ordinal) {
TF_RETURN_IF_ERROR(
client->CheckDeviceOrdinal(device_ordinal, "PyLocalBuffer::MakeTuple"));
std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device) {
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState());
std::vector<Shape> host_shapes;
std::vector<std::shared_ptr<SharedDeviceBuffer>> device_buffers;
host_shapes.reserve(buffers.size());
device_buffers.reserve(buffers.size());
for (const PyLocalBuffer* buffer : buffers) {
TF_RET_CHECK(buffer->device_ordinal() == device_ordinal);
TF_RET_CHECK(buffer->device().get() == device.get());
std::shared_ptr<SharedDeviceBuffer> device_buffer = buffer->DeviceBuffer();
if (!device_buffer) {
return InvalidArgument(
@ -436,45 +439,48 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
se::DeviceMemoryAllocator* allocator = client->allocator();
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
DeviceState& device = client->device_state(device_ordinal);
auto definition_event = std::make_shared<BufferDefinitionEvent>();
TF_ASSIGN_OR_RETURN(
std::shared_ptr<SharedDeviceBuffer> tuple_buffer,
SharedDeviceBuffer::MakeTuple(device_buffers, transfer_manager, allocator,
device_ordinal, definition_event));
TF_ASSIGN_OR_RETURN(std::shared_ptr<SharedDeviceBuffer> tuple_buffer,
SharedDeviceBuffer::MakeTuple(
device_buffers, transfer_manager, allocator,
local_device->device_ordinal(), definition_event));
auto buffer = absl::make_unique<PyLocalBuffer>(
ShapeUtil::MakeTupleShape(host_shapes), tuple_buffer, std::move(client));
ShapeUtil::MakeTupleShape(host_shapes), tuple_buffer, std::move(client),
std::move(device));
// TODO(phawkins): extend TransferManager so we do not need to form a full
// ShapedBuffer just to write the root tuple index table.
TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, buffer->AsShapedBuffer());
if (!transfer_manager->CanShapedBufferBeAccessedNow(
device.host_to_device_stream()->parent(), shaped_buffer)) {
local_device->host_to_device_stream()->parent(), shaped_buffer)) {
// Wait for the compute stream so that memory allocations are synchronized.
device.host_to_device_stream()->ThenWaitFor(device.compute_stream());
local_device->host_to_device_stream()->ThenWaitFor(
local_device->compute_stream());
}
TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
device.host_to_device_stream(), shaped_buffer));
local_device->host_to_device_stream(), shaped_buffer));
TF_ASSIGN_OR_RETURN(EventPool::Handle event,
device.event_pool().ThenAllocateAndRecordEvent(
device.host_to_device_stream()));
local_device->event_pool().ThenAllocateAndRecordEvent(
local_device->host_to_device_stream()));
definition_event->SetDefinitionEvent(std::move(event),
device.host_to_device_stream());
local_device->host_to_device_stream());
if (device.synchronous_deallocation()) {
device.ThenRelease(device.host_to_device_stream(), std::move(tuple_buffer));
if (local_device->synchronous_deallocation()) {
local_device->ThenRelease(local_device->host_to_device_stream(),
std::move(tuple_buffer));
}
return buffer;
}
PyLocalBuffer::PyLocalBuffer(Shape on_host_shape,
std::shared_ptr<SharedDeviceBuffer> device_buffer,
std::shared_ptr<PyLocalClient> client)
std::shared_ptr<PyLocalClient> client,
std::shared_ptr<Device> device)
: client_(std::move(client)),
on_host_shape_(std::move(on_host_shape)),
device_ordinal_(device_buffer->device_ordinal()),
device_(std::move(device)),
device_buffer_(std::move(device_buffer)) {}
void PyLocalBuffer::Delete() {
@ -499,8 +505,7 @@ Status PyLocalBuffer::CopyToHostAsync() {
}
host_value = host_value_ = std::make_shared<HostValue>();
}
se::Stream* stream =
client_->device_state(device_ordinal_).device_to_host_stream();
se::Stream* stream = device_->local_device_state()->device_to_host_stream();
WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
host_value->value = std::make_shared<Literal>(on_host_shape_);
TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, AsShapedBuffer());
@ -564,36 +569,38 @@ PyLocalBuffer::DestructureTuple() {
for (int64 i = 0; i < num_children; ++i) {
results.push_back(absl::make_unique<PyLocalBuffer>(
on_host_shape_.tuple_shapes(i), device_buffer_->children().at(i),
client_));
client_, device_));
}
return results;
}
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
int dst_device_ordinal) {
std::shared_ptr<Device> dst_device) {
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::CopyToDevice");
std::shared_ptr<SharedDeviceBuffer> src_device_buffer = DeviceBuffer();
if (dst_device_ordinal == device_ordinal_) {
return absl::make_unique<PyLocalBuffer>(on_host_shape_, src_device_buffer,
client_);
}
int transfer_device_ordinal = client_->EnqueueD2DTransfersOnSrcStream()
? device_ordinal_
: dst_device_ordinal;
DeviceState& transfer_device = client_->device_state(transfer_device_ordinal);
const DeviceState& dst_device = client_->device_state(dst_device_ordinal);
TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device,
dst_device->GetLocalDeviceState());
se::Stream* transfer_stream = transfer_device.GetDeviceToDeviceStream();
if (dst_device.get() == device_.get()) {
return absl::make_unique<PyLocalBuffer>(on_host_shape_, src_device_buffer,
client_, device_);
}
LocalDeviceState* transfer_local_device =
client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state()
: dst_local_device;
se::Stream* transfer_stream =
transfer_local_device->GetDeviceToDeviceStream();
TransferManager* transfer_manager =
client_->client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer dst_buffer,
transfer_manager->AllocateScopedShapedBuffer(
on_host_shape_, client_->allocator(), dst_device_ordinal));
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer,
transfer_manager->AllocateScopedShapedBuffer(
on_host_shape_, client_->allocator(),
dst_local_device->device_ordinal()));
if (!transfer_manager->CanShapedBufferBeAccessedNow(
dst_device.compute_stream()->parent(), dst_buffer)) {
transfer_stream->ThenWaitFor(dst_device.compute_stream());
dst_local_device->compute_stream()->parent(), dst_buffer)) {
transfer_stream->ThenWaitFor(dst_local_device->compute_stream());
}
TF_ASSIGN_OR_RETURN(ShapedBuffer src_buffer, AsShapedBuffer());
@ -607,37 +614,39 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
TF_RET_CHECK(input_buffer.size() == output_buffer.size())
<< "input: " << input_buffer.size()
<< " output: " << output_buffer.size();
TF_RETURN_IF_ERROR(transfer_device.ThenMemcpyDeviceToDevice(
transfer_stream, dst_device.compute_stream(), input_buffer,
TF_RETURN_IF_ERROR(transfer_local_device->ThenMemcpyDeviceToDevice(
transfer_stream, dst_local_device->compute_stream(), input_buffer,
output_buffer));
}
// We hold on to the `src_device_buffer` until the transfer is finished.
transfer_device.ThenRelease(transfer_stream, std::move(src_device_buffer));
transfer_local_device->ThenRelease(transfer_stream,
std::move(src_device_buffer));
// Write new tuple buffers. The destination buffers have different addresses,
// so we must construct tuple buffers from scratch instead of copying them.
if (dst_buffer.on_device_shape().IsTuple()) {
TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync(
dst_device.host_to_device_stream(), dst_buffer));
dst_local_device->host_to_device_stream(), dst_buffer));
// We need a single definition event, so make the device to device stream
// wait for the stream that wrote the tuple index tables on the destination
// device.
transfer_stream->ThenWaitFor(dst_device.host_to_device_stream());
transfer_stream->ThenWaitFor(dst_local_device->host_to_device_stream());
}
auto definition_event = std::make_shared<BufferDefinitionEvent>();
TF_ASSIGN_OR_RETURN(
EventPool::Handle event,
transfer_device.event_pool().ThenAllocateAndRecordEvent(transfer_stream));
transfer_local_device->event_pool().ThenAllocateAndRecordEvent(
transfer_stream));
definition_event->SetDefinitionEvent(std::move(event), transfer_stream);
std::shared_ptr<SharedDeviceBuffer> dst_device_buffer =
SharedDeviceBuffer::FromScopedShapedBuffer(std::move(dst_buffer),
definition_event);
return absl::make_unique<PyLocalBuffer>(
on_host_shape_, std::move(dst_device_buffer), client_);
on_host_shape_, std::move(dst_device_buffer), client_, dst_device);
}
Status PyLocalBuffer::BlockHostUntilReady() {
@ -694,7 +703,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
const int device_id = (*device_assignment_)(replica, 0);
std::shared_ptr<Device> device = LookupDevice(*client_, device_id);
CHECK_EQ(device->host_id(), client_->host_id());
int device_ordinal = device->local_device_ordinal();
int device_ordinal = device->local_device_state()->device_ordinal();
tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
VLOG(3) << "Replica " << replica
<< " mapped to device ordinal for execution: " << device_ordinal;
@ -729,7 +738,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
<< " buffer: " << argument_buffers.back().ToString();
}
DeviceState* device_state = &client_->device_state(device_ordinal);
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
// The choice of where we wait is arbitrary; the reason for the wait is pacing
// to avoid problems such as memory fragmentation and running ahead too far,
// not for correctness. Placing it before the executable launch allows the
@ -782,7 +791,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
device_state->compute_stream(),
std::make_tuple(executable_, compute_reservation, device_assignment_));
return absl::make_unique<PyLocalBuffer>(on_host_shape, std::move(out_buffer),
client_);
client_, device);
}
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::Execute(
@ -833,8 +842,7 @@ PyLocalExecutable::ExecutePerReplica(
for (int i = 0; i < num_local_replicas; ++i) {
const int replica = local_replicas_[i];
std::shared_ptr<Device> device = local_devices_[i];
const DeviceState& device_state =
client_->device_state(device->local_device_ordinal());
const LocalDeviceState& device_state = *device->local_device_state();
device_state.execute_thread()->Schedule([&, replica, i] {
results[i] = ExecuteHelper(argument_handles[i], replica, run_id);

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/python/device_state.h"
#include "tensorflow/compiler/xla/python/local_device_state.h"
#include "tensorflow/compiler/xla/python/shared_device_buffer.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
@ -43,10 +43,10 @@ class PyLocalExecutable;
class Device {
public:
explicit Device(int id, int local_device_ordinal,
explicit Device(int id, std::unique_ptr<LocalDeviceState> local_device_state,
absl::string_view platform_name, int host_id = 0)
: id_(id),
local_device_ordinal_(local_device_ordinal),
local_device_state_(std::move(local_device_state)),
host_id_(host_id),
platform_name_(platform_name) {}
virtual ~Device() {}
@ -56,13 +56,17 @@ class Device {
// hosts' devices. This is the ID that should be used in a DeviceAssignment.
int id() const { return id_; }
// If this is a device local to this host, the local index of this device as
// according to the underlying backend. Unlike id(), this will always be in
// the range [0, num_local_devices), and can be used with the xla::LocalClient
// and xla::Backend APIs.
//
// -1 if this device is not local to this host.
int local_device_ordinal() const { return local_device_ordinal_; }
// If this is a device local to this host, returns a LocalDeviceState object
// that can be used to manipulate the device. Returns nullptr if the device is
// not local to this host.
LocalDeviceState* local_device_state() const {
return local_device_state_.get();
}
// If this is a device local to this host, returns a LocalDeviceState object
// that can be used to manipulate the device. Returns an error if the device
// is not local to this host.
StatusOr<LocalDeviceState*> GetLocalDeviceState() const;
// The ID of this device's host. This is always 0 on single-host platforms.
int host_id() const { return host_id_; }
@ -73,7 +77,7 @@ class Device {
private:
const int id_;
const int local_device_ordinal_;
const std::unique_ptr<LocalDeviceState> local_device_state_;
const int host_id_;
const std::string platform_name_;
};
@ -123,13 +127,14 @@ class PyLocalClient {
explicit PyLocalClient(
std::string platform_name, LocalClient* client,
std::vector<std::shared_ptr<Device>> devices, int host_id,
std::vector<std::unique_ptr<DeviceState>> device_states,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator);
virtual ~PyLocalClient() = default;
Status TransferToInfeed(const LiteralSlice& literal, int device_ordinal);
StatusOr<Literal> TransferFromOutfeed(const Shape& shape, int device_ordinal);
Status TransferToInfeed(const LiteralSlice& literal,
std::shared_ptr<Device> device);
StatusOr<Literal> TransferFromOutfeed(const Shape& shape,
std::shared_ptr<Device> device);
virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas) const;
@ -146,8 +151,8 @@ class PyLocalClient {
int host_id() const { return host_id_; }
const std::string& platform_name() const { return platform_name_; }
DeviceState& device_state(int device_ordinal) const {
return *device_states_.at(device_ordinal);
LocalDeviceState& device_state(int device_ordinal) const {
return *local_devices_.at(device_ordinal)->local_device_state();
}
LocalClient* client() const { return client_; }
@ -178,10 +183,6 @@ class PyLocalClient {
const std::string& serialized,
std::shared_ptr<PyLocalClient> this_shared) const;
// Returns a bad status containing `caller_name` if `device_ordinal` doesn't
// correspond to a local device.
Status CheckDeviceOrdinal(int device_ordinal, absl::string_view caller_name);
protected:
std::string platform_name_;
LocalClient* client_;
@ -194,8 +195,6 @@ class PyLocalClient {
std::vector<std::shared_ptr<Device>> local_devices_;
int host_id_;
// Device states local to this host. Indexed by local device ordinal.
std::vector<std::unique_ptr<DeviceState>> device_states_;
se::DeviceMemoryAllocator* allocator_;
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
@ -219,16 +218,16 @@ class PyLocalBuffer {
static StatusOr<std::unique_ptr<PyLocalBuffer>> FromLiterals(
std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
std::shared_ptr<void> leaves_reference,
std::shared_ptr<PyLocalClient> client, int device_ordinal);
std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device);
static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple(
const std::vector<PyLocalBuffer*> buffers,
std::shared_ptr<PyLocalClient> client, int device_ordinal);
std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device);
PyLocalBuffer() = default;
PyLocalBuffer(Shape on_host_shape,
std::shared_ptr<SharedDeviceBuffer> device_buffer,
std::shared_ptr<PyLocalClient> client);
std::shared_ptr<PyLocalClient> client,
std::shared_ptr<Device> device);
PyLocalBuffer(const PyLocalBuffer&) = delete;
PyLocalBuffer(PyLocalBuffer&&) = delete;
@ -236,7 +235,7 @@ class PyLocalBuffer {
PyLocalBuffer& operator=(PyLocalBuffer&&) = delete;
const Shape& on_host_shape() const { return on_host_shape_; }
int device_ordinal() const { return device_ordinal_; }
std::shared_ptr<Device> device() const { return device_; }
const std::string& platform_name() const { return client_->platform_name(); }
std::shared_ptr<PyLocalClient> client() const { return client_; }
@ -266,8 +265,9 @@ class PyLocalBuffer {
// Destructures a tuple-valued PyLocalBuffer into its constituent elements.
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> DestructureTuple();
// Copies the buffer to device `dst_device_ordinal`.
StatusOr<std::unique_ptr<PyLocalBuffer>> CopyToDevice(int dst_device_ordinal);
// Copies the buffer to device `dst_device`.
StatusOr<std::unique_ptr<PyLocalBuffer>> CopyToDevice(
std::shared_ptr<Device> dst_device);
// Blocks the host until the buffer's value has been computed and is ready for
// immediate use on the device. Useful in particular for timing benchmarks.
@ -276,7 +276,7 @@ class PyLocalBuffer {
private:
const std::shared_ptr<PyLocalClient> client_;
const Shape on_host_shape_;
const int device_ordinal_;
const std::shared_ptr<Device> device_;
mutable absl::Mutex mu_;
std::shared_ptr<SharedDeviceBuffer> device_buffer_ GUARDED_BY(mu_);

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/device_state.h"
#include "tensorflow/compiler/xla/python/local_device_state.h"
#include <memory>
#include <vector>
@ -24,12 +24,13 @@ limitations under the License.
namespace xla {
DeviceState::DeviceState(se::StreamExecutor* executor,
bool synchronous_deallocation, bool asynchronous,
bool allow_event_reuse)
LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor,
bool synchronous_deallocation,
bool asynchronous, bool allow_event_reuse)
: synchronous_deallocation_(synchronous_deallocation),
event_pool_(allow_event_reuse),
compute_semaphore_(/*capacity=*/asynchronous ? 32 : 1) {
compute_semaphore_(/*capacity=*/asynchronous ? 32 : 1),
executor_(executor) {
compute_stream_ = absl::make_unique<se::Stream>(executor);
host_to_device_stream_ = absl::make_unique<se::Stream>(executor);
device_to_host_stream_ = absl::make_unique<se::Stream>(executor);
@ -50,14 +51,14 @@ DeviceState::DeviceState(se::StreamExecutor* executor,
"py_xla_callback");
}
DeviceState::~DeviceState() {
LocalDeviceState::~LocalDeviceState() {
Status status = SynchronizeAllActivity();
if (!status.ok()) {
LOG(ERROR) << "Error when closing device: " << status;
}
}
Status DeviceState::SynchronizeAllActivity() {
Status LocalDeviceState::SynchronizeAllActivity() {
Status status;
// TODO(phawkins): in theory the call to SynchronizeAllActivity below should
// suffice. However on the Host platform SynchronizeAllActivity is a dummy
@ -73,10 +74,9 @@ Status DeviceState::SynchronizeAllActivity() {
return status;
}
Status DeviceState::ThenMemcpyDeviceToDevice(se::Stream* transfer_stream,
se::Stream* dst_stream,
se::DeviceMemoryBase src_buffer,
se::DeviceMemoryBase dst_buffer) {
Status LocalDeviceState::ThenMemcpyDeviceToDevice(
se::Stream* transfer_stream, se::Stream* dst_stream,
se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) {
// The default implementation simply calls ThenMemcpyD2D, and assumes that
// the buffer addresses identify the devices. This does not work
// on all platforms; this method is virtual so it can be overridden.
@ -84,14 +84,14 @@ Status DeviceState::ThenMemcpyDeviceToDevice(se::Stream* transfer_stream,
return Status::OK();
}
void DeviceState::ThenExecuteOnCallbackThread(
void LocalDeviceState::ThenExecuteOnCallbackThread(
se::Stream* stream, std::function<void()> callback) const {
stream->ThenDoHostCallback([this, callback]() mutable {
callback_thread_->Schedule(std::move(callback));
});
}
se::Stream* DeviceState::GetDeviceToDeviceStream() {
se::Stream* LocalDeviceState::GetDeviceToDeviceStream() {
absl::MutexLock lock(&mu_);
int i = next_device_to_device_stream_;
next_device_to_device_stream_ =

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_STATE_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_STATE_H_
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_
#include <memory>
#include <vector>
@ -29,9 +29,9 @@ limitations under the License.
namespace xla {
// Class that encapsulates state relating to a device (e.g., a GPU) on which we
// can perform computation and transfers. DeviceState objects only exist for
// devices local to this host.
class DeviceState {
// can perform computation and transfers. LocalDeviceState objects only exist
// for devices local to this host.
class LocalDeviceState {
public:
// If synchronous_deallocation is true, the host must not free buffers until
// compute/transfers that use those buffers have completed. For example, this
@ -40,9 +40,12 @@ class DeviceState {
//
// If asynchronous is false, the host will synchronize to the device after
// each execution or transfer. This is intended for debugging only.
DeviceState(se::StreamExecutor* executor, bool synchronous_deallocation,
bool asynchronous, bool allow_event_reuse);
virtual ~DeviceState();
LocalDeviceState(se::StreamExecutor* executor, bool synchronous_deallocation,
bool asynchronous, bool allow_event_reuse);
virtual ~LocalDeviceState();
// StreamExecutor (local) device ordinal.
int device_ordinal() const { return executor_->device_ordinal(); }
bool synchronous_deallocation() const { return synchronous_deallocation_; }
@ -104,6 +107,7 @@ class DeviceState {
// stream by the host ahead of the device.
Semaphore compute_semaphore_;
se::StreamExecutor* executor_;
std::unique_ptr<se::Stream> compute_stream_;
std::unique_ptr<se::Stream> host_to_device_stream_;
std::unique_ptr<se::Stream> device_to_host_stream_;
@ -132,4 +136,4 @@ class DeviceState {
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_STATE_H_
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_

View File

@ -19,7 +19,6 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/compiler/xla/python:device_state",
"//tensorflow/compiler/xla/python:local_client",
"//tensorflow/compiler/xla/python:semaphore",
"//tensorflow/compiler/xla/python/tpu_driver",

View File

@ -39,10 +39,9 @@ std::string TpuDevice::DebugString() const {
}
static std::shared_ptr<Device> MakeDevice(const std::string& platform_name,
int id, int local_device_ordinal) {
int id) {
CHECK_EQ(platform_name, "tpu");
CHECK_EQ(id, local_device_ordinal); // Every device must be local for now.
return std::make_shared<TpuDevice>(id, local_device_ordinal, "tpu");
return std::make_shared<TpuDevice>(id, /*local_device_state=*/nullptr, "tpu");
}
StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
@ -67,7 +66,7 @@ StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
LOG(INFO) << "Creating " << num_cores << " TPU device(s).";
devices.reserve(num_cores);
for (int i = 0; i < num_cores; ++i) {
devices.push_back(MakeDevice("tpu", i, i));
devices.push_back(MakeDevice("tpu", i));
}
return std::make_shared<PyTpuClient>("tpu", std::move(client),
@ -87,8 +86,8 @@ PyTpuClient::PyTpuClient(std::string platform_name,
CHECK(id_to_device_.insert({device->id(), device}).second)
<< "Duplicate device id: " << device->id();
if (device->local_device_ordinal() != -1) {
int idx = device->local_device_ordinal();
if (device->id() != -1) {
int idx = device->id();
CHECK(local_devices_[idx] == nullptr) << idx;
CHECK_LT(idx, local_devices_.size());
local_devices_[idx] = device;
@ -509,7 +508,7 @@ PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper(
const int device_id = device_assignment_(replica, 0);
std::shared_ptr<Device> device = LookupDevice(*client_, device_id);
CHECK_EQ(device->host_id(), client_->host_id());
int device_ordinal = device->local_device_ordinal();
int device_ordinal = device->id();
tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Execute");
VLOG(3) << "Replica " << replica
<< " mapped to device ordinal for execution: " << device_ordinal;
@ -742,7 +741,7 @@ PyTpuExecutable::ExecutePerReplica(
const int device_id = (*device_assignment)(replica, 0);
std::shared_ptr<Device> device = LookupDevice(*client, device_id);
CHECK_EQ(device->host_id(), client->host_id());
int device_ordinal = device->local_device_ordinal();
int device_ordinal = device->id();
loaded_programs[replica] = client->driver()->LoadProgram(
device_ordinal, compiled_program.get(), {});
}

View File

@ -24,7 +24,6 @@ limitations under the License.
#include "absl/synchronization/notification.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/python/device_state.h"
#include "tensorflow/compiler/xla/python/local_client.h"
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"

View File

@ -96,9 +96,9 @@ PYBIND11_MODULE(tpu_client_extension, m) {
std::make_move_iterator(tree.leaves.end()));
py::gil_scoped_release gil_release;
return PyTpuBuffer::FromLiterals(
std::move(leaves), tree.shape, std::move(py_buffer_ref),
std::move(client), device->local_device_ordinal());
return PyTpuBuffer::FromLiterals(std::move(leaves), tree.shape,
std::move(py_buffer_ref),
std::move(client), device->id());
})
.def_static(
"from_python",
@ -135,8 +135,8 @@ PYBIND11_MODULE(tpu_client_extension, m) {
"Cannot make tuple on device '%s' with '%s' backend",
device->DebugString(), client->platform_name());
}
return PyTpuBuffer::MakeTuple(
buffers, client, device->local_device_ordinal());
return PyTpuBuffer::MakeTuple(buffers, client,
device->id());
})
.def_static("make_tuple", &PyTpuBuffer::MakeTuple)
.def("copy_to_device",
@ -144,7 +144,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
CHECK(dst_device != nullptr);
GlobalPyRefManager()->CollectGarbage();
py::gil_scoped_release gil_release;
return buffer->CopyToDevice(dst_device->local_device_ordinal());
return buffer->CopyToDevice(dst_device->id());
})
.def("copy_to_device",
[](PyTpuBuffer* buffer, int dst_device_ordinal) {
@ -193,7 +193,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
[](const PyTpuExecutable& executable) {
std::vector<int> device_ordinals;
for (std::shared_ptr<Device> device : executable.local_devices()) {
device_ordinals.push_back(device->local_device_ordinal());
device_ordinals.push_back(device->id());
}
return device_ordinals;
})

View File

@ -142,6 +142,16 @@ Status PyRegisterCustomCallTarget(const std::string& fn_name,
return Status::OK();
}
StatusOr<std::shared_ptr<Device>> LookupDeviceOrdinal(
PyLocalClient* client, int device_ordinal, absl::string_view caller_name) {
if (device_ordinal < 0 || device_ordinal >= client->local_device_count()) {
return InvalidArgument(
"%s got bad device_ordinal: %d (num_local_devices=%d)", caller_name,
device_ordinal, client->local_device_count());
}
return client->local_devices()[device_ordinal];
}
} // namespace
PYBIND11_MODULE(xla_extension, m) {
@ -381,13 +391,27 @@ PYBIND11_MODULE(xla_extension, m) {
}
return result;
})
// TODO(phawkins): delete overload that accepts a device_ordinal after
// all callers have been updated to pass a Device.
.def("TransferToInfeed",
[](PyLocalClient* client, const LiteralSlice& literal,
int device_ordinal) {
GlobalPyRefManager()->CollectGarbage();
py::gil_scoped_release gil_release;
return client->TransferToInfeed(literal, device_ordinal);
TF_ASSIGN_OR_RETURN(std::shared_ptr<Device> device,
LookupDeviceOrdinal(client, device_ordinal,
"TransferToInfeed"));
return client->TransferToInfeed(literal, device);
})
.def("TransferToInfeed",
[](PyLocalClient* client, const LiteralSlice& literal,
std::shared_ptr<Device> device) {
GlobalPyRefManager()->CollectGarbage();
py::gil_scoped_release gil_release;
return client->TransferToInfeed(literal, device);
})
// TODO(phawkins): delete overload that accepts a device_ordinal after
// all callers have been updated to pass a Device.
.def("TransferFromOutfeed",
[](PyLocalClient* client, const Shape& shape,
int device_ordinal) -> StatusOr<py::object> {
@ -395,8 +419,24 @@ PYBIND11_MODULE(xla_extension, m) {
std::shared_ptr<Literal> literal_shared;
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(Literal literal, client->TransferFromOutfeed(
shape, device_ordinal));
TF_ASSIGN_OR_RETURN(std::shared_ptr<Device> device,
LookupDeviceOrdinal(client, device_ordinal,
"TransferFromOutfeed"));
TF_ASSIGN_OR_RETURN(Literal literal,
client->TransferFromOutfeed(shape, device));
literal_shared = std::make_shared<Literal>(std::move(literal));
}
return LiteralToPython(std::move(literal_shared));
})
.def("TransferFromOutfeed",
[](PyLocalClient* client, const Shape& shape,
std::shared_ptr<Device> device) -> StatusOr<py::object> {
GlobalPyRefManager()->CollectGarbage();
std::shared_ptr<Literal> literal_shared;
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(Literal literal,
client->TransferFromOutfeed(shape, device));
literal_shared = std::make_shared<Literal>(std::move(literal));
}
return LiteralToPython(std::move(literal_shared));
@ -440,7 +480,7 @@ PYBIND11_MODULE(xla_extension, m) {
py::gil_scoped_release gil_release;
return PyLocalBuffer::FromLiterals(
std::move(leaves), tree.shape, std::move(py_buffer_ref),
std::move(client), device->local_device_ordinal());
std::move(client), std::move(device));
})
.def_static("make_tuple",
[](const std::vector<PyLocalBuffer*> buffers,
@ -454,15 +494,15 @@ PYBIND11_MODULE(xla_extension, m) {
"Cannot make tuple on device '%s' with '%s' backend",
device->DebugString(), client->platform_name());
}
return PyLocalBuffer::MakeTuple(
buffers, client, device->local_device_ordinal());
return PyLocalBuffer::MakeTuple(buffers, std::move(client),
std::move(device));
})
.def("copy_to_device",
[](PyLocalBuffer* buffer, std::shared_ptr<Device> dst_device) {
CHECK(dst_device != nullptr);
GlobalPyRefManager()->CollectGarbage();
py::gil_scoped_release gil_release;
return buffer->CopyToDevice(dst_device->local_device_ordinal());
return buffer->CopyToDevice(std::move(dst_device));
})
.def("delete", &PyLocalBuffer::Delete)
.def("destructure", &PyLocalBuffer::DestructureTuple)
@ -485,10 +525,7 @@ PYBIND11_MODULE(xla_extension, m) {
return LiteralToPython(std::move(literal));
})
.def("shape", &PyLocalBuffer::on_host_shape)
.def("device",
[](PyLocalBuffer* buffer) -> std::shared_ptr<Device> {
return buffer->client()->local_devices()[buffer->device_ordinal()];
})
.def("device", &PyLocalBuffer::device)
.def("platform", &PyLocalBuffer::platform_name)
.def("is_deleted",
[](const PyLocalBuffer& buffer) {

View File

@ -444,7 +444,7 @@ def shape_from_pyval(pyval):
return convert(pyval)
def transfer_to_infeed(value, device_ordinal=0):
def transfer_to_infeed(value, device=None):
"""Transfers the given value into the XLA infeed queue.
XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with
@ -454,29 +454,31 @@ def transfer_to_infeed(value, device_ordinal=0):
Args:
value: the value that the caller would like to enqueue into the XLA infeed
queue
device_ordinal: the device to infeed the value to. Each device has a
device: the device to infeed the value to. Each device has a
distinct infeed queue.
"""
# TODO(phawkins): support non-default backends.
backend = get_local_backend()
backend.client.TransferToInfeed(value, device_ordinal)
device = device or backend.local_devices()[0]
backend.client.TransferToInfeed(value, device)
def transfer_from_outfeed(shape, device_ordinal=0):
"""Transfers a literal of the given shape from `device_ordinal`'s outfeed.
def transfer_from_outfeed(shape, device=None):
"""Transfers a literal of the given shape from `device`'s outfeed.
Args:
shape: The shape of the value to transfer from outfeed.
device_ordinal: The device ordinal to transfer the outfeed value from. Each
device has a distinct outfeed queue..
device: The device from which to transfer the outfeed value. Each device has
a distinct outfeed queue..
Returns:
The literal value that is produced from the outfeed queue.
"""
# TODO(phawkins): support non-default backends.
backend = get_local_backend()
device = device or backend.local_devices()[0]
return backend.client.TransferFromOutfeed(
shape.with_major_to_minor_layout_if_absent(), device_ordinal)
shape.with_major_to_minor_layout_if_absent(), device)
DeviceAssignment = _xla.DeviceAssignment