Add the ability to compile a single-replica executable that is portable across devices.
PiperOrigin-RevId: 318485159 Change-Id: Ied5115fab26b49f448ee61fded64904e92794b2b
This commit is contained in:
parent
e91990219b
commit
530aa3e4e0
@ -1402,11 +1402,12 @@ static Device* LookupDevice(const PjRtClient& client, int device_id) {
|
||||
|
||||
PjRtExecutable::PjRtExecutable(
|
||||
std::vector<std::unique_ptr<LocalExecutable>> executables,
|
||||
bool parameter_is_tupled_arguments, DeviceAssignment device_assignment,
|
||||
bool parameter_is_tupled_arguments,
|
||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||
std::vector<std::pair<int, int>> local_logical_device_ids,
|
||||
std::vector<Device*> local_devices, PjRtClient* client)
|
||||
: client_(client),
|
||||
device_assignment_(std::make_shared<DeviceAssignment>(device_assignment)),
|
||||
device_assignment_(std::move(device_assignment)),
|
||||
parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
|
||||
local_logical_device_ids_(std::move(local_logical_device_ids)),
|
||||
local_devices_(std::move(local_devices)) {
|
||||
@ -1415,11 +1416,21 @@ PjRtExecutable::PjRtExecutable(
|
||||
executables_.emplace_back(std::move(executable));
|
||||
}
|
||||
|
||||
// This must go after `executables_` is initialized.
|
||||
VLOG(1) << "PjRtExecutable " << name() << " device_assignment:\n"
|
||||
<< device_assignment_->ToString();
|
||||
|
||||
const int num_partitions = device_assignment_->computation_count();
|
||||
int num_partitions;
|
||||
if (device_assignment_ == nullptr) {
|
||||
// This must go after `executables_` is initialized.
|
||||
VLOG(1) << "PjRtExecutable " << name() << " portable single-core";
|
||||
num_partitions = 1;
|
||||
CHECK(local_devices_.empty());
|
||||
} else {
|
||||
// This must go after `executables_` is initialized.
|
||||
VLOG(1) << "PjRtExecutable " << name() << " device_assignment:\n"
|
||||
<< device_assignment_->ToString();
|
||||
CHECK_GE(local_devices_.size(), 1) << device_assignment_->ToString();
|
||||
CHECK_LE(local_devices_.size(), client_->local_device_count())
|
||||
<< "Inconsistent local device count.";
|
||||
num_partitions = device_assignment_->computation_count();
|
||||
}
|
||||
|
||||
// SPMD sharding produces a single executable for multiple partitions.
|
||||
if (executables_.size() > 1) {
|
||||
@ -1427,10 +1438,6 @@ PjRtExecutable::PjRtExecutable(
|
||||
<< "Number of executables " << executables_.size()
|
||||
<< " did not match number of partitions " << num_partitions;
|
||||
}
|
||||
|
||||
CHECK_GE(local_devices_.size(), 1) << device_assignment_->ToString();
|
||||
CHECK_LE(local_devices_.size(), client_->local_device_count())
|
||||
<< "Inconsistent local device count.";
|
||||
}
|
||||
|
||||
Status PjRtExecutable::SetUpDonation(PjRtClient* client, bool tuple_inputs) {
|
||||
@ -1462,7 +1469,8 @@ const std::string& PjRtExecutable::name() const {
|
||||
StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
|
||||
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
|
||||
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
|
||||
Device* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers) const {
|
||||
Device* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers,
|
||||
std::shared_ptr<DeviceAssignment> device_assignment) const {
|
||||
int device_ordinal = device->local_device_state()->device_ordinal();
|
||||
tensorflow::profiler::TraceMeConsumer activity(
|
||||
"LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
|
||||
@ -1559,7 +1567,7 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
|
||||
run_options.set_allocator(client_->allocator());
|
||||
run_options.set_intra_op_thread_pool(
|
||||
client_->client()->backend().eigen_intra_op_thread_pool_device());
|
||||
run_options.set_device_assignment(device_assignment_.get());
|
||||
run_options.set_device_assignment(device_assignment.get());
|
||||
run_options.set_run_id(run_id);
|
||||
run_options.set_rng_seed(device_state->GetNewPrngSeed());
|
||||
run_options.set_gpu_executable_run_options(client_->gpu_run_options());
|
||||
@ -1603,7 +1611,7 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
|
||||
device_state->ThenExecuteOnCallbackThread(
|
||||
device_state->compute_stream(),
|
||||
[references{std::make_tuple(executables_[executable_idx],
|
||||
compute_reservation, device_assignment_)},
|
||||
compute_reservation, device_assignment)},
|
||||
donated_ptrs{std::move(donated_ptrs)}, allocator{client_->allocator()},
|
||||
device_ordinal]() {
|
||||
for (const auto& ptr : donated_ptrs) {
|
||||
@ -1616,7 +1624,7 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
|
||||
device_state->ThenRelease(
|
||||
device_state->compute_stream(),
|
||||
std::make_tuple(executables_[executable_idx], compute_reservation,
|
||||
device_assignment_));
|
||||
device_assignment));
|
||||
}
|
||||
|
||||
return result_buffer_or_status.ConsumeValueOrDie().ConsumeResult();
|
||||
@ -1625,9 +1633,22 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
|
||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
||||
PjRtExecutable::ExecuteHelper(absl::Span<PjRtBuffer* const> argument_handles,
|
||||
int replica, int partition, const RunId& run_id,
|
||||
const ExecuteOptions& options) const {
|
||||
const int device_id = (*device_assignment_)(replica, partition);
|
||||
Device* device = LookupDevice(*client_, device_id);
|
||||
const ExecuteOptions& options,
|
||||
Device* device) const {
|
||||
std::shared_ptr<DeviceAssignment> device_assignment;
|
||||
if (device == nullptr) {
|
||||
CHECK(device_assignment_ != nullptr);
|
||||
const int device_id = (*device_assignment_)(replica, partition);
|
||||
device = LookupDevice(*client_, device_id);
|
||||
device_assignment = device_assignment_;
|
||||
} else {
|
||||
CHECK(device_assignment_ == nullptr);
|
||||
CHECK_EQ(replica, 0);
|
||||
CHECK_EQ(partition, 0);
|
||||
CHECK(local_devices_.empty());
|
||||
device_assignment = std::make_shared<DeviceAssignment>(1, 1);
|
||||
(*device_assignment)(0, 0) = device->id();
|
||||
}
|
||||
|
||||
CHECK_EQ(device->host_id(), client_->host_id());
|
||||
int device_ordinal = device->local_device_state()->device_ordinal();
|
||||
@ -1640,9 +1661,9 @@ PjRtExecutable::ExecuteHelper(absl::Span<PjRtBuffer* const> argument_handles,
|
||||
|
||||
std::vector<PjRtBuffer::ScopedHold> device_buffers;
|
||||
device_buffers.reserve(argument_handles.size());
|
||||
StatusOr<ScopedShapedBuffer> result_buffer_or_status =
|
||||
EnqueueExecution(argument_handles, replica, partition, executable_idx,
|
||||
run_id, options, device, &device_buffers);
|
||||
StatusOr<ScopedShapedBuffer> result_buffer_or_status = EnqueueExecution(
|
||||
argument_handles, replica, partition, executable_idx, run_id, options,
|
||||
device, &device_buffers, std::move(device_assignment));
|
||||
|
||||
if (!result_buffer_or_status.ok()) {
|
||||
LOG(ERROR) << "Execution of replica " << replica
|
||||
@ -1736,6 +1757,13 @@ StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
||||
PjRtExecutable::ExecuteOnLocalDevice(
|
||||
absl::Span<PjRtBuffer* const> argument_handles, Device* device,
|
||||
const ExecuteOptions& options) const {
|
||||
if (device_assignment_ == nullptr) {
|
||||
VLOG(1) << "Executing portable single-core program on "
|
||||
<< device->DebugString();
|
||||
return ExecuteHelper(argument_handles,
|
||||
/*replica=*/0,
|
||||
/*partition=*/0, RunId(), options, device);
|
||||
}
|
||||
for (int i = 0; i < local_devices_.size(); ++i) {
|
||||
if (local_devices_[i] == device) {
|
||||
VLOG(1) << "Executing computation " << name();
|
||||
@ -1754,6 +1782,8 @@ StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
|
||||
PjRtExecutable::ExecuteOnLocalDevices(
|
||||
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
||||
const ExecuteOptions& options) const {
|
||||
CHECK(device_assignment_ != nullptr);
|
||||
|
||||
RunId run_id;
|
||||
tensorflow::profiler::TraceMeProducer activity(
|
||||
"LocalExecutable::ExecuteOnLocalDevices",
|
||||
@ -1952,16 +1982,33 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
||||
build_options.set_device_allocator(client->allocator());
|
||||
}
|
||||
|
||||
if (!build_options.has_device_assignment()) {
|
||||
VLOG(2) << "PjRtExecutable::Compile using default device_assignment.";
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
DeviceAssignment device_assignment,
|
||||
client->GetDefaultDeviceAssignment(build_options.num_replicas(),
|
||||
build_options.num_partitions()));
|
||||
build_options.set_device_assignment(device_assignment);
|
||||
int num_replicas;
|
||||
int num_partitions;
|
||||
std::shared_ptr<DeviceAssignment> device_assignment;
|
||||
if (options.compile_portable_executable) {
|
||||
if (build_options.has_device_assignment()) {
|
||||
return InvalidArgument(
|
||||
"CompileOptions requests portable executable but "
|
||||
"ExecutableBuildOptions includes a device assignment");
|
||||
}
|
||||
num_replicas = 1;
|
||||
num_partitions = 1;
|
||||
} else {
|
||||
if (!build_options.has_device_assignment()) {
|
||||
VLOG(2) << "PjRtExecutable::Compile using default device_assignment.";
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
DeviceAssignment device_assignment,
|
||||
client->GetDefaultDeviceAssignment(build_options.num_replicas(),
|
||||
build_options.num_partitions()));
|
||||
build_options.set_device_assignment(device_assignment);
|
||||
}
|
||||
VLOG(2) << "PjRtExecutable::Compile device_assignment:\n"
|
||||
<< build_options.device_assignment().ToString();
|
||||
num_replicas = build_options.device_assignment().replica_count();
|
||||
num_partitions = build_options.device_assignment().computation_count();
|
||||
device_assignment =
|
||||
std::make_shared<DeviceAssignment>(build_options.device_assignment());
|
||||
}
|
||||
VLOG(2) << "PjRtExecutable::Compile device_assignment:\n"
|
||||
<< build_options.device_assignment().ToString();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
|
||||
computation.GetProgramShape());
|
||||
@ -2020,33 +2067,31 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
||||
TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout));
|
||||
build_options.set_result_layout(result_layout);
|
||||
|
||||
const int num_replicas = build_options.device_assignment().replica_count();
|
||||
const int num_partitions =
|
||||
build_options.device_assignment().computation_count();
|
||||
|
||||
std::vector<std::pair<int, int>> local_logical_device_ids;
|
||||
std::vector<Device*> local_devices;
|
||||
for (int replica = 0; replica < num_replicas; ++replica) {
|
||||
for (int partition = 0; partition < num_partitions; ++partition) {
|
||||
int device_id = build_options.device_assignment()(replica, partition);
|
||||
Device* device = LookupDevice(*client, device_id);
|
||||
if (device->host_id() != client->host_id()) {
|
||||
VLOG(3) << "Non-local device: " << device_id;
|
||||
continue;
|
||||
if (device_assignment != nullptr) {
|
||||
for (int replica = 0; replica < num_replicas; ++replica) {
|
||||
for (int partition = 0; partition < num_partitions; ++partition) {
|
||||
int device_id = (*device_assignment)(replica, partition);
|
||||
Device* device = LookupDevice(*client, device_id);
|
||||
if (device->host_id() != client->host_id()) {
|
||||
VLOG(3) << "Non-local device: " << device_id;
|
||||
continue;
|
||||
}
|
||||
local_logical_device_ids.emplace_back(replica, partition);
|
||||
local_devices.push_back(device);
|
||||
}
|
||||
local_logical_device_ids.emplace_back(replica, partition);
|
||||
local_devices.push_back(device);
|
||||
}
|
||||
}
|
||||
if (local_devices.empty()) {
|
||||
return InvalidArgument(
|
||||
"Device assignment (%s) does not have any local devices.",
|
||||
build_options.device_assignment().ToString());
|
||||
}
|
||||
if (local_devices.empty()) {
|
||||
return InvalidArgument(
|
||||
"Device assignment (%s) does not have any local devices.",
|
||||
device_assignment->ToString());
|
||||
}
|
||||
|
||||
if (build_options.device_ordinal() < 0) {
|
||||
build_options.set_device_ordinal(
|
||||
local_devices.front()->local_device_state()->device_ordinal());
|
||||
if (build_options.device_ordinal() < 0) {
|
||||
build_options.set_device_ordinal(
|
||||
local_devices.front()->local_device_state()->device_ordinal());
|
||||
}
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -2056,7 +2101,7 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
||||
|
||||
auto py_executable = absl::make_unique<PjRtExecutable>(
|
||||
std::move(local_executables), options.parameter_is_tupled_arguments,
|
||||
build_options.device_assignment(), std::move(local_logical_device_ids),
|
||||
std::move(device_assignment), std::move(local_logical_device_ids),
|
||||
std::move(local_devices), client);
|
||||
TF_RETURN_IF_ERROR(py_executable->SetUpDonation(
|
||||
client, options.parameter_is_tupled_arguments));
|
||||
|
@ -647,6 +647,12 @@ struct CompileOptions {
|
||||
|
||||
// XLA's compilation time options.
|
||||
ExecutableBuildOptions executable_build_options;
|
||||
|
||||
// If true, the executable can be run on any device. May only be true if
|
||||
// !executable_build_options.has_device_assignment(), so only applies to
|
||||
// single-device executables. Beware: on GPUs, sometimes an executable
|
||||
// compiled for one device doesn't run on another.
|
||||
bool compile_portable_executable = false;
|
||||
};
|
||||
|
||||
struct ExecuteOptions {
|
||||
@ -673,7 +679,7 @@ class PjRtExecutable {
|
||||
|
||||
PjRtExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,
|
||||
bool parameter_is_tupled_arguments,
|
||||
DeviceAssignment device_assignment,
|
||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||
std::vector<std::pair<int, int>> local_logical_device_ids,
|
||||
std::vector<Device*> local_devices, PjRtClient* client);
|
||||
|
||||
@ -738,10 +744,12 @@ class PjRtExecutable {
|
||||
absl::Span<PjRtBuffer* const> argument_handles, int replica,
|
||||
int partition, int executable_idx, const RunId& run_id,
|
||||
const ExecuteOptions& options, Device* device,
|
||||
std::vector<PjRtBuffer::ScopedHold>* device_buffers) const;
|
||||
std::vector<PjRtBuffer::ScopedHold>* device_buffers,
|
||||
std::shared_ptr<DeviceAssignment> device_assignment) const;
|
||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteHelper(
|
||||
absl::Span<PjRtBuffer* const> argument_handles, int replica,
|
||||
int partition, const RunId& run_id, const ExecuteOptions& options) const;
|
||||
int partition, const RunId& run_id, const ExecuteOptions& options,
|
||||
Device* device = nullptr) const;
|
||||
|
||||
// Create shared pointers so we can free them after the execution: with
|
||||
// asynchronous execution, the process being executed can outlive the
|
||||
|
Loading…
x
Reference in New Issue
Block a user