Implement GRPC TPU driver reset.

PiperOrigin-RevId: 281642385
Change-Id: Ie2bbf98a1895adacabb3c2738f4400186d1ed691
This commit is contained in:
Wenhao Jia 2019-11-20 17:36:44 -08:00 committed by TensorFlower Gardener
parent 472804bd1f
commit 25cccbcaad
3 changed files with 59 additions and 18 deletions

View File

@ -332,19 +332,11 @@ class GrpcTpuDriver : public TpuDriver {
}
~GrpcTpuDriver() override {
auto stub = CreateTpuDriverStub(config_, creds_);
::grpc::ClientContext ctx;
ctx.set_fail_fast(false);
ctx.set_deadline(std::chrono::system_clock::now() +
std::chrono::seconds(10));
CloseRequest req;
req.set_client_id(client_id_);
CloseResponse resp;
::grpc::Status status = stub->Close(&ctx, req, &resp);
if (!status.ok()) {
LOG(ERROR) << "Failed to close the gRPC driver: " << status.error_code()
<< ": " << status.error_details();
if (closed_) {
return;
}
auto status = Close();
LOG_IF(ERROR, !status.ok()) << status;
}
void QuerySystemInfo(SystemInfo* system_info) override;
@ -432,6 +424,7 @@ class GrpcTpuDriver : public TpuDriver {
uint32_t client_id() const { return client_id_; }
private:
Status Close();
std::unique_ptr<GrpcTpuStream> AllocateStream(int32_t core_id);
const TpuDriverConfig config_;
@ -442,6 +435,7 @@ class GrpcTpuDriver : public TpuDriver {
std::unique_ptr<GrpcTpuStream> host_stream_;
// Shared by all streams.
std::atomic<uint64_t> operation_id_{0};
std::atomic<bool> closed_{false};
}; // namespace
GrpcEvent::~GrpcEvent() { stream_->DeleteEvent(id_); }
@ -1007,14 +1001,52 @@ void GrpcTpuDriver::QuerySystemInfo(SystemInfo* system_info) {
::grpc::Status status = stub->QuerySystemInfo(&ctx, req, &resp);
if (!status.ok()) {
LOG(ERROR) << "QuerySystemInfo request failed: " << status.error_code()
<< ":" << status.error_details();
<< ": " << status.error_message() << ": "
<< status.error_details();
return;
}
*system_info = resp.system_info();
}
Status GrpcTpuDriver::Reset() {
return xla::Unimplemented("GRPC driver reset is not implemented yet.");
auto stub = CreateTpuDriverStub(config_, creds_);
::grpc::ClientContext ctx;
ctx.set_fail_fast(false);
ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10));
ResetRequest req;
ResetResponse resp;
::grpc::Status status = stub->Reset(&ctx, req, &resp);
if (!status.ok()) {
LOG(ERROR) << "Failed to reset the gRPC driver: " << status.error_code()
<< ": " << status.error_message() << ": "
<< status.error_details();
return xla::Status(tensorflow::error::Code(status.error_code()),
absl::StrCat("Failed to reset TPU driver. Error was: ",
status.error_message(),
". Details: ", status.error_details()));
}
streams_.clear();
host_stream_.reset();
return Close();
}
Status GrpcTpuDriver::Close() {
auto stub = CreateTpuDriverStub(config_, creds_);
::grpc::ClientContext ctx;
ctx.set_fail_fast(false);
ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10));
CloseRequest req;
req.set_client_id(client_id_);
CloseResponse resp;
::grpc::Status status = stub->Close(&ctx, req, &resp);
if (!status.ok()) {
return xla::Status(tensorflow::error::Code(status.error_code()),
absl::StrCat("Failed to close TPU driver. Error was: ",
status.error_message(),
". Details: ", status.error_details()));
}
closed_ = true;
return Status::OK();
}
} // namespace

View File

@ -153,11 +153,13 @@ class TpuDriver {
virtual ~TpuDriver() {}
virtual void QuerySystemInfo(SystemInfo* system_info) = 0;
// Synchronous. Reset the state of the TPU driver. All running programs
// will be terminated and all allocations reset.
// Synchronous. Reset the state of the TPU driver. After Reset(), this TPU
// driver object is no longer usable. Users must destroy this object and
// create a new one.
//
// All events and buffer handles created prior to Reset() will be invalid,
// and any use will result in undefined behavior.
// All running programs will be terminated and all allocations reset. All
// events and buffer handles created prior to Reset() will be invalid, and any
// use will result in undefined behavior.
virtual xla::Status Reset() = 0;
virtual std::unique_ptr<BufferHandle> Allocate(

View File

@ -157,6 +157,10 @@ message CloseRequest {
message CloseResponse {}
message ResetRequest {}
message ResetResponse {}
message QuerySystemInfoRequest {}
message QuerySystemInfoResponse {
@ -170,6 +174,9 @@ service CloudTpuDriver {
// Close the driver. Any outstanding requests will be terminated.
rpc Close(CloseRequest) returns (CloseResponse);
// Reset the driver. All connected clients will be disconnected.
rpc Reset(ResetRequest) returns (ResetResponse);
// Query the driver for current system performance information.
rpc QuerySystemInfo(QuerySystemInfoRequest) returns (QuerySystemInfoResponse);