diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/python/local_client.cc index dc6e8c5b500..4d7a6335c3f 100644 --- a/tensorflow/compiler/xla/python/local_client.cc +++ b/tensorflow/compiler/xla/python/local_client.cc @@ -1413,6 +1413,24 @@ PyLocalExecutable::Execute(absl::Span argument_handles, RunId(), options); } +StatusOr>> +PyLocalExecutable::ExecuteOnLocalDevice( + absl::Span argument_handles, Device* device, + const ExecuteOptions& options) const { + for (int i = 0; i < local_devices_.size(); ++i) { + if (local_devices_[i] == device) { + VLOG(1) << "Executing computation " << name(); + return ExecuteHelper(argument_handles, + /*replica=*/local_logical_device_ids_[i].first, + /*partition=*/local_logical_device_ids_[i].second, + RunId(), options); + } + } + return InvalidArgument( + "Attempted to execute on device id %d which is not a local device", + device->id()); +} + StatusOr>>> PyLocalExecutable::ExecuteOnLocalDevices( absl::Span> argument_handles, @@ -1435,8 +1453,8 @@ PyLocalExecutable::ExecuteOnLocalDevices( VLOG(1) << "Executing computation " << name() << "; num_replicas=" << num_replicas() - << " num_partitions=" << num_partitions() << " num_local_devices=8" - << num_local_devices; + << " num_partitions=" << num_partitions() + << " num_local_devices=" << num_local_devices; std::vector>>> results( num_local_devices); if (num_local_devices == 1) { diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/python/local_client.h index 63786042955..ea10693255f 100644 --- a/tensorflow/compiler/xla/python/local_client.h +++ b/tensorflow/compiler/xla/python/local_client.h @@ -558,6 +558,10 @@ class PyLocalExecutable { absl::Span argument_handles, const ExecuteOptions& options) const; + StatusOr>> ExecuteOnLocalDevice( + absl::Span argument_handles, Device* device, + const ExecuteOptions& options) const; + // Execute on local devices. Takes a sequence of argument lists (one argument // list per local device) and returns a tuple of results (one result per local // device). The number of argument lists must be equal to the local device