[XLA:Python] Refactor Device and DeviceState class.
* Remove all device_ordinals from the local_client.h API. Instead, change APIs to expect a Device object. * Rename DeviceState to LocalDeviceState, which perhaps better clarifies its role: it hold objects pertaining to a locally-attached device. * Make the LocalDeviceState owned by Device. PiperOrigin-RevId: 285394262 Change-Id: I8a4f14bb03be31c4667c85e7d39a4ebcee4c6f40
This commit is contained in:
parent
035050412d
commit
9a4295cb3d
@ -140,9 +140,9 @@ tf_cc_test(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "device_state",
|
||||
srcs = ["device_state.cc"],
|
||||
hdrs = ["device_state.h"],
|
||||
name = "local_device_state",
|
||||
srcs = ["local_device_state.cc"],
|
||||
hdrs = ["local_device_state.h"],
|
||||
deps = [
|
||||
":event_pool",
|
||||
":semaphore",
|
||||
@ -161,7 +161,7 @@ cc_library(
|
||||
srcs = ["local_client.cc"],
|
||||
hdrs = ["local_client.h"],
|
||||
deps = [
|
||||
":device_state",
|
||||
":local_device_state",
|
||||
":shared_device_buffer",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
|
@ -105,6 +105,13 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<LocalDeviceState*> Device::GetLocalDeviceState() const {
|
||||
if (local_device_state_) {
|
||||
return local_device_state_.get();
|
||||
}
|
||||
return InvalidArgument("Device %s is not a local device.", DebugString());
|
||||
}
|
||||
|
||||
std::string CpuDevice::DebugString() const {
|
||||
return absl::StrCat("CPU_", id());
|
||||
}
|
||||
@ -115,7 +122,7 @@ std::string GpuDevice::DebugString() const {
|
||||
|
||||
static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
|
||||
se::Platform* platform,
|
||||
absl::Span<const std::unique_ptr<DeviceState>> device_states,
|
||||
absl::Span<const std::shared_ptr<Device>> local_devices,
|
||||
LocalClient* client, double memory_fraction, bool preallocate) {
|
||||
CHECK_GT(client->backend().device_count(), 0);
|
||||
std::vector<se::MultiDeviceAdapter::AllocatorWithStream> allocators;
|
||||
@ -148,19 +155,24 @@ static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
|
||||
/*allow_growth=*/!preallocate,
|
||||
absl::StrCat("GPU_", device_ordinal, "_bfc"));
|
||||
allocators.emplace_back(std::move(gpu_bfc_allocator),
|
||||
device_states.at(device_ordinal)->compute_stream());
|
||||
local_devices.at(device_ordinal)
|
||||
->local_device_state()
|
||||
->compute_stream());
|
||||
}
|
||||
return absl::make_unique<se::MultiDeviceAdapter>(platform,
|
||||
std::move(allocators));
|
||||
}
|
||||
|
||||
static std::shared_ptr<Device> MakeDevice(const std::string& platform_name,
|
||||
int id, int local_device_ordinal) {
|
||||
static std::shared_ptr<Device> MakeDevice(
|
||||
const std::string& platform_name, int id,
|
||||
std::unique_ptr<LocalDeviceState> local_device_state) {
|
||||
if (platform_name == "cpu") {
|
||||
return std::make_shared<CpuDevice>(id, local_device_ordinal, platform_name);
|
||||
return std::make_shared<CpuDevice>(id, std::move(local_device_state),
|
||||
platform_name);
|
||||
} else {
|
||||
CHECK_EQ(platform_name, "gpu");
|
||||
return std::make_shared<GpuDevice>(id, local_device_ordinal, platform_name);
|
||||
return std::make_shared<GpuDevice>(id, std::move(local_device_state),
|
||||
platform_name);
|
||||
}
|
||||
}
|
||||
|
||||
@ -179,16 +191,15 @@ StatusOr<std::shared_ptr<PyLocalClient>> PyLocalClient::Get(
|
||||
ClientLibrary::GetOrCreateLocalClient(options));
|
||||
|
||||
bool gpu_platform = platform_name == "gpu";
|
||||
std::vector<std::unique_ptr<DeviceState>> device_states;
|
||||
std::vector<std::shared_ptr<Device>> devices;
|
||||
bool synchronous_deallocation = platform_name == "cpu";
|
||||
for (int i = 0; i < client->device_count(); ++i) {
|
||||
se::StreamExecutor* executor =
|
||||
client->backend().stream_executor(i).ValueOrDie();
|
||||
device_states.push_back(absl::make_unique<DeviceState>(
|
||||
auto device_state = absl::make_unique<LocalDeviceState>(
|
||||
executor, synchronous_deallocation, asynchronous,
|
||||
/*allow_event_reuse=*/gpu_platform));
|
||||
devices.push_back(MakeDevice(platform_name, i, i));
|
||||
/*allow_event_reuse=*/gpu_platform);
|
||||
devices.push_back(MakeDevice(platform_name, i, std::move(device_state)));
|
||||
}
|
||||
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator;
|
||||
@ -196,7 +207,7 @@ StatusOr<std::shared_ptr<PyLocalClient>> PyLocalClient::Get(
|
||||
if (gpu_platform) {
|
||||
if (allocator_config.kind != AllocatorConfig::Kind::kPlatform) {
|
||||
TF_ASSIGN_OR_RETURN(allocator,
|
||||
CreateBFCAllocator(platform, device_states, client,
|
||||
CreateBFCAllocator(platform, devices, client,
|
||||
allocator_config.memory_fraction,
|
||||
allocator_config.preallocate));
|
||||
}
|
||||
@ -217,21 +228,18 @@ StatusOr<std::shared_ptr<PyLocalClient>> PyLocalClient::Get(
|
||||
|
||||
return std::make_shared<PyLocalClient>(
|
||||
platform_name, client, std::move(devices), /*host_id=*/0,
|
||||
std::move(device_states), std::move(allocator),
|
||||
std::move(host_memory_allocator));
|
||||
std::move(allocator), std::move(host_memory_allocator));
|
||||
}
|
||||
|
||||
PyLocalClient::PyLocalClient(
|
||||
std::string platform_name, LocalClient* client,
|
||||
std::vector<std::shared_ptr<Device>> devices, int host_id,
|
||||
std::vector<std::unique_ptr<DeviceState>> device_states,
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator)
|
||||
: platform_name_(std::move(platform_name)),
|
||||
client_(client),
|
||||
devices_(std::move(devices)),
|
||||
host_id_(host_id),
|
||||
device_states_(std::move(device_states)),
|
||||
owned_allocator_(std::move(allocator)),
|
||||
host_memory_allocator_(std::move(host_memory_allocator)),
|
||||
h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
|
||||
@ -242,15 +250,16 @@ PyLocalClient::PyLocalClient(
|
||||
allocator_ = client_->backend().memory_allocator();
|
||||
}
|
||||
|
||||
local_devices_.resize(device_states_.size());
|
||||
for (const std::shared_ptr<Device>& device : devices_) {
|
||||
CHECK(id_to_device_.insert({device->id(), device}).second)
|
||||
<< "Duplicate device id: " << device->id();
|
||||
|
||||
if (device->local_device_ordinal() != -1) {
|
||||
int idx = device->local_device_ordinal();
|
||||
if (device->local_device_state()) {
|
||||
int idx = device->local_device_state()->device_ordinal();
|
||||
if (idx >= local_devices_.size()) {
|
||||
local_devices_.resize(idx + 1);
|
||||
}
|
||||
CHECK(local_devices_[idx] == nullptr) << idx;
|
||||
CHECK_LT(idx, local_devices_.size());
|
||||
local_devices_[idx] = device;
|
||||
}
|
||||
}
|
||||
@ -274,17 +283,19 @@ PyLocalClient::DeserializeExecutable(
|
||||
}
|
||||
|
||||
Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal,
|
||||
int device_ordinal) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckDeviceOrdinal(device_ordinal, "PyLocalClient::TransferToInfeed"));
|
||||
return client_->TransferToInfeedLocal(literal, device_ordinal);
|
||||
std::shared_ptr<Device> device) {
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
device->GetLocalDeviceState());
|
||||
return client_->TransferToInfeedLocal(literal,
|
||||
local_device->device_ordinal());
|
||||
}
|
||||
|
||||
StatusOr<Literal> PyLocalClient::TransferFromOutfeed(const Shape& shape,
|
||||
int device_ordinal) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckDeviceOrdinal(device_ordinal, "PyLocalClient::TransferFromOutfeed"));
|
||||
return client_->TransferFromOutfeedLocal(shape, device_ordinal);
|
||||
StatusOr<Literal> PyLocalClient::TransferFromOutfeed(
|
||||
const Shape& shape, std::shared_ptr<Device> device) {
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
device->GetLocalDeviceState());
|
||||
return client_->TransferFromOutfeedLocal(shape,
|
||||
local_device->device_ordinal());
|
||||
}
|
||||
|
||||
StatusOr<DeviceAssignment> PyLocalClient::GetDefaultDeviceAssignment(
|
||||
@ -293,36 +304,26 @@ StatusOr<DeviceAssignment> PyLocalClient::GetDefaultDeviceAssignment(
|
||||
num_replicas, /*computation_count=*/1);
|
||||
}
|
||||
|
||||
Status PyLocalClient::CheckDeviceOrdinal(int device_ordinal,
|
||||
absl::string_view caller_name) {
|
||||
if (device_ordinal < 0 || device_ordinal >= local_device_count()) {
|
||||
return InvalidArgument(
|
||||
"%s got bad device_ordinal: %d (num_local_devices=%d)", caller_name,
|
||||
device_ordinal, local_device_count());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
|
||||
std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
|
||||
std::shared_ptr<void> leaves_reference,
|
||||
std::shared_ptr<PyLocalClient> client, int device_ordinal) {
|
||||
std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device) {
|
||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromLiterals");
|
||||
VLOG(1) << "PyLocalBuffer::FromLiterals: shape: " << tuple_shape.ToString()
|
||||
<< " device ordinal: " << device_ordinal;
|
||||
TF_RETURN_IF_ERROR(client->CheckDeviceOrdinal(device_ordinal,
|
||||
"PyLocalBuffer::FromLiterals"));
|
||||
DeviceState* device = &client->device_state(device_ordinal);
|
||||
<< " device: " << device->DebugString();
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
device->GetLocalDeviceState());
|
||||
TransferManager* transfer_manager =
|
||||
client->client()->backend().transfer_manager();
|
||||
se::DeviceMemoryAllocator* allocator = client->allocator();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape compact_shape,
|
||||
transfer_manager->ChooseCompactLayoutForShape(tuple_shape));
|
||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer scoped_buffer,
|
||||
transfer_manager->AllocateScopedShapedBuffer(
|
||||
compact_shape, allocator, device_ordinal));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
ScopedShapedBuffer scoped_buffer,
|
||||
transfer_manager->AllocateScopedShapedBuffer(
|
||||
compact_shape, allocator, local_device->device_ordinal()));
|
||||
|
||||
// Make the host to device stream wait for the newly allocated buffer to be
|
||||
// available on the compute stream. We schedule this wait synchronously; while
|
||||
@ -331,8 +332,9 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
|
||||
// computations that depend on this transfer being enqueued on the compute
|
||||
// stream.
|
||||
if (!transfer_manager->CanShapedBufferBeAccessedNow(
|
||||
device->host_to_device_stream()->parent(), scoped_buffer)) {
|
||||
device->host_to_device_stream()->ThenWaitFor(device->compute_stream());
|
||||
local_device->host_to_device_stream()->parent(), scoped_buffer)) {
|
||||
local_device->host_to_device_stream()->ThenWaitFor(
|
||||
local_device->compute_stream());
|
||||
}
|
||||
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event =
|
||||
@ -344,16 +346,15 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
|
||||
// TODO(makro): Use move capture once C++ 14 features are available.
|
||||
auto leaves = std::make_shared<std::vector<BorrowingLiteral>>(
|
||||
std::move(leaves_literals));
|
||||
auto transfer_h2d = [client, transfer_manager, device, device_ordinal,
|
||||
device_buffer, compact_shape, leaves,
|
||||
leaves_reference]() {
|
||||
auto transfer_h2d = [client, transfer_manager, local_device, device_buffer,
|
||||
compact_shape, leaves, leaves_reference]() {
|
||||
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way to
|
||||
// report failures from a callback. However, the operations here are
|
||||
// unlikely to fail and not recoverable even if we were to fail: DMAs to
|
||||
// memory that has already been allocated, and a possible Event allocation.
|
||||
ShapedBuffer buffer = device_buffer->AsShapedBuffer(compact_shape);
|
||||
TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync(
|
||||
device->host_to_device_stream(), buffer));
|
||||
local_device->host_to_device_stream(), buffer));
|
||||
std::vector<std::shared_ptr<void>> staging_buffers;
|
||||
staging_buffers.reserve(leaves->size());
|
||||
auto it = leaves->begin();
|
||||
@ -363,7 +364,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
|
||||
ShapedBuffer leaf(
|
||||
indexed_shape.shape,
|
||||
transfer_manager->HostShapeToDeviceShape(indexed_shape.shape),
|
||||
client->client()->platform(), device_ordinal);
|
||||
client->client()->platform(), local_device->device_ordinal());
|
||||
leaf.buffers().CopySubtreeFrom(buffer.buffers(), indexed_shape.index, {});
|
||||
|
||||
// If applicable on the backend, stage the transfer via host memory
|
||||
@ -379,51 +380,53 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
|
||||
BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
|
||||
it->shape());
|
||||
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
||||
device->host_to_device_stream(), literal, leaf));
|
||||
local_device->host_to_device_stream(), literal, leaf));
|
||||
staging_buffers.push_back(std::move(staging_buffer));
|
||||
} else {
|
||||
// Otherwise, just transfer the literal.
|
||||
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
||||
device->host_to_device_stream(), *it, leaf));
|
||||
local_device->host_to_device_stream(), *it, leaf));
|
||||
}
|
||||
++it;
|
||||
}
|
||||
|
||||
EventPool::Handle event =
|
||||
device->event_pool()
|
||||
.ThenAllocateAndRecordEvent(device->host_to_device_stream())
|
||||
local_device->event_pool()
|
||||
.ThenAllocateAndRecordEvent(local_device->host_to_device_stream())
|
||||
.ValueOrDie();
|
||||
|
||||
// Sets the buffer definition event. Note: this has the side effect of
|
||||
// unblocking any host threads that may have been waiting to consume the
|
||||
// buffer.
|
||||
device_buffer->definition_event()->SetDefinitionEvent(
|
||||
std::move(event), device->host_to_device_stream());
|
||||
std::move(event), local_device->host_to_device_stream());
|
||||
|
||||
if (device->synchronous_deallocation()) {
|
||||
device->ThenRelease(device->host_to_device_stream(), device_buffer);
|
||||
if (local_device->synchronous_deallocation()) {
|
||||
local_device->ThenRelease(local_device->host_to_device_stream(),
|
||||
device_buffer);
|
||||
}
|
||||
|
||||
device->ThenRelease(
|
||||
device->host_to_device_stream(),
|
||||
local_device->ThenRelease(
|
||||
local_device->host_to_device_stream(),
|
||||
std::make_pair(leaves_reference, std::move(staging_buffers)));
|
||||
};
|
||||
client->h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||
return absl::make_unique<PyLocalBuffer>(
|
||||
compact_shape, std::move(device_buffer), std::move(client));
|
||||
return absl::make_unique<PyLocalBuffer>(compact_shape,
|
||||
std::move(device_buffer),
|
||||
std::move(client), std::move(device));
|
||||
}
|
||||
|
||||
/* static */ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::MakeTuple(
|
||||
const std::vector<PyLocalBuffer*> buffers,
|
||||
std::shared_ptr<PyLocalClient> client, int device_ordinal) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
client->CheckDeviceOrdinal(device_ordinal, "PyLocalBuffer::MakeTuple"));
|
||||
std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device) {
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
device->GetLocalDeviceState());
|
||||
std::vector<Shape> host_shapes;
|
||||
std::vector<std::shared_ptr<SharedDeviceBuffer>> device_buffers;
|
||||
host_shapes.reserve(buffers.size());
|
||||
device_buffers.reserve(buffers.size());
|
||||
for (const PyLocalBuffer* buffer : buffers) {
|
||||
TF_RET_CHECK(buffer->device_ordinal() == device_ordinal);
|
||||
TF_RET_CHECK(buffer->device().get() == device.get());
|
||||
std::shared_ptr<SharedDeviceBuffer> device_buffer = buffer->DeviceBuffer();
|
||||
if (!device_buffer) {
|
||||
return InvalidArgument(
|
||||
@ -436,45 +439,48 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
|
||||
se::DeviceMemoryAllocator* allocator = client->allocator();
|
||||
TransferManager* transfer_manager =
|
||||
client->client()->backend().transfer_manager();
|
||||
DeviceState& device = client->device_state(device_ordinal);
|
||||
|
||||
auto definition_event = std::make_shared<BufferDefinitionEvent>();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::shared_ptr<SharedDeviceBuffer> tuple_buffer,
|
||||
SharedDeviceBuffer::MakeTuple(device_buffers, transfer_manager, allocator,
|
||||
device_ordinal, definition_event));
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<SharedDeviceBuffer> tuple_buffer,
|
||||
SharedDeviceBuffer::MakeTuple(
|
||||
device_buffers, transfer_manager, allocator,
|
||||
local_device->device_ordinal(), definition_event));
|
||||
auto buffer = absl::make_unique<PyLocalBuffer>(
|
||||
ShapeUtil::MakeTupleShape(host_shapes), tuple_buffer, std::move(client));
|
||||
ShapeUtil::MakeTupleShape(host_shapes), tuple_buffer, std::move(client),
|
||||
std::move(device));
|
||||
|
||||
// TODO(phawkins): extend TransferManager so we do not need to form a full
|
||||
// ShapedBuffer just to write the root tuple index table.
|
||||
TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, buffer->AsShapedBuffer());
|
||||
if (!transfer_manager->CanShapedBufferBeAccessedNow(
|
||||
device.host_to_device_stream()->parent(), shaped_buffer)) {
|
||||
local_device->host_to_device_stream()->parent(), shaped_buffer)) {
|
||||
// Wait for the compute stream so that memory allocations are synchronized.
|
||||
device.host_to_device_stream()->ThenWaitFor(device.compute_stream());
|
||||
local_device->host_to_device_stream()->ThenWaitFor(
|
||||
local_device->compute_stream());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
|
||||
device.host_to_device_stream(), shaped_buffer));
|
||||
local_device->host_to_device_stream(), shaped_buffer));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(EventPool::Handle event,
|
||||
device.event_pool().ThenAllocateAndRecordEvent(
|
||||
device.host_to_device_stream()));
|
||||
local_device->event_pool().ThenAllocateAndRecordEvent(
|
||||
local_device->host_to_device_stream()));
|
||||
definition_event->SetDefinitionEvent(std::move(event),
|
||||
device.host_to_device_stream());
|
||||
local_device->host_to_device_stream());
|
||||
|
||||
if (device.synchronous_deallocation()) {
|
||||
device.ThenRelease(device.host_to_device_stream(), std::move(tuple_buffer));
|
||||
if (local_device->synchronous_deallocation()) {
|
||||
local_device->ThenRelease(local_device->host_to_device_stream(),
|
||||
std::move(tuple_buffer));
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
PyLocalBuffer::PyLocalBuffer(Shape on_host_shape,
|
||||
std::shared_ptr<SharedDeviceBuffer> device_buffer,
|
||||
std::shared_ptr<PyLocalClient> client)
|
||||
std::shared_ptr<PyLocalClient> client,
|
||||
std::shared_ptr<Device> device)
|
||||
: client_(std::move(client)),
|
||||
on_host_shape_(std::move(on_host_shape)),
|
||||
device_ordinal_(device_buffer->device_ordinal()),
|
||||
device_(std::move(device)),
|
||||
device_buffer_(std::move(device_buffer)) {}
|
||||
|
||||
void PyLocalBuffer::Delete() {
|
||||
@ -499,8 +505,7 @@ Status PyLocalBuffer::CopyToHostAsync() {
|
||||
}
|
||||
host_value = host_value_ = std::make_shared<HostValue>();
|
||||
}
|
||||
se::Stream* stream =
|
||||
client_->device_state(device_ordinal_).device_to_host_stream();
|
||||
se::Stream* stream = device_->local_device_state()->device_to_host_stream();
|
||||
WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
|
||||
host_value->value = std::make_shared<Literal>(on_host_shape_);
|
||||
TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, AsShapedBuffer());
|
||||
@ -564,36 +569,38 @@ PyLocalBuffer::DestructureTuple() {
|
||||
for (int64 i = 0; i < num_children; ++i) {
|
||||
results.push_back(absl::make_unique<PyLocalBuffer>(
|
||||
on_host_shape_.tuple_shapes(i), device_buffer_->children().at(i),
|
||||
client_));
|
||||
client_, device_));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
|
||||
int dst_device_ordinal) {
|
||||
std::shared_ptr<Device> dst_device) {
|
||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::CopyToDevice");
|
||||
std::shared_ptr<SharedDeviceBuffer> src_device_buffer = DeviceBuffer();
|
||||
if (dst_device_ordinal == device_ordinal_) {
|
||||
return absl::make_unique<PyLocalBuffer>(on_host_shape_, src_device_buffer,
|
||||
client_);
|
||||
}
|
||||
int transfer_device_ordinal = client_->EnqueueD2DTransfersOnSrcStream()
|
||||
? device_ordinal_
|
||||
: dst_device_ordinal;
|
||||
DeviceState& transfer_device = client_->device_state(transfer_device_ordinal);
|
||||
const DeviceState& dst_device = client_->device_state(dst_device_ordinal);
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device,
|
||||
dst_device->GetLocalDeviceState());
|
||||
|
||||
se::Stream* transfer_stream = transfer_device.GetDeviceToDeviceStream();
|
||||
if (dst_device.get() == device_.get()) {
|
||||
return absl::make_unique<PyLocalBuffer>(on_host_shape_, src_device_buffer,
|
||||
client_, device_);
|
||||
}
|
||||
LocalDeviceState* transfer_local_device =
|
||||
client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state()
|
||||
: dst_local_device;
|
||||
|
||||
se::Stream* transfer_stream =
|
||||
transfer_local_device->GetDeviceToDeviceStream();
|
||||
|
||||
TransferManager* transfer_manager =
|
||||
client_->client()->backend().transfer_manager();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
ScopedShapedBuffer dst_buffer,
|
||||
transfer_manager->AllocateScopedShapedBuffer(
|
||||
on_host_shape_, client_->allocator(), dst_device_ordinal));
|
||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer,
|
||||
transfer_manager->AllocateScopedShapedBuffer(
|
||||
on_host_shape_, client_->allocator(),
|
||||
dst_local_device->device_ordinal()));
|
||||
if (!transfer_manager->CanShapedBufferBeAccessedNow(
|
||||
dst_device.compute_stream()->parent(), dst_buffer)) {
|
||||
transfer_stream->ThenWaitFor(dst_device.compute_stream());
|
||||
dst_local_device->compute_stream()->parent(), dst_buffer)) {
|
||||
transfer_stream->ThenWaitFor(dst_local_device->compute_stream());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(ShapedBuffer src_buffer, AsShapedBuffer());
|
||||
|
||||
@ -607,37 +614,39 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
|
||||
TF_RET_CHECK(input_buffer.size() == output_buffer.size())
|
||||
<< "input: " << input_buffer.size()
|
||||
<< " output: " << output_buffer.size();
|
||||
TF_RETURN_IF_ERROR(transfer_device.ThenMemcpyDeviceToDevice(
|
||||
transfer_stream, dst_device.compute_stream(), input_buffer,
|
||||
TF_RETURN_IF_ERROR(transfer_local_device->ThenMemcpyDeviceToDevice(
|
||||
transfer_stream, dst_local_device->compute_stream(), input_buffer,
|
||||
output_buffer));
|
||||
}
|
||||
|
||||
// We hold on to the `src_device_buffer` until the transfer is finished.
|
||||
transfer_device.ThenRelease(transfer_stream, std::move(src_device_buffer));
|
||||
transfer_local_device->ThenRelease(transfer_stream,
|
||||
std::move(src_device_buffer));
|
||||
|
||||
// Write new tuple buffers. The destination buffers have different addresses,
|
||||
// so we must construct tuple buffers from scratch instead of copying them.
|
||||
if (dst_buffer.on_device_shape().IsTuple()) {
|
||||
TF_RETURN_IF_ERROR(transfer_manager->WriteTupleIndexTablesAsync(
|
||||
dst_device.host_to_device_stream(), dst_buffer));
|
||||
dst_local_device->host_to_device_stream(), dst_buffer));
|
||||
|
||||
// We need a single definition event, so make the device to device stream
|
||||
// wait for the stream that wrote the tuple index tables on the destination
|
||||
// device.
|
||||
transfer_stream->ThenWaitFor(dst_device.host_to_device_stream());
|
||||
transfer_stream->ThenWaitFor(dst_local_device->host_to_device_stream());
|
||||
}
|
||||
|
||||
auto definition_event = std::make_shared<BufferDefinitionEvent>();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
EventPool::Handle event,
|
||||
transfer_device.event_pool().ThenAllocateAndRecordEvent(transfer_stream));
|
||||
transfer_local_device->event_pool().ThenAllocateAndRecordEvent(
|
||||
transfer_stream));
|
||||
definition_event->SetDefinitionEvent(std::move(event), transfer_stream);
|
||||
|
||||
std::shared_ptr<SharedDeviceBuffer> dst_device_buffer =
|
||||
SharedDeviceBuffer::FromScopedShapedBuffer(std::move(dst_buffer),
|
||||
definition_event);
|
||||
return absl::make_unique<PyLocalBuffer>(
|
||||
on_host_shape_, std::move(dst_device_buffer), client_);
|
||||
on_host_shape_, std::move(dst_device_buffer), client_, dst_device);
|
||||
}
|
||||
|
||||
Status PyLocalBuffer::BlockHostUntilReady() {
|
||||
@ -694,7 +703,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
|
||||
const int device_id = (*device_assignment_)(replica, 0);
|
||||
std::shared_ptr<Device> device = LookupDevice(*client_, device_id);
|
||||
CHECK_EQ(device->host_id(), client_->host_id());
|
||||
int device_ordinal = device->local_device_ordinal();
|
||||
int device_ordinal = device->local_device_state()->device_ordinal();
|
||||
tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
|
||||
VLOG(3) << "Replica " << replica
|
||||
<< " mapped to device ordinal for execution: " << device_ordinal;
|
||||
@ -729,7 +738,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
|
||||
<< " buffer: " << argument_buffers.back().ToString();
|
||||
}
|
||||
|
||||
DeviceState* device_state = &client_->device_state(device_ordinal);
|
||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
||||
// The choice of where we wait is arbitrary; the reason for the wait is pacing
|
||||
// to avoid problems such as memory fragmentation and running ahead too far,
|
||||
// not for correctness. Placing it before the executable launch allows the
|
||||
@ -782,7 +791,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
|
||||
device_state->compute_stream(),
|
||||
std::make_tuple(executable_, compute_reservation, device_assignment_));
|
||||
return absl::make_unique<PyLocalBuffer>(on_host_shape, std::move(out_buffer),
|
||||
client_);
|
||||
client_, device);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::Execute(
|
||||
@ -833,8 +842,7 @@ PyLocalExecutable::ExecutePerReplica(
|
||||
for (int i = 0; i < num_local_replicas; ++i) {
|
||||
const int replica = local_replicas_[i];
|
||||
std::shared_ptr<Device> device = local_devices_[i];
|
||||
const DeviceState& device_state =
|
||||
client_->device_state(device->local_device_ordinal());
|
||||
const LocalDeviceState& device_state = *device->local_device_state();
|
||||
device_state.execute_thread()->Schedule([&, replica, i] {
|
||||
results[i] = ExecuteHelper(argument_handles[i], replica, run_id);
|
||||
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/python/device_state.h"
|
||||
#include "tensorflow/compiler/xla/python/local_device_state.h"
|
||||
#include "tensorflow/compiler/xla/python/shared_device_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
@ -43,10 +43,10 @@ class PyLocalExecutable;
|
||||
|
||||
class Device {
|
||||
public:
|
||||
explicit Device(int id, int local_device_ordinal,
|
||||
explicit Device(int id, std::unique_ptr<LocalDeviceState> local_device_state,
|
||||
absl::string_view platform_name, int host_id = 0)
|
||||
: id_(id),
|
||||
local_device_ordinal_(local_device_ordinal),
|
||||
local_device_state_(std::move(local_device_state)),
|
||||
host_id_(host_id),
|
||||
platform_name_(platform_name) {}
|
||||
virtual ~Device() {}
|
||||
@ -56,13 +56,17 @@ class Device {
|
||||
// hosts' devices. This is the ID that should be used in a DeviceAssignment.
|
||||
int id() const { return id_; }
|
||||
|
||||
// If this is a device local to this host, the local index of this device as
|
||||
// according to the underlying backend. Unlike id(), this will always be in
|
||||
// the range [0, num_local_devices), and can be used with the xla::LocalClient
|
||||
// and xla::Backend APIs.
|
||||
//
|
||||
// -1 if this device is not local to this host.
|
||||
int local_device_ordinal() const { return local_device_ordinal_; }
|
||||
// If this is a device local to this host, returns a LocalDeviceState object
|
||||
// that can be used to manipulate the device. Returns nullptr if the device is
|
||||
// not local to this host.
|
||||
LocalDeviceState* local_device_state() const {
|
||||
return local_device_state_.get();
|
||||
}
|
||||
|
||||
// If this is a device local to this host, returns a LocalDeviceState object
|
||||
// that can be used to manipulate the device. Returns an error if the device
|
||||
// is not local to this host.
|
||||
StatusOr<LocalDeviceState*> GetLocalDeviceState() const;
|
||||
|
||||
// The ID of this device's host. This is always 0 on single-host platforms.
|
||||
int host_id() const { return host_id_; }
|
||||
@ -73,7 +77,7 @@ class Device {
|
||||
|
||||
private:
|
||||
const int id_;
|
||||
const int local_device_ordinal_;
|
||||
const std::unique_ptr<LocalDeviceState> local_device_state_;
|
||||
const int host_id_;
|
||||
const std::string platform_name_;
|
||||
};
|
||||
@ -123,13 +127,14 @@ class PyLocalClient {
|
||||
explicit PyLocalClient(
|
||||
std::string platform_name, LocalClient* client,
|
||||
std::vector<std::shared_ptr<Device>> devices, int host_id,
|
||||
std::vector<std::unique_ptr<DeviceState>> device_states,
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator);
|
||||
virtual ~PyLocalClient() = default;
|
||||
|
||||
Status TransferToInfeed(const LiteralSlice& literal, int device_ordinal);
|
||||
StatusOr<Literal> TransferFromOutfeed(const Shape& shape, int device_ordinal);
|
||||
Status TransferToInfeed(const LiteralSlice& literal,
|
||||
std::shared_ptr<Device> device);
|
||||
StatusOr<Literal> TransferFromOutfeed(const Shape& shape,
|
||||
std::shared_ptr<Device> device);
|
||||
|
||||
virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
||||
int num_replicas) const;
|
||||
@ -146,8 +151,8 @@ class PyLocalClient {
|
||||
int host_id() const { return host_id_; }
|
||||
const std::string& platform_name() const { return platform_name_; }
|
||||
|
||||
DeviceState& device_state(int device_ordinal) const {
|
||||
return *device_states_.at(device_ordinal);
|
||||
LocalDeviceState& device_state(int device_ordinal) const {
|
||||
return *local_devices_.at(device_ordinal)->local_device_state();
|
||||
}
|
||||
|
||||
LocalClient* client() const { return client_; }
|
||||
@ -178,10 +183,6 @@ class PyLocalClient {
|
||||
const std::string& serialized,
|
||||
std::shared_ptr<PyLocalClient> this_shared) const;
|
||||
|
||||
// Returns a bad status containing `caller_name` if `device_ordinal` doesn't
|
||||
// correspond to a local device.
|
||||
Status CheckDeviceOrdinal(int device_ordinal, absl::string_view caller_name);
|
||||
|
||||
protected:
|
||||
std::string platform_name_;
|
||||
LocalClient* client_;
|
||||
@ -194,8 +195,6 @@ class PyLocalClient {
|
||||
std::vector<std::shared_ptr<Device>> local_devices_;
|
||||
int host_id_;
|
||||
|
||||
// Device states local to this host. Indexed by local device ordinal.
|
||||
std::vector<std::unique_ptr<DeviceState>> device_states_;
|
||||
se::DeviceMemoryAllocator* allocator_;
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
|
||||
|
||||
@ -219,16 +218,16 @@ class PyLocalBuffer {
|
||||
static StatusOr<std::unique_ptr<PyLocalBuffer>> FromLiterals(
|
||||
std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
|
||||
std::shared_ptr<void> leaves_reference,
|
||||
std::shared_ptr<PyLocalClient> client, int device_ordinal);
|
||||
std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device);
|
||||
|
||||
static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple(
|
||||
const std::vector<PyLocalBuffer*> buffers,
|
||||
std::shared_ptr<PyLocalClient> client, int device_ordinal);
|
||||
std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device);
|
||||
|
||||
PyLocalBuffer() = default;
|
||||
PyLocalBuffer(Shape on_host_shape,
|
||||
std::shared_ptr<SharedDeviceBuffer> device_buffer,
|
||||
std::shared_ptr<PyLocalClient> client);
|
||||
std::shared_ptr<PyLocalClient> client,
|
||||
std::shared_ptr<Device> device);
|
||||
|
||||
PyLocalBuffer(const PyLocalBuffer&) = delete;
|
||||
PyLocalBuffer(PyLocalBuffer&&) = delete;
|
||||
@ -236,7 +235,7 @@ class PyLocalBuffer {
|
||||
PyLocalBuffer& operator=(PyLocalBuffer&&) = delete;
|
||||
|
||||
const Shape& on_host_shape() const { return on_host_shape_; }
|
||||
int device_ordinal() const { return device_ordinal_; }
|
||||
std::shared_ptr<Device> device() const { return device_; }
|
||||
const std::string& platform_name() const { return client_->platform_name(); }
|
||||
std::shared_ptr<PyLocalClient> client() const { return client_; }
|
||||
|
||||
@ -266,8 +265,9 @@ class PyLocalBuffer {
|
||||
// Destructures a tuple-valued PyLocalBuffer into its constituent elements.
|
||||
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> DestructureTuple();
|
||||
|
||||
// Copies the buffer to device `dst_device_ordinal`.
|
||||
StatusOr<std::unique_ptr<PyLocalBuffer>> CopyToDevice(int dst_device_ordinal);
|
||||
// Copies the buffer to device `dst_device`.
|
||||
StatusOr<std::unique_ptr<PyLocalBuffer>> CopyToDevice(
|
||||
std::shared_ptr<Device> dst_device);
|
||||
|
||||
// Blocks the host until the buffer's value has been computed and is ready for
|
||||
// immediate use on the device. Useful in particular for timing benchmarks.
|
||||
@ -276,7 +276,7 @@ class PyLocalBuffer {
|
||||
private:
|
||||
const std::shared_ptr<PyLocalClient> client_;
|
||||
const Shape on_host_shape_;
|
||||
const int device_ordinal_;
|
||||
const std::shared_ptr<Device> device_;
|
||||
mutable absl::Mutex mu_;
|
||||
std::shared_ptr<SharedDeviceBuffer> device_buffer_ GUARDED_BY(mu_);
|
||||
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/python/device_state.h"
|
||||
#include "tensorflow/compiler/xla/python/local_device_state.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
@ -24,12 +24,13 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
DeviceState::DeviceState(se::StreamExecutor* executor,
|
||||
bool synchronous_deallocation, bool asynchronous,
|
||||
bool allow_event_reuse)
|
||||
LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor,
|
||||
bool synchronous_deallocation,
|
||||
bool asynchronous, bool allow_event_reuse)
|
||||
: synchronous_deallocation_(synchronous_deallocation),
|
||||
event_pool_(allow_event_reuse),
|
||||
compute_semaphore_(/*capacity=*/asynchronous ? 32 : 1) {
|
||||
compute_semaphore_(/*capacity=*/asynchronous ? 32 : 1),
|
||||
executor_(executor) {
|
||||
compute_stream_ = absl::make_unique<se::Stream>(executor);
|
||||
host_to_device_stream_ = absl::make_unique<se::Stream>(executor);
|
||||
device_to_host_stream_ = absl::make_unique<se::Stream>(executor);
|
||||
@ -50,14 +51,14 @@ DeviceState::DeviceState(se::StreamExecutor* executor,
|
||||
"py_xla_callback");
|
||||
}
|
||||
|
||||
DeviceState::~DeviceState() {
|
||||
LocalDeviceState::~LocalDeviceState() {
|
||||
Status status = SynchronizeAllActivity();
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Error when closing device: " << status;
|
||||
}
|
||||
}
|
||||
|
||||
Status DeviceState::SynchronizeAllActivity() {
|
||||
Status LocalDeviceState::SynchronizeAllActivity() {
|
||||
Status status;
|
||||
// TODO(phawkins): in theory the call to SynchronizeAllActivity below should
|
||||
// suffice. However on the Host platform SynchronizeAllActivity is a dummy
|
||||
@ -73,10 +74,9 @@ Status DeviceState::SynchronizeAllActivity() {
|
||||
return status;
|
||||
}
|
||||
|
||||
Status DeviceState::ThenMemcpyDeviceToDevice(se::Stream* transfer_stream,
|
||||
se::Stream* dst_stream,
|
||||
se::DeviceMemoryBase src_buffer,
|
||||
se::DeviceMemoryBase dst_buffer) {
|
||||
Status LocalDeviceState::ThenMemcpyDeviceToDevice(
|
||||
se::Stream* transfer_stream, se::Stream* dst_stream,
|
||||
se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) {
|
||||
// The default implementation simply calls ThenMemcpyD2D, and assumes that
|
||||
// the buffer addresses identify the devices. This does not work
|
||||
// on all platforms; this method is virtual so it can be overridden.
|
||||
@ -84,14 +84,14 @@ Status DeviceState::ThenMemcpyDeviceToDevice(se::Stream* transfer_stream,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void DeviceState::ThenExecuteOnCallbackThread(
|
||||
void LocalDeviceState::ThenExecuteOnCallbackThread(
|
||||
se::Stream* stream, std::function<void()> callback) const {
|
||||
stream->ThenDoHostCallback([this, callback]() mutable {
|
||||
callback_thread_->Schedule(std::move(callback));
|
||||
});
|
||||
}
|
||||
|
||||
se::Stream* DeviceState::GetDeviceToDeviceStream() {
|
||||
se::Stream* LocalDeviceState::GetDeviceToDeviceStream() {
|
||||
absl::MutexLock lock(&mu_);
|
||||
int i = next_device_to_device_stream_;
|
||||
next_device_to_device_stream_ =
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_STATE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_STATE_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
@ -29,9 +29,9 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
// Class that encapsulates state relating to a device (e.g., a GPU) on which we
|
||||
// can perform computation and transfers. DeviceState objects only exist for
|
||||
// devices local to this host.
|
||||
class DeviceState {
|
||||
// can perform computation and transfers. LocalDeviceState objects only exist
|
||||
// for devices local to this host.
|
||||
class LocalDeviceState {
|
||||
public:
|
||||
// If synchronous_deallocation is true, the host must not free buffers until
|
||||
// compute/transfers that use those buffers have completed. For example, this
|
||||
@ -40,9 +40,12 @@ class DeviceState {
|
||||
//
|
||||
// If asynchronous is false, the host will synchronize to the device after
|
||||
// each execution or transfer. This is intended for debugging only.
|
||||
DeviceState(se::StreamExecutor* executor, bool synchronous_deallocation,
|
||||
bool asynchronous, bool allow_event_reuse);
|
||||
virtual ~DeviceState();
|
||||
LocalDeviceState(se::StreamExecutor* executor, bool synchronous_deallocation,
|
||||
bool asynchronous, bool allow_event_reuse);
|
||||
virtual ~LocalDeviceState();
|
||||
|
||||
// StreamExecutor (local) device ordinal.
|
||||
int device_ordinal() const { return executor_->device_ordinal(); }
|
||||
|
||||
bool synchronous_deallocation() const { return synchronous_deallocation_; }
|
||||
|
||||
@ -104,6 +107,7 @@ class DeviceState {
|
||||
// stream by the host ahead of the device.
|
||||
Semaphore compute_semaphore_;
|
||||
|
||||
se::StreamExecutor* executor_;
|
||||
std::unique_ptr<se::Stream> compute_stream_;
|
||||
std::unique_ptr<se::Stream> host_to_device_stream_;
|
||||
std::unique_ptr<se::Stream> device_to_host_stream_;
|
||||
@ -132,4 +136,4 @@ class DeviceState {
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_STATE_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_
|
@ -19,7 +19,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||
"//tensorflow/compiler/xla/python:device_state",
|
||||
"//tensorflow/compiler/xla/python:local_client",
|
||||
"//tensorflow/compiler/xla/python:semaphore",
|
||||
"//tensorflow/compiler/xla/python/tpu_driver",
|
||||
|
@ -39,10 +39,9 @@ std::string TpuDevice::DebugString() const {
|
||||
}
|
||||
|
||||
static std::shared_ptr<Device> MakeDevice(const std::string& platform_name,
|
||||
int id, int local_device_ordinal) {
|
||||
int id) {
|
||||
CHECK_EQ(platform_name, "tpu");
|
||||
CHECK_EQ(id, local_device_ordinal); // Every device must be local for now.
|
||||
return std::make_shared<TpuDevice>(id, local_device_ordinal, "tpu");
|
||||
return std::make_shared<TpuDevice>(id, /*local_device_state=*/nullptr, "tpu");
|
||||
}
|
||||
|
||||
StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
|
||||
@ -67,7 +66,7 @@ StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
|
||||
LOG(INFO) << "Creating " << num_cores << " TPU device(s).";
|
||||
devices.reserve(num_cores);
|
||||
for (int i = 0; i < num_cores; ++i) {
|
||||
devices.push_back(MakeDevice("tpu", i, i));
|
||||
devices.push_back(MakeDevice("tpu", i));
|
||||
}
|
||||
|
||||
return std::make_shared<PyTpuClient>("tpu", std::move(client),
|
||||
@ -87,8 +86,8 @@ PyTpuClient::PyTpuClient(std::string platform_name,
|
||||
CHECK(id_to_device_.insert({device->id(), device}).second)
|
||||
<< "Duplicate device id: " << device->id();
|
||||
|
||||
if (device->local_device_ordinal() != -1) {
|
||||
int idx = device->local_device_ordinal();
|
||||
if (device->id() != -1) {
|
||||
int idx = device->id();
|
||||
CHECK(local_devices_[idx] == nullptr) << idx;
|
||||
CHECK_LT(idx, local_devices_.size());
|
||||
local_devices_[idx] = device;
|
||||
@ -509,7 +508,7 @@ PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper(
|
||||
const int device_id = device_assignment_(replica, 0);
|
||||
std::shared_ptr<Device> device = LookupDevice(*client_, device_id);
|
||||
CHECK_EQ(device->host_id(), client_->host_id());
|
||||
int device_ordinal = device->local_device_ordinal();
|
||||
int device_ordinal = device->id();
|
||||
tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Execute");
|
||||
VLOG(3) << "Replica " << replica
|
||||
<< " mapped to device ordinal for execution: " << device_ordinal;
|
||||
@ -742,7 +741,7 @@ PyTpuExecutable::ExecutePerReplica(
|
||||
const int device_id = (*device_assignment)(replica, 0);
|
||||
std::shared_ptr<Device> device = LookupDevice(*client, device_id);
|
||||
CHECK_EQ(device->host_id(), client->host_id());
|
||||
int device_ordinal = device->local_device_ordinal();
|
||||
int device_ordinal = device->id();
|
||||
loaded_programs[replica] = client->driver()->LoadProgram(
|
||||
device_ordinal, compiled_program.get(), {});
|
||||
}
|
||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include "absl/synchronization/notification.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||
#include "tensorflow/compiler/xla/python/device_state.h"
|
||||
#include "tensorflow/compiler/xla/python/local_client.h"
|
||||
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
|
||||
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
|
||||
|
@ -96,9 +96,9 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
std::make_move_iterator(tree.leaves.end()));
|
||||
|
||||
py::gil_scoped_release gil_release;
|
||||
return PyTpuBuffer::FromLiterals(
|
||||
std::move(leaves), tree.shape, std::move(py_buffer_ref),
|
||||
std::move(client), device->local_device_ordinal());
|
||||
return PyTpuBuffer::FromLiterals(std::move(leaves), tree.shape,
|
||||
std::move(py_buffer_ref),
|
||||
std::move(client), device->id());
|
||||
})
|
||||
.def_static(
|
||||
"from_python",
|
||||
@ -135,8 +135,8 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
"Cannot make tuple on device '%s' with '%s' backend",
|
||||
device->DebugString(), client->platform_name());
|
||||
}
|
||||
return PyTpuBuffer::MakeTuple(
|
||||
buffers, client, device->local_device_ordinal());
|
||||
return PyTpuBuffer::MakeTuple(buffers, client,
|
||||
device->id());
|
||||
})
|
||||
.def_static("make_tuple", &PyTpuBuffer::MakeTuple)
|
||||
.def("copy_to_device",
|
||||
@ -144,7 +144,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
CHECK(dst_device != nullptr);
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
py::gil_scoped_release gil_release;
|
||||
return buffer->CopyToDevice(dst_device->local_device_ordinal());
|
||||
return buffer->CopyToDevice(dst_device->id());
|
||||
})
|
||||
.def("copy_to_device",
|
||||
[](PyTpuBuffer* buffer, int dst_device_ordinal) {
|
||||
@ -193,7 +193,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
[](const PyTpuExecutable& executable) {
|
||||
std::vector<int> device_ordinals;
|
||||
for (std::shared_ptr<Device> device : executable.local_devices()) {
|
||||
device_ordinals.push_back(device->local_device_ordinal());
|
||||
device_ordinals.push_back(device->id());
|
||||
}
|
||||
return device_ordinals;
|
||||
})
|
||||
|
@ -142,6 +142,16 @@ Status PyRegisterCustomCallTarget(const std::string& fn_name,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<std::shared_ptr<Device>> LookupDeviceOrdinal(
|
||||
PyLocalClient* client, int device_ordinal, absl::string_view caller_name) {
|
||||
if (device_ordinal < 0 || device_ordinal >= client->local_device_count()) {
|
||||
return InvalidArgument(
|
||||
"%s got bad device_ordinal: %d (num_local_devices=%d)", caller_name,
|
||||
device_ordinal, client->local_device_count());
|
||||
}
|
||||
return client->local_devices()[device_ordinal];
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
PYBIND11_MODULE(xla_extension, m) {
|
||||
@ -381,13 +391,27 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
}
|
||||
return result;
|
||||
})
|
||||
// TODO(phawkins): delete overload that accepts a device_ordinal after
|
||||
// all callers have been updated to pass a Device.
|
||||
.def("TransferToInfeed",
|
||||
[](PyLocalClient* client, const LiteralSlice& literal,
|
||||
int device_ordinal) {
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
py::gil_scoped_release gil_release;
|
||||
return client->TransferToInfeed(literal, device_ordinal);
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<Device> device,
|
||||
LookupDeviceOrdinal(client, device_ordinal,
|
||||
"TransferToInfeed"));
|
||||
return client->TransferToInfeed(literal, device);
|
||||
})
|
||||
.def("TransferToInfeed",
|
||||
[](PyLocalClient* client, const LiteralSlice& literal,
|
||||
std::shared_ptr<Device> device) {
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
py::gil_scoped_release gil_release;
|
||||
return client->TransferToInfeed(literal, device);
|
||||
})
|
||||
// TODO(phawkins): delete overload that accepts a device_ordinal after
|
||||
// all callers have been updated to pass a Device.
|
||||
.def("TransferFromOutfeed",
|
||||
[](PyLocalClient* client, const Shape& shape,
|
||||
int device_ordinal) -> StatusOr<py::object> {
|
||||
@ -395,8 +419,24 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
std::shared_ptr<Literal> literal_shared;
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(Literal literal, client->TransferFromOutfeed(
|
||||
shape, device_ordinal));
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<Device> device,
|
||||
LookupDeviceOrdinal(client, device_ordinal,
|
||||
"TransferFromOutfeed"));
|
||||
TF_ASSIGN_OR_RETURN(Literal literal,
|
||||
client->TransferFromOutfeed(shape, device));
|
||||
literal_shared = std::make_shared<Literal>(std::move(literal));
|
||||
}
|
||||
return LiteralToPython(std::move(literal_shared));
|
||||
})
|
||||
.def("TransferFromOutfeed",
|
||||
[](PyLocalClient* client, const Shape& shape,
|
||||
std::shared_ptr<Device> device) -> StatusOr<py::object> {
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
std::shared_ptr<Literal> literal_shared;
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(Literal literal,
|
||||
client->TransferFromOutfeed(shape, device));
|
||||
literal_shared = std::make_shared<Literal>(std::move(literal));
|
||||
}
|
||||
return LiteralToPython(std::move(literal_shared));
|
||||
@ -440,7 +480,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
py::gil_scoped_release gil_release;
|
||||
return PyLocalBuffer::FromLiterals(
|
||||
std::move(leaves), tree.shape, std::move(py_buffer_ref),
|
||||
std::move(client), device->local_device_ordinal());
|
||||
std::move(client), std::move(device));
|
||||
})
|
||||
.def_static("make_tuple",
|
||||
[](const std::vector<PyLocalBuffer*> buffers,
|
||||
@ -454,15 +494,15 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
"Cannot make tuple on device '%s' with '%s' backend",
|
||||
device->DebugString(), client->platform_name());
|
||||
}
|
||||
return PyLocalBuffer::MakeTuple(
|
||||
buffers, client, device->local_device_ordinal());
|
||||
return PyLocalBuffer::MakeTuple(buffers, std::move(client),
|
||||
std::move(device));
|
||||
})
|
||||
.def("copy_to_device",
|
||||
[](PyLocalBuffer* buffer, std::shared_ptr<Device> dst_device) {
|
||||
CHECK(dst_device != nullptr);
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
py::gil_scoped_release gil_release;
|
||||
return buffer->CopyToDevice(dst_device->local_device_ordinal());
|
||||
return buffer->CopyToDevice(std::move(dst_device));
|
||||
})
|
||||
.def("delete", &PyLocalBuffer::Delete)
|
||||
.def("destructure", &PyLocalBuffer::DestructureTuple)
|
||||
@ -485,10 +525,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
return LiteralToPython(std::move(literal));
|
||||
})
|
||||
.def("shape", &PyLocalBuffer::on_host_shape)
|
||||
.def("device",
|
||||
[](PyLocalBuffer* buffer) -> std::shared_ptr<Device> {
|
||||
return buffer->client()->local_devices()[buffer->device_ordinal()];
|
||||
})
|
||||
.def("device", &PyLocalBuffer::device)
|
||||
.def("platform", &PyLocalBuffer::platform_name)
|
||||
.def("is_deleted",
|
||||
[](const PyLocalBuffer& buffer) {
|
||||
|
@ -444,7 +444,7 @@ def shape_from_pyval(pyval):
|
||||
return convert(pyval)
|
||||
|
||||
|
||||
def transfer_to_infeed(value, device_ordinal=0):
|
||||
def transfer_to_infeed(value, device=None):
|
||||
"""Transfers the given value into the XLA infeed queue.
|
||||
|
||||
XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with
|
||||
@ -454,29 +454,31 @@ def transfer_to_infeed(value, device_ordinal=0):
|
||||
Args:
|
||||
value: the value that the caller would like to enqueue into the XLA infeed
|
||||
queue
|
||||
device_ordinal: the device to infeed the value to. Each device has a
|
||||
device: the device to infeed the value to. Each device has a
|
||||
distinct infeed queue.
|
||||
"""
|
||||
# TODO(phawkins): support non-default backends.
|
||||
backend = get_local_backend()
|
||||
backend.client.TransferToInfeed(value, device_ordinal)
|
||||
device = device or backend.local_devices()[0]
|
||||
backend.client.TransferToInfeed(value, device)
|
||||
|
||||
|
||||
def transfer_from_outfeed(shape, device_ordinal=0):
|
||||
"""Transfers a literal of the given shape from `device_ordinal`'s outfeed.
|
||||
def transfer_from_outfeed(shape, device=None):
|
||||
"""Transfers a literal of the given shape from `device`'s outfeed.
|
||||
|
||||
Args:
|
||||
shape: The shape of the value to transfer from outfeed.
|
||||
device_ordinal: The device ordinal to transfer the outfeed value from. Each
|
||||
device has a distinct outfeed queue..
|
||||
device: The device from which to transfer the outfeed value. Each device has
|
||||
a distinct outfeed queue..
|
||||
|
||||
Returns:
|
||||
The literal value that is produced from the outfeed queue.
|
||||
"""
|
||||
# TODO(phawkins): support non-default backends.
|
||||
backend = get_local_backend()
|
||||
device = device or backend.local_devices()[0]
|
||||
return backend.client.TransferFromOutfeed(
|
||||
shape.with_major_to_minor_layout_if_absent(), device_ordinal)
|
||||
shape.with_major_to_minor_layout_if_absent(), device)
|
||||
|
||||
|
||||
DeviceAssignment = _xla.DeviceAssignment
|
||||
|
Loading…
Reference in New Issue
Block a user