Extract a PjRtClient interface.

- Extract a pure interface
- Current implementation is renamed to PjRtStreamExecutorClient.

TODO: split into a pjrt_stream_executor.{h,cc}
PiperOrigin-RevId: 346470160
Change-Id: I6235d9eed58cfd281c59f441dfed67ea3b9035ed
This commit is contained in:
Qiao Zhang 2020-12-08 20:50:53 -08:00 committed by TensorFlower Gardener
parent b9187102b6
commit 93dfb9b68f
10 changed files with 318 additions and 195 deletions

View File

@ -40,7 +40,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
TF_ASSIGN_OR_RETURN(LocalClient * client,
ClientLibrary::GetOrCreateLocalClient(options));
std::vector<std::unique_ptr<PjRtDevice>> devices;
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
for (int i = 0; i < client->device_count(); ++i) {
se::StreamExecutorConfig config;
config.ordinal = i;
@ -57,11 +57,11 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
devices.push_back(std::move(device));
}
return std::make_unique<PjRtClient>(
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
kCpuName, client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr);
/*gpu_run_options=*/nullptr));
}
} // namespace xla

View File

@ -35,9 +35,9 @@ namespace xla {
namespace {
// A custom PjRtClient that overrides the device assignment method.
class GpuClient : public xla::PjRtClient {
class GpuClient : public xla::PjRtStreamExecutorClient {
public:
using xla::PjRtClient::PjRtClient;
using xla::PjRtStreamExecutorClient::PjRtStreamExecutorClient;
xla::StatusOr<xla::DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override;
@ -55,7 +55,8 @@ xla::StatusOr<xla::DeviceAssignment> GpuClient::GetDefaultDeviceAssignment(
return assignment;
}
// Fallback to default global device assignment if we can't run locally.
return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions);
return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
num_partitions);
}
// Builds an xla::LocalClient for the GPU platform.
@ -225,9 +226,9 @@ StatusOr<std::string> NcclIdStore::GetNcclUniqueId(
return result.first->second;
}
std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> BuildLocalDevices(
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
std::vector<std::unique_ptr<PjRtDevice>> devices;
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
for (auto& local_device : local_device_states) {
int device_ordinal = local_device->device_ordinal();
const se::DeviceDescription& description =
@ -243,7 +244,7 @@ std::vector<std::unique_ptr<PjRtDevice>> BuildLocalDevices(
Status BuildDistributedDevices(
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states,
std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id,
std::vector<std::unique_ptr<PjRtDevice>>* devices,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>* devices,
gpu::GpuExecutableRunOptions* gpu_executable_run_options) {
LocalTopologyProto local_topology;
local_topology.set_node_id(node_id);
@ -322,7 +323,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient(
auto host_memory_allocator =
GetGpuHostAllocator(local_device_states.front()->executor());
std::vector<std::unique_ptr<PjRtDevice>> devices;
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
auto gpu_run_options = absl::make_unique<gpu::GpuExecutableRunOptions>();
if (distributed_client) {
TF_RETURN_IF_ERROR(BuildDistributedDevices(

View File

@ -41,7 +41,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
TF_ASSIGN_OR_RETURN(LocalClient * client,
ClientLibrary::GetOrCreateLocalClient(options));
std::vector<std::unique_ptr<PjRtDevice>> devices;
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
se::StreamExecutor* executor =
client->backend().stream_executor(0).ValueOrDie();
auto device_state = absl::make_unique<LocalDeviceState>(
@ -51,11 +51,11 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
absl::make_unique<InterpreterDevice>(0, std::move(device_state));
devices.push_back(std::move(device));
return std::make_unique<PjRtClient>(
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
"interpreter", client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr);
/*gpu_run_options=*/nullptr));
}
} // namespace xla

View File

@ -184,9 +184,9 @@ class CpuAllocator : public tensorflow::Allocator {
}
};
PjRtClient::PjRtClient(
PjRtStreamExecutorClient::PjRtStreamExecutorClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
@ -195,7 +195,7 @@ PjRtClient::PjRtClient(
platform_name_(std::move(platform_name)),
client_(client),
host_memory_allocator_(std::move(host_memory_allocator)),
devices_(std::move(devices)),
owned_devices_(std::move(devices)),
host_id_(host_id),
owned_allocator_(std::move(allocator)),
should_stage_host_to_device_transfers_(
@ -213,7 +213,9 @@ PjRtClient::PjRtClient(
host_memory_allocator_ = std::make_unique<CpuAllocator>();
}
for (const std::unique_ptr<PjRtDevice>& device : devices_) {
for (const std::unique_ptr<PjRtStreamExecutorDevice>& device :
owned_devices_) {
devices_.push_back(device.get());
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
<< "Duplicate device id: " << device->id();
@ -225,21 +227,21 @@ PjRtClient::PjRtClient(
CHECK(local_devices_[idx] == nullptr) << idx;
local_devices_[idx] = device.get();
}
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device.get())
->SetClient(this);
device->SetClient(this);
}
for (int idx = 0; idx < local_devices_.size(); ++idx) {
CHECK(local_devices_[idx] != nullptr) << idx;
}
}
StatusOr<DeviceAssignment> PjRtClient::GetDefaultDeviceAssignment(
StatusOr<DeviceAssignment> PjRtStreamExecutorClient::GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const {
return client_->backend().computation_placer()->AssignDevices(num_replicas,
num_partitions);
}
std::unique_ptr<HloCostAnalysis> PjRtClient::GetHloCostAnalysis() {
std::unique_ptr<HloCostAnalysis>
PjRtStreamExecutorClient::GetHloCostAnalysis() {
return absl::make_unique<HloCostAnalysis>(
client_->backend().compiler()->ShapeSizeBytesFunction());
}
@ -349,12 +351,13 @@ StatusOr<std::unique_ptr<PjRtBuffer>> AllocateDestinationBuffer(
return InvalidArgument("Can't make a buffer from an empty tuple");
}
auto* se_client = tensorflow::down_cast<PjRtStreamExecutorClient*>(client);
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer dst_buffer,
transfer_manager->AllocateScopedShapedBuffer(
on_host_shape, client->allocator(), local_device->device_ordinal()));
se_client->client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer dst_buffer,
transfer_manager->AllocateScopedShapedBuffer(
on_host_shape, se_client->allocator(),
local_device->device_ordinal()));
if (local_device->allocation_model() ==
LocalDeviceState::kComputeSynchronized) {
if (copy_stream == nullptr) {
@ -546,13 +549,15 @@ void PjRtBuffer::ScopedHold::AddToInput(
bool PjRtBuffer::IsOnCpu() const { return client()->platform_id() == kCpuId; }
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
StatusOr<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorClient::BufferFromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostBuffer");
VLOG(2) << "PjRtClient::BufferFromHostBuffer: shape: " << shape.ToString()
<< " device: " << device->DebugString();
tensorflow::profiler::TraceMe traceme(
"PjRtStreamExecutorClient::BufferFromHostBuffer");
VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostBuffer: shape: "
<< shape.ToString() << " device: " << device->DebugString();
if (shape.IsTuple()) {
return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
}
@ -712,17 +717,19 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
return py_buffer;
}
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device) {
StatusOr<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorClient::CreateUninitializedBuffer(const Shape& shape,
PjRtDevice* device) {
return CreateUninitializedBuffer(shape, device, nullptr);
}
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
StatusOr<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorClient::CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device,
std::shared_ptr<BufferSequencingEvent> definition_event) {
tensorflow::profiler::TraceMe traceme(
"PjRtClient::CreateUninitializedBuffer");
VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: "
"PjRtStreamExecutorClient::CreateUninitializedBuffer");
VLOG(2) << "PjRtStreamExecutorClient::CreateUninitializedBuffer: shape: "
<< shape.ToString() << " device: " << device->DebugString();
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
@ -738,10 +745,12 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
definition_event);
}
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
const LiteralSlice& literal, PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostLiteral");
VLOG(2) << "PjRtClient::BufferFromHostLiteral: shape: "
StatusOr<std::unique_ptr<PjRtBuffer>>
PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme(
"PjRtStreamExecutorClient::BufferFromHostLiteral");
VLOG(2) << "PjRtStreamExecutorClient::BufferFromHostLiteral: shape: "
<< literal.shape().ToString() << " device: " << device->DebugString();
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
@ -798,7 +807,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
return py_buffer;
}
void PjRtClient::MakeCrossHostReceiveBuffers(
void PjRtStreamExecutorClient::MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier) {
if (shapes.empty()) {
@ -851,14 +860,15 @@ StatusOr<Literal> PjRtStreamExecutorDevice::TransferFromOutfeed(
shape, local_device->device_ordinal());
}
StatusOr<PjRtDevice*> PjRtClient::LookupAddressableDevice(int device_id) const {
StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
int local_hardware_id) const {
for (auto* device : local_devices_) {
if (device_id == device->local_hardware_id()) {
if (local_hardware_id == device->local_hardware_id()) {
return device;
}
}
return InvalidArgument("No matching device found for device_id %d",
device_id);
return InvalidArgument("No matching device found for local_hardware_id %d",
local_hardware_id);
}
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
@ -883,7 +893,8 @@ PjRtBuffer::~PjRtBuffer() {
}
int64 PjRtBuffer::OnDeviceSizeInBytes() const {
return client_->client()
return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
->client()
->backend()
.transfer_manager()
->GetByteSizeRequirement(on_device_shape_);
@ -1136,12 +1147,16 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
host_value->value = std::make_shared<Literal>(host_shape);
ShapedBuffer shaped_buffer =
device_buffer->AsShapedBuffer(host_shape, on_device_shape_);
client_->client()->backend().transfer_manager()->TransferLiteralFromDevice(
stream, shaped_buffer, host_value->value.get(),
[host_value](Status done_status) {
host_value->status = done_status;
host_value->ready.Notify();
});
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
->client()
->backend()
.transfer_manager()
->TransferLiteralFromDevice(stream, shaped_buffer,
host_value->value.get(),
[host_value](Status done_status) {
host_value->status = done_status;
host_value->ready.Notify();
});
auto usage_event = std::make_shared<BufferSequencingEvent>();
StatusOr<EventPool::Handle> event_or =
@ -1170,7 +1185,7 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
tensorflow::profiler::TraceMe traceme("PjRtClient::ToLiteral");
tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::ToLiteral");
TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
CopyToHostAsyncInternal(discard_cached_copy, layout));
if (host_value == nullptr) {
@ -1280,7 +1295,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
return dst_device->client()->BufferFromHostBuffer(
literal->untyped_data(), literal->shape(),
PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, dst_device);
PjRtStreamExecutorClient::HostBufferSemantics::kZeroCopy, nullptr,
dst_device);
}
TF_ASSIGN_OR_RETURN(
@ -1288,7 +1304,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
tensorflow::down_cast<PjRtStreamExecutorDevice*>(dst_device)
->GetLocalDeviceState());
LocalDeviceState* transfer_local_device =
client_->EnqueueD2DTransfersOnSrcStream()
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
->EnqueueD2DTransfersOnSrcStream()
? tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
->local_device_state()
: dst_local_device;
@ -1339,7 +1356,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
}
Status PjRtBuffer::CopyToRemoteDevice(absl::string_view serialized_descriptor) {
return client_->CopyToRemoteDevice(this, serialized_descriptor);
return tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
->CopyToRemoteDevice(this, serialized_descriptor);
}
Status PjRtBuffer::BlockHostUntilReady() {
@ -1401,9 +1419,13 @@ StatusOr<TupleHandle> MakeTupleHelper(
Shape on_host_shape = ShapeUtil::MakeTupleShape(host_shapes);
Shape on_device_shape = ShapeUtil::MakeTupleShape(device_shapes);
se::DeviceMemoryAllocator* allocator = client->allocator();
se::DeviceMemoryAllocator* allocator =
tensorflow::down_cast<PjRtStreamExecutorClient*>(client)->allocator();
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
tensorflow::down_cast<PjRtStreamExecutorClient*>(client)
->client()
->backend()
.transfer_manager();
se::Stream* stream = local_device->host_to_device_stream();
TF_ASSIGN_OR_RETURN(
se::OwningDeviceMemory root_table_memory,
@ -1467,14 +1489,6 @@ std::unique_ptr<PjRtBuffer> OutputBufferHelper(
/*prefer_to_retain_reference=*/false);
return pjrt_buffer;
}
static PjRtDevice* LookupDevice(const PjRtClient& client, int device_id) {
auto it = client.id_to_device().find(device_id);
CHECK(it != client.id_to_device().end())
<< "Unknown device id: " << device_id;
return it->second;
}
} // namespace
PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
@ -1482,7 +1496,8 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
bool parameter_is_tupled_arguments,
std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
std::vector<PjRtDevice*> addressable_devices, PjRtClient* client)
std::vector<PjRtDevice*> addressable_devices,
PjRtStreamExecutorClient* client)
: client_(client),
device_assignment_(std::move(device_assignment)),
parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
@ -1505,7 +1520,7 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
VLOG(1) << "PjRtStreamExecutorExecutable device_assignment:\n"
<< device_assignment_->ToString();
CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString();
CHECK_LE(addressable_devices_.size(), client_->local_device_count())
CHECK_LE(addressable_devices_.size(), client_->addressable_device_count())
<< "Inconsistent local device count.";
num_partitions = device_assignment_->computation_count();
}
@ -1607,7 +1622,7 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
absl::Span<const PjRtBuffer::ScopedHold> device_buffers,
absl::flat_hash_set<BufferSequencingEvent*>& events) const {
std::vector<ExecutionInput> execution_inputs;
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
// Lift tuple_handle outside the conditional so that the event it returns is
// not destroyed until after the loop below that waits on events.
absl::optional<TupleHandle> tuple_handle;
@ -1630,8 +1645,10 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
execution_input.MutableBuffers()->begin();
ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
execution_input.MutableBuffers()->end();
device_buffers[i].AddToInput(&input_iterator, iterator_end,
&execution_input, client_->allocator());
device_buffers[i].AddToInput(
&input_iterator, iterator_end, &execution_input,
tensorflow::down_cast<PjRtStreamExecutorClient*>(client_)
->allocator());
CHECK(input_iterator == iterator_end);
}
}
@ -1654,7 +1671,7 @@ StatusOr<ScopedShapedBuffer> PjRtStreamExecutorExecutable::EnqueueExecution(
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->local_device_state()
->device_ordinal();
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
tensorflow::profiler::TraceMeConsumer activity(
"LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
run_id.ToInt());
@ -1790,7 +1807,7 @@ PjRtStreamExecutorExecutable::MakeOutputBuffers(
std::shared_ptr<BufferSequencingEvent> definition_event,
PjRtDevice* device) const {
std::vector<std::unique_ptr<PjRtBuffer>> outputs;
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) {
int tuple_count = result_buffer.on_host_shape().tuple_shapes_size();
outputs.reserve(tuple_count);
@ -1827,7 +1844,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
if (device == nullptr) {
CHECK(device_assignment_ != nullptr);
const int device_id = (*device_assignment_)(replica, partition);
device = LookupDevice(*client_, device_id);
TF_ASSIGN_OR_RETURN(device, client_->LookupDevice(device_id));
device_assignment = device_assignment_;
} else {
CHECK(device_assignment_ == nullptr);
@ -1863,7 +1880,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
ScopedShapedBuffer result_buffer =
result_buffer_or_status.ConsumeValueOrDie();
LocalDeviceState* device_state = &client_->device_state(device_ordinal);
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
se::Stream* stream = device_state->compute_stream();
StatusOr<EventPool::Handle> event_or =
device_state->event_pool().ThenAllocateAndRecordEvent(stream);
@ -2160,9 +2177,9 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
} // namespace
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
const XlaComputation& computation, CompileOptions options) {
tensorflow::profiler::TraceMe traceme("PjRtClient::Compile");
tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
ExecutableBuildOptions& build_options = options.executable_build_options;
if (!build_options.device_allocator()) {
@ -2182,14 +2199,15 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
num_partitions = 1;
} else {
if (!build_options.has_device_assignment()) {
VLOG(2) << "PjRtClient::Compile using default device_assignment.";
VLOG(2) << "PjRtStreamExecutorClient::Compile using default "
"device_assignment.";
TF_ASSIGN_OR_RETURN(
DeviceAssignment device_assignment,
GetDefaultDeviceAssignment(build_options.num_replicas(),
build_options.num_partitions()));
build_options.set_device_assignment(device_assignment);
}
VLOG(2) << "PjRtClient::Compile device_assignment:\n"
VLOG(2) << "PjRtStreamExecutorClient::Compile device_assignment:\n"
<< build_options.device_assignment().ToString();
num_replicas = build_options.device_assignment().replica_count();
num_partitions = build_options.device_assignment().computation_count();
@ -2263,7 +2281,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
for (int replica = 0; replica < num_replicas; ++replica) {
for (int partition = 0; partition < num_partitions; ++partition) {
int device_id = (*device_assignment)(replica, partition);
PjRtDevice* device = LookupDevice(*this, device_id);
TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id));
if (device->host_id() != host_id()) {
VLOG(3) << "Non-local device: " << device_id;
continue;
@ -2283,10 +2301,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
if (build_options.device_ordinal() < 0) {
build_options.set_device_ordinal(
tensorflow::down_cast<PjRtStreamExecutorDevice*>(
addressable_devices.front())
->local_device_state()
->device_ordinal());
addressable_devices.front()->local_hardware_id());
}
}

View File

@ -219,86 +219,62 @@ class PjRtExecutable;
// alive as long as any of the other runtime objects are alive.
class PjRtClient {
public:
// `allocator` may null, in which case the platform default allocator is used.
explicit PjRtClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
virtual ~PjRtClient() = default;
// TODO(zhangqiaorjc): Rename to task_id.
// Return the task id of this client. In single-task setting, always 0.
virtual int host_id() const = 0;
// Return the number of devices in the entire computation. In multi-headed
// client setting, some are addressable by this client, some are not. In a
// single-client setting, this is equal to the number of addressable devices.
virtual int device_count() const = 0;
// Return number of addressable devices. Addressable devices are those that
// the client can issue commands to.
virtual int addressable_device_count() const = 0;
// Return all devices in the entire computation, including addressable and
// non-addressable devices.
virtual absl::Span<PjRtDevice* const> devices() const = 0;
// TODO(zhangqiaorjc): Rename to addressable_devices.
// Return only addressable devices.
virtual absl::Span<PjRtDevice* const> local_devices() const = 0;
// Lookup any PjRtDevice for a given PjRtDevice::id().
virtual StatusOr<PjRtDevice*> LookupDevice(int device_id) const = 0;
// Return an addressable PjRtDevice for a given
// PjRtDevice::local_hardware_id().
virtual StatusOr<PjRtDevice*> LookupAddressableDevice(
int local_hardware_id) const = 0;
// Return an ID that identifies the platform (CPU/GPU/TPU).
virtual PjRtPlatformId platform_id() const = 0;
// Returns a string that identifies the platform (CPU/GPU/TPU).
virtual const std::string& platform_name() const = 0;
// Return a device-specific default device assignment, e.g., GPU and TPU may
// be different.
virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const;
int device_count() const { return devices_.size(); }
int local_device_count() const { return local_devices_.size(); }
const std::vector<std::unique_ptr<PjRtDevice>>& devices() const {
return devices_;
}
absl::Span<PjRtDevice* const> local_devices() const { return local_devices_; }
const std::map<int, PjRtDevice*>& id_to_device() const {
return id_to_device_;
}
int host_id() const { return host_id_; }
PjRtPlatformId platform_id() const { return platform_id_; }
const std::string& platform_name() const { return platform_name_; }
LocalDeviceState& device_state(int device_ordinal) const {
return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
local_devices_.at(device_ordinal))
->local_device_state();
}
// Return an addressable PjRtDevice for a given `device_id`.
virtual StatusOr<PjRtDevice*> LookupAddressableDevice(int device_id) const;
LocalClient* client() const { return client_; }
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
tensorflow::Allocator* host_memory_allocator() const {
return host_memory_allocator_.get();
}
bool should_stage_host_to_device_transfers() const {
return should_stage_host_to_device_transfers_;
}
gpu::GpuExecutableRunOptions* gpu_run_options() const {
return gpu_run_options_.get();
}
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
return &h2d_transfer_pool_;
}
// Most platforms expect device-to-device transfers to be enqueued on the
// source d2d stream, but some platforms use the destination d2d stream. This
// function specifies which one the platform expects.
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
// Generates a unique fingerprint for `executable`.
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
const PjRtExecutable& executable) const {
return absl::optional<std::string>();
}
int num_replicas, int num_partitions) const = 0;
// Returns a backend-specific HLO cost analysis visitor.
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis();
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() = 0;
// Compile `computation` with given `options`.
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
const XlaComputation& computation, CompileOptions options);
const XlaComputation& computation, CompileOptions options) = 0;
// Generates a unique fingerprint for `executable`, may be absl::nullopt.
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
const PjRtExecutable& executable) const = 0;
// Creates a buffer on the device without initializing or copying any data.
// An optional `definition_event` may be speficied that can be used to
// ensure the buffer isn't referenced until some external mechanism has
// initialized the data.
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
// future backends and so callers should avoid wherever possible.
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device);
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device,
std::shared_ptr<BufferSequencingEvent> definition_event);
const Shape& shape, PjRtDevice* device) = 0;
// Describes the semantics the caller to BufferFromHostBuffer expects from the
// runtime, in a total order from most restrictive to least restrictive.
@ -330,13 +306,13 @@ class PjRtClient {
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device);
std::shared_ptr<void> buffer_reference, PjRtDevice* device) = 0;
// Note that literal must remain in scope until the transfer has completed, so
// the caller should, for example, wait for BlockHostUntilReady() completes on
// the return value before letting literal go out of scope.
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
const LiteralSlice& literal, PjRtDevice* device);
const LiteralSlice& literal, PjRtDevice* device) = 0;
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
// cross host transfers using `client` on `device'. `shapes` must be the exact
@ -349,18 +325,140 @@ class PjRtClient {
// buffers will become ready until *all* of the sends have completed.
virtual void MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier);
PjRtCrossHostRecvNotifier&& notifier) = 0;
virtual StatusOr<ChannelHandle> CreateChannelHandle() {
// Create ChannelHandles for XLA send/recv.
virtual StatusOr<ChannelHandle> CreateChannelHandle() = 0;
virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() = 0;
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() = 0;
};
class PjRtStreamExecutorClient : public PjRtClient {
public:
// `allocator` may null, in which case the platform default allocator is used.
explicit PjRtStreamExecutorClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
int host_id, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
~PjRtStreamExecutorClient() override = default;
int host_id() const override { return host_id_; }
int device_count() const override { return devices_.size(); }
int addressable_device_count() const override {
return local_devices_.size();
}
absl::Span<PjRtDevice* const> devices() const override { return devices_; }
absl::Span<PjRtDevice* const> local_devices() const override {
return local_devices_;
}
StatusOr<PjRtDevice*> LookupDevice(int device_id) const override {
auto it = id_to_device_.find(device_id);
if (it != id_to_device_.end()) {
return it->second;
}
return InvalidArgument("No matching device found for device_id %d",
device_id);
}
StatusOr<PjRtDevice*> LookupAddressableDevice(
int local_hardware_id) const override;
PjRtPlatformId platform_id() const override { return platform_id_; }
const std::string& platform_name() const override { return platform_name_; }
// Most platforms expect device-to-device transfers to be enqueued on the
// source d2d stream, but some platforms use the destination d2d stream. This
// function specifies which one the platform expects.
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override;
StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
const XlaComputation& computation, CompileOptions options) override;
// Generates a unique fingerprint for `executable`.
StatusOr<absl::optional<std::string>> ExecutableFingerprint(
const PjRtExecutable& executable) const override {
return absl::optional<std::string>();
}
// Returns a backend-specific HLO cost analysis visitor.
std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() override;
// Creates a buffer on the device without initializing or copying any data.
// An optional `definition_event` may be speficied that can be used to
// ensure the buffer isn't referenced until some external mechanism has
// initialized the data.
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
// future backends and so callers should avoid wherever possible.
StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device) override;
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device,
std::shared_ptr<BufferSequencingEvent> definition_event);
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device) override;
// Note that literal must remain in scope until the transfer has completed, so
// the caller should, for example, wait for BlockHostUntilReady() completes on
// the return value before letting literal go out of scope.
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
const LiteralSlice& literal, PjRtDevice* device) override;
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
// cross host transfers using `client` on `device'. `shapes` must be the exact
// shapes, with identical layouts, corresponding to the buffers that will be
// sent. When resources for the transfer are available, notifier will be
// called with a vector of PjRtCrossHostRecvBuffer structs, one for each
// shape in `shapes`. Each struct contains a buffer that will contain the
// received value, and an opaque string that should be transmitted to the
// sending host and used in a call to CopyToRemoteDevice. None of the recv
// buffers will become ready until *all* of the sends have completed.
void MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier) override;
StatusOr<ChannelHandle> CreateChannelHandle() override {
return client()->CreateChannelHandle();
}
virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override {
return client()->CreateDeviceToHostChannelHandle();
}
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override {
return client()->CreateHostToDeviceChannelHandle();
}
LocalDeviceState& device_state(int device_ordinal) const {
return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
local_devices_.at(device_ordinal))
->local_device_state();
}
LocalClient* client() const { return client_; }
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
tensorflow::Allocator* host_memory_allocator() const {
return host_memory_allocator_.get();
}
bool should_stage_host_to_device_transfers() const {
return should_stage_host_to_device_transfers_;
}
gpu::GpuExecutableRunOptions* gpu_run_options() const {
return gpu_run_options_.get();
}
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
return &h2d_transfer_pool_;
}
protected:
friend class PjRtBuffer;
virtual void EnqueueCrossHostReceive(
@ -383,7 +481,9 @@ class PjRtClient {
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
// Includes all devices, including non-local devices on multi-host platforms.
std::vector<std::unique_ptr<PjRtDevice>> devices_;
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> owned_devices_;
// Pointers to `owned_devices_`.
std::vector<PjRtDevice*> devices_;
// Maps Device::id() to the corresponding Device. Includes all devices.
std::map<int, PjRtDevice*> id_to_device_;
// Local devices indexed by local device ordinal.
@ -550,7 +650,7 @@ class PjRtBuffer {
private:
friend class PjRtBuffer;
friend class PjRtClient;
friend class PjRtStreamExecutorClient;
// Helper struct that makes it possible to move a ScopedHold through a
// closure.
@ -810,7 +910,7 @@ class PjRtExecutable {
virtual PjRtClient* client() const = 0;
// Unique name for this executable, e.g., HloModule name.
virtual const string& name() const = 0;
virtual const std::string& name() const = 0;
virtual int num_replicas() const = 0;
@ -875,13 +975,14 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
bool parameter_is_tupled_arguments,
std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
std::vector<PjRtDevice*> addressable_devices, PjRtClient* client);
std::vector<PjRtDevice*> addressable_devices,
PjRtStreamExecutorClient* client);
~PjRtStreamExecutorExecutable() override = default;
PjRtClient* client() const override { return client_; }
PjRtStreamExecutorClient* client() const override { return client_; }
const string& name() const override;
const std::string& name() const override;
int num_replicas() const override {
return executables_[0]->build_options().num_replicas();
@ -940,7 +1041,7 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
}
private:
friend class PjRtClient;
friend class PjRtStreamExecutorClient;
// Initializes information about which arguments to which executables must be
// donated due to aliases that were specified by the computation.
Status SetUpDonation(bool tuple_inputs);
@ -975,7 +1076,7 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
// Create shared pointers so we can free them after the execution: with
// asynchronous execution, the process being executed can outlive the
// executable itself.
PjRtClient* const client_;
PjRtStreamExecutorClient* const client_;
// One executable per partition.
std::vector<std::shared_ptr<LocalExecutable>> executables_;
// Per-executable set of parameters that have any aliased buffers and thus

View File

@ -94,10 +94,11 @@ Status TpuDeviceState::ThenMemcpyDeviceToDevice(
return Status::OK();
}
class PjRtTpuClient : public PjRtClient {
class PjRtTpuClient : public PjRtStreamExecutorClient {
public:
PjRtTpuClient(LocalClient* client,
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id);
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
int host_id);
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override;
@ -108,14 +109,14 @@ class PjRtTpuClient : public PjRtClient {
const PjRtExecutable& executable) const override;
};
PjRtTpuClient::PjRtTpuClient(LocalClient* client,
std::vector<std::unique_ptr<PjRtDevice>> devices,
int host_id)
: PjRtClient(kTpuName, client, std::move(devices), host_id,
/*allocator=*/nullptr,
/*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr) {}
PjRtTpuClient::PjRtTpuClient(
LocalClient* client,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id)
: PjRtStreamExecutorClient(kTpuName, client, std::move(devices), host_id,
/*allocator=*/nullptr,
/*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr) {}
StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const {
@ -128,7 +129,8 @@ StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
num_partitions);
}
// Fallback to default global device assignment if we can't run locally.
return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions);
return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
num_partitions);
}
StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
@ -152,10 +154,10 @@ StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
return absl::optional<std::string>(tpu_executable->fingerprint());
}
StatusOr<std::vector<std::unique_ptr<PjRtDevice>>> GetTpuDevices(
StatusOr<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>> GetTpuDevices(
LocalClient* client,
std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
std::vector<std::unique_ptr<PjRtDevice>> devices;
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
tf_tpu::TpuTopologyExternal topology =
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();

View File

@ -230,8 +230,8 @@ OutfeedReceiverImpl::OutfeedReceiverImpl(
callback_ = callback;
max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
for (const auto& client : clients) {
for (const auto& device : client->devices()) {
devices_.push_back(device.get());
for (auto device : client->devices()) {
devices_.push_back(device);
}
}
CHECK_GT(devices_.size(), 0);

View File

@ -37,9 +37,10 @@ PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() {
std::vector<ClientAndPtr<PjRtDevice>> devices;
devices.reserve(pjrt_client_->devices().size());
for (const auto& device : pjrt_client_->devices()) {
devices.push_back(WrapWithClient(shared_from_this(), device.get()));
auto span = pjrt_client_->devices();
devices.reserve(span.size());
for (PjRtDevice* device : span) {
devices.push_back(WrapWithClient(shared_from_this(), device));
}
return devices;
}
@ -64,9 +65,9 @@ PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
result[r].resize(num_partitions);
for (int p = 0; p < num_partitions; ++p) {
int device_id = device_assignment(r, p);
auto iter = pjrt_client_->id_to_device().find(device_id);
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
result[r][p] = WrapWithClient(shared_from_this(), iter->second);
TF_ASSIGN_OR_RETURN(PjRtDevice * device,
pjrt_client_->LookupDevice(device_id));
result[r][p] = WrapWithClient(shared_from_this(), device);
}
}
return result;
@ -80,9 +81,9 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
std::vector<ClientAndPtr<PjRtDevice>> result;
for (int i = 0; i < num_replicas; ++i) {
int device_id = device_assignment(i, 0);
auto iter = pjrt_client_->id_to_device().find(device_id);
CHECK(iter != pjrt_client_->id_to_device().end()) << device_id;
result.push_back(WrapWithClient(shared_from_this(), iter->second));
TF_ASSIGN_OR_RETURN(PjRtDevice * device,
pjrt_client_->LookupDevice(device_id));
result.push_back(WrapWithClient(shared_from_this(), device));
}
return result;
}
@ -95,8 +96,9 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
device = pjrt_client_->local_devices().front();
}
CHECK(device != nullptr);
auto iter = pjrt_client_->id_to_device().find(device->id());
if (iter->second != device) {
TF_ASSIGN_OR_RETURN(PjRtDevice * found_device,
pjrt_client_->LookupDevice(device->id()));
if (found_device != device) {
return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
device->DebugString(),
pjrt_client_->platform_name());

View File

@ -97,7 +97,9 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
const std::string& platform_name() const {
return pjrt_client_->platform_name();
}
int local_device_count() const { return pjrt_client_->local_device_count(); }
int addressable_device_count() const {
return pjrt_client_->addressable_device_count();
}
int device_count() const { return pjrt_client_->device_count(); }
int host_id() const { return pjrt_client_->host_id(); }

View File

@ -240,7 +240,7 @@ PYBIND11_MODULE(xla_extension, m) {
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
.def("device_count", &PyClient::device_count)
.def("local_device_count", &PyClient::local_device_count)
.def("local_device_count", &PyClient::addressable_device_count)
.def("devices", &PyClient::Devices)
.def("local_devices", &PyClient::LocalDevices)
.def("host_id", &PyClient::host_id)