TpuDriver: Add more detail to connection errors and make timeout configurable.
PiperOrigin-RevId: 281108883 Change-Id: I6bbeb951815816b00d9591ba2fed27879edcd51c
This commit is contained in:
parent
c455ab4555
commit
c85d9b1369
@ -978,28 +978,30 @@ Status GrpcTpuDriver::Reset() {
|
||||
return xla::Unimplemented("GRPC driver reset is not implemented yet.");
|
||||
}
|
||||
|
||||
REGISTER_TPU_DRIVER("grpc://",
|
||||
[](const TpuDriverConfig& config)
|
||||
-> xla::StatusOr<std::unique_ptr<TpuDriver>> {
|
||||
auto stub = GrpcTpuDriver::CreateTpuDriverStub(config);
|
||||
::grpc::ClientContext ctx;
|
||||
ctx.set_fail_fast(false);
|
||||
ctx.set_deadline(std::chrono::system_clock::now() +
|
||||
std::chrono::seconds(10));
|
||||
OpenRequest req;
|
||||
OpenResponse resp;
|
||||
::grpc::Status status = stub->Open(&ctx, req, &resp);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Failed to open the gRPC driver: "
|
||||
<< status.error_code() << ": "
|
||||
<< status.error_details();
|
||||
return xla::Status(
|
||||
tensorflow::error::Code(status.error_code()),
|
||||
status.error_message() + status.error_details());
|
||||
}
|
||||
return std::unique_ptr<TpuDriver>(
|
||||
new GrpcTpuDriver(config, resp.client_id()));
|
||||
});
|
||||
REGISTER_TPU_DRIVER(
|
||||
"grpc://",
|
||||
[](const TpuDriverConfig& config)
|
||||
-> xla::StatusOr<std::unique_ptr<TpuDriver>> {
|
||||
auto stub = GrpcTpuDriver::CreateTpuDriverStub(config);
|
||||
::grpc::ClientContext ctx;
|
||||
ctx.set_fail_fast(false);
|
||||
ctx.set_deadline(std::chrono::system_clock::now() +
|
||||
std::chrono::seconds(config.connection_timeout_secs));
|
||||
OpenRequest req;
|
||||
OpenResponse resp;
|
||||
::grpc::Status status = stub->Open(&ctx, req, &resp);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Failed to open the gRPC driver: " << status.error_code()
|
||||
<< ": " << status.error_details();
|
||||
return xla::Status(
|
||||
tensorflow::error::Code(status.error_code()),
|
||||
absl::StrCat("Failed to connect to remote server at address: ",
|
||||
config.worker,
|
||||
". Error from gRPC: ", status.error_details()));
|
||||
}
|
||||
return std::unique_ptr<TpuDriver>(
|
||||
new GrpcTpuDriver(config, resp.client_id()));
|
||||
});
|
||||
|
||||
} // namespace
|
||||
} // namespace tpu_driver
|
||||
|
@ -229,6 +229,7 @@ class TpuDriver {
|
||||
|
||||
struct TpuDriverConfig {
|
||||
std::string worker;
|
||||
int64_t connection_timeout_secs = 10;
|
||||
};
|
||||
|
||||
class TpuDriverRegistry {
|
||||
|
Loading…
x
Reference in New Issue
Block a user