diff --git a/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc index 2cbeffd62c6..64aeb13d71d 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc @@ -978,28 +978,30 @@ Status GrpcTpuDriver::Reset() { return xla::Unimplemented("GRPC driver reset is not implemented yet."); } -REGISTER_TPU_DRIVER("grpc://", - [](const TpuDriverConfig& config) - -> xla::StatusOr> { - auto stub = GrpcTpuDriver::CreateTpuDriverStub(config); - ::grpc::ClientContext ctx; - ctx.set_fail_fast(false); - ctx.set_deadline(std::chrono::system_clock::now() + - std::chrono::seconds(10)); - OpenRequest req; - OpenResponse resp; - ::grpc::Status status = stub->Open(&ctx, req, &resp); - if (!status.ok()) { - LOG(ERROR) << "Failed to open the gRPC driver: " - << status.error_code() << ": " - << status.error_details(); - return xla::Status( - tensorflow::error::Code(status.error_code()), - status.error_message() + status.error_details()); - } - return std::unique_ptr( - new GrpcTpuDriver(config, resp.client_id())); - }); +REGISTER_TPU_DRIVER( + "grpc://", + [](const TpuDriverConfig& config) + -> xla::StatusOr> { + auto stub = GrpcTpuDriver::CreateTpuDriverStub(config); + ::grpc::ClientContext ctx; + ctx.set_fail_fast(false); + ctx.set_deadline(std::chrono::system_clock::now() + + std::chrono::seconds(config.connection_timeout_secs)); + OpenRequest req; + OpenResponse resp; + ::grpc::Status status = stub->Open(&ctx, req, &resp); + if (!status.ok()) { + LOG(ERROR) << "Failed to open the gRPC driver: " << status.error_code() + << ": " << status.error_details(); + return xla::Status( + tensorflow::error::Code(status.error_code()), + absl::StrCat("Failed to connect to remote server at address: ", + config.worker, + ". Error from gRPC: ", status.error_details())); + } + return std::unique_ptr( + new GrpcTpuDriver(config, resp.client_id())); + }); } // namespace } // namespace tpu_driver diff --git a/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h b/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h index 3b010b38a17..fa45328e649 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h +++ b/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h @@ -229,6 +229,7 @@ class TpuDriver { struct TpuDriverConfig { std::string worker; + int64_t connection_timeout_secs = 10; }; class TpuDriverRegistry {