Clarify ResetDevice operates on all devices associated with backend.
Change: 144258290
This commit is contained in:
parent
ca5ba46fb5
commit
99e1b19ceb
@ -234,4 +234,8 @@ StatusOr<bool> Backend::devices_equivalent(int device_ordinal_a,
|
|||||||
executor_b->GetDeviceDescription().name());
|
executor_b->GetDeviceDescription().name());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status Backend::ResetDevices() {
|
||||||
|
return transfer_manager_->ResetDevices(stream_executors_);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -149,6 +149,9 @@ class Backend {
|
|||||||
// used for scheduling work. For other platforms, returns NULL.
|
// used for scheduling work. For other platforms, returns NULL.
|
||||||
const Eigen::ThreadPoolDevice* eigen_intra_op_thread_pool_device() const;
|
const Eigen::ThreadPoolDevice* eigen_intra_op_thread_pool_device() const;
|
||||||
|
|
||||||
|
// Resets the devices associated with this backend.
|
||||||
|
Status ResetDevices();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct EigenThreadPoolWrapper;
|
struct EigenThreadPoolWrapper;
|
||||||
Backend(int64 replica_count, perftools::gputools::Platform* platform,
|
Backend(int64 replica_count, perftools::gputools::Platform* platform,
|
||||||
|
@ -160,7 +160,9 @@ Status GenericTransferManager::TransferLiteralToInfeed(
|
|||||||
return Unimplemented("Infeed is not supported on GPU (b/30467474)");
|
return Unimplemented("Infeed is not supported on GPU (b/30467474)");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GenericTransferManager::ResetDevice(se::StreamExecutor* executor) {
|
Status GenericTransferManager::ResetDevices(
|
||||||
|
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
||||||
|
executors) {
|
||||||
return Unimplemented(
|
return Unimplemented(
|
||||||
"Device reset is not yet supported on CPU and GPU (b/30481585)");
|
"Device reset is not yet supported on CPU and GPU (b/30481585)");
|
||||||
}
|
}
|
||||||
|
@ -55,7 +55,9 @@ class GenericTransferManager : public TransferManager {
|
|||||||
Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor,
|
Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor,
|
||||||
const Literal& literal) override;
|
const Literal& literal) override;
|
||||||
|
|
||||||
Status ResetDevice(perftools::gputools::StreamExecutor* executor) override;
|
Status ResetDevices(
|
||||||
|
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
||||||
|
executors) override;
|
||||||
|
|
||||||
StatusOr<std::vector<perftools::gputools::DeviceMemoryBase>>
|
StatusOr<std::vector<perftools::gputools::DeviceMemoryBase>>
|
||||||
ShallowCopyTupleFromDevice(
|
ShallowCopyTupleFromDevice(
|
||||||
|
@ -1019,16 +1019,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
|
|||||||
|
|
||||||
tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg,
|
tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg,
|
||||||
ResetDeviceResponse* result) {
|
ResetDeviceResponse* result) {
|
||||||
int first_device_ordinal = arg->has_device_handle()
|
return execute_backend_->ResetDevices();
|
||||||
? arg->device_handle().handle()
|
|
||||||
: execute_backend_->default_device_ordinal();
|
|
||||||
TF_ASSIGN_OR_RETURN(auto executors,
|
|
||||||
execute_backend_->Replicas(first_device_ordinal));
|
|
||||||
for (se::StreamExecutor* executor : executors) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
execute_backend_->transfer_manager()->ResetDevice(executor));
|
|
||||||
}
|
|
||||||
return tensorflow::Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::Status Service::TransferToClientInProcess(
|
tensorflow::Status Service::TransferToClientInProcess(
|
||||||
|
@ -162,7 +162,15 @@ class Service : public ServiceInterface {
|
|||||||
const TransferToInfeedRequest* arg,
|
const TransferToInfeedRequest* arg,
|
||||||
TransferToInfeedResponse* result) override;
|
TransferToInfeedResponse* result) override;
|
||||||
|
|
||||||
// Resets the device, clearing all existing state on the device.
|
// Resets devices, clearing all existing state on all the devices associated
|
||||||
|
// with this service (including memory allocated on the devices).
|
||||||
|
//
|
||||||
|
// ResetDevice may only be called where no previous Execution state on the
|
||||||
|
// device is used by the next Execution.
|
||||||
|
//
|
||||||
|
// ResetDevice should be called before an Execution that expect the device to
|
||||||
|
// be in the reset state. For example, if the prior Execution modifies device
|
||||||
|
// state (e.g., architectural state) that the next Execution depends on.
|
||||||
tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
|
tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
|
||||||
ResetDeviceResponse* result) override;
|
ResetDeviceResponse* result) override;
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
@ -63,8 +64,10 @@ class TransferManager {
|
|||||||
perftools::gputools::StreamExecutor* executor,
|
perftools::gputools::StreamExecutor* executor,
|
||||||
const Literal& literal) = 0;
|
const Literal& literal) = 0;
|
||||||
|
|
||||||
// Resets the device that the given executor runs on.
|
// Resets the devices associated with this transfer manager.
|
||||||
virtual Status ResetDevice(perftools::gputools::StreamExecutor* executor) = 0;
|
virtual Status ResetDevices(
|
||||||
|
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
||||||
|
executor) = 0;
|
||||||
|
|
||||||
// Shallow copy a tuple from the device and create a DeviceMemoryBase object
|
// Shallow copy a tuple from the device and create a DeviceMemoryBase object
|
||||||
// for each element in the tuple. A DeviceMemoryBase object refers to the
|
// for each element in the tuple. A DeviceMemoryBase object refers to the
|
||||||
|
Loading…
x
Reference in New Issue
Block a user