Use device attributes from group resolution
1. CompleteInstanceLocal is changed to use device attributes from GroupRec, so it no longer needs to issue GetStatus calls to get device attributes. 2. CollectiveParamResolverDistributed::CompleteParamsAsync updates device attributes from group resolution to the local DeviceResolver, so that RecvFromPeer can query the attributes of the target device. 3. DeviceResolver is simplified since it no longer needs to issue RPCs. PiperOrigin-RevId: 330752257 Change-Id: Id6d5a1a3b1115928570b354fc27ebf7351c9d6ca
This commit is contained in:
parent
f7620b089c
commit
9d584860d7
@ -546,6 +546,9 @@ cc_library(
|
||||
deps = [
|
||||
":device_mgr",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/framework:device_attributes_proto_cc",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -495,7 +495,8 @@ void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
|
||||
|
||||
void CollectiveParamResolverLocal::InitInstanceSharedParams(
|
||||
const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
|
||||
const StatusCallback& done) {
|
||||
const StatusCallback& done) TF_NO_THREAD_SAFETY_ANALYSIS {
|
||||
std::vector<DeviceAttributes> attributes;
|
||||
ir->shared.instance = cp->instance;
|
||||
{
|
||||
mutex_lock gl(gr->mu);
|
||||
@ -504,10 +505,12 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams(
|
||||
ir->shared.instance.task_names.clear();
|
||||
ir->shared.instance.device_names.reserve(gr->devices.size());
|
||||
ir->shared.instance.task_names.reserve(gr->devices.size());
|
||||
attributes.reserve(gr->devices.size());
|
||||
for (const auto& item : gr->devices) {
|
||||
ir->shared.instance.device_names.push_back(item.first);
|
||||
ir->shared.instance.task_names.push_back(
|
||||
TaskNameFromDeviceName(item.first));
|
||||
attributes.push_back(item.second);
|
||||
}
|
||||
VLOG(2) << "Initialized names for instance: "
|
||||
<< ir->shared.instance.ToString();
|
||||
@ -526,6 +529,9 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams(
|
||||
// GetDeviceAttributesAsync will use those fields to launch RPCs.
|
||||
CompleteTaskIsLocal(task_name_, &ir->shared);
|
||||
|
||||
// TODO(b/151232436): clean up the following code since we no longer need to
|
||||
// execute it in a callback.
|
||||
|
||||
// Because the callback may execute in a different thread, we release
|
||||
// ir->out_mu here. Before releasing, we mark it as unavailable for other
|
||||
// threads.
|
||||
@ -533,30 +539,24 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams(
|
||||
const auto device_names = ir->shared.instance.device_names;
|
||||
const auto task_names = ir->shared.instance.task_names;
|
||||
ir->out_mu.unlock();
|
||||
std::vector<DeviceAttributes>* attributes = new std::vector<DeviceAttributes>;
|
||||
// Suppress linter warning about access to shared without mutex because in
|
||||
// principle the members are locked due to out_mu_available=false.
|
||||
dev_resolver_->GetAllDeviceAttributesAsync(
|
||||
ir->shared.instance.device_names, // NOLINT
|
||||
ir->shared.instance.task_names, // NOLINT
|
||||
attributes,
|
||||
[this, gr, cp, ir, attributes, done](const Status& s)
|
||||
TF_EXCLUSIVE_LOCK_FUNCTION(ir->out_mu) {
|
||||
// Then we recover the lock in the callback thread that will hold it
|
||||
// through the rest of the call chain. Signal the cv now, any
|
||||
// waiting threads will wake only when out_mu is released later.
|
||||
ir->out_mu.lock();
|
||||
DCHECK(!ir->out_mu_available);
|
||||
ir->out_mu_available = true;
|
||||
ir->out_cv.notify_all();
|
||||
if (s.ok()) {
|
||||
CompleteDefaultRanking(gr, cp, ir, *attributes);
|
||||
done(Status::OK());
|
||||
} else {
|
||||
done(s);
|
||||
}
|
||||
delete attributes;
|
||||
});
|
||||
auto complete_init = [this, gr, cp, ir, attributes, done](const Status& s)
|
||||
TF_EXCLUSIVE_LOCK_FUNCTION(ir->out_mu) {
|
||||
// Then we recover the lock in the callback thread
|
||||
// that will hold it through the rest of the call
|
||||
// chain. Signal the cv now, any waiting threads
|
||||
// will wake only when out_mu is released later.
|
||||
ir->out_mu.lock();
|
||||
DCHECK(!ir->out_mu_available);
|
||||
ir->out_mu_available = true;
|
||||
ir->out_cv.notify_all();
|
||||
if (s.ok()) {
|
||||
CompleteDefaultRanking(gr, cp, ir, attributes);
|
||||
done(Status::OK());
|
||||
} else {
|
||||
done(s);
|
||||
}
|
||||
};
|
||||
complete_init(Status::OK());
|
||||
}
|
||||
|
||||
// NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks
|
||||
|
@ -15,41 +15,34 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_resolver_local.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void DeviceResolverLocal::GetAllDeviceAttributesAsync(
|
||||
const std::vector<string>& devices, const std::vector<string>& tasks,
|
||||
std::vector<DeviceAttributes>* attributes, const StatusCallback& done) {
|
||||
attributes->clear();
|
||||
for (const string& device_name : devices) {
|
||||
Device* dev;
|
||||
Status s = dev_mgr_->LookupDevice(device_name, &dev);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
attributes->push_back(dev->attributes());
|
||||
}
|
||||
done(Status::OK());
|
||||
}
|
||||
|
||||
void DeviceResolverLocal::GetDeviceAttributesAsync(const string& device,
|
||||
const string& task,
|
||||
DeviceAttributes* attributes,
|
||||
const StatusCallback& done) {
|
||||
Status DeviceResolverLocal::GetDeviceAttributes(const string& device,
|
||||
DeviceAttributes* attributes) {
|
||||
Device* dev;
|
||||
// LookupDevice returns InvalidArgument if the device is not found.
|
||||
Status s = dev_mgr_->LookupDevice(device, &dev);
|
||||
if (s.ok()) {
|
||||
*attributes = dev->attributes();
|
||||
if (errors::IsInvalidArgument(s)) {
|
||||
return errors::NotFound(device, " not found");
|
||||
} else if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
done(s);
|
||||
*attributes = dev->attributes();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeviceResolverLocal::GetTaskCached(
|
||||
Status DeviceResolverLocal::GetAllDeviceAttributes(
|
||||
const string& task, std::vector<DeviceAttributes>* attributes) {
|
||||
return errors::Internal(
|
||||
"GetTaskCached is not supposed to be called in local collectives");
|
||||
}
|
||||
|
||||
Status DeviceResolverLocal::UpdateDeviceAttributes(
|
||||
const std::vector<DeviceAttributes>& attributes) {
|
||||
return errors::Internal(
|
||||
"UpdateDeviceAttributes shouldn't be called with local collectives");
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,9 +16,11 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/collective.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class DeviceMgr;
|
||||
@ -26,21 +28,16 @@ class DeviceMgr;
|
||||
// Implements DeviceResolverInterface in a single-task context.
|
||||
class DeviceResolverLocal : public DeviceResolverInterface {
|
||||
public:
|
||||
DeviceResolverLocal(const DeviceMgr* dev_mgr) : dev_mgr_(dev_mgr) {}
|
||||
explicit DeviceResolverLocal(const DeviceMgr* dev_mgr) : dev_mgr_(dev_mgr) {}
|
||||
|
||||
virtual ~DeviceResolverLocal() {}
|
||||
Status GetDeviceAttributes(const string& device,
|
||||
DeviceAttributes* attributes) override;
|
||||
|
||||
void GetAllDeviceAttributesAsync(const std::vector<string>& devices,
|
||||
const std::vector<string>& tasks,
|
||||
std::vector<DeviceAttributes>* attributes,
|
||||
const StatusCallback& done) override;
|
||||
Status GetAllDeviceAttributes(
|
||||
const string& task, std::vector<DeviceAttributes>* attributes) override;
|
||||
|
||||
void GetDeviceAttributesAsync(const string& device, const string& task,
|
||||
DeviceAttributes* attributes,
|
||||
const StatusCallback& done) override;
|
||||
|
||||
Status GetTaskCached(const string& task,
|
||||
std::vector<DeviceAttributes>* attributes) override;
|
||||
Status UpdateDeviceAttributes(
|
||||
const std::vector<DeviceAttributes>& attributes) override;
|
||||
|
||||
protected:
|
||||
const DeviceMgr* dev_mgr_;
|
||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@ -46,35 +45,27 @@ class DeviceResolverLocalTest : public ::testing::Test {
|
||||
};
|
||||
|
||||
TEST_F(DeviceResolverLocalTest, GetDeviceAttributesKnown) {
|
||||
std::vector<DeviceAttributes> attributes;
|
||||
std::vector<string> devices{"/job:localhost/replica:0/task:0/device:CPU:1",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:2"};
|
||||
Notification note;
|
||||
Status status;
|
||||
drl_->GetAllDeviceAttributesAsync(devices, /*tasks=*/{}, &attributes,
|
||||
[¬e, &status](const Status& s) {
|
||||
status = s;
|
||||
note.Notify();
|
||||
});
|
||||
note.WaitForNotification();
|
||||
TF_EXPECT_OK(status);
|
||||
EXPECT_EQ(2, attributes.size());
|
||||
DeviceAttributes attributes;
|
||||
TF_EXPECT_OK(drl_->GetDeviceAttributes(
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1", &attributes));
|
||||
EXPECT_EQ(attributes.name(), "/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
}
|
||||
|
||||
TEST_F(DeviceResolverLocalTest, GetDeviceAttributesUnknown) {
|
||||
DeviceAttributes attributes;
|
||||
EXPECT_TRUE(errors::IsNotFound(drl_->GetDeviceAttributes(
|
||||
"/job:localhost/replica:0/task:0/device:CPU:9", &attributes)));
|
||||
}
|
||||
|
||||
TEST_F(DeviceResolverLocalTest, GetAllDeviceAttributes) {
|
||||
std::vector<DeviceAttributes> attributes;
|
||||
// In some builds there may be 1 GPU, but there should never be 9.
|
||||
std::vector<string> devices{"/job:localhost/replica:0/task:0/device:GPU:9"};
|
||||
Notification note;
|
||||
Status status;
|
||||
drl_->GetAllDeviceAttributesAsync(devices, /*tasks=*/{}, &attributes,
|
||||
[¬e, &status](const Status& s) {
|
||||
status = s;
|
||||
note.Notify();
|
||||
});
|
||||
note.WaitForNotification();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_EQ(0, attributes.size());
|
||||
EXPECT_TRUE(errors::IsInternal(
|
||||
drl_->GetAllDeviceAttributes(/*task*/ "", &attributes)));
|
||||
}
|
||||
|
||||
TEST_F(DeviceResolverLocalTest, UpdateDeviceAttributes) {
|
||||
std::vector<DeviceAttributes> attributes;
|
||||
EXPECT_TRUE(errors::IsInternal(drl_->UpdateDeviceAttributes(attributes)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -574,6 +574,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -621,6 +622,8 @@ cc_library(
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
)
|
||||
|
@ -106,16 +106,23 @@ void CollectiveParamResolverDistributed::CompleteParamsAsync(
|
||||
CancellationManager* cancel_mgr, const StatusCallback& done) {
|
||||
VLOG(1) << "CompleteParams distributed " << device.name() << " for " << cp
|
||||
<< ": " << cp->ToString();
|
||||
CompleteGroupDistributed(device, cp, cancel_mgr,
|
||||
[this, device, cp, cancel_mgr, done](
|
||||
const Status& s, const GroupRec* gr) {
|
||||
if (s.ok()) {
|
||||
CompleteInstanceDistributed(
|
||||
device.name(), gr, cp, cancel_mgr, done);
|
||||
} else {
|
||||
done(s);
|
||||
}
|
||||
});
|
||||
CompleteGroupDistributed(
|
||||
device, cp, cancel_mgr,
|
||||
[this, device, cp, cancel_mgr, done](Status s, const GroupRec* gr) {
|
||||
if (s.ok()) {
|
||||
std::vector<DeviceAttributes> attributes;
|
||||
mutex_lock l(gr->mu);
|
||||
for (const auto& item : gr->devices) {
|
||||
attributes.push_back(item.second);
|
||||
}
|
||||
s = dev_resolver_->UpdateDeviceAttributes(attributes);
|
||||
}
|
||||
if (s.ok()) {
|
||||
CompleteInstanceDistributed(device.name(), gr, cp, cancel_mgr, done);
|
||||
} else {
|
||||
done(s);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void CollectiveParamResolverDistributed::CompleteGroupAsync(
|
||||
|
@ -157,7 +157,7 @@ class DeviceResDistTest : public ::testing::Test {
|
||||
dv->push_back(d->name());
|
||||
}
|
||||
dev_resolvers_[worker_name] = absl::make_unique<DeviceResolverDistributed>(
|
||||
device_mgrs_[worker_name].get(), &wc_, worker_name);
|
||||
device_mgrs_[worker_name].get());
|
||||
cp_resolvers_[worker_name] =
|
||||
absl::make_unique<CollectiveParamResolverDistributed>(
|
||||
config, device_mgrs_[worker_name].get(),
|
||||
@ -256,6 +256,7 @@ class DeviceResDistTest : public ::testing::Test {
|
||||
EXPECT_EQ(cp_[device_name].instance.device_names.size(), dev_count);
|
||||
EXPECT_EQ(cp_[device_name].instance.device_names[idx], device_name);
|
||||
EXPECT_EQ(cp_[device_name].instance.task_names[idx], task_name);
|
||||
ValidateDeviceResolver(cp_[device_name], task_name);
|
||||
if (idx > 0) {
|
||||
EXPECT_EQ(cp_[dev0].group.runtime_details.communicator_key,
|
||||
cp_[device_name].group.runtime_details.communicator_key);
|
||||
@ -270,6 +271,14 @@ class DeviceResDistTest : public ::testing::Test {
|
||||
}
|
||||
}
|
||||
|
||||
void ValidateDeviceResolver(const CollectiveParams& cp, const string& task) {
|
||||
for (const string& device_name : cp.instance.device_names) {
|
||||
DeviceAttributes attributes;
|
||||
TF_ASSERT_OK(
|
||||
dev_resolvers_[task]->GetDeviceAttributes(device_name, &attributes));
|
||||
}
|
||||
}
|
||||
|
||||
void RestartWorker(int worker_idx, int num_workers, int num_devices,
|
||||
const string& device_type, bool nccl) {
|
||||
string worker_name =
|
||||
|
@ -160,26 +160,17 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
|
||||
done(s);
|
||||
};
|
||||
|
||||
// Logic to execute once we have the device attributes for the server-side
|
||||
// device.
|
||||
auto dev_attributes_callback = [this, state, peer_device, peer_task, key,
|
||||
to_device, to_device_ctx, to_alloc_attr,
|
||||
to_tensor, client_locality,
|
||||
recv_buf_callback](const Status& s) {
|
||||
if (!s.ok()) {
|
||||
recv_buf_callback(s);
|
||||
} else {
|
||||
state->call.reset(new RecvBufCall(
|
||||
step_id_, peer_device, peer_task, key, to_device, to_device_ctx,
|
||||
to_alloc_attr, to_tensor, client_locality, state->server_attributes,
|
||||
&cancel_mgr_, worker_cache_));
|
||||
state->call->Start(recv_buf_callback);
|
||||
}
|
||||
};
|
||||
|
||||
dev_resolver_->GetDeviceAttributesAsync(peer_device, peer_task,
|
||||
&state->server_attributes,
|
||||
dev_attributes_callback);
|
||||
Status s = dev_resolver_->GetDeviceAttributes(peer_device,
|
||||
&state->server_attributes);
|
||||
if (!s.ok()) {
|
||||
recv_buf_callback(s);
|
||||
return;
|
||||
}
|
||||
state->call.reset(
|
||||
new RecvBufCall(step_id_, peer_device, peer_task, key, to_device,
|
||||
to_device_ctx, to_alloc_attr, to_tensor, client_locality,
|
||||
state->server_attributes, &cancel_mgr_, worker_cache_));
|
||||
state->call->Start(recv_buf_callback);
|
||||
}
|
||||
|
||||
void CollectiveRemoteAccessDistributed::CheckPeerHealth(
|
||||
@ -209,7 +200,7 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth(
|
||||
[this, req, resp, wi, peer_task, done](Status s) {
|
||||
std::vector<DeviceAttributes> cached_attrs;
|
||||
if (s.ok()) {
|
||||
s = dev_resolver_->GetTaskCached(peer_task, &cached_attrs);
|
||||
s = dev_resolver_->GetAllDeviceAttributes(peer_task, &cached_attrs);
|
||||
}
|
||||
if (s.ok()) {
|
||||
absl::flat_hash_set<uint64> remote_incarnations;
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
|
||||
#include "tensorflow/core/distributed_runtime/test_utils.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
@ -30,7 +31,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/transport_options.pb.h"
|
||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
// The only interesting method on CollectiveRemoteAccessDistributed
|
||||
// that's not on CollectiveRemoteAccessLocal is RecvFromPeer which
|
||||
@ -224,6 +224,18 @@ class CollRMADistTest : public ::testing::Test {
|
||||
}
|
||||
}
|
||||
|
||||
// Populates all device resolvers with device attributes of the cluster. This
|
||||
// should be called in the beginning of all tests unless you would like to
|
||||
// simulate a situation that is before parameter resolution.
|
||||
void ResolveDeviceAttributes() {
|
||||
for (auto& dev_resolver_item : dev_resolvers_) {
|
||||
DeviceResolverDistributed* dev_resolver = dev_resolver_item.second;
|
||||
for (const auto& item : dev_by_task_) {
|
||||
TF_CHECK_OK(dev_resolver->UpdateDeviceAttributes(item.second));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DefineWorker(const string& worker_name, const string& device_type,
|
||||
int num_devices, bool is_failed = false) {
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
@ -234,13 +246,12 @@ class CollRMADistTest : public ::testing::Test {
|
||||
}
|
||||
DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
|
||||
device_mgrs_.push_back(dev_mgr);
|
||||
std::vector<string>* dv = &dev_by_task_[worker_name];
|
||||
std::vector<DeviceAttributes>* dv = &dev_by_task_[worker_name];
|
||||
dv->clear();
|
||||
for (auto d : dev_mgr->ListDevices()) {
|
||||
dv->push_back(d->name());
|
||||
dv->push_back(d->attributes());
|
||||
}
|
||||
DeviceResolverDistributed* dev_res =
|
||||
new DeviceResolverDistributed(dev_mgr, &wc_, worker_name);
|
||||
DeviceResolverDistributed* dev_res = new DeviceResolverDistributed(dev_mgr);
|
||||
dev_resolvers_[worker_name] = dev_res;
|
||||
FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res, is_failed);
|
||||
workers_.push_back(fw);
|
||||
@ -254,6 +265,9 @@ class CollRMADistTest : public ::testing::Test {
|
||||
delete it->second;
|
||||
dev_resolvers_.erase(it);
|
||||
}
|
||||
// After restarting a worker, the other workers already have the device
|
||||
// attributes of the old worker. We don't broadcast device attributes of the
|
||||
// new worker to mimic the real world.
|
||||
DefineWorker(worker_name, device_type, num_devices, is_failed);
|
||||
}
|
||||
|
||||
@ -269,7 +283,7 @@ class CollRMADistTest : public ::testing::Test {
|
||||
CancellationManager cm_;
|
||||
std::vector<DeviceMgr*> device_mgrs_;
|
||||
std::unordered_map<string, DeviceResolverDistributed*> dev_resolvers_;
|
||||
std::unordered_map<string, std::vector<string>> dev_by_task_;
|
||||
std::unordered_map<string, std::vector<DeviceAttributes>> dev_by_task_;
|
||||
std::shared_ptr<UnboundedWorkQueue> work_queue_;
|
||||
std::vector<FakeWorker*> workers_;
|
||||
std::unique_ptr<CollectiveRemoteAccessDistributed> rma_;
|
||||
@ -284,6 +298,7 @@ class CollRMADistTest : public ::testing::Test {
|
||||
};
|
||||
|
||||
TEST_F(CollRMADistTest, ProdFirstOK) {
|
||||
ResolveDeviceAttributes();
|
||||
Notification consumer_note;
|
||||
Notification producer_note;
|
||||
Status consumer_status;
|
||||
@ -319,6 +334,7 @@ TEST_F(CollRMADistTest, ProdFirstOK) {
|
||||
}
|
||||
|
||||
TEST_F(CollRMADistTest, ConsFirstOK) {
|
||||
ResolveDeviceAttributes();
|
||||
Notification consumer_note;
|
||||
Notification producer_note;
|
||||
Status consumer_status;
|
||||
@ -354,6 +370,7 @@ TEST_F(CollRMADistTest, ConsFirstOK) {
|
||||
}
|
||||
|
||||
TEST_F(CollRMADistTest, ConsFirstAbort) {
|
||||
ResolveDeviceAttributes();
|
||||
Notification consumer_note;
|
||||
Status consumer_status;
|
||||
const string kBufKey = "fake_buf_key";
|
||||
@ -377,6 +394,7 @@ TEST_F(CollRMADistTest, ConsFirstAbort) {
|
||||
}
|
||||
|
||||
TEST_F(CollRMADistTest, WorkerRestart) {
|
||||
ResolveDeviceAttributes();
|
||||
Notification consumer_note;
|
||||
Notification producer_note;
|
||||
Status consumer_status;
|
||||
@ -428,21 +446,7 @@ TEST_F(CollRMADistTest, WorkerRestart) {
|
||||
}
|
||||
|
||||
TEST_F(CollRMADistTest, CheckHealthOKWithCachedAttr) {
|
||||
DeviceAttributes attr;
|
||||
Status get_attr_status;
|
||||
Notification get_attr_done;
|
||||
// Call GetDeviceAttributesAsync to cache the device attributes of a remote
|
||||
// worker.
|
||||
dev_resolvers_["/job:worker/replica:0/task:0"]->GetDeviceAttributesAsync(
|
||||
"/job:worker/replica:0/task:1/device:CPU:0",
|
||||
"/job:worker/replica:0/task:1", &attr,
|
||||
[&get_attr_status, &get_attr_done](const Status& s) {
|
||||
get_attr_status = s;
|
||||
get_attr_done.Notify();
|
||||
});
|
||||
get_attr_done.WaitForNotification();
|
||||
TF_ASSERT_OK(get_attr_status);
|
||||
|
||||
ResolveDeviceAttributes();
|
||||
Status check_health_status;
|
||||
Notification check_health_done;
|
||||
rma_->CheckPeerHealth(
|
||||
@ -469,21 +473,7 @@ TEST_F(CollRMADistTest, CheckHealthOKWithoutCachedAttr) {
|
||||
}
|
||||
|
||||
TEST_F(CollRMADistTest, CheckHealthRestarted) {
|
||||
DeviceAttributes attr;
|
||||
Status get_attr_status;
|
||||
Notification get_attr_done;
|
||||
// Call GetDeviceAttributesAsync to cache the device attributes of a remote
|
||||
// worker.
|
||||
dev_resolvers_["/job:worker/replica:0/task:0"]->GetDeviceAttributesAsync(
|
||||
"/job:worker/replica:0/task:1/device:CPU:0",
|
||||
"/job:worker/replica:0/task:1", &attr,
|
||||
[&get_attr_status, &get_attr_done](const Status& s) {
|
||||
get_attr_status = s;
|
||||
get_attr_done.Notify();
|
||||
});
|
||||
get_attr_done.WaitForNotification();
|
||||
TF_ASSERT_OK(get_attr_status);
|
||||
|
||||
ResolveDeviceAttributes();
|
||||
RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1);
|
||||
|
||||
Status check_health_status;
|
||||
@ -499,21 +489,7 @@ TEST_F(CollRMADistTest, CheckHealthRestarted) {
|
||||
}
|
||||
|
||||
TEST_F(CollRMADistTest, CheckHealthFailedPeer) {
|
||||
DeviceAttributes attr;
|
||||
Status get_attr_status;
|
||||
Notification get_attr_done;
|
||||
// Call GetDeviceAttributesAsync to cache the device attributes of a remote
|
||||
// worker.
|
||||
dev_resolvers_["/job:worker/replica:0/task:0"]->GetDeviceAttributesAsync(
|
||||
"/job:worker/replica:0/task:1/device:CPU:0",
|
||||
"/job:worker/replica:0/task:1", &attr,
|
||||
[&get_attr_status, &get_attr_done](const Status& s) {
|
||||
get_attr_status = s;
|
||||
get_attr_done.Notify();
|
||||
});
|
||||
get_attr_done.WaitForNotification();
|
||||
TF_ASSERT_OK(get_attr_status);
|
||||
|
||||
ResolveDeviceAttributes();
|
||||
RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1,
|
||||
/*is_failed*/ true);
|
||||
|
||||
@ -530,25 +506,8 @@ TEST_F(CollRMADistTest, CheckHealthFailedPeer) {
|
||||
}
|
||||
|
||||
TEST_F(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) {
|
||||
ResolveDeviceAttributes();
|
||||
RestartWorker("/job:worker/replica:0/task:1", "GPU", /*num_devices*/ 1);
|
||||
|
||||
DeviceAttributes attr;
|
||||
Status get_attr_status;
|
||||
Notification get_attr_done;
|
||||
// Call GetDeviceAttributesAsync to cache the device attributes of a remote
|
||||
// worker.
|
||||
dev_resolvers_["/job:worker/replica:0/task:0"]->GetDeviceAttributesAsync(
|
||||
"/job:worker/replica:0/task:1/device:GPU:0",
|
||||
"/job:worker/replica:0/task:1", &attr,
|
||||
[&get_attr_status, &get_attr_done](const Status& s) {
|
||||
get_attr_status = s;
|
||||
get_attr_done.Notify();
|
||||
});
|
||||
get_attr_done.WaitForNotification();
|
||||
TF_ASSERT_OK(get_attr_status);
|
||||
|
||||
RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1);
|
||||
|
||||
Status check_health_status;
|
||||
Notification check_health_done;
|
||||
rma_->CheckPeerHealth(
|
||||
|
@ -15,105 +15,30 @@ limitations under the License.
|
||||
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
DeviceResolverDistributed::DeviceResolverDistributed(
|
||||
const DeviceMgr* dev_mgr, WorkerCacheInterface* worker_cache,
|
||||
const string& task_name)
|
||||
: dev_mgr_(dev_mgr), worker_cache_(worker_cache), task_name_(task_name) {}
|
||||
|
||||
void DeviceResolverDistributed::GetDeviceAttributesAsync(
|
||||
const string& device, const string& task, DeviceAttributes* attributes,
|
||||
const StatusCallback& done) {
|
||||
if (task.empty() || task == task_name_) {
|
||||
// Device is local to this task.
|
||||
Device* dev;
|
||||
Status s = dev_mgr_->LookupDevice(device, &dev);
|
||||
if (s.ok()) {
|
||||
*attributes = dev->attributes();
|
||||
}
|
||||
done(s);
|
||||
return;
|
||||
} else {
|
||||
// Lookup of a remote device: first try the local cache.
|
||||
bool found = false;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
auto it = attr_table_.find(device);
|
||||
if (it != attr_table_.end()) {
|
||||
*attributes = it->second;
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
done(Status::OK());
|
||||
return;
|
||||
}
|
||||
}
|
||||
// Device is remote and no cache entry was found. Refresh the cache
|
||||
// then retry the lookup.
|
||||
RefreshRemoteAttributes(
|
||||
device, task, [this, device, task, attributes, done](const Status& s) {
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
} else {
|
||||
GetDeviceAttributesAsync(device, task, attributes, done);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void DeviceResolverDistributed::GetAllDeviceAttributesAsync(
|
||||
const std::vector<string>& devices, const std::vector<string>& tasks,
|
||||
std::vector<DeviceAttributes>* attributes, const StatusCallback& done) {
|
||||
attributes->clear();
|
||||
GetAllDeviceAttributesRecursive(devices, tasks, attributes, done);
|
||||
}
|
||||
|
||||
void DeviceResolverDistributed::GetAllDeviceAttributesRecursive(
|
||||
const std::vector<string>& devices, const std::vector<string>& tasks,
|
||||
std::vector<DeviceAttributes>* attributes, const StatusCallback& done) {
|
||||
size_t i = attributes->size();
|
||||
if (i < devices.size()) {
|
||||
attributes->push_back(DeviceAttributes());
|
||||
GetDeviceAttributesAsync(
|
||||
devices[i], tasks[i], &attributes->back(),
|
||||
[this, &devices, &tasks, attributes, done](const Status& s) {
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
} else {
|
||||
GetAllDeviceAttributesRecursive(devices, tasks, attributes, done);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
done(Status::OK());
|
||||
DeviceResolverDistributed::DeviceResolverDistributed(const DeviceMgr* dev_mgr) {
|
||||
mutex_lock l(mu_);
|
||||
for (Device* device : dev_mgr->ListDevices()) {
|
||||
attr_table_[device->name()] = device->attributes();
|
||||
}
|
||||
}
|
||||
|
||||
void DeviceResolverDistributed::RefreshRemoteAttributes(
|
||||
const string& device, const string& task, const StatusCallback& done) {
|
||||
GetStatusRequest* req = new GetStatusRequest;
|
||||
GetStatusResponse* resp = new GetStatusResponse;
|
||||
WorkerInterface* worker = worker_cache_->GetOrCreateWorker(task);
|
||||
CHECK(worker) << "Failed to get worker for " << task;
|
||||
worker->GetStatusAsync(
|
||||
req, resp, /*fail_fast=*/false,
|
||||
[this, device, task, req, resp, worker, done](Status s) {
|
||||
if (s.ok()) {
|
||||
mutex_lock l(mu_);
|
||||
for (const DeviceAttributes& da : resp->device_attributes()) {
|
||||
attr_table_[da.name()] = da;
|
||||
}
|
||||
}
|
||||
done(s);
|
||||
delete req;
|
||||
delete resp;
|
||||
worker_cache_->ReleaseWorker(task, worker);
|
||||
});
|
||||
Status DeviceResolverDistributed::GetDeviceAttributes(
|
||||
const string& device, DeviceAttributes* attributes) {
|
||||
mutex_lock l(mu_);
|
||||
auto it = attr_table_.find(device);
|
||||
if (it == attr_table_.end()) {
|
||||
return errors::NotFound(device, " not found");
|
||||
}
|
||||
*attributes = it->second;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeviceResolverDistributed::GetTaskCached(
|
||||
Status DeviceResolverDistributed::GetAllDeviceAttributes(
|
||||
const string& task, std::vector<DeviceAttributes>* attributes) {
|
||||
mutex_lock l(mu_);
|
||||
attributes->clear();
|
||||
@ -129,4 +54,23 @@ Status DeviceResolverDistributed::GetTaskCached(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeviceResolverDistributed::UpdateDeviceAttributes(
|
||||
const std::vector<DeviceAttributes>& attributes) {
|
||||
mutex_lock l(mu_);
|
||||
for (const DeviceAttributes& attr : attributes) {
|
||||
auto item = attr_table_.insert({attr.name(), attr});
|
||||
auto it = item.first;
|
||||
bool success = item.second;
|
||||
// Returns error if the device already exists in the cache and has a
|
||||
// different incarnation.
|
||||
if (!success && it->second.incarnation() != attr.incarnation()) {
|
||||
return errors::FailedPrecondition(
|
||||
attr.name(),
|
||||
"exists in cache with a different incarnation. "
|
||||
"This usually means the remote worker has restarted");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/framework/collective.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class DeviceMgr;
|
||||
@ -28,39 +29,18 @@ class WorkerCacheInterface;
|
||||
|
||||
class DeviceResolverDistributed : public DeviceResolverInterface {
|
||||
public:
|
||||
DeviceResolverDistributed(const DeviceMgr* dev_mgr,
|
||||
WorkerCacheInterface* worker_cache,
|
||||
const string& task_name);
|
||||
explicit DeviceResolverDistributed(const DeviceMgr* dev_mgr);
|
||||
|
||||
virtual ~DeviceResolverDistributed() {}
|
||||
Status GetDeviceAttributes(const string& device,
|
||||
DeviceAttributes* attributes) override;
|
||||
|
||||
void GetAllDeviceAttributesAsync(const std::vector<string>& devices,
|
||||
const std::vector<string>& tasks,
|
||||
std::vector<DeviceAttributes>* attributes,
|
||||
const StatusCallback& done) override;
|
||||
Status GetAllDeviceAttributes(
|
||||
const string& task, std::vector<DeviceAttributes>* attributes) override;
|
||||
|
||||
void GetDeviceAttributesAsync(const string& device, const string& task,
|
||||
DeviceAttributes* attributes,
|
||||
const StatusCallback& done) override;
|
||||
|
||||
Status GetTaskCached(const string& task,
|
||||
std::vector<DeviceAttributes>* attributes) override;
|
||||
Status UpdateDeviceAttributes(
|
||||
const std::vector<DeviceAttributes>& attributes) override;
|
||||
|
||||
protected:
|
||||
// Loads attr_table_ with device attributes retrieved from remote task.
|
||||
void RefreshRemoteAttributes(const string& device, const string& task,
|
||||
const StatusCallback& done)
|
||||
TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Subroutine used by GetAllDeviceAttributesAsync. Recursively extends
|
||||
// *attributes with DeviceAttributes of the corresponding device named
|
||||
// by inst_params.instance.device_names.
|
||||
void GetAllDeviceAttributesRecursive(
|
||||
const std::vector<string>& devices, const std::vector<string>& tasks,
|
||||
std::vector<DeviceAttributes>* attributes, const StatusCallback& done);
|
||||
|
||||
const DeviceMgr* dev_mgr_; // Not owned
|
||||
WorkerCacheInterface* worker_cache_; // Not owned
|
||||
const string task_name_;
|
||||
mutex mu_;
|
||||
absl::flat_hash_map<string, DeviceAttributes> attr_table_ TF_GUARDED_BY(mu_);
|
||||
|
@ -22,30 +22,19 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/random.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Subclass of DeviceResolverDistributed which behaves identically but
|
||||
// allows access to the attr_table_.
|
||||
class TestableDeviceResolverDistributed : public DeviceResolverDistributed {
|
||||
public:
|
||||
TestableDeviceResolverDistributed(const DeviceMgr* dev_mgr,
|
||||
WorkerCacheInterface* worker_cache,
|
||||
const string& task)
|
||||
: DeviceResolverDistributed(dev_mgr, worker_cache, task) {}
|
||||
|
||||
absl::flat_hash_map<string, DeviceAttributes>& attr_table() {
|
||||
return attr_table_;
|
||||
}
|
||||
};
|
||||
using ::testing::Property;
|
||||
using ::testing::UnorderedElementsAre;
|
||||
|
||||
// Create a fake 'Device' whose only interesting attribute is a non-default
|
||||
// DeviceLocality and incarnation.
|
||||
static std::unique_ptr<Device> NewDevice(const string& type, const string& name,
|
||||
int numa_node, uint64 incarnation) {
|
||||
std::unique_ptr<Device> NewDevice(const string& type, const string& name) {
|
||||
class FakeDevice : public Device {
|
||||
public:
|
||||
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
|
||||
@ -55,161 +44,121 @@ static std::unique_ptr<Device> NewDevice(const string& type, const string& name,
|
||||
DeviceAttributes attr;
|
||||
attr.set_name(name);
|
||||
attr.set_device_type(type);
|
||||
attr.mutable_locality()->set_numa_node(numa_node);
|
||||
attr.set_incarnation(incarnation);
|
||||
attr.set_incarnation(random::New64());
|
||||
return absl::make_unique<FakeDevice>(attr);
|
||||
}
|
||||
|
||||
// Create a fake WorkerInterface that responds to requests without RPCs,
|
||||
// in this case returning the DeviceAttributes of a fake remote worker.
|
||||
class FakeWorker : public TestWorkerInterface {
|
||||
public:
|
||||
FakeWorker(const string& name, DeviceMgr* dev_mgr,
|
||||
DeviceResolverDistributed* dres)
|
||||
: name_(name), device_mgr_(dev_mgr), device_resolver_(dres) {}
|
||||
|
||||
void GetStatusAsync(const GetStatusRequest* request,
|
||||
GetStatusResponse* response, bool fail_fast,
|
||||
StatusCallback done) override {
|
||||
std::vector<DeviceAttributes> dev_attr;
|
||||
device_mgr_->ListDeviceAttributes(&dev_attr);
|
||||
for (const auto& da : dev_attr) {
|
||||
*response->add_device_attributes() = da;
|
||||
}
|
||||
done(Status::OK());
|
||||
}
|
||||
|
||||
private:
|
||||
string name_;
|
||||
DeviceMgr* device_mgr_;
|
||||
DeviceResolverDistributed* device_resolver_;
|
||||
};
|
||||
|
||||
// An implementation of WorkerCacheInterface that routes all requests
|
||||
// to local FakeWorkers, implementing only the methods needed for tests.
|
||||
class FakeCache : public TestWorkerCache {
|
||||
public:
|
||||
// Override the Locality methods to actually pass through to the
|
||||
// worker.
|
||||
bool GetDeviceLocalityNonBlocking(const string& device,
|
||||
DeviceLocality* locality) override {
|
||||
return false;
|
||||
}
|
||||
|
||||
void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
|
||||
StatusCallback done) override {
|
||||
string task_name;
|
||||
string dev_part;
|
||||
if (!DeviceNameUtils::SplitDeviceName(device, &task_name, &dev_part)) {
|
||||
done(errors::Internal("failed to parse device name"));
|
||||
return;
|
||||
}
|
||||
auto it = workers_.find(task_name);
|
||||
if (it == workers_.end()) {
|
||||
done(errors::Internal("failed to find worker ", task_name));
|
||||
return;
|
||||
}
|
||||
WorkerInterface* wi = it->second;
|
||||
GetStatusRequest req;
|
||||
GetStatusResponse resp;
|
||||
Status status = wi->GetStatus(&req, &resp);
|
||||
if (!status.ok()) {
|
||||
done(status);
|
||||
return;
|
||||
}
|
||||
for (const auto& it : resp.device_attributes()) {
|
||||
if (it.name() == device) {
|
||||
*locality = it.locality();
|
||||
done(Status::OK());
|
||||
return;
|
||||
}
|
||||
}
|
||||
done(errors::Internal("device not found: ", device));
|
||||
}
|
||||
};
|
||||
|
||||
class DeviceResDistTest : public ::testing::Test {
|
||||
protected:
|
||||
DeviceResDistTest() {}
|
||||
|
||||
~DeviceResDistTest() override {
|
||||
for (DeviceMgr* dm : device_mgrs_) {
|
||||
delete dm;
|
||||
}
|
||||
for (auto it : resolvers_) {
|
||||
delete it.second;
|
||||
}
|
||||
for (FakeWorker* w : workers_) {
|
||||
delete w;
|
||||
}
|
||||
}
|
||||
|
||||
void DefineWorkers(int num_workers, int num_devices,
|
||||
const string& device_type,
|
||||
uint64 device_incarnation_base) {
|
||||
for (int w = 0; w < num_workers; ++w) {
|
||||
string name = strings::StrCat("/job:worker/replica:0/task:", w);
|
||||
DefineWorker(name, device_type, num_devices,
|
||||
w * num_devices + device_incarnation_base);
|
||||
}
|
||||
}
|
||||
|
||||
void DefineWorker(const string& worker_name, const string& device_type,
|
||||
int num_devices, uint64 device_incarnation_base) {
|
||||
void SetUp() override {
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
for (int i = 0; i < num_devices; ++i) {
|
||||
devices.push_back(NewDevice(
|
||||
device_type,
|
||||
strings::StrCat(worker_name, "/device:", device_type, ":", i), i,
|
||||
device_incarnation_base + i));
|
||||
}
|
||||
DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
|
||||
TestableDeviceResolverDistributed* dev_res =
|
||||
new TestableDeviceResolverDistributed(dev_mgr, &wc_, worker_name);
|
||||
resolvers_[worker_name] = dev_res;
|
||||
device_mgrs_.push_back(dev_mgr);
|
||||
std::vector<string>* dv = &dev_by_task_[worker_name];
|
||||
dv->clear();
|
||||
for (auto* d : dev_mgr->ListDevices()) {
|
||||
dv->push_back(d->name());
|
||||
}
|
||||
FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res);
|
||||
workers_.push_back(fw);
|
||||
wc_.AddWorker(worker_name, fw);
|
||||
devices.push_back(
|
||||
NewDevice("CPU", "/job:worker/replica:0/task:0/device:CPU:0"));
|
||||
devices.push_back(
|
||||
NewDevice("CPU", "/job:worker/replica:0/task:0/device:CPU:1"));
|
||||
dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
dev_resolver_ =
|
||||
absl::make_unique<DeviceResolverDistributed>(dev_mgr_.get());
|
||||
|
||||
std::vector<DeviceAttributes> attributes;
|
||||
attributes.push_back(
|
||||
NewDevice("CPU", "/job:worker/replica:0/task:1/device:CPU:0")
|
||||
->attributes());
|
||||
attributes.push_back(
|
||||
NewDevice("CPU", "/job:worker/replica:0/task:1/device:CPU:1")
|
||||
->attributes());
|
||||
TF_ASSERT_OK(dev_resolver_->UpdateDeviceAttributes(attributes));
|
||||
}
|
||||
|
||||
FakeCache wc_;
|
||||
std::vector<DeviceMgr*> device_mgrs_;
|
||||
std::unordered_map<string, TestableDeviceResolverDistributed*> resolvers_;
|
||||
std::unordered_map<string, std::vector<string>> dev_by_task_;
|
||||
std::vector<FakeWorker*> workers_;
|
||||
std::unique_ptr<DeviceMgr> dev_mgr_;
|
||||
std::unique_ptr<DeviceResolverDistributed> dev_resolver_;
|
||||
};
|
||||
|
||||
TEST_F(DeviceResDistTest, Workers3Devices4) {
|
||||
DefineWorkers(/*num_workers=*/3, /*num_devices=*/4, /*device_type=*/"CPU",
|
||||
/*device_incarnation_base=*/1);
|
||||
// Check that every device is available from every task.
|
||||
for (auto it : resolvers_) {
|
||||
DeviceResolverDistributed* dres = it.second;
|
||||
for (auto it2 : dev_by_task_) {
|
||||
const string& task_name = it2.first;
|
||||
for (const auto& dev_name : it2.second) {
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
ASSERT_TRUE(DeviceNameUtils::ParseFullName(dev_name, &parsed));
|
||||
Notification note;
|
||||
Status status;
|
||||
DeviceAttributes attributes;
|
||||
dres->GetDeviceAttributesAsync(dev_name, task_name, &attributes,
|
||||
[¬e, &status](const Status& s) {
|
||||
status = s;
|
||||
note.Notify();
|
||||
});
|
||||
note.WaitForNotification();
|
||||
TF_EXPECT_OK(status);
|
||||
EXPECT_EQ(parsed.id, attributes.locality().numa_node());
|
||||
}
|
||||
}
|
||||
}
|
||||
TEST_F(DeviceResDistTest, GetDeviceAttributesLocal) {
|
||||
DeviceAttributes attributes;
|
||||
TF_ASSERT_OK(dev_resolver_->GetDeviceAttributes(
|
||||
"/job:worker/replica:0/task:0/device:CPU:0", &attributes));
|
||||
EXPECT_EQ(attributes.name(), "/job:worker/replica:0/task:0/device:CPU:0");
|
||||
}
|
||||
|
||||
TEST_F(DeviceResDistTest, GetDeviceAttributesLocalUnknown) {
|
||||
DeviceAttributes attributes;
|
||||
EXPECT_TRUE(errors::IsNotFound(dev_resolver_->GetDeviceAttributes(
|
||||
"/job:worker/replica:0/task:0/device:CPU:9", &attributes)));
|
||||
}
|
||||
|
||||
TEST_F(DeviceResDistTest, GetAllDeviceAttributes) {
|
||||
std::vector<DeviceAttributes> attributes;
|
||||
TF_ASSERT_OK(dev_resolver_->GetAllDeviceAttributes(
|
||||
"/job:worker/replica:0/task:0", &attributes));
|
||||
EXPECT_THAT(attributes,
|
||||
UnorderedElementsAre(
|
||||
Property(&DeviceAttributes::name,
|
||||
"/job:worker/replica:0/task:0/device:CPU:0"),
|
||||
Property(&DeviceAttributes::name,
|
||||
"/job:worker/replica:0/task:0/device:CPU:1")));
|
||||
TF_ASSERT_OK(dev_resolver_->GetAllDeviceAttributes(
|
||||
"/job:worker/replica:0/task:1", &attributes));
|
||||
EXPECT_THAT(attributes,
|
||||
UnorderedElementsAre(
|
||||
Property(&DeviceAttributes::name,
|
||||
"/job:worker/replica:0/task:1/device:CPU:0"),
|
||||
Property(&DeviceAttributes::name,
|
||||
"/job:worker/replica:0/task:1/device:CPU:1")));
|
||||
}
|
||||
|
||||
TEST_F(DeviceResDistTest, GetAllDeviceAttributesUnknown) {
|
||||
std::vector<DeviceAttributes> attributes;
|
||||
EXPECT_TRUE(errors::IsNotFound(dev_resolver_->GetAllDeviceAttributes(
|
||||
"/job:worker/replica:0/task:3", &attributes)));
|
||||
}
|
||||
|
||||
TEST_F(DeviceResDistTest, UpdateDeviceAttributes) {
|
||||
std::vector<DeviceAttributes> attributes;
|
||||
attributes.push_back(
|
||||
NewDevice("CPU", "/job:worker/replica:0/task:2/device:CPU:0")
|
||||
->attributes());
|
||||
attributes.push_back(
|
||||
NewDevice("CPU", "/job:worker/replica:0/task:2/device:CPU:1")
|
||||
->attributes());
|
||||
TF_ASSERT_OK(dev_resolver_->UpdateDeviceAttributes(attributes));
|
||||
// Get the new task.
|
||||
TF_ASSERT_OK(dev_resolver_->GetAllDeviceAttributes(
|
||||
"/job:worker/replica:0/task:2", &attributes));
|
||||
EXPECT_THAT(attributes,
|
||||
UnorderedElementsAre(
|
||||
Property(&DeviceAttributes::name,
|
||||
"/job:worker/replica:0/task:2/device:CPU:0"),
|
||||
Property(&DeviceAttributes::name,
|
||||
"/job:worker/replica:0/task:2/device:CPU:1")));
|
||||
// Get an existing task.
|
||||
TF_ASSERT_OK(dev_resolver_->GetAllDeviceAttributes(
|
||||
"/job:worker/replica:0/task:0", &attributes));
|
||||
EXPECT_THAT(attributes,
|
||||
UnorderedElementsAre(
|
||||
Property(&DeviceAttributes::name,
|
||||
"/job:worker/replica:0/task:0/device:CPU:0"),
|
||||
Property(&DeviceAttributes::name,
|
||||
"/job:worker/replica:0/task:0/device:CPU:1")));
|
||||
}
|
||||
|
||||
TEST_F(DeviceResDistTest, UpdateDeviceAttributesExisting) {
|
||||
std::vector<DeviceAttributes> attributes;
|
||||
TF_ASSERT_OK(dev_resolver_->GetAllDeviceAttributes(
|
||||
"/job:worker/replica:0/task:0", &attributes));
|
||||
TF_ASSERT_OK(dev_resolver_->UpdateDeviceAttributes(attributes));
|
||||
}
|
||||
|
||||
TEST_F(DeviceResDistTest, UpdateDeviceAttributesDifferentIncarnation) {
|
||||
std::vector<DeviceAttributes> attributes;
|
||||
attributes.push_back(
|
||||
NewDevice("CPU", "/job:worker/replica:0/task:0/device:CPU:0")
|
||||
->attributes());
|
||||
attributes.push_back(
|
||||
NewDevice("CPU", "/job:worker/replica:0/task:0/device:CPU:1")
|
||||
->attributes());
|
||||
EXPECT_TRUE(errors::IsFailedPrecondition(
|
||||
dev_resolver_->UpdateDeviceAttributes(attributes)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -279,8 +279,7 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
|
||||
}
|
||||
} else {
|
||||
std::unique_ptr<DeviceResolverDistributed> dev_resolver(
|
||||
new DeviceResolverDistributed(worker_env_.device_mgr, worker_cache,
|
||||
default_worker_name));
|
||||
new DeviceResolverDistributed(worker_env_.device_mgr));
|
||||
std::unique_ptr<CollectiveParamResolverDistributed> param_resolver(
|
||||
new CollectiveParamResolverDistributed(config, worker_env_.device_mgr,
|
||||
dev_resolver.get(), worker_cache,
|
||||
@ -448,8 +447,7 @@ Status GrpcServer::UpdateServerDef(const ServerDef& server_def) {
|
||||
return errors::Internal("Could not parse worker name.");
|
||||
}
|
||||
std::unique_ptr<DeviceResolverDistributed> dev_resolver(
|
||||
new DeviceResolverDistributed(worker_env_.device_mgr, worker_cache,
|
||||
default_worker_name));
|
||||
new DeviceResolverDistributed(worker_env_.device_mgr));
|
||||
std::unique_ptr<CollectiveParamResolverDistributed> param_resolver(
|
||||
new CollectiveParamResolverDistributed(
|
||||
server_def_.default_session_config(), worker_env_.device_mgr,
|
||||
|
@ -45,8 +45,8 @@ class RpcCollectiveExecutorMgrTest : public ::testing::Test {
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
std::unique_ptr<DeviceResolverDistributed> dr(new DeviceResolverDistributed(
|
||||
device_mgr_.get(), worker_cache, task_name));
|
||||
std::unique_ptr<DeviceResolverDistributed> dr(
|
||||
new DeviceResolverDistributed(device_mgr_.get()));
|
||||
std::unique_ptr<CollectiveParamResolverDistributed> cpr(
|
||||
new CollectiveParamResolverDistributed(options.config,
|
||||
device_mgr_.get(), dr.get(),
|
||||
|
@ -154,22 +154,18 @@ class DeviceResolverInterface {
|
||||
public:
|
||||
virtual ~DeviceResolverInterface() {}
|
||||
|
||||
// Collects DeviceAttributes protobufs from all of the devices identified
|
||||
// in 'col_params'.
|
||||
virtual void GetAllDeviceAttributesAsync(
|
||||
const std::vector<string>& devices, const std::vector<string>& tasks,
|
||||
std::vector<DeviceAttributes>* attributes,
|
||||
const StatusCallback& done) = 0;
|
||||
|
||||
// Populates *attributes with the DeviceAttributes of the specified device.
|
||||
virtual void GetDeviceAttributesAsync(const string& device,
|
||||
const string& task,
|
||||
DeviceAttributes* attributes,
|
||||
const StatusCallback& done) = 0;
|
||||
virtual Status GetDeviceAttributes(const string& device,
|
||||
DeviceAttributes* attributes) = 0;
|
||||
|
||||
// Returns the cached device attributes of a task.
|
||||
virtual Status GetTaskCached(const string& task,
|
||||
std::vector<DeviceAttributes>* attributes) = 0;
|
||||
// Returns all device attributes of a task.
|
||||
virtual Status GetAllDeviceAttributes(
|
||||
const string& task, std::vector<DeviceAttributes>* attributes) = 0;
|
||||
|
||||
// Updates device attributes. It returns error if any device already
|
||||
// exists in the DeviceResolver and has a different incarnation.
|
||||
virtual Status UpdateDeviceAttributes(
|
||||
const std::vector<DeviceAttributes>& attributes) = 0;
|
||||
};
|
||||
|
||||
// Interface that provides resolution of shared CollectiveParams fields.
|
||||
|
Loading…
Reference in New Issue
Block a user