Clarify ResetDevice operates on all devices associated with backend.
Change: 144258290
This commit is contained in:
parent
ca5ba46fb5
commit
99e1b19ceb
tensorflow/compiler/xla/service
@ -234,4 +234,8 @@ StatusOr<bool> Backend::devices_equivalent(int device_ordinal_a,
|
||||
executor_b->GetDeviceDescription().name());
|
||||
}
|
||||
|
||||
Status Backend::ResetDevices() {
|
||||
return transfer_manager_->ResetDevices(stream_executors_);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -149,6 +149,9 @@ class Backend {
|
||||
// used for scheduling work. For other platforms, returns NULL.
|
||||
const Eigen::ThreadPoolDevice* eigen_intra_op_thread_pool_device() const;
|
||||
|
||||
// Resets the devices associated with this backend.
|
||||
Status ResetDevices();
|
||||
|
||||
private:
|
||||
struct EigenThreadPoolWrapper;
|
||||
Backend(int64 replica_count, perftools::gputools::Platform* platform,
|
||||
|
@ -160,7 +160,9 @@ Status GenericTransferManager::TransferLiteralToInfeed(
|
||||
return Unimplemented("Infeed is not supported on GPU (b/30467474)");
|
||||
}
|
||||
|
||||
Status GenericTransferManager::ResetDevice(se::StreamExecutor* executor) {
|
||||
Status GenericTransferManager::ResetDevices(
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
||||
executors) {
|
||||
return Unimplemented(
|
||||
"Device reset is not yet supported on CPU and GPU (b/30481585)");
|
||||
}
|
||||
|
@ -55,7 +55,9 @@ class GenericTransferManager : public TransferManager {
|
||||
Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor,
|
||||
const Literal& literal) override;
|
||||
|
||||
Status ResetDevice(perftools::gputools::StreamExecutor* executor) override;
|
||||
Status ResetDevices(
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
||||
executors) override;
|
||||
|
||||
StatusOr<std::vector<perftools::gputools::DeviceMemoryBase>>
|
||||
ShallowCopyTupleFromDevice(
|
||||
|
@ -1019,16 +1019,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
|
||||
|
||||
tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg,
|
||||
ResetDeviceResponse* result) {
|
||||
int first_device_ordinal = arg->has_device_handle()
|
||||
? arg->device_handle().handle()
|
||||
: execute_backend_->default_device_ordinal();
|
||||
TF_ASSIGN_OR_RETURN(auto executors,
|
||||
execute_backend_->Replicas(first_device_ordinal));
|
||||
for (se::StreamExecutor* executor : executors) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
execute_backend_->transfer_manager()->ResetDevice(executor));
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
return execute_backend_->ResetDevices();
|
||||
}
|
||||
|
||||
tensorflow::Status Service::TransferToClientInProcess(
|
||||
|
@ -162,7 +162,15 @@ class Service : public ServiceInterface {
|
||||
const TransferToInfeedRequest* arg,
|
||||
TransferToInfeedResponse* result) override;
|
||||
|
||||
// Resets the device, clearing all existing state on the device.
|
||||
// Resets devices, clearing all existing state on all the devices associated
|
||||
// with this service (including memory allocated on the devices).
|
||||
//
|
||||
// ResetDevice may only be called where no previous Execution state on the
|
||||
// device is used by the next Execution.
|
||||
//
|
||||
// ResetDevice should be called before an Execution that expect the device to
|
||||
// be in the reset state. For example, if the prior Execution modifies device
|
||||
// state (e.g., architectural state) that the next Execution depends on.
|
||||
tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
|
||||
ResetDeviceResponse* result) override;
|
||||
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
@ -63,8 +64,10 @@ class TransferManager {
|
||||
perftools::gputools::StreamExecutor* executor,
|
||||
const Literal& literal) = 0;
|
||||
|
||||
// Resets the device that the given executor runs on.
|
||||
virtual Status ResetDevice(perftools::gputools::StreamExecutor* executor) = 0;
|
||||
// Resets the devices associated with this transfer manager.
|
||||
virtual Status ResetDevices(
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
||||
executor) = 0;
|
||||
|
||||
// Shallow copy a tuple from the device and create a DeviceMemoryBase object
|
||||
// for each element in the tuple. A DeviceMemoryBase object refers to the
|
||||
|
Loading…
Reference in New Issue
Block a user