Implement GRPC TPU driver reset.
PiperOrigin-RevId: 281642385 Change-Id: Ie2bbf98a1895adacabb3c2738f4400186d1ed691
This commit is contained in:
parent
472804bd1f
commit
25cccbcaad
@ -332,19 +332,11 @@ class GrpcTpuDriver : public TpuDriver {
|
||||
}
|
||||
|
||||
~GrpcTpuDriver() override {
|
||||
auto stub = CreateTpuDriverStub(config_, creds_);
|
||||
::grpc::ClientContext ctx;
|
||||
ctx.set_fail_fast(false);
|
||||
ctx.set_deadline(std::chrono::system_clock::now() +
|
||||
std::chrono::seconds(10));
|
||||
CloseRequest req;
|
||||
req.set_client_id(client_id_);
|
||||
CloseResponse resp;
|
||||
::grpc::Status status = stub->Close(&ctx, req, &resp);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Failed to close the gRPC driver: " << status.error_code()
|
||||
<< ": " << status.error_details();
|
||||
if (closed_) {
|
||||
return;
|
||||
}
|
||||
auto status = Close();
|
||||
LOG_IF(ERROR, !status.ok()) << status;
|
||||
}
|
||||
|
||||
void QuerySystemInfo(SystemInfo* system_info) override;
|
||||
@ -432,6 +424,7 @@ class GrpcTpuDriver : public TpuDriver {
|
||||
uint32_t client_id() const { return client_id_; }
|
||||
|
||||
private:
|
||||
Status Close();
|
||||
std::unique_ptr<GrpcTpuStream> AllocateStream(int32_t core_id);
|
||||
|
||||
const TpuDriverConfig config_;
|
||||
@ -442,6 +435,7 @@ class GrpcTpuDriver : public TpuDriver {
|
||||
std::unique_ptr<GrpcTpuStream> host_stream_;
|
||||
// Shared by all streams.
|
||||
std::atomic<uint64_t> operation_id_{0};
|
||||
std::atomic<bool> closed_{false};
|
||||
}; // namespace
|
||||
|
||||
GrpcEvent::~GrpcEvent() { stream_->DeleteEvent(id_); }
|
||||
@ -1007,14 +1001,52 @@ void GrpcTpuDriver::QuerySystemInfo(SystemInfo* system_info) {
|
||||
::grpc::Status status = stub->QuerySystemInfo(&ctx, req, &resp);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "QuerySystemInfo request failed: " << status.error_code()
|
||||
<< ":" << status.error_details();
|
||||
<< ": " << status.error_message() << ": "
|
||||
<< status.error_details();
|
||||
return;
|
||||
}
|
||||
*system_info = resp.system_info();
|
||||
}
|
||||
|
||||
Status GrpcTpuDriver::Reset() {
|
||||
return xla::Unimplemented("GRPC driver reset is not implemented yet.");
|
||||
auto stub = CreateTpuDriverStub(config_, creds_);
|
||||
::grpc::ClientContext ctx;
|
||||
ctx.set_fail_fast(false);
|
||||
ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10));
|
||||
ResetRequest req;
|
||||
ResetResponse resp;
|
||||
::grpc::Status status = stub->Reset(&ctx, req, &resp);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Failed to reset the gRPC driver: " << status.error_code()
|
||||
<< ": " << status.error_message() << ": "
|
||||
<< status.error_details();
|
||||
return xla::Status(tensorflow::error::Code(status.error_code()),
|
||||
absl::StrCat("Failed to reset TPU driver. Error was: ",
|
||||
status.error_message(),
|
||||
". Details: ", status.error_details()));
|
||||
}
|
||||
streams_.clear();
|
||||
host_stream_.reset();
|
||||
return Close();
|
||||
}
|
||||
|
||||
Status GrpcTpuDriver::Close() {
|
||||
auto stub = CreateTpuDriverStub(config_, creds_);
|
||||
::grpc::ClientContext ctx;
|
||||
ctx.set_fail_fast(false);
|
||||
ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10));
|
||||
CloseRequest req;
|
||||
req.set_client_id(client_id_);
|
||||
CloseResponse resp;
|
||||
::grpc::Status status = stub->Close(&ctx, req, &resp);
|
||||
if (!status.ok()) {
|
||||
return xla::Status(tensorflow::error::Code(status.error_code()),
|
||||
absl::StrCat("Failed to close TPU driver. Error was: ",
|
||||
status.error_message(),
|
||||
". Details: ", status.error_details()));
|
||||
}
|
||||
closed_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -153,11 +153,13 @@ class TpuDriver {
|
||||
virtual ~TpuDriver() {}
|
||||
|
||||
virtual void QuerySystemInfo(SystemInfo* system_info) = 0;
|
||||
// Synchronous. Reset the state of the TPU driver. All running programs
|
||||
// will be terminated and all allocations reset.
|
||||
// Synchronous. Reset the state of the TPU driver. After Reset(), this TPU
|
||||
// driver object is no longer usable. Users must destroy this object and
|
||||
// create a new one.
|
||||
//
|
||||
// All events and buffer handles created prior to Reset() will be invalid,
|
||||
// and any use will result in undefined behavior.
|
||||
// All running programs will be terminated and all allocations reset. All
|
||||
// events and buffer handles created prior to Reset() will be invalid, and any
|
||||
// use will result in undefined behavior.
|
||||
virtual xla::Status Reset() = 0;
|
||||
|
||||
virtual std::unique_ptr<BufferHandle> Allocate(
|
||||
|
@ -157,6 +157,10 @@ message CloseRequest {
|
||||
|
||||
message CloseResponse {}
|
||||
|
||||
message ResetRequest {}
|
||||
|
||||
message ResetResponse {}
|
||||
|
||||
message QuerySystemInfoRequest {}
|
||||
|
||||
message QuerySystemInfoResponse {
|
||||
@ -170,6 +174,9 @@ service CloudTpuDriver {
|
||||
// Close the driver. Any outstanding requests will be terminated.
|
||||
rpc Close(CloseRequest) returns (CloseResponse);
|
||||
|
||||
// Reset the driver. All connected clients will be disconnected.
|
||||
rpc Reset(ResetRequest) returns (ResetResponse);
|
||||
|
||||
// Query the driver for current system performance information.
|
||||
rpc QuerySystemInfo(QuerySystemInfoRequest) returns (QuerySystemInfoResponse);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user