Add the ability to compile a single-replica executable that is portable across devices.

PiperOrigin-RevId: 318485159
Change-Id: Ied5115fab26b49f448ee61fded64904e92794b2b
This commit is contained in:
A. Unique TensorFlower 2020-06-26 09:20:51 -07:00 committed by TensorFlower Gardener
parent e91990219b
commit 530aa3e4e0
2 changed files with 109 additions and 56 deletions

View File

@ -1402,11 +1402,12 @@ static Device* LookupDevice(const PjRtClient& client, int device_id) {
PjRtExecutable::PjRtExecutable( PjRtExecutable::PjRtExecutable(
std::vector<std::unique_ptr<LocalExecutable>> executables, std::vector<std::unique_ptr<LocalExecutable>> executables,
bool parameter_is_tupled_arguments, DeviceAssignment device_assignment, bool parameter_is_tupled_arguments,
std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<std::pair<int, int>> local_logical_device_ids, std::vector<std::pair<int, int>> local_logical_device_ids,
std::vector<Device*> local_devices, PjRtClient* client) std::vector<Device*> local_devices, PjRtClient* client)
: client_(client), : client_(client),
device_assignment_(std::make_shared<DeviceAssignment>(device_assignment)), device_assignment_(std::move(device_assignment)),
parameter_is_tupled_arguments_(parameter_is_tupled_arguments), parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
local_logical_device_ids_(std::move(local_logical_device_ids)), local_logical_device_ids_(std::move(local_logical_device_ids)),
local_devices_(std::move(local_devices)) { local_devices_(std::move(local_devices)) {
@ -1415,11 +1416,21 @@ PjRtExecutable::PjRtExecutable(
executables_.emplace_back(std::move(executable)); executables_.emplace_back(std::move(executable));
} }
int num_partitions;
if (device_assignment_ == nullptr) {
// This must go after `executables_` is initialized.
VLOG(1) << "PjRtExecutable " << name() << " portable single-core";
num_partitions = 1;
CHECK(local_devices_.empty());
} else {
// This must go after `executables_` is initialized. // This must go after `executables_` is initialized.
VLOG(1) << "PjRtExecutable " << name() << " device_assignment:\n" VLOG(1) << "PjRtExecutable " << name() << " device_assignment:\n"
<< device_assignment_->ToString(); << device_assignment_->ToString();
CHECK_GE(local_devices_.size(), 1) << device_assignment_->ToString();
const int num_partitions = device_assignment_->computation_count(); CHECK_LE(local_devices_.size(), client_->local_device_count())
<< "Inconsistent local device count.";
num_partitions = device_assignment_->computation_count();
}
// SPMD sharding produces a single executable for multiple partitions. // SPMD sharding produces a single executable for multiple partitions.
if (executables_.size() > 1) { if (executables_.size() > 1) {
@ -1427,10 +1438,6 @@ PjRtExecutable::PjRtExecutable(
<< "Number of executables " << executables_.size() << "Number of executables " << executables_.size()
<< " did not match number of partitions " << num_partitions; << " did not match number of partitions " << num_partitions;
} }
CHECK_GE(local_devices_.size(), 1) << device_assignment_->ToString();
CHECK_LE(local_devices_.size(), client_->local_device_count())
<< "Inconsistent local device count.";
} }
Status PjRtExecutable::SetUpDonation(PjRtClient* client, bool tuple_inputs) { Status PjRtExecutable::SetUpDonation(PjRtClient* client, bool tuple_inputs) {
@ -1462,7 +1469,8 @@ const std::string& PjRtExecutable::name() const {
StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution( StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition, absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
int executable_idx, const RunId& run_id, const ExecuteOptions& options, int executable_idx, const RunId& run_id, const ExecuteOptions& options,
Device* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers) const { Device* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers,
std::shared_ptr<DeviceAssignment> device_assignment) const {
int device_ordinal = device->local_device_state()->device_ordinal(); int device_ordinal = device->local_device_state()->device_ordinal();
tensorflow::profiler::TraceMeConsumer activity( tensorflow::profiler::TraceMeConsumer activity(
"LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt, "LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
@ -1559,7 +1567,7 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
run_options.set_allocator(client_->allocator()); run_options.set_allocator(client_->allocator());
run_options.set_intra_op_thread_pool( run_options.set_intra_op_thread_pool(
client_->client()->backend().eigen_intra_op_thread_pool_device()); client_->client()->backend().eigen_intra_op_thread_pool_device());
run_options.set_device_assignment(device_assignment_.get()); run_options.set_device_assignment(device_assignment.get());
run_options.set_run_id(run_id); run_options.set_run_id(run_id);
run_options.set_rng_seed(device_state->GetNewPrngSeed()); run_options.set_rng_seed(device_state->GetNewPrngSeed());
run_options.set_gpu_executable_run_options(client_->gpu_run_options()); run_options.set_gpu_executable_run_options(client_->gpu_run_options());
@ -1603,7 +1611,7 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
device_state->ThenExecuteOnCallbackThread( device_state->ThenExecuteOnCallbackThread(
device_state->compute_stream(), device_state->compute_stream(),
[references{std::make_tuple(executables_[executable_idx], [references{std::make_tuple(executables_[executable_idx],
compute_reservation, device_assignment_)}, compute_reservation, device_assignment)},
donated_ptrs{std::move(donated_ptrs)}, allocator{client_->allocator()}, donated_ptrs{std::move(donated_ptrs)}, allocator{client_->allocator()},
device_ordinal]() { device_ordinal]() {
for (const auto& ptr : donated_ptrs) { for (const auto& ptr : donated_ptrs) {
@ -1616,7 +1624,7 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
device_state->ThenRelease( device_state->ThenRelease(
device_state->compute_stream(), device_state->compute_stream(),
std::make_tuple(executables_[executable_idx], compute_reservation, std::make_tuple(executables_[executable_idx], compute_reservation,
device_assignment_)); device_assignment));
} }
return result_buffer_or_status.ConsumeValueOrDie().ConsumeResult(); return result_buffer_or_status.ConsumeValueOrDie().ConsumeResult();
@ -1625,9 +1633,22 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtExecutable::ExecuteHelper(absl::Span<PjRtBuffer* const> argument_handles, PjRtExecutable::ExecuteHelper(absl::Span<PjRtBuffer* const> argument_handles,
int replica, int partition, const RunId& run_id, int replica, int partition, const RunId& run_id,
const ExecuteOptions& options) const { const ExecuteOptions& options,
Device* device) const {
std::shared_ptr<DeviceAssignment> device_assignment;
if (device == nullptr) {
CHECK(device_assignment_ != nullptr);
const int device_id = (*device_assignment_)(replica, partition); const int device_id = (*device_assignment_)(replica, partition);
Device* device = LookupDevice(*client_, device_id); device = LookupDevice(*client_, device_id);
device_assignment = device_assignment_;
} else {
CHECK(device_assignment_ == nullptr);
CHECK_EQ(replica, 0);
CHECK_EQ(partition, 0);
CHECK(local_devices_.empty());
device_assignment = std::make_shared<DeviceAssignment>(1, 1);
(*device_assignment)(0, 0) = device->id();
}
CHECK_EQ(device->host_id(), client_->host_id()); CHECK_EQ(device->host_id(), client_->host_id());
int device_ordinal = device->local_device_state()->device_ordinal(); int device_ordinal = device->local_device_state()->device_ordinal();
@ -1640,9 +1661,9 @@ PjRtExecutable::ExecuteHelper(absl::Span<PjRtBuffer* const> argument_handles,
std::vector<PjRtBuffer::ScopedHold> device_buffers; std::vector<PjRtBuffer::ScopedHold> device_buffers;
device_buffers.reserve(argument_handles.size()); device_buffers.reserve(argument_handles.size());
StatusOr<ScopedShapedBuffer> result_buffer_or_status = StatusOr<ScopedShapedBuffer> result_buffer_or_status = EnqueueExecution(
EnqueueExecution(argument_handles, replica, partition, executable_idx, argument_handles, replica, partition, executable_idx, run_id, options,
run_id, options, device, &device_buffers); device, &device_buffers, std::move(device_assignment));
if (!result_buffer_or_status.ok()) { if (!result_buffer_or_status.ok()) {
LOG(ERROR) << "Execution of replica " << replica LOG(ERROR) << "Execution of replica " << replica
@ -1736,6 +1757,13 @@ StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
PjRtExecutable::ExecuteOnLocalDevice( PjRtExecutable::ExecuteOnLocalDevice(
absl::Span<PjRtBuffer* const> argument_handles, Device* device, absl::Span<PjRtBuffer* const> argument_handles, Device* device,
const ExecuteOptions& options) const { const ExecuteOptions& options) const {
if (device_assignment_ == nullptr) {
VLOG(1) << "Executing portable single-core program on "
<< device->DebugString();
return ExecuteHelper(argument_handles,
/*replica=*/0,
/*partition=*/0, RunId(), options, device);
}
for (int i = 0; i < local_devices_.size(); ++i) { for (int i = 0; i < local_devices_.size(); ++i) {
if (local_devices_[i] == device) { if (local_devices_[i] == device) {
VLOG(1) << "Executing computation " << name(); VLOG(1) << "Executing computation " << name();
@ -1754,6 +1782,8 @@ StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
PjRtExecutable::ExecuteOnLocalDevices( PjRtExecutable::ExecuteOnLocalDevices(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles, absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options) const { const ExecuteOptions& options) const {
CHECK(device_assignment_ != nullptr);
RunId run_id; RunId run_id;
tensorflow::profiler::TraceMeProducer activity( tensorflow::profiler::TraceMeProducer activity(
"LocalExecutable::ExecuteOnLocalDevices", "LocalExecutable::ExecuteOnLocalDevices",
@ -1952,6 +1982,18 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
build_options.set_device_allocator(client->allocator()); build_options.set_device_allocator(client->allocator());
} }
int num_replicas;
int num_partitions;
std::shared_ptr<DeviceAssignment> device_assignment;
if (options.compile_portable_executable) {
if (build_options.has_device_assignment()) {
return InvalidArgument(
"CompileOptions requests portable executable but "
"ExecutableBuildOptions includes a device assignment");
}
num_replicas = 1;
num_partitions = 1;
} else {
if (!build_options.has_device_assignment()) { if (!build_options.has_device_assignment()) {
VLOG(2) << "PjRtExecutable::Compile using default device_assignment."; VLOG(2) << "PjRtExecutable::Compile using default device_assignment.";
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
@ -1962,6 +2004,11 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
} }
VLOG(2) << "PjRtExecutable::Compile device_assignment:\n" VLOG(2) << "PjRtExecutable::Compile device_assignment:\n"
<< build_options.device_assignment().ToString(); << build_options.device_assignment().ToString();
num_replicas = build_options.device_assignment().replica_count();
num_partitions = build_options.device_assignment().computation_count();
device_assignment =
std::make_shared<DeviceAssignment>(build_options.device_assignment());
}
TF_ASSIGN_OR_RETURN(ProgramShape program_shape, TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
computation.GetProgramShape()); computation.GetProgramShape());
@ -2020,15 +2067,12 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout)); TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout));
build_options.set_result_layout(result_layout); build_options.set_result_layout(result_layout);
const int num_replicas = build_options.device_assignment().replica_count();
const int num_partitions =
build_options.device_assignment().computation_count();
std::vector<std::pair<int, int>> local_logical_device_ids; std::vector<std::pair<int, int>> local_logical_device_ids;
std::vector<Device*> local_devices; std::vector<Device*> local_devices;
if (device_assignment != nullptr) {
for (int replica = 0; replica < num_replicas; ++replica) { for (int replica = 0; replica < num_replicas; ++replica) {
for (int partition = 0; partition < num_partitions; ++partition) { for (int partition = 0; partition < num_partitions; ++partition) {
int device_id = build_options.device_assignment()(replica, partition); int device_id = (*device_assignment)(replica, partition);
Device* device = LookupDevice(*client, device_id); Device* device = LookupDevice(*client, device_id);
if (device->host_id() != client->host_id()) { if (device->host_id() != client->host_id()) {
VLOG(3) << "Non-local device: " << device_id; VLOG(3) << "Non-local device: " << device_id;
@ -2041,13 +2085,14 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
if (local_devices.empty()) { if (local_devices.empty()) {
return InvalidArgument( return InvalidArgument(
"Device assignment (%s) does not have any local devices.", "Device assignment (%s) does not have any local devices.",
build_options.device_assignment().ToString()); device_assignment->ToString());
} }
if (build_options.device_ordinal() < 0) { if (build_options.device_ordinal() < 0) {
build_options.set_device_ordinal( build_options.set_device_ordinal(
local_devices.front()->local_device_state()->device_ordinal()); local_devices.front()->local_device_state()->device_ordinal());
} }
}
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<LocalExecutable>> local_executables, std::vector<std::unique_ptr<LocalExecutable>> local_executables,
@ -2056,7 +2101,7 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
auto py_executable = absl::make_unique<PjRtExecutable>( auto py_executable = absl::make_unique<PjRtExecutable>(
std::move(local_executables), options.parameter_is_tupled_arguments, std::move(local_executables), options.parameter_is_tupled_arguments,
build_options.device_assignment(), std::move(local_logical_device_ids), std::move(device_assignment), std::move(local_logical_device_ids),
std::move(local_devices), client); std::move(local_devices), client);
TF_RETURN_IF_ERROR(py_executable->SetUpDonation( TF_RETURN_IF_ERROR(py_executable->SetUpDonation(
client, options.parameter_is_tupled_arguments)); client, options.parameter_is_tupled_arguments));

View File

@ -647,6 +647,12 @@ struct CompileOptions {
// XLA's compilation time options. // XLA's compilation time options.
ExecutableBuildOptions executable_build_options; ExecutableBuildOptions executable_build_options;
// If true, the executable can be run on any device. May only be true if
// !executable_build_options.has_device_assignment(), so only applies to
// single-device executables. Beware: on GPUs, sometimes an executable
// compiled for one device doesn't run on another.
bool compile_portable_executable = false;
}; };
struct ExecuteOptions { struct ExecuteOptions {
@ -673,7 +679,7 @@ class PjRtExecutable {
PjRtExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables, PjRtExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,
bool parameter_is_tupled_arguments, bool parameter_is_tupled_arguments,
DeviceAssignment device_assignment, std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<std::pair<int, int>> local_logical_device_ids, std::vector<std::pair<int, int>> local_logical_device_ids,
std::vector<Device*> local_devices, PjRtClient* client); std::vector<Device*> local_devices, PjRtClient* client);
@ -738,10 +744,12 @@ class PjRtExecutable {
absl::Span<PjRtBuffer* const> argument_handles, int replica, absl::Span<PjRtBuffer* const> argument_handles, int replica,
int partition, int executable_idx, const RunId& run_id, int partition, int executable_idx, const RunId& run_id,
const ExecuteOptions& options, Device* device, const ExecuteOptions& options, Device* device,
std::vector<PjRtBuffer::ScopedHold>* device_buffers) const; std::vector<PjRtBuffer::ScopedHold>* device_buffers,
std::shared_ptr<DeviceAssignment> device_assignment) const;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteHelper( StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteHelper(
absl::Span<PjRtBuffer* const> argument_handles, int replica, absl::Span<PjRtBuffer* const> argument_handles, int replica,
int partition, const RunId& run_id, const ExecuteOptions& options) const; int partition, const RunId& run_id, const ExecuteOptions& options,
Device* device = nullptr) const;
// Create shared pointers so we can free them after the execution: with // Create shared pointers so we can free them after the execution: with
// asynchronous execution, the process being executed can outlive the // asynchronous execution, the process being executed can outlive the