From fda82aa165cde37cafca07a30b1e12f5bea46c14 Mon Sep 17 00:00:00 2001 From: jiakai Date: Thu, 23 May 2019 17:54:08 +0800 Subject: [PATCH] Rename WorkerCacheInterface::CreateWorker to GetOrCreateWorker(). Change-Id: I7d3a0c40f7a578a6eef97f82bad13bb6ac2f9cbc --- tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc | 2 +- tensorflow/core/distributed_runtime/cancellable_call.h | 2 +- .../cluster_function_library_runtime.cc | 2 +- .../distributed_runtime/device_resolver_distributed.cc | 2 +- tensorflow/core/distributed_runtime/master.cc | 2 +- tensorflow/core/distributed_runtime/master_session.cc | 6 +++--- tensorflow/core/distributed_runtime/remote_device.cc | 2 +- .../core/distributed_runtime/remote_device_test.cc | 2 +- .../core/distributed_runtime/rpc/grpc_worker_cache.cc | 2 +- .../core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc | 2 +- .../distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc | 2 +- .../distributed_runtime/rpc_collective_executor_mgr.cc | 2 +- tensorflow/core/distributed_runtime/test_utils.h | 2 +- tensorflow/core/distributed_runtime/worker_cache.h | 7 ++----- .../core/distributed_runtime/worker_cache_partial.cc | 2 +- .../core/distributed_runtime/worker_cache_wrapper.h | 9 +++------ tensorflow/core/distributed_runtime/worker_session.cc | 4 ++-- 17 files changed, 23 insertions(+), 29 deletions(-) diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc index 1124dff7413..4744a9ee9a8 100644 --- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc +++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc @@ -142,7 +142,7 @@ class GdrRemoteRendezvous : public BaseRemoteRendezvous { } WorkerSession* sess = session(); - WorkerInterface* rwi = sess->worker_cache->CreateWorker(src_worker); + WorkerInterface* rwi = sess->worker_cache->GetOrCreateWorker(src_worker); if (rwi == nullptr) { Status s = errors::Internal("No worker known as ", src_worker); done(s, Args(), recv_args, Tensor{}, false); diff --git a/tensorflow/core/distributed_runtime/cancellable_call.h b/tensorflow/core/distributed_runtime/cancellable_call.h index dcf9f973a9e..3d82bef5c80 100644 --- a/tensorflow/core/distributed_runtime/cancellable_call.h +++ b/tensorflow/core/distributed_runtime/cancellable_call.h @@ -32,7 +32,7 @@ class CancellableCall { : cancel_mgr_(cancel_mgr), remote_worker_(remote_worker), wc_(wc), - wi_(wc_->CreateWorker(remote_worker_)) {} + wi_(wc_->GetOrCreateWorker(remote_worker_)) {} virtual ~CancellableCall() { wc_->ReleaseWorker(remote_worker_, wi_); } diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc index d2fcd9cdd2c..a26cdb6eb4d 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc @@ -127,7 +127,7 @@ Status ClusterFunctionLibraryRuntime::Instantiate( VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << options.target << " (this: " << this << ")"; WorkerInterface* wi = - worker_session_->worker_cache->CreateWorker(options.target); + worker_session_->worker_cache->GetOrCreateWorker(options.target); if (wi == nullptr) { std::vector workers; diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed.cc b/tensorflow/core/distributed_runtime/device_resolver_distributed.cc index 038974cb390..165a19641ef 100644 --- a/tensorflow/core/distributed_runtime/device_resolver_distributed.cc +++ b/tensorflow/core/distributed_runtime/device_resolver_distributed.cc @@ -97,7 +97,7 @@ void DeviceResolverDistributed::RefreshRemoteAttributes( const string& device, const string& task, const StatusCallback& done) { GetStatusRequest* req = new GetStatusRequest; GetStatusResponse* resp = new GetStatusResponse; - WorkerInterface* worker = worker_cache_->CreateWorker(task); + WorkerInterface* worker = worker_cache_->GetOrCreateWorker(task); CHECK(worker) << "Failed to get worker for " << task; worker->GetStatusAsync( req, resp, [this, device, task, req, resp, worker, done](Status s) { diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc index fc8d2871ac7..3b926d9dfcd 100644 --- a/tensorflow/core/distributed_runtime/master.cc +++ b/tensorflow/core/distributed_runtime/master.cc @@ -629,7 +629,7 @@ void Master::CleanupWorkers(const ResetRequest& reset) { int c = 0; for (int i = 0; i < num_workers; ++i) { const string& worker_name = worker_names[i]; - auto worker = env_->worker_cache->CreateWorker(worker_name); + auto worker = env_->worker_cache->GetOrCreateWorker(worker_name); if (worker) { worker->CleanupAllAsync( &req, &resp[i], [this, &n, worker_name, worker, c](Status s) { diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index e3b788f437c..81bfd90265a 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -444,7 +444,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions( Part* part = &partitions_.back(); part->name = name_def.first; TrackFeedsAndFetches(part, name_def.second, popts); - part->worker = worker_cache_->CreateWorker(part->name); + part->worker = worker_cache_->GetOrCreateWorker(part->name); if (part->worker == nullptr) { s = errors::NotFound("worker ", part->name); break; @@ -1279,7 +1279,7 @@ Status MasterSession::CreateWorkerSessions( // Create all the workers & kick off the computations. for (size_t i = 0; i < worker_names.size(); ++i) { workers[i].name = &worker_names[i]; - workers[i].worker = worker_cache->CreateWorker(worker_names[i]); + workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]); workers[i].request.set_session_handle(handle_); if (session_opts_.config.experimental() .share_cluster_devices_in_session()) { @@ -1377,7 +1377,7 @@ Status MasterSession::DeleteWorkerSessions() { // Create all the workers & kick off the computations. for (size_t i = 0; i < worker_names.size(); ++i) { workers[i].name = &worker_names[i]; - workers[i].worker = worker_cache->CreateWorker(worker_names[i]); + workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]); workers[i].request.set_session_handle(handle_); // Since the worker may have gone away, set a timeout to avoid blocking the // session-close operation. diff --git a/tensorflow/core/distributed_runtime/remote_device.cc b/tensorflow/core/distributed_runtime/remote_device.cc index f0fc6666b29..82fd5832179 100644 --- a/tensorflow/core/distributed_runtime/remote_device.cc +++ b/tensorflow/core/distributed_runtime/remote_device.cc @@ -66,7 +66,7 @@ void AsRemoteDevices( void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache, const string& worker_name, NewRemoteDevicesDone done) { - WorkerInterface* wi = worker_cache->CreateWorker(worker_name); + WorkerInterface* wi = worker_cache->GetOrCreateWorker(worker_name); if (wi == nullptr) { std::vector empty; done(errors::NotFound("Device ", worker_name, " is not found."), &empty); diff --git a/tensorflow/core/distributed_runtime/remote_device_test.cc b/tensorflow/core/distributed_runtime/remote_device_test.cc index a04e79328b0..62082aa6d59 100644 --- a/tensorflow/core/distributed_runtime/remote_device_test.cc +++ b/tensorflow/core/distributed_runtime/remote_device_test.cc @@ -53,7 +53,7 @@ class RemoteDeviceTest : public ::testing::Test { NewGrpcChannelCache(spec, channel_func)); worker_cache_.reset(NewGrpcWorkerCache(channel_cache)); remote_name_ = "/job:localhost/replica:0/task:0"; - wi_ = worker_cache_->CreateWorker(remote_name_); + wi_ = worker_cache_->GetOrCreateWorker(remote_name_); } ~RemoteDeviceTest() override { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc index 60d5881d4ca..0c54232d0bc 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc @@ -69,7 +69,7 @@ class GrpcWorkerCache : public WorkerCachePartial { channel_cache_->ListWorkersInJob(job_name, workers); } - WorkerInterface* CreateWorker(const string& target) override { + WorkerInterface* GetOrCreateWorker(const string& target) override { if (target == local_target_) { return local_worker_; } else { diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc index ee561e1a8a0..1a4298d5550 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc @@ -232,7 +232,7 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( // The worker will be released in a subsequent call to // `sess->worker_cache->ReleaseWorker()` (if the call has not yet been // initialized) or `call->ReleaseWorker()` (if it has been initialized). - WorkerInterface* rwi = sess->worker_cache->CreateWorker(call->src_worker_); + WorkerInterface* rwi = sess->worker_cache->GetOrCreateWorker(call->src_worker_); if (s.ok() && rwi == nullptr) { s = errors::Internal("No worker known as ", call->src_worker_); } diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc index 28ac30d07ae..7633f59ed31 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc @@ -52,7 +52,7 @@ class DummyWorkerCache : public WorkerCacheInterface { void ListWorkers(std::vector* workers) const override {} void ListWorkersInJob(const string& job_name, std::vector* workers) const override {} - WorkerInterface* CreateWorker(const string& target) override { + WorkerInterface* GetOrCreateWorker(const string& target) override { return nullptr; } bool GetDeviceLocalityNonBlocking(const string& device, diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc index 054bed7781b..9157dbe648c 100644 --- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc @@ -80,7 +80,7 @@ void RpcCollectiveExecutorMgr::RefreshStepIdSequenceAsync( gks->next_step_id_ = NewRandomStepId(); done(Status::OK()); } else { - WorkerInterface* wi = worker_cache_->CreateWorker(group_leader_); + WorkerInterface* wi = worker_cache_->GetOrCreateWorker(group_leader_); GetStepSequenceRequest* req = new GetStepSequenceRequest; GetStepSequenceResponse* resp = new GetStepSequenceResponse; req->add_graph_key(graph_key); diff --git a/tensorflow/core/distributed_runtime/test_utils.h b/tensorflow/core/distributed_runtime/test_utils.h index 88a97da34d6..12803ed997f 100644 --- a/tensorflow/core/distributed_runtime/test_utils.h +++ b/tensorflow/core/distributed_runtime/test_utils.h @@ -152,7 +152,7 @@ class TestWorkerCache : public WorkerCacheInterface { } } - WorkerInterface* CreateWorker(const string& target) override { + WorkerInterface* GetOrCreateWorker(const string& target) override { auto it = workers_.find(target); if (it != workers_.end()) { return it->second; diff --git a/tensorflow/core/distributed_runtime/worker_cache.h b/tensorflow/core/distributed_runtime/worker_cache.h index 0c8575b4d5d..8e2c06af77c 100644 --- a/tensorflow/core/distributed_runtime/worker_cache.h +++ b/tensorflow/core/distributed_runtime/worker_cache.h @@ -43,12 +43,9 @@ class WorkerCacheInterface { // or can be constructed, returns a pointer to a WorkerInterface object // wrapping that channel. The returned value must be destroyed by // calling `this->ReleaseWorker(target, ret)` - // TODO(mrry): rename this to GetOrCreateWorker() or something that - // makes it more obvious that this method returns a potentially - // shared object. - virtual WorkerInterface* CreateWorker(const string& target) = 0; + virtual WorkerInterface* GetOrCreateWorker(const string& target) = 0; - // Release a worker previously returned by this->CreateWorker(target). + // Release a worker previously returned by this->GetOrCreateWorker(target). // // TODO(jeff,sanjay): Consider moving target into WorkerInterface. // TODO(jeff,sanjay): Unify all worker-cache impls and factor out a diff --git a/tensorflow/core/distributed_runtime/worker_cache_partial.cc b/tensorflow/core/distributed_runtime/worker_cache_partial.cc index 55b6957b962..ec59d12a021 100644 --- a/tensorflow/core/distributed_runtime/worker_cache_partial.cc +++ b/tensorflow/core/distributed_runtime/worker_cache_partial.cc @@ -65,7 +65,7 @@ Status WorkerCachePartial::RefreshDeviceStatus(const string& device_name) { auto deleter = [this, &task](WorkerInterface* wi) { ReleaseWorker(task, wi); }; - std::unique_ptr rwi(CreateWorker(task), + std::unique_ptr rwi(GetOrCreateWorker(task), deleter); if (s.ok() && !rwi) { s = errors::Internal("RefreshDeviceStatus, unknown worker task: ", task); diff --git a/tensorflow/core/distributed_runtime/worker_cache_wrapper.h b/tensorflow/core/distributed_runtime/worker_cache_wrapper.h index 1f309b4361f..d242ff82ab7 100644 --- a/tensorflow/core/distributed_runtime/worker_cache_wrapper.h +++ b/tensorflow/core/distributed_runtime/worker_cache_wrapper.h @@ -41,14 +41,11 @@ class WorkerCacheWrapper : public WorkerCacheInterface { // or can be constructed, returns a pointer to a WorkerInterface object // wrapping that channel. The returned value must be destroyed by // calling `this->ReleaseWorker(target, ret)` - // TODO(mrry): rename this to GetOrCreateWorker() or something that - // makes it more obvious that this method returns a potentially - // shared object. - virtual WorkerInterface* CreateWorker(const string& target) { - return wrapped_->CreateWorker(target); + virtual WorkerInterface* GetOrCreateWorker(const string& target) { + return wrapped_->GetOrCreateWorker(target); } - // Release a worker previously returned by this->CreateWorker(target). + // Release a worker previously returned by this->GetOrCreateWorker(target). // // TODO(jeff,sanjay): Consider moving target into WorkerInterface. // TODO(jeff,sanjay): Unify all worker-cache impls and factor out a diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc index 1a716c618f2..02721421d94 100644 --- a/tensorflow/core/distributed_runtime/worker_session.cc +++ b/tensorflow/core/distributed_runtime/worker_session.cc @@ -42,14 +42,14 @@ class WorkerFreeListCache : public WorkerCacheInterface { wrapped_->ListWorkersInJob(job_name, workers); } - WorkerInterface* CreateWorker(const string& target) override { + WorkerInterface* GetOrCreateWorker(const string& target) override { mutex_lock l(mu_); auto p = workers_.find(target); if (p != workers_.end()) { return p->second.worker; } WorkerState state; - state.worker = wrapped_->CreateWorker(target); + state.worker = wrapped_->GetOrCreateWorker(target); if (state.worker != nullptr) { workers_.insert(std::make_pair(target, state)); }