[Distributed] Add methods to WorkerCache that selectively list workers by job name.

PiperOrigin-RevId: 209597829
This commit is contained in:
Derek Murray 2018-08-21 08:18:48 -07:00 committed by TensorFlower Gardener
parent aeab291563
commit 5b456c9ab5
11 changed files with 147 additions and 18 deletions

View File

@ -562,6 +562,7 @@ cc_library(
deps = [
":worker_cache",
":worker_interface",
"//tensorflow/core:framework",
],
)

View File

@ -53,6 +53,7 @@ limitations under the License.
#include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@ -167,13 +168,55 @@ class DeviceFinder {
}
// Enumerates all known workers' target. A target name is a
// prefix of a device name. E.g., /job:mnist/replica:0/task:10.
CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
const string& local_device_name = env_->local_devices[0]->name();
std::vector<string> workers;
worker_cache->ListWorkers(&workers);
if (filters_.empty()) {
// If no filters were specified, we list all known workers in
// `worker_cache`.
std::vector<string> workers;
worker_cache->ListWorkers(&workers);
std::swap(workers, targets_);
} else {
// When applying filters, we must include the local worker, even if it
// does not match any of the filters.
CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
const string& local_device_name = env_->local_devices[0]->name();
DeviceNameUtils::ParsedName local_parsed_name;
CHECK(DeviceNameUtils::ParseFullName(local_device_name,
&local_parsed_name));
bool all_filters_have_job = true;
std::unordered_set<string> filter_job_names({local_parsed_name.job});
for (const DeviceNameUtils::ParsedName& filter : filters_) {
all_filters_have_job = all_filters_have_job && filter.has_job;
if (filter.has_job) {
filter_job_names.insert(filter.job);
}
}
std::vector<string> workers;
if (all_filters_have_job) {
// If all of the device filters have a job specified, then we only need
// to list the workers in the jobs named in the filter, because a worker
// in any other job would not match any filter.
for (const string& job_name : filter_job_names) {
VLOG(2) << "Selectively listing workers in job: " << job_name;
std::vector<string> workers_in_job;
worker_cache->ListWorkersInJob(job_name, &workers_in_job);
workers.insert(workers.end(), workers_in_job.begin(),
workers_in_job.end());
}
} else {
// If any of the device filters does not have a job specified, then we
// must list the workers from all jobs.
VLOG(2) << "Listing workers in all jobs because some device "
<< "filter has no job specified. Filters were:";
if (device_filters.empty()) {
VLOG(2) << "- <NO FILTERS>";
} else {
for (const string& filter : device_filters) {
VLOG(2) << "- " << filter;
}
}
worker_cache->ListWorkers(&workers);
}
for (const string& name : workers) {
if (MatchFilters(name) ||
DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) {

View File

@ -163,6 +163,13 @@ class MultiGrpcChannelCache : public CachingGrpcChannelCache {
}
}
void ListWorkersInJob(const string& job_name,
std::vector<string>* workers) override {
for (GrpcChannelCache* cache : caches_) {
cache->ListWorkersInJob(job_name, workers);
}
}
string TranslateTask(const string& target) override {
mutex_lock l(mu_); // could use reader lock
GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
@ -223,6 +230,13 @@ class SparseGrpcChannelCache : public CachingGrpcChannelCache {
}
}
void ListWorkersInJob(const string& job_name,
std::vector<string>* workers) override {
if (job_name == job_id_) {
ListWorkers(workers);
}
}
string TranslateTask(const string& target) override {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(target, &parsed)) {

View File

@ -66,6 +66,8 @@ class GrpcChannelCache {
// /job:<job identifier>/task:<task id>
// e.g. /job:mnist/task:2
virtual void ListWorkers(std::vector<string>* workers) = 0;
virtual void ListWorkersInJob(const string& job_name,
std::vector<string>* workers) = 0;
// If found, returns a gRPC channel that is connected to the remote
// worker named by 'target'. 'target' is of the following

View File

@ -89,13 +89,33 @@ TEST(GrpcChannelTest, HostPorts) {
EXPECT_NE(d_4_1.get(), e_5_2.get());
}
std::vector<string> workers;
cc->ListWorkers(&workers);
EXPECT_EQ(std::vector<string>(
{"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
"/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
"/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
workers);
{
std::vector<string> workers;
cc->ListWorkers(&workers);
EXPECT_EQ(
std::vector<string>(
{"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
"/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
"/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
workers);
}
{
std::vector<string> workers;
cc->ListWorkersInJob("mnist", &workers);
EXPECT_EQ(
std::vector<string>(
{"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
"/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
"/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
workers);
}
{
std::vector<string> workers;
cc->ListWorkersInJob("other", &workers);
EXPECT_TRUE(workers.empty());
}
}
TEST(GrpcChannelTest, SparseHostPorts) {
@ -135,13 +155,30 @@ TEST(GrpcChannelTest, SparseHostPorts) {
EXPECT_NE(d_4_1.get(), e_5_2.get());
}
std::vector<string> workers;
cc->ListWorkers(&workers);
std::sort(workers.begin(), workers.end());
EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
"/job:mnist/replica:0/task:3",
"/job:mnist/replica:0/task:4"}),
workers);
{
std::vector<string> workers;
cc->ListWorkers(&workers);
std::sort(workers.begin(), workers.end());
EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
"/job:mnist/replica:0/task:3",
"/job:mnist/replica:0/task:4"}),
workers);
}
{
std::vector<string> workers;
cc->ListWorkersInJob("mnist", &workers);
EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
"/job:mnist/replica:0/task:3",
"/job:mnist/replica:0/task:4"}),
workers);
}
{
std::vector<string> workers;
cc->ListWorkersInJob("other", &workers);
EXPECT_TRUE(workers.empty());
}
}
TEST(GrpcChannelTest, NewHostPortGrpcChannelValidation) {

View File

@ -54,6 +54,11 @@ class GrpcWorkerCache : public WorkerCachePartial {
channel_cache_->ListWorkers(workers);
}
void ListWorkersInJob(const string& job_name,
std::vector<string>* workers) const override {
channel_cache_->ListWorkersInJob(job_name, workers);
}
WorkerInterface* CreateWorker(const string& target) override {
if (target == local_target_) {
return local_worker_;

View File

@ -50,6 +50,8 @@ namespace {
// Fake cache implementation for WorkerEnv.
class DummyWorkerCache : public WorkerCacheInterface {
void ListWorkers(std::vector<string>* workers) const override {}
void ListWorkersInJob(const string& job_name,
std::vector<string>* workers) const override {}
WorkerInterface* CreateWorker(const string& target) override {
return nullptr;
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@ -138,6 +139,19 @@ class TestWorkerCache : public WorkerCacheInterface {
}
}
void ListWorkersInJob(const string& job_name,
std::vector<string>* workers) const override {
workers->clear();
for (auto it : workers_) {
DeviceNameUtils::ParsedName device_name;
CHECK(DeviceNameUtils::ParseFullName(it.first, &device_name));
CHECK(device_name.has_job);
if (job_name == device_name.job) {
workers->push_back(it.first);
}
}
}
WorkerInterface* CreateWorker(const string& target) override {
auto it = workers_.find(target);
if (it != workers_.end()) {

View File

@ -36,6 +36,8 @@ class WorkerCacheInterface {
// Updates *workers with strings naming the remote worker tasks to
// which open channels have been established.
virtual void ListWorkers(std::vector<string>* workers) const = 0;
virtual void ListWorkersInJob(const string& job_name,
std::vector<string>* workers) const = 0;
// If "target" names a remote task for which an RPC channel exists
// or can be constructed, returns a pointer to a WorkerInterface object

View File

@ -32,6 +32,10 @@ class WorkerCacheWrapper : public WorkerCacheInterface {
virtual void ListWorkers(std::vector<string>* workers) const {
return wrapped_->ListWorkers(workers);
}
virtual void ListWorkersInJob(const string& job_name,
std::vector<string>* workers) const {
return wrapped_->ListWorkersInJob(job_name, workers);
}
// If "target" names a remote task for which an RPC channel exists
// or can be constructed, returns a pointer to a WorkerInterface object

View File

@ -35,6 +35,11 @@ class WorkerFreeListCache : public WorkerCacheInterface {
wrapped_->ListWorkers(workers);
}
void ListWorkersInJob(const string& job_name,
std::vector<string>* workers) const override {
wrapped_->ListWorkersInJob(job_name, workers);
}
WorkerInterface* CreateWorker(const string& target) override {
mutex_lock l(mu_);
auto p = workers_.find(target);