Extract a PjRtDevice interface.
- Extract a pure interface
- Current implementation is renamed to PjRtStreamExecutorDevice.
- Rename local_device to addressable_device.
TODO: split into a pjrt_stream_executor.{h,cc}
PiperOrigin-RevId: 346249293
Change-Id: Icf3cf5fe876a71e172b8ba0fced5fc4c8ca1cc93
This commit is contained in:
parent
1b3b6d7470
commit
55691acd36
@ -26,8 +26,8 @@ static const char kCpuPlatformName[] = "cpu";
|
|||||||
|
|
||||||
CpuDevice::CpuDevice(int id,
|
CpuDevice::CpuDevice(int id,
|
||||||
std::unique_ptr<LocalDeviceState> local_device_state)
|
std::unique_ptr<LocalDeviceState> local_device_state)
|
||||||
: PjRtDevice(id, std::move(local_device_state),
|
: PjRtStreamExecutorDevice(id, std::move(local_device_state),
|
||||||
/*device_kind=*/kCpuPlatformName) {}
|
/*device_kind=*/kCpuPlatformName) {}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
||||||
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
||||||
|
|||||||
@ -23,7 +23,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
class CpuDevice : public PjRtDevice {
|
class CpuDevice : public PjRtStreamExecutorDevice {
|
||||||
public:
|
public:
|
||||||
CpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state);
|
CpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state);
|
||||||
};
|
};
|
||||||
|
|||||||
@ -306,8 +306,8 @@ Status BuildDistributedDevices(
|
|||||||
GpuDevice::GpuDevice(int id,
|
GpuDevice::GpuDevice(int id,
|
||||||
std::unique_ptr<LocalDeviceState> local_device_state,
|
std::unique_ptr<LocalDeviceState> local_device_state,
|
||||||
std::string device_kind, int node_id)
|
std::string device_kind, int node_id)
|
||||||
: PjRtDevice(id, std::move(local_device_state), std::move(device_kind),
|
: PjRtStreamExecutorDevice(id, std::move(local_device_state),
|
||||||
node_id) {}
|
std::move(device_kind), node_id) {}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient(
|
StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient(
|
||||||
bool asynchronous, const GpuAllocatorConfig& allocator_config,
|
bool asynchronous, const GpuAllocatorConfig& allocator_config,
|
||||||
|
|||||||
@ -25,7 +25,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
class GpuDevice : public PjRtDevice {
|
class GpuDevice : public PjRtStreamExecutorDevice {
|
||||||
public:
|
public:
|
||||||
GpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state,
|
GpuDevice(int id, std::unique_ptr<LocalDeviceState> local_device_state,
|
||||||
std::string device_kind, int node_id);
|
std::string device_kind, int node_id);
|
||||||
|
|||||||
@ -26,8 +26,8 @@ static const char kInterpreterPlatformName[] = "interpreter";
|
|||||||
|
|
||||||
InterpreterDevice::InterpreterDevice(
|
InterpreterDevice::InterpreterDevice(
|
||||||
int id, std::unique_ptr<LocalDeviceState> local_device_state)
|
int id, std::unique_ptr<LocalDeviceState> local_device_state)
|
||||||
: PjRtDevice(id, std::move(local_device_state),
|
: PjRtStreamExecutorDevice(id, std::move(local_device_state),
|
||||||
/*device_kind=*/kInterpreterPlatformName) {}
|
/*device_kind=*/kInterpreterPlatformName) {}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
|
StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
|
||||||
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
||||||
|
|||||||
@ -23,7 +23,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
class InterpreterDevice : public PjRtDevice {
|
class InterpreterDevice : public PjRtStreamExecutorDevice {
|
||||||
public:
|
public:
|
||||||
InterpreterDevice(int id,
|
InterpreterDevice(int id,
|
||||||
std::unique_ptr<LocalDeviceState> local_device_state);
|
std::unique_ptr<LocalDeviceState> local_device_state);
|
||||||
|
|||||||
@ -114,21 +114,22 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
PjRtPlatformId PjRtDevice::platform_id() const {
|
PjRtPlatformId PjRtStreamExecutorDevice::platform_id() const {
|
||||||
return client_->platform_id();
|
return client_->platform_id();
|
||||||
}
|
}
|
||||||
const std::string& PjRtDevice::platform_name() const {
|
const std::string& PjRtStreamExecutorDevice::platform_name() const {
|
||||||
return client_->platform_name();
|
return client_->platform_name();
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<LocalDeviceState*> PjRtDevice::GetLocalDeviceState() const {
|
StatusOr<LocalDeviceState*> PjRtStreamExecutorDevice::GetLocalDeviceState()
|
||||||
|
const {
|
||||||
if (local_device_state_) {
|
if (local_device_state_) {
|
||||||
return local_device_state_.get();
|
return local_device_state_.get();
|
||||||
}
|
}
|
||||||
return InvalidArgument("Device %s is not a local device.", DebugString());
|
return InvalidArgument("Device %s is not a local device.", DebugString());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string PjRtDevice::DebugString() const {
|
std::string PjRtStreamExecutorDevice::DebugString() const {
|
||||||
return absl::StrCat(platform_name(), ":", id());
|
return absl::StrCat(platform_name(), ":", id());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -153,14 +154,15 @@ StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
|
|||||||
devices[replica].size(), replica, devices[0].size());
|
devices[replica].size(), replica, devices[0].size());
|
||||||
}
|
}
|
||||||
for (int partition = 0; partition < devices[replica].size(); ++partition) {
|
for (int partition = 0; partition < devices[replica].size(); ++partition) {
|
||||||
if (devices[0][0]->platform_id() !=
|
if (devices[0][0]->client()->platform_id() !=
|
||||||
devices[replica][partition]->platform_id()) {
|
devices[replica][partition]->client()->platform_id()) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Device assignment passed to Compile() must have devices of a "
|
"Device assignment passed to Compile() must have devices of a "
|
||||||
"single kind, got %s for replica 0 partition 0 and %s for replica "
|
"single kind, got %s for replica 0 partition 0 and %s for replica "
|
||||||
"%d partition %d.",
|
"%d partition %d.",
|
||||||
devices[0][0]->platform_name(),
|
devices[0][0]->client()->platform_name(),
|
||||||
devices[replica][partition]->platform_name(), replica, partition);
|
devices[replica][partition]->client()->platform_name(), replica,
|
||||||
|
partition);
|
||||||
}
|
}
|
||||||
xla_assignment(replica, partition) = devices[replica][partition]->id();
|
xla_assignment(replica, partition) = devices[replica][partition]->id();
|
||||||
}
|
}
|
||||||
@ -215,15 +217,16 @@ PjRtClient::PjRtClient(
|
|||||||
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
|
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
|
||||||
<< "Duplicate device id: " << device->id();
|
<< "Duplicate device id: " << device->id();
|
||||||
|
|
||||||
if (device->IsLocalDevice()) {
|
if (device->IsAddressable()) {
|
||||||
int idx = device->local_device_id();
|
int idx = device->local_hardware_id();
|
||||||
if (idx >= local_devices_.size()) {
|
if (idx >= local_devices_.size()) {
|
||||||
local_devices_.resize(idx + 1);
|
local_devices_.resize(idx + 1);
|
||||||
}
|
}
|
||||||
CHECK(local_devices_[idx] == nullptr) << idx;
|
CHECK(local_devices_[idx] == nullptr) << idx;
|
||||||
local_devices_[idx] = device.get();
|
local_devices_[idx] = device.get();
|
||||||
}
|
}
|
||||||
device->SetClient(this);
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device.get())
|
||||||
|
->SetClient(this);
|
||||||
}
|
}
|
||||||
for (int idx = 0; idx < local_devices_.size(); ++idx) {
|
for (int idx = 0; idx < local_devices_.size(); ++idx) {
|
||||||
CHECK(local_devices_[idx] != nullptr) << idx;
|
CHECK(local_devices_[idx] != nullptr) << idx;
|
||||||
@ -554,7 +557,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
|
|||||||
return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
|
return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
|
||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||||
device->GetLocalDeviceState());
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||||
|
->GetLocalDeviceState());
|
||||||
int64 size = ShapeUtil::ByteSizeOf(shape);
|
int64 size = ShapeUtil::ByteSizeOf(shape);
|
||||||
|
|
||||||
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
||||||
@ -721,7 +725,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
|||||||
VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: "
|
VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: "
|
||||||
<< shape.ToString() << " device: " << device->DebugString();
|
<< shape.ToString() << " device: " << device->DebugString();
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||||
device->GetLocalDeviceState());
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||||
|
->GetLocalDeviceState());
|
||||||
|
|
||||||
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
||||||
TF_ASSIGN_OR_RETURN(Shape compact_shape,
|
TF_ASSIGN_OR_RETURN(Shape compact_shape,
|
||||||
@ -739,7 +744,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
|
|||||||
VLOG(2) << "PjRtClient::BufferFromHostLiteral: shape: "
|
VLOG(2) << "PjRtClient::BufferFromHostLiteral: shape: "
|
||||||
<< literal.shape().ToString() << " device: " << device->DebugString();
|
<< literal.shape().ToString() << " device: " << device->DebugString();
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||||
device->GetLocalDeviceState());
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||||
|
->GetLocalDeviceState());
|
||||||
|
|
||||||
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
@ -801,7 +807,9 @@ void PjRtClient::MakeCrossHostReceiveBuffers(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto local_device_or = device->GetLocalDeviceState();
|
auto local_device_or =
|
||||||
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||||
|
->GetLocalDeviceState();
|
||||||
if (!local_device_or.ok()) {
|
if (!local_device_or.ok()) {
|
||||||
notifier(local_device_or.status());
|
notifier(local_device_or.status());
|
||||||
return;
|
return;
|
||||||
@ -828,27 +836,29 @@ void PjRtClient::MakeCrossHostReceiveBuffers(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Transfer the given literal to the infeed queue of the given local device.
|
// Transfer the given literal to the infeed queue of the given local device.
|
||||||
Status PjRtDevice::TransferToInfeed(const LiteralSlice& literal) const {
|
Status PjRtStreamExecutorDevice::TransferToInfeed(
|
||||||
|
const LiteralSlice& literal) const {
|
||||||
// Only support infeed to local device.
|
// Only support infeed to local device.
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
|
||||||
return local_device->client()->TransferToInfeedLocal(
|
return local_device->client()->TransferToInfeedLocal(
|
||||||
literal, local_device->device_ordinal());
|
literal, local_device->device_ordinal());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<Literal> PjRtDevice::TransferFromOutfeed(const Shape& shape) const {
|
StatusOr<Literal> PjRtStreamExecutorDevice::TransferFromOutfeed(
|
||||||
|
const Shape& shape) const {
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
|
||||||
return local_device->client()->TransferFromOutfeedLocal(
|
return local_device->client()->TransferFromOutfeedLocal(
|
||||||
shape, local_device->device_ordinal());
|
shape, local_device->device_ordinal());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<PjRtDevice*> PjRtClient::LookupLocalDevice(int local_device_id) const {
|
StatusOr<PjRtDevice*> PjRtClient::LookupAddressableDevice(int device_id) const {
|
||||||
for (auto* device : local_devices_) {
|
for (auto* device : local_devices_) {
|
||||||
if (local_device_id == device->local_device_id()) {
|
if (device_id == device->local_hardware_id()) {
|
||||||
return device;
|
return device;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return InvalidArgument("No matching device found for local_device_id %d",
|
return InvalidArgument("No matching device found for device_id %d",
|
||||||
local_device_id);
|
device_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||||
@ -919,7 +929,9 @@ StatusOr<std::shared_ptr<TrackedDeviceBuffer>> PjRtBuffer::Release(
|
|||||||
// the final set of usage events.
|
// the final set of usage events.
|
||||||
events = device_buffer->LockUseAndTransferUsageEvents();
|
events = device_buffer->LockUseAndTransferUsageEvents();
|
||||||
}
|
}
|
||||||
LocalDeviceState* local_device_state = device_->local_device_state();
|
LocalDeviceState* local_device_state =
|
||||||
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||||
|
->local_device_state();
|
||||||
if (wait_for_operations_to_complete) {
|
if (wait_for_operations_to_complete) {
|
||||||
// Block the host until all usage events have completed. Usage events
|
// Block the host until all usage events have completed. Usage events
|
||||||
// dominate definition events, so this also waits for the buffer to be
|
// dominate definition events, so this also waits for the buffer to be
|
||||||
@ -1080,7 +1092,9 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
|
|||||||
}
|
}
|
||||||
ScopedHold device_buffer(this, ScopedHold::kUsage);
|
ScopedHold device_buffer(this, ScopedHold::kUsage);
|
||||||
std::shared_ptr<HostValue> host_value;
|
std::shared_ptr<HostValue> host_value;
|
||||||
LocalDeviceState* local_device = device_->local_device_state();
|
LocalDeviceState* local_device =
|
||||||
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||||
|
->local_device_state();
|
||||||
se::Stream* stream = local_device->GetDeviceToHostStream();
|
se::Stream* stream = local_device->GetDeviceToHostStream();
|
||||||
const xla::Layout& host_layout =
|
const xla::Layout& host_layout =
|
||||||
layout.has_value() ? layout.value() : on_host_shape_.layout();
|
layout.has_value() ? layout.value() : on_host_shape_.layout();
|
||||||
@ -1241,8 +1255,9 @@ PjRtBuffer::CopyToDeviceHelper(
|
|||||||
// StallStreamOnError only makes sure the destination device is ok, so
|
// StallStreamOnError only makes sure the destination device is ok, so
|
||||||
// make sure that the src buffer remains valid until after any transfers
|
// make sure that the src buffer remains valid until after any transfers
|
||||||
// have completed.
|
// have completed.
|
||||||
device_->local_device_state()->ThenRelease(transfer_stream,
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||||
src_device_buffer);
|
->local_device_state()
|
||||||
|
->ThenRelease(transfer_stream, src_device_buffer);
|
||||||
}
|
}
|
||||||
return copy_event_or.status();
|
return copy_event_or.status();
|
||||||
}
|
}
|
||||||
@ -1268,11 +1283,15 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
|
|||||||
PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, dst_device);
|
PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, dst_device);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device,
|
TF_ASSIGN_OR_RETURN(
|
||||||
dst_device->GetLocalDeviceState());
|
LocalDeviceState * dst_local_device,
|
||||||
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device)
|
||||||
|
->GetLocalDeviceState());
|
||||||
LocalDeviceState* transfer_local_device =
|
LocalDeviceState* transfer_local_device =
|
||||||
client_->EnqueueD2DTransfersOnSrcStream() ? device_->local_device_state()
|
client_->EnqueueD2DTransfersOnSrcStream()
|
||||||
: dst_local_device;
|
? tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||||
|
->local_device_state()
|
||||||
|
: dst_local_device;
|
||||||
CHECK_EQ(dst_local_device->allocation_model(),
|
CHECK_EQ(dst_local_device->allocation_model(),
|
||||||
transfer_local_device->allocation_model());
|
transfer_local_device->allocation_model());
|
||||||
|
|
||||||
@ -1310,7 +1329,9 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
|
|||||||
// alternative is to ensure, before freeing the buffer, that the compute
|
// alternative is to ensure, before freeing the buffer, that the compute
|
||||||
// stream is synchronized past the transfer, but it seems better to hold onto
|
// stream is synchronized past the transfer, but it seems better to hold onto
|
||||||
// the buffer too long than to stall the compute stream.
|
// the buffer too long than to stall the compute stream.
|
||||||
RecordUsage(std::move(src_device_buffer), device_->local_device_state(),
|
RecordUsage(std::move(src_device_buffer),
|
||||||
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||||
|
->local_device_state(),
|
||||||
transfer_local_device, event, transfer_stream,
|
transfer_local_device, event, transfer_stream,
|
||||||
/*prefer_to_retain_reference=*/true);
|
/*prefer_to_retain_reference=*/true);
|
||||||
|
|
||||||
@ -1332,7 +1353,9 @@ Status PjRtBuffer::BlockHostUntilReady() {
|
|||||||
}
|
}
|
||||||
device_buffer = device_buffer_;
|
device_buffer = device_buffer_;
|
||||||
}
|
}
|
||||||
LocalDeviceState* local_device_state = device_->local_device_state();
|
LocalDeviceState* local_device_state =
|
||||||
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||||
|
->local_device_state();
|
||||||
std::unique_ptr<se::Stream> stream;
|
std::unique_ptr<se::Stream> stream;
|
||||||
for (auto& event : device_buffer->definition_events()) {
|
for (auto& event : device_buffer->definition_events()) {
|
||||||
if (!event->IsComplete()) {
|
if (!event->IsComplete()) {
|
||||||
@ -1628,7 +1651,9 @@ StatusOr<ScopedShapedBuffer> PjRtStreamExecutorExecutable::EnqueueExecution(
|
|||||||
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
|
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
|
||||||
PjRtDevice* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers,
|
PjRtDevice* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers,
|
||||||
std::shared_ptr<DeviceAssignment> device_assignment) const {
|
std::shared_ptr<DeviceAssignment> device_assignment) const {
|
||||||
int device_ordinal = device->local_device_state()->device_ordinal();
|
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||||
|
->local_device_state()
|
||||||
|
->device_ordinal();
|
||||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
||||||
tensorflow::profiler::TraceMeConsumer activity(
|
tensorflow::profiler::TraceMeConsumer activity(
|
||||||
"LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
|
"LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
|
||||||
@ -1814,7 +1839,9 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
|
|||||||
}
|
}
|
||||||
|
|
||||||
CHECK_EQ(device->host_id(), client_->host_id());
|
CHECK_EQ(device->host_id(), client_->host_id());
|
||||||
int device_ordinal = device->local_device_state()->device_ordinal();
|
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||||
|
->local_device_state()
|
||||||
|
->device_ordinal();
|
||||||
tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
|
tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
|
||||||
VLOG(3) << "Replica " << replica << ", partition " << partition
|
VLOG(3) << "Replica " << replica << ", partition " << partition
|
||||||
<< " mapped to device ordinal for execution: " << device_ordinal;
|
<< " mapped to device ordinal for execution: " << device_ordinal;
|
||||||
@ -1922,7 +1949,9 @@ PjRtStreamExecutorExecutable::Execute(
|
|||||||
const int replica = addressable_device_logical_ids_[i].replica;
|
const int replica = addressable_device_logical_ids_[i].replica;
|
||||||
const int partition = addressable_device_logical_ids_[i].partition;
|
const int partition = addressable_device_logical_ids_[i].partition;
|
||||||
PjRtDevice* device = addressable_devices_[i];
|
PjRtDevice* device = addressable_devices_[i];
|
||||||
const LocalDeviceState& device_state = *device->local_device_state();
|
const LocalDeviceState& device_state =
|
||||||
|
*tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||||
|
->local_device_state();
|
||||||
device_state.execute_thread()->Schedule([&, replica, partition, i] {
|
device_state.execute_thread()->Schedule([&, replica, partition, i] {
|
||||||
results[i] = ExecuteHelper(argument_handles[i], replica, partition,
|
results[i] = ExecuteHelper(argument_handles[i], replica, partition,
|
||||||
run_id, options);
|
run_id, options);
|
||||||
@ -2254,7 +2283,10 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
|||||||
|
|
||||||
if (build_options.device_ordinal() < 0) {
|
if (build_options.device_ordinal() < 0) {
|
||||||
build_options.set_device_ordinal(
|
build_options.set_device_ordinal(
|
||||||
addressable_devices.front()->local_device_state()->device_ordinal());
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(
|
||||||
|
addressable_devices.front())
|
||||||
|
->local_device_state()
|
||||||
|
->device_ordinal());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -45,6 +45,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/framework/allocator.h"
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/fingerprint.h"
|
#include "tensorflow/core/platform/fingerprint.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
@ -67,16 +68,56 @@ class PjRtClient;
|
|||||||
|
|
||||||
class PjRtDevice {
|
class PjRtDevice {
|
||||||
public:
|
public:
|
||||||
explicit PjRtDevice(int id,
|
virtual ~PjRtDevice() {}
|
||||||
std::unique_ptr<LocalDeviceState> local_device_state,
|
|
||||||
std::string device_kind, int host_id = 0)
|
// Return the client that owns this device.
|
||||||
|
virtual PjRtClient* client() const = 0;
|
||||||
|
|
||||||
|
// Whether client can issue command to this device.
|
||||||
|
virtual bool IsAddressable() const = 0;
|
||||||
|
|
||||||
|
// The ID of this device. IDs are unique among devices of this type
|
||||||
|
// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
|
||||||
|
// hosts' devices. This is the ID that should be used in a DeviceAssignment.
|
||||||
|
virtual int id() const = 0;
|
||||||
|
|
||||||
|
// The task ID of this device according to TpuTopology. This is not the same
|
||||||
|
// as PjRtClient::host_id() in a multi-task setting, where each client can see
|
||||||
|
// devices from all tasks, but only a subset of them are addressable and have
|
||||||
|
// the same task_id as the client.
|
||||||
|
virtual int host_id() const = 0;
|
||||||
|
|
||||||
|
// Opaque hardware ID, e.g., the CUDA device number, useful for identifying
|
||||||
|
// which GPU when interacting with non-JAX code. In general, not guaranteed to
|
||||||
|
// be dense, and -1 if undefined.
|
||||||
|
virtual int local_hardware_id() const = 0;
|
||||||
|
|
||||||
|
// A vendor-dependent string that uniquely identifies the kind of device,
|
||||||
|
// e.g., "Tesla V100-SXM2-16GB". May be used to determine whether two GPUs are
|
||||||
|
// compatible compilation.
|
||||||
|
virtual const std::string& device_kind() const = 0;
|
||||||
|
|
||||||
|
virtual std::string DebugString() const = 0;
|
||||||
|
|
||||||
|
// Transfer the given literal to the infeed queue.
|
||||||
|
virtual Status TransferToInfeed(const LiteralSlice& literal) const = 0;
|
||||||
|
|
||||||
|
// Transfer and return a value of the given shape from the outfeed queue.
|
||||||
|
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class PjRtStreamExecutorDevice : public PjRtDevice {
|
||||||
|
public:
|
||||||
|
explicit PjRtStreamExecutorDevice(
|
||||||
|
int id, std::unique_ptr<LocalDeviceState> local_device_state,
|
||||||
|
std::string device_kind, int host_id = 0)
|
||||||
: id_(id),
|
: id_(id),
|
||||||
local_device_id_(
|
device_ordinal_(
|
||||||
local_device_state ? local_device_state->device_ordinal() : -1),
|
local_device_state ? local_device_state->device_ordinal() : -1),
|
||||||
local_device_state_(std::move(local_device_state)),
|
local_device_state_(std::move(local_device_state)),
|
||||||
host_id_(host_id),
|
host_id_(host_id),
|
||||||
device_kind_(std::move(device_kind)) {}
|
device_kind_(std::move(device_kind)) {}
|
||||||
virtual ~PjRtDevice() {}
|
~PjRtStreamExecutorDevice() override {}
|
||||||
|
|
||||||
// Must set client exactly once.
|
// Must set client exactly once.
|
||||||
void SetClient(PjRtClient* client) {
|
void SetClient(PjRtClient* client) {
|
||||||
@ -84,14 +125,25 @@ class PjRtDevice {
|
|||||||
client_ = client;
|
client_ = client;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Task ID. This is always 0 on single-task setup.
|
||||||
|
int host_id() const override { return host_id_; }
|
||||||
|
|
||||||
|
// Return `platform_id` from client.
|
||||||
|
PjRtPlatformId platform_id() const;
|
||||||
|
|
||||||
|
// Return `platform_name` from client.
|
||||||
|
const std::string& platform_name() const;
|
||||||
|
|
||||||
|
PjRtClient* client() const override { return client_; }
|
||||||
|
|
||||||
// The ID of this device. IDs are unique among devices of this type
|
// The ID of this device. IDs are unique among devices of this type
|
||||||
// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
|
// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
|
||||||
// hosts' devices. This is the ID that should be used in a DeviceAssignment.
|
// hosts' devices. This is the ID that should be used in a DeviceAssignment.
|
||||||
int id() const { return id_; }
|
int id() const override { return id_; }
|
||||||
|
|
||||||
bool IsLocalDevice() const { return local_device_id_ != -1; }
|
bool IsAddressable() const override { return device_ordinal_ != -1; }
|
||||||
|
|
||||||
int local_device_id() const { return local_device_id_; }
|
int local_hardware_id() const override { return device_ordinal_; }
|
||||||
|
|
||||||
// If this is a device local to this host, returns a LocalDeviceState object
|
// If this is a device local to this host, returns a LocalDeviceState object
|
||||||
// that can be used to manipulate the device. Returns nullptr if the device is
|
// that can be used to manipulate the device. Returns nullptr if the device is
|
||||||
@ -105,32 +157,21 @@ class PjRtDevice {
|
|||||||
// is not local to this host.
|
// is not local to this host.
|
||||||
StatusOr<LocalDeviceState*> GetLocalDeviceState() const;
|
StatusOr<LocalDeviceState*> GetLocalDeviceState() const;
|
||||||
|
|
||||||
// The ID of this device's host. This is always 0 on single-host platforms.
|
|
||||||
int host_id() const { return host_id_; }
|
|
||||||
|
|
||||||
// Return `platform_id` from client.
|
|
||||||
PjRtPlatformId platform_id() const;
|
|
||||||
|
|
||||||
// Return `platform_name` from client.
|
|
||||||
const std::string& platform_name() const;
|
|
||||||
|
|
||||||
// A vendor-dependent string that uniquely identifies the kind of device.
|
// A vendor-dependent string that uniquely identifies the kind of device.
|
||||||
const std::string& device_kind() const { return device_kind_; }
|
const std::string& device_kind() const override { return device_kind_; }
|
||||||
|
|
||||||
virtual std::string DebugString() const;
|
std::string DebugString() const override;
|
||||||
|
|
||||||
PjRtClient* client() const { return client_; }
|
|
||||||
|
|
||||||
// Transfer the given literal to the infeed queue of the given localdevice.
|
// Transfer the given literal to the infeed queue of the given localdevice.
|
||||||
virtual Status TransferToInfeed(const LiteralSlice& literal) const;
|
Status TransferToInfeed(const LiteralSlice& literal) const override;
|
||||||
|
|
||||||
// Transfer and return a value of the given shape from the outfeed of the
|
// Transfer and return a value of the given shape from the outfeed of the
|
||||||
// given device.
|
// given device.
|
||||||
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const;
|
StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const int id_;
|
const int id_;
|
||||||
const int local_device_id_; // -1 means not local.
|
const int device_ordinal_; // -1 means not local.
|
||||||
const std::unique_ptr<LocalDeviceState> local_device_state_;
|
const std::unique_ptr<LocalDeviceState> local_device_state_;
|
||||||
const int host_id_;
|
const int host_id_;
|
||||||
const std::string device_kind_;
|
const std::string device_kind_;
|
||||||
@ -196,9 +237,7 @@ class PjRtClient {
|
|||||||
const std::vector<std::unique_ptr<PjRtDevice>>& devices() const {
|
const std::vector<std::unique_ptr<PjRtDevice>>& devices() const {
|
||||||
return devices_;
|
return devices_;
|
||||||
}
|
}
|
||||||
const std::vector<PjRtDevice*>& local_devices() const {
|
absl::Span<PjRtDevice* const> local_devices() const { return local_devices_; }
|
||||||
return local_devices_;
|
|
||||||
}
|
|
||||||
const std::map<int, PjRtDevice*>& id_to_device() const {
|
const std::map<int, PjRtDevice*>& id_to_device() const {
|
||||||
return id_to_device_;
|
return id_to_device_;
|
||||||
}
|
}
|
||||||
@ -207,11 +246,13 @@ class PjRtClient {
|
|||||||
const std::string& platform_name() const { return platform_name_; }
|
const std::string& platform_name() const { return platform_name_; }
|
||||||
|
|
||||||
LocalDeviceState& device_state(int device_ordinal) const {
|
LocalDeviceState& device_state(int device_ordinal) const {
|
||||||
return *local_devices_.at(device_ordinal)->local_device_state();
|
return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
|
||||||
|
local_devices_.at(device_ordinal))
|
||||||
|
->local_device_state();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a local PjRtDevice for a given `local_device_id`.
|
// Return an addressable PjRtDevice for a given `device_id`.
|
||||||
virtual StatusOr<PjRtDevice*> LookupLocalDevice(int local_device_id) const;
|
virtual StatusOr<PjRtDevice*> LookupAddressableDevice(int device_id) const;
|
||||||
|
|
||||||
LocalClient* client() const { return client_; }
|
LocalClient* client() const { return client_; }
|
||||||
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
|
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
|
||||||
@ -791,6 +832,7 @@ class PjRtExecutable {
|
|||||||
virtual absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
|
virtual absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
|
||||||
const = 0;
|
const = 0;
|
||||||
|
|
||||||
|
// An addressable_device is one which the client can issue commands to.
|
||||||
// addressable_devices()[i] is the Device to which
|
// addressable_devices()[i] is the Device to which
|
||||||
// addressable_device_logical_ids()[i] is assigned.
|
// addressable_device_logical_ids()[i] is assigned.
|
||||||
virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0;
|
virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0;
|
||||||
|
|||||||
@ -26,14 +26,14 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
class PjRtTpuDevice : public PjRtDevice {
|
class PjRtTpuDevice : public PjRtStreamExecutorDevice {
|
||||||
public:
|
public:
|
||||||
PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core,
|
PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core,
|
||||||
std::unique_ptr<LocalDeviceState> local_device_state,
|
std::unique_ptr<LocalDeviceState> local_device_state,
|
||||||
int host_id, const std::array<int, 3>& coords,
|
int host_id, const std::array<int, 3>& coords,
|
||||||
std::string device_kind)
|
std::string device_kind)
|
||||||
: PjRtDevice(core.Id(), std::move(local_device_state),
|
: PjRtStreamExecutorDevice(core.Id(), std::move(local_device_state),
|
||||||
std::move(device_kind), host_id),
|
std::move(device_kind), host_id),
|
||||||
core_(core),
|
core_(core),
|
||||||
coords_(coords) {}
|
coords_(coords) {}
|
||||||
|
|
||||||
|
|||||||
@ -214,11 +214,9 @@ StatusOr<std::vector<int64>> StridesToLayout(absl::Span<int64 const> dims,
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
|
StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
|
||||||
const se::Platform* platform =
|
if (device.client()->platform_id() == kCpuId) {
|
||||||
device.local_device_state()->executor()->platform();
|
|
||||||
if (platform->id() == se::host::kHostPlatformId) {
|
|
||||||
return kDLCPU;
|
return kDLCPU;
|
||||||
} else if (platform->id() == se::cuda::kCudaPlatformId) {
|
} else if (device.client()->platform_id() == kGpuId) {
|
||||||
return kDLGPU;
|
return kDLGPU;
|
||||||
}
|
}
|
||||||
return InvalidArgument("Device %s cannot be used as a DLPack device.",
|
return InvalidArgument("Device %s cannot be used as a DLPack device.",
|
||||||
@ -228,7 +226,7 @@ StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
|
|||||||
StatusOr<DLContext> DLContextForDevice(const PjRtDevice& device) {
|
StatusOr<DLContext> DLContextForDevice(const PjRtDevice& device) {
|
||||||
DLContext context;
|
DLContext context;
|
||||||
TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
|
TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device));
|
||||||
context.device_id = device.local_device_id();
|
context.device_id = device.local_hardware_id();
|
||||||
return context;
|
return context;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -241,14 +239,14 @@ StatusOr<PjRtDevice*> DeviceForDLContext(const PjRtClient& client,
|
|||||||
"DLPack CPU device type mismatch with PjRtClient platform %s",
|
"DLPack CPU device type mismatch with PjRtClient platform %s",
|
||||||
client.platform_name());
|
client.platform_name());
|
||||||
}
|
}
|
||||||
return client.LookupLocalDevice(context.device_id);
|
return client.LookupAddressableDevice(context.device_id);
|
||||||
case kDLGPU:
|
case kDLGPU:
|
||||||
if (client.platform_id() != kGpuId) {
|
if (client.platform_id() != kGpuId) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"DLPack GPU device type mismatch with PjRtClient platform %s",
|
"DLPack GPU device type mismatch with PjRtClient platform %s",
|
||||||
client.platform_name());
|
client.platform_name());
|
||||||
}
|
}
|
||||||
return client.LookupLocalDevice(context.device_id);
|
return client.LookupAddressableDevice(context.device_id);
|
||||||
default:
|
default:
|
||||||
return InvalidArgument("Unknown/unsupported DLPack device type %d",
|
return InvalidArgument("Unknown/unsupported DLPack device type %d",
|
||||||
context.device_type);
|
context.device_type);
|
||||||
@ -297,7 +295,7 @@ StatusOr<py::capsule> BufferToDLPackManagedTensor(py::handle py_buffer,
|
|||||||
pack->tensor.manager_ctx = pack.get();
|
pack->tensor.manager_ctx = pack.get();
|
||||||
pack->tensor.deleter = DLPackTensorDeleter;
|
pack->tensor.deleter = DLPackTensorDeleter;
|
||||||
TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device()));
|
TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device()));
|
||||||
dt.ctx.device_id = buffer->buffer()->device()->local_device_id();
|
dt.ctx.device_id = buffer->buffer()->device()->local_hardware_id();
|
||||||
dt.ndim = buffer->buffer()->on_host_shape().dimensions_size();
|
dt.ndim = buffer->buffer()->on_host_shape().dimensions_size();
|
||||||
TF_ASSIGN_OR_RETURN(dt.dtype,
|
TF_ASSIGN_OR_RETURN(dt.dtype,
|
||||||
PrimitiveTypeToDLDataType(
|
PrimitiveTypeToDLDataType(
|
||||||
|
|||||||
@ -342,11 +342,7 @@ StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed(
|
|||||||
const PjRtDevice* device, const Shape& shape) {
|
const PjRtDevice* device, const Shape& shape) {
|
||||||
std::shared_ptr<Literal> literal_shared;
|
std::shared_ptr<Literal> literal_shared;
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
TF_ASSIGN_OR_RETURN(Literal literal, device->TransferFromOutfeed(shape));
|
||||||
device->GetLocalDeviceState());
|
|
||||||
TF_ASSIGN_OR_RETURN(Literal literal,
|
|
||||||
local_device->client()->TransferFromOutfeedLocal(
|
|
||||||
shape, local_device->device_ordinal()));
|
|
||||||
|
|
||||||
return absl::make_unique<Literal>(std::move(literal));
|
return absl::make_unique<Literal>(std::move(literal));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -86,8 +86,8 @@ StatusOr<std::uintptr_t> PyBuffer::UnsafeBufferPointer() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<py::dict> PyBuffer::CudaArrayInterface() const {
|
StatusOr<py::dict> PyBuffer::CudaArrayInterface() const {
|
||||||
if (buffer_->device()->local_device_state()->executor()->platform_kind() !=
|
// TODO(zhangqiaorjc): Differentiate between NVidia and other GPUs.
|
||||||
se::PlatformKind::kCuda) {
|
if (buffer_->client()->platform_id() != kGpuId) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"__cuda_array_interface__ is only defined for NVidia GPU buffers.");
|
"__cuda_array_interface__ is only defined for NVidia GPU buffers.");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -37,6 +37,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:computation_placer",
|
"//tensorflow/compiler/xla/service:computation_placer",
|
||||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||||
"//tensorflow/core/framework:allocator",
|
"//tensorflow/core/framework:allocator",
|
||||||
|
"//tensorflow/core/platform:casts",
|
||||||
"//tensorflow/core/platform:env",
|
"//tensorflow/core/platform:env",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
|
|||||||
@ -37,8 +37,8 @@ namespace xla {
|
|||||||
|
|
||||||
TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
||||||
int core_on_chip)
|
int core_on_chip)
|
||||||
: xla::PjRtDevice(id, /*local_device_state=*/nullptr,
|
: xla::PjRtStreamExecutorDevice(id, /*local_device_state=*/nullptr,
|
||||||
/*device_kind=*/"Cloud TPU", host_id),
|
/*device_kind=*/"Cloud TPU", host_id),
|
||||||
coords_(coords),
|
coords_(coords),
|
||||||
core_on_chip_(core_on_chip) {}
|
core_on_chip_(core_on_chip) {}
|
||||||
|
|
||||||
@ -531,7 +531,7 @@ PyTpuExecutable::PyTpuExecutable(
|
|||||||
<< "Inserting duplicate replica:" << replica;
|
<< "Inserting duplicate replica:" << replica;
|
||||||
executables_[replica] =
|
executables_[replica] =
|
||||||
client_->driver()->LoadProgram(device_id, compiled_program.get(), {});
|
client_->driver()->LoadProgram(device_id, compiled_program.get(), {});
|
||||||
addressable_device_logical_ids_.emplace_back(replica, partition);
|
local_logical_device_ids_.emplace_back(replica, partition);
|
||||||
local_devices_.push_back(device);
|
local_devices_.push_back(device);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -711,8 +711,8 @@ PyTpuExecutable::ExecuteOnLocalDevices(
|
|||||||
// long time and we want all cores to be scheduled in parallel.
|
// long time and we want all cores to be scheduled in parallel.
|
||||||
thread_pool->Schedule([this, i, argument_handles, &results, &results_lock,
|
thread_pool->Schedule([this, i, argument_handles, &results, &results_lock,
|
||||||
&execute_semaphore]() {
|
&execute_semaphore]() {
|
||||||
const int replica = addressable_device_logical_ids_[i].first;
|
const int replica = local_logical_device_ids_[i].first;
|
||||||
const int partition = addressable_device_logical_ids_[i].second;
|
const int partition = local_logical_device_ids_[i].second;
|
||||||
RunId run_id;
|
RunId run_id;
|
||||||
auto result = ExecuteHelper(argument_handles, argument_handles[i],
|
auto result = ExecuteHelper(argument_handles, argument_handles[i],
|
||||||
replica, partition, run_id);
|
replica, partition, run_id);
|
||||||
|
|||||||
@ -32,13 +32,14 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/status.h"
|
#include "tensorflow/compiler/xla/status.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/threadpool.h"
|
#include "tensorflow/core/platform/threadpool.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
constexpr char kTpuPlatform[] = "tpu";
|
constexpr char kTpuPlatform[] = "tpu";
|
||||||
|
|
||||||
class TpuDevice : public PjRtDevice {
|
class TpuDevice : public PjRtStreamExecutorDevice {
|
||||||
public:
|
public:
|
||||||
TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
||||||
int core_on_chip);
|
int core_on_chip);
|
||||||
@ -298,9 +299,8 @@ class PyTpuExecutable {
|
|||||||
return device_assignment_;
|
return device_assignment_;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<std::pair<int, int>>& addressable_device_logical_ids()
|
const std::vector<std::pair<int, int>>& local_logical_device_ids() const {
|
||||||
const {
|
return local_logical_device_ids_;
|
||||||
return addressable_device_logical_ids_;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() const {
|
const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() const {
|
||||||
@ -341,14 +341,16 @@ class PyTpuExecutable {
|
|||||||
|
|
||||||
// The replica and partition indices of device_assignment_ to be run by this
|
// The replica and partition indices of device_assignment_ to be run by this
|
||||||
// client. On single-host platforms without partitioning, this is all replicas
|
// client. On single-host platforms without partitioning, this is all replicas
|
||||||
// (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
|
// (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case
|
||||||
// case on multi-host platforms. If there are 4 replicas and 2 partitions on a
|
// on multi-host platforms.
|
||||||
// single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
|
// If there are 4 replicas and 2 partitions on a single host platform, size of
|
||||||
std::vector<std::pair<int, int>> addressable_device_logical_ids_;
|
// local_logical_device_ids_ is 4*2 = 8.
|
||||||
|
std::vector<std::pair<int, int>> local_logical_device_ids_;
|
||||||
|
|
||||||
// local_devices_[i] is the Device to which addressable_device_logical_ids_[i]
|
// local_devices_[i] is the Device to which local_logical_device_ids_[i] is
|
||||||
// is assigned. shared_ptrs instead of unique_ptrs to play well with the
|
// assigned.
|
||||||
// Python bindings (see xla.cc).
|
// shared_ptrs instead of unique_ptrs to play well with the Python bindings
|
||||||
|
// (see xla.cc).
|
||||||
std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
|
std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
|
||||||
|
|
||||||
xla::Shape result_shape_;
|
xla::Shape result_shape_;
|
||||||
|
|||||||
@ -186,7 +186,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
|||||||
|
|
||||||
py::class_<PyTpuExecutable>(m, "TpuExecutable")
|
py::class_<PyTpuExecutable>(m, "TpuExecutable")
|
||||||
.def("local_logical_device_ids",
|
.def("local_logical_device_ids",
|
||||||
&PyTpuExecutable::addressable_device_logical_ids)
|
&PyTpuExecutable::local_logical_device_ids)
|
||||||
.def("local_devices", &PyTpuExecutable::local_devices)
|
.def("local_devices", &PyTpuExecutable::local_devices)
|
||||||
.def_property_readonly("client", &PyTpuExecutable::client)
|
.def_property_readonly("client", &PyTpuExecutable::client)
|
||||||
.def("size_of_generated_code_in_bytes",
|
.def("size_of_generated_code_in_bytes",
|
||||||
|
|||||||
@ -149,7 +149,10 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
.def_property_readonly("host_id", &PjRtDevice::host_id,
|
.def_property_readonly("host_id", &PjRtDevice::host_id,
|
||||||
"Integer ID of this device's host.\n\n"
|
"Integer ID of this device's host.\n\n"
|
||||||
"This is always 0 except on multi-host platforms.")
|
"This is always 0 except on multi-host platforms.")
|
||||||
.def_property_readonly("platform", &PjRtDevice::platform_name)
|
.def_property_readonly("platform",
|
||||||
|
[](const PjRtDevice& device) {
|
||||||
|
return device.client()->platform_name();
|
||||||
|
})
|
||||||
.def_property_readonly("device_kind", &PjRtDevice::device_kind)
|
.def_property_readonly("device_kind", &PjRtDevice::device_kind)
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"client",
|
"client",
|
||||||
@ -381,10 +384,10 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
[](PyExecutable* exec) {
|
[](PyExecutable* exec) {
|
||||||
auto span = exec->addressable_device_logical_ids();
|
auto span = exec->addressable_device_logical_ids();
|
||||||
// Not on dispatch critical path, so ok to have heap allocation.
|
// Not on dispatch critical path, so ok to have heap allocation.
|
||||||
std::vector<std::pair<int, int>> addressable_device_logical_ids;
|
std::vector<std::pair<int, int>> addressable_device_logic_ids;
|
||||||
addressable_device_logical_ids.reserve(span.size());
|
addressable_device_logic_ids.reserve(span.size());
|
||||||
for (const auto& logical_device_id : span) {
|
for (const auto& logical_device_id : span) {
|
||||||
addressable_device_logical_ids.push_back(std::make_pair(
|
addressable_device_logic_ids.push_back(std::make_pair(
|
||||||
logical_device_id.replica, logical_device_id.partition));
|
logical_device_id.replica, logical_device_id.partition));
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user