Extract a PjRtClient interface.

- Extract a pure interface
- Current implementation is renamed to PjRtStreamExecutorClient.

TODO: split into a pjrt_stream_executor.{h,cc}
PiperOrigin-RevId: 346470160
Change-Id: I6235d9eed58cfd281c59f441dfed67ea3b9035ed
This commit is contained in:
Qiao Zhang 2020-12-08 20:50:53 -08:00 committed by TensorFlower Gardener
parent b9187102b6
commit 93dfb9b68f
10 changed files with 318 additions and 195 deletions

View File

@ -40,7 +40,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
TF_ASSIGN_OR_RETURN(LocalClient * client, TF_ASSIGN_OR_RETURN(LocalClient * client,
ClientLibrary::GetOrCreateLocalClient(options)); ClientLibrary::GetOrCreateLocalClient(options));
std::vector<std::unique_ptr<PjRtDevice>> devices; std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
for (int i = 0; i < client->device_count(); ++i) { for (int i = 0; i < client->device_count(); ++i) {
se::StreamExecutorConfig config; se::StreamExecutorConfig config;
config.ordinal = i; config.ordinal = i;
@ -57,11 +57,11 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
devices.push_back(std::move(device)); devices.push_back(std::move(device));
} }
return std::make_unique<PjRtClient>( return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
kCpuName, client, std::move(devices), /*host_id=*/0, kCpuName, client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false, /*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr); /*gpu_run_options=*/nullptr));
} }
} // namespace xla } // namespace xla

View File

@ -35,9 +35,9 @@ namespace xla {
namespace { namespace {
// A custom PjRtClient that overrides the device assignment method. // A custom PjRtClient that overrides the device assignment method.
class GpuClient : public xla::PjRtClient { class GpuClient : public xla::PjRtStreamExecutorClient {
public: public:
using xla::PjRtClient::PjRtClient; using xla::PjRtStreamExecutorClient::PjRtStreamExecutorClient;
xla::StatusOr<xla::DeviceAssignment> GetDefaultDeviceAssignment( xla::StatusOr<xla::DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override; int num_replicas, int num_partitions) const override;
@ -55,7 +55,8 @@ xla::StatusOr<xla::DeviceAssignment> GpuClient::GetDefaultDeviceAssignment(
return assignment; return assignment;
} }
// Fallback to default global device assignment if we can't run locally. // Fallback to default global device assignment if we can't run locally.
return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions); return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
num_partitions);
} }
// Builds an xla::LocalClient for the GPU platform. // Builds an xla::LocalClient for the GPU platform.
@ -225,9 +226,9 @@ StatusOr<std::string> NcclIdStore::GetNcclUniqueId(
return result.first->second; return result.first->second;
} }
std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices( std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> BuildLocalDevices(
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) { std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
std::vector<std::unique_ptr<PjRtDevice>> devices; std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
for (auto& local_device : local_device_states) { for (auto& local_device : local_device_states) {
int device_ordinal = local_device->device_ordinal(); int device_ordinal = local_device->device_ordinal();
const se::DeviceDescription& description = const se::DeviceDescription& description =
@ -243,7 +244,7 @@ std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
Status BuildDistributedDevices( Status BuildDistributedDevices(
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states, std::vector<std::unique_ptr<LocalDeviceState>> local_device_states,
std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id, std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id,
std::vector<std::unique_ptr<PjRtDevice>>* devices, std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>* devices,
gpu::GpuExecutableRunOptions* gpu_executable_run_options) { gpu::GpuExecutableRunOptions* gpu_executable_run_options) {
LocalTopologyProto local_topology; LocalTopologyProto local_topology;
local_topology.set_node_id(node_id); local_topology.set_node_id(node_id);
@ -322,7 +323,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient(
auto host_memory_allocator = auto host_memory_allocator =
GetGpuHostAllocator(local_device_states.front()->executor()); GetGpuHostAllocator(local_device_states.front()->executor());
std::vector<std::unique_ptr<PjRtDevice>> devices; std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
auto gpu_run_options = absl::make_unique<gpu::GpuExecutableRunOptions>(); auto gpu_run_options = absl::make_unique<gpu::GpuExecutableRunOptions>();
if (distributed_client) { if (distributed_client) {
TF_RETURN_IF_ERROR(BuildDistributedDevices( TF_RETURN_IF_ERROR(BuildDistributedDevices(

View File

@ -41,7 +41,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
TF_ASSIGN_OR_RETURN(LocalClient * client, TF_ASSIGN_OR_RETURN(LocalClient * client,
ClientLibrary::GetOrCreateLocalClient(options)); ClientLibrary::GetOrCreateLocalClient(options));
std::vector<std::unique_ptr<PjRtDevice>> devices; std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
se::StreamExecutor* executor = se::StreamExecutor* executor =
client->backend().stream_executor(0).ValueOrDie(); client->backend().stream_executor(0).ValueOrDie();
auto device_state = absl::make_unique<LocalDeviceState>( auto device_state = absl::make_unique<LocalDeviceState>(
@ -51,11 +51,11 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
absl::make_unique<InterpreterDevice>(0, std::move(device_state)); absl::make_unique<InterpreterDevice>(0, std::move(device_state));
devices.push_back(std::move(device)); devices.push_back(std::move(device));
return std::make_unique<PjRtClient>( return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
"interpreter", client, std::move(devices), /*host_id=*/0, "interpreter", client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false, /*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr); /*gpu_run_options=*/nullptr));
} }
} // namespace xla } // namespace xla

View File

@ -184,9 +184,9 @@ class CpuAllocator : public tensorflow::Allocator {
} }
}; };
PjRtClient::PjRtClient( PjRtStreamExecutorClient::PjRtStreamExecutorClient(
std::string platform_name, LocalClient* client, std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id, std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator, std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers, bool should_stage_host_to_device_transfers,
@ -195,7 +195,7 @@ PjRtClient::PjRtClient(
platform_name_(std::move(platform_name)), platform_name_(std::move(platform_name)),
client_(client), client_(client),
host_memory_allocator_(std::move(host_memory_allocator)), host_memory_allocator_(std::move(host_memory_allocator)),
devices_(std::move(devices)), owned_devices_(std::move(devices)),
host_id_(host_id), host_id_(host_id),
owned_allocator_(std::move(allocator)), owned_allocator_(std::move(allocator)),
should_stage_host_to_device_transfers_( should_stage_host_to_device_transfers_(
@ -213,7 +213,9 @@ PjRtClient::PjRtClient(
host_memory_allocator_ = std::make_unique<CpuAllocator>(); host_memory_allocator_ = std::make_unique<CpuAllocator>();
} }
for (const std::unique_ptr<PjRtDevice>& device : devices_) { for (const std::unique_ptr<PjRtStreamExecutorDevice>& device :
owned_devices_) {
devices_.push_back(device.get());
CHECK(id_to_device_.insert({device->id(), device.get()}).second) CHECK(id_to_device_.insert({device->id(), device.get()}).second)
<< "Duplicate device id: " << device->id(); << "Duplicate device id: " << device->id();
@ -225,21 +227,21 @@ PjRtClient::PjRtClient(
CHECK(local_devices_[idx] == nullptr) << idx; CHECK(local_devices_[idx] == nullptr) << idx;
local_devices_[idx] = device.get(); local_devices_[idx] = device.get();
} }
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device.get()) device->SetClient(this);
->SetClient(this);
} }
for (int idx = 0; idx < local_devices_.size(); ++idx) { for (int idx = 0; idx < local_devices_.size(); ++idx) {
CHECK(local_devices_[idx] != nullptr) << idx; CHECK(local_devices_[idx] != nullptr) << idx;
} }
} }
StatusOr<DeviceAssignment> PjRtClient::GetDefaultDeviceAssignment( StatusOr<DeviceAssignment> PjRtStreamExecutorClient::GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const { int num_replicas, int num_partitions) const {
return client_->backend().computation_placer()->AssignDevices(num_replicas, return client_->backend().computation_placer()->AssignDevices(num_replicas,
num_partitions); num_partitions);
} }
std::unique_ptr<HloCostAnalysis> PjRtClient::GetHloCostAnalysis() { std::unique_ptr<HloCostAnalysis>
PjRtStreamExecutorClient::GetHloCostAnalysis() {
return absl::make_unique<HloCostAnalysis>( return absl::make_unique<HloCostAnalysis>(
client_->backend().compiler()->ShapeSizeBytesFunction()); client_->backend().compiler()->ShapeSizeBytesFunction());
} }
@ -349,12 +351,13 @@ StatusOr<std::unique_ptr<PjRtBuffer>> AllocateDestinationBuffer(
return InvalidArgument("Can't make a buffer from an empty tuple"); return InvalidArgument("Can't make a buffer from an empty tuple");
} }
auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client);
TransferManager* transfer_manager = TransferManager* transfer_manager =
client->client()->backend().transfer_manager(); se_client->client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer,
ScopedShapedBuffer dst_buffer, transfer_manager->AllocateScopedShapedBuffer(
transfer_manager->AllocateScopedShapedBuffer( on_host_shape, se_client->allocator(),
on_host_shape, client->allocator(), local_device->device_ordinal())); local_device->device_ordinal()));
if (local_device->allocation_model() == if (local_device->allocation_model() ==
LocalDeviceState::kComputeSynchronized) { LocalDeviceState::kComputeSynchronized) {
if (copy_stream == nullptr) { if (copy_stream == nullptr) {
@ -546,13 +549,15 @@ void PjRtBuffer::ScopedHold::AddToInput(
bool PjRtBuffer::IsOnCpu() const { return client()->platform_id() == kCpuId; } bool PjRtBuffer::IsOnCpu() const { return client()->platform_id() == kCpuId; }
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer( StatusOr<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorClient::BufferFromHostBuffer(
const void* data, const Shape& shape, const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics, HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device) { std::shared_ptr<void> buffer_reference, PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostBuffer"); tensorflow::profiler::TraceMe traceme(
VLOG(2) << "PjRtClient::BufferFromHostBuffer: shape: " << shape.ToString() "PjRtStreamExecutorClient::BufferFromHostBuffer");
<< " device: " << device->DebugString(); VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: "
<< shape.ToString() << " device: " << device->DebugString();
if (shape.IsTuple()) { if (shape.IsTuple()) {
return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple"); return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
} }
@ -712,17 +717,19 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
return py_buffer; return py_buffer;
} }
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer( StatusOr<std::unique_ptr<PjRtBuffer>>
const Shape& shape, PjRtDevice* device) { PjRtStreamExecutorClient::CreateUninitializedBuffer(const Shape& shape,
PjRtDevice* device) {
return CreateUninitializedBuffer(shape, device, nullptr); return CreateUninitializedBuffer(shape, device, nullptr);
} }
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer( StatusOr<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorClient::CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device, const Shape& shape, PjRtDevice* device,
std::shared_ptr<BufferSequencingEvent> definition_event) { std::shared_ptr<BufferSequencingEvent> definition_event) {
tensorflow::profiler::TraceMe traceme( tensorflow::profiler::TraceMe traceme(
"PjRtClient::CreateUninitializedBuffer"); "PjRtStreamExecutorClient::CreateUninitializedBuffer");
VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: " VLOG(2) << "PjRtStreamExecutorClient::CreateUninitializedBuffer: shape: "
<< shape.ToString() << " device: " << device->DebugString(); << shape.ToString() << " device: " << device->DebugString();
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device) tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
@ -738,10 +745,12 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
definition_event); definition_event);
} }
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral( StatusOr<std::unique_ptr<PjRtBuffer>>
const LiteralSlice& literal, PjRtDevice* device) { PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostLiteral"); PjRtDevice* device) {
VLOG(2) << "PjRtClient::BufferFromHostLiteral: shape: " tensorflow::profiler::TraceMe traceme(
"PjRtStreamExecutorClient::BufferFromHostLiteral");
VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostLiteral: shape: "
<< literal.shape().ToString() << " device: " << device->DebugString(); << literal.shape().ToString() << " device: " << device->DebugString();
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device) tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
@ -798,7 +807,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
return py_buffer; return py_buffer;
} }
void PjRtClient::MakeCrossHostReceiveBuffers( void PjRtStreamExecutorClient::MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtDevice* device, absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier) { PjRtCrossHostRecvNotifier&& notifier) {
if (shapes.empty()) { if (shapes.empty()) {
@ -851,14 +860,15 @@ StatusOr<Literal> PjRtStreamExecutorDevice::TransferFromOutfeed(
shape, local_device->device_ordinal()); shape, local_device->device_ordinal());
} }
StatusOr<PjRtDevice*> PjRtClient::LookupAddressableDevice(int device_id) const { StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
int local_hardware_id) const {
for (auto* device : local_devices_) { for (auto* device : local_devices_) {
if (device_id == device->local_hardware_id()) { if (local_hardware_id == device->local_hardware_id()) {
return device; return device;
} }
} }
return InvalidArgument("No matching device found for device_id %d", return InvalidArgument("No matching device found for local_hardware_id %d",
device_id); local_hardware_id);
} }
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape, PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
@ -883,7 +893,8 @@ PjRtBuffer::~PjRtBuffer() {
} }
int64 PjRtBuffer::OnDeviceSizeInBytes() const { int64 PjRtBuffer::OnDeviceSizeInBytes() const {
return client_->client() return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
->client()
->backend() ->backend()
.transfer_manager() .transfer_manager()
->GetByteSizeRequirement(on_device_shape_); ->GetByteSizeRequirement(on_device_shape_);
@ -1136,12 +1147,16 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
host_value->value = std::make_shared<Literal>(host_shape); host_value->value = std::make_shared<Literal>(host_shape);
ShapedBuffer shaped_buffer = ShapedBuffer shaped_buffer =
device_buffer->AsShapedBuffer(host_shape, on_device_shape_); device_buffer->AsShapedBuffer(host_shape, on_device_shape_);
client_->client()->backend().transfer_manager()->TransferLiteralFromDevice( tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
stream, shaped_buffer, host_value->value.get(), ->client()
[host_value](Status done_status) { ->backend()
host_value->status = done_status; .transfer_manager()
host_value->ready.Notify(); ->TransferLiteralFromDevice(stream, shaped_buffer,
}); host_value->value.get(),
[host_value](Status done_status) {
host_value->status = done_status;
host_value->ready.Notify();
});
auto usage_event = std::make_shared<BufferSequencingEvent>(); auto usage_event = std::make_shared<BufferSequencingEvent>();
StatusOr<EventPool::Handle> event_or = StatusOr<EventPool::Handle> event_or =
@ -1170,7 +1185,7 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral( StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
const bool discard_cached_copy, absl::optional<xla::Layout> layout) { const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
tensorflow::profiler::TraceMe traceme("PjRtClient::ToLiteral"); tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::ToLiteral");
TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value, TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
CopyToHostAsyncInternal(discard_cached_copy, layout)); CopyToHostAsyncInternal(discard_cached_copy, layout));
if (host_value == nullptr) { if (host_value == nullptr) {
@ -1280,7 +1295,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral()); TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
return dst_device->client()->BufferFromHostBuffer( return dst_device->client()->BufferFromHostBuffer(
literal->untyped_data(), literal->shape(), literal->untyped_data(), literal->shape(),
PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, dst_device); PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy, nullptr,
dst_device);
} }
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
@ -1288,7 +1304,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device) tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device)
->GetLocalDeviceState()); ->GetLocalDeviceState());
LocalDeviceState* transfer_local_device = LocalDeviceState* transfer_local_device =
client_->EnqueueD2DTransfersOnSrcStream() tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
->EnqueueD2DTransfersOnSrcStream()
? tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_) ? tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
->local_device_state() ->local_device_state()
: dst_local_device; : dst_local_device;
@ -1339,7 +1356,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
} }
Status PjRtBuffer::CopyToRemoteDevice(absl::string_view serialized_descriptor) { Status PjRtBuffer::CopyToRemoteDevice(absl::string_view serialized_descriptor) {
return client_->CopyToRemoteDevice(this, serialized_descriptor); return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
->CopyToRemoteDevice(this, serialized_descriptor);
} }
Status PjRtBuffer::BlockHostUntilReady() { Status PjRtBuffer::BlockHostUntilReady() {
@ -1401,9 +1419,13 @@ StatusOr<TupleHandle> MakeTupleHelper(
Shape on_host_shape = ShapeUtil::MakeTupleShape(host_shapes); Shape on_host_shape = ShapeUtil::MakeTupleShape(host_shapes);
Shape on_device_shape = ShapeUtil::MakeTupleShape(device_shapes); Shape on_device_shape = ShapeUtil::MakeTupleShape(device_shapes);
se::DeviceMemoryAllocator* allocator = client->allocator(); se::DeviceMemoryAllocator* allocator =
tensorflow::down_cast<PjRtStreamExecutorClient*>(client)->allocator();
TransferManager* transfer_manager = TransferManager* transfer_manager =
client->client()->backend().transfer_manager(); tensorflow::down_cast<PjRtStreamExecutorClient*>(client)
->client()
->backend()
.transfer_manager();
se::Stream* stream = local_device->host_to_device_stream(); se::Stream* stream = local_device->host_to_device_stream();
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
se::OwningDeviceMemory root_table_memory, se::OwningDeviceMemory root_table_memory,
@ -1467,14 +1489,6 @@ std::unique_ptr<PjRtBuffer> OutputBufferHelper(
/*prefer_to_retain_reference=*/false); /*prefer_to_retain_reference=*/false);
return pjrt_buffer; return pjrt_buffer;
} }
static PjRtDevice* LookupDevice(const PjRtClient& client, int device_id) {
auto it = client.id_to_device().find(device_id);
CHECK(it != client.id_to_device().end())
<< "Unknown device id: " << device_id;
return it->second;
}
} // namespace } // namespace
PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable( PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
@ -1482,7 +1496,8 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
bool parameter_is_tupled_arguments, bool parameter_is_tupled_arguments,
std::shared_ptr<DeviceAssignment> device_assignment, std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<LogicalDeviceIds> addressable_device_logical_ids, std::vector<LogicalDeviceIds> addressable_device_logical_ids,
std::vector<PjRtDevice*> addressable_devices, PjRtClient* client) std::vector<PjRtDevice*> addressable_devices,
PjRtStreamExecutorClient* client)
: client_(client), : client_(client),
device_assignment_(std::move(device_assignment)), device_assignment_(std::move(device_assignment)),
parameter_is_tupled_arguments_(parameter_is_tupled_arguments), parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
@ -1505,7 +1520,7 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
VLOG(1) << "PjRtStreamExecutorExecutable device_assignment:\n" VLOG(1) << "PjRtStreamExecutorExecutable device_assignment:\n"
<< device_assignment_->ToString(); << device_assignment_->ToString();
CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString(); CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString();
CHECK_LE(addressable_devices_.size(), client_->local_device_count()) CHECK_LE(addressable_devices_.size(), client_->addressable_device_count())
<< "Inconsistent local device count."; << "Inconsistent local device count.";
num_partitions = device_assignment_->computation_count(); num_partitions = device_assignment_->computation_count();
} }
@ -1607,7 +1622,7 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
absl::Span<const PjRtBuffer::ScopedHold> device_buffers, absl::Span<const PjRtBuffer::ScopedHold> device_buffers,
absl::flat_hash_set<BufferSequencingEvent*>& events) const { absl::flat_hash_set<BufferSequencingEvent*>& events) const {
std::vector<ExecutionInput> execution_inputs; std::vector<ExecutionInput> execution_inputs;
LocalDeviceState* device_state = &client_->device_state(device_ordinal); LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
// Lift tuple_handle outside the conditional so that the event it returns is // Lift tuple_handle outside the conditional so that the event it returns is
// not destroyed until after the loop below that waits on events. // not destroyed until after the loop below that waits on events.
absl::optional<TupleHandle> tuple_handle; absl::optional<TupleHandle> tuple_handle;
@ -1630,8 +1645,10 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
execution_input.MutableBuffers()->begin(); execution_input.MutableBuffers()->begin();
ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end = ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
execution_input.MutableBuffers()->end(); execution_input.MutableBuffers()->end();
device_buffers[i].AddToInput(&input_iterator, iterator_end, device_buffers[i].AddToInput(
&execution_input, client_->allocator()); &input_iterator, iterator_end, &execution_input,
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
->allocator());
CHECK(input_iterator == iterator_end); CHECK(input_iterator == iterator_end);
} }
} }
@ -1654,7 +1671,7 @@ StatusOr<ScopedShapedBuffer> PjRtStreamExecutorExecutable::EnqueueExecution(
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device) int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->local_device_state() ->local_device_state()
->device_ordinal(); ->device_ordinal();
LocalDeviceState* device_state = &client_->device_state(device_ordinal); LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
tensorflow::profiler::TraceMeConsumer activity( tensorflow::profiler::TraceMeConsumer activity(
"LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt, "LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
run_id.ToInt()); run_id.ToInt());
@ -1790,7 +1807,7 @@ PjRtStreamExecutorExecutable::MakeOutputBuffers(
std::shared_ptr<BufferSequencingEvent> definition_event, std::shared_ptr<BufferSequencingEvent> definition_event,
PjRtDevice* device) const { PjRtDevice* device) const {
std::vector<std::unique_ptr<PjRtBuffer>> outputs; std::vector<std::unique_ptr<PjRtBuffer>> outputs;
LocalDeviceState* device_state = &client_->device_state(device_ordinal); LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) { if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) {
int tuple_count = result_buffer.on_host_shape().tuple_shapes_size(); int tuple_count = result_buffer.on_host_shape().tuple_shapes_size();
outputs.reserve(tuple_count); outputs.reserve(tuple_count);
@ -1827,7 +1844,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
if (device == nullptr) { if (device == nullptr) {
CHECK(device_assignment_ != nullptr); CHECK(device_assignment_ != nullptr);
const int device_id = (*device_assignment_)(replica, partition); const int device_id = (*device_assignment_)(replica, partition);
device = LookupDevice(*client_, device_id); TF_ASSIGN_OR_RETURN(device, client_->LookupDevice(device_id));
device_assignment = device_assignment_; device_assignment = device_assignment_;
} else { } else {
CHECK(device_assignment_ == nullptr); CHECK(device_assignment_ == nullptr);
@ -1863,7 +1880,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
ScopedShapedBuffer result_buffer = ScopedShapedBuffer result_buffer =
result_buffer_or_status.ConsumeValueOrDie(); result_buffer_or_status.ConsumeValueOrDie();
LocalDeviceState* device_state = &client_->device_state(device_ordinal); LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
se::Stream* stream = device_state->compute_stream(); se::Stream* stream = device_state->compute_stream();
StatusOr<EventPool::Handle> event_or = StatusOr<EventPool::Handle> event_or =
device_state->event_pool().ThenAllocateAndRecordEvent(stream); device_state->event_pool().ThenAllocateAndRecordEvent(stream);
@ -2160,9 +2177,9 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
} // namespace } // namespace
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile( StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
const XlaComputation& computation, CompileOptions options) { const XlaComputation& computation, CompileOptions options) {
tensorflow::profiler::TraceMe traceme("PjRtClient::Compile"); tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
ExecutableBuildOptions& build_options = options.executable_build_options; ExecutableBuildOptions& build_options = options.executable_build_options;
if (!build_options.device_allocator()) { if (!build_options.device_allocator()) {
@ -2182,14 +2199,15 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
num_partitions = 1; num_partitions = 1;
} else { } else {
if (!build_options.has_device_assignment()) { if (!build_options.has_device_assignment()) {
VLOG(2) << "PjRtClient::Compile using default device_assignment."; VLOG(2) << "PjRtStreamExecutorClient::Compile using default "
"device_assignment.";
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
DeviceAssignment device_assignment, DeviceAssignment device_assignment,
GetDefaultDeviceAssignment(build_options.num_replicas(), GetDefaultDeviceAssignment(build_options.num_replicas(),
build_options.num_partitions())); build_options.num_partitions()));
build_options.set_device_assignment(device_assignment); build_options.set_device_assignment(device_assignment);
} }
VLOG(2) << "PjRtClient::Compile device_assignment:\n" VLOG(2) << "PjRtStreamExecutorClient::Compile device_assignment:\n"
<< build_options.device_assignment().ToString(); << build_options.device_assignment().ToString();
num_replicas = build_options.device_assignment().replica_count(); num_replicas = build_options.device_assignment().replica_count();
num_partitions = build_options.device_assignment().computation_count(); num_partitions = build_options.device_assignment().computation_count();
@ -2263,7 +2281,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
for (int replica = 0; replica < num_replicas; ++replica) { for (int replica = 0; replica < num_replicas; ++replica) {
for (int partition = 0; partition < num_partitions; ++partition) { for (int partition = 0; partition < num_partitions; ++partition) {
int device_id = (*device_assignment)(replica, partition); int device_id = (*device_assignment)(replica, partition);
PjRtDevice* device = LookupDevice(*this, device_id); TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id));
if (device->host_id() != host_id()) { if (device->host_id() != host_id()) {
VLOG(3) << "Non-local device: " << device_id; VLOG(3) << "Non-local device: " << device_id;
continue; continue;
@ -2283,10 +2301,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
if (build_options.device_ordinal() < 0) { if (build_options.device_ordinal() < 0) {
build_options.set_device_ordinal( build_options.set_device_ordinal(
tensorflow::down_cast<PjRtStreamExecutorDevice*>( addressable_devices.front()->local_hardware_id());
addressable_devices.front())
->local_device_state()
->device_ordinal());
} }
} }

View File

@ -219,86 +219,62 @@ class PjRtExecutable;
// alive as long as any of the other runtime objects are alive. // alive as long as any of the other runtime objects are alive.
class PjRtClient { class PjRtClient {
public: public:
// `allocator` may null, in which case the platform default allocator is used.
explicit PjRtClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
virtual ~PjRtClient() = default; virtual ~PjRtClient() = default;
// TODO(zhangqiaorjc): Rename to task_id.
// Return the task id of this client. In single-task setting, always 0.
virtual int host_id() const = 0;
// Return the number of devices in the entire computation. In multi-headed
// client setting, some are addressable by this client, some are not. In a
// single-client setting, this is equal to the number of addressable devices.
virtual int device_count() const = 0;
// Return number of addressable devices. Addressable devices are those that
// the client can issue commands to.
virtual int addressable_device_count() const = 0;
// Return all devices in the entire computation, including addressable and
// non-addressable devices.
virtual absl::Span<PjRtDevice* const> devices() const = 0;
// TODO(zhangqiaorjc): Rename to addressable_devices.
// Return only addressable devices.
virtual absl::Span<PjRtDevice* const> local_devices() const = 0;
// Lookup any PjRtDevice for a given PjRtDevice::id().
virtual StatusOr<PjRtDevice*> LookupDevice(int device_id) const = 0;
// Return an addressable PjRtDevice for a given
// PjRtDevice::local_hardware_id().
virtual StatusOr<PjRtDevice*> LookupAddressableDevice(
int local_hardware_id) const = 0;
// Return an ID that identifies the platform (CPU/GPU/TPU).
virtual PjRtPlatformId platform_id() const = 0;
// Returns a string that identifies the platform (CPU/GPU/TPU).
virtual const std::string& platform_name() const = 0;
// Return a device-specific default device assignment, e.g., GPU and TPU may
// be different.
virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment( virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const; int num_replicas, int num_partitions) const = 0;
int device_count() const { return devices_.size(); }
int local_device_count() const { return local_devices_.size(); }
const std::vector<std::unique_ptr<PjRtDevice>>& devices() const {
return devices_;
}
absl::Span<PjRtDevice* const> local_devices() const { return local_devices_; }
const std::map<int, PjRtDevice*>& id_to_device() const {
return id_to_device_;
}
int host_id() const { return host_id_; }
PjRtPlatformId platform_id() const { return platform_id_; }
const std::string& platform_name() const { return platform_name_; }
LocalDeviceState& device_state(int device_ordinal) const {
return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
local_devices_.at(device_ordinal))
->local_device_state();
}
// Return an addressable PjRtDevice for a given `device_id`.
virtual StatusOr<PjRtDevice*> LookupAddressableDevice(int device_id) const;
LocalClient* client() const { return client_; }
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
tensorflow::Allocator* host_memory_allocator() const {
return host_memory_allocator_.get();
}
bool should_stage_host_to_device_transfers() const {
return should_stage_host_to_device_transfers_;
}
gpu::GpuExecutableRunOptions* gpu_run_options() const {
return gpu_run_options_.get();
}
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
return &h2d_transfer_pool_;
}
// Most platforms expect device-to-device transfers to be enqueued on the
// source d2d stream, but some platforms use the destination d2d stream. This
// function specifies which one the platform expects.
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
// Generates a unique fingerprint for `executable`.
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
const PjRtExecutable& executable) const {
return absl::optional<std::string>();
}
// Returns a backend-specific HLO cost analysis visitor. // Returns a backend-specific HLO cost analysis visitor.
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis(); virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() = 0;
// Compile `computation` with given `options`.
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile( virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
const XlaComputation& computation, CompileOptions options); const XlaComputation& computation, CompileOptions options) = 0;
// Generates a unique fingerprint for `executable`, may be absl::nullopt.
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
const PjRtExecutable& executable) const = 0;
// Creates a buffer on the device without initializing or copying any data. // Creates a buffer on the device without initializing or copying any data.
// An optional `definition_event` may be speficied that can be used to
// ensure the buffer isn't referenced until some external mechanism has
// initialized the data.
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
// future backends and so callers should avoid wherever possible.
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer( virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device); const Shape& shape, PjRtDevice* device) = 0;
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device,
std::shared_ptr<BufferSequencingEvent> definition_event);
// Describes the semantics the caller to BufferFromHostBuffer expects from the // Describes the semantics the caller to BufferFromHostBuffer expects from the
// runtime, in a total order from most restrictive to least restrictive. // runtime, in a total order from most restrictive to least restrictive.
@ -330,13 +306,13 @@ class PjRtClient {
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer( virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, const Shape& shape, const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics, HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device); std::shared_ptr<void> buffer_reference, PjRtDevice* device) = 0;
// Note that literal must remain in scope until the transfer has completed, so // Note that literal must remain in scope until the transfer has completed, so
// the caller should, for example, wait for BlockHostUntilReady() completes on // the caller should, for example, wait for BlockHostUntilReady() completes on
// the return value before letting literal go out of scope. // the return value before letting literal go out of scope.
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral( virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
const LiteralSlice& literal, PjRtDevice* device); const LiteralSlice& literal, PjRtDevice* device) = 0;
// Asynchronously makes a vector of PjRtBuffers that can be used to receive // Asynchronously makes a vector of PjRtBuffers that can be used to receive
// cross host transfers using `client` on `device'. `shapes` must be the exact // cross host transfers using `client` on `device'. `shapes` must be the exact
@ -349,18 +325,140 @@ class PjRtClient {
// buffers will become ready until *all* of the sends have completed. // buffers will become ready until *all* of the sends have completed.
virtual void MakeCrossHostReceiveBuffers( virtual void MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtDevice* device, absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier); PjRtCrossHostRecvNotifier&& notifier) = 0;
virtual StatusOr<ChannelHandle> CreateChannelHandle() { // Create ChannelHandles for XLA send/recv.
virtual StatusOr<ChannelHandle> CreateChannelHandle() = 0;
virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() = 0;
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() = 0;
};
class PjRtStreamExecutorClient : public PjRtClient {
public:
// `allocator` may null, in which case the platform default allocator is used.
explicit PjRtStreamExecutorClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
int host_id, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
~PjRtStreamExecutorClient() override = default;
int host_id() const override { return host_id_; }
int device_count() const override { return devices_.size(); }
int addressable_device_count() const override {
return local_devices_.size();
}
absl::Span<PjRtDevice* const> devices() const override { return devices_; }
absl::Span<PjRtDevice* const> local_devices() const override {
return local_devices_;
}
StatusOr<PjRtDevice*> LookupDevice(int device_id) const override {
auto it = id_to_device_.find(device_id);
if (it != id_to_device_.end()) {
return it->second;
}
return InvalidArgument("No matching device found for device_id %d",
device_id);
}
StatusOr<PjRtDevice*> LookupAddressableDevice(
int local_hardware_id) const override;
PjRtPlatformId platform_id() const override { return platform_id_; }
const std::string& platform_name() const override { return platform_name_; }
// Most platforms expect device-to-device transfers to be enqueued on the
// source d2d stream, but some platforms use the destination d2d stream. This
// function specifies which one the platform expects.
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override;
StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
const XlaComputation& computation, CompileOptions options) override;
// Generates a unique fingerprint for `executable`.
StatusOr<absl::optional<std::string>> ExecutableFingerprint(
const PjRtExecutable& executable) const override {
return absl::optional<std::string>();
}
// Returns a backend-specific HLO cost analysis visitor.
std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() override;
// Creates a buffer on the device without initializing or copying any data.
// An optional `definition_event` may be speficied that can be used to
// ensure the buffer isn't referenced until some external mechanism has
// initialized the data.
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
// future backends and so callers should avoid wherever possible.
StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device) override;
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device,
std::shared_ptr<BufferSequencingEvent> definition_event);
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device) override;
// Note that literal must remain in scope until the transfer has completed, so
// the caller should, for example, wait for BlockHostUntilReady() completes on
// the return value before letting literal go out of scope.
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
const LiteralSlice& literal, PjRtDevice* device) override;
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
// cross host transfers using `client` on `device'. `shapes` must be the exact
// shapes, with identical layouts, corresponding to the buffers that will be
// sent. When resources for the transfer are available, notifier will be
// called with a vector of PjRtCrossHostRecvBuffer structs, one for each
// shape in `shapes`. Each struct contains a buffer that will contain the
// received value, and an opaque string that should be transmitted to the
// sending host and used in a call to CopyToRemoteDevice. None of the recv
// buffers will become ready until *all* of the sends have completed.
void MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier) override;
StatusOr<ChannelHandle> CreateChannelHandle() override {
return client()->CreateChannelHandle(); return client()->CreateChannelHandle();
} }
virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() { StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override {
return client()->CreateDeviceToHostChannelHandle(); return client()->CreateDeviceToHostChannelHandle();
} }
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() { StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override {
return client()->CreateHostToDeviceChannelHandle(); return client()->CreateHostToDeviceChannelHandle();
} }
LocalDeviceState& device_state(int device_ordinal) const {
return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
local_devices_.at(device_ordinal))
->local_device_state();
}
LocalClient* client() const { return client_; }
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
tensorflow::Allocator* host_memory_allocator() const {
return host_memory_allocator_.get();
}
bool should_stage_host_to_device_transfers() const {
return should_stage_host_to_device_transfers_;
}
gpu::GpuExecutableRunOptions* gpu_run_options() const {
return gpu_run_options_.get();
}
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
return &h2d_transfer_pool_;
}
protected: protected:
friend class PjRtBuffer; friend class PjRtBuffer;
virtual void EnqueueCrossHostReceive( virtual void EnqueueCrossHostReceive(
@ -383,7 +481,9 @@ class PjRtClient {
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_; std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
// Includes all devices, including non-local devices on multi-host platforms. // Includes all devices, including non-local devices on multi-host platforms.
std::vector<std::unique_ptr<PjRtDevice>> devices_; std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> owned_devices_;
// Pointers to `owned_devices_`.
std::vector<PjRtDevice*> devices_;
// Maps Device::id() to the corresponding Device. Includes all devices. // Maps Device::id() to the corresponding Device. Includes all devices.
std::map<int, PjRtDevice*> id_to_device_; std::map<int, PjRtDevice*> id_to_device_;
// Local devices indexed by local device ordinal. // Local devices indexed by local device ordinal.
@ -550,7 +650,7 @@ class PjRtBuffer {
private: private:
friend class PjRtBuffer; friend class PjRtBuffer;
friend class PjRtClient; friend class PjRtStreamExecutorClient;
// Helper struct that makes it possible to move a ScopedHold through a // Helper struct that makes it possible to move a ScopedHold through a
// closure. // closure.
@ -810,7 +910,7 @@ class PjRtExecutable {
virtual PjRtClient* client() const = 0; virtual PjRtClient* client() const = 0;
// Unique name for this executable, e.g., HloModule name. // Unique name for this executable, e.g., HloModule name.
virtual const string& name() const = 0; virtual const std::string& name() const = 0;
virtual int num_replicas() const = 0; virtual int num_replicas() const = 0;
@ -875,13 +975,14 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
bool parameter_is_tupled_arguments, bool parameter_is_tupled_arguments,
std::shared_ptr<DeviceAssignment> device_assignment, std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<LogicalDeviceIds> addressable_device_logical_ids, std::vector<LogicalDeviceIds> addressable_device_logical_ids,
std::vector<PjRtDevice*> addressable_devices, PjRtClient* client); std::vector<PjRtDevice*> addressable_devices,
PjRtStreamExecutorClient* client);
~PjRtStreamExecutorExecutable() override = default; ~PjRtStreamExecutorExecutable() override = default;
PjRtClient* client() const override { return client_; } PjRtStreamExecutorClient* client() const override { return client_; }
const string& name() const override; const std::string& name() const override;
int num_replicas() const override { int num_replicas() const override {
return executables_[0]->build_options().num_replicas(); return executables_[0]->build_options().num_replicas();
@ -940,7 +1041,7 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
} }
private: private:
friend class PjRtClient; friend class PjRtStreamExecutorClient;
// Initializes information about which arguments to which executables must be // Initializes information about which arguments to which executables must be
// donated due to aliases that were specified by the computation. // donated due to aliases that were specified by the computation.
Status SetUpDonation(bool tuple_inputs); Status SetUpDonation(bool tuple_inputs);
@ -975,7 +1076,7 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
// Create shared pointers so we can free them after the execution: with // Create shared pointers so we can free them after the execution: with
// asynchronous execution, the process being executed can outlive the // asynchronous execution, the process being executed can outlive the
// executable itself. // executable itself.
PjRtClient* const client_; PjRtStreamExecutorClient* const client_;
// One executable per partition. // One executable per partition.
std::vector<std::shared_ptr<LocalExecutable>> executables_; std::vector<std::shared_ptr<LocalExecutable>> executables_;
// Per-executable set of parameters that have any aliased buffers and thus // Per-executable set of parameters that have any aliased buffers and thus

View File

@ -94,10 +94,11 @@ Status TpuDeviceState::ThenMemcpyDeviceToDevice(
return Status::OK(); return Status::OK();
} }
class PjRtTpuClient : public PjRtClient { class PjRtTpuClient : public PjRtStreamExecutorClient {
public: public:
PjRtTpuClient(LocalClient* client, PjRtTpuClient(LocalClient* client,
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id); std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
int host_id);
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment( StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override; int num_replicas, int num_partitions) const override;
@ -108,14 +109,14 @@ class PjRtTpuClient : public PjRtClient {
const PjRtExecutable& executable) const override; const PjRtExecutable& executable) const override;
}; };
PjRtTpuClient::PjRtTpuClient(LocalClient* client, PjRtTpuClient::PjRtTpuClient(
std::vector<std::unique_ptr<PjRtDevice>> devices, LocalClient* client,
int host_id) std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id)
: PjRtClient(kTpuName, client, std::move(devices), host_id, : PjRtStreamExecutorClient(kTpuName, client, std::move(devices), host_id,
/*allocator=*/nullptr, /*allocator=*/nullptr,
/*host_memory_allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false, /*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr) {} /*gpu_run_options=*/nullptr) {}
StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment( StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const { int num_replicas, int num_partitions) const {
@ -128,7 +129,8 @@ StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
num_partitions); num_partitions);
} }
// Fallback to default global device assignment if we can't run locally. // Fallback to default global device assignment if we can't run locally.
return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions); return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
num_partitions);
} }
StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint( StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
@ -152,10 +154,10 @@ StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
return absl::optional<std::string>(tpu_executable->fingerprint()); return absl::optional<std::string>(tpu_executable->fingerprint());
} }
StatusOr<std::vector<std::unique_ptr<PjRtDevice>>> GetTpuDevices( StatusOr<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>> GetTpuDevices(
LocalClient* client, LocalClient* client,
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) { std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
std::vector<std::unique_ptr<PjRtDevice>> devices; std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
tf_tpu::TpuTopologyExternal topology = tf_tpu::TpuTopologyExternal topology =
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology(); tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();

View File

@ -230,8 +230,8 @@ OutfeedReceiverImpl::OutfeedReceiverImpl(
callback_ = callback; callback_ = callback;
max_callback_queue_size_bytes_ = max_callback_queue_size_bytes; max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
for (const auto& client : clients) { for (const auto& client : clients) {
for (const auto& device : client->devices()) { for (auto device : client->devices()) {
devices_.push_back(device.get()); devices_.push_back(device);
} }
} }
CHECK_GT(devices_.size(), 0); CHECK_GT(devices_.size(), 0);

View File

@ -37,9 +37,10 @@ PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() { std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() {
std::vector<ClientAndPtr<PjRtDevice>> devices; std::vector<ClientAndPtr<PjRtDevice>> devices;
devices.reserve(pjrt_client_->devices().size()); auto span = pjrt_client_->devices();
for (const auto& device : pjrt_client_->devices()) { devices.reserve(span.size());
devices.push_back(WrapWithClient(shared_from_this(), device.get())); for (PjRtDevice* device : span) {
devices.push_back(WrapWithClient(shared_from_this(), device));
} }
return devices; return devices;
} }
@ -64,9 +65,9 @@ PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
result[r].resize(num_partitions); result[r].resize(num_partitions);
for (int p = 0; p < num_partitions; ++p) { for (int p = 0; p < num_partitions; ++p) {
int device_id = device_assignment(r, p); int device_id = device_assignment(r, p);
auto iter = pjrt_client_->id_to_device().find(device_id); TF_ASSIGN_OR_RETURN(PjRtDevice * device,
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id; pjrt_client_->LookupDevice(device_id));
result[r][p] = WrapWithClient(shared_from_this(), iter->second); result[r][p] = WrapWithClient(shared_from_this(), device);
} }
} }
return result; return result;
@ -80,9 +81,9 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
std::vector<ClientAndPtr<PjRtDevice>> result; std::vector<ClientAndPtr<PjRtDevice>> result;
for (int i = 0; i < num_replicas; ++i) { for (int i = 0; i < num_replicas; ++i) {
int device_id = device_assignment(i, 0); int device_id = device_assignment(i, 0);
auto iter = pjrt_client_->id_to_device().find(device_id); TF_ASSIGN_OR_RETURN(PjRtDevice * device,
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id; pjrt_client_->LookupDevice(device_id));
result.push_back(WrapWithClient(shared_from_this(), iter->second)); result.push_back(WrapWithClient(shared_from_this(), device));
} }
return result; return result;
} }
@ -95,8 +96,9 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
device = pjrt_client_->local_devices().front(); device = pjrt_client_->local_devices().front();
} }
CHECK(device != nullptr); CHECK(device != nullptr);
auto iter = pjrt_client_->id_to_device().find(device->id()); TF_ASSIGN_OR_RETURN(PjRtDevice * found_device,
if (iter->second != device) { pjrt_client_->LookupDevice(device->id()));
if (found_device != device) {
return InvalidArgument("Cannot copy value to device '%s' with '%s' backend", return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
device->DebugString(), device->DebugString(),
pjrt_client_->platform_name()); pjrt_client_->platform_name());

View File

@ -97,7 +97,9 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
const std::string& platform_name() const { const std::string& platform_name() const {
return pjrt_client_->platform_name(); return pjrt_client_->platform_name();
} }
int local_device_count() const { return pjrt_client_->local_device_count(); } int addressable_device_count() const {
return pjrt_client_->addressable_device_count();
}
int device_count() const { return pjrt_client_->device_count(); } int device_count() const { return pjrt_client_->device_count(); }
int host_id() const { return pjrt_client_->host_id(); } int host_id() const { return pjrt_client_->host_id(); }

View File

@ -240,7 +240,7 @@ PYBIND11_MODULE(xla_extension, m) {
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client"); py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
py_local_client.def_property_readonly("platform", &PyClient::platform_name) py_local_client.def_property_readonly("platform", &PyClient::platform_name)
.def("device_count", &PyClient::device_count) .def("device_count", &PyClient::device_count)
.def("local_device_count", &PyClient::local_device_count) .def("local_device_count", &PyClient::addressable_device_count)
.def("devices", &PyClient::Devices) .def("devices", &PyClient::Devices)
.def("local_devices", &PyClient::LocalDevices) .def("local_devices", &PyClient::LocalDevices)
.def("host_id", &PyClient::host_id) .def("host_id", &PyClient::host_id)