[Distributed] Add methods to WorkerCache that selectively list workers by job name.
PiperOrigin-RevId: 209597829
This commit is contained in:
parent
aeab291563
commit
5b456c9ab5
@ -562,6 +562,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":worker_cache",
|
":worker_cache",
|
||||||
":worker_interface",
|
":worker_interface",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -53,6 +53,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/protobuf/master.pb.h"
|
#include "tensorflow/core/protobuf/master.pb.h"
|
||||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||||
#include "tensorflow/core/public/session_options.h"
|
#include "tensorflow/core/public/session_options.h"
|
||||||
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -167,13 +168,55 @@ class DeviceFinder {
|
|||||||
}
|
}
|
||||||
// Enumerates all known workers' target. A target name is a
|
// Enumerates all known workers' target. A target name is a
|
||||||
// prefix of a device name. E.g., /job:mnist/replica:0/task:10.
|
// prefix of a device name. E.g., /job:mnist/replica:0/task:10.
|
||||||
CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
|
if (filters_.empty()) {
|
||||||
const string& local_device_name = env_->local_devices[0]->name();
|
// If no filters were specified, we list all known workers in
|
||||||
|
// `worker_cache`.
|
||||||
std::vector<string> workers;
|
std::vector<string> workers;
|
||||||
worker_cache->ListWorkers(&workers);
|
worker_cache->ListWorkers(&workers);
|
||||||
if (filters_.empty()) {
|
|
||||||
std::swap(workers, targets_);
|
std::swap(workers, targets_);
|
||||||
} else {
|
} else {
|
||||||
|
// When applying filters, we must include the local worker, even if it
|
||||||
|
// does not match any of the filters.
|
||||||
|
CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
|
||||||
|
const string& local_device_name = env_->local_devices[0]->name();
|
||||||
|
DeviceNameUtils::ParsedName local_parsed_name;
|
||||||
|
CHECK(DeviceNameUtils::ParseFullName(local_device_name,
|
||||||
|
&local_parsed_name));
|
||||||
|
bool all_filters_have_job = true;
|
||||||
|
std::unordered_set<string> filter_job_names({local_parsed_name.job});
|
||||||
|
for (const DeviceNameUtils::ParsedName& filter : filters_) {
|
||||||
|
all_filters_have_job = all_filters_have_job && filter.has_job;
|
||||||
|
if (filter.has_job) {
|
||||||
|
filter_job_names.insert(filter.job);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<string> workers;
|
||||||
|
if (all_filters_have_job) {
|
||||||
|
// If all of the device filters have a job specified, then we only need
|
||||||
|
// to list the workers in the jobs named in the filter, because a worker
|
||||||
|
// in any other job would not match any filter.
|
||||||
|
for (const string& job_name : filter_job_names) {
|
||||||
|
VLOG(2) << "Selectively listing workers in job: " << job_name;
|
||||||
|
std::vector<string> workers_in_job;
|
||||||
|
worker_cache->ListWorkersInJob(job_name, &workers_in_job);
|
||||||
|
workers.insert(workers.end(), workers_in_job.begin(),
|
||||||
|
workers_in_job.end());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If any of the device filters does not have a job specified, then we
|
||||||
|
// must list the workers from all jobs.
|
||||||
|
VLOG(2) << "Listing workers in all jobs because some device "
|
||||||
|
<< "filter has no job specified. Filters were:";
|
||||||
|
if (device_filters.empty()) {
|
||||||
|
VLOG(2) << "- <NO FILTERS>";
|
||||||
|
} else {
|
||||||
|
for (const string& filter : device_filters) {
|
||||||
|
VLOG(2) << "- " << filter;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
worker_cache->ListWorkers(&workers);
|
||||||
|
}
|
||||||
for (const string& name : workers) {
|
for (const string& name : workers) {
|
||||||
if (MatchFilters(name) ||
|
if (MatchFilters(name) ||
|
||||||
DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) {
|
DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) {
|
||||||
|
@ -163,6 +163,13 @@ class MultiGrpcChannelCache : public CachingGrpcChannelCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ListWorkersInJob(const string& job_name,
|
||||||
|
std::vector<string>* workers) override {
|
||||||
|
for (GrpcChannelCache* cache : caches_) {
|
||||||
|
cache->ListWorkersInJob(job_name, workers);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
string TranslateTask(const string& target) override {
|
string TranslateTask(const string& target) override {
|
||||||
mutex_lock l(mu_); // could use reader lock
|
mutex_lock l(mu_); // could use reader lock
|
||||||
GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
|
GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
|
||||||
@ -223,6 +230,13 @@ class SparseGrpcChannelCache : public CachingGrpcChannelCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ListWorkersInJob(const string& job_name,
|
||||||
|
std::vector<string>* workers) override {
|
||||||
|
if (job_name == job_id_) {
|
||||||
|
ListWorkers(workers);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
string TranslateTask(const string& target) override {
|
string TranslateTask(const string& target) override {
|
||||||
DeviceNameUtils::ParsedName parsed;
|
DeviceNameUtils::ParsedName parsed;
|
||||||
if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
|
if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
|
||||||
|
@ -66,6 +66,8 @@ class GrpcChannelCache {
|
|||||||
// /job:<job identifier>/task:<task id>
|
// /job:<job identifier>/task:<task id>
|
||||||
// e.g. /job:mnist/task:2
|
// e.g. /job:mnist/task:2
|
||||||
virtual void ListWorkers(std::vector<string>* workers) = 0;
|
virtual void ListWorkers(std::vector<string>* workers) = 0;
|
||||||
|
virtual void ListWorkersInJob(const string& job_name,
|
||||||
|
std::vector<string>* workers) = 0;
|
||||||
|
|
||||||
// If found, returns a gRPC channel that is connected to the remote
|
// If found, returns a gRPC channel that is connected to the remote
|
||||||
// worker named by 'target'. 'target' is of the following
|
// worker named by 'target'. 'target' is of the following
|
||||||
|
@ -89,13 +89,33 @@ TEST(GrpcChannelTest, HostPorts) {
|
|||||||
EXPECT_NE(d_4_1.get(), e_5_2.get());
|
EXPECT_NE(d_4_1.get(), e_5_2.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
std::vector<string> workers;
|
std::vector<string> workers;
|
||||||
cc->ListWorkers(&workers);
|
cc->ListWorkers(&workers);
|
||||||
EXPECT_EQ(std::vector<string>(
|
EXPECT_EQ(
|
||||||
|
std::vector<string>(
|
||||||
{"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
|
{"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
|
||||||
"/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
|
"/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
|
||||||
"/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
|
"/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
|
||||||
workers);
|
workers);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
std::vector<string> workers;
|
||||||
|
cc->ListWorkersInJob("mnist", &workers);
|
||||||
|
EXPECT_EQ(
|
||||||
|
std::vector<string>(
|
||||||
|
{"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
|
||||||
|
"/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
|
||||||
|
"/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
|
||||||
|
workers);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
std::vector<string> workers;
|
||||||
|
cc->ListWorkersInJob("other", &workers);
|
||||||
|
EXPECT_TRUE(workers.empty());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GrpcChannelTest, SparseHostPorts) {
|
TEST(GrpcChannelTest, SparseHostPorts) {
|
||||||
@ -135,6 +155,7 @@ TEST(GrpcChannelTest, SparseHostPorts) {
|
|||||||
EXPECT_NE(d_4_1.get(), e_5_2.get());
|
EXPECT_NE(d_4_1.get(), e_5_2.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
std::vector<string> workers;
|
std::vector<string> workers;
|
||||||
cc->ListWorkers(&workers);
|
cc->ListWorkers(&workers);
|
||||||
std::sort(workers.begin(), workers.end());
|
std::sort(workers.begin(), workers.end());
|
||||||
@ -142,6 +163,22 @@ TEST(GrpcChannelTest, SparseHostPorts) {
|
|||||||
"/job:mnist/replica:0/task:3",
|
"/job:mnist/replica:0/task:3",
|
||||||
"/job:mnist/replica:0/task:4"}),
|
"/job:mnist/replica:0/task:4"}),
|
||||||
workers);
|
workers);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
std::vector<string> workers;
|
||||||
|
cc->ListWorkersInJob("mnist", &workers);
|
||||||
|
EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
|
||||||
|
"/job:mnist/replica:0/task:3",
|
||||||
|
"/job:mnist/replica:0/task:4"}),
|
||||||
|
workers);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
std::vector<string> workers;
|
||||||
|
cc->ListWorkersInJob("other", &workers);
|
||||||
|
EXPECT_TRUE(workers.empty());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GrpcChannelTest, NewHostPortGrpcChannelValidation) {
|
TEST(GrpcChannelTest, NewHostPortGrpcChannelValidation) {
|
||||||
|
@ -54,6 +54,11 @@ class GrpcWorkerCache : public WorkerCachePartial {
|
|||||||
channel_cache_->ListWorkers(workers);
|
channel_cache_->ListWorkers(workers);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ListWorkersInJob(const string& job_name,
|
||||||
|
std::vector<string>* workers) const override {
|
||||||
|
channel_cache_->ListWorkersInJob(job_name, workers);
|
||||||
|
}
|
||||||
|
|
||||||
WorkerInterface* CreateWorker(const string& target) override {
|
WorkerInterface* CreateWorker(const string& target) override {
|
||||||
if (target == local_target_) {
|
if (target == local_target_) {
|
||||||
return local_worker_;
|
return local_worker_;
|
||||||
|
@ -50,6 +50,8 @@ namespace {
|
|||||||
// Fake cache implementation for WorkerEnv.
|
// Fake cache implementation for WorkerEnv.
|
||||||
class DummyWorkerCache : public WorkerCacheInterface {
|
class DummyWorkerCache : public WorkerCacheInterface {
|
||||||
void ListWorkers(std::vector<string>* workers) const override {}
|
void ListWorkers(std::vector<string>* workers) const override {}
|
||||||
|
void ListWorkersInJob(const string& job_name,
|
||||||
|
std::vector<string>* workers) const override {}
|
||||||
WorkerInterface* CreateWorker(const string& target) override {
|
WorkerInterface* CreateWorker(const string& target) override {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||||
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -138,6 +139,19 @@ class TestWorkerCache : public WorkerCacheInterface {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ListWorkersInJob(const string& job_name,
|
||||||
|
std::vector<string>* workers) const override {
|
||||||
|
workers->clear();
|
||||||
|
for (auto it : workers_) {
|
||||||
|
DeviceNameUtils::ParsedName device_name;
|
||||||
|
CHECK(DeviceNameUtils::ParseFullName(it.first, &device_name));
|
||||||
|
CHECK(device_name.has_job);
|
||||||
|
if (job_name == device_name.job) {
|
||||||
|
workers->push_back(it.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
WorkerInterface* CreateWorker(const string& target) override {
|
WorkerInterface* CreateWorker(const string& target) override {
|
||||||
auto it = workers_.find(target);
|
auto it = workers_.find(target);
|
||||||
if (it != workers_.end()) {
|
if (it != workers_.end()) {
|
||||||
|
@ -36,6 +36,8 @@ class WorkerCacheInterface {
|
|||||||
// Updates *workers with strings naming the remote worker tasks to
|
// Updates *workers with strings naming the remote worker tasks to
|
||||||
// which open channels have been established.
|
// which open channels have been established.
|
||||||
virtual void ListWorkers(std::vector<string>* workers) const = 0;
|
virtual void ListWorkers(std::vector<string>* workers) const = 0;
|
||||||
|
virtual void ListWorkersInJob(const string& job_name,
|
||||||
|
std::vector<string>* workers) const = 0;
|
||||||
|
|
||||||
// If "target" names a remote task for which an RPC channel exists
|
// If "target" names a remote task for which an RPC channel exists
|
||||||
// or can be constructed, returns a pointer to a WorkerInterface object
|
// or can be constructed, returns a pointer to a WorkerInterface object
|
||||||
|
@ -32,6 +32,10 @@ class WorkerCacheWrapper : public WorkerCacheInterface {
|
|||||||
virtual void ListWorkers(std::vector<string>* workers) const {
|
virtual void ListWorkers(std::vector<string>* workers) const {
|
||||||
return wrapped_->ListWorkers(workers);
|
return wrapped_->ListWorkers(workers);
|
||||||
}
|
}
|
||||||
|
virtual void ListWorkersInJob(const string& job_name,
|
||||||
|
std::vector<string>* workers) const {
|
||||||
|
return wrapped_->ListWorkersInJob(job_name, workers);
|
||||||
|
}
|
||||||
|
|
||||||
// If "target" names a remote task for which an RPC channel exists
|
// If "target" names a remote task for which an RPC channel exists
|
||||||
// or can be constructed, returns a pointer to a WorkerInterface object
|
// or can be constructed, returns a pointer to a WorkerInterface object
|
||||||
|
@ -35,6 +35,11 @@ class WorkerFreeListCache : public WorkerCacheInterface {
|
|||||||
wrapped_->ListWorkers(workers);
|
wrapped_->ListWorkers(workers);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ListWorkersInJob(const string& job_name,
|
||||||
|
std::vector<string>* workers) const override {
|
||||||
|
wrapped_->ListWorkersInJob(job_name, workers);
|
||||||
|
}
|
||||||
|
|
||||||
WorkerInterface* CreateWorker(const string& target) override {
|
WorkerInterface* CreateWorker(const string& target) override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
auto p = workers_.find(target);
|
auto p = workers_.find(target);
|
||||||
|
Loading…
Reference in New Issue
Block a user