Use device attributes from group resolution

1. CompleteInstanceLocal is changed to use device attributes from GroupRec, so it no longer needs to issue GetStatus calls to get device attributes.
2. CollectiveParamResolverDistributed::CompleteParamsAsync updates device attributes from group resolution to the local DeviceResolver, so that RecvFromPeer can query the attributes of the target device.
3. DeviceResolver is simplified since it no longer needs to issue RPCs.

PiperOrigin-RevId: 330752257
Change-Id: Id6d5a1a3b1115928570b354fc27ebf7351c9d6ca
This commit is contained in:
Ran Chen 2020-09-09 10:40:59 -07:00 committed by TensorFlower Gardener
parent f7620b089c
commit 9d584860d7
16 changed files with 307 additions and 487 deletions

View File

@ -546,6 +546,9 @@ cc_library(
deps = [ deps = [
":device_mgr", ":device_mgr",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core/framework:device_attributes_proto_cc",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
], ],
) )

View File

@ -495,7 +495,8 @@ void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
void CollectiveParamResolverLocal::InitInstanceSharedParams( void CollectiveParamResolverLocal::InitInstanceSharedParams(
const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir, const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
const StatusCallback& done) { const StatusCallback& done) TF_NO_THREAD_SAFETY_ANALYSIS {
std::vector<DeviceAttributes> attributes;
ir->shared.instance = cp->instance; ir->shared.instance = cp->instance;
{ {
mutex_lock gl(gr->mu); mutex_lock gl(gr->mu);
@ -504,10 +505,12 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams(
ir->shared.instance.task_names.clear(); ir->shared.instance.task_names.clear();
ir->shared.instance.device_names.reserve(gr->devices.size()); ir->shared.instance.device_names.reserve(gr->devices.size());
ir->shared.instance.task_names.reserve(gr->devices.size()); ir->shared.instance.task_names.reserve(gr->devices.size());
attributes.reserve(gr->devices.size());
for (const auto& item : gr->devices) { for (const auto& item : gr->devices) {
ir->shared.instance.device_names.push_back(item.first); ir->shared.instance.device_names.push_back(item.first);
ir->shared.instance.task_names.push_back( ir->shared.instance.task_names.push_back(
TaskNameFromDeviceName(item.first)); TaskNameFromDeviceName(item.first));
attributes.push_back(item.second);
} }
VLOG(2) << "Initialized names for instance: " VLOG(2) << "Initialized names for instance: "
<< ir->shared.instance.ToString(); << ir->shared.instance.ToString();
@ -526,6 +529,9 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams(
// GetDeviceAttributesAsync will use those fields to launch RPCs. // GetDeviceAttributesAsync will use those fields to launch RPCs.
CompleteTaskIsLocal(task_name_, &ir->shared); CompleteTaskIsLocal(task_name_, &ir->shared);
// TODO(b/151232436): clean up the following code since we no longer need to
// execute it in a callback.
// Because the callback may execute in a different thread, we release // Because the callback may execute in a different thread, we release
// ir->out_mu here. Before releasing, we mark it as unavailable for other // ir->out_mu here. Before releasing, we mark it as unavailable for other
// threads. // threads.
@ -533,30 +539,24 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams(
const auto device_names = ir->shared.instance.device_names; const auto device_names = ir->shared.instance.device_names;
const auto task_names = ir->shared.instance.task_names; const auto task_names = ir->shared.instance.task_names;
ir->out_mu.unlock(); ir->out_mu.unlock();
std::vector<DeviceAttributes>* attributes = new std::vector<DeviceAttributes>; auto complete_init = [this, gr, cp, ir, attributes, done](const Status& s)
// Suppress linter warning about access to shared without mutex because in
// principle the members are locked due to out_mu_available=false.
dev_resolver_->GetAllDeviceAttributesAsync(
ir->shared.instance.device_names, // NOLINT
ir->shared.instance.task_names, // NOLINT
attributes,
[this, gr, cp, ir, attributes, done](const Status& s)
TF_EXCLUSIVE_LOCK_FUNCTION(ir->out_mu) { TF_EXCLUSIVE_LOCK_FUNCTION(ir->out_mu) {
// Then we recover the lock in the callback thread that will hold it // Then we recover the lock in the callback thread
// through the rest of the call chain. Signal the cv now, any // that will hold it through the rest of the call
// waiting threads will wake only when out_mu is released later. // chain. Signal the cv now, any waiting threads
// will wake only when out_mu is released later.
ir->out_mu.lock(); ir->out_mu.lock();
DCHECK(!ir->out_mu_available); DCHECK(!ir->out_mu_available);
ir->out_mu_available = true; ir->out_mu_available = true;
ir->out_cv.notify_all(); ir->out_cv.notify_all();
if (s.ok()) { if (s.ok()) {
CompleteDefaultRanking(gr, cp, ir, *attributes); CompleteDefaultRanking(gr, cp, ir, attributes);
done(Status::OK()); done(Status::OK());
} else { } else {
done(s); done(s);
} }
delete attributes; };
}); complete_init(Status::OK());
} }
// NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks // NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks

View File

@ -15,41 +15,34 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_resolver_local.h" #include "tensorflow/core/common_runtime/device_resolver_local.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow { namespace tensorflow {
void DeviceResolverLocal::GetAllDeviceAttributesAsync( Status DeviceResolverLocal::GetDeviceAttributes(const string& device,
const std::vector<string>& devices, const std::vector<string>& tasks, DeviceAttributes* attributes) {
std::vector<DeviceAttributes>* attributes, const StatusCallback& done) {
attributes->clear();
for (const string& device_name : devices) {
Device* dev;
Status s = dev_mgr_->LookupDevice(device_name, &dev);
if (!s.ok()) {
done(s);
return;
}
attributes->push_back(dev->attributes());
}
done(Status::OK());
}
void DeviceResolverLocal::GetDeviceAttributesAsync(const string& device,
const string& task,
DeviceAttributes* attributes,
const StatusCallback& done) {
Device* dev; Device* dev;
// LookupDevice returns InvalidArgument if the device is not found.
Status s = dev_mgr_->LookupDevice(device, &dev); Status s = dev_mgr_->LookupDevice(device, &dev);
if (s.ok()) { if (errors::IsInvalidArgument(s)) {
*attributes = dev->attributes(); return errors::NotFound(device, " not found");
} else if (!s.ok()) {
return s;
} }
done(s); *attributes = dev->attributes();
return Status::OK();
} }
Status DeviceResolverLocal::GetTaskCached( Status DeviceResolverLocal::GetAllDeviceAttributes(
const string& task, std::vector<DeviceAttributes>* attributes) { const string& task, std::vector<DeviceAttributes>* attributes) {
return errors::Internal( return errors::Internal(
"GetTaskCached is not supposed to be called in local collectives"); "GetTaskCached is not supposed to be called in local collectives");
} }
Status DeviceResolverLocal::UpdateDeviceAttributes(
const std::vector<DeviceAttributes>& attributes) {
return errors::Internal(
"UpdateDeviceAttributes shouldn't be called with local collectives");
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -16,9 +16,11 @@ limitations under the License.
#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
#include <string> #include <string>
#include <vector>
#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow { namespace tensorflow {
class DeviceMgr; class DeviceMgr;
@ -26,21 +28,16 @@ class DeviceMgr;
// Implements DeviceResolverInterface in a single-task context. // Implements DeviceResolverInterface in a single-task context.
class DeviceResolverLocal : public DeviceResolverInterface { class DeviceResolverLocal : public DeviceResolverInterface {
public: public:
DeviceResolverLocal(const DeviceMgr* dev_mgr) : dev_mgr_(dev_mgr) {} explicit DeviceResolverLocal(const DeviceMgr* dev_mgr) : dev_mgr_(dev_mgr) {}
virtual ~DeviceResolverLocal() {} Status GetDeviceAttributes(const string& device,
DeviceAttributes* attributes) override;
void GetAllDeviceAttributesAsync(const std::vector<string>& devices, Status GetAllDeviceAttributes(
const std::vector<string>& tasks, const string& task, std::vector<DeviceAttributes>* attributes) override;
std::vector<DeviceAttributes>* attributes,
const StatusCallback& done) override;
void GetDeviceAttributesAsync(const string& device, const string& task, Status UpdateDeviceAttributes(
DeviceAttributes* attributes, const std::vector<DeviceAttributes>& attributes) override;
const StatusCallback& done) override;
Status GetTaskCached(const string& task,
std::vector<DeviceAttributes>* attributes) override;
protected: protected:
const DeviceMgr* dev_mgr_; const DeviceMgr* dev_mgr_;

View File

@ -17,7 +17,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
@ -46,35 +45,27 @@ class DeviceResolverLocalTest : public ::testing::Test {
}; };
TEST_F(DeviceResolverLocalTest, GetDeviceAttributesKnown) { TEST_F(DeviceResolverLocalTest, GetDeviceAttributesKnown) {
std::vector<DeviceAttributes> attributes; DeviceAttributes attributes;
std::vector<string> devices{"/job:localhost/replica:0/task:0/device:CPU:1", TF_EXPECT_OK(drl_->GetDeviceAttributes(
"/job:localhost/replica:0/task:0/device:CPU:2"}; "/job:localhost/replica:0/task:0/device:CPU:1", &attributes));
Notification note; EXPECT_EQ(attributes.name(), "/job:localhost/replica:0/task:0/device:CPU:1");
Status status;
drl_->GetAllDeviceAttributesAsync(devices, /*tasks=*/{}, &attributes,
[&note, &status](const Status& s) {
status = s;
note.Notify();
});
note.WaitForNotification();
TF_EXPECT_OK(status);
EXPECT_EQ(2, attributes.size());
} }
TEST_F(DeviceResolverLocalTest, GetDeviceAttributesUnknown) { TEST_F(DeviceResolverLocalTest, GetDeviceAttributesUnknown) {
DeviceAttributes attributes;
EXPECT_TRUE(errors::IsNotFound(drl_->GetDeviceAttributes(
"/job:localhost/replica:0/task:0/device:CPU:9", &attributes)));
}
TEST_F(DeviceResolverLocalTest, GetAllDeviceAttributes) {
std::vector<DeviceAttributes> attributes; std::vector<DeviceAttributes> attributes;
// In some builds there may be 1 GPU, but there should never be 9. EXPECT_TRUE(errors::IsInternal(
std::vector<string> devices{"/job:localhost/replica:0/task:0/device:GPU:9"}; drl_->GetAllDeviceAttributes(/*task*/ "", &attributes)));
Notification note; }
Status status;
drl_->GetAllDeviceAttributesAsync(devices, /*tasks=*/{}, &attributes, TEST_F(DeviceResolverLocalTest, UpdateDeviceAttributes) {
[&note, &status](const Status& s) { std::vector<DeviceAttributes> attributes;
status = s; EXPECT_TRUE(errors::IsInternal(drl_->UpdateDeviceAttributes(attributes)));
note.Notify();
});
note.WaitForNotification();
EXPECT_FALSE(status.ok());
EXPECT_EQ(0, attributes.size());
} }
} // namespace } // namespace

View File

@ -574,6 +574,7 @@ cc_library(
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors", "//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
@ -621,6 +622,8 @@ cc_library(
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
], ],
) )

View File

@ -106,12 +106,19 @@ void CollectiveParamResolverDistributed::CompleteParamsAsync(
CancellationManager* cancel_mgr, const StatusCallback& done) { CancellationManager* cancel_mgr, const StatusCallback& done) {
VLOG(1) << "CompleteParams distributed " << device.name() << " for " << cp VLOG(1) << "CompleteParams distributed " << device.name() << " for " << cp
<< ": " << cp->ToString(); << ": " << cp->ToString();
CompleteGroupDistributed(device, cp, cancel_mgr, CompleteGroupDistributed(
[this, device, cp, cancel_mgr, done]( device, cp, cancel_mgr,
const Status& s, const GroupRec* gr) { [this, device, cp, cancel_mgr, done](Status s, const GroupRec* gr) {
if (s.ok()) { if (s.ok()) {
CompleteInstanceDistributed( std::vector<DeviceAttributes> attributes;
device.name(), gr, cp, cancel_mgr, done); mutex_lock l(gr->mu);
for (const auto& item : gr->devices) {
attributes.push_back(item.second);
}
s = dev_resolver_->UpdateDeviceAttributes(attributes);
}
if (s.ok()) {
CompleteInstanceDistributed(device.name(), gr, cp, cancel_mgr, done);
} else { } else {
done(s); done(s);
} }

View File

@ -157,7 +157,7 @@ class DeviceResDistTest : public ::testing::Test {
dv->push_back(d->name()); dv->push_back(d->name());
} }
dev_resolvers_[worker_name] = absl::make_unique<DeviceResolverDistributed>( dev_resolvers_[worker_name] = absl::make_unique<DeviceResolverDistributed>(
device_mgrs_[worker_name].get(), &wc_, worker_name); device_mgrs_[worker_name].get());
cp_resolvers_[worker_name] = cp_resolvers_[worker_name] =
absl::make_unique<CollectiveParamResolverDistributed>( absl::make_unique<CollectiveParamResolverDistributed>(
config, device_mgrs_[worker_name].get(), config, device_mgrs_[worker_name].get(),
@ -256,6 +256,7 @@ class DeviceResDistTest : public ::testing::Test {
EXPECT_EQ(cp_[device_name].instance.device_names.size(), dev_count); EXPECT_EQ(cp_[device_name].instance.device_names.size(), dev_count);
EXPECT_EQ(cp_[device_name].instance.device_names[idx], device_name); EXPECT_EQ(cp_[device_name].instance.device_names[idx], device_name);
EXPECT_EQ(cp_[device_name].instance.task_names[idx], task_name); EXPECT_EQ(cp_[device_name].instance.task_names[idx], task_name);
ValidateDeviceResolver(cp_[device_name], task_name);
if (idx > 0) { if (idx > 0) {
EXPECT_EQ(cp_[dev0].group.runtime_details.communicator_key, EXPECT_EQ(cp_[dev0].group.runtime_details.communicator_key,
cp_[device_name].group.runtime_details.communicator_key); cp_[device_name].group.runtime_details.communicator_key);
@ -270,6 +271,14 @@ class DeviceResDistTest : public ::testing::Test {
} }
} }
void ValidateDeviceResolver(const CollectiveParams& cp, const string& task) {
for (const string& device_name : cp.instance.device_names) {
DeviceAttributes attributes;
TF_ASSERT_OK(
dev_resolvers_[task]->GetDeviceAttributes(device_name, &attributes));
}
}
void RestartWorker(int worker_idx, int num_workers, int num_devices, void RestartWorker(int worker_idx, int num_workers, int num_devices,
const string& device_type, bool nccl) { const string& device_type, bool nccl) {
string worker_name = string worker_name =

View File

@ -160,26 +160,17 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
done(s); done(s);
}; };
// Logic to execute once we have the device attributes for the server-side Status s = dev_resolver_->GetDeviceAttributes(peer_device,
// device. &state->server_attributes);
auto dev_attributes_callback = [this, state, peer_device, peer_task, key,
to_device, to_device_ctx, to_alloc_attr,
to_tensor, client_locality,
recv_buf_callback](const Status& s) {
if (!s.ok()) { if (!s.ok()) {
recv_buf_callback(s); recv_buf_callback(s);
} else { return;
state->call.reset(new RecvBufCall(
step_id_, peer_device, peer_task, key, to_device, to_device_ctx,
to_alloc_attr, to_tensor, client_locality, state->server_attributes,
&cancel_mgr_, worker_cache_));
state->call->Start(recv_buf_callback);
} }
}; state->call.reset(
new RecvBufCall(step_id_, peer_device, peer_task, key, to_device,
dev_resolver_->GetDeviceAttributesAsync(peer_device, peer_task, to_device_ctx, to_alloc_attr, to_tensor, client_locality,
&state->server_attributes, state->server_attributes, &cancel_mgr_, worker_cache_));
dev_attributes_callback); state->call->Start(recv_buf_callback);
} }
void CollectiveRemoteAccessDistributed::CheckPeerHealth( void CollectiveRemoteAccessDistributed::CheckPeerHealth(
@ -209,7 +200,7 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth(
[this, req, resp, wi, peer_task, done](Status s) { [this, req, resp, wi, peer_task, done](Status s) {
std::vector<DeviceAttributes> cached_attrs; std::vector<DeviceAttributes> cached_attrs;
if (s.ok()) { if (s.ok()) {
s = dev_resolver_->GetTaskCached(peer_task, &cached_attrs); s = dev_resolver_->GetAllDeviceAttributes(peer_task, &cached_attrs);
} }
if (s.ok()) { if (s.ok()) {
absl::flat_hash_set<uint64> remote_incarnations; absl::flat_hash_set<uint64> remote_incarnations;

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/test_utils.h" #include "tensorflow/core/distributed_runtime/test_utils.h"
#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
@ -30,7 +31,6 @@ limitations under the License.
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/transport_options.pb.h" #include "tensorflow/core/protobuf/transport_options.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h" #include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
// The only interesting method on CollectiveRemoteAccessDistributed // The only interesting method on CollectiveRemoteAccessDistributed
// that's not on CollectiveRemoteAccessLocal is RecvFromPeer which // that's not on CollectiveRemoteAccessLocal is RecvFromPeer which
@ -224,6 +224,18 @@ class CollRMADistTest : public ::testing::Test {
} }
} }
// Populates all device resolvers with device attributes of the cluster. This
// should be called in the beginning of all tests unless you would like to
// simulate a situation that is before parameter resolution.
void ResolveDeviceAttributes() {
for (auto& dev_resolver_item : dev_resolvers_) {
DeviceResolverDistributed* dev_resolver = dev_resolver_item.second;
for (const auto& item : dev_by_task_) {
TF_CHECK_OK(dev_resolver->UpdateDeviceAttributes(item.second));
}
}
}
void DefineWorker(const string& worker_name, const string& device_type, void DefineWorker(const string& worker_name, const string& device_type,
int num_devices, bool is_failed = false) { int num_devices, bool is_failed = false) {
std::vector<std::unique_ptr<Device>> devices; std::vector<std::unique_ptr<Device>> devices;
@ -234,13 +246,12 @@ class CollRMADistTest : public ::testing::Test {
} }
DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices)); DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
device_mgrs_.push_back(dev_mgr); device_mgrs_.push_back(dev_mgr);
std::vector<string>* dv = &dev_by_task_[worker_name]; std::vector<DeviceAttributes>* dv = &dev_by_task_[worker_name];
dv->clear(); dv->clear();
for (auto d : dev_mgr->ListDevices()) { for (auto d : dev_mgr->ListDevices()) {
dv->push_back(d->name()); dv->push_back(d->attributes());
} }
DeviceResolverDistributed* dev_res = DeviceResolverDistributed* dev_res = new DeviceResolverDistributed(dev_mgr);
new DeviceResolverDistributed(dev_mgr, &wc_, worker_name);
dev_resolvers_[worker_name] = dev_res; dev_resolvers_[worker_name] = dev_res;
FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res, is_failed); FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res, is_failed);
workers_.push_back(fw); workers_.push_back(fw);
@ -254,6 +265,9 @@ class CollRMADistTest : public ::testing::Test {
delete it->second; delete it->second;
dev_resolvers_.erase(it); dev_resolvers_.erase(it);
} }
// After restarting a worker, the other workers already have the device
// attributes of the old worker. We don't broadcast device attributes of the
// new worker to mimic the real world.
DefineWorker(worker_name, device_type, num_devices, is_failed); DefineWorker(worker_name, device_type, num_devices, is_failed);
} }
@ -269,7 +283,7 @@ class CollRMADistTest : public ::testing::Test {
CancellationManager cm_; CancellationManager cm_;
std::vector<DeviceMgr*> device_mgrs_; std::vector<DeviceMgr*> device_mgrs_;
std::unordered_map<string, DeviceResolverDistributed*> dev_resolvers_; std::unordered_map<string, DeviceResolverDistributed*> dev_resolvers_;
std::unordered_map<string, std::vector<string>> dev_by_task_; std::unordered_map<string, std::vector<DeviceAttributes>> dev_by_task_;
std::shared_ptr<UnboundedWorkQueue> work_queue_; std::shared_ptr<UnboundedWorkQueue> work_queue_;
std::vector<FakeWorker*> workers_; std::vector<FakeWorker*> workers_;
std::unique_ptr<CollectiveRemoteAccessDistributed> rma_; std::unique_ptr<CollectiveRemoteAccessDistributed> rma_;
@ -284,6 +298,7 @@ class CollRMADistTest : public ::testing::Test {
}; };
TEST_F(CollRMADistTest, ProdFirstOK) { TEST_F(CollRMADistTest, ProdFirstOK) {
ResolveDeviceAttributes();
Notification consumer_note; Notification consumer_note;
Notification producer_note; Notification producer_note;
Status consumer_status; Status consumer_status;
@ -319,6 +334,7 @@ TEST_F(CollRMADistTest, ProdFirstOK) {
} }
TEST_F(CollRMADistTest, ConsFirstOK) { TEST_F(CollRMADistTest, ConsFirstOK) {
ResolveDeviceAttributes();
Notification consumer_note; Notification consumer_note;
Notification producer_note; Notification producer_note;
Status consumer_status; Status consumer_status;
@ -354,6 +370,7 @@ TEST_F(CollRMADistTest, ConsFirstOK) {
} }
TEST_F(CollRMADistTest, ConsFirstAbort) { TEST_F(CollRMADistTest, ConsFirstAbort) {
ResolveDeviceAttributes();
Notification consumer_note; Notification consumer_note;
Status consumer_status; Status consumer_status;
const string kBufKey = "fake_buf_key"; const string kBufKey = "fake_buf_key";
@ -377,6 +394,7 @@ TEST_F(CollRMADistTest, ConsFirstAbort) {
} }
TEST_F(CollRMADistTest, WorkerRestart) { TEST_F(CollRMADistTest, WorkerRestart) {
ResolveDeviceAttributes();
Notification consumer_note; Notification consumer_note;
Notification producer_note; Notification producer_note;
Status consumer_status; Status consumer_status;
@ -428,21 +446,7 @@ TEST_F(CollRMADistTest, WorkerRestart) {
} }
TEST_F(CollRMADistTest, CheckHealthOKWithCachedAttr) { TEST_F(CollRMADistTest, CheckHealthOKWithCachedAttr) {
DeviceAttributes attr; ResolveDeviceAttributes();
Status get_attr_status;
Notification get_attr_done;
// Call GetDeviceAttributesAsync to cache the device attributes of a remote
// worker.
dev_resolvers_["/job:worker/replica:0/task:0"]->GetDeviceAttributesAsync(
"/job:worker/replica:0/task:1/device:CPU:0",
"/job:worker/replica:0/task:1", &attr,
[&get_attr_status, &get_attr_done](const Status& s) {
get_attr_status = s;
get_attr_done.Notify();
});
get_attr_done.WaitForNotification();
TF_ASSERT_OK(get_attr_status);
Status check_health_status; Status check_health_status;
Notification check_health_done; Notification check_health_done;
rma_->CheckPeerHealth( rma_->CheckPeerHealth(
@ -469,21 +473,7 @@ TEST_F(CollRMADistTest, CheckHealthOKWithoutCachedAttr) {
} }
TEST_F(CollRMADistTest, CheckHealthRestarted) { TEST_F(CollRMADistTest, CheckHealthRestarted) {
DeviceAttributes attr; ResolveDeviceAttributes();
Status get_attr_status;
Notification get_attr_done;
// Call GetDeviceAttributesAsync to cache the device attributes of a remote
// worker.
dev_resolvers_["/job:worker/replica:0/task:0"]->GetDeviceAttributesAsync(
"/job:worker/replica:0/task:1/device:CPU:0",
"/job:worker/replica:0/task:1", &attr,
[&get_attr_status, &get_attr_done](const Status& s) {
get_attr_status = s;
get_attr_done.Notify();
});
get_attr_done.WaitForNotification();
TF_ASSERT_OK(get_attr_status);
RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1); RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1);
Status check_health_status; Status check_health_status;
@ -499,21 +489,7 @@ TEST_F(CollRMADistTest, CheckHealthRestarted) {
} }
TEST_F(CollRMADistTest, CheckHealthFailedPeer) { TEST_F(CollRMADistTest, CheckHealthFailedPeer) {
DeviceAttributes attr; ResolveDeviceAttributes();
Status get_attr_status;
Notification get_attr_done;
// Call GetDeviceAttributesAsync to cache the device attributes of a remote
// worker.
dev_resolvers_["/job:worker/replica:0/task:0"]->GetDeviceAttributesAsync(
"/job:worker/replica:0/task:1/device:CPU:0",
"/job:worker/replica:0/task:1", &attr,
[&get_attr_status, &get_attr_done](const Status& s) {
get_attr_status = s;
get_attr_done.Notify();
});
get_attr_done.WaitForNotification();
TF_ASSERT_OK(get_attr_status);
RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1, RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1,
/*is_failed*/ true); /*is_failed*/ true);
@ -530,25 +506,8 @@ TEST_F(CollRMADistTest, CheckHealthFailedPeer) {
} }
TEST_F(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) { TEST_F(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) {
ResolveDeviceAttributes();
RestartWorker("/job:worker/replica:0/task:1", "GPU", /*num_devices*/ 1); RestartWorker("/job:worker/replica:0/task:1", "GPU", /*num_devices*/ 1);
DeviceAttributes attr;
Status get_attr_status;
Notification get_attr_done;
// Call GetDeviceAttributesAsync to cache the device attributes of a remote
// worker.
dev_resolvers_["/job:worker/replica:0/task:0"]->GetDeviceAttributesAsync(
"/job:worker/replica:0/task:1/device:GPU:0",
"/job:worker/replica:0/task:1", &attr,
[&get_attr_status, &get_attr_done](const Status& s) {
get_attr_status = s;
get_attr_done.Notify();
});
get_attr_done.WaitForNotification();
TF_ASSERT_OK(get_attr_status);
RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1);
Status check_health_status; Status check_health_status;
Notification check_health_done; Notification check_health_done;
rma_->CheckPeerHealth( rma_->CheckPeerHealth(

View File

@ -15,105 +15,30 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow { namespace tensorflow {
DeviceResolverDistributed::DeviceResolverDistributed(
const DeviceMgr* dev_mgr, WorkerCacheInterface* worker_cache,
const string& task_name)
: dev_mgr_(dev_mgr), worker_cache_(worker_cache), task_name_(task_name) {}
void DeviceResolverDistributed::GetDeviceAttributesAsync( DeviceResolverDistributed::DeviceResolverDistributed(const DeviceMgr* dev_mgr) {
const string& device, const string& task, DeviceAttributes* attributes, mutex_lock l(mu_);
const StatusCallback& done) { for (Device* device : dev_mgr->ListDevices()) {
if (task.empty() || task == task_name_) { attr_table_[device->name()] = device->attributes();
// Device is local to this task.
Device* dev;
Status s = dev_mgr_->LookupDevice(device, &dev);
if (s.ok()) {
*attributes = dev->attributes();
} }
done(s); }
return;
} else { Status DeviceResolverDistributed::GetDeviceAttributes(
// Lookup of a remote device: first try the local cache. const string& device, DeviceAttributes* attributes) {
bool found = false;
{
mutex_lock l(mu_); mutex_lock l(mu_);
auto it = attr_table_.find(device); auto it = attr_table_.find(device);
if (it != attr_table_.end()) { if (it == attr_table_.end()) {
return errors::NotFound(device, " not found");
}
*attributes = it->second; *attributes = it->second;
found = true; return Status::OK();
}
}
if (found) {
done(Status::OK());
return;
}
}
// Device is remote and no cache entry was found. Refresh the cache
// then retry the lookup.
RefreshRemoteAttributes(
device, task, [this, device, task, attributes, done](const Status& s) {
if (!s.ok()) {
done(s);
} else {
GetDeviceAttributesAsync(device, task, attributes, done);
}
});
} }
void DeviceResolverDistributed::GetAllDeviceAttributesAsync( Status DeviceResolverDistributed::GetAllDeviceAttributes(
const std::vector<string>& devices, const std::vector<string>& tasks,
std::vector<DeviceAttributes>* attributes, const StatusCallback& done) {
attributes->clear();
GetAllDeviceAttributesRecursive(devices, tasks, attributes, done);
}
void DeviceResolverDistributed::GetAllDeviceAttributesRecursive(
const std::vector<string>& devices, const std::vector<string>& tasks,
std::vector<DeviceAttributes>* attributes, const StatusCallback& done) {
size_t i = attributes->size();
if (i < devices.size()) {
attributes->push_back(DeviceAttributes());
GetDeviceAttributesAsync(
devices[i], tasks[i], &attributes->back(),
[this, &devices, &tasks, attributes, done](const Status& s) {
if (!s.ok()) {
done(s);
return;
} else {
GetAllDeviceAttributesRecursive(devices, tasks, attributes, done);
}
});
} else {
done(Status::OK());
}
}
void DeviceResolverDistributed::RefreshRemoteAttributes(
const string& device, const string& task, const StatusCallback& done) {
GetStatusRequest* req = new GetStatusRequest;
GetStatusResponse* resp = new GetStatusResponse;
WorkerInterface* worker = worker_cache_->GetOrCreateWorker(task);
CHECK(worker) << "Failed to get worker for " << task;
worker->GetStatusAsync(
req, resp, /*fail_fast=*/false,
[this, device, task, req, resp, worker, done](Status s) {
if (s.ok()) {
mutex_lock l(mu_);
for (const DeviceAttributes& da : resp->device_attributes()) {
attr_table_[da.name()] = da;
}
}
done(s);
delete req;
delete resp;
worker_cache_->ReleaseWorker(task, worker);
});
}
Status DeviceResolverDistributed::GetTaskCached(
const string& task, std::vector<DeviceAttributes>* attributes) { const string& task, std::vector<DeviceAttributes>* attributes) {
mutex_lock l(mu_); mutex_lock l(mu_);
attributes->clear(); attributes->clear();
@ -129,4 +54,23 @@ Status DeviceResolverDistributed::GetTaskCached(
return Status::OK(); return Status::OK();
} }
Status DeviceResolverDistributed::UpdateDeviceAttributes(
const std::vector<DeviceAttributes>& attributes) {
mutex_lock l(mu_);
for (const DeviceAttributes& attr : attributes) {
auto item = attr_table_.insert({attr.name(), attr});
auto it = item.first;
bool success = item.second;
// Returns error if the device already exists in the cache and has a
// different incarnation.
if (!success && it->second.incarnation() != attr.incarnation()) {
return errors::FailedPrecondition(
attr.name(),
"exists in cache with a different incarnation. "
"This usually means the remote worker has restarted");
}
}
return Status::OK();
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow { namespace tensorflow {
class DeviceMgr; class DeviceMgr;
@ -28,39 +29,18 @@ class WorkerCacheInterface;
class DeviceResolverDistributed : public DeviceResolverInterface { class DeviceResolverDistributed : public DeviceResolverInterface {
public: public:
DeviceResolverDistributed(const DeviceMgr* dev_mgr, explicit DeviceResolverDistributed(const DeviceMgr* dev_mgr);
WorkerCacheInterface* worker_cache,
const string& task_name);
virtual ~DeviceResolverDistributed() {} Status GetDeviceAttributes(const string& device,
DeviceAttributes* attributes) override;
void GetAllDeviceAttributesAsync(const std::vector<string>& devices, Status GetAllDeviceAttributes(
const std::vector<string>& tasks, const string& task, std::vector<DeviceAttributes>* attributes) override;
std::vector<DeviceAttributes>* attributes,
const StatusCallback& done) override;
void GetDeviceAttributesAsync(const string& device, const string& task, Status UpdateDeviceAttributes(
DeviceAttributes* attributes, const std::vector<DeviceAttributes>& attributes) override;
const StatusCallback& done) override;
Status GetTaskCached(const string& task,
std::vector<DeviceAttributes>* attributes) override;
protected: protected:
// Loads attr_table_ with device attributes retrieved from remote task.
void RefreshRemoteAttributes(const string& device, const string& task,
const StatusCallback& done)
TF_LOCKS_EXCLUDED(mu_);
// Subroutine used by GetAllDeviceAttributesAsync. Recursively extends
// *attributes with DeviceAttributes of the corresponding device named
// by inst_params.instance.device_names.
void GetAllDeviceAttributesRecursive(
const std::vector<string>& devices, const std::vector<string>& tasks,
std::vector<DeviceAttributes>* attributes, const StatusCallback& done);
const DeviceMgr* dev_mgr_; // Not owned
WorkerCacheInterface* worker_cache_; // Not owned
const string task_name_; const string task_name_;
mutex mu_; mutex mu_;
absl::flat_hash_map<string, DeviceAttributes> attr_table_ TF_GUARDED_BY(mu_); absl::flat_hash_map<string, DeviceAttributes> attr_table_ TF_GUARDED_BY(mu_);

View File

@ -22,30 +22,19 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
// Subclass of DeviceResolverDistributed which behaves identically but using ::testing::Property;
// allows access to the attr_table_. using ::testing::UnorderedElementsAre;
class TestableDeviceResolverDistributed : public DeviceResolverDistributed {
public:
TestableDeviceResolverDistributed(const DeviceMgr* dev_mgr,
WorkerCacheInterface* worker_cache,
const string& task)
: DeviceResolverDistributed(dev_mgr, worker_cache, task) {}
absl::flat_hash_map<string, DeviceAttributes>& attr_table() {
return attr_table_;
}
};
// Create a fake 'Device' whose only interesting attribute is a non-default // Create a fake 'Device' whose only interesting attribute is a non-default
// DeviceLocality and incarnation. // DeviceLocality and incarnation.
static std::unique_ptr<Device> NewDevice(const string& type, const string& name, std::unique_ptr<Device> NewDevice(const string& type, const string& name) {
int numa_node, uint64 incarnation) {
class FakeDevice : public Device { class FakeDevice : public Device {
public: public:
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
@ -55,161 +44,121 @@ static std::unique_ptr<Device> NewDevice(const string& type, const string& name,
DeviceAttributes attr; DeviceAttributes attr;
attr.set_name(name); attr.set_name(name);
attr.set_device_type(type); attr.set_device_type(type);
attr.mutable_locality()->set_numa_node(numa_node); attr.set_incarnation(random::New64());
attr.set_incarnation(incarnation);
return absl::make_unique<FakeDevice>(attr); return absl::make_unique<FakeDevice>(attr);
} }
// Create a fake WorkerInterface that responds to requests without RPCs,
// in this case returning the DeviceAttributes of a fake remote worker.
class FakeWorker : public TestWorkerInterface {
public:
FakeWorker(const string& name, DeviceMgr* dev_mgr,
DeviceResolverDistributed* dres)
: name_(name), device_mgr_(dev_mgr), device_resolver_(dres) {}
void GetStatusAsync(const GetStatusRequest* request,
GetStatusResponse* response, bool fail_fast,
StatusCallback done) override {
std::vector<DeviceAttributes> dev_attr;
device_mgr_->ListDeviceAttributes(&dev_attr);
for (const auto& da : dev_attr) {
*response->add_device_attributes() = da;
}
done(Status::OK());
}
private:
string name_;
DeviceMgr* device_mgr_;
DeviceResolverDistributed* device_resolver_;
};
// An implementation of WorkerCacheInterface that routes all requests
// to local FakeWorkers, implementing only the methods needed for tests.
class FakeCache : public TestWorkerCache {
public:
// Override the Locality methods to actually pass through to the
// worker.
bool GetDeviceLocalityNonBlocking(const string& device,
DeviceLocality* locality) override {
return false;
}
void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
StatusCallback done) override {
string task_name;
string dev_part;
if (!DeviceNameUtils::SplitDeviceName(device, &task_name, &dev_part)) {
done(errors::Internal("failed to parse device name"));
return;
}
auto it = workers_.find(task_name);
if (it == workers_.end()) {
done(errors::Internal("failed to find worker ", task_name));
return;
}
WorkerInterface* wi = it->second;
GetStatusRequest req;
GetStatusResponse resp;
Status status = wi->GetStatus(&req, &resp);
if (!status.ok()) {
done(status);
return;
}
for (const auto& it : resp.device_attributes()) {
if (it.name() == device) {
*locality = it.locality();
done(Status::OK());
return;
}
}
done(errors::Internal("device not found: ", device));
}
};
class DeviceResDistTest : public ::testing::Test { class DeviceResDistTest : public ::testing::Test {
protected: protected:
DeviceResDistTest() {} void SetUp() override {
~DeviceResDistTest() override {
for (DeviceMgr* dm : device_mgrs_) {
delete dm;
}
for (auto it : resolvers_) {
delete it.second;
}
for (FakeWorker* w : workers_) {
delete w;
}
}
void DefineWorkers(int num_workers, int num_devices,
const string& device_type,
uint64 device_incarnation_base) {
for (int w = 0; w < num_workers; ++w) {
string name = strings::StrCat("/job:worker/replica:0/task:", w);
DefineWorker(name, device_type, num_devices,
w * num_devices + device_incarnation_base);
}
}
void DefineWorker(const string& worker_name, const string& device_type,
int num_devices, uint64 device_incarnation_base) {
std::vector<std::unique_ptr<Device>> devices; std::vector<std::unique_ptr<Device>> devices;
for (int i = 0; i < num_devices; ++i) { devices.push_back(
devices.push_back(NewDevice( NewDevice("CPU", "/job:worker/replica:0/task:0/device:CPU:0"));
device_type, devices.push_back(
strings::StrCat(worker_name, "/device:", device_type, ":", i), i, NewDevice("CPU", "/job:worker/replica:0/task:0/device:CPU:1"));
device_incarnation_base + i)); dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
} dev_resolver_ =
DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices)); absl::make_unique<DeviceResolverDistributed>(dev_mgr_.get());
TestableDeviceResolverDistributed* dev_res =
new TestableDeviceResolverDistributed(dev_mgr, &wc_, worker_name); std::vector<DeviceAttributes> attributes;
resolvers_[worker_name] = dev_res; attributes.push_back(
device_mgrs_.push_back(dev_mgr); NewDevice("CPU", "/job:worker/replica:0/task:1/device:CPU:0")
std::vector<string>* dv = &dev_by_task_[worker_name]; ->attributes());
dv->clear(); attributes.push_back(
for (auto* d : dev_mgr->ListDevices()) { NewDevice("CPU", "/job:worker/replica:0/task:1/device:CPU:1")
dv->push_back(d->name()); ->attributes());
} TF_ASSERT_OK(dev_resolver_->UpdateDeviceAttributes(attributes));
FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res);
workers_.push_back(fw);
wc_.AddWorker(worker_name, fw);
} }
FakeCache wc_; std::unique_ptr<DeviceMgr> dev_mgr_;
std::vector<DeviceMgr*> device_mgrs_; std::unique_ptr<DeviceResolverDistributed> dev_resolver_;
std::unordered_map<string, TestableDeviceResolverDistributed*> resolvers_;
std::unordered_map<string, std::vector<string>> dev_by_task_;
std::vector<FakeWorker*> workers_;
}; };
TEST_F(DeviceResDistTest, Workers3Devices4) { TEST_F(DeviceResDistTest, GetDeviceAttributesLocal) {
DefineWorkers(/*num_workers=*/3, /*num_devices=*/4, /*device_type=*/"CPU",
/*device_incarnation_base=*/1);
// Check that every device is available from every task.
for (auto it : resolvers_) {
DeviceResolverDistributed* dres = it.second;
for (auto it2 : dev_by_task_) {
const string& task_name = it2.first;
for (const auto& dev_name : it2.second) {
DeviceNameUtils::ParsedName parsed;
ASSERT_TRUE(DeviceNameUtils::ParseFullName(dev_name, &parsed));
Notification note;
Status status;
DeviceAttributes attributes; DeviceAttributes attributes;
dres->GetDeviceAttributesAsync(dev_name, task_name, &attributes, TF_ASSERT_OK(dev_resolver_->GetDeviceAttributes(
[&note, &status](const Status& s) { "/job:worker/replica:0/task:0/device:CPU:0", &attributes));
status = s; EXPECT_EQ(attributes.name(), "/job:worker/replica:0/task:0/device:CPU:0");
note.Notify();
});
note.WaitForNotification();
TF_EXPECT_OK(status);
EXPECT_EQ(parsed.id, attributes.locality().numa_node());
} }
TEST_F(DeviceResDistTest, GetDeviceAttributesLocalUnknown) {
DeviceAttributes attributes;
EXPECT_TRUE(errors::IsNotFound(dev_resolver_->GetDeviceAttributes(
"/job:worker/replica:0/task:0/device:CPU:9", &attributes)));
} }
TEST_F(DeviceResDistTest, GetAllDeviceAttributes) {
std::vector<DeviceAttributes> attributes;
TF_ASSERT_OK(dev_resolver_->GetAllDeviceAttributes(
"/job:worker/replica:0/task:0", &attributes));
EXPECT_THAT(attributes,
UnorderedElementsAre(
Property(&DeviceAttributes::name,
"/job:worker/replica:0/task:0/device:CPU:0"),
Property(&DeviceAttributes::name,
"/job:worker/replica:0/task:0/device:CPU:1")));
TF_ASSERT_OK(dev_resolver_->GetAllDeviceAttributes(
"/job:worker/replica:0/task:1", &attributes));
EXPECT_THAT(attributes,
UnorderedElementsAre(
Property(&DeviceAttributes::name,
"/job:worker/replica:0/task:1/device:CPU:0"),
Property(&DeviceAttributes::name,
"/job:worker/replica:0/task:1/device:CPU:1")));
} }
TEST_F(DeviceResDistTest, GetAllDeviceAttributesUnknown) {
std::vector<DeviceAttributes> attributes;
EXPECT_TRUE(errors::IsNotFound(dev_resolver_->GetAllDeviceAttributes(
"/job:worker/replica:0/task:3", &attributes)));
}
TEST_F(DeviceResDistTest, UpdateDeviceAttributes) {
std::vector<DeviceAttributes> attributes;
attributes.push_back(
NewDevice("CPU", "/job:worker/replica:0/task:2/device:CPU:0")
->attributes());
attributes.push_back(
NewDevice("CPU", "/job:worker/replica:0/task:2/device:CPU:1")
->attributes());
TF_ASSERT_OK(dev_resolver_->UpdateDeviceAttributes(attributes));
// Get the new task.
TF_ASSERT_OK(dev_resolver_->GetAllDeviceAttributes(
"/job:worker/replica:0/task:2", &attributes));
EXPECT_THAT(attributes,
UnorderedElementsAre(
Property(&DeviceAttributes::name,
"/job:worker/replica:0/task:2/device:CPU:0"),
Property(&DeviceAttributes::name,
"/job:worker/replica:0/task:2/device:CPU:1")));
// Get an existing task.
TF_ASSERT_OK(dev_resolver_->GetAllDeviceAttributes(
"/job:worker/replica:0/task:0", &attributes));
EXPECT_THAT(attributes,
UnorderedElementsAre(
Property(&DeviceAttributes::name,
"/job:worker/replica:0/task:0/device:CPU:0"),
Property(&DeviceAttributes::name,
"/job:worker/replica:0/task:0/device:CPU:1")));
}
TEST_F(DeviceResDistTest, UpdateDeviceAttributesExisting) {
std::vector<DeviceAttributes> attributes;
TF_ASSERT_OK(dev_resolver_->GetAllDeviceAttributes(
"/job:worker/replica:0/task:0", &attributes));
TF_ASSERT_OK(dev_resolver_->UpdateDeviceAttributes(attributes));
}
TEST_F(DeviceResDistTest, UpdateDeviceAttributesDifferentIncarnation) {
std::vector<DeviceAttributes> attributes;
attributes.push_back(
NewDevice("CPU", "/job:worker/replica:0/task:0/device:CPU:0")
->attributes());
attributes.push_back(
NewDevice("CPU", "/job:worker/replica:0/task:0/device:CPU:1")
->attributes());
EXPECT_TRUE(errors::IsFailedPrecondition(
dev_resolver_->UpdateDeviceAttributes(attributes)));
} }
} // namespace } // namespace

View File

@ -279,8 +279,7 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
} }
} else { } else {
std::unique_ptr<DeviceResolverDistributed> dev_resolver( std::unique_ptr<DeviceResolverDistributed> dev_resolver(
new DeviceResolverDistributed(worker_env_.device_mgr, worker_cache, new DeviceResolverDistributed(worker_env_.device_mgr));
default_worker_name));
std::unique_ptr<CollectiveParamResolverDistributed> param_resolver( std::unique_ptr<CollectiveParamResolverDistributed> param_resolver(
new CollectiveParamResolverDistributed(config, worker_env_.device_mgr, new CollectiveParamResolverDistributed(config, worker_env_.device_mgr,
dev_resolver.get(), worker_cache, dev_resolver.get(), worker_cache,
@ -448,8 +447,7 @@ Status GrpcServer::UpdateServerDef(const ServerDef& server_def) {
return errors::Internal("Could not parse worker name."); return errors::Internal("Could not parse worker name.");
} }
std::unique_ptr<DeviceResolverDistributed> dev_resolver( std::unique_ptr<DeviceResolverDistributed> dev_resolver(
new DeviceResolverDistributed(worker_env_.device_mgr, worker_cache, new DeviceResolverDistributed(worker_env_.device_mgr));
default_worker_name));
std::unique_ptr<CollectiveParamResolverDistributed> param_resolver( std::unique_ptr<CollectiveParamResolverDistributed> param_resolver(
new CollectiveParamResolverDistributed( new CollectiveParamResolverDistributed(
server_def_.default_session_config(), worker_env_.device_mgr, server_def_.default_session_config(), worker_env_.device_mgr,

View File

@ -45,8 +45,8 @@ class RpcCollectiveExecutorMgrTest : public ::testing::Test {
std::vector<std::unique_ptr<Device>> devices; std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices)); TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices)); device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
std::unique_ptr<DeviceResolverDistributed> dr(new DeviceResolverDistributed( std::unique_ptr<DeviceResolverDistributed> dr(
device_mgr_.get(), worker_cache, task_name)); new DeviceResolverDistributed(device_mgr_.get()));
std::unique_ptr<CollectiveParamResolverDistributed> cpr( std::unique_ptr<CollectiveParamResolverDistributed> cpr(
new CollectiveParamResolverDistributed(options.config, new CollectiveParamResolverDistributed(options.config,
device_mgr_.get(), dr.get(), device_mgr_.get(), dr.get(),

View File

@ -154,22 +154,18 @@ class DeviceResolverInterface {
public: public:
virtual ~DeviceResolverInterface() {} virtual ~DeviceResolverInterface() {}
// Collects DeviceAttributes protobufs from all of the devices identified
// in 'col_params'.
virtual void GetAllDeviceAttributesAsync(
const std::vector<string>& devices, const std::vector<string>& tasks,
std::vector<DeviceAttributes>* attributes,
const StatusCallback& done) = 0;
// Populates *attributes with the DeviceAttributes of the specified device. // Populates *attributes with the DeviceAttributes of the specified device.
virtual void GetDeviceAttributesAsync(const string& device, virtual Status GetDeviceAttributes(const string& device,
const string& task, DeviceAttributes* attributes) = 0;
DeviceAttributes* attributes,
const StatusCallback& done) = 0;
// Returns the cached device attributes of a task. // Returns all device attributes of a task.
virtual Status GetTaskCached(const string& task, virtual Status GetAllDeviceAttributes(
std::vector<DeviceAttributes>* attributes) = 0; const string& task, std::vector<DeviceAttributes>* attributes) = 0;
// Updates device attributes. It returns error if any device already
// exists in the DeviceResolver and has a different incarnation.
virtual Status UpdateDeviceAttributes(
const std::vector<DeviceAttributes>& attributes) = 0;
}; };
// Interface that provides resolution of shared CollectiveParams fields. // Interface that provides resolution of shared CollectiveParams fields.