Allow overriding the GetPort and WorkerCacheFactory functions in

grpc_server_lib

PiperOrigin-RevId: 243700130
This commit is contained in:
Akshay Modi 2019-04-15 15:27:09 -07:00 committed by TensorFlower Gardener
parent 554caf4863
commit 4307698dce
2 changed files with 30 additions and 18 deletions

View File

@ -112,17 +112,9 @@ GrpcServer::~GrpcServer() {
void GrpcServer::MaybeMutateBuilder(::grpc::ServerBuilder* builder) {}
Status GrpcServer::Init(const GrpcServerOptions& opts) {
mutex_lock l(mu_);
CHECK_EQ(state_, NEW);
master_env_.env = env_;
worker_env_.env = env_;
// Check parameters before DeviceFactory::AddDevices,
// otherwise if 'task_index=-1' the program will abort.
// Look up the port that has been requested for this task in `server_def_`.
int requested_port = -1;
// Look up the port that has been requested for this task in `server_def_`.
Status GrpcServer::GetPort(int* port) const {
*port = -1;
for (const auto& job : server_def_.cluster().job()) {
if (job.name() == server_def_.job_name()) {
auto iter = job.tasks().find(server_def_.task_index());
@ -132,8 +124,7 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
server_def_.job_name(), "\"");
}
auto colon_index = iter->second.find_last_of(':');
if (!strings::safe_strto32(iter->second.substr(colon_index + 1),
&requested_port)) {
if (!strings::safe_strto32(iter->second.substr(colon_index + 1), port)) {
return errors::InvalidArgument(
"Could not parse port for local server from \"", iter->second,
"\".");
@ -141,11 +132,26 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
break;
}
}
if (requested_port == -1) {
if (*port == -1) {
return errors::Internal("Job \"", server_def_.job_name(),
"\" was not defined in cluster");
}
return Status::OK();
}
Status GrpcServer::Init(const GrpcServerOptions& opts) {
mutex_lock l(mu_);
CHECK_EQ(state_, NEW);
master_env_.env = env_;
worker_env_.env = env_;
// Check parameters before DeviceFactory::AddDevices,
// otherwise if 'task_index=-1' the program will abort.
int requested_port;
TF_RETURN_IF_ERROR(GetPort(&requested_port));
SessionOptions sess_opts;
ConfigProto config = server_def_.default_session_config();
sess_opts.config = config;
@ -337,8 +343,8 @@ Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
" differs from expected port ", bound_port_);
}
*worker_cache = NewGrpcWorkerCacheWithLocalWorker(
channel_cache_, worker_impl_.get(), name_prefix);
*worker_cache = NewGrpcWorkerCacheWithLocalWorker(channel_cache_,
worker_impl(), name_prefix);
return Status::OK();
}

View File

@ -98,6 +98,7 @@ class GrpcServer : public ServerInterface {
std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; }
protected:
virtual Status GetPort(int* port) const;
Status Init(const GrpcServerOptions& opts = GrpcServerOptions());
// A subclass can override this method to support secure credentials.
@ -109,8 +110,8 @@ class GrpcServer : public ServerInterface {
virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env);
// Creates a WorkerCacheInterface for a session.
Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
WorkerCacheInterface** worker_cache);
virtual Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
WorkerCacheInterface** worker_cache);
// Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec.
Status ParseChannelSpec(const WorkerCacheFactoryOptions& options,
@ -121,6 +122,11 @@ class GrpcServer : public ServerInterface {
int bound_port() const { return bound_port_; }
const ServerDef& server_def() const { return server_def_; }
GrpcWorker* worker_impl() const { return worker_impl_.get(); }
void set_channel_cache(GrpcChannelCache* channel_cache) {
channel_cache_.reset(channel_cache);
}
private:
// The overall server configuration.