Add overload to PyLocalExecutable to run an executable on a designated local device.

PiperOrigin-RevId: 305494326
Change-Id: I795a0f1a97399c07cac72503f9cc40b3f4ec7859
This commit is contained in:
A. Unique TensorFlower 2020-04-08 09:35:50 -07:00 committed by TensorFlower Gardener
parent f394a76871
commit 3009d8c3ff
2 changed files with 24 additions and 2 deletions

View File

@ -1413,6 +1413,24 @@ PyLocalExecutable::Execute(absl::Span<PyLocalBuffer* const> argument_handles,
RunId(), options);
}
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>>
PyLocalExecutable::ExecuteOnLocalDevice(
absl::Span<PyLocalBuffer* const> argument_handles, Device* device,
const ExecuteOptions& options) const {
for (int i = 0; i < local_devices_.size(); ++i) {
if (local_devices_[i] == device) {
VLOG(1) << "Executing computation " << name();
return ExecuteHelper(argument_handles,
/*replica=*/local_logical_device_ids_[i].first,
/*partition=*/local_logical_device_ids_[i].second,
RunId(), options);
}
}
return InvalidArgument(
"Attempted to execute on device id %d which is not a local device",
device->id());
}
StatusOr<std::vector<std::vector<std::unique_ptr<PyLocalBuffer>>>>
PyLocalExecutable::ExecuteOnLocalDevices(
absl::Span<const std::vector<PyLocalBuffer*>> argument_handles,
@ -1435,8 +1453,8 @@ PyLocalExecutable::ExecuteOnLocalDevices(
VLOG(1) << "Executing computation " << name()
<< "; num_replicas=" << num_replicas()
<< " num_partitions=" << num_partitions() << " num_local_devices=8"
<< num_local_devices;
<< " num_partitions=" << num_partitions()
<< " num_local_devices=" << num_local_devices;
std::vector<StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>>> results(
num_local_devices);
if (num_local_devices == 1) {

View File

@ -558,6 +558,10 @@ class PyLocalExecutable {
absl::Span<PyLocalBuffer* const> argument_handles,
const ExecuteOptions& options) const;
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>> ExecuteOnLocalDevice(
absl::Span<PyLocalBuffer* const> argument_handles, Device* device,
const ExecuteOptions& options) const;
// Execute on local devices. Takes a sequence of argument lists (one argument
// list per local device) and returns a tuple of results (one result per local
// device). The number of argument lists must be equal to the local device