diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index b2192c5a801..37029f3f1a7 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -562,6 +562,7 @@ cc_library( deps = [ ":worker_cache", ":worker_interface", + "//tensorflow/core:framework", ], ) diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc index a48f734d3e2..269f620e42e 100644 --- a/tensorflow/core/distributed_runtime/master.cc +++ b/tensorflow/core/distributed_runtime/master.cc @@ -53,6 +53,7 @@ limitations under the License. #include "tensorflow/core/protobuf/master.pb.h" #include "tensorflow/core/protobuf/worker.pb.h" #include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { @@ -167,13 +168,55 @@ class DeviceFinder { } // Enumerates all known workers' target. A target name is a // prefix of a device name. E.g., /job:mnist/replica:0/task:10. - CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided."; - const string& local_device_name = env_->local_devices[0]->name(); - std::vector workers; - worker_cache->ListWorkers(&workers); if (filters_.empty()) { + // If no filters were specified, we list all known workers in + // `worker_cache`. + std::vector workers; + worker_cache->ListWorkers(&workers); std::swap(workers, targets_); } else { + // When applying filters, we must include the local worker, even if it + // does not match any of the filters. + CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided."; + const string& local_device_name = env_->local_devices[0]->name(); + DeviceNameUtils::ParsedName local_parsed_name; + CHECK(DeviceNameUtils::ParseFullName(local_device_name, + &local_parsed_name)); + bool all_filters_have_job = true; + std::unordered_set filter_job_names({local_parsed_name.job}); + for (const DeviceNameUtils::ParsedName& filter : filters_) { + all_filters_have_job = all_filters_have_job && filter.has_job; + if (filter.has_job) { + filter_job_names.insert(filter.job); + } + } + + std::vector workers; + if (all_filters_have_job) { + // If all of the device filters have a job specified, then we only need + // to list the workers in the jobs named in the filter, because a worker + // in any other job would not match any filter. + for (const string& job_name : filter_job_names) { + VLOG(2) << "Selectively listing workers in job: " << job_name; + std::vector workers_in_job; + worker_cache->ListWorkersInJob(job_name, &workers_in_job); + workers.insert(workers.end(), workers_in_job.begin(), + workers_in_job.end()); + } + } else { + // If any of the device filters does not have a job specified, then we + // must list the workers from all jobs. + VLOG(2) << "Listing workers in all jobs because some device " + << "filter has no job specified. Filters were:"; + if (device_filters.empty()) { + VLOG(2) << "- "; + } else { + for (const string& filter : device_filters) { + VLOG(2) << "- " << filter; + } + } + worker_cache->ListWorkers(&workers); + } for (const string& name : workers) { if (MatchFilters(name) || DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc index b7eb3c9015a..456c30ecf49 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc @@ -163,6 +163,13 @@ class MultiGrpcChannelCache : public CachingGrpcChannelCache { } } + void ListWorkersInJob(const string& job_name, + std::vector* workers) override { + for (GrpcChannelCache* cache : caches_) { + cache->ListWorkersInJob(job_name, workers); + } + } + string TranslateTask(const string& target) override { mutex_lock l(mu_); // could use reader lock GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target); @@ -223,6 +230,13 @@ class SparseGrpcChannelCache : public CachingGrpcChannelCache { } } + void ListWorkersInJob(const string& job_name, + std::vector* workers) override { + if (job_name == job_id_) { + ListWorkers(workers); + } + } + string TranslateTask(const string& target) override { DeviceNameUtils::ParsedName parsed; if (!DeviceNameUtils::ParseFullName(target, &parsed)) { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h index 4861cdb691c..6fa99d7b148 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h @@ -66,6 +66,8 @@ class GrpcChannelCache { // /job:/task: // e.g. /job:mnist/task:2 virtual void ListWorkers(std::vector* workers) = 0; + virtual void ListWorkersInJob(const string& job_name, + std::vector* workers) = 0; // If found, returns a gRPC channel that is connected to the remote // worker named by 'target'. 'target' is of the following diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc index f07a5a09746..a814ef85e20 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc @@ -89,13 +89,33 @@ TEST(GrpcChannelTest, HostPorts) { EXPECT_NE(d_4_1.get(), e_5_2.get()); } - std::vector workers; - cc->ListWorkers(&workers); - EXPECT_EQ(std::vector( - {"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1", - "/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3", - "/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}), - workers); + { + std::vector workers; + cc->ListWorkers(&workers); + EXPECT_EQ( + std::vector( + {"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1", + "/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3", + "/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}), + workers); + } + + { + std::vector workers; + cc->ListWorkersInJob("mnist", &workers); + EXPECT_EQ( + std::vector( + {"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1", + "/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3", + "/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}), + workers); + } + + { + std::vector workers; + cc->ListWorkersInJob("other", &workers); + EXPECT_TRUE(workers.empty()); + } } TEST(GrpcChannelTest, SparseHostPorts) { @@ -135,13 +155,30 @@ TEST(GrpcChannelTest, SparseHostPorts) { EXPECT_NE(d_4_1.get(), e_5_2.get()); } - std::vector workers; - cc->ListWorkers(&workers); - std::sort(workers.begin(), workers.end()); - EXPECT_EQ(std::vector({"/job:mnist/replica:0/task:0", - "/job:mnist/replica:0/task:3", - "/job:mnist/replica:0/task:4"}), - workers); + { + std::vector workers; + cc->ListWorkers(&workers); + std::sort(workers.begin(), workers.end()); + EXPECT_EQ(std::vector({"/job:mnist/replica:0/task:0", + "/job:mnist/replica:0/task:3", + "/job:mnist/replica:0/task:4"}), + workers); + } + + { + std::vector workers; + cc->ListWorkersInJob("mnist", &workers); + EXPECT_EQ(std::vector({"/job:mnist/replica:0/task:0", + "/job:mnist/replica:0/task:3", + "/job:mnist/replica:0/task:4"}), + workers); + } + + { + std::vector workers; + cc->ListWorkersInJob("other", &workers); + EXPECT_TRUE(workers.empty()); + } } TEST(GrpcChannelTest, NewHostPortGrpcChannelValidation) { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc index b9f21ea211b..e1541db69bf 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc @@ -54,6 +54,11 @@ class GrpcWorkerCache : public WorkerCachePartial { channel_cache_->ListWorkers(workers); } + void ListWorkersInJob(const string& job_name, + std::vector* workers) const override { + channel_cache_->ListWorkersInJob(job_name, workers); + } + WorkerInterface* CreateWorker(const string& target) override { if (target == local_target_) { return local_worker_; diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc index 25ff6512a03..b070dd13dd6 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc @@ -50,6 +50,8 @@ namespace { // Fake cache implementation for WorkerEnv. class DummyWorkerCache : public WorkerCacheInterface { void ListWorkers(std::vector* workers) const override {} + void ListWorkersInJob(const string& job_name, + std::vector* workers) const override {} WorkerInterface* CreateWorker(const string& target) override { return nullptr; } diff --git a/tensorflow/core/distributed_runtime/test_utils.h b/tensorflow/core/distributed_runtime/test_utils.h index 48d83845dd3..88a97da34d6 100644 --- a/tensorflow/core/distributed_runtime/test_utils.h +++ b/tensorflow/core/distributed_runtime/test_utils.h @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_interface.h" +#include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { @@ -138,6 +139,19 @@ class TestWorkerCache : public WorkerCacheInterface { } } + void ListWorkersInJob(const string& job_name, + std::vector* workers) const override { + workers->clear(); + for (auto it : workers_) { + DeviceNameUtils::ParsedName device_name; + CHECK(DeviceNameUtils::ParseFullName(it.first, &device_name)); + CHECK(device_name.has_job); + if (job_name == device_name.job) { + workers->push_back(it.first); + } + } + } + WorkerInterface* CreateWorker(const string& target) override { auto it = workers_.find(target); if (it != workers_.end()) { diff --git a/tensorflow/core/distributed_runtime/worker_cache.h b/tensorflow/core/distributed_runtime/worker_cache.h index 8521f8956b9..0c8575b4d5d 100644 --- a/tensorflow/core/distributed_runtime/worker_cache.h +++ b/tensorflow/core/distributed_runtime/worker_cache.h @@ -36,6 +36,8 @@ class WorkerCacheInterface { // Updates *workers with strings naming the remote worker tasks to // which open channels have been established. virtual void ListWorkers(std::vector* workers) const = 0; + virtual void ListWorkersInJob(const string& job_name, + std::vector* workers) const = 0; // If "target" names a remote task for which an RPC channel exists // or can be constructed, returns a pointer to a WorkerInterface object diff --git a/tensorflow/core/distributed_runtime/worker_cache_wrapper.h b/tensorflow/core/distributed_runtime/worker_cache_wrapper.h index 43c3b6285b9..1f309b4361f 100644 --- a/tensorflow/core/distributed_runtime/worker_cache_wrapper.h +++ b/tensorflow/core/distributed_runtime/worker_cache_wrapper.h @@ -32,6 +32,10 @@ class WorkerCacheWrapper : public WorkerCacheInterface { virtual void ListWorkers(std::vector* workers) const { return wrapped_->ListWorkers(workers); } + virtual void ListWorkersInJob(const string& job_name, + std::vector* workers) const { + return wrapped_->ListWorkersInJob(job_name, workers); + } // If "target" names a remote task for which an RPC channel exists // or can be constructed, returns a pointer to a WorkerInterface object diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc index ca6dc1b1dea..c7d0c6b7f30 100644 --- a/tensorflow/core/distributed_runtime/worker_session.cc +++ b/tensorflow/core/distributed_runtime/worker_session.cc @@ -35,6 +35,11 @@ class WorkerFreeListCache : public WorkerCacheInterface { wrapped_->ListWorkers(workers); } + void ListWorkersInJob(const string& job_name, + std::vector* workers) const override { + wrapped_->ListWorkersInJob(job_name, workers); + } + WorkerInterface* CreateWorker(const string& target) override { mutex_lock l(mu_); auto p = workers_.find(target);