Implement GRPC TPU driver reset.
PiperOrigin-RevId: 281642385 Change-Id: Ie2bbf98a1895adacabb3c2738f4400186d1ed691
This commit is contained in:
parent
472804bd1f
commit
25cccbcaad
@ -332,19 +332,11 @@ class GrpcTpuDriver : public TpuDriver {
|
|||||||
}
|
}
|
||||||
|
|
||||||
~GrpcTpuDriver() override {
|
~GrpcTpuDriver() override {
|
||||||
auto stub = CreateTpuDriverStub(config_, creds_);
|
if (closed_) {
|
||||||
::grpc::ClientContext ctx;
|
return;
|
||||||
ctx.set_fail_fast(false);
|
|
||||||
ctx.set_deadline(std::chrono::system_clock::now() +
|
|
||||||
std::chrono::seconds(10));
|
|
||||||
CloseRequest req;
|
|
||||||
req.set_client_id(client_id_);
|
|
||||||
CloseResponse resp;
|
|
||||||
::grpc::Status status = stub->Close(&ctx, req, &resp);
|
|
||||||
if (!status.ok()) {
|
|
||||||
LOG(ERROR) << "Failed to close the gRPC driver: " << status.error_code()
|
|
||||||
<< ": " << status.error_details();
|
|
||||||
}
|
}
|
||||||
|
auto status = Close();
|
||||||
|
LOG_IF(ERROR, !status.ok()) << status;
|
||||||
}
|
}
|
||||||
|
|
||||||
void QuerySystemInfo(SystemInfo* system_info) override;
|
void QuerySystemInfo(SystemInfo* system_info) override;
|
||||||
@ -432,6 +424,7 @@ class GrpcTpuDriver : public TpuDriver {
|
|||||||
uint32_t client_id() const { return client_id_; }
|
uint32_t client_id() const { return client_id_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
Status Close();
|
||||||
std::unique_ptr<GrpcTpuStream> AllocateStream(int32_t core_id);
|
std::unique_ptr<GrpcTpuStream> AllocateStream(int32_t core_id);
|
||||||
|
|
||||||
const TpuDriverConfig config_;
|
const TpuDriverConfig config_;
|
||||||
@ -442,6 +435,7 @@ class GrpcTpuDriver : public TpuDriver {
|
|||||||
std::unique_ptr<GrpcTpuStream> host_stream_;
|
std::unique_ptr<GrpcTpuStream> host_stream_;
|
||||||
// Shared by all streams.
|
// Shared by all streams.
|
||||||
std::atomic<uint64_t> operation_id_{0};
|
std::atomic<uint64_t> operation_id_{0};
|
||||||
|
std::atomic<bool> closed_{false};
|
||||||
}; // namespace
|
}; // namespace
|
||||||
|
|
||||||
GrpcEvent::~GrpcEvent() { stream_->DeleteEvent(id_); }
|
GrpcEvent::~GrpcEvent() { stream_->DeleteEvent(id_); }
|
||||||
@ -1007,14 +1001,52 @@ void GrpcTpuDriver::QuerySystemInfo(SystemInfo* system_info) {
|
|||||||
::grpc::Status status = stub->QuerySystemInfo(&ctx, req, &resp);
|
::grpc::Status status = stub->QuerySystemInfo(&ctx, req, &resp);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
LOG(ERROR) << "QuerySystemInfo request failed: " << status.error_code()
|
LOG(ERROR) << "QuerySystemInfo request failed: " << status.error_code()
|
||||||
<< ":" << status.error_details();
|
<< ": " << status.error_message() << ": "
|
||||||
|
<< status.error_details();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
*system_info = resp.system_info();
|
*system_info = resp.system_info();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GrpcTpuDriver::Reset() {
|
Status GrpcTpuDriver::Reset() {
|
||||||
return xla::Unimplemented("GRPC driver reset is not implemented yet.");
|
auto stub = CreateTpuDriverStub(config_, creds_);
|
||||||
|
::grpc::ClientContext ctx;
|
||||||
|
ctx.set_fail_fast(false);
|
||||||
|
ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10));
|
||||||
|
ResetRequest req;
|
||||||
|
ResetResponse resp;
|
||||||
|
::grpc::Status status = stub->Reset(&ctx, req, &resp);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(ERROR) << "Failed to reset the gRPC driver: " << status.error_code()
|
||||||
|
<< ": " << status.error_message() << ": "
|
||||||
|
<< status.error_details();
|
||||||
|
return xla::Status(tensorflow::error::Code(status.error_code()),
|
||||||
|
absl::StrCat("Failed to reset TPU driver. Error was: ",
|
||||||
|
status.error_message(),
|
||||||
|
". Details: ", status.error_details()));
|
||||||
|
}
|
||||||
|
streams_.clear();
|
||||||
|
host_stream_.reset();
|
||||||
|
return Close();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GrpcTpuDriver::Close() {
|
||||||
|
auto stub = CreateTpuDriverStub(config_, creds_);
|
||||||
|
::grpc::ClientContext ctx;
|
||||||
|
ctx.set_fail_fast(false);
|
||||||
|
ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10));
|
||||||
|
CloseRequest req;
|
||||||
|
req.set_client_id(client_id_);
|
||||||
|
CloseResponse resp;
|
||||||
|
::grpc::Status status = stub->Close(&ctx, req, &resp);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return xla::Status(tensorflow::error::Code(status.error_code()),
|
||||||
|
absl::StrCat("Failed to close TPU driver. Error was: ",
|
||||||
|
status.error_message(),
|
||||||
|
". Details: ", status.error_details()));
|
||||||
|
}
|
||||||
|
closed_ = true;
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -153,11 +153,13 @@ class TpuDriver {
|
|||||||
virtual ~TpuDriver() {}
|
virtual ~TpuDriver() {}
|
||||||
|
|
||||||
virtual void QuerySystemInfo(SystemInfo* system_info) = 0;
|
virtual void QuerySystemInfo(SystemInfo* system_info) = 0;
|
||||||
// Synchronous. Reset the state of the TPU driver. All running programs
|
// Synchronous. Reset the state of the TPU driver. After Reset(), this TPU
|
||||||
// will be terminated and all allocations reset.
|
// driver object is no longer usable. Users must destroy this object and
|
||||||
|
// create a new one.
|
||||||
//
|
//
|
||||||
// All events and buffer handles created prior to Reset() will be invalid,
|
// All running programs will be terminated and all allocations reset. All
|
||||||
// and any use will result in undefined behavior.
|
// events and buffer handles created prior to Reset() will be invalid, and any
|
||||||
|
// use will result in undefined behavior.
|
||||||
virtual xla::Status Reset() = 0;
|
virtual xla::Status Reset() = 0;
|
||||||
|
|
||||||
virtual std::unique_ptr<BufferHandle> Allocate(
|
virtual std::unique_ptr<BufferHandle> Allocate(
|
||||||
|
@ -157,6 +157,10 @@ message CloseRequest {
|
|||||||
|
|
||||||
message CloseResponse {}
|
message CloseResponse {}
|
||||||
|
|
||||||
|
message ResetRequest {}
|
||||||
|
|
||||||
|
message ResetResponse {}
|
||||||
|
|
||||||
message QuerySystemInfoRequest {}
|
message QuerySystemInfoRequest {}
|
||||||
|
|
||||||
message QuerySystemInfoResponse {
|
message QuerySystemInfoResponse {
|
||||||
@ -170,6 +174,9 @@ service CloudTpuDriver {
|
|||||||
// Close the driver. Any outstanding requests will be terminated.
|
// Close the driver. Any outstanding requests will be terminated.
|
||||||
rpc Close(CloseRequest) returns (CloseResponse);
|
rpc Close(CloseRequest) returns (CloseResponse);
|
||||||
|
|
||||||
|
// Reset the driver. All connected clients will be disconnected.
|
||||||
|
rpc Reset(ResetRequest) returns (ResetResponse);
|
||||||
|
|
||||||
// Query the driver for current system performance information.
|
// Query the driver for current system performance information.
|
||||||
rpc QuerySystemInfo(QuerySystemInfoRequest) returns (QuerySystemInfoResponse);
|
rpc QuerySystemInfo(QuerySystemInfoRequest) returns (QuerySystemInfoResponse);
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user