Cleanup PjRtExecutable API.
- Regroup Execute API into 3 distinct Execute methods - Rename local_devices to addressable_devices - Introduce LogicalDeviceId struct to name replica and partition, to replace std::pair. - Return Span instead of const vector&. PiperOrigin-RevId: 345551501 Change-Id: I2f7b50101849af02c7188d547d78a53dc7d030be
This commit is contained in:
parent
077fe29d9d
commit
5bbe185466
@ -93,11 +93,11 @@ TEST(GpuMultiStream, Basics) {
|
|||||||
options.untuple_result = true;
|
options.untuple_result = true;
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
auto out_buffers,
|
auto out_buffers,
|
||||||
executable->Execute({in_buffer0.get(), in_buffer1.get()}, options));
|
executable->Execute({{in_buffer0.get(), in_buffer1.get()}}, options));
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto out_literal, out_buffers[0]->ToLiteral());
|
TF_ASSERT_OK_AND_ASSIGN(auto out_literal, out_buffers[0][0]->ToLiteral());
|
||||||
LiteralTestUtil::ExpectR1Equal<int32>(expected_outputs, *out_literal);
|
LiteralTestUtil::ExpectR1Equal<int32>(expected_outputs, *out_literal);
|
||||||
TF_ASSERT_OK_AND_ASSIGN(out_literal, out_buffers[1]->ToLiteral());
|
TF_ASSERT_OK_AND_ASSIGN(out_literal, out_buffers[0][1]->ToLiteral());
|
||||||
LiteralTestUtil::ExpectR1Equal<int32>(expected_outputs, *out_literal);
|
LiteralTestUtil::ExpectR1Equal<int32>(expected_outputs, *out_literal);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1458,13 +1458,14 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
|
|||||||
std::vector<std::unique_ptr<LocalExecutable>> executables,
|
std::vector<std::unique_ptr<LocalExecutable>> executables,
|
||||||
bool parameter_is_tupled_arguments,
|
bool parameter_is_tupled_arguments,
|
||||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||||
std::vector<std::pair<int, int>> local_logical_device_ids,
|
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
|
||||||
std::vector<PjRtDevice*> local_devices, PjRtClient* client)
|
std::vector<PjRtDevice*> addressable_devices, PjRtClient* client)
|
||||||
: client_(client),
|
: client_(client),
|
||||||
device_assignment_(std::move(device_assignment)),
|
device_assignment_(std::move(device_assignment)),
|
||||||
parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
|
parameter_is_tupled_arguments_(parameter_is_tupled_arguments),
|
||||||
local_logical_device_ids_(std::move(local_logical_device_ids)),
|
addressable_device_logical_ids_(
|
||||||
local_devices_(std::move(local_devices)) {
|
std::move(addressable_device_logical_ids)),
|
||||||
|
addressable_devices_(std::move(addressable_devices)) {
|
||||||
executables_.reserve(executables.size());
|
executables_.reserve(executables.size());
|
||||||
for (auto& executable : executables) {
|
for (auto& executable : executables) {
|
||||||
executables_.emplace_back(std::move(executable));
|
executables_.emplace_back(std::move(executable));
|
||||||
@ -1475,13 +1476,13 @@ PjRtStreamExecutorExecutable::PjRtStreamExecutorExecutable(
|
|||||||
// This must go after `executables_` is initialized.
|
// This must go after `executables_` is initialized.
|
||||||
VLOG(1) << "PjRtStreamExecutorExecutable portable single-core";
|
VLOG(1) << "PjRtStreamExecutorExecutable portable single-core";
|
||||||
num_partitions = 1;
|
num_partitions = 1;
|
||||||
CHECK(local_devices_.empty());
|
CHECK(addressable_devices_.empty());
|
||||||
} else {
|
} else {
|
||||||
// This must go after `executables_` is initialized.
|
// This must go after `executables_` is initialized.
|
||||||
VLOG(1) << "PjRtStreamExecutorExecutable device_assignment:\n"
|
VLOG(1) << "PjRtStreamExecutorExecutable device_assignment:\n"
|
||||||
<< device_assignment_->ToString();
|
<< device_assignment_->ToString();
|
||||||
CHECK_GE(local_devices_.size(), 1) << device_assignment_->ToString();
|
CHECK_GE(addressable_devices_.size(), 1) << device_assignment_->ToString();
|
||||||
CHECK_LE(local_devices_.size(), client_->local_device_count())
|
CHECK_LE(addressable_devices_.size(), client_->local_device_count())
|
||||||
<< "Inconsistent local device count.";
|
<< "Inconsistent local device count.";
|
||||||
num_partitions = device_assignment_->computation_count();
|
num_partitions = device_assignment_->computation_count();
|
||||||
}
|
}
|
||||||
@ -1807,7 +1808,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
|
|||||||
CHECK(device_assignment_ == nullptr);
|
CHECK(device_assignment_ == nullptr);
|
||||||
CHECK_EQ(replica, 0);
|
CHECK_EQ(replica, 0);
|
||||||
CHECK_EQ(partition, 0);
|
CHECK_EQ(partition, 0);
|
||||||
CHECK(local_devices_.empty());
|
CHECK(addressable_devices_.empty());
|
||||||
device_assignment = std::make_shared<DeviceAssignment>(1, 1);
|
device_assignment = std::make_shared<DeviceAssignment>(1, 1);
|
||||||
(*device_assignment)(0, 0) = device->id();
|
(*device_assignment)(0, 0) = device->id();
|
||||||
}
|
}
|
||||||
@ -1875,94 +1876,52 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
|
|||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
|
||||||
PjRtStreamExecutorExecutable::Execute(
|
|
||||||
absl::Span<PjRtBuffer* const> argument_handles,
|
|
||||||
const ExecuteOptions& options) const {
|
|
||||||
if (num_replicas() != 1) {
|
|
||||||
return InvalidArgument(
|
|
||||||
"Attempted to execute computation with %d replicas using Execute()",
|
|
||||||
num_replicas());
|
|
||||||
}
|
|
||||||
if (num_partitions() != 1) {
|
|
||||||
return InvalidArgument(
|
|
||||||
"Attempted to execute computation with %d partitions using Execute()",
|
|
||||||
num_partitions());
|
|
||||||
}
|
|
||||||
VLOG(1) << "Executing computation " << name();
|
|
||||||
return ExecuteHelper(argument_handles, /*replica=*/0, /*partition=*/0,
|
|
||||||
RunId(), options);
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
|
||||||
PjRtStreamExecutorExecutable::ExecuteOnLocalDevice(
|
|
||||||
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
|
||||||
const ExecuteOptions& options) const {
|
|
||||||
if (device_assignment_ == nullptr) {
|
|
||||||
VLOG(1) << "Executing portable single-core program on "
|
|
||||||
<< device->DebugString();
|
|
||||||
return ExecuteHelper(argument_handles,
|
|
||||||
/*replica=*/0,
|
|
||||||
/*partition=*/0, RunId(), options, device);
|
|
||||||
}
|
|
||||||
for (int i = 0; i < local_devices_.size(); ++i) {
|
|
||||||
if (local_devices_[i] == device) {
|
|
||||||
VLOG(1) << "Executing computation " << name();
|
|
||||||
return ExecuteHelper(argument_handles,
|
|
||||||
/*replica=*/local_logical_device_ids_[i].first,
|
|
||||||
/*partition=*/local_logical_device_ids_[i].second,
|
|
||||||
RunId(), options);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return InvalidArgument(
|
|
||||||
"Attempted to execute on device id %d which is not a local device",
|
|
||||||
device->id());
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
|
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
|
||||||
PjRtStreamExecutorExecutable::ExecuteOnLocalDevices(
|
PjRtStreamExecutorExecutable::Execute(
|
||||||
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
||||||
const ExecuteOptions& options) const {
|
const ExecuteOptions& options) const {
|
||||||
CHECK(device_assignment_ != nullptr);
|
if (device_assignment_ == nullptr) {
|
||||||
|
return InvalidArgument("Execute expects a non-null device_assignment");
|
||||||
|
}
|
||||||
|
|
||||||
RunId run_id;
|
RunId run_id;
|
||||||
tensorflow::profiler::TraceMeProducer activity(
|
tensorflow::profiler::TraceMeProducer activity(
|
||||||
"LocalExecutable::ExecuteOnLocalDevices",
|
"PjRtStreamExecutorExecutable::Execute",
|
||||||
tensorflow::profiler::ContextType::kPjRt, run_id.ToInt());
|
tensorflow::profiler::ContextType::kPjRt, run_id.ToInt());
|
||||||
|
|
||||||
const int num_local_devices = local_devices_.size();
|
const int num_addressable_devices = addressable_devices_.size();
|
||||||
|
|
||||||
if (argument_handles.size() != num_local_devices) {
|
if (argument_handles.size() != num_addressable_devices) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Attempted to execute with %d argument lists when local device "
|
"Attempted to execute with %d argument lists when local device "
|
||||||
"count is %d (total replica count: %d, partition count: %d)",
|
"count is %d (total replica count: %d, partition count: %d)",
|
||||||
argument_handles.size(), num_local_devices, num_replicas(),
|
argument_handles.size(), num_addressable_devices, num_replicas(),
|
||||||
num_partitions());
|
num_partitions());
|
||||||
}
|
}
|
||||||
|
|
||||||
VLOG(1) << "Executing computation " << name()
|
VLOG(1) << "Executing computation " << name()
|
||||||
<< "; num_replicas=" << num_replicas()
|
<< "; num_replicas=" << num_replicas()
|
||||||
<< " num_partitions=" << num_partitions()
|
<< " num_partitions=" << num_partitions()
|
||||||
<< " num_local_devices=" << num_local_devices;
|
<< " num_addressable_devices=" << num_addressable_devices;
|
||||||
std::vector<StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>> results(
|
std::vector<StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>> results(
|
||||||
num_local_devices);
|
num_addressable_devices);
|
||||||
if (num_local_devices == 1) {
|
if (num_addressable_devices == 1) {
|
||||||
// Fast-path if there is only one device — run the computation on the
|
// Fast-path if there is only one device — run the computation on the
|
||||||
// current thread.
|
// current thread.
|
||||||
const int replica = local_logical_device_ids_[0].first;
|
const int replica = addressable_device_logical_ids_[0].replica;
|
||||||
const int partition = local_logical_device_ids_[0].second;
|
const int partition = addressable_device_logical_ids_[0].partition;
|
||||||
results[0] =
|
results[0] =
|
||||||
ExecuteHelper(argument_handles[0], replica, partition, run_id, options);
|
ExecuteHelper(argument_handles[0], replica, partition, run_id, options);
|
||||||
} else {
|
} else {
|
||||||
absl::Mutex mu;
|
absl::Mutex mu;
|
||||||
int running = num_local_devices;
|
int running = num_addressable_devices;
|
||||||
int failed = 0;
|
int failed = 0;
|
||||||
Status first_failure_status;
|
Status first_failure_status;
|
||||||
|
|
||||||
for (int i = 0; i < num_local_devices; ++i) {
|
for (int i = 0; i < num_addressable_devices; ++i) {
|
||||||
const int replica = local_logical_device_ids_[i].first;
|
const int replica = addressable_device_logical_ids_[i].replica;
|
||||||
const int partition = local_logical_device_ids_[i].second;
|
const int partition = addressable_device_logical_ids_[i].partition;
|
||||||
PjRtDevice* device = local_devices_[i];
|
PjRtDevice* device = addressable_devices_[i];
|
||||||
const LocalDeviceState& device_state = *device->local_device_state();
|
const LocalDeviceState& device_state = *device->local_device_state();
|
||||||
device_state.execute_thread()->Schedule([&, replica, partition, i] {
|
device_state.execute_thread()->Schedule([&, replica, partition, i] {
|
||||||
results[i] = ExecuteHelper(argument_handles[i], replica, partition,
|
results[i] = ExecuteHelper(argument_handles[i], replica, partition,
|
||||||
@ -2008,10 +1967,10 @@ PjRtStreamExecutorExecutable::ExecuteOnLocalDevices(
|
|||||||
VLOG(1) << "Replicated execution complete.";
|
VLOG(1) << "Replicated execution complete.";
|
||||||
|
|
||||||
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> wrapped_results(
|
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> wrapped_results(
|
||||||
num_local_devices);
|
num_addressable_devices);
|
||||||
for (int i = 0; i < num_local_devices; ++i) {
|
for (int i = 0; i < num_addressable_devices; ++i) {
|
||||||
const int replica = local_logical_device_ids_[i].first;
|
const int replica = addressable_device_logical_ids_[i].replica;
|
||||||
const int partition = local_logical_device_ids_[i].second;
|
const int partition = addressable_device_logical_ids_[i].partition;
|
||||||
auto& statusor = results[i];
|
auto& statusor = results[i];
|
||||||
if (!statusor.ok()) {
|
if (!statusor.ok()) {
|
||||||
return AppendStatus(
|
return AppendStatus(
|
||||||
@ -2026,6 +1985,52 @@ PjRtStreamExecutorExecutable::ExecuteOnLocalDevices(
|
|||||||
return wrapped_results;
|
return wrapped_results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
||||||
|
PjRtStreamExecutorExecutable::ExecuteSharded(
|
||||||
|
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||||
|
const ExecuteOptions& options) const {
|
||||||
|
if (device_assignment_ == nullptr) {
|
||||||
|
return InvalidArgument("ExecuteShard expects a non-null device_assignment");
|
||||||
|
}
|
||||||
|
for (int i = 0; i < addressable_devices_.size(); ++i) {
|
||||||
|
if (addressable_devices_[i] == device) {
|
||||||
|
VLOG(1) << "ExecuteShard executes computation " << name()
|
||||||
|
<< " on assigned replica/partition on device "
|
||||||
|
<< device->DebugString();
|
||||||
|
return ExecuteHelper(
|
||||||
|
argument_handles, addressable_device_logical_ids_[i].replica,
|
||||||
|
addressable_device_logical_ids_[i].partition, RunId(), options);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return InvalidArgument(
|
||||||
|
"ExecuteShard attempted to execute on device id %d which is not "
|
||||||
|
"addressable by this client",
|
||||||
|
device->id());
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
||||||
|
PjRtStreamExecutorExecutable::ExecutePortable(
|
||||||
|
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||||
|
const ExecuteOptions& options) const {
|
||||||
|
if (device_assignment_ != nullptr) {
|
||||||
|
return InvalidArgument("ExecutePortable gets a non-portable executable");
|
||||||
|
}
|
||||||
|
if (num_replicas() != 1 || num_partitions() != 1) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"ExecutePortable expects a single-core executable but gets "
|
||||||
|
"one with %d replica %d partition",
|
||||||
|
num_replicas(), num_partitions());
|
||||||
|
}
|
||||||
|
if (device == nullptr) {
|
||||||
|
return InvalidArgument("ExecutePortable expects a device to be specified");
|
||||||
|
}
|
||||||
|
VLOG(1) << "ExecutePortable executes single-core portable executable "
|
||||||
|
<< name();
|
||||||
|
return ExecuteHelper(argument_handles,
|
||||||
|
/*replica=*/0,
|
||||||
|
/*partition=*/0, RunId(), options, device);
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<std::shared_ptr<HloModule>>>
|
StatusOr<std::vector<std::shared_ptr<HloModule>>>
|
||||||
PjRtStreamExecutorExecutable::GetHloModules() const {
|
PjRtStreamExecutorExecutable::GetHloModules() const {
|
||||||
std::vector<std::shared_ptr<HloModule>> modules;
|
std::vector<std::shared_ptr<HloModule>> modules;
|
||||||
@ -2220,9 +2225,12 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
|||||||
TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout));
|
TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout));
|
||||||
build_options.set_result_layout(result_layout);
|
build_options.set_result_layout(result_layout);
|
||||||
|
|
||||||
std::vector<std::pair<int, int>> local_logical_device_ids;
|
// Find devices that are addressable by this client/task.
|
||||||
std::vector<PjRtDevice*> local_devices;
|
std::vector<PjRtExecutable::LogicalDeviceIds> addressable_device_logical_ids;
|
||||||
|
std::vector<PjRtDevice*> addressable_devices;
|
||||||
if (device_assignment != nullptr) {
|
if (device_assignment != nullptr) {
|
||||||
|
addressable_device_logical_ids.reserve(num_replicas * num_partitions);
|
||||||
|
addressable_devices.reserve(num_replicas * num_partitions);
|
||||||
for (int replica = 0; replica < num_replicas; ++replica) {
|
for (int replica = 0; replica < num_replicas; ++replica) {
|
||||||
for (int partition = 0; partition < num_partitions; ++partition) {
|
for (int partition = 0; partition < num_partitions; ++partition) {
|
||||||
int device_id = (*device_assignment)(replica, partition);
|
int device_id = (*device_assignment)(replica, partition);
|
||||||
@ -2231,11 +2239,13 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
|||||||
VLOG(3) << "Non-local device: " << device_id;
|
VLOG(3) << "Non-local device: " << device_id;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
local_logical_device_ids.emplace_back(replica, partition);
|
addressable_device_logical_ids.push_back(
|
||||||
local_devices.push_back(device);
|
PjRtExecutable::LogicalDeviceIds{.replica = replica,
|
||||||
|
.partition = partition});
|
||||||
|
addressable_devices.push_back(device);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (local_devices.empty()) {
|
if (addressable_devices.empty()) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Device assignment (%s) does not have any local devices.",
|
"Device assignment (%s) does not have any local devices.",
|
||||||
device_assignment->ToString());
|
device_assignment->ToString());
|
||||||
@ -2243,7 +2253,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
|||||||
|
|
||||||
if (build_options.device_ordinal() < 0) {
|
if (build_options.device_ordinal() < 0) {
|
||||||
build_options.set_device_ordinal(
|
build_options.set_device_ordinal(
|
||||||
local_devices.front()->local_device_state()->device_ordinal());
|
addressable_devices.front()->local_device_state()->device_ordinal());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2253,8 +2263,8 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
|||||||
|
|
||||||
auto executable = absl::make_unique<PjRtStreamExecutorExecutable>(
|
auto executable = absl::make_unique<PjRtStreamExecutorExecutable>(
|
||||||
std::move(local_executables), options.parameter_is_tupled_arguments,
|
std::move(local_executables), options.parameter_is_tupled_arguments,
|
||||||
std::move(device_assignment), std::move(local_logical_device_ids),
|
std::move(device_assignment), std::move(addressable_device_logical_ids),
|
||||||
std::move(local_devices), this);
|
std::move(addressable_devices), this);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
executable->SetUpDonation(options.parameter_is_tupled_arguments));
|
executable->SetUpDonation(options.parameter_is_tupled_arguments));
|
||||||
return std::unique_ptr<PjRtExecutable>(std::move(executable));
|
return std::unique_ptr<PjRtExecutable>(std::move(executable));
|
||||||
|
@ -781,41 +781,43 @@ class PjRtExecutable {
|
|||||||
|
|
||||||
// The replica and partition indices of device_assignment to be run by this
|
// The replica and partition indices of device_assignment to be run by this
|
||||||
// client. On single-host platforms without partitioning, this is all replicas
|
// client. On single-host platforms without partitioning, this is all replicas
|
||||||
// (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case
|
// (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
|
||||||
// on multi-host platforms. If there are 4 replicas and 2 partitions on a
|
// case on multi-host platforms. If there are 4 replicas and 2 partitions on a
|
||||||
// single host platform, size of local_logical_device_ids_ is 4*2 = 8.
|
// single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
|
||||||
// TODO(zhangqiaorjc): Add a struct for the pair and return a span.
|
struct LogicalDeviceIds {
|
||||||
virtual const std::vector<std::pair<int, int>>& local_logical_device_ids()
|
int replica;
|
||||||
|
int partition;
|
||||||
|
};
|
||||||
|
virtual absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
|
||||||
const = 0;
|
const = 0;
|
||||||
|
|
||||||
// local_devices()[i] is the Device to which local_logical_device_ids()[i] is
|
// addressable_devices()[i] is the Device to which
|
||||||
// assigned.
|
// addressable_device_logical_ids()[i] is assigned.
|
||||||
virtual const std::vector<PjRtDevice*>& local_devices() const = 0;
|
virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0;
|
||||||
|
|
||||||
// Return an HloModule (optimized) per partition.
|
// Return an HloModule (optimized) per partition.
|
||||||
virtual StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
|
virtual StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
|
||||||
const = 0;
|
const = 0;
|
||||||
|
|
||||||
// Execute on replica 0 and partition 0 with the requirement that there's a
|
// Executes on devices addressable by the client. Requires executable has a
|
||||||
// single replica and partition.
|
// device_assignment and all devices in the device_assignment are addressable
|
||||||
// TODO(zhangqiaorjc): Merge with ExecuteOnLocalDevice. Remove "local".
|
// by the client.
|
||||||
virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> Execute(
|
virtual StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
|
||||||
absl::Span<PjRtBuffer* const> argument_handles,
|
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
||||||
|
const ExecuteOptions& options) const = 0;
|
||||||
|
|
||||||
|
// Execute the assigned replica/partition on a given `device`. Requires
|
||||||
|
// executable has a device_assignment, `device` is present in the
|
||||||
|
// device_assignment and addressable by the client.
|
||||||
|
virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
|
||||||
|
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||||
const ExecuteOptions& options) const = 0;
|
const ExecuteOptions& options) const = 0;
|
||||||
|
|
||||||
// Execute on a given local device.
|
// Execute on a given `device`. Requires `device` to be addressable by client.
|
||||||
virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>>
|
// Requires executable has exactly 1 replica and 1 partition and no
|
||||||
ExecuteOnLocalDevice(absl::Span<PjRtBuffer* const> argument_handles,
|
// device_assignment (thus portable).
|
||||||
PjRtDevice* device,
|
virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
|
||||||
const ExecuteOptions& options) const = 0;
|
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||||
|
|
||||||
// Execute on local devices. Takes a sequence of argument lists (one argument
|
|
||||||
// list per local device) and returns a tuple of results (one result per local
|
|
||||||
// device). The number of argument lists must be equal to the local device
|
|
||||||
// count.
|
|
||||||
virtual StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
|
|
||||||
ExecuteOnLocalDevices(
|
|
||||||
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
|
||||||
const ExecuteOptions& options) const = 0;
|
const ExecuteOptions& options) const = 0;
|
||||||
|
|
||||||
// Asynchronously free resources after the last execution completes.
|
// Asynchronously free resources after the last execution completes.
|
||||||
@ -830,8 +832,8 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
|
|||||||
std::vector<std::unique_ptr<LocalExecutable>> executables,
|
std::vector<std::unique_ptr<LocalExecutable>> executables,
|
||||||
bool parameter_is_tupled_arguments,
|
bool parameter_is_tupled_arguments,
|
||||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||||
std::vector<std::pair<int, int>> local_logical_device_ids,
|
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
|
||||||
std::vector<PjRtDevice*> local_devices, PjRtClient* client);
|
std::vector<PjRtDevice*> addressable_devices, PjRtClient* client);
|
||||||
|
|
||||||
~PjRtStreamExecutorExecutable() override = default;
|
~PjRtStreamExecutorExecutable() override = default;
|
||||||
|
|
||||||
@ -859,39 +861,34 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
|
|||||||
return *device_assignment_;
|
return *device_assignment_;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<std::pair<int, int>>& local_logical_device_ids()
|
absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
|
||||||
const override {
|
const override {
|
||||||
return local_logical_device_ids_;
|
return addressable_device_logical_ids_;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<PjRtDevice*>& local_devices() const override {
|
absl::Span<PjRtDevice* const> addressable_devices() const override {
|
||||||
return local_devices_;
|
return addressable_devices_;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return an HloModule per partition.
|
// Return an HloModule per partition.
|
||||||
StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
|
StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
|
||||||
const override;
|
const override;
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> Execute(
|
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
|
||||||
absl::Span<PjRtBuffer* const> argument_handles,
|
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
||||||
const ExecuteOptions& options) const override;
|
const ExecuteOptions& options) const override;
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteOnLocalDevice(
|
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
|
||||||
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||||
const ExecuteOptions& options) const override;
|
const ExecuteOptions& options) const override;
|
||||||
|
|
||||||
// Execute on local devices. Takes a sequence of argument lists (one argument
|
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
|
||||||
// list per local device) and returns a tuple of results (one result per local
|
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||||
// device). The number of argument lists must be equal to the local device
|
|
||||||
// count.
|
|
||||||
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
|
|
||||||
ExecuteOnLocalDevices(
|
|
||||||
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
|
||||||
const ExecuteOptions& options) const override;
|
const ExecuteOptions& options) const override;
|
||||||
|
|
||||||
void Delete() override { executables_.clear(); }
|
void Delete() override { executables_.clear(); }
|
||||||
|
|
||||||
const std::vector<std::shared_ptr<LocalExecutable>>& executables() const {
|
absl::Span<const std::shared_ptr<LocalExecutable>> executables() const {
|
||||||
return executables_;
|
return executables_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -950,17 +947,15 @@ class PjRtStreamExecutorExecutable : public PjRtExecutable {
|
|||||||
|
|
||||||
// The replica and partition indices of device_assignment_ to be run by this
|
// The replica and partition indices of device_assignment_ to be run by this
|
||||||
// client. On single-host platforms without partitioning, this is all replicas
|
// client. On single-host platforms without partitioning, this is all replicas
|
||||||
// (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case
|
// (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
|
||||||
// on multi-host platforms.
|
// case on multi-host platforms. If there are 4 replicas and 2 partitions on a
|
||||||
// If there are 4 replicas and 2 partitions on a single host platform, size of
|
// single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
|
||||||
// local_logical_device_ids_ is 4*2 = 8.
|
std::vector<LogicalDeviceIds> addressable_device_logical_ids_;
|
||||||
std::vector<std::pair<int, int>> local_logical_device_ids_;
|
|
||||||
|
|
||||||
// local_devices_[i] is the Device to which local_logical_device_ids_[i] is
|
// addressable_devices_[i] is the Device to which
|
||||||
// assigned.
|
// addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of
|
||||||
// shared_ptrs instead of unique_ptrs to play well with the Python bindings
|
// unique_ptrs to play well with the Python bindings (see xla.cc).
|
||||||
// (see xla.cc).
|
std::vector<PjRtDevice*> addressable_devices_;
|
||||||
std::vector<PjRtDevice*> local_devices_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Executables can donate buffers so that buffers can be aliased from inputs
|
// Executables can donate buffers so that buffers can be aliased from inputs
|
||||||
|
@ -781,7 +781,7 @@ CacheEntry* CompiledFunction::AddCacheEntry(const py::args& args,
|
|||||||
|
|
||||||
cache_entry->executable = std::move(executable);
|
cache_entry->executable = std::move(executable);
|
||||||
int num_devices =
|
int num_devices =
|
||||||
cache_entry->executable->pjrt_executable().local_devices().size();
|
cache_entry->executable->pjrt_executable().addressable_devices().size();
|
||||||
// The presence of jit(pmap) is detected from Python.
|
// The presence of jit(pmap) is detected from Python.
|
||||||
CHECK_EQ(num_devices, 1);
|
CHECK_EQ(num_devices, 1);
|
||||||
|
|
||||||
|
@ -413,8 +413,9 @@ Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
|
|||||||
devices_[device_idx]->client()->Compile(
|
devices_[device_idx]->client()->Compile(
|
||||||
computation, std::move(compile_options)));
|
computation, std::move(compile_options)));
|
||||||
ExecuteOptions execute_options;
|
ExecuteOptions execute_options;
|
||||||
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
|
TF_ASSIGN_OR_RETURN(
|
||||||
executable->Execute({}, execute_options));
|
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers,
|
||||||
|
executable->Execute({{}}, execute_options));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,8 +43,9 @@ Status CompileAndExecute(XlaBuilder* builder, XlaOp root, int device_id,
|
|||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
|
||||||
client->Compile(computation, std::move(compile_options)));
|
client->Compile(computation, std::move(compile_options)));
|
||||||
ExecuteOptions execute_options;
|
ExecuteOptions execute_options;
|
||||||
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
|
TF_ASSIGN_OR_RETURN(
|
||||||
executable->Execute({}, execute_options));
|
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers,
|
||||||
|
executable->Execute({{}}, execute_options));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,26 +58,29 @@ PyExecutable::~PyExecutable() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<ClientAndPtr<PjRtDevice>> PyExecutable::LocalDevices() const {
|
std::vector<ClientAndPtr<PjRtDevice>> PyExecutable::AddressableDevices() const {
|
||||||
std::vector<ClientAndPtr<PjRtDevice>> devices;
|
std::vector<ClientAndPtr<PjRtDevice>> devices;
|
||||||
devices.reserve(executable_->local_devices().size());
|
devices.reserve(executable_->addressable_devices().size());
|
||||||
for (PjRtDevice* device : executable_->local_devices()) {
|
for (PjRtDevice* device : executable_->addressable_devices()) {
|
||||||
devices.push_back(WrapWithClient(client_, device));
|
devices.push_back(WrapWithClient(client_, device));
|
||||||
}
|
}
|
||||||
return devices;
|
return devices;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Used by JAX JIT which has C++ PjRtBuffers as inputs (Numpy to PjRtBuffer is
|
||||||
|
// faster and simpler than Numpy to PyBuffer to PjRtBuffer) and requires
|
||||||
|
// PyBuffer as outputs as it will return to Python.
|
||||||
StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::PjRtExecute(
|
StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::PjRtExecute(
|
||||||
absl::Span<PjRtBuffer* const> args) {
|
const std::vector<PjRtBuffer*>& args) {
|
||||||
std::vector<std::unique_ptr<PjRtBuffer>> output_buffers;
|
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
|
||||||
{
|
{
|
||||||
py::gil_scoped_release gil_release;
|
py::gil_scoped_release gil_release;
|
||||||
TF_ASSIGN_OR_RETURN(output_buffers, executable_->Execute(args, options_));
|
TF_ASSIGN_OR_RETURN(output_buffers, executable_->Execute({args}, options_));
|
||||||
}
|
}
|
||||||
auto traceback = Traceback::Get();
|
auto traceback = Traceback::Get();
|
||||||
std::vector<std::unique_ptr<PyBuffer>> outputs;
|
std::vector<std::unique_ptr<PyBuffer>> outputs;
|
||||||
outputs.reserve(output_buffers.size());
|
outputs.reserve(output_buffers[0].size());
|
||||||
for (auto& buffer : output_buffers) {
|
for (auto& buffer : output_buffers[0]) {
|
||||||
outputs.push_back(
|
outputs.push_back(
|
||||||
std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
|
std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
|
||||||
}
|
}
|
||||||
@ -86,19 +89,19 @@ StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::PjRtExecute(
|
|||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::Execute(
|
StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::Execute(
|
||||||
absl::Span<PyBuffer* const> args) {
|
absl::Span<PyBuffer* const> args) {
|
||||||
std::vector<std::unique_ptr<PjRtBuffer>> output_buffers;
|
std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
|
||||||
{
|
{
|
||||||
py::gil_scoped_release gil_release;
|
py::gil_scoped_release gil_release;
|
||||||
std::vector<PjRtBuffer*> arg_buffers(args.size());
|
std::vector<PjRtBuffer*> arg_buffers(args.size());
|
||||||
absl::c_transform(args, arg_buffers.begin(),
|
absl::c_transform(args, arg_buffers.begin(),
|
||||||
[](PyBuffer* buf) { return buf->buffer(); });
|
[](PyBuffer* buf) { return buf->buffer(); });
|
||||||
TF_ASSIGN_OR_RETURN(output_buffers,
|
TF_ASSIGN_OR_RETURN(output_buffers,
|
||||||
executable_->Execute(arg_buffers, options_));
|
executable_->Execute({arg_buffers}, options_));
|
||||||
}
|
}
|
||||||
auto traceback = Traceback::Get();
|
auto traceback = Traceback::Get();
|
||||||
std::vector<std::unique_ptr<PyBuffer>> outputs;
|
std::vector<std::unique_ptr<PyBuffer>> outputs;
|
||||||
outputs.reserve(output_buffers.size());
|
outputs.reserve(output_buffers[0].size());
|
||||||
for (auto& buffer : output_buffers) {
|
for (auto& buffer : output_buffers[0]) {
|
||||||
outputs.push_back(
|
outputs.push_back(
|
||||||
std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
|
std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
|
||||||
}
|
}
|
||||||
@ -117,8 +120,8 @@ PyExecutable::ExecuteOnLocalDevices(
|
|||||||
absl::c_transform(args[computation], arg_buffers[computation].begin(),
|
absl::c_transform(args[computation], arg_buffers[computation].begin(),
|
||||||
[](PyBuffer* buf) { return buf->buffer(); });
|
[](PyBuffer* buf) { return buf->buffer(); });
|
||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(output_buffers, executable_->ExecuteOnLocalDevices(
|
TF_ASSIGN_OR_RETURN(output_buffers,
|
||||||
arg_buffers, options_));
|
executable_->Execute(arg_buffers, options_));
|
||||||
}
|
}
|
||||||
auto traceback = Traceback::Get();
|
auto traceback = Traceback::Get();
|
||||||
std::vector<std::vector<std::unique_ptr<PyBuffer>>> outputs;
|
std::vector<std::vector<std::unique_ptr<PyBuffer>>> outputs;
|
||||||
|
@ -43,11 +43,12 @@ class PyExecutable {
|
|||||||
|
|
||||||
std::shared_ptr<PyClient> client() const { return client_; }
|
std::shared_ptr<PyClient> client() const { return client_; }
|
||||||
|
|
||||||
const std::vector<std::pair<int, int>>& local_logical_device_ids() const {
|
absl::Span<const PjRtExecutable::LogicalDeviceIds>
|
||||||
return executable_->local_logical_device_ids();
|
addressable_device_logical_ids() const {
|
||||||
|
return executable_->addressable_device_logical_ids();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<ClientAndPtr<PjRtDevice>> LocalDevices() const;
|
std::vector<ClientAndPtr<PjRtDevice>> AddressableDevices() const;
|
||||||
|
|
||||||
int64 SizeOfGeneratedCodeInBytes() const {
|
int64 SizeOfGeneratedCodeInBytes() const {
|
||||||
return executable_->SizeOfGeneratedCodeInBytes();
|
return executable_->SizeOfGeneratedCodeInBytes();
|
||||||
@ -60,7 +61,7 @@ class PyExecutable {
|
|||||||
|
|
||||||
// Same as above, but take as inputs `PjRtBuffer*`. Only targets C++ code.
|
// Same as above, but take as inputs `PjRtBuffer*`. Only targets C++ code.
|
||||||
StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PjRtExecute(
|
StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PjRtExecute(
|
||||||
absl::Span<PjRtBuffer* const> args);
|
const std::vector<PjRtBuffer*>& args);
|
||||||
|
|
||||||
StatusOr<std::vector<std::vector<std::unique_ptr<PyBuffer>>>>
|
StatusOr<std::vector<std::vector<std::unique_ptr<PyBuffer>>>>
|
||||||
ExecuteOnLocalDevices(absl::Span<const std::vector<PyBuffer*>> args);
|
ExecuteOnLocalDevices(absl::Span<const std::vector<PyBuffer*>> args);
|
||||||
|
@ -531,7 +531,7 @@ PyTpuExecutable::PyTpuExecutable(
|
|||||||
<< "Inserting duplicate replica:" << replica;
|
<< "Inserting duplicate replica:" << replica;
|
||||||
executables_[replica] =
|
executables_[replica] =
|
||||||
client_->driver()->LoadProgram(device_id, compiled_program.get(), {});
|
client_->driver()->LoadProgram(device_id, compiled_program.get(), {});
|
||||||
local_logical_device_ids_.emplace_back(replica, partition);
|
addressable_device_logical_ids_.emplace_back(replica, partition);
|
||||||
local_devices_.push_back(device);
|
local_devices_.push_back(device);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -711,8 +711,8 @@ PyTpuExecutable::ExecuteOnLocalDevices(
|
|||||||
// long time and we want all cores to be scheduled in parallel.
|
// long time and we want all cores to be scheduled in parallel.
|
||||||
thread_pool->Schedule([this, i, argument_handles, &results, &results_lock,
|
thread_pool->Schedule([this, i, argument_handles, &results, &results_lock,
|
||||||
&execute_semaphore]() {
|
&execute_semaphore]() {
|
||||||
const int replica = local_logical_device_ids_[i].first;
|
const int replica = addressable_device_logical_ids_[i].first;
|
||||||
const int partition = local_logical_device_ids_[i].second;
|
const int partition = addressable_device_logical_ids_[i].second;
|
||||||
RunId run_id;
|
RunId run_id;
|
||||||
auto result = ExecuteHelper(argument_handles, argument_handles[i],
|
auto result = ExecuteHelper(argument_handles, argument_handles[i],
|
||||||
replica, partition, run_id);
|
replica, partition, run_id);
|
||||||
|
@ -298,8 +298,9 @@ class PyTpuExecutable {
|
|||||||
return device_assignment_;
|
return device_assignment_;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<std::pair<int, int>>& local_logical_device_ids() const {
|
const std::vector<std::pair<int, int>>& addressable_device_logical_ids()
|
||||||
return local_logical_device_ids_;
|
const {
|
||||||
|
return addressable_device_logical_ids_;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() const {
|
const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() const {
|
||||||
@ -340,16 +341,14 @@ class PyTpuExecutable {
|
|||||||
|
|
||||||
// The replica and partition indices of device_assignment_ to be run by this
|
// The replica and partition indices of device_assignment_ to be run by this
|
||||||
// client. On single-host platforms without partitioning, this is all replicas
|
// client. On single-host platforms without partitioning, this is all replicas
|
||||||
// (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case
|
// (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
|
||||||
// on multi-host platforms.
|
// case on multi-host platforms. If there are 4 replicas and 2 partitions on a
|
||||||
// If there are 4 replicas and 2 partitions on a single host platform, size of
|
// single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
|
||||||
// local_logical_device_ids_ is 4*2 = 8.
|
std::vector<std::pair<int, int>> addressable_device_logical_ids_;
|
||||||
std::vector<std::pair<int, int>> local_logical_device_ids_;
|
|
||||||
|
|
||||||
// local_devices_[i] is the Device to which local_logical_device_ids_[i] is
|
// local_devices_[i] is the Device to which addressable_device_logical_ids_[i]
|
||||||
// assigned.
|
// is assigned. shared_ptrs instead of unique_ptrs to play well with the
|
||||||
// shared_ptrs instead of unique_ptrs to play well with the Python bindings
|
// Python bindings (see xla.cc).
|
||||||
// (see xla.cc).
|
|
||||||
std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
|
std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
|
||||||
|
|
||||||
xla::Shape result_shape_;
|
xla::Shape result_shape_;
|
||||||
|
@ -186,7 +186,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
|||||||
|
|
||||||
py::class_<PyTpuExecutable>(m, "TpuExecutable")
|
py::class_<PyTpuExecutable>(m, "TpuExecutable")
|
||||||
.def("local_logical_device_ids",
|
.def("local_logical_device_ids",
|
||||||
&PyTpuExecutable::local_logical_device_ids)
|
&PyTpuExecutable::addressable_device_logical_ids)
|
||||||
.def("local_devices", &PyTpuExecutable::local_devices)
|
.def("local_devices", &PyTpuExecutable::local_devices)
|
||||||
.def_property_readonly("client", &PyTpuExecutable::client)
|
.def_property_readonly("client", &PyTpuExecutable::client)
|
||||||
.def("size_of_generated_code_in_bytes",
|
.def("size_of_generated_code_in_bytes",
|
||||||
|
@ -377,8 +377,18 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
py::class_<PyExecutable, std::shared_ptr<PyExecutable>> executable(
|
py::class_<PyExecutable, std::shared_ptr<PyExecutable>> executable(
|
||||||
m, "Executable");
|
m, "Executable");
|
||||||
executable.def_property_readonly("client", &PyExecutable::client)
|
executable.def_property_readonly("client", &PyExecutable::client)
|
||||||
.def("local_logical_device_ids", &PyExecutable::local_logical_device_ids)
|
.def("local_logical_device_ids",
|
||||||
.def("local_devices", &PyExecutable::LocalDevices)
|
[](PyExecutable* exec) {
|
||||||
|
auto span = exec->addressable_device_logical_ids();
|
||||||
|
// Not on dispatch critical path, so ok to have heap allocation.
|
||||||
|
std::vector<std::pair<int, int>> addressable_device_logical_ids;
|
||||||
|
addressable_device_logical_ids.reserve(span.size());
|
||||||
|
for (const auto& logical_device_id : span) {
|
||||||
|
addressable_device_logical_ids.push_back(std::make_pair(
|
||||||
|
logical_device_id.replica, logical_device_id.partition));
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.def("local_devices", &PyExecutable::AddressableDevices)
|
||||||
.def("size_of_generated_code_in_bytes",
|
.def("size_of_generated_code_in_bytes",
|
||||||
&PyExecutable::SizeOfGeneratedCodeInBytes)
|
&PyExecutable::SizeOfGeneratedCodeInBytes)
|
||||||
.def("delete", &PyExecutable::Delete)
|
.def("delete", &PyExecutable::Delete)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user