diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index f087a39f019..c0407af29ba 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -112,17 +112,9 @@ GrpcServer::~GrpcServer() { void GrpcServer::MaybeMutateBuilder(::grpc::ServerBuilder* builder) {} -Status GrpcServer::Init(const GrpcServerOptions& opts) { - mutex_lock l(mu_); - CHECK_EQ(state_, NEW); - master_env_.env = env_; - worker_env_.env = env_; - - // Check parameters before DeviceFactory::AddDevices, - // otherwise if 'task_index=-1' the program will abort. - - // Look up the port that has been requested for this task in `server_def_`. - int requested_port = -1; +// Look up the port that has been requested for this task in `server_def_`. +Status GrpcServer::GetPort(int* port) const { + *port = -1; for (const auto& job : server_def_.cluster().job()) { if (job.name() == server_def_.job_name()) { auto iter = job.tasks().find(server_def_.task_index()); @@ -132,8 +124,7 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) { server_def_.job_name(), "\""); } auto colon_index = iter->second.find_last_of(':'); - if (!strings::safe_strto32(iter->second.substr(colon_index + 1), - &requested_port)) { + if (!strings::safe_strto32(iter->second.substr(colon_index + 1), port)) { return errors::InvalidArgument( "Could not parse port for local server from \"", iter->second, "\"."); @@ -141,11 +132,26 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) { break; } } - if (requested_port == -1) { + if (*port == -1) { return errors::Internal("Job \"", server_def_.job_name(), "\" was not defined in cluster"); } + return Status::OK(); +} + +Status GrpcServer::Init(const GrpcServerOptions& opts) { + mutex_lock l(mu_); + CHECK_EQ(state_, NEW); + master_env_.env = env_; + worker_env_.env = env_; + + // Check parameters before DeviceFactory::AddDevices, + // otherwise if 'task_index=-1' the program will abort. + + int requested_port; + TF_RETURN_IF_ERROR(GetPort(&requested_port)); + SessionOptions sess_opts; ConfigProto config = server_def_.default_session_config(); sess_opts.config = config; @@ -337,8 +343,8 @@ Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options, " differs from expected port ", bound_port_); } - *worker_cache = NewGrpcWorkerCacheWithLocalWorker( - channel_cache_, worker_impl_.get(), name_prefix); + *worker_cache = NewGrpcWorkerCacheWithLocalWorker(channel_cache_, + worker_impl(), name_prefix); return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index f66d7eb82e8..17bc93588c3 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -98,6 +98,7 @@ class GrpcServer : public ServerInterface { std::shared_ptr channel_cache() { return channel_cache_; } protected: + virtual Status GetPort(int* port) const; Status Init(const GrpcServerOptions& opts = GrpcServerOptions()); // A subclass can override this method to support secure credentials. @@ -109,8 +110,8 @@ class GrpcServer : public ServerInterface { virtual std::unique_ptr CreateMaster(MasterEnv* master_env); // Creates a WorkerCacheInterface for a session. - Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options, - WorkerCacheInterface** worker_cache); + virtual Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options, + WorkerCacheInterface** worker_cache); // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec. Status ParseChannelSpec(const WorkerCacheFactoryOptions& options, @@ -121,6 +122,11 @@ class GrpcServer : public ServerInterface { int bound_port() const { return bound_port_; } const ServerDef& server_def() const { return server_def_; } + GrpcWorker* worker_impl() const { return worker_impl_.get(); } + + void set_channel_cache(GrpcChannelCache* channel_cache) { + channel_cache_.reset(channel_cache); + } private: // The overall server configuration.