Add overload to PyLocalExecutable to run an executable on a designated local device.
PiperOrigin-RevId: 305494326 Change-Id: I795a0f1a97399c07cac72503f9cc40b3f4ec7859
This commit is contained in:
parent
f394a76871
commit
3009d8c3ff
@ -1413,6 +1413,24 @@ PyLocalExecutable::Execute(absl::Span<PyLocalBuffer* const> argument_handles,
|
|||||||
RunId(), options);
|
RunId(), options);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>>
|
||||||
|
PyLocalExecutable::ExecuteOnLocalDevice(
|
||||||
|
absl::Span<PyLocalBuffer* const> argument_handles, Device* device,
|
||||||
|
const ExecuteOptions& options) const {
|
||||||
|
for (int i = 0; i < local_devices_.size(); ++i) {
|
||||||
|
if (local_devices_[i] == device) {
|
||||||
|
VLOG(1) << "Executing computation " << name();
|
||||||
|
return ExecuteHelper(argument_handles,
|
||||||
|
/*replica=*/local_logical_device_ids_[i].first,
|
||||||
|
/*partition=*/local_logical_device_ids_[i].second,
|
||||||
|
RunId(), options);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return InvalidArgument(
|
||||||
|
"Attempted to execute on device id %d which is not a local device",
|
||||||
|
device->id());
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>>>
|
StatusOr<std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>>>
|
||||||
PyLocalExecutable::ExecuteOnLocalDevices(
|
PyLocalExecutable::ExecuteOnLocalDevices(
|
||||||
absl::Span<const std::vector<PyLocalBuffer*>> argument_handles,
|
absl::Span<const std::vector<PyLocalBuffer*>> argument_handles,
|
||||||
@ -1435,8 +1453,8 @@ PyLocalExecutable::ExecuteOnLocalDevices(
|
|||||||
|
|
||||||
VLOG(1) << "Executing computation " << name()
|
VLOG(1) << "Executing computation " << name()
|
||||||
<< "; num_replicas=" << num_replicas()
|
<< "; num_replicas=" << num_replicas()
|
||||||
<< " num_partitions=" << num_partitions() << " num_local_devices=8"
|
<< " num_partitions=" << num_partitions()
|
||||||
<< num_local_devices;
|
<< " num_local_devices=" << num_local_devices;
|
||||||
std::vector<StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>>> results(
|
std::vector<StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>>> results(
|
||||||
num_local_devices);
|
num_local_devices);
|
||||||
if (num_local_devices == 1) {
|
if (num_local_devices == 1) {
|
||||||
|
@ -558,6 +558,10 @@ class PyLocalExecutable {
|
|||||||
absl::Span<PyLocalBuffer* const> argument_handles,
|
absl::Span<PyLocalBuffer* const> argument_handles,
|
||||||
const ExecuteOptions& options) const;
|
const ExecuteOptions& options) const;
|
||||||
|
|
||||||
|
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> ExecuteOnLocalDevice(
|
||||||
|
absl::Span<PyLocalBuffer* const> argument_handles, Device* device,
|
||||||
|
const ExecuteOptions& options) const;
|
||||||
|
|
||||||
// Execute on local devices. Takes a sequence of argument lists (one argument
|
// Execute on local devices. Takes a sequence of argument lists (one argument
|
||||||
// list per local device) and returns a tuple of results (one result per local
|
// list per local device) and returns a tuple of results (one result per local
|
||||||
// device). The number of argument lists must be equal to the local device
|
// device). The number of argument lists must be equal to the local device
|
||||||
|
Loading…
Reference in New Issue
Block a user