[Distributed] Add methods to WorkerCache that selectively list workers by job name.
PiperOrigin-RevId: 209597829
This commit is contained in:
parent
aeab291563
commit
5b456c9ab5
@ -562,6 +562,7 @@ cc_library(
|
||||
deps = [
|
||||
":worker_cache",
|
||||
":worker_interface",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -53,6 +53,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/protobuf/master.pb.h"
|
||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -167,13 +168,55 @@ class DeviceFinder {
|
||||
}
|
||||
// Enumerates all known workers' target. A target name is a
|
||||
// prefix of a device name. E.g., /job:mnist/replica:0/task:10.
|
||||
CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
|
||||
const string& local_device_name = env_->local_devices[0]->name();
|
||||
std::vector<string> workers;
|
||||
worker_cache->ListWorkers(&workers);
|
||||
if (filters_.empty()) {
|
||||
// If no filters were specified, we list all known workers in
|
||||
// `worker_cache`.
|
||||
std::vector<string> workers;
|
||||
worker_cache->ListWorkers(&workers);
|
||||
std::swap(workers, targets_);
|
||||
} else {
|
||||
// When applying filters, we must include the local worker, even if it
|
||||
// does not match any of the filters.
|
||||
CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
|
||||
const string& local_device_name = env_->local_devices[0]->name();
|
||||
DeviceNameUtils::ParsedName local_parsed_name;
|
||||
CHECK(DeviceNameUtils::ParseFullName(local_device_name,
|
||||
&local_parsed_name));
|
||||
bool all_filters_have_job = true;
|
||||
std::unordered_set<string> filter_job_names({local_parsed_name.job});
|
||||
for (const DeviceNameUtils::ParsedName& filter : filters_) {
|
||||
all_filters_have_job = all_filters_have_job && filter.has_job;
|
||||
if (filter.has_job) {
|
||||
filter_job_names.insert(filter.job);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<string> workers;
|
||||
if (all_filters_have_job) {
|
||||
// If all of the device filters have a job specified, then we only need
|
||||
// to list the workers in the jobs named in the filter, because a worker
|
||||
// in any other job would not match any filter.
|
||||
for (const string& job_name : filter_job_names) {
|
||||
VLOG(2) << "Selectively listing workers in job: " << job_name;
|
||||
std::vector<string> workers_in_job;
|
||||
worker_cache->ListWorkersInJob(job_name, &workers_in_job);
|
||||
workers.insert(workers.end(), workers_in_job.begin(),
|
||||
workers_in_job.end());
|
||||
}
|
||||
} else {
|
||||
// If any of the device filters does not have a job specified, then we
|
||||
// must list the workers from all jobs.
|
||||
VLOG(2) << "Listing workers in all jobs because some device "
|
||||
<< "filter has no job specified. Filters were:";
|
||||
if (device_filters.empty()) {
|
||||
VLOG(2) << "- <NO FILTERS>";
|
||||
} else {
|
||||
for (const string& filter : device_filters) {
|
||||
VLOG(2) << "- " << filter;
|
||||
}
|
||||
}
|
||||
worker_cache->ListWorkers(&workers);
|
||||
}
|
||||
for (const string& name : workers) {
|
||||
if (MatchFilters(name) ||
|
||||
DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) {
|
||||
|
@ -163,6 +163,13 @@ class MultiGrpcChannelCache : public CachingGrpcChannelCache {
|
||||
}
|
||||
}
|
||||
|
||||
void ListWorkersInJob(const string& job_name,
|
||||
std::vector<string>* workers) override {
|
||||
for (GrpcChannelCache* cache : caches_) {
|
||||
cache->ListWorkersInJob(job_name, workers);
|
||||
}
|
||||
}
|
||||
|
||||
string TranslateTask(const string& target) override {
|
||||
mutex_lock l(mu_); // could use reader lock
|
||||
GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
|
||||
@ -223,6 +230,13 @@ class SparseGrpcChannelCache : public CachingGrpcChannelCache {
|
||||
}
|
||||
}
|
||||
|
||||
void ListWorkersInJob(const string& job_name,
|
||||
std::vector<string>* workers) override {
|
||||
if (job_name == job_id_) {
|
||||
ListWorkers(workers);
|
||||
}
|
||||
}
|
||||
|
||||
string TranslateTask(const string& target) override {
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
|
||||
|
@ -66,6 +66,8 @@ class GrpcChannelCache {
|
||||
// /job:<job identifier>/task:<task id>
|
||||
// e.g. /job:mnist/task:2
|
||||
virtual void ListWorkers(std::vector<string>* workers) = 0;
|
||||
virtual void ListWorkersInJob(const string& job_name,
|
||||
std::vector<string>* workers) = 0;
|
||||
|
||||
// If found, returns a gRPC channel that is connected to the remote
|
||||
// worker named by 'target'. 'target' is of the following
|
||||
|
@ -89,13 +89,33 @@ TEST(GrpcChannelTest, HostPorts) {
|
||||
EXPECT_NE(d_4_1.get(), e_5_2.get());
|
||||
}
|
||||
|
||||
std::vector<string> workers;
|
||||
cc->ListWorkers(&workers);
|
||||
EXPECT_EQ(std::vector<string>(
|
||||
{"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
|
||||
"/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
|
||||
"/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
|
||||
workers);
|
||||
{
|
||||
std::vector<string> workers;
|
||||
cc->ListWorkers(&workers);
|
||||
EXPECT_EQ(
|
||||
std::vector<string>(
|
||||
{"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
|
||||
"/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
|
||||
"/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
|
||||
workers);
|
||||
}
|
||||
|
||||
{
|
||||
std::vector<string> workers;
|
||||
cc->ListWorkersInJob("mnist", &workers);
|
||||
EXPECT_EQ(
|
||||
std::vector<string>(
|
||||
{"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
|
||||
"/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
|
||||
"/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
|
||||
workers);
|
||||
}
|
||||
|
||||
{
|
||||
std::vector<string> workers;
|
||||
cc->ListWorkersInJob("other", &workers);
|
||||
EXPECT_TRUE(workers.empty());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GrpcChannelTest, SparseHostPorts) {
|
||||
@ -135,13 +155,30 @@ TEST(GrpcChannelTest, SparseHostPorts) {
|
||||
EXPECT_NE(d_4_1.get(), e_5_2.get());
|
||||
}
|
||||
|
||||
std::vector<string> workers;
|
||||
cc->ListWorkers(&workers);
|
||||
std::sort(workers.begin(), workers.end());
|
||||
EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
|
||||
"/job:mnist/replica:0/task:3",
|
||||
"/job:mnist/replica:0/task:4"}),
|
||||
workers);
|
||||
{
|
||||
std::vector<string> workers;
|
||||
cc->ListWorkers(&workers);
|
||||
std::sort(workers.begin(), workers.end());
|
||||
EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
|
||||
"/job:mnist/replica:0/task:3",
|
||||
"/job:mnist/replica:0/task:4"}),
|
||||
workers);
|
||||
}
|
||||
|
||||
{
|
||||
std::vector<string> workers;
|
||||
cc->ListWorkersInJob("mnist", &workers);
|
||||
EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
|
||||
"/job:mnist/replica:0/task:3",
|
||||
"/job:mnist/replica:0/task:4"}),
|
||||
workers);
|
||||
}
|
||||
|
||||
{
|
||||
std::vector<string> workers;
|
||||
cc->ListWorkersInJob("other", &workers);
|
||||
EXPECT_TRUE(workers.empty());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GrpcChannelTest, NewHostPortGrpcChannelValidation) {
|
||||
|
@ -54,6 +54,11 @@ class GrpcWorkerCache : public WorkerCachePartial {
|
||||
channel_cache_->ListWorkers(workers);
|
||||
}
|
||||
|
||||
void ListWorkersInJob(const string& job_name,
|
||||
std::vector<string>* workers) const override {
|
||||
channel_cache_->ListWorkersInJob(job_name, workers);
|
||||
}
|
||||
|
||||
WorkerInterface* CreateWorker(const string& target) override {
|
||||
if (target == local_target_) {
|
||||
return local_worker_;
|
||||
|
@ -50,6 +50,8 @@ namespace {
|
||||
// Fake cache implementation for WorkerEnv.
|
||||
class DummyWorkerCache : public WorkerCacheInterface {
|
||||
void ListWorkers(std::vector<string>* workers) const override {}
|
||||
void ListWorkersInJob(const string& job_name,
|
||||
std::vector<string>* workers) const override {}
|
||||
WorkerInterface* CreateWorker(const string& target) override {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -138,6 +139,19 @@ class TestWorkerCache : public WorkerCacheInterface {
|
||||
}
|
||||
}
|
||||
|
||||
void ListWorkersInJob(const string& job_name,
|
||||
std::vector<string>* workers) const override {
|
||||
workers->clear();
|
||||
for (auto it : workers_) {
|
||||
DeviceNameUtils::ParsedName device_name;
|
||||
CHECK(DeviceNameUtils::ParseFullName(it.first, &device_name));
|
||||
CHECK(device_name.has_job);
|
||||
if (job_name == device_name.job) {
|
||||
workers->push_back(it.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
WorkerInterface* CreateWorker(const string& target) override {
|
||||
auto it = workers_.find(target);
|
||||
if (it != workers_.end()) {
|
||||
|
@ -36,6 +36,8 @@ class WorkerCacheInterface {
|
||||
// Updates *workers with strings naming the remote worker tasks to
|
||||
// which open channels have been established.
|
||||
virtual void ListWorkers(std::vector<string>* workers) const = 0;
|
||||
virtual void ListWorkersInJob(const string& job_name,
|
||||
std::vector<string>* workers) const = 0;
|
||||
|
||||
// If "target" names a remote task for which an RPC channel exists
|
||||
// or can be constructed, returns a pointer to a WorkerInterface object
|
||||
|
@ -32,6 +32,10 @@ class WorkerCacheWrapper : public WorkerCacheInterface {
|
||||
virtual void ListWorkers(std::vector<string>* workers) const {
|
||||
return wrapped_->ListWorkers(workers);
|
||||
}
|
||||
virtual void ListWorkersInJob(const string& job_name,
|
||||
std::vector<string>* workers) const {
|
||||
return wrapped_->ListWorkersInJob(job_name, workers);
|
||||
}
|
||||
|
||||
// If "target" names a remote task for which an RPC channel exists
|
||||
// or can be constructed, returns a pointer to a WorkerInterface object
|
||||
|
@ -35,6 +35,11 @@ class WorkerFreeListCache : public WorkerCacheInterface {
|
||||
wrapped_->ListWorkers(workers);
|
||||
}
|
||||
|
||||
void ListWorkersInJob(const string& job_name,
|
||||
std::vector<string>* workers) const override {
|
||||
wrapped_->ListWorkersInJob(job_name, workers);
|
||||
}
|
||||
|
||||
WorkerInterface* CreateWorker(const string& target) override {
|
||||
mutex_lock l(mu_);
|
||||
auto p = workers_.find(target);
|
||||
|
Loading…
Reference in New Issue
Block a user