diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index 0c422f5c0ec..c251d14c8db 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -48,7 +48,7 @@ static std::shared_ptr MakeDevice(const std::string& platform_name, StatusOr> PyTpuClient::Get( const std::string& worker) { tpu_driver::TpuDriverConfig driver_config; - driver_config.worker = worker; + driver_config.set_worker(worker); auto client_status = tpu_driver::TpuDriverRegistry::Open(driver_config); if (!client_status.ok()) { diff --git a/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc index e6ca04e357f..c50053c0981 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc @@ -907,8 +907,8 @@ GrpcTpuDriver::CreateTpuDriverStub(const TpuDriverConfig& config) { args.SetMaxSendMessageSize(std::numeric_limits::max()); // Send at least 20 keep-alives before giving up. - int keepalive_timeout_ms = config.keepalive_timeout_secs * 1000; - int keepalive_interval_ms = config.keepalive_timeout_secs / 20; + int keepalive_timeout_ms = config.grpc().keepalive_timeout_secs() * 1000; + int keepalive_interval_ms = keepalive_timeout_ms / 20; grpc_arg client_arg_vals[] = { {.type = GRPC_ARG_INTEGER, @@ -935,7 +935,7 @@ GrpcTpuDriver::CreateTpuDriverStub(const TpuDriverConfig& config) { args.SetChannelArgs(&client_args); // strips out 'grpc://' - auto worker_addr = absl::StripPrefix(config.worker, kGrpcProtocol); + auto worker_addr = absl::StripPrefix(config.worker(), kGrpcProtocol); std::shared_ptr<::grpc::Channel> channel = ::grpc::CreateCustomChannel(std::string(worker_addr), creds, args); return grpc::CloudTpuDriver::NewStub(channel); @@ -977,8 +977,9 @@ REGISTER_TPU_DRIVER( auto stub = GrpcTpuDriver::CreateTpuDriverStub(config); ::grpc::ClientContext ctx; ctx.set_fail_fast(false); - ctx.set_deadline(std::chrono::system_clock::now() + - std::chrono::seconds(config.connection_timeout_secs)); + ctx.set_deadline( + std::chrono::system_clock::now() + + std::chrono::seconds(config.grpc().connection_timeout_secs())); OpenRequest req; OpenResponse resp; ::grpc::Status status = stub->Open(&ctx, req, &resp); @@ -988,7 +989,7 @@ REGISTER_TPU_DRIVER( return xla::Status( tensorflow::error::Code(status.error_code()), absl::StrCat("Failed to connect to remote server at address: ", - config.worker, + config.worker(), ". Error from gRPC: ", status.error_details())); } return std::unique_ptr( diff --git a/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc index efc5e73e3d2..4116691eb0a 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc @@ -516,7 +516,7 @@ class RecordingTpuDriver : public TpuDriver { xla::StatusOr> RegisterRecordingTpuDriver( const TpuDriverConfig& config) { - std::vector configs = absl::StrSplit(config.worker, '|'); + std::vector configs = absl::StrSplit(config.worker(), '|'); std::string file; std::string worker; @@ -533,7 +533,7 @@ xla::StatusOr> RegisterRecordingTpuDriver( } TpuDriverConfig worker_config; - worker_config.worker = worker; + worker_config.set_worker(worker); auto driver_status = TpuDriverRegistry::Open(worker_config); if (!driver_status.ok()) return driver_status.status(); diff --git a/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.cc index b77c81ca73a..1920cf75e26 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.cc @@ -88,12 +88,12 @@ uint64_t ByteSizeOfPrimitiveType(xla::PrimitiveType primitive_type) { /*static*/ xla::StatusOr> TpuDriverRegistry::Open( const TpuDriverConfig& config) { for (const auto& driver : *GetDriverRegistryMap()) { - if (absl::StartsWith(config.worker, driver.first)) { + if (absl::StartsWith(config.worker(), driver.first)) { return driver.second(config); } } return xla::NotFound("Unable to find driver in registry given worker: %s", - config.worker); + config.worker()); } uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape) { diff --git a/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h b/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h index 13b0a930b5e..8c9bc4ea26d 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h +++ b/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h @@ -37,6 +37,9 @@ #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" +// This API is EXPERIMENTAL and under active developement. It is subject to +// change without notice. + namespace tpu_driver { uint64_t ComputeBytesFromShape(const xla::ShapeProto& shape); @@ -227,17 +230,6 @@ class TpuDriver { virtual std::unique_ptr GetLinearizer() { return nullptr; } }; -struct TpuDriverConfig { - std::string worker; - - // Time in seconds before the initial connection to the server will timeout. - int64_t connection_timeout_secs = 10; - - // Time in seconds the server may be unresponsive before terminating the - // connection. - int64_t keepalive_timeout_secs = 30; -}; - class TpuDriverRegistry { public: static xla::StatusOr> Open( diff --git a/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.proto b/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.proto index f047de26289..a8721839789 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.proto +++ b/tensorflow/compiler/xla/python/tpu_driver/tpu_driver.proto @@ -41,3 +41,18 @@ message SystemInfo { repeated TpuChipInfo tpu_chip = 1; required CpuInfo cpu = 2; } + +message TpuDriverConfig { + optional string worker = 1; + + message GrpcConfig { + // Time in seconds before the initial connection to the server will timeout. + optional int64 connection_timeout_secs = 1 [default = 30]; + + // Time in seconds the server may be unresponsive before terminating the + // connection. + optional int64 keepalive_timeout_secs = 2 [default = 30]; + } + + optional GrpcConfig grpc = 2; +}