Extract a PjRtClient interface.
- Extract a pure interface - Current implementation is renamed to PjRtStreamExecutorClient. TODO: split into a pjrt_stream_executor.{h,cc} PiperOrigin-RevId: 346470160 Change-Id: I6235d9eed58cfd281c59f441dfed67ea3b9035ed
This commit is contained in:
parent
b9187102b6
commit
93dfb9b68f
@ -40,7 +40,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
||||
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
||||
ClientLibrary::GetOrCreateLocalClient(options));
|
||||
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||
for (int i = 0; i < client->device_count(); ++i) {
|
||||
se::StreamExecutorConfig config;
|
||||
config.ordinal = i;
|
||||
@ -57,11 +57,11 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
||||
devices.push_back(std::move(device));
|
||||
}
|
||||
|
||||
return std::make_unique<PjRtClient>(
|
||||
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
|
||||
kCpuName, client, std::move(devices), /*host_id=*/0,
|
||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr);
|
||||
/*gpu_run_options=*/nullptr));
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -35,9 +35,9 @@ namespace xla {
|
||||
namespace {
|
||||
|
||||
// A custom PjRtClient that overrides the device assignment method.
|
||||
class GpuClient : public xla::PjRtClient {
|
||||
class GpuClient : public xla::PjRtStreamExecutorClient {
|
||||
public:
|
||||
using xla::PjRtClient::PjRtClient;
|
||||
using xla::PjRtStreamExecutorClient::PjRtStreamExecutorClient;
|
||||
|
||||
xla::StatusOr<xla::DeviceAssignment> GetDefaultDeviceAssignment(
|
||||
int num_replicas, int num_partitions) const override;
|
||||
@ -55,7 +55,8 @@ xla::StatusOr<xla::DeviceAssignment> GpuClient::GetDefaultDeviceAssignment(
|
||||
return assignment;
|
||||
}
|
||||
// Fallback to default global device assignment if we can't run locally.
|
||||
return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions);
|
||||
return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
|
||||
num_partitions);
|
||||
}
|
||||
|
||||
// Builds an xla::LocalClient for the GPU platform.
|
||||
@ -225,9 +226,9 @@ StatusOr<std::string> NcclIdStore::GetNcclUniqueId(
|
||||
return result.first->second;
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> BuildLocalDevices(
|
||||
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||
for (auto& local_device : local_device_states) {
|
||||
int device_ordinal = local_device->device_ordinal();
|
||||
const se::DeviceDescription& description =
|
||||
@ -243,7 +244,7 @@ std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
|
||||
Status BuildDistributedDevices(
|
||||
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states,
|
||||
std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id,
|
||||
std::vector<std::unique_ptr<PjRtDevice>>* devices,
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>* devices,
|
||||
gpu::GpuExecutableRunOptions* gpu_executable_run_options) {
|
||||
LocalTopologyProto local_topology;
|
||||
local_topology.set_node_id(node_id);
|
||||
@ -322,7 +323,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient(
|
||||
auto host_memory_allocator =
|
||||
GetGpuHostAllocator(local_device_states.front()->executor());
|
||||
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||
auto gpu_run_options = absl::make_unique<gpu::GpuExecutableRunOptions>();
|
||||
if (distributed_client) {
|
||||
TF_RETURN_IF_ERROR(BuildDistributedDevices(
|
||||
|
@ -41,7 +41,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
|
||||
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
||||
ClientLibrary::GetOrCreateLocalClient(options));
|
||||
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||
se::StreamExecutor* executor =
|
||||
client->backend().stream_executor(0).ValueOrDie();
|
||||
auto device_state = absl::make_unique<LocalDeviceState>(
|
||||
@ -51,11 +51,11 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
|
||||
absl::make_unique<InterpreterDevice>(0, std::move(device_state));
|
||||
devices.push_back(std::move(device));
|
||||
|
||||
return std::make_unique<PjRtClient>(
|
||||
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
|
||||
"interpreter", client, std::move(devices), /*host_id=*/0,
|
||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr);
|
||||
/*gpu_run_options=*/nullptr));
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -184,9 +184,9 @@ class CpuAllocator : public tensorflow::Allocator {
|
||||
}
|
||||
};
|
||||
|
||||
PjRtClient::PjRtClient(
|
||||
PjRtStreamExecutorClient::PjRtStreamExecutorClient(
|
||||
std::string platform_name, LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id,
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||
bool should_stage_host_to_device_transfers,
|
||||
@ -195,7 +195,7 @@ PjRtClient::PjRtClient(
|
||||
platform_name_(std::move(platform_name)),
|
||||
client_(client),
|
||||
host_memory_allocator_(std::move(host_memory_allocator)),
|
||||
devices_(std::move(devices)),
|
||||
owned_devices_(std::move(devices)),
|
||||
host_id_(host_id),
|
||||
owned_allocator_(std::move(allocator)),
|
||||
should_stage_host_to_device_transfers_(
|
||||
@ -213,7 +213,9 @@ PjRtClient::PjRtClient(
|
||||
host_memory_allocator_ = std::make_unique<CpuAllocator>();
|
||||
}
|
||||
|
||||
for (const std::unique_ptr<PjRtDevice>& device : devices_) {
|
||||
for (const std::unique_ptr<PjRtStreamExecutorDevice>& device :
|
||||
owned_devices_) {
|
||||
devices_.push_back(device.get());
|
||||
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
|
||||
<< "Duplicate device id: " << device->id();
|
||||
|
||||
@ -225,21 +227,21 @@ PjRtClient::PjRtClient(
|
||||
CHECK(local_devices_[idx] == nullptr) << idx;
|
||||
local_devices_[idx] = device.get();
|
||||
}
|
||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device.get())
|
||||
->SetClient(this);
|
||||
device->SetClient(this);
|
||||
}
|
||||
for (int idx = 0; idx < local_devices_.size(); ++idx) {
|
||||
CHECK(local_devices_[idx] != nullptr) << idx;
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<DeviceAssignment> PjRtClient::GetDefaultDeviceAssignment(
|
||||
StatusOr<DeviceAssignment> PjRtStreamExecutorClient::GetDefaultDeviceAssignment(
|
||||
int num_replicas, int num_partitions) const {
|
||||
return client_->backend().computation_placer()->AssignDevices(num_replicas,
|
||||
num_partitions);
|
||||
}
|
||||
|
||||
std::unique_ptr<HloCostAnalysis> PjRtClient::GetHloCostAnalysis() {
|
||||
std::unique_ptr<HloCostAnalysis>
|
||||
PjRtStreamExecutorClient::GetHloCostAnalysis() {
|
||||
return absl::make_unique<HloCostAnalysis>(
|
||||
client_->backend().compiler()->ShapeSizeBytesFunction());
|
||||
}
|
||||
@ -349,12 +351,13 @@ StatusOr<std::unique_ptr<PjRtBuffer>> AllocateDestinationBuffer(
|
||||
return InvalidArgument("Can't make a buffer from an empty tuple");
|
||||
}
|
||||
|
||||
auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client);
|
||||
TransferManager* transfer_manager =
|
||||
client->client()->backend().transfer_manager();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
ScopedShapedBuffer dst_buffer,
|
||||
transfer_manager->AllocateScopedShapedBuffer(
|
||||
on_host_shape, client->allocator(), local_device->device_ordinal()));
|
||||
se_client->client()->backend().transfer_manager();
|
||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer,
|
||||
transfer_manager->AllocateScopedShapedBuffer(
|
||||
on_host_shape, se_client->allocator(),
|
||||
local_device->device_ordinal()));
|
||||
if (local_device->allocation_model() ==
|
||||
LocalDeviceState::kComputeSynchronized) {
|
||||
if (copy_stream == nullptr) {
|
||||
@ -546,13 +549,15 @@ void PjRtBuffer::ScopedHold::AddToInput(
|
||||
|
||||
bool PjRtBuffer::IsOnCpu() const { return client()->platform_id() == kCpuId; }
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>>
|
||||
PjRtStreamExecutorClient::BufferFromHostBuffer(
|
||||
const void* data, const Shape& shape,
|
||||
HostBufferSemantics host_buffer_semantics,
|
||||
std::shared_ptr<void> buffer_reference, PjRtDevice* device) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostBuffer");
|
||||
VLOG(2) << "PjRtClient::BufferFromHostBuffer: shape: " << shape.ToString()
|
||||
<< " device: " << device->DebugString();
|
||||
tensorflow::profiler::TraceMe traceme(
|
||||
"PjRtStreamExecutorClient::BufferFromHostBuffer");
|
||||
VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: "
|
||||
<< shape.ToString() << " device: " << device->DebugString();
|
||||
if (shape.IsTuple()) {
|
||||
return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
|
||||
}
|
||||
@ -712,17 +717,19 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
|
||||
return py_buffer;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
||||
const Shape& shape, PjRtDevice* device) {
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>>
|
||||
PjRtStreamExecutorClient::CreateUninitializedBuffer(const Shape& shape,
|
||||
PjRtDevice* device) {
|
||||
return CreateUninitializedBuffer(shape, device, nullptr);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>>
|
||||
PjRtStreamExecutorClient::CreateUninitializedBuffer(
|
||||
const Shape& shape, PjRtDevice* device,
|
||||
std::shared_ptr<BufferSequencingEvent> definition_event) {
|
||||
tensorflow::profiler::TraceMe traceme(
|
||||
"PjRtClient::CreateUninitializedBuffer");
|
||||
VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: "
|
||||
"PjRtStreamExecutorClient::CreateUninitializedBuffer");
|
||||
VLOG(2) << "PjRtStreamExecutorClient::CreateUninitializedBuffer: shape: "
|
||||
<< shape.ToString() << " device: " << device->DebugString();
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||
@ -738,10 +745,12 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
||||
definition_event);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
|
||||
const LiteralSlice& literal, PjRtDevice* device) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostLiteral");
|
||||
VLOG(2) << "PjRtClient::BufferFromHostLiteral: shape: "
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>>
|
||||
PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
|
||||
PjRtDevice* device) {
|
||||
tensorflow::profiler::TraceMe traceme(
|
||||
"PjRtStreamExecutorClient::BufferFromHostLiteral");
|
||||
VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostLiteral: shape: "
|
||||
<< literal.shape().ToString() << " device: " << device->DebugString();
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||
@ -798,7 +807,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
|
||||
return py_buffer;
|
||||
}
|
||||
|
||||
void PjRtClient::MakeCrossHostReceiveBuffers(
|
||||
void PjRtStreamExecutorClient::MakeCrossHostReceiveBuffers(
|
||||
absl::Span<const Shape> shapes, PjRtDevice* device,
|
||||
PjRtCrossHostRecvNotifier&& notifier) {
|
||||
if (shapes.empty()) {
|
||||
@ -851,14 +860,15 @@ StatusOr<Literal> PjRtStreamExecutorDevice::TransferFromOutfeed(
|
||||
shape, local_device->device_ordinal());
|
||||
}
|
||||
|
||||
StatusOr<PjRtDevice*> PjRtClient::LookupAddressableDevice(int device_id) const {
|
||||
StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
|
||||
int local_hardware_id) const {
|
||||
for (auto* device : local_devices_) {
|
||||
if (device_id == device->local_hardware_id()) {
|
||||
if (local_hardware_id == device->local_hardware_id()) {
|
||||
return device;
|
||||
}
|
||||
}
|
||||
return InvalidArgument("No matching device found for device_id %d",
|
||||
device_id);
|
||||
return InvalidArgument("No matching device found for local_hardware_id %d",
|
||||
local_hardware_id);
|
||||
}
|
||||
|
||||
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||
@ -883,7 +893,8 @@ PjRtBuffer::~PjRtBuffer() {
|
||||
}
|
||||
|
||||
int64 PjRtBuffer::OnDeviceSizeInBytes() const {
|
||||
return client_->client()
|
||||
return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||
->client()
|
||||
->backend()
|
||||
.transfer_manager()
|
||||
->GetByteSizeRequirement(on_device_shape_);
|
||||
@ -1136,12 +1147,16 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
|
||||
host_value->value = std::make_shared<Literal>(host_shape);
|
||||
ShapedBuffer shaped_buffer =
|
||||
device_buffer->AsShapedBuffer(host_shape, on_device_shape_);
|
||||
client_->client()->backend().transfer_manager()->TransferLiteralFromDevice(
|
||||
stream, shaped_buffer, host_value->value.get(),
|
||||
[host_value](Status done_status) {
|
||||
host_value->status = done_status;
|
||||
host_value->ready.Notify();
|
||||
});
|
||||
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||
->client()
|
||||
->backend()
|
||||
.transfer_manager()
|
||||
->TransferLiteralFromDevice(stream, shaped_buffer,
|
||||
host_value->value.get(),
|
||||
[host_value](Status done_status) {
|
||||
host_value->status = done_status;
|
||||
host_value->ready.Notify();
|
||||
});
|
||||
|
||||
auto usage_event = std::make_shared<BufferSequencingEvent>();
|
||||
StatusOr<EventPool::Handle> event_or =
|
||||
@ -1170,7 +1185,7 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
|
||||
|
||||
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
|
||||
const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtClient::ToLiteral");
|
||||
tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::ToLiteral");
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
|
||||
CopyToHostAsyncInternal(discard_cached_copy, layout));
|
||||
if (host_value == nullptr) {
|
||||
@ -1280,7 +1295,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
|
||||
return dst_device->client()->BufferFromHostBuffer(
|
||||
literal->untyped_data(), literal->shape(),
|
||||
PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, dst_device);
|
||||
PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy, nullptr,
|
||||
dst_device);
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -1288,7 +1304,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
|
||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device)
|
||||
->GetLocalDeviceState());
|
||||
LocalDeviceState* transfer_local_device =
|
||||
client_->EnqueueD2DTransfersOnSrcStream()
|
||||
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||
->EnqueueD2DTransfersOnSrcStream()
|
||||
? tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||
->local_device_state()
|
||||
: dst_local_device;
|
||||
@ -1339,7 +1356,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
|
||||
}
|
||||
|
||||
Status PjRtBuffer::CopyToRemoteDevice(absl::string_view serialized_descriptor) {
|
||||
return client_->CopyToRemoteDevice(this, serialized_descriptor);
|
||||
return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||
->CopyToRemoteDevice(this, serialized_descriptor);
|
||||
}
|
||||
|
||||
Status PjRtBuffer::BlockHostUntilReady() {
|
||||
@ -1401,9 +1419,13 @@ StatusOr<TupleHandle> MakeTupleHelper(
|
||||
Shape on_host_shape = ShapeUtil::MakeTupleShape(host_shapes);
|
||||
Shape on_device_shape = ShapeUtil::MakeTupleShape(device_shapes);
|
||||
|
||||
se::DeviceMemoryAllocator* allocator = client->allocator();
|
||||
se::DeviceMemoryAllocator* allocator =
|
||||
tensorflow::down_cast<PjRtStreamExecutorClient*>(client)->allocator();
|
||||
TransferManager* transfer_manager =
|
||||
client->client()->backend().transfer_manager();
|
||||
tensorflow::down_cast<PjRtStreamExecutorClient*>(client)
|
||||
->client()
|
||||
->backend()
|
||||
.transfer_manager();
|
||||
se::Stream* stream = local_device->host_to_device_stream();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
se::OwningDeviceMemory root_table_memory,
|
||||
@ -1467,14 +1489,6 @@ std::unique_ptr<PjRtBuffer> OutputBufferHelper(
|
||||
/*prefer_to_retain_reference=*/false);
|
||||
return pjrt_buffer;
|
||||
}
|
||||
|
||||
static PjRtDevice* LookupDevice(const PjRtClient& client, int device_id) {
|
||||
auto it = client.id_to_device().find(device_id);
|
||||
CHECK(it != client.id_to_device().end())
|
||||
<< "Unknown device id: " << device_id;
|
||||
return it->second;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
|
||||
@ -1482,7 +1496,8 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
|
||||
bool parameter_is_tupled_arguments,
|
||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
|
||||
std::vector<PjRtDevice*> addressable_devices, PjRtClient* client)
|
||||
std::vector<PjRtDevice*> addressable_devices,
|
||||
PjRtStreamExecutorClient* client)
|
||||
: client_(client),
|
||||
device_assignment_(std::move(device_assignment)),
|
||||
parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
|
||||
@ -1505,7 +1520,7 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
|
||||
VLOG(1) << "PjRtStreamExecutorExecutable device_assignment:\n"
|
||||
<< device_assignment_->ToString();
|
||||
CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString();
|
||||
CHECK_LE(addressable_devices_.size(), client_->local_device_count())
|
||||
CHECK_LE(addressable_devices_.size(), client_->addressable_device_count())
|
||||
<< "Inconsistent local device count.";
|
||||
num_partitions = device_assignment_->computation_count();
|
||||
}
|
||||
@ -1607,7 +1622,7 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
|
||||
absl::Span<const PjRtBuffer::ScopedHold> device_buffers,
|
||||
absl::flat_hash_set<BufferSequencingEvent*>& events) const {
|
||||
std::vector<ExecutionInput> execution_inputs;
|
||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
||||
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||
// Lift tuple_handle outside the conditional so that the event it returns is
|
||||
// not destroyed until after the loop below that waits on events.
|
||||
absl::optional<TupleHandle> tuple_handle;
|
||||
@ -1630,8 +1645,10 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
|
||||
execution_input.MutableBuffers()->begin();
|
||||
ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
|
||||
execution_input.MutableBuffers()->end();
|
||||
device_buffers[i].AddToInput(&input_iterator, iterator_end,
|
||||
&execution_input, client_->allocator());
|
||||
device_buffers[i].AddToInput(
|
||||
&input_iterator, iterator_end, &execution_input,
|
||||
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
|
||||
->allocator());
|
||||
CHECK(input_iterator == iterator_end);
|
||||
}
|
||||
}
|
||||
@ -1654,7 +1671,7 @@ StatusOr<ScopedShapedBuffer> PjRtStreamExecutorExecutable::EnqueueExecution(
|
||||
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||
->local_device_state()
|
||||
->device_ordinal();
|
||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
||||
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||
tensorflow::profiler::TraceMeConsumer activity(
|
||||
"LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
|
||||
run_id.ToInt());
|
||||
@ -1790,7 +1807,7 @@ PjRtStreamExecutorExecutable::MakeOutputBuffers(
|
||||
std::shared_ptr<BufferSequencingEvent> definition_event,
|
||||
PjRtDevice* device) const {
|
||||
std::vector<std::unique_ptr<PjRtBuffer>> outputs;
|
||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
||||
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||
if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) {
|
||||
int tuple_count = result_buffer.on_host_shape().tuple_shapes_size();
|
||||
outputs.reserve(tuple_count);
|
||||
@ -1827,7 +1844,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
|
||||
if (device == nullptr) {
|
||||
CHECK(device_assignment_ != nullptr);
|
||||
const int device_id = (*device_assignment_)(replica, partition);
|
||||
device = LookupDevice(*client_, device_id);
|
||||
TF_ASSIGN_OR_RETURN(device, client_->LookupDevice(device_id));
|
||||
device_assignment = device_assignment_;
|
||||
} else {
|
||||
CHECK(device_assignment_ == nullptr);
|
||||
@ -1863,7 +1880,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
|
||||
ScopedShapedBuffer result_buffer =
|
||||
result_buffer_or_status.ConsumeValueOrDie();
|
||||
|
||||
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
|
||||
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||
se::Stream* stream = device_state->compute_stream();
|
||||
StatusOr<EventPool::Handle> event_or =
|
||||
device_state->event_pool().ThenAllocateAndRecordEvent(stream);
|
||||
@ -2160,9 +2177,9 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
||||
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
|
||||
const XlaComputation& computation, CompileOptions options) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtClient::Compile");
|
||||
tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
|
||||
|
||||
ExecutableBuildOptions& build_options = options.executable_build_options;
|
||||
if (!build_options.device_allocator()) {
|
||||
@ -2182,14 +2199,15 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
||||
num_partitions = 1;
|
||||
} else {
|
||||
if (!build_options.has_device_assignment()) {
|
||||
VLOG(2) << "PjRtClient::Compile using default device_assignment.";
|
||||
VLOG(2) << "PjRtStreamExecutorClient::Compile using default "
|
||||
"device_assignment.";
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
DeviceAssignment device_assignment,
|
||||
GetDefaultDeviceAssignment(build_options.num_replicas(),
|
||||
build_options.num_partitions()));
|
||||
build_options.set_device_assignment(device_assignment);
|
||||
}
|
||||
VLOG(2) << "PjRtClient::Compile device_assignment:\n"
|
||||
VLOG(2) << "PjRtStreamExecutorClient::Compile device_assignment:\n"
|
||||
<< build_options.device_assignment().ToString();
|
||||
num_replicas = build_options.device_assignment().replica_count();
|
||||
num_partitions = build_options.device_assignment().computation_count();
|
||||
@ -2263,7 +2281,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
||||
for (int replica = 0; replica < num_replicas; ++replica) {
|
||||
for (int partition = 0; partition < num_partitions; ++partition) {
|
||||
int device_id = (*device_assignment)(replica, partition);
|
||||
PjRtDevice* device = LookupDevice(*this, device_id);
|
||||
TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id));
|
||||
if (device->host_id() != host_id()) {
|
||||
VLOG(3) << "Non-local device: " << device_id;
|
||||
continue;
|
||||
@ -2283,10 +2301,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
||||
|
||||
if (build_options.device_ordinal() < 0) {
|
||||
build_options.set_device_ordinal(
|
||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(
|
||||
addressable_devices.front())
|
||||
->local_device_state()
|
||||
->device_ordinal());
|
||||
addressable_devices.front()->local_hardware_id());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -219,86 +219,62 @@ class PjRtExecutable;
|
||||
// alive as long as any of the other runtime objects are alive.
|
||||
class PjRtClient {
|
||||
public:
|
||||
// `allocator` may null, in which case the platform default allocator is used.
|
||||
explicit PjRtClient(
|
||||
std::string platform_name, LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||
bool should_stage_host_to_device_transfers,
|
||||
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
|
||||
virtual ~PjRtClient() = default;
|
||||
|
||||
// TODO(zhangqiaorjc): Rename to task_id.
|
||||
// Return the task id of this client. In single-task setting, always 0.
|
||||
virtual int host_id() const = 0;
|
||||
|
||||
// Return the number of devices in the entire computation. In multi-headed
|
||||
// client setting, some are addressable by this client, some are not. In a
|
||||
// single-client setting, this is equal to the number of addressable devices.
|
||||
virtual int device_count() const = 0;
|
||||
|
||||
// Return number of addressable devices. Addressable devices are those that
|
||||
// the client can issue commands to.
|
||||
virtual int addressable_device_count() const = 0;
|
||||
|
||||
// Return all devices in the entire computation, including addressable and
|
||||
// non-addressable devices.
|
||||
virtual absl::Span<PjRtDevice* const> devices() const = 0;
|
||||
|
||||
// TODO(zhangqiaorjc): Rename to addressable_devices.
|
||||
// Return only addressable devices.
|
||||
virtual absl::Span<PjRtDevice* const> local_devices() const = 0;
|
||||
|
||||
// Lookup any PjRtDevice for a given PjRtDevice::id().
|
||||
virtual StatusOr<PjRtDevice*> LookupDevice(int device_id) const = 0;
|
||||
|
||||
// Return an addressable PjRtDevice for a given
|
||||
// PjRtDevice::local_hardware_id().
|
||||
virtual StatusOr<PjRtDevice*> LookupAddressableDevice(
|
||||
int local_hardware_id) const = 0;
|
||||
|
||||
// Return an ID that identifies the platform (CPU/GPU/TPU).
|
||||
virtual PjRtPlatformId platform_id() const = 0;
|
||||
|
||||
// Returns a string that identifies the platform (CPU/GPU/TPU).
|
||||
virtual const std::string& platform_name() const = 0;
|
||||
|
||||
// Return a device-specific default device assignment, e.g., GPU and TPU may
|
||||
// be different.
|
||||
virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
||||
int num_replicas, int num_partitions) const;
|
||||
|
||||
int device_count() const { return devices_.size(); }
|
||||
int local_device_count() const { return local_devices_.size(); }
|
||||
const std::vector<std::unique_ptr<PjRtDevice>>& devices() const {
|
||||
return devices_;
|
||||
}
|
||||
absl::Span<PjRtDevice* const> local_devices() const { return local_devices_; }
|
||||
const std::map<int, PjRtDevice*>& id_to_device() const {
|
||||
return id_to_device_;
|
||||
}
|
||||
int host_id() const { return host_id_; }
|
||||
PjRtPlatformId platform_id() const { return platform_id_; }
|
||||
const std::string& platform_name() const { return platform_name_; }
|
||||
|
||||
LocalDeviceState& device_state(int device_ordinal) const {
|
||||
return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
|
||||
local_devices_.at(device_ordinal))
|
||||
->local_device_state();
|
||||
}
|
||||
|
||||
// Return an addressable PjRtDevice for a given `device_id`.
|
||||
virtual StatusOr<PjRtDevice*> LookupAddressableDevice(int device_id) const;
|
||||
|
||||
LocalClient* client() const { return client_; }
|
||||
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
|
||||
tensorflow::Allocator* host_memory_allocator() const {
|
||||
return host_memory_allocator_.get();
|
||||
}
|
||||
bool should_stage_host_to_device_transfers() const {
|
||||
return should_stage_host_to_device_transfers_;
|
||||
}
|
||||
|
||||
gpu::GpuExecutableRunOptions* gpu_run_options() const {
|
||||
return gpu_run_options_.get();
|
||||
}
|
||||
|
||||
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
|
||||
return &h2d_transfer_pool_;
|
||||
}
|
||||
|
||||
// Most platforms expect device-to-device transfers to be enqueued on the
|
||||
// source d2d stream, but some platforms use the destination d2d stream. This
|
||||
// function specifies which one the platform expects.
|
||||
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
|
||||
|
||||
// Generates a unique fingerprint for `executable`.
|
||||
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
|
||||
const PjRtExecutable& executable) const {
|
||||
return absl::optional<std::string>();
|
||||
}
|
||||
int num_replicas, int num_partitions) const = 0;
|
||||
|
||||
// Returns a backend-specific HLO cost analysis visitor.
|
||||
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis();
|
||||
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() = 0;
|
||||
|
||||
// Compile `computation` with given `options`.
|
||||
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
|
||||
const XlaComputation& computation, CompileOptions options);
|
||||
const XlaComputation& computation, CompileOptions options) = 0;
|
||||
|
||||
// Generates a unique fingerprint for `executable`, may be absl::nullopt.
|
||||
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
|
||||
const PjRtExecutable& executable) const = 0;
|
||||
|
||||
// Creates a buffer on the device without initializing or copying any data.
|
||||
// An optional `definition_event` may be speficied that can be used to
|
||||
// ensure the buffer isn't referenced until some external mechanism has
|
||||
// initialized the data.
|
||||
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
|
||||
// future backends and so callers should avoid wherever possible.
|
||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||
const Shape& shape, PjRtDevice* device);
|
||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||
const Shape& shape, PjRtDevice* device,
|
||||
std::shared_ptr<BufferSequencingEvent> definition_event);
|
||||
const Shape& shape, PjRtDevice* device) = 0;
|
||||
|
||||
// Describes the semantics the caller to BufferFromHostBuffer expects from the
|
||||
// runtime, in a total order from most restrictive to least restrictive.
|
||||
@ -330,13 +306,13 @@ class PjRtClient {
|
||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
|
||||
const void* data, const Shape& shape,
|
||||
HostBufferSemantics host_buffer_semantics,
|
||||
std::shared_ptr<void> buffer_reference, PjRtDevice* device);
|
||||
std::shared_ptr<void> buffer_reference, PjRtDevice* device) = 0;
|
||||
|
||||
// Note that literal must remain in scope until the transfer has completed, so
|
||||
// the caller should, for example, wait for BlockHostUntilReady() completes on
|
||||
// the return value before letting literal go out of scope.
|
||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
|
||||
const LiteralSlice& literal, PjRtDevice* device);
|
||||
const LiteralSlice& literal, PjRtDevice* device) = 0;
|
||||
|
||||
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
|
||||
// cross host transfers using `client` on `device'. `shapes` must be the exact
|
||||
@ -349,18 +325,140 @@ class PjRtClient {
|
||||
// buffers will become ready until *all* of the sends have completed.
|
||||
virtual void MakeCrossHostReceiveBuffers(
|
||||
absl::Span<const Shape> shapes, PjRtDevice* device,
|
||||
PjRtCrossHostRecvNotifier&& notifier);
|
||||
PjRtCrossHostRecvNotifier&& notifier) = 0;
|
||||
|
||||
virtual StatusOr<ChannelHandle> CreateChannelHandle() {
|
||||
// Create ChannelHandles for XLA send/recv.
|
||||
virtual StatusOr<ChannelHandle> CreateChannelHandle() = 0;
|
||||
virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() = 0;
|
||||
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() = 0;
|
||||
};
|
||||
|
||||
class PjRtStreamExecutorClient : public PjRtClient {
|
||||
public:
|
||||
// `allocator` may null, in which case the platform default allocator is used.
|
||||
explicit PjRtStreamExecutorClient(
|
||||
std::string platform_name, LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
|
||||
int host_id, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||
bool should_stage_host_to_device_transfers,
|
||||
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
|
||||
~PjRtStreamExecutorClient() override = default;
|
||||
|
||||
int host_id() const override { return host_id_; }
|
||||
|
||||
int device_count() const override { return devices_.size(); }
|
||||
int addressable_device_count() const override {
|
||||
return local_devices_.size();
|
||||
}
|
||||
absl::Span<PjRtDevice* const> devices() const override { return devices_; }
|
||||
absl::Span<PjRtDevice* const> local_devices() const override {
|
||||
return local_devices_;
|
||||
}
|
||||
|
||||
StatusOr<PjRtDevice*> LookupDevice(int device_id) const override {
|
||||
auto it = id_to_device_.find(device_id);
|
||||
if (it != id_to_device_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return InvalidArgument("No matching device found for device_id %d",
|
||||
device_id);
|
||||
}
|
||||
|
||||
StatusOr<PjRtDevice*> LookupAddressableDevice(
|
||||
int local_hardware_id) const override;
|
||||
|
||||
PjRtPlatformId platform_id() const override { return platform_id_; }
|
||||
const std::string& platform_name() const override { return platform_name_; }
|
||||
|
||||
// Most platforms expect device-to-device transfers to be enqueued on the
|
||||
// source d2d stream, but some platforms use the destination d2d stream. This
|
||||
// function specifies which one the platform expects.
|
||||
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
|
||||
|
||||
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
||||
int num_replicas, int num_partitions) const override;
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
|
||||
const XlaComputation& computation, CompileOptions options) override;
|
||||
|
||||
// Generates a unique fingerprint for `executable`.
|
||||
StatusOr<absl::optional<std::string>> ExecutableFingerprint(
|
||||
const PjRtExecutable& executable) const override {
|
||||
return absl::optional<std::string>();
|
||||
}
|
||||
|
||||
// Returns a backend-specific HLO cost analysis visitor.
|
||||
std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() override;
|
||||
|
||||
// Creates a buffer on the device without initializing or copying any data.
|
||||
// An optional `definition_event` may be speficied that can be used to
|
||||
// ensure the buffer isn't referenced until some external mechanism has
|
||||
// initialized the data.
|
||||
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
|
||||
// future backends and so callers should avoid wherever possible.
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||
const Shape& shape, PjRtDevice* device) override;
|
||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||
const Shape& shape, PjRtDevice* device,
|
||||
std::shared_ptr<BufferSequencingEvent> definition_event);
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
|
||||
const void* data, const Shape& shape,
|
||||
HostBufferSemantics host_buffer_semantics,
|
||||
std::shared_ptr<void> buffer_reference, PjRtDevice* device) override;
|
||||
|
||||
// Note that literal must remain in scope until the transfer has completed, so
|
||||
// the caller should, for example, wait for BlockHostUntilReady() completes on
|
||||
// the return value before letting literal go out of scope.
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
|
||||
const LiteralSlice& literal, PjRtDevice* device) override;
|
||||
|
||||
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
|
||||
// cross host transfers using `client` on `device'. `shapes` must be the exact
|
||||
// shapes, with identical layouts, corresponding to the buffers that will be
|
||||
// sent. When resources for the transfer are available, notifier will be
|
||||
// called with a vector of PjRtCrossHostRecvBuffer structs, one for each
|
||||
// shape in `shapes`. Each struct contains a buffer that will contain the
|
||||
// received value, and an opaque string that should be transmitted to the
|
||||
// sending host and used in a call to CopyToRemoteDevice. None of the recv
|
||||
// buffers will become ready until *all* of the sends have completed.
|
||||
void MakeCrossHostReceiveBuffers(
|
||||
absl::Span<const Shape> shapes, PjRtDevice* device,
|
||||
PjRtCrossHostRecvNotifier&& notifier) override;
|
||||
|
||||
StatusOr<ChannelHandle> CreateChannelHandle() override {
|
||||
return client()->CreateChannelHandle();
|
||||
}
|
||||
virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
|
||||
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override {
|
||||
return client()->CreateDeviceToHostChannelHandle();
|
||||
}
|
||||
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
|
||||
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override {
|
||||
return client()->CreateHostToDeviceChannelHandle();
|
||||
}
|
||||
|
||||
LocalDeviceState& device_state(int device_ordinal) const {
|
||||
return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
|
||||
local_devices_.at(device_ordinal))
|
||||
->local_device_state();
|
||||
}
|
||||
LocalClient* client() const { return client_; }
|
||||
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
|
||||
tensorflow::Allocator* host_memory_allocator() const {
|
||||
return host_memory_allocator_.get();
|
||||
}
|
||||
bool should_stage_host_to_device_transfers() const {
|
||||
return should_stage_host_to_device_transfers_;
|
||||
}
|
||||
|
||||
gpu::GpuExecutableRunOptions* gpu_run_options() const {
|
||||
return gpu_run_options_.get();
|
||||
}
|
||||
|
||||
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
|
||||
return &h2d_transfer_pool_;
|
||||
}
|
||||
|
||||
protected:
|
||||
friend class PjRtBuffer;
|
||||
virtual void EnqueueCrossHostReceive(
|
||||
@ -383,7 +481,9 @@ class PjRtClient {
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
|
||||
|
||||
// Includes all devices, including non-local devices on multi-host platforms.
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices_;
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> owned_devices_;
|
||||
// Pointers to `owned_devices_`.
|
||||
std::vector<PjRtDevice*> devices_;
|
||||
// Maps Device::id() to the corresponding Device. Includes all devices.
|
||||
std::map<int, PjRtDevice*> id_to_device_;
|
||||
// Local devices indexed by local device ordinal.
|
||||
@ -550,7 +650,7 @@ class PjRtBuffer {
|
||||
|
||||
private:
|
||||
friend class PjRtBuffer;
|
||||
friend class PjRtClient;
|
||||
friend class PjRtStreamExecutorClient;
|
||||
|
||||
// Helper struct that makes it possible to move a ScopedHold through a
|
||||
// closure.
|
||||
@ -810,7 +910,7 @@ class PjRtExecutable {
|
||||
virtual PjRtClient* client() const = 0;
|
||||
|
||||
// Unique name for this executable, e.g., HloModule name.
|
||||
virtual const string& name() const = 0;
|
||||
virtual const std::string& name() const = 0;
|
||||
|
||||
virtual int num_replicas() const = 0;
|
||||
|
||||
@ -875,13 +975,14 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
|
||||
bool parameter_is_tupled_arguments,
|
||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
|
||||
std::vector<PjRtDevice*> addressable_devices, PjRtClient* client);
|
||||
std::vector<PjRtDevice*> addressable_devices,
|
||||
PjRtStreamExecutorClient* client);
|
||||
|
||||
~PjRtStreamExecutorExecutable() override = default;
|
||||
|
||||
PjRtClient* client() const override { return client_; }
|
||||
PjRtStreamExecutorClient* client() const override { return client_; }
|
||||
|
||||
const string& name() const override;
|
||||
const std::string& name() const override;
|
||||
|
||||
int num_replicas() const override {
|
||||
return executables_[0]->build_options().num_replicas();
|
||||
@ -940,7 +1041,7 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
|
||||
}
|
||||
|
||||
private:
|
||||
friend class PjRtClient;
|
||||
friend class PjRtStreamExecutorClient;
|
||||
// Initializes information about which arguments to which executables must be
|
||||
// donated due to aliases that were specified by the computation.
|
||||
Status SetUpDonation(bool tuple_inputs);
|
||||
@ -975,7 +1076,7 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
|
||||
// Create shared pointers so we can free them after the execution: with
|
||||
// asynchronous execution, the process being executed can outlive the
|
||||
// executable itself.
|
||||
PjRtClient* const client_;
|
||||
PjRtStreamExecutorClient* const client_;
|
||||
// One executable per partition.
|
||||
std::vector<std::shared_ptr<LocalExecutable>> executables_;
|
||||
// Per-executable set of parameters that have any aliased buffers and thus
|
||||
|
@ -94,10 +94,11 @@ Status TpuDeviceState::ThenMemcpyDeviceToDevice(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class PjRtTpuClient : public PjRtClient {
|
||||
class PjRtTpuClient : public PjRtStreamExecutorClient {
|
||||
public:
|
||||
PjRtTpuClient(LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id);
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
|
||||
int host_id);
|
||||
|
||||
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
||||
int num_replicas, int num_partitions) const override;
|
||||
@ -108,14 +109,14 @@ class PjRtTpuClient : public PjRtClient {
|
||||
const PjRtExecutable& executable) const override;
|
||||
};
|
||||
|
||||
PjRtTpuClient::PjRtTpuClient(LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices,
|
||||
int host_id)
|
||||
: PjRtClient(kTpuName, client, std::move(devices), host_id,
|
||||
/*allocator=*/nullptr,
|
||||
/*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr) {}
|
||||
PjRtTpuClient::PjRtTpuClient(
|
||||
LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id)
|
||||
: PjRtStreamExecutorClient(kTpuName, client, std::move(devices), host_id,
|
||||
/*allocator=*/nullptr,
|
||||
/*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr) {}
|
||||
|
||||
StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
|
||||
int num_replicas, int num_partitions) const {
|
||||
@ -128,7 +129,8 @@ StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
|
||||
num_partitions);
|
||||
}
|
||||
// Fallback to default global device assignment if we can't run locally.
|
||||
return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions);
|
||||
return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
|
||||
num_partitions);
|
||||
}
|
||||
|
||||
StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
|
||||
@ -152,10 +154,10 @@ StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
|
||||
return absl::optional<std::string>(tpu_executable->fingerprint());
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<PjRtDevice>>> GetTpuDevices(
|
||||
StatusOr<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>> GetTpuDevices(
|
||||
LocalClient* client,
|
||||
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices;
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||
tf_tpu::TpuTopologyExternal topology =
|
||||
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();
|
||||
|
||||
|
@ -230,8 +230,8 @@ OutfeedReceiverImpl::OutfeedReceiverImpl(
|
||||
callback_ = callback;
|
||||
max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
|
||||
for (const auto& client : clients) {
|
||||
for (const auto& device : client->devices()) {
|
||||
devices_.push_back(device.get());
|
||||
for (auto device : client->devices()) {
|
||||
devices_.push_back(device);
|
||||
}
|
||||
}
|
||||
CHECK_GT(devices_.size(), 0);
|
||||
|
@ -37,9 +37,10 @@ PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
|
||||
|
||||
std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() {
|
||||
std::vector<ClientAndPtr<PjRtDevice>> devices;
|
||||
devices.reserve(pjrt_client_->devices().size());
|
||||
for (const auto& device : pjrt_client_->devices()) {
|
||||
devices.push_back(WrapWithClient(shared_from_this(), device.get()));
|
||||
auto span = pjrt_client_->devices();
|
||||
devices.reserve(span.size());
|
||||
for (PjRtDevice* device : span) {
|
||||
devices.push_back(WrapWithClient(shared_from_this(), device));
|
||||
}
|
||||
return devices;
|
||||
}
|
||||
@ -64,9 +65,9 @@ PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
|
||||
result[r].resize(num_partitions);
|
||||
for (int p = 0; p < num_partitions; ++p) {
|
||||
int device_id = device_assignment(r, p);
|
||||
auto iter = pjrt_client_->id_to_device().find(device_id);
|
||||
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
|
||||
result[r][p] = WrapWithClient(shared_from_this(), iter->second);
|
||||
TF_ASSIGN_OR_RETURN(PjRtDevice * device,
|
||||
pjrt_client_->LookupDevice(device_id));
|
||||
result[r][p] = WrapWithClient(shared_from_this(), device);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
@ -80,9 +81,9 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
|
||||
std::vector<ClientAndPtr<PjRtDevice>> result;
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
int device_id = device_assignment(i, 0);
|
||||
auto iter = pjrt_client_->id_to_device().find(device_id);
|
||||
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
|
||||
result.push_back(WrapWithClient(shared_from_this(), iter->second));
|
||||
TF_ASSIGN_OR_RETURN(PjRtDevice * device,
|
||||
pjrt_client_->LookupDevice(device_id));
|
||||
result.push_back(WrapWithClient(shared_from_this(), device));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -95,8 +96,9 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
|
||||
device = pjrt_client_->local_devices().front();
|
||||
}
|
||||
CHECK(device != nullptr);
|
||||
auto iter = pjrt_client_->id_to_device().find(device->id());
|
||||
if (iter->second != device) {
|
||||
TF_ASSIGN_OR_RETURN(PjRtDevice * found_device,
|
||||
pjrt_client_->LookupDevice(device->id()));
|
||||
if (found_device != device) {
|
||||
return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
|
||||
device->DebugString(),
|
||||
pjrt_client_->platform_name());
|
||||
|
@ -97,7 +97,9 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
|
||||
const std::string& platform_name() const {
|
||||
return pjrt_client_->platform_name();
|
||||
}
|
||||
int local_device_count() const { return pjrt_client_->local_device_count(); }
|
||||
int addressable_device_count() const {
|
||||
return pjrt_client_->addressable_device_count();
|
||||
}
|
||||
int device_count() const { return pjrt_client_->device_count(); }
|
||||
int host_id() const { return pjrt_client_->host_id(); }
|
||||
|
||||
|
@ -240,7 +240,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
|
||||
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
|
||||
.def("device_count", &PyClient::device_count)
|
||||
.def("local_device_count", &PyClient::local_device_count)
|
||||
.def("local_device_count", &PyClient::addressable_device_count)
|
||||
.def("devices", &PyClient::Devices)
|
||||
.def("local_devices", &PyClient::LocalDevices)
|
||||
.def("host_id", &PyClient::host_id)
|
||||
|
Loading…
x
Reference in New Issue
Block a user