Add overload to PyLocalExecutable to run an executable on a designated local device.
PiperOrigin-RevId: 305494326 Change-Id: I795a0f1a97399c07cac72503f9cc40b3f4ec7859
This commit is contained in:
parent
f394a76871
commit
3009d8c3ff
@ -1413,6 +1413,24 @@ PyLocalExecutable::Execute(absl::Span<PyLocalBuffer* const> argument_handles,
|
||||
RunId(), options);
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>>
|
||||
PyLocalExecutable::ExecuteOnLocalDevice(
|
||||
absl::Span<PyLocalBuffer* const> argument_handles, Device* device,
|
||||
const ExecuteOptions& options) const {
|
||||
for (int i = 0; i < local_devices_.size(); ++i) {
|
||||
if (local_devices_[i] == device) {
|
||||
VLOG(1) << "Executing computation " << name();
|
||||
return ExecuteHelper(argument_handles,
|
||||
/*replica=*/local_logical_device_ids_[i].first,
|
||||
/*partition=*/local_logical_device_ids_[i].second,
|
||||
RunId(), options);
|
||||
}
|
||||
}
|
||||
return InvalidArgument(
|
||||
"Attempted to execute on device id %d which is not a local device",
|
||||
device->id());
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>>>
|
||||
PyLocalExecutable::ExecuteOnLocalDevices(
|
||||
absl::Span<const std::vector<PyLocalBuffer*>> argument_handles,
|
||||
@ -1435,8 +1453,8 @@ PyLocalExecutable::ExecuteOnLocalDevices(
|
||||
|
||||
VLOG(1) << "Executing computation " << name()
|
||||
<< "; num_replicas=" << num_replicas()
|
||||
<< " num_partitions=" << num_partitions() << " num_local_devices=8"
|
||||
<< num_local_devices;
|
||||
<< " num_partitions=" << num_partitions()
|
||||
<< " num_local_devices=" << num_local_devices;
|
||||
std::vector<StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>>> results(
|
||||
num_local_devices);
|
||||
if (num_local_devices == 1) {
|
||||
|
@ -558,6 +558,10 @@ class PyLocalExecutable {
|
||||
absl::Span<PyLocalBuffer* const> argument_handles,
|
||||
const ExecuteOptions& options) const;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> ExecuteOnLocalDevice(
|
||||
absl::Span<PyLocalBuffer* const> argument_handles, Device* device,
|
||||
const ExecuteOptions& options) const;
|
||||
|
||||
// Execute on local devices. Takes a sequence of argument lists (one argument
|
||||
// list per local device) and returns a tuple of results (one result per local
|
||||
// device). The number of argument lists must be equal to the local device
|
||||
|
Loading…
Reference in New Issue
Block a user