Extract a PjRtClient interface.
- Extract a pure interface - Current implementation is renamed to PjRtStreamExecutorClient. TODO: split into a pjrt_stream_executor.{h,cc} PiperOrigin-RevId: 346470160 Change-Id: I6235d9eed58cfd281c59f441dfed67ea3b9035ed
This commit is contained in:
parent
b9187102b6
commit
93dfb9b68f
@ -40,7 +40,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
|||||||
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
||||||
ClientLibrary::GetOrCreateLocalClient(options));
|
ClientLibrary::GetOrCreateLocalClient(options));
|
||||||
|
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||||
for (int i = 0; i < client->device_count(); ++i) {
|
for (int i = 0; i < client->device_count(); ++i) {
|
||||||
se::StreamExecutorConfig config;
|
se::StreamExecutorConfig config;
|
||||||
config.ordinal = i;
|
config.ordinal = i;
|
||||||
@ -57,11 +57,11 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
|||||||
devices.push_back(std::move(device));
|
devices.push_back(std::move(device));
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_unique<PjRtClient>(
|
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
|
||||||
kCpuName, client, std::move(devices), /*host_id=*/0,
|
kCpuName, client, std::move(devices), /*host_id=*/0,
|
||||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||||
/*should_stage_host_to_device_transfers=*/false,
|
/*should_stage_host_to_device_transfers=*/false,
|
||||||
/*gpu_run_options=*/nullptr);
|
/*gpu_run_options=*/nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -35,9 +35,9 @@ namespace xla {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// A custom PjRtClient that overrides the device assignment method.
|
// A custom PjRtClient that overrides the device assignment method.
|
||||||
class GpuClient : public xla::PjRtClient {
|
class GpuClient : public xla::PjRtStreamExecutorClient {
|
||||||
public:
|
public:
|
||||||
using xla::PjRtClient::PjRtClient;
|
using xla::PjRtStreamExecutorClient::PjRtStreamExecutorClient;
|
||||||
|
|
||||||
xla::StatusOr<xla::DeviceAssignment> GetDefaultDeviceAssignment(
|
xla::StatusOr<xla::DeviceAssignment> GetDefaultDeviceAssignment(
|
||||||
int num_replicas, int num_partitions) const override;
|
int num_replicas, int num_partitions) const override;
|
||||||
@ -55,7 +55,8 @@ xla::StatusOr<xla::DeviceAssignment> GpuClient::GetDefaultDeviceAssignment(
|
|||||||
return assignment;
|
return assignment;
|
||||||
}
|
}
|
||||||
// Fallback to default global device assignment if we can't run locally.
|
// Fallback to default global device assignment if we can't run locally.
|
||||||
return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions);
|
return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
|
||||||
|
num_partitions);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Builds an xla::LocalClient for the GPU platform.
|
// Builds an xla::LocalClient for the GPU platform.
|
||||||
@ -225,9 +226,9 @@ StatusOr<std::string> NcclIdStore::GetNcclUniqueId(
|
|||||||
return result.first->second;
|
return result.first->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> BuildLocalDevices(
|
||||||
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
|
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||||
for (auto& local_device : local_device_states) {
|
for (auto& local_device : local_device_states) {
|
||||||
int device_ordinal = local_device->device_ordinal();
|
int device_ordinal = local_device->device_ordinal();
|
||||||
const se::DeviceDescription& description =
|
const se::DeviceDescription& description =
|
||||||
@ -243,7 +244,7 @@ std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
|
|||||||
Status BuildDistributedDevices(
|
Status BuildDistributedDevices(
|
||||||
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states,
|
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states,
|
||||||
std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id,
|
std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id,
|
||||||
std::vector<std::unique_ptr<PjRtDevice>>* devices,
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>* devices,
|
||||||
gpu::GpuExecutableRunOptions* gpu_executable_run_options) {
|
gpu::GpuExecutableRunOptions* gpu_executable_run_options) {
|
||||||
LocalTopologyProto local_topology;
|
LocalTopologyProto local_topology;
|
||||||
local_topology.set_node_id(node_id);
|
local_topology.set_node_id(node_id);
|
||||||
@ -322,7 +323,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient(
|
|||||||
auto host_memory_allocator =
|
auto host_memory_allocator =
|
||||||
GetGpuHostAllocator(local_device_states.front()->executor());
|
GetGpuHostAllocator(local_device_states.front()->executor());
|
||||||
|
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||||
auto gpu_run_options = absl::make_unique<gpu::GpuExecutableRunOptions>();
|
auto gpu_run_options = absl::make_unique<gpu::GpuExecutableRunOptions>();
|
||||||
if (distributed_client) {
|
if (distributed_client) {
|
||||||
TF_RETURN_IF_ERROR(BuildDistributedDevices(
|
TF_RETURN_IF_ERROR(BuildDistributedDevices(
|
||||||
|
@ -41,7 +41,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
|
|||||||
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
||||||
ClientLibrary::GetOrCreateLocalClient(options));
|
ClientLibrary::GetOrCreateLocalClient(options));
|
||||||
|
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||||
se::StreamExecutor* executor =
|
se::StreamExecutor* executor =
|
||||||
client->backend().stream_executor(0).ValueOrDie();
|
client->backend().stream_executor(0).ValueOrDie();
|
||||||
auto device_state = absl::make_unique<LocalDeviceState>(
|
auto device_state = absl::make_unique<LocalDeviceState>(
|
||||||
@ -51,11 +51,11 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
|
|||||||
absl::make_unique<InterpreterDevice>(0, std::move(device_state));
|
absl::make_unique<InterpreterDevice>(0, std::move(device_state));
|
||||||
devices.push_back(std::move(device));
|
devices.push_back(std::move(device));
|
||||||
|
|
||||||
return std::make_unique<PjRtClient>(
|
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
|
||||||
"interpreter", client, std::move(devices), /*host_id=*/0,
|
"interpreter", client, std::move(devices), /*host_id=*/0,
|
||||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||||
/*should_stage_host_to_device_transfers=*/false,
|
/*should_stage_host_to_device_transfers=*/false,
|
||||||
/*gpu_run_options=*/nullptr);
|
/*gpu_run_options=*/nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -184,9 +184,9 @@ class CpuAllocator : public tensorflow::Allocator {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
PjRtClient::PjRtClient(
|
PjRtStreamExecutorClient::PjRtStreamExecutorClient(
|
||||||
std::string platform_name, LocalClient* client,
|
std::string platform_name, LocalClient* client,
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id,
|
||||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||||
bool should_stage_host_to_device_transfers,
|
bool should_stage_host_to_device_transfers,
|
||||||
@ -195,7 +195,7 @@ PjRtClient::PjRtClient(
|
|||||||
platform_name_(std::move(platform_name)),
|
platform_name_(std::move(platform_name)),
|
||||||
client_(client),
|
client_(client),
|
||||||
host_memory_allocator_(std::move(host_memory_allocator)),
|
host_memory_allocator_(std::move(host_memory_allocator)),
|
||||||
devices_(std::move(devices)),
|
owned_devices_(std::move(devices)),
|
||||||
host_id_(host_id),
|
host_id_(host_id),
|
||||||
owned_allocator_(std::move(allocator)),
|
owned_allocator_(std::move(allocator)),
|
||||||
should_stage_host_to_device_transfers_(
|
should_stage_host_to_device_transfers_(
|
||||||
@ -213,7 +213,9 @@ PjRtClient::PjRtClient(
|
|||||||
host_memory_allocator_ = std::make_unique<CpuAllocator>();
|
host_memory_allocator_ = std::make_unique<CpuAllocator>();
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const std::unique_ptr<PjRtDevice>& device : devices_) {
|
for (const std::unique_ptr<PjRtStreamExecutorDevice>& device :
|
||||||
|
owned_devices_) {
|
||||||
|
devices_.push_back(device.get());
|
||||||
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
|
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
|
||||||
<< "Duplicate device id: " << device->id();
|
<< "Duplicate device id: " << device->id();
|
||||||
|
|
||||||
@ -225,21 +227,21 @@ PjRtClient::PjRtClient(
|
|||||||
CHECK(local_devices_[idx] == nullptr) << idx;
|
CHECK(local_devices_[idx] == nullptr) << idx;
|
||||||
local_devices_[idx] = device.get();
|
local_devices_[idx] = device.get();
|
||||||
}
|
}
|
||||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device.get())
|
device->SetClient(this);
|
||||||
->SetClient(this);
|
|
||||||
}
|
}
|
||||||
for (int idx = 0; idx < local_devices_.size(); ++idx) {
|
for (int idx = 0; idx < local_devices_.size(); ++idx) {
|
||||||
CHECK(local_devices_[idx] != nullptr) << idx;
|
CHECK(local_devices_[idx] != nullptr) << idx;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<DeviceAssignment> PjRtClient::GetDefaultDeviceAssignment(
|
StatusOr<DeviceAssignment> PjRtStreamExecutorClient::GetDefaultDeviceAssignment(
|
||||||
int num_replicas, int num_partitions) const {
|
int num_replicas, int num_partitions) const {
|
||||||
return client_->backend().computation_placer()->AssignDevices(num_replicas,
|
return client_->backend().computation_placer()->AssignDevices(num_replicas,
|
||||||
num_partitions);
|
num_partitions);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<HloCostAnalysis> PjRtClient::GetHloCostAnalysis() {
|
std::unique_ptr<HloCostAnalysis>
|
||||||
|
PjRtStreamExecutorClient::GetHloCostAnalysis() {
|
||||||
return absl::make_unique<HloCostAnalysis>(
|
return absl::make_unique<HloCostAnalysis>(
|
||||||
client_->backend().compiler()->ShapeSizeBytesFunction());
|
client_->backend().compiler()->ShapeSizeBytesFunction());
|
||||||
}
|
}
|
||||||
@ -349,12 +351,13 @@ StatusOr<std::unique_ptr<PjRtBuffer>> AllocateDestinationBuffer(
|
|||||||
return InvalidArgument("Can't make a buffer from an empty tuple");
|
return InvalidArgument("Can't make a buffer from an empty tuple");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client);
|
||||||
TransferManager* transfer_manager =
|
TransferManager* transfer_manager =
|
||||||
client->client()->backend().transfer_manager();
|
se_client->client()->backend().transfer_manager();
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer,
|
||||||
ScopedShapedBuffer dst_buffer,
|
transfer_manager->AllocateScopedShapedBuffer(
|
||||||
transfer_manager->AllocateScopedShapedBuffer(
|
on_host_shape, se_client->allocator(),
|
||||||
on_host_shape, client->allocator(), local_device->device_ordinal()));
|
local_device->device_ordinal()));
|
||||||
if (local_device->allocation_model() ==
|
if (local_device->allocation_model() ==
|
||||||
LocalDeviceState::kComputeSynchronized) {
|
LocalDeviceState::kComputeSynchronized) {
|
||||||
if (copy_stream == nullptr) {
|
if (copy_stream == nullptr) {
|
||||||
@ -546,13 +549,15 @@ void PjRtBuffer::ScopedHold::AddToInput(
|
|||||||
|
|
||||||
bool PjRtBuffer::IsOnCpu() const { return client()->platform_id() == kCpuId; }
|
bool PjRtBuffer::IsOnCpu() const { return client()->platform_id() == kCpuId; }
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
|
StatusOr<std::unique_ptr<PjRtBuffer>>
|
||||||
|
PjRtStreamExecutorClient::BufferFromHostBuffer(
|
||||||
const void* data, const Shape& shape,
|
const void* data, const Shape& shape,
|
||||||
HostBufferSemantics host_buffer_semantics,
|
HostBufferSemantics host_buffer_semantics,
|
||||||
std::shared_ptr<void> buffer_reference, PjRtDevice* device) {
|
std::shared_ptr<void> buffer_reference, PjRtDevice* device) {
|
||||||
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostBuffer");
|
tensorflow::profiler::TraceMe traceme(
|
||||||
VLOG(2) << "PjRtClient::BufferFromHostBuffer: shape: " << shape.ToString()
|
"PjRtStreamExecutorClient::BufferFromHostBuffer");
|
||||||
<< " device: " << device->DebugString();
|
VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: "
|
||||||
|
<< shape.ToString() << " device: " << device->DebugString();
|
||||||
if (shape.IsTuple()) {
|
if (shape.IsTuple()) {
|
||||||
return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
|
return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
|
||||||
}
|
}
|
||||||
@ -712,17 +717,19 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
|
|||||||
return py_buffer;
|
return py_buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
StatusOr<std::unique_ptr<PjRtBuffer>>
|
||||||
const Shape& shape, PjRtDevice* device) {
|
PjRtStreamExecutorClient::CreateUninitializedBuffer(const Shape& shape,
|
||||||
|
PjRtDevice* device) {
|
||||||
return CreateUninitializedBuffer(shape, device, nullptr);
|
return CreateUninitializedBuffer(shape, device, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
StatusOr<std::unique_ptr<PjRtBuffer>>
|
||||||
|
PjRtStreamExecutorClient::CreateUninitializedBuffer(
|
||||||
const Shape& shape, PjRtDevice* device,
|
const Shape& shape, PjRtDevice* device,
|
||||||
std::shared_ptr<BufferSequencingEvent> definition_event) {
|
std::shared_ptr<BufferSequencingEvent> definition_event) {
|
||||||
tensorflow::profiler::TraceMe traceme(
|
tensorflow::profiler::TraceMe traceme(
|
||||||
"PjRtClient::CreateUninitializedBuffer");
|
"PjRtStreamExecutorClient::CreateUninitializedBuffer");
|
||||||
VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: "
|
VLOG(2) << "PjRtStreamExecutorClient::CreateUninitializedBuffer: shape: "
|
||||||
<< shape.ToString() << " device: " << device->DebugString();
|
<< shape.ToString() << " device: " << device->DebugString();
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||||
@ -738,10 +745,12 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
|||||||
definition_event);
|
definition_event);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
|
StatusOr<std::unique_ptr<PjRtBuffer>>
|
||||||
const LiteralSlice& literal, PjRtDevice* device) {
|
PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
|
||||||
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostLiteral");
|
PjRtDevice* device) {
|
||||||
VLOG(2) << "PjRtClient::BufferFromHostLiteral: shape: "
|
tensorflow::profiler::TraceMe traceme(
|
||||||
|
"PjRtStreamExecutorClient::BufferFromHostLiteral");
|
||||||
|
VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostLiteral: shape: "
|
||||||
<< literal.shape().ToString() << " device: " << device->DebugString();
|
<< literal.shape().ToString() << " device: " << device->DebugString();
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||||
@ -798,7 +807,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
|
|||||||
return py_buffer;
|
return py_buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PjRtClient::MakeCrossHostReceiveBuffers(
|
void PjRtStreamExecutorClient::MakeCrossHostReceiveBuffers(
|
||||||
absl::Span<const Shape> shapes, PjRtDevice* device,
|
absl::Span<const Shape> shapes, PjRtDevice* device,
|
||||||
PjRtCrossHostRecvNotifier&& notifier) {
|
PjRtCrossHostRecvNotifier&& notifier) {
|
||||||
if (shapes.empty()) {
|
if (shapes.empty()) {
|
||||||
@ -851,14 +860,15 @@ StatusOr<Literal> PjRtStreamExecutorDevice::TransferFromOutfeed(
|
|||||||
shape, local_device->device_ordinal());
|
shape, local_device->device_ordinal());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<PjRtDevice*> PjRtClient::LookupAddressableDevice(int device_id) const {
|
StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
|
||||||
|
int local_hardware_id) const {
|
||||||
for (auto* device : local_devices_) {
|
for (auto* device : local_devices_) {
|
||||||
if (device_id == device->local_hardware_id()) {
|
if (local_hardware_id == device->local_hardware_id()) {
|
||||||
return device;
|
return device;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return InvalidArgument("No matching device found for device_id %d",
|
return InvalidArgument("No matching device found for local_hardware_id %d",
|
||||||
device_id);
|
local_hardware_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||||
@ -883,7 +893,8 @@ PjRtBuffer::~PjRtBuffer() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int64 PjRtBuffer::OnDeviceSizeInBytes() const {
|
int64 PjRtBuffer::OnDeviceSizeInBytes() const {
|
||||||
return client_->client()
|
return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||||
|
->client()
|
||||||
->backend()
|
->backend()
|
||||||
.transfer_manager()
|
.transfer_manager()
|
||||||
->GetByteSizeRequirement(on_device_shape_);
|
->GetByteSizeRequirement(on_device_shape_);
|
||||||
@ -1136,12 +1147,16 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
|
|||||||
host_value->value = std::make_shared<Literal>(host_shape);
|
host_value->value = std::make_shared<Literal>(host_shape);
|
||||||
ShapedBuffer shaped_buffer =
|
ShapedBuffer shaped_buffer =
|
||||||
device_buffer->AsShapedBuffer(host_shape, on_device_shape_);
|
device_buffer->AsShapedBuffer(host_shape, on_device_shape_);
|
||||||
client_->client()->backend().transfer_manager()->TransferLiteralFromDevice(
|
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||||
stream, shaped_buffer, host_value->value.get(),
|
->client()
|
||||||
[host_value](Status done_status) {
|
->backend()
|
||||||
host_value->status = done_status;
|
.transfer_manager()
|
||||||
host_value->ready.Notify();
|
->TransferLiteralFromDevice(stream, shaped_buffer,
|
||||||
});
|
host_value->value.get(),
|
||||||
|
[host_value](Status done_status) {
|
||||||
|
host_value->status = done_status;
|
||||||
|
host_value->ready.Notify();
|
||||||
|
});
|
||||||
|
|
||||||
auto usage_event = std::make_shared<BufferSequencingEvent>();
|
auto usage_event = std::make_shared<BufferSequencingEvent>();
|
||||||
StatusOr<EventPool::Handle> event_or =
|
StatusOr<EventPool::Handle> event_or =
|
||||||
@ -1170,7 +1185,7 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
|
|||||||
|
|
||||||
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
|
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
|
||||||
const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
|
const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
|
||||||
tensorflow::profiler::TraceMe traceme("PjRtClient::ToLiteral");
|
tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::ToLiteral");
|
||||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
|
TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
|
||||||
CopyToHostAsyncInternal(discard_cached_copy, layout));
|
CopyToHostAsyncInternal(discard_cached_copy, layout));
|
||||||
if (host_value == nullptr) {
|
if (host_value == nullptr) {
|
||||||
@ -1280,7 +1295,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
|
|||||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
|
TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
|
||||||
return dst_device->client()->BufferFromHostBuffer(
|
return dst_device->client()->BufferFromHostBuffer(
|
||||||
literal->untyped_data(), literal->shape(),
|
literal->untyped_data(), literal->shape(),
|
||||||
PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, dst_device);
|
PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy, nullptr,
|
||||||
|
dst_device);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
@ -1288,7 +1304,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
|
|||||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device)
|
tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device)
|
||||||
->GetLocalDeviceState());
|
->GetLocalDeviceState());
|
||||||
LocalDeviceState* transfer_local_device =
|
LocalDeviceState* transfer_local_device =
|
||||||
client_->EnqueueD2DTransfersOnSrcStream()
|
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||||
|
->EnqueueD2DTransfersOnSrcStream()
|
||||||
? tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
? tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||||
->local_device_state()
|
->local_device_state()
|
||||||
: dst_local_device;
|
: dst_local_device;
|
||||||
@ -1339,7 +1356,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status PjRtBuffer::CopyToRemoteDevice(absl::string_view serialized_descriptor) {
|
Status PjRtBuffer::CopyToRemoteDevice(absl::string_view serialized_descriptor) {
|
||||||
return client_->CopyToRemoteDevice(this, serialized_descriptor);
|
return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||||
|
->CopyToRemoteDevice(this, serialized_descriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status PjRtBuffer::BlockHostUntilReady() {
|
Status PjRtBuffer::BlockHostUntilReady() {
|
||||||
@ -1401,9 +1419,13 @@ StatusOr<TupleHandle> MakeTupleHelper(
|
|||||||
Shape on_host_shape = ShapeUtil::MakeTupleShape(host_shapes);
|
Shape on_host_shape = ShapeUtil::MakeTupleShape(host_shapes);
|
||||||
Shape on_device_shape = ShapeUtil::MakeTupleShape(device_shapes);
|
Shape on_device_shape = ShapeUtil::MakeTupleShape(device_shapes);
|
||||||
|
|
||||||
se::DeviceMemoryAllocator* allocator = client->allocator();
|
se::DeviceMemoryAllocator* allocator =
|
||||||
|
tensorflow::down_cast<PjRtStreamExecutorClient*>(client)->allocator();
|
||||||
TransferManager* transfer_manager =
|
TransferManager* transfer_manager =
|
||||||
client->client()->backend().transfer_manager();
|
tensorflow::down_cast<PjRtStreamExecutorClient*>(client)
|
||||||
|
->client()
|
||||||
|
->backend()
|
||||||
|
.transfer_manager();
|
||||||
se::Stream* stream = local_device->host_to_device_stream();
|
se::Stream* stream = local_device->host_to_device_stream();
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
se::OwningDeviceMemory root_table_memory,
|
se::OwningDeviceMemory root_table_memory,
|
||||||
@ -1467,14 +1489,6 @@ std::unique_ptr<PjRtBuffer> OutputBufferHelper(
|
|||||||
/*prefer_to_retain_reference=*/false);
|
/*prefer_to_retain_reference=*/false);
|
||||||
return pjrt_buffer;
|
return pjrt_buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
static PjRtDevice* LookupDevice(const PjRtClient& client, int device_id) {
|
|
||||||
auto it = client.id_to_device().find(device_id);
|
|
||||||
CHECK(it != client.id_to_device().end())
|
|
||||||
<< "Unknown device id: " << device_id;
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
|
PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
|
||||||
@ -1482,7 +1496,8 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
|
|||||||
bool parameter_is_tupled_arguments,
|
bool parameter_is_tupled_arguments,
|
||||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||||
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
|
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
|
||||||
std::vector<PjRtDevice*> addressable_devices, PjRtClient* client)
|
std::vector<PjRtDevice*> addressable_devices,
|
||||||
|
PjRtStreamExecutorClient* client)
|
||||||
: client_(client),
|
: client_(client),
|
||||||
device_assignment_(std::move(device_assignment)),
|
device_assignment_(std::move(device_assignment)),
|
||||||
parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
|
parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
|
||||||
@ -1505,7 +1520,7 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
|
|||||||
VLOG(1) << "PjRtStreamExecutorExecutable device_assignment:\n"
|
VLOG(1) << "PjRtStreamExecutorExecutable device_assignment:\n"
|
||||||
<< device_assignment_->ToString();
|
<< device_assignment_->ToString();
|
||||||
CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString();
|
CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString();
|
||||||
CHECK_LE(addressable_devices_.size(), client_->local_device_count())
|
CHECK_LE(addressable_devices_.size(), client_->addressable_device_count())
|
||||||
<< "Inconsistent local device count.";
|
<< "Inconsistent local device count.";
|
||||||
num_partitions = device_assignment_->computation_count();
|
num_partitions = device_assignment_->computation_count();
|
||||||
}
|
}
|
||||||
@ -1607,7 +1622,7 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
|
|||||||
absl::Span<const PjRtBuffer::ScopedHold> device_buffers,
|
absl::Span<const PjRtBuffer::ScopedHold> device_buffers,
|
||||||
absl::flat_hash_set<BufferSequencingEvent*>& events) const {
|
absl::flat_hash_set<BufferSequencingEvent*>& events) const {
|
||||||
std::vector<ExecutionInput> execution_inputs;
|
std::vector<ExecutionInput> execution_inputs;
|
||||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||||
// Lift tuple_handle outside the conditional so that the event it returns is
|
// Lift tuple_handle outside the conditional so that the event it returns is
|
||||||
// not destroyed until after the loop below that waits on events.
|
// not destroyed until after the loop below that waits on events.
|
||||||
absl::optional<TupleHandle> tuple_handle;
|
absl::optional<TupleHandle> tuple_handle;
|
||||||
@ -1630,8 +1645,10 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
|
|||||||
execution_input.MutableBuffers()->begin();
|
execution_input.MutableBuffers()->begin();
|
||||||
ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
|
ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
|
||||||
execution_input.MutableBuffers()->end();
|
execution_input.MutableBuffers()->end();
|
||||||
device_buffers[i].AddToInput(&input_iterator, iterator_end,
|
device_buffers[i].AddToInput(
|
||||||
&execution_input, client_->allocator());
|
&input_iterator, iterator_end, &execution_input,
|
||||||
|
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||||
|
->allocator());
|
||||||
CHECK(input_iterator == iterator_end);
|
CHECK(input_iterator == iterator_end);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1654,7 +1671,7 @@ StatusOr<ScopedShapedBuffer> PjRtStreamExecutorExecutable::EnqueueExecution(
|
|||||||
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||||
->local_device_state()
|
->local_device_state()
|
||||||
->device_ordinal();
|
->device_ordinal();
|
||||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||||
tensorflow::profiler::TraceMeConsumer activity(
|
tensorflow::profiler::TraceMeConsumer activity(
|
||||||
"LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
|
"LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
|
||||||
run_id.ToInt());
|
run_id.ToInt());
|
||||||
@ -1790,7 +1807,7 @@ PjRtStreamExecutorExecutable::MakeOutputBuffers(
|
|||||||
std::shared_ptr<BufferSequencingEvent> definition_event,
|
std::shared_ptr<BufferSequencingEvent> definition_event,
|
||||||
PjRtDevice* device) const {
|
PjRtDevice* device) const {
|
||||||
std::vector<std::unique_ptr<PjRtBuffer>> outputs;
|
std::vector<std::unique_ptr<PjRtBuffer>> outputs;
|
||||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||||
if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) {
|
if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) {
|
||||||
int tuple_count = result_buffer.on_host_shape().tuple_shapes_size();
|
int tuple_count = result_buffer.on_host_shape().tuple_shapes_size();
|
||||||
outputs.reserve(tuple_count);
|
outputs.reserve(tuple_count);
|
||||||
@ -1827,7 +1844,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
|
|||||||
if (device == nullptr) {
|
if (device == nullptr) {
|
||||||
CHECK(device_assignment_ != nullptr);
|
CHECK(device_assignment_ != nullptr);
|
||||||
const int device_id = (*device_assignment_)(replica, partition);
|
const int device_id = (*device_assignment_)(replica, partition);
|
||||||
device = LookupDevice(*client_, device_id);
|
TF_ASSIGN_OR_RETURN(device, client_->LookupDevice(device_id));
|
||||||
device_assignment = device_assignment_;
|
device_assignment = device_assignment_;
|
||||||
} else {
|
} else {
|
||||||
CHECK(device_assignment_ == nullptr);
|
CHECK(device_assignment_ == nullptr);
|
||||||
@ -1863,7 +1880,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
|
|||||||
ScopedShapedBuffer result_buffer =
|
ScopedShapedBuffer result_buffer =
|
||||||
result_buffer_or_status.ConsumeValueOrDie();
|
result_buffer_or_status.ConsumeValueOrDie();
|
||||||
|
|
||||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||||
se::Stream* stream = device_state->compute_stream();
|
se::Stream* stream = device_state->compute_stream();
|
||||||
StatusOr<EventPool::Handle> event_or =
|
StatusOr<EventPool::Handle> event_or =
|
||||||
device_state->event_pool().ThenAllocateAndRecordEvent(stream);
|
device_state->event_pool().ThenAllocateAndRecordEvent(stream);
|
||||||
@ -2160,9 +2177,9 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
|
||||||
const XlaComputation& computation, CompileOptions options) {
|
const XlaComputation& computation, CompileOptions options) {
|
||||||
tensorflow::profiler::TraceMe traceme("PjRtClient::Compile");
|
tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
|
||||||
|
|
||||||
ExecutableBuildOptions& build_options = options.executable_build_options;
|
ExecutableBuildOptions& build_options = options.executable_build_options;
|
||||||
if (!build_options.device_allocator()) {
|
if (!build_options.device_allocator()) {
|
||||||
@ -2182,14 +2199,15 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
|||||||
num_partitions = 1;
|
num_partitions = 1;
|
||||||
} else {
|
} else {
|
||||||
if (!build_options.has_device_assignment()) {
|
if (!build_options.has_device_assignment()) {
|
||||||
VLOG(2) << "PjRtClient::Compile using default device_assignment.";
|
VLOG(2) << "PjRtStreamExecutorClient::Compile using default "
|
||||||
|
"device_assignment.";
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
DeviceAssignment device_assignment,
|
DeviceAssignment device_assignment,
|
||||||
GetDefaultDeviceAssignment(build_options.num_replicas(),
|
GetDefaultDeviceAssignment(build_options.num_replicas(),
|
||||||
build_options.num_partitions()));
|
build_options.num_partitions()));
|
||||||
build_options.set_device_assignment(device_assignment);
|
build_options.set_device_assignment(device_assignment);
|
||||||
}
|
}
|
||||||
VLOG(2) << "PjRtClient::Compile device_assignment:\n"
|
VLOG(2) << "PjRtStreamExecutorClient::Compile device_assignment:\n"
|
||||||
<< build_options.device_assignment().ToString();
|
<< build_options.device_assignment().ToString();
|
||||||
num_replicas = build_options.device_assignment().replica_count();
|
num_replicas = build_options.device_assignment().replica_count();
|
||||||
num_partitions = build_options.device_assignment().computation_count();
|
num_partitions = build_options.device_assignment().computation_count();
|
||||||
@ -2263,7 +2281,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
|||||||
for (int replica = 0; replica < num_replicas; ++replica) {
|
for (int replica = 0; replica < num_replicas; ++replica) {
|
||||||
for (int partition = 0; partition < num_partitions; ++partition) {
|
for (int partition = 0; partition < num_partitions; ++partition) {
|
||||||
int device_id = (*device_assignment)(replica, partition);
|
int device_id = (*device_assignment)(replica, partition);
|
||||||
PjRtDevice* device = LookupDevice(*this, device_id);
|
TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id));
|
||||||
if (device->host_id() != host_id()) {
|
if (device->host_id() != host_id()) {
|
||||||
VLOG(3) << "Non-local device: " << device_id;
|
VLOG(3) << "Non-local device: " << device_id;
|
||||||
continue;
|
continue;
|
||||||
@ -2283,10 +2301,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
|||||||
|
|
||||||
if (build_options.device_ordinal() < 0) {
|
if (build_options.device_ordinal() < 0) {
|
||||||
build_options.set_device_ordinal(
|
build_options.set_device_ordinal(
|
||||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(
|
addressable_devices.front()->local_hardware_id());
|
||||||
addressable_devices.front())
|
|
||||||
->local_device_state()
|
|
||||||
->device_ordinal());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -219,86 +219,62 @@ class PjRtExecutable;
|
|||||||
// alive as long as any of the other runtime objects are alive.
|
// alive as long as any of the other runtime objects are alive.
|
||||||
class PjRtClient {
|
class PjRtClient {
|
||||||
public:
|
public:
|
||||||
// `allocator` may null, in which case the platform default allocator is used.
|
|
||||||
explicit PjRtClient(
|
|
||||||
std::string platform_name, LocalClient* client,
|
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
|
|
||||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
|
||||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
|
||||||
bool should_stage_host_to_device_transfers,
|
|
||||||
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
|
|
||||||
virtual ~PjRtClient() = default;
|
virtual ~PjRtClient() = default;
|
||||||
|
|
||||||
|
// TODO(zhangqiaorjc): Rename to task_id.
|
||||||
|
// Return the task id of this client. In single-task setting, always 0.
|
||||||
|
virtual int host_id() const = 0;
|
||||||
|
|
||||||
|
// Return the number of devices in the entire computation. In multi-headed
|
||||||
|
// client setting, some are addressable by this client, some are not. In a
|
||||||
|
// single-client setting, this is equal to the number of addressable devices.
|
||||||
|
virtual int device_count() const = 0;
|
||||||
|
|
||||||
|
// Return number of addressable devices. Addressable devices are those that
|
||||||
|
// the client can issue commands to.
|
||||||
|
virtual int addressable_device_count() const = 0;
|
||||||
|
|
||||||
|
// Return all devices in the entire computation, including addressable and
|
||||||
|
// non-addressable devices.
|
||||||
|
virtual absl::Span<PjRtDevice* const> devices() const = 0;
|
||||||
|
|
||||||
|
// TODO(zhangqiaorjc): Rename to addressable_devices.
|
||||||
|
// Return only addressable devices.
|
||||||
|
virtual absl::Span<PjRtDevice* const> local_devices() const = 0;
|
||||||
|
|
||||||
|
// Lookup any PjRtDevice for a given PjRtDevice::id().
|
||||||
|
virtual StatusOr<PjRtDevice*> LookupDevice(int device_id) const = 0;
|
||||||
|
|
||||||
|
// Return an addressable PjRtDevice for a given
|
||||||
|
// PjRtDevice::local_hardware_id().
|
||||||
|
virtual StatusOr<PjRtDevice*> LookupAddressableDevice(
|
||||||
|
int local_hardware_id) const = 0;
|
||||||
|
|
||||||
|
// Return an ID that identifies the platform (CPU/GPU/TPU).
|
||||||
|
virtual PjRtPlatformId platform_id() const = 0;
|
||||||
|
|
||||||
|
// Returns a string that identifies the platform (CPU/GPU/TPU).
|
||||||
|
virtual const std::string& platform_name() const = 0;
|
||||||
|
|
||||||
|
// Return a device-specific default device assignment, e.g., GPU and TPU may
|
||||||
|
// be different.
|
||||||
virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
||||||
int num_replicas, int num_partitions) const;
|
int num_replicas, int num_partitions) const = 0;
|
||||||
|
|
||||||
int device_count() const { return devices_.size(); }
|
|
||||||
int local_device_count() const { return local_devices_.size(); }
|
|
||||||
const std::vector<std::unique_ptr<PjRtDevice>>& devices() const {
|
|
||||||
return devices_;
|
|
||||||
}
|
|
||||||
absl::Span<PjRtDevice* const> local_devices() const { return local_devices_; }
|
|
||||||
const std::map<int, PjRtDevice*>& id_to_device() const {
|
|
||||||
return id_to_device_;
|
|
||||||
}
|
|
||||||
int host_id() const { return host_id_; }
|
|
||||||
PjRtPlatformId platform_id() const { return platform_id_; }
|
|
||||||
const std::string& platform_name() const { return platform_name_; }
|
|
||||||
|
|
||||||
LocalDeviceState& device_state(int device_ordinal) const {
|
|
||||||
return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
|
|
||||||
local_devices_.at(device_ordinal))
|
|
||||||
->local_device_state();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return an addressable PjRtDevice for a given `device_id`.
|
|
||||||
virtual StatusOr<PjRtDevice*> LookupAddressableDevice(int device_id) const;
|
|
||||||
|
|
||||||
LocalClient* client() const { return client_; }
|
|
||||||
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
|
|
||||||
tensorflow::Allocator* host_memory_allocator() const {
|
|
||||||
return host_memory_allocator_.get();
|
|
||||||
}
|
|
||||||
bool should_stage_host_to_device_transfers() const {
|
|
||||||
return should_stage_host_to_device_transfers_;
|
|
||||||
}
|
|
||||||
|
|
||||||
gpu::GpuExecutableRunOptions* gpu_run_options() const {
|
|
||||||
return gpu_run_options_.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
|
|
||||||
return &h2d_transfer_pool_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Most platforms expect device-to-device transfers to be enqueued on the
|
|
||||||
// source d2d stream, but some platforms use the destination d2d stream. This
|
|
||||||
// function specifies which one the platform expects.
|
|
||||||
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
|
|
||||||
|
|
||||||
// Generates a unique fingerprint for `executable`.
|
|
||||||
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
|
|
||||||
const PjRtExecutable& executable) const {
|
|
||||||
return absl::optional<std::string>();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a backend-specific HLO cost analysis visitor.
|
// Returns a backend-specific HLO cost analysis visitor.
|
||||||
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis();
|
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() = 0;
|
||||||
|
|
||||||
|
// Compile `computation` with given `options`.
|
||||||
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
|
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
|
||||||
const XlaComputation& computation, CompileOptions options);
|
const XlaComputation& computation, CompileOptions options) = 0;
|
||||||
|
|
||||||
|
// Generates a unique fingerprint for `executable`, may be absl::nullopt.
|
||||||
|
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
|
||||||
|
const PjRtExecutable& executable) const = 0;
|
||||||
|
|
||||||
// Creates a buffer on the device without initializing or copying any data.
|
// Creates a buffer on the device without initializing or copying any data.
|
||||||
// An optional `definition_event` may be speficied that can be used to
|
|
||||||
// ensure the buffer isn't referenced until some external mechanism has
|
|
||||||
// initialized the data.
|
|
||||||
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
|
|
||||||
// future backends and so callers should avoid wherever possible.
|
|
||||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||||
const Shape& shape, PjRtDevice* device);
|
const Shape& shape, PjRtDevice* device) = 0;
|
||||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
|
||||||
const Shape& shape, PjRtDevice* device,
|
|
||||||
std::shared_ptr<BufferSequencingEvent> definition_event);
|
|
||||||
|
|
||||||
// Describes the semantics the caller to BufferFromHostBuffer expects from the
|
// Describes the semantics the caller to BufferFromHostBuffer expects from the
|
||||||
// runtime, in a total order from most restrictive to least restrictive.
|
// runtime, in a total order from most restrictive to least restrictive.
|
||||||
@ -330,13 +306,13 @@ class PjRtClient {
|
|||||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
|
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
|
||||||
const void* data, const Shape& shape,
|
const void* data, const Shape& shape,
|
||||||
HostBufferSemantics host_buffer_semantics,
|
HostBufferSemantics host_buffer_semantics,
|
||||||
std::shared_ptr<void> buffer_reference, PjRtDevice* device);
|
std::shared_ptr<void> buffer_reference, PjRtDevice* device) = 0;
|
||||||
|
|
||||||
// Note that literal must remain in scope until the transfer has completed, so
|
// Note that literal must remain in scope until the transfer has completed, so
|
||||||
// the caller should, for example, wait for BlockHostUntilReady() completes on
|
// the caller should, for example, wait for BlockHostUntilReady() completes on
|
||||||
// the return value before letting literal go out of scope.
|
// the return value before letting literal go out of scope.
|
||||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
|
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
|
||||||
const LiteralSlice& literal, PjRtDevice* device);
|
const LiteralSlice& literal, PjRtDevice* device) = 0;
|
||||||
|
|
||||||
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
|
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
|
||||||
// cross host transfers using `client` on `device'. `shapes` must be the exact
|
// cross host transfers using `client` on `device'. `shapes` must be the exact
|
||||||
@ -349,18 +325,140 @@ class PjRtClient {
|
|||||||
// buffers will become ready until *all* of the sends have completed.
|
// buffers will become ready until *all* of the sends have completed.
|
||||||
virtual void MakeCrossHostReceiveBuffers(
|
virtual void MakeCrossHostReceiveBuffers(
|
||||||
absl::Span<const Shape> shapes, PjRtDevice* device,
|
absl::Span<const Shape> shapes, PjRtDevice* device,
|
||||||
PjRtCrossHostRecvNotifier&& notifier);
|
PjRtCrossHostRecvNotifier&& notifier) = 0;
|
||||||
|
|
||||||
virtual StatusOr<ChannelHandle> CreateChannelHandle() {
|
// Create ChannelHandles for XLA send/recv.
|
||||||
|
virtual StatusOr<ChannelHandle> CreateChannelHandle() = 0;
|
||||||
|
virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() = 0;
|
||||||
|
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class PjRtStreamExecutorClient : public PjRtClient {
|
||||||
|
public:
|
||||||
|
// `allocator` may null, in which case the platform default allocator is used.
|
||||||
|
explicit PjRtStreamExecutorClient(
|
||||||
|
std::string platform_name, LocalClient* client,
|
||||||
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
|
||||||
|
int host_id, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||||
|
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||||
|
bool should_stage_host_to_device_transfers,
|
||||||
|
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
|
||||||
|
~PjRtStreamExecutorClient() override = default;
|
||||||
|
|
||||||
|
int host_id() const override { return host_id_; }
|
||||||
|
|
||||||
|
int device_count() const override { return devices_.size(); }
|
||||||
|
int addressable_device_count() const override {
|
||||||
|
return local_devices_.size();
|
||||||
|
}
|
||||||
|
absl::Span<PjRtDevice* const> devices() const override { return devices_; }
|
||||||
|
absl::Span<PjRtDevice* const> local_devices() const override {
|
||||||
|
return local_devices_;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<PjRtDevice*> LookupDevice(int device_id) const override {
|
||||||
|
auto it = id_to_device_.find(device_id);
|
||||||
|
if (it != id_to_device_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
return InvalidArgument("No matching device found for device_id %d",
|
||||||
|
device_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<PjRtDevice*> LookupAddressableDevice(
|
||||||
|
int local_hardware_id) const override;
|
||||||
|
|
||||||
|
PjRtPlatformId platform_id() const override { return platform_id_; }
|
||||||
|
const std::string& platform_name() const override { return platform_name_; }
|
||||||
|
|
||||||
|
// Most platforms expect device-to-device transfers to be enqueued on the
|
||||||
|
// source d2d stream, but some platforms use the destination d2d stream. This
|
||||||
|
// function specifies which one the platform expects.
|
||||||
|
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
|
||||||
|
|
||||||
|
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
||||||
|
int num_replicas, int num_partitions) const override;
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
|
||||||
|
const XlaComputation& computation, CompileOptions options) override;
|
||||||
|
|
||||||
|
// Generates a unique fingerprint for `executable`.
|
||||||
|
StatusOr<absl::optional<std::string>> ExecutableFingerprint(
|
||||||
|
const PjRtExecutable& executable) const override {
|
||||||
|
return absl::optional<std::string>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a backend-specific HLO cost analysis visitor.
|
||||||
|
std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() override;
|
||||||
|
|
||||||
|
// Creates a buffer on the device without initializing or copying any data.
|
||||||
|
// An optional `definition_event` may be speficied that can be used to
|
||||||
|
// ensure the buffer isn't referenced until some external mechanism has
|
||||||
|
// initialized the data.
|
||||||
|
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
|
||||||
|
// future backends and so callers should avoid wherever possible.
|
||||||
|
StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||||
|
const Shape& shape, PjRtDevice* device) override;
|
||||||
|
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||||
|
const Shape& shape, PjRtDevice* device,
|
||||||
|
std::shared_ptr<BufferSequencingEvent> definition_event);
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
|
||||||
|
const void* data, const Shape& shape,
|
||||||
|
HostBufferSemantics host_buffer_semantics,
|
||||||
|
std::shared_ptr<void> buffer_reference, PjRtDevice* device) override;
|
||||||
|
|
||||||
|
// Note that literal must remain in scope until the transfer has completed, so
|
||||||
|
// the caller should, for example, wait for BlockHostUntilReady() completes on
|
||||||
|
// the return value before letting literal go out of scope.
|
||||||
|
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
|
||||||
|
const LiteralSlice& literal, PjRtDevice* device) override;
|
||||||
|
|
||||||
|
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
|
||||||
|
// cross host transfers using `client` on `device'. `shapes` must be the exact
|
||||||
|
// shapes, with identical layouts, corresponding to the buffers that will be
|
||||||
|
// sent. When resources for the transfer are available, notifier will be
|
||||||
|
// called with a vector of PjRtCrossHostRecvBuffer structs, one for each
|
||||||
|
// shape in `shapes`. Each struct contains a buffer that will contain the
|
||||||
|
// received value, and an opaque string that should be transmitted to the
|
||||||
|
// sending host and used in a call to CopyToRemoteDevice. None of the recv
|
||||||
|
// buffers will become ready until *all* of the sends have completed.
|
||||||
|
void MakeCrossHostReceiveBuffers(
|
||||||
|
absl::Span<const Shape> shapes, PjRtDevice* device,
|
||||||
|
PjRtCrossHostRecvNotifier&& notifier) override;
|
||||||
|
|
||||||
|
StatusOr<ChannelHandle> CreateChannelHandle() override {
|
||||||
return client()->CreateChannelHandle();
|
return client()->CreateChannelHandle();
|
||||||
}
|
}
|
||||||
virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
|
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override {
|
||||||
return client()->CreateDeviceToHostChannelHandle();
|
return client()->CreateDeviceToHostChannelHandle();
|
||||||
}
|
}
|
||||||
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
|
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override {
|
||||||
return client()->CreateHostToDeviceChannelHandle();
|
return client()->CreateHostToDeviceChannelHandle();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LocalDeviceState& device_state(int device_ordinal) const {
|
||||||
|
return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
|
||||||
|
local_devices_.at(device_ordinal))
|
||||||
|
->local_device_state();
|
||||||
|
}
|
||||||
|
LocalClient* client() const { return client_; }
|
||||||
|
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
|
||||||
|
tensorflow::Allocator* host_memory_allocator() const {
|
||||||
|
return host_memory_allocator_.get();
|
||||||
|
}
|
||||||
|
bool should_stage_host_to_device_transfers() const {
|
||||||
|
return should_stage_host_to_device_transfers_;
|
||||||
|
}
|
||||||
|
|
||||||
|
gpu::GpuExecutableRunOptions* gpu_run_options() const {
|
||||||
|
return gpu_run_options_.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
|
||||||
|
return &h2d_transfer_pool_;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
friend class PjRtBuffer;
|
friend class PjRtBuffer;
|
||||||
virtual void EnqueueCrossHostReceive(
|
virtual void EnqueueCrossHostReceive(
|
||||||
@ -383,7 +481,9 @@ class PjRtClient {
|
|||||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
|
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
|
||||||
|
|
||||||
// Includes all devices, including non-local devices on multi-host platforms.
|
// Includes all devices, including non-local devices on multi-host platforms.
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> devices_;
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> owned_devices_;
|
||||||
|
// Pointers to `owned_devices_`.
|
||||||
|
std::vector<PjRtDevice*> devices_;
|
||||||
// Maps Device::id() to the corresponding Device. Includes all devices.
|
// Maps Device::id() to the corresponding Device. Includes all devices.
|
||||||
std::map<int, PjRtDevice*> id_to_device_;
|
std::map<int, PjRtDevice*> id_to_device_;
|
||||||
// Local devices indexed by local device ordinal.
|
// Local devices indexed by local device ordinal.
|
||||||
@ -550,7 +650,7 @@ class PjRtBuffer {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
friend class PjRtBuffer;
|
friend class PjRtBuffer;
|
||||||
friend class PjRtClient;
|
friend class PjRtStreamExecutorClient;
|
||||||
|
|
||||||
// Helper struct that makes it possible to move a ScopedHold through a
|
// Helper struct that makes it possible to move a ScopedHold through a
|
||||||
// closure.
|
// closure.
|
||||||
@ -810,7 +910,7 @@ class PjRtExecutable {
|
|||||||
virtual PjRtClient* client() const = 0;
|
virtual PjRtClient* client() const = 0;
|
||||||
|
|
||||||
// Unique name for this executable, e.g., HloModule name.
|
// Unique name for this executable, e.g., HloModule name.
|
||||||
virtual const string& name() const = 0;
|
virtual const std::string& name() const = 0;
|
||||||
|
|
||||||
virtual int num_replicas() const = 0;
|
virtual int num_replicas() const = 0;
|
||||||
|
|
||||||
@ -875,13 +975,14 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
|
|||||||
bool parameter_is_tupled_arguments,
|
bool parameter_is_tupled_arguments,
|
||||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||||
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
|
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
|
||||||
std::vector<PjRtDevice*> addressable_devices, PjRtClient* client);
|
std::vector<PjRtDevice*> addressable_devices,
|
||||||
|
PjRtStreamExecutorClient* client);
|
||||||
|
|
||||||
~PjRtStreamExecutorExecutable() override = default;
|
~PjRtStreamExecutorExecutable() override = default;
|
||||||
|
|
||||||
PjRtClient* client() const override { return client_; }
|
PjRtStreamExecutorClient* client() const override { return client_; }
|
||||||
|
|
||||||
const string& name() const override;
|
const std::string& name() const override;
|
||||||
|
|
||||||
int num_replicas() const override {
|
int num_replicas() const override {
|
||||||
return executables_[0]->build_options().num_replicas();
|
return executables_[0]->build_options().num_replicas();
|
||||||
@ -940,7 +1041,7 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class PjRtClient;
|
friend class PjRtStreamExecutorClient;
|
||||||
// Initializes information about which arguments to which executables must be
|
// Initializes information about which arguments to which executables must be
|
||||||
// donated due to aliases that were specified by the computation.
|
// donated due to aliases that were specified by the computation.
|
||||||
Status SetUpDonation(bool tuple_inputs);
|
Status SetUpDonation(bool tuple_inputs);
|
||||||
@ -975,7 +1076,7 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
|
|||||||
// Create shared pointers so we can free them after the execution: with
|
// Create shared pointers so we can free them after the execution: with
|
||||||
// asynchronous execution, the process being executed can outlive the
|
// asynchronous execution, the process being executed can outlive the
|
||||||
// executable itself.
|
// executable itself.
|
||||||
PjRtClient* const client_;
|
PjRtStreamExecutorClient* const client_;
|
||||||
// One executable per partition.
|
// One executable per partition.
|
||||||
std::vector<std::shared_ptr<LocalExecutable>> executables_;
|
std::vector<std::shared_ptr<LocalExecutable>> executables_;
|
||||||
// Per-executable set of parameters that have any aliased buffers and thus
|
// Per-executable set of parameters that have any aliased buffers and thus
|
||||||
|
@ -94,10 +94,11 @@ Status TpuDeviceState::ThenMemcpyDeviceToDevice(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
class PjRtTpuClient : public PjRtClient {
|
class PjRtTpuClient : public PjRtStreamExecutorClient {
|
||||||
public:
|
public:
|
||||||
PjRtTpuClient(LocalClient* client,
|
PjRtTpuClient(LocalClient* client,
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id);
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
|
||||||
|
int host_id);
|
||||||
|
|
||||||
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
||||||
int num_replicas, int num_partitions) const override;
|
int num_replicas, int num_partitions) const override;
|
||||||
@ -108,14 +109,14 @@ class PjRtTpuClient : public PjRtClient {
|
|||||||
const PjRtExecutable& executable) const override;
|
const PjRtExecutable& executable) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
PjRtTpuClient::PjRtTpuClient(LocalClient* client,
|
PjRtTpuClient::PjRtTpuClient(
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> devices,
|
LocalClient* client,
|
||||||
int host_id)
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id)
|
||||||
: PjRtClient(kTpuName, client, std::move(devices), host_id,
|
: PjRtStreamExecutorClient(kTpuName, client, std::move(devices), host_id,
|
||||||
/*allocator=*/nullptr,
|
/*allocator=*/nullptr,
|
||||||
/*host_memory_allocator=*/nullptr,
|
/*host_memory_allocator=*/nullptr,
|
||||||
/*should_stage_host_to_device_transfers=*/false,
|
/*should_stage_host_to_device_transfers=*/false,
|
||||||
/*gpu_run_options=*/nullptr) {}
|
/*gpu_run_options=*/nullptr) {}
|
||||||
|
|
||||||
StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
|
StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
|
||||||
int num_replicas, int num_partitions) const {
|
int num_replicas, int num_partitions) const {
|
||||||
@ -128,7 +129,8 @@ StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
|
|||||||
num_partitions);
|
num_partitions);
|
||||||
}
|
}
|
||||||
// Fallback to default global device assignment if we can't run locally.
|
// Fallback to default global device assignment if we can't run locally.
|
||||||
return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions);
|
return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
|
||||||
|
num_partitions);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
|
StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
|
||||||
@ -152,10 +154,10 @@ StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
|
|||||||
return absl::optional<std::string>(tpu_executable->fingerprint());
|
return absl::optional<std::string>(tpu_executable->fingerprint());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<PjRtDevice>>> GetTpuDevices(
|
StatusOr<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>> GetTpuDevices(
|
||||||
LocalClient* client,
|
LocalClient* client,
|
||||||
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
|
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
|
||||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||||
tf_tpu::TpuTopologyExternal topology =
|
tf_tpu::TpuTopologyExternal topology =
|
||||||
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();
|
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();
|
||||||
|
|
||||||
|
@ -230,8 +230,8 @@ OutfeedReceiverImpl::OutfeedReceiverImpl(
|
|||||||
callback_ = callback;
|
callback_ = callback;
|
||||||
max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
|
max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
|
||||||
for (const auto& client : clients) {
|
for (const auto& client : clients) {
|
||||||
for (const auto& device : client->devices()) {
|
for (auto device : client->devices()) {
|
||||||
devices_.push_back(device.get());
|
devices_.push_back(device);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
CHECK_GT(devices_.size(), 0);
|
CHECK_GT(devices_.size(), 0);
|
||||||
|
@ -37,9 +37,10 @@ PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
|
|||||||
|
|
||||||
std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() {
|
std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() {
|
||||||
std::vector<ClientAndPtr<PjRtDevice>> devices;
|
std::vector<ClientAndPtr<PjRtDevice>> devices;
|
||||||
devices.reserve(pjrt_client_->devices().size());
|
auto span = pjrt_client_->devices();
|
||||||
for (const auto& device : pjrt_client_->devices()) {
|
devices.reserve(span.size());
|
||||||
devices.push_back(WrapWithClient(shared_from_this(), device.get()));
|
for (PjRtDevice* device : span) {
|
||||||
|
devices.push_back(WrapWithClient(shared_from_this(), device));
|
||||||
}
|
}
|
||||||
return devices;
|
return devices;
|
||||||
}
|
}
|
||||||
@ -64,9 +65,9 @@ PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
|
|||||||
result[r].resize(num_partitions);
|
result[r].resize(num_partitions);
|
||||||
for (int p = 0; p < num_partitions; ++p) {
|
for (int p = 0; p < num_partitions; ++p) {
|
||||||
int device_id = device_assignment(r, p);
|
int device_id = device_assignment(r, p);
|
||||||
auto iter = pjrt_client_->id_to_device().find(device_id);
|
TF_ASSIGN_OR_RETURN(PjRtDevice * device,
|
||||||
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
|
pjrt_client_->LookupDevice(device_id));
|
||||||
result[r][p] = WrapWithClient(shared_from_this(), iter->second);
|
result[r][p] = WrapWithClient(shared_from_this(), device);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
@ -80,9 +81,9 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
|
|||||||
std::vector<ClientAndPtr<PjRtDevice>> result;
|
std::vector<ClientAndPtr<PjRtDevice>> result;
|
||||||
for (int i = 0; i < num_replicas; ++i) {
|
for (int i = 0; i < num_replicas; ++i) {
|
||||||
int device_id = device_assignment(i, 0);
|
int device_id = device_assignment(i, 0);
|
||||||
auto iter = pjrt_client_->id_to_device().find(device_id);
|
TF_ASSIGN_OR_RETURN(PjRtDevice * device,
|
||||||
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
|
pjrt_client_->LookupDevice(device_id));
|
||||||
result.push_back(WrapWithClient(shared_from_this(), iter->second));
|
result.push_back(WrapWithClient(shared_from_this(), device));
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -95,8 +96,9 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
|
|||||||
device = pjrt_client_->local_devices().front();
|
device = pjrt_client_->local_devices().front();
|
||||||
}
|
}
|
||||||
CHECK(device != nullptr);
|
CHECK(device != nullptr);
|
||||||
auto iter = pjrt_client_->id_to_device().find(device->id());
|
TF_ASSIGN_OR_RETURN(PjRtDevice * found_device,
|
||||||
if (iter->second != device) {
|
pjrt_client_->LookupDevice(device->id()));
|
||||||
|
if (found_device != device) {
|
||||||
return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
|
return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
|
||||||
device->DebugString(),
|
device->DebugString(),
|
||||||
pjrt_client_->platform_name());
|
pjrt_client_->platform_name());
|
||||||
|
@ -97,7 +97,9 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
|
|||||||
const std::string& platform_name() const {
|
const std::string& platform_name() const {
|
||||||
return pjrt_client_->platform_name();
|
return pjrt_client_->platform_name();
|
||||||
}
|
}
|
||||||
int local_device_count() const { return pjrt_client_->local_device_count(); }
|
int addressable_device_count() const {
|
||||||
|
return pjrt_client_->addressable_device_count();
|
||||||
|
}
|
||||||
int device_count() const { return pjrt_client_->device_count(); }
|
int device_count() const { return pjrt_client_->device_count(); }
|
||||||
int host_id() const { return pjrt_client_->host_id(); }
|
int host_id() const { return pjrt_client_->host_id(); }
|
||||||
|
|
||||||
|
@ -240,7 +240,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
|
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
|
||||||
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
|
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
|
||||||
.def("device_count", &PyClient::device_count)
|
.def("device_count", &PyClient::device_count)
|
||||||
.def("local_device_count", &PyClient::local_device_count)
|
.def("local_device_count", &PyClient::addressable_device_count)
|
||||||
.def("devices", &PyClient::Devices)
|
.def("devices", &PyClient::Devices)
|
||||||
.def("local_devices", &PyClient::LocalDevices)
|
.def("local_devices", &PyClient::LocalDevices)
|
||||||
.def("host_id", &PyClient::host_id)
|
.def("host_id", &PyClient::host_id)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user