Clarify ResetDevice operates on all devices associated with backend.

Change: 144258290
This commit is contained in:
Jacques Pienaar 2017-01-11 15:32:45 -08:00 committed by TensorFlower Gardener
parent ca5ba46fb5
commit 99e1b19ceb
7 changed files with 28 additions and 15 deletions

View File

@ -234,4 +234,8 @@ StatusOr<bool> Backend::devices_equivalent(int device_ordinal_a,
executor_b->GetDeviceDescription().name());
}
Status Backend::ResetDevices() {
return transfer_manager_->ResetDevices(stream_executors_);
}
} // namespace xla

View File

@ -149,6 +149,9 @@ class Backend {
// used for scheduling work. For other platforms, returns NULL.
const Eigen::ThreadPoolDevice* eigen_intra_op_thread_pool_device() const;
// Resets the devices associated with this backend.
Status ResetDevices();
private:
struct EigenThreadPoolWrapper;
Backend(int64 replica_count, perftools::gputools::Platform* platform,

View File

@ -160,7 +160,9 @@ Status GenericTransferManager::TransferLiteralToInfeed(
return Unimplemented("Infeed is not supported on GPU (b/30467474)");
}
Status GenericTransferManager::ResetDevice(se::StreamExecutor* executor) {
Status GenericTransferManager::ResetDevices(
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
executors) {
return Unimplemented(
"Device reset is not yet supported on CPU and GPU (b/30481585)");
}

View File

@ -55,7 +55,9 @@ class GenericTransferManager : public TransferManager {
Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor,
const Literal& literal) override;
Status ResetDevice(perftools::gputools::StreamExecutor* executor) override;
Status ResetDevices(
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
executors) override;
StatusOr<std::vector<perftools::gputools::DeviceMemoryBase>>
ShallowCopyTupleFromDevice(

View File

@ -1019,16 +1019,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg,
ResetDeviceResponse* result) {
int first_device_ordinal = arg->has_device_handle()
? arg->device_handle().handle()
: execute_backend_->default_device_ordinal();
TF_ASSIGN_OR_RETURN(auto executors,
execute_backend_->Replicas(first_device_ordinal));
for (se::StreamExecutor* executor : executors) {
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->ResetDevice(executor));
}
return tensorflow::Status::OK();
return execute_backend_->ResetDevices();
}
tensorflow::Status Service::TransferToClientInProcess(

View File

@ -162,7 +162,15 @@ class Service : public ServiceInterface {
const TransferToInfeedRequest* arg,
TransferToInfeedResponse* result) override;
// Resets the device, clearing all existing state on the device.
// Resets devices, clearing all existing state on all the devices associated
// with this service (including memory allocated on the devices).
//
// ResetDevice may only be called where no previous Execution state on the
// device is used by the next Execution.
//
// ResetDevice should be called before an Execution that expect the device to
// be in the reset state. For example, if the prior Execution modifies device
// state (e.g., architectural state) that the next Execution depends on.
tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
ResetDeviceResponse* result) override;

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
@ -63,8 +64,10 @@ class TransferManager {
perftools::gputools::StreamExecutor* executor,
const Literal& literal) = 0;
// Resets the device that the given executor runs on.
virtual Status ResetDevice(perftools::gputools::StreamExecutor* executor) = 0;
// Resets the devices associated with this transfer manager.
virtual Status ResetDevices(
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
executor) = 0;
// Shallow copy a tuple from the device and create a DeviceMemoryBase object
// for each element in the tuple. A DeviceMemoryBase object refers to the