Add the ability to compile a single-replica executable that is portable across devices.

PiperOrigin-RevId: 318485159
Change-Id: Ied5115fab26b49f448ee61fded64904e92794b2b
This commit is contained in:
A. Unique TensorFlower 2020-06-26 09:20:51 -07:00 committed by TensorFlower Gardener
parent e91990219b
commit 530aa3e4e0
2 changed files with 109 additions and 56 deletions

View File

@ -1402,11 +1402,12 @@ static Device* LookupDevice(const PjRtClient& client, int device_id) {
PjRtExecutable::PjRtExecutable(
std::vector<std::unique_ptr<LocalExecutable>> executables,
bool parameter_is_tupled_arguments, DeviceAssignment device_assignment,
bool parameter_is_tupled_arguments,
std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<std::pair<int, int>> local_logical_device_ids,
std::vector<Device*> local_devices, PjRtClient* client)
: client_(client),
device_assignment_(std::make_shared<DeviceAssignment>(device_assignment)),
device_assignment_(std::move(device_assignment)),
parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
local_logical_device_ids_(std::move(local_logical_device_ids)),
local_devices_(std::move(local_devices)) {
@ -1415,11 +1416,21 @@ PjRtExecutable::PjRtExecutable(
executables_.emplace_back(std::move(executable));
}
// This must go after `executables_` is initialized.
VLOG(1) << "PjRtExecutable " << name() << " device_assignment:\n"
<< device_assignment_->ToString();
const int num_partitions = device_assignment_->computation_count();
int num_partitions;
if (device_assignment_ == nullptr) {
// This must go after `executables_` is initialized.
VLOG(1) << "PjRtExecutable " << name() << " portable single-core";
num_partitions = 1;
CHECK(local_devices_.empty());
} else {
// This must go after `executables_` is initialized.
VLOG(1) << "PjRtExecutable " << name() << " device_assignment:\n"
<< device_assignment_->ToString();
CHECK_GE(local_devices_.size(), 1) << device_assignment_->ToString();
CHECK_LE(local_devices_.size(), client_->local_device_count())
<< "Inconsistent local device count.";
num_partitions = device_assignment_->computation_count();
}
// SPMD sharding produces a single executable for multiple partitions.
if (executables_.size() > 1) {
@ -1427,10 +1438,6 @@ PjRtExecutable::PjRtExecutable(
<< "Number of executables " << executables_.size()
<< " did not match number of partitions " << num_partitions;
}
CHECK_GE(local_devices_.size(), 1) << device_assignment_->ToString();
CHECK_LE(local_devices_.size(), client_->local_device_count())
<< "Inconsistent local device count.";
}
Status PjRtExecutable::SetUpDonation(PjRtClient* client, bool tuple_inputs) {
@ -1462,7 +1469,8 @@ const std::string& PjRtExecutable::name() const {
StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
Device* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers) const {
Device* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers,
std::shared_ptr<DeviceAssignment> device_assignment) const {
int device_ordinal = device->local_device_state()->device_ordinal();
tensorflow::profiler::TraceMeConsumer activity(
"LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
@ -1559,7 +1567,7 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
run_options.set_allocator(client_->allocator());
run_options.set_intra_op_thread_pool(
client_->client()->backend().eigen_intra_op_thread_pool_device());
run_options.set_device_assignment(device_assignment_.get());
run_options.set_device_assignment(device_assignment.get());
run_options.set_run_id(run_id);
run_options.set_rng_seed(device_state->GetNewPrngSeed());
run_options.set_gpu_executable_run_options(client_->gpu_run_options());
@ -1603,7 +1611,7 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
device_state->ThenExecuteOnCallbackThread(
device_state->compute_stream(),
[references{std::make_tuple(executables_[executable_idx],
compute_reservation, device_assignment_)},
compute_reservation, device_assignment)},
donated_ptrs{std::move(donated_ptrs)}, allocator{client_->allocator()},
device_ordinal]() {
for (const auto& ptr : donated_ptrs) {
@ -1616,7 +1624,7 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
device_state->ThenRelease(
device_state->compute_stream(),
std::make_tuple(executables_[executable_idx], compute_reservation,
device_assignment_));
device_assignment));
}
return result_buffer_or_status.ConsumeValueOrDie().ConsumeResult();
@ -1625,9 +1633,22 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtExecutable::ExecuteHelper(absl::Span<PjRtBuffer* const> argument_handles,
int replica, int partition, const RunId& run_id,
const ExecuteOptions& options) const {
const int device_id = (*device_assignment_)(replica, partition);
Device* device = LookupDevice(*client_, device_id);
const ExecuteOptions& options,
Device* device) const {
std::shared_ptr<DeviceAssignment> device_assignment;
if (device == nullptr) {
CHECK(device_assignment_ != nullptr);
const int device_id = (*device_assignment_)(replica, partition);
device = LookupDevice(*client_, device_id);
device_assignment = device_assignment_;
} else {
CHECK(device_assignment_ == nullptr);
CHECK_EQ(replica, 0);
CHECK_EQ(partition, 0);
CHECK(local_devices_.empty());
device_assignment = std::make_shared<DeviceAssignment>(1, 1);
(*device_assignment)(0, 0) = device->id();
}
CHECK_EQ(device->host_id(), client_->host_id());
int device_ordinal = device->local_device_state()->device_ordinal();
@ -1640,9 +1661,9 @@ PjRtExecutable::ExecuteHelper(absl::Span<PjRtBuffer* const> argument_handles,
std::vector<PjRtBuffer::ScopedHold> device_buffers;
device_buffers.reserve(argument_handles.size());
StatusOr<ScopedShapedBuffer> result_buffer_or_status =
EnqueueExecution(argument_handles, replica, partition, executable_idx,
run_id, options, device, &device_buffers);
StatusOr<ScopedShapedBuffer> result_buffer_or_status = EnqueueExecution(
argument_handles, replica, partition, executable_idx, run_id, options,
device, &device_buffers, std::move(device_assignment));
if (!result_buffer_or_status.ok()) {
LOG(ERROR) << "Execution of replica " << replica
@ -1736,6 +1757,13 @@ StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtExecutable::ExecuteOnLocalDevice(
absl::Span<PjRtBuffer* const> argument_handles, Device* device,
const ExecuteOptions& options) const {
if (device_assignment_ == nullptr) {
VLOG(1) << "Executing portable single-core program on "
<< device->DebugString();
return ExecuteHelper(argument_handles,
/*replica=*/0,
/*partition=*/0, RunId(), options, device);
}
for (int i = 0; i < local_devices_.size(); ++i) {
if (local_devices_[i] == device) {
VLOG(1) << "Executing computation " << name();
@ -1754,6 +1782,8 @@ StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
PjRtExecutable::ExecuteOnLocalDevices(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options) const {
CHECK(device_assignment_ != nullptr);
RunId run_id;
tensorflow::profiler::TraceMeProducer activity(
"LocalExecutable::ExecuteOnLocalDevices",
@ -1952,16 +1982,33 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
build_options.set_device_allocator(client->allocator());
}
if (!build_options.has_device_assignment()) {
VLOG(2) << "PjRtExecutable::Compile using default device_assignment.";
TF_ASSIGN_OR_RETURN(
DeviceAssignment device_assignment,
client->GetDefaultDeviceAssignment(build_options.num_replicas(),
build_options.num_partitions()));
build_options.set_device_assignment(device_assignment);
int num_replicas;
int num_partitions;
std::shared_ptr<DeviceAssignment> device_assignment;
if (options.compile_portable_executable) {
if (build_options.has_device_assignment()) {
return InvalidArgument(
"CompileOptions requests portable executable but "
"ExecutableBuildOptions includes a device assignment");
}
num_replicas = 1;
num_partitions = 1;
} else {
if (!build_options.has_device_assignment()) {
VLOG(2) << "PjRtExecutable::Compile using default device_assignment.";
TF_ASSIGN_OR_RETURN(
DeviceAssignment device_assignment,
client->GetDefaultDeviceAssignment(build_options.num_replicas(),
build_options.num_partitions()));
build_options.set_device_assignment(device_assignment);
}
VLOG(2) << "PjRtExecutable::Compile device_assignment:\n"
<< build_options.device_assignment().ToString();
num_replicas = build_options.device_assignment().replica_count();
num_partitions = build_options.device_assignment().computation_count();
device_assignment =
std::make_shared<DeviceAssignment>(build_options.device_assignment());
}
VLOG(2) << "PjRtExecutable::Compile device_assignment:\n"
<< build_options.device_assignment().ToString();
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
computation.GetProgramShape());
@ -2020,33 +2067,31 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout));
build_options.set_result_layout(result_layout);
const int num_replicas = build_options.device_assignment().replica_count();
const int num_partitions =
build_options.device_assignment().computation_count();
std::vector<std::pair<int, int>> local_logical_device_ids;
std::vector<Device*> local_devices;
for (int replica = 0; replica < num_replicas; ++replica) {
for (int partition = 0; partition < num_partitions; ++partition) {
int device_id = build_options.device_assignment()(replica, partition);
Device* device = LookupDevice(*client, device_id);
if (device->host_id() != client->host_id()) {
VLOG(3) << "Non-local device: " << device_id;
continue;
if (device_assignment != nullptr) {
for (int replica = 0; replica < num_replicas; ++replica) {
for (int partition = 0; partition < num_partitions; ++partition) {
int device_id = (*device_assignment)(replica, partition);
Device* device = LookupDevice(*client, device_id);
if (device->host_id() != client->host_id()) {
VLOG(3) << "Non-local device: " << device_id;
continue;
}
local_logical_device_ids.emplace_back(replica, partition);
local_devices.push_back(device);
}
local_logical_device_ids.emplace_back(replica, partition);
local_devices.push_back(device);
}
}
if (local_devices.empty()) {
return InvalidArgument(
"Device assignment (%s) does not have any local devices.",
build_options.device_assignment().ToString());
}
if (local_devices.empty()) {
return InvalidArgument(
"Device assignment (%s) does not have any local devices.",
device_assignment->ToString());
}
if (build_options.device_ordinal() < 0) {
build_options.set_device_ordinal(
local_devices.front()->local_device_state()->device_ordinal());
if (build_options.device_ordinal() < 0) {
build_options.set_device_ordinal(
local_devices.front()->local_device_state()->device_ordinal());
}
}
TF_ASSIGN_OR_RETURN(
@ -2056,7 +2101,7 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
auto py_executable = absl::make_unique<PjRtExecutable>(
std::move(local_executables), options.parameter_is_tupled_arguments,
build_options.device_assignment(), std::move(local_logical_device_ids),
std::move(device_assignment), std::move(local_logical_device_ids),
std::move(local_devices), client);
TF_RETURN_IF_ERROR(py_executable->SetUpDonation(
client, options.parameter_is_tupled_arguments));

View File

@ -647,6 +647,12 @@ struct CompileOptions {
// XLA's compilation time options.
ExecutableBuildOptions executable_build_options;
// If true, the executable can be run on any device. May only be true if
// !executable_build_options.has_device_assignment(), so only applies to
// single-device executables. Beware: on GPUs, sometimes an executable
// compiled for one device doesn't run on another.
bool compile_portable_executable = false;
};
struct ExecuteOptions {
@ -673,7 +679,7 @@ class PjRtExecutable {
PjRtExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,
bool parameter_is_tupled_arguments,
DeviceAssignment device_assignment,
std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<std::pair<int, int>> local_logical_device_ids,
std::vector<Device*> local_devices, PjRtClient* client);
@ -738,10 +744,12 @@ class PjRtExecutable {
absl::Span<PjRtBuffer* const> argument_handles, int replica,
int partition, int executable_idx, const RunId& run_id,
const ExecuteOptions& options, Device* device,
std::vector<PjRtBuffer::ScopedHold>* device_buffers) const;
std::vector<PjRtBuffer::ScopedHold>* device_buffers,
std::shared_ptr<DeviceAssignment> device_assignment) const;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteHelper(
absl::Span<PjRtBuffer* const> argument_handles, int replica,
int partition, const RunId& run_id, const ExecuteOptions& options) const;
int partition, const RunId& run_id, const ExecuteOptions& options,
Device* device = nullptr) const;
// Create shared pointers so we can free them after the execution: with
// asynchronous execution, the process being executed can outlive the