From a33022c1470ce1334766b0cad38d9e91c17a2e5d Mon Sep 17 00:00:00 2001 From: Liangliang He Date: Fri, 12 May 2017 13:11:33 +0800 Subject: [PATCH] Fix verbs compile error (#9791) --- tensorflow/contrib/verbs/rdma.cc | 7 ++----- tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc | 15 ++++++--------- tensorflow/contrib/verbs/rdma_rendezvous_mgr.h | 7 +++---- tensorflow/contrib/verbs/verbs_server_lib.cc | 14 ++++---------- 4 files changed, 15 insertions(+), 28 deletions(-) diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index 05df05de353..bfd00ba3f51 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -778,11 +778,8 @@ void RdmaTensorBuffer::SendNextItem() { EnqueueItem(key_with_step_id); } }; - // Use default session (legacy_session_) - // TODO use WorkerSessionForSession - // need to pass in session handle - channel_->adapter_->worker_env_->session_mgr->LegacySession() - ->rendezvous_mgr->RecvLocalAsync(step_id, parsed, cb); + channel_->adapter_->worker_env_->rendezvous_mgr + ->RecvLocalAsync(step_id, parsed, cb); } } diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc index 8cbdfaa9439..d665f92cd92 100644 --- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc @@ -29,9 +29,9 @@ namespace tensorflow { class RdmaRemoteRendezvous : public BaseRemoteRendezvous { public: - RdmaRemoteRendezvous(const WorkerEnv* env, const string& worker_name, + RdmaRemoteRendezvous(const WorkerEnv* env, int64 step_id, RdmaMgr* rdma_mgr) - : BaseRemoteRendezvous(env, worker_name, step_id, true), + : BaseRemoteRendezvous(env, step_id, true), rdma_mgr_(rdma_mgr) {} protected: @@ -133,15 +133,12 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( rb->SendNextItem(); } -RdmaRendezvousMgr::RdmaRendezvousMgr(const WorkerEnv* env, - const string& worker_name, - WorkerCacheInterface* worker_cache) - : BaseRendezvousMgr(env, worker_name) {} +RdmaRendezvousMgr::RdmaRendezvousMgr(const WorkerEnv* env) + : BaseRendezvousMgr(env) {} BaseRemoteRendezvous* RdmaRendezvousMgr::Create(int64 step_id, - const WorkerEnv* worker_env, - const string& worker_name) { - return new RdmaRemoteRendezvous(worker_env, worker_name, step_id, rdma_mgr_); + const WorkerEnv* worker_env) { + return new RdmaRemoteRendezvous(worker_env, step_id, rdma_mgr_); } } // end namespace tensorflow diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h index 57cd4bf5e4e..2dedd6c48f9 100644 --- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h +++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h @@ -45,13 +45,12 @@ namespace tensorflow { // RendezvousMgr must have keys generated by Rendezvous::CreateKey. class RdmaRendezvousMgr : public BaseRendezvousMgr { public: - explicit RdmaRendezvousMgr(const WorkerEnv* env, const string& worker_name, - WorkerCacheInterface* worker_cache); + explicit RdmaRendezvousMgr(const WorkerEnv* env); void SetRdmaMgr(RdmaMgr* rdma_mgr) { rdma_mgr_ = rdma_mgr; } protected: - BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env, - const string& worker_name) override; + BaseRemoteRendezvous* Create(int64 step_id, + const WorkerEnv* worker_env) override; private: RdmaMgr* rdma_mgr_; diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc index b061c81d2d8..c3597249354 100644 --- a/tensorflow/contrib/verbs/verbs_server_lib.cc +++ b/tensorflow/contrib/verbs/verbs_server_lib.cc @@ -27,10 +27,8 @@ namespace tensorflow { namespace { // static utility function -RendezvousMgrInterface* NewRdmaRendezvousMgr( - const WorkerEnv* env, const string& worker_name, - WorkerCacheInterface* worker_cache) { - return new RdmaRendezvousMgr(env, worker_name, worker_cache); +RendezvousMgrInterface* NewRdmaRendezvousMgr(const WorkerEnv* env) { + return new RdmaRendezvousMgr(env); } } // namespace @@ -56,7 +54,7 @@ Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def, TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec)); *channel_cache = - NewGrpcChannelCache(channel_spec, GetChannelCreationFunction(server_def)); + NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()); const string host_port = (*channel_cache)->TranslateTask(name_prefix); int requested_port; @@ -86,11 +84,7 @@ Status VerbsServer::Init(ServiceInitFunction service_func, rdma_mgr_ = new RdmaMgr(worker_env(), channel_cache_); // set rdma_mgr for verbs_service and rdma_rendezvous_mgr verbs_service_->SetRdmaMgr(rdma_mgr_); - // hardcoded to default session (legacy_session_) - // TODO: use WorkerSessionForSession - // need to pass in session handle - dynamic_cast( - worker_env()->session_mgr->LegacySession()->rendezvous_mgr.get()) + dynamic_cast(worker_env()->rendezvous_mgr) ->SetRdmaMgr(rdma_mgr_); } return s;