Allow overriding the GetPort and WorkerCacheFactory functions in
grpc_server_lib PiperOrigin-RevId: 243700130
This commit is contained in:
parent
554caf4863
commit
4307698dce
@ -112,17 +112,9 @@ GrpcServer::~GrpcServer() {
|
||||
|
||||
void GrpcServer::MaybeMutateBuilder(::grpc::ServerBuilder* builder) {}
|
||||
|
||||
Status GrpcServer::Init(const GrpcServerOptions& opts) {
|
||||
mutex_lock l(mu_);
|
||||
CHECK_EQ(state_, NEW);
|
||||
master_env_.env = env_;
|
||||
worker_env_.env = env_;
|
||||
|
||||
// Check parameters before DeviceFactory::AddDevices,
|
||||
// otherwise if 'task_index=-1' the program will abort.
|
||||
|
||||
// Look up the port that has been requested for this task in `server_def_`.
|
||||
int requested_port = -1;
|
||||
// Look up the port that has been requested for this task in `server_def_`.
|
||||
Status GrpcServer::GetPort(int* port) const {
|
||||
*port = -1;
|
||||
for (const auto& job : server_def_.cluster().job()) {
|
||||
if (job.name() == server_def_.job_name()) {
|
||||
auto iter = job.tasks().find(server_def_.task_index());
|
||||
@ -132,8 +124,7 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
|
||||
server_def_.job_name(), "\"");
|
||||
}
|
||||
auto colon_index = iter->second.find_last_of(':');
|
||||
if (!strings::safe_strto32(iter->second.substr(colon_index + 1),
|
||||
&requested_port)) {
|
||||
if (!strings::safe_strto32(iter->second.substr(colon_index + 1), port)) {
|
||||
return errors::InvalidArgument(
|
||||
"Could not parse port for local server from \"", iter->second,
|
||||
"\".");
|
||||
@ -141,11 +132,26 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (requested_port == -1) {
|
||||
if (*port == -1) {
|
||||
return errors::Internal("Job \"", server_def_.job_name(),
|
||||
"\" was not defined in cluster");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GrpcServer::Init(const GrpcServerOptions& opts) {
|
||||
mutex_lock l(mu_);
|
||||
CHECK_EQ(state_, NEW);
|
||||
master_env_.env = env_;
|
||||
worker_env_.env = env_;
|
||||
|
||||
// Check parameters before DeviceFactory::AddDevices,
|
||||
// otherwise if 'task_index=-1' the program will abort.
|
||||
|
||||
int requested_port;
|
||||
TF_RETURN_IF_ERROR(GetPort(&requested_port));
|
||||
|
||||
SessionOptions sess_opts;
|
||||
ConfigProto config = server_def_.default_session_config();
|
||||
sess_opts.config = config;
|
||||
@ -337,8 +343,8 @@ Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
|
||||
" differs from expected port ", bound_port_);
|
||||
}
|
||||
|
||||
*worker_cache = NewGrpcWorkerCacheWithLocalWorker(
|
||||
channel_cache_, worker_impl_.get(), name_prefix);
|
||||
*worker_cache = NewGrpcWorkerCacheWithLocalWorker(channel_cache_,
|
||||
worker_impl(), name_prefix);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -98,6 +98,7 @@ class GrpcServer : public ServerInterface {
|
||||
std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; }
|
||||
|
||||
protected:
|
||||
virtual Status GetPort(int* port) const;
|
||||
Status Init(const GrpcServerOptions& opts = GrpcServerOptions());
|
||||
|
||||
// A subclass can override this method to support secure credentials.
|
||||
@ -109,8 +110,8 @@ class GrpcServer : public ServerInterface {
|
||||
virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env);
|
||||
|
||||
// Creates a WorkerCacheInterface for a session.
|
||||
Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
|
||||
WorkerCacheInterface** worker_cache);
|
||||
virtual Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
|
||||
WorkerCacheInterface** worker_cache);
|
||||
|
||||
// Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec.
|
||||
Status ParseChannelSpec(const WorkerCacheFactoryOptions& options,
|
||||
@ -121,6 +122,11 @@ class GrpcServer : public ServerInterface {
|
||||
int bound_port() const { return bound_port_; }
|
||||
|
||||
const ServerDef& server_def() const { return server_def_; }
|
||||
GrpcWorker* worker_impl() const { return worker_impl_.get(); }
|
||||
|
||||
void set_channel_cache(GrpcChannelCache* channel_cache) {
|
||||
channel_cache_.reset(channel_cache);
|
||||
}
|
||||
|
||||
private:
|
||||
// The overall server configuration.
|
||||
|
Loading…
Reference in New Issue
Block a user