Add DeviceAttributes to the collectives device resolver interface.
This change surfaces DeviceAttributes via the DeviceResolverInterface, by changing GetDeviceLocalityAsync to GetDeviceAttributesAsync. The eventual goal is to pass device incarnation numbers of the remote device as a part of collective ops's RPCs, and use them to detect worker failures. We can still access DeviceLocality as it is a part of DeviceAttributes. As a part of this change, we also add the ability to clear the device resolver cache upon a FailedPrecondition error. The intention is to throw this error upon detecting worker restart due to device incarnation mismatch. PiperOrigin-RevId: 254784520
This commit is contained in:
parent
7259c86fd0
commit
9070aa6976
tensorflow
contrib/gdr
core
common_runtime
collective_param_resolver_local.cccollective_param_resolver_local.hcollective_param_resolver_local_test.cccollective_rma_local.cccollective_rma_local.hdevice_resolver_local.ccdevice_resolver_local.hdevice_resolver_local_test.cc
distributed_runtime
BUILDcollective_rma_distributed.ccdevice_resolver_distributed.ccdevice_resolver_distributed.hdevice_resolver_distributed_test.cc
framework
@ -93,7 +93,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
|
||||
// State that needs to be threaded through a couple of async calls
|
||||
// in order to make this function completely non-blocking.
|
||||
struct State {
|
||||
DeviceLocality server_locality;
|
||||
DeviceAttributes server_attributes;
|
||||
std::unique_ptr<RecvBufCall> call;
|
||||
};
|
||||
State* state = new State;
|
||||
@ -108,6 +108,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
|
||||
}
|
||||
if (!s.ok() && errors::IsFailedPrecondition(s)) {
|
||||
dev_resolver_->ClearTask(peer_task);
|
||||
done(s);
|
||||
}
|
||||
|
||||
delete state;
|
||||
@ -124,14 +125,15 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
|
||||
} else {
|
||||
state->call.reset(new RecvBufCall(
|
||||
step_id_, peer_device, peer_task, key, to_device, to_device_ctx,
|
||||
to_alloc_attr, to_tensor, client_locality, state->server_locality,
|
||||
&cancel_mgr_, worker_cache_));
|
||||
to_alloc_attr, to_tensor, client_locality,
|
||||
state->server_attributes.locality(), &cancel_mgr_, worker_cache_));
|
||||
state->call->Start(recv_buf_callback);
|
||||
}
|
||||
};
|
||||
|
||||
dev_resolver_->GetLocalityAsync(
|
||||
peer_device, peer_task, &state->server_locality, dev_locality_callback);
|
||||
dev_resolver_->GetDeviceAttributesAsync(peer_device, peer_task,
|
||||
&state->server_attributes,
|
||||
dev_locality_callback);
|
||||
}
|
||||
|
||||
void StartAbort(const Status& s) override {
|
||||
|
@ -209,10 +209,10 @@ typedef std::unordered_map<string, TaskDeviceMap> GlobalDeviceMap;
|
||||
|
||||
// Create a populated GlobalDeviceMap from CollInstanceParams and localities.
|
||||
GlobalDeviceMap BuildDevRecs(const CollInstanceParams& ip,
|
||||
const std::vector<DeviceLocality>& localities) {
|
||||
const std::vector<DeviceAttributes>& attributes) {
|
||||
GlobalDeviceMap gdm;
|
||||
CHECK_EQ(ip.device_names.size(), ip.task_names.size());
|
||||
CHECK_EQ(ip.device_names.size(), localities.size());
|
||||
CHECK_EQ(ip.device_names.size(), attributes.size());
|
||||
for (int i = 0; i < ip.device_names.size(); ++i) {
|
||||
TaskDeviceMap& tdm = gdm[ip.task_names[i]];
|
||||
DevRec* dr = &tdm[ip.device_names[i]];
|
||||
@ -221,7 +221,7 @@ GlobalDeviceMap BuildDevRecs(const CollInstanceParams& ip,
|
||||
dr->original_rank = i;
|
||||
dr->local_rank = 0; // Will be populated later by OrderTaskDeviceMap.
|
||||
dr->global_rank = 0; // Will be populated later by EstablishGlobalRank.
|
||||
dr->locality = &localities[i];
|
||||
dr->locality = &attributes[i].locality();
|
||||
}
|
||||
return gdm;
|
||||
}
|
||||
@ -342,9 +342,9 @@ void OrderTaskDeviceMap(const string& gpu_ring_order, TaskDeviceMap* tdm) {
|
||||
// sharing the same device group where there is more than one good
|
||||
// order.
|
||||
GlobalDeviceMap EstablishGlobalRank(
|
||||
CollectiveParams* cp, const std::vector<DeviceLocality>& localities) {
|
||||
CollectiveParams* cp, const std::vector<DeviceAttributes>& attributes) {
|
||||
VLOG(1) << "EstablishGlobalRank";
|
||||
GlobalDeviceMap gdm = BuildDevRecs(cp->instance, localities);
|
||||
GlobalDeviceMap gdm = BuildDevRecs(cp->instance, attributes);
|
||||
for (auto& iter : gdm) {
|
||||
TaskDeviceMap& tdm = iter.second;
|
||||
OrderTaskDeviceMap(cp->instance.gpu_ring_order, &tdm);
|
||||
@ -472,20 +472,26 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams(
|
||||
// Get Locality data for all devices.
|
||||
|
||||
// Set is_local and task_names in *shared prior to invoking
|
||||
// GetDeviceLocalitiesAsync. In a distributed context this function can be
|
||||
// GetDeviceAttributesAsync. In a distributed context this function can be
|
||||
// called by a derived class, some of the devices may be non-local and
|
||||
// GetDeviceLocalitiesAsync will use those fields to launch RPCs.
|
||||
// GetDeviceAttributesAsync will use those fields to launch RPCs.
|
||||
CompleteTaskIsLocal(task_name_, &ir->shared);
|
||||
|
||||
// Because the callback may execute in a different thread, we release
|
||||
// ir->out_mu here. Before releasing, we mark it as unavailable for other
|
||||
// threads.
|
||||
ir->out_mu_available = false;
|
||||
const auto device_names = ir->shared.instance.device_names;
|
||||
const auto task_names = ir->shared.instance.task_names;
|
||||
ir->out_mu.unlock();
|
||||
std::vector<DeviceLocality>* localities = new std::vector<DeviceLocality>;
|
||||
dev_resolver_->GetDeviceLocalitiesAsync(
|
||||
ir->shared.instance, localities,
|
||||
[this, gr, cp, ir, localities, done](const Status& s)
|
||||
std::vector<DeviceAttributes>* attributes = new std::vector<DeviceAttributes>;
|
||||
// Suppress linter warning about access to shared without mutex because in
|
||||
// principle the members are locked due to out_mu_available=false.
|
||||
dev_resolver_->GetAllDeviceAttributesAsync(
|
||||
ir->shared.instance.device_names, // NOLINT
|
||||
ir->shared.instance.task_names, // NOLINT
|
||||
attributes,
|
||||
[this, gr, cp, ir, attributes, done](const Status& s)
|
||||
EXCLUSIVE_LOCK_FUNCTION(ir->out_mu) {
|
||||
// Then we recover the lock in the callback thread that will hold it
|
||||
// through the rest of the call chain. Signal the cv now, any
|
||||
@ -495,26 +501,26 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams(
|
||||
ir->out_mu_available = true;
|
||||
ir->out_cv.notify_all();
|
||||
if (s.ok()) {
|
||||
CompleteDefaultRanking(gr, cp, ir, *localities);
|
||||
CompleteDefaultRanking(gr, cp, ir, *attributes);
|
||||
done(Status::OK());
|
||||
} else {
|
||||
done(s);
|
||||
}
|
||||
delete localities;
|
||||
delete attributes;
|
||||
});
|
||||
}
|
||||
|
||||
// NOTE(ayushd): The DeviceLocality objects in localities will have LocalLinks
|
||||
// NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks
|
||||
// to all devices that they are physically connected to and visible to the
|
||||
// TensorFlow runtime. This set of devices may be a superset of the devices
|
||||
// participating in this instance of collectives.
|
||||
void CollectiveParamResolverLocal::CompleteDefaultRanking(
|
||||
const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
|
||||
const std::vector<DeviceLocality>& localities) {
|
||||
const std::vector<DeviceAttributes>& attributes) {
|
||||
// Establish an instance-specific default rank order for devices
|
||||
// based on localities. This rank order should be a good ring
|
||||
// order, if possible.
|
||||
GlobalDeviceMap gdm = EstablishGlobalRank(&ir->shared, localities);
|
||||
GlobalDeviceMap gdm = EstablishGlobalRank(&ir->shared, attributes);
|
||||
// Reflect the new global ranking on shared
|
||||
size_t num_devices = ir->shared.group.group_size;
|
||||
std::vector<string> new_device_names(num_devices, "");
|
||||
@ -599,7 +605,7 @@ void CollectiveParamResolverLocal::CallInitInstanceSharedParams(
|
||||
// before all the function stack frames pop. The static analysis will
|
||||
// not allow that.
|
||||
//
|
||||
// *the lock is dropped just before calling GetDeviceLocalitiesAsync, because
|
||||
// *the lock is dropped just before calling GetDeviceAttributesAsync, because
|
||||
// there is no guarantee that the thread that executes the callback is the
|
||||
// same as the one that locked ir->out_mu. To prevent other threads from
|
||||
// grabbing ir->out_mu, we mark ir->out_mu_available as false. Hence, in
|
||||
|
@ -182,7 +182,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
|
||||
// ir->shared.instance.task_names by considering localities of all devices.
|
||||
void CompleteDefaultRanking(const GroupRec* gr, const CollectiveParams* cp,
|
||||
InstanceRec* ir,
|
||||
const std::vector<DeviceLocality>& localities)
|
||||
const std::vector<DeviceAttributes>& attributes)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(ir->out_mu);
|
||||
|
||||
// Finish populating *cp.
|
||||
|
@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
@ -47,7 +47,7 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
|
||||
|
||||
void RunCompleteDefaultRanking(
|
||||
const CollectiveParams& shared_cp,
|
||||
const std::vector<DeviceLocality>& localities,
|
||||
const std::vector<DeviceAttributes>& attributes,
|
||||
const std::vector<int32>& gpu_ring_order,
|
||||
const std::vector<string>& expected_device_order) {
|
||||
CollectiveParams cp;
|
||||
@ -69,7 +69,7 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
|
||||
ir.shared.instance.gpu_ring_order, gpu_ring_order.back());
|
||||
}
|
||||
VLOG(2) << "gpu_ring_order " << ir.shared.instance.gpu_ring_order;
|
||||
prl_->CompleteDefaultRanking(nullptr, &cp, &ir, localities);
|
||||
prl_->CompleteDefaultRanking(nullptr, &cp, &ir, attributes);
|
||||
EXPECT_EQ(ir.shared.instance.device_names, expected_device_order);
|
||||
}
|
||||
}
|
||||
@ -82,7 +82,7 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
|
||||
TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) {
|
||||
constexpr int kNumGpus = 8;
|
||||
CollectiveParams cp;
|
||||
std::vector<DeviceLocality> localities(kNumGpus);
|
||||
std::vector<DeviceAttributes> attributes(kNumGpus);
|
||||
cp.name = "PRLTest";
|
||||
cp.group.device_type = DeviceType("GPU");
|
||||
cp.group.num_tasks = 1;
|
||||
@ -95,7 +95,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) {
|
||||
cp.instance.task_names.push_back("/job:localhost/replica:0/task:0");
|
||||
cp.instance.device_names.push_back(strings::StrCat(
|
||||
"/job:localhost/replica:0/task:0/device:GPU:", gpu_idx));
|
||||
DeviceLocality* locality = &localities[gpu_idx];
|
||||
DeviceLocality locality;
|
||||
// Build localities so that 0,1,6,7 and 2,3,4,5 form 2 strongly connected
|
||||
// components. Across components, connect 3 and 7.
|
||||
for (int link_idx = 0; link_idx < kNumGpus; ++link_idx) {
|
||||
@ -104,20 +104,21 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) {
|
||||
bool link_in_clique1 = clique1.find(link_idx) != clique1.end();
|
||||
if ((gpu_in_clique1 && link_in_clique1) ||
|
||||
(!gpu_in_clique1 && !link_in_clique1)) {
|
||||
LocalLinks* links = locality->mutable_links();
|
||||
LocalLinks* links = locality.mutable_links();
|
||||
InterconnectLink* ilink = links->add_link();
|
||||
ilink->set_device_id(link_idx);
|
||||
ilink->set_strength(2);
|
||||
} else if ((gpu_idx == 3 && link_idx == 7) ||
|
||||
(gpu_idx == 7 && link_idx == 3)) {
|
||||
LocalLinks* links = locality->mutable_links();
|
||||
LocalLinks* links = locality.mutable_links();
|
||||
InterconnectLink* ilink = links->add_link();
|
||||
ilink->set_device_id(link_idx);
|
||||
ilink->set_strength(1);
|
||||
}
|
||||
}
|
||||
*attributes[gpu_idx].mutable_locality() = locality;
|
||||
}
|
||||
RunCompleteDefaultRanking(cp, localities, {1, 3, 5, 7, 6, 4, 2, 0},
|
||||
RunCompleteDefaultRanking(cp, attributes, {1, 3, 5, 7, 6, 4, 2, 0},
|
||||
{
|
||||
"/job:localhost/replica:0/task:0/device:GPU:1",
|
||||
"/job:localhost/replica:0/task:0/device:GPU:3",
|
||||
@ -128,7 +129,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) {
|
||||
"/job:localhost/replica:0/task:0/device:GPU:2",
|
||||
"/job:localhost/replica:0/task:0/device:GPU:0",
|
||||
});
|
||||
RunCompleteDefaultRanking(cp, localities, {7, 6, 5, 4, 3, 2, 1, 0},
|
||||
RunCompleteDefaultRanking(cp, attributes, {7, 6, 5, 4, 3, 2, 1, 0},
|
||||
{
|
||||
"/job:localhost/replica:0/task:0/device:GPU:7",
|
||||
"/job:localhost/replica:0/task:0/device:GPU:6",
|
||||
@ -141,7 +142,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) {
|
||||
});
|
||||
// With no gpu_ring_order passed, automatic link detection should kick in.
|
||||
// Starting at dev 0, the best order would be: 0,1,6,7,3,2,4,5
|
||||
RunCompleteDefaultRanking(cp, localities, {},
|
||||
RunCompleteDefaultRanking(cp, attributes, {},
|
||||
{
|
||||
"/job:localhost/replica:0/task:0/device:GPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:GPU:1",
|
||||
|
@ -21,6 +21,9 @@ namespace tensorflow {
|
||||
|
||||
void CollectiveRemoteAccessLocal::StartAbort(const Status& s) {
|
||||
buf_rendezvous_.StartAbort(s);
|
||||
if (errors::IsFailedPrecondition(s)) {
|
||||
dev_resolver_->ClearCache();
|
||||
}
|
||||
}
|
||||
|
||||
void CollectiveRemoteAccessLocal::RecvFromPeer(
|
||||
|
@ -52,22 +52,26 @@ class CollectiveRemoteAccessLocal : public PerStepCollectiveRemoteAccess {
|
||||
const DeviceLocality& client_locality,
|
||||
const StatusCallback& done) override;
|
||||
|
||||
void GetDeviceLocalitiesAsync(const CollInstanceParams& ci_params,
|
||||
std::vector<DeviceLocality>* localities,
|
||||
const StatusCallback& done) override {
|
||||
dev_resolver_->GetDeviceLocalitiesAsync(ci_params, localities, done);
|
||||
void GetAllDeviceAttributesAsync(const std::vector<string>& devices,
|
||||
const std::vector<string>& tasks,
|
||||
std::vector<DeviceAttributes>* attributes,
|
||||
const StatusCallback& done) override {
|
||||
dev_resolver_->GetAllDeviceAttributesAsync(devices, tasks, attributes,
|
||||
done);
|
||||
}
|
||||
|
||||
void GetLocalityAsync(const string& device, const string& task,
|
||||
DeviceLocality* locality,
|
||||
const StatusCallback& done) override {
|
||||
dev_resolver_->GetLocalityAsync(device, task, locality, done);
|
||||
void GetDeviceAttributesAsync(const string& device, const string& task,
|
||||
DeviceAttributes* attributes,
|
||||
const StatusCallback& done) override {
|
||||
dev_resolver_->GetDeviceAttributesAsync(device, task, attributes, done);
|
||||
}
|
||||
|
||||
void ClearTask(const string& task) override {
|
||||
dev_resolver_->ClearTask(task);
|
||||
}
|
||||
|
||||
void ClearCache() override { dev_resolver_->ClearCache(); }
|
||||
|
||||
BufRendezvous* buf_rendezvous() override { return &buf_rendezvous_; }
|
||||
|
||||
// Copy utility that always copies bytes from src to dst even if
|
||||
|
@ -18,30 +18,30 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void DeviceResolverLocal::GetDeviceLocalitiesAsync(
|
||||
const CollInstanceParams& ci_params,
|
||||
std::vector<DeviceLocality>* localities, const StatusCallback& done) {
|
||||
localities->clear();
|
||||
for (const string& device_name : ci_params.device_names) {
|
||||
void DeviceResolverLocal::GetAllDeviceAttributesAsync(
|
||||
const std::vector<string>& devices, const std::vector<string>& tasks,
|
||||
std::vector<DeviceAttributes>* attributes, const StatusCallback& done) {
|
||||
attributes->clear();
|
||||
for (const string& device_name : devices) {
|
||||
Device* dev;
|
||||
Status s = dev_mgr_->LookupDevice(device_name, &dev);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
localities->push_back(dev->attributes().locality());
|
||||
attributes->push_back(dev->attributes());
|
||||
}
|
||||
done(Status::OK());
|
||||
}
|
||||
|
||||
void DeviceResolverLocal::GetLocalityAsync(const string& device,
|
||||
const string& task,
|
||||
DeviceLocality* locality,
|
||||
const StatusCallback& done) {
|
||||
void DeviceResolverLocal::GetDeviceAttributesAsync(const string& device,
|
||||
const string& task,
|
||||
DeviceAttributes* attributes,
|
||||
const StatusCallback& done) {
|
||||
Device* dev;
|
||||
Status s = dev_mgr_->LookupDevice(device, &dev);
|
||||
if (s.ok()) {
|
||||
*locality = dev->attributes().locality();
|
||||
*attributes = dev->attributes();
|
||||
}
|
||||
done(s);
|
||||
}
|
||||
|
@ -30,16 +30,19 @@ class DeviceResolverLocal : public DeviceResolverInterface {
|
||||
|
||||
virtual ~DeviceResolverLocal() {}
|
||||
|
||||
void GetDeviceLocalitiesAsync(const CollInstanceParams& ci_params,
|
||||
std::vector<DeviceLocality>* localities,
|
||||
void GetAllDeviceAttributesAsync(const std::vector<string>& devices,
|
||||
const std::vector<string>& tasks,
|
||||
std::vector<DeviceAttributes>* attributes,
|
||||
const StatusCallback& done) override;
|
||||
|
||||
void GetDeviceAttributesAsync(const string& device, const string& task,
|
||||
DeviceAttributes* attributes,
|
||||
const StatusCallback& done) override;
|
||||
|
||||
void GetLocalityAsync(const string& device, const string& task,
|
||||
DeviceLocality* locality,
|
||||
const StatusCallback& done) override;
|
||||
|
||||
void ClearTask(const string& task) override {}
|
||||
|
||||
void ClearCache() override {}
|
||||
|
||||
protected:
|
||||
const DeviceMgr* dev_mgr_;
|
||||
};
|
||||
|
@ -45,41 +45,36 @@ class DeviceResolverLocalTest : public ::testing::Test {
|
||||
std::unique_ptr<DeviceResolverLocal> drl_;
|
||||
};
|
||||
|
||||
TEST_F(DeviceResolverLocalTest, GetDeviceLocalitiesKnown) {
|
||||
CollectiveParams cp;
|
||||
std::vector<DeviceLocality> localities;
|
||||
cp.instance.device_names.push_back(
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
cp.instance.device_names.push_back(
|
||||
"/job:localhost/replica:0/task:0/device:CPU:2");
|
||||
TEST_F(DeviceResolverLocalTest, GetDeviceAttributesKnown) {
|
||||
std::vector<DeviceAttributes> attributes;
|
||||
std::vector<string> devices{"/job:localhost/replica:0/task:0/device:CPU:1",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:2"};
|
||||
Notification note;
|
||||
Status status;
|
||||
drl_->GetDeviceLocalitiesAsync(cp.instance, &localities,
|
||||
[¬e, &status](const Status& s) {
|
||||
status = s;
|
||||
note.Notify();
|
||||
});
|
||||
drl_->GetAllDeviceAttributesAsync(devices, /*tasks=*/{}, &attributes,
|
||||
[¬e, &status](const Status& s) {
|
||||
status = s;
|
||||
note.Notify();
|
||||
});
|
||||
note.WaitForNotification();
|
||||
TF_EXPECT_OK(status);
|
||||
EXPECT_EQ(2, localities.size());
|
||||
EXPECT_EQ(2, attributes.size());
|
||||
}
|
||||
|
||||
TEST_F(DeviceResolverLocalTest, GetDeviceLocalitiesUnknown) {
|
||||
CollectiveParams cp;
|
||||
std::vector<DeviceLocality> localities;
|
||||
TEST_F(DeviceResolverLocalTest, GetDeviceAttributesUnknown) {
|
||||
std::vector<DeviceAttributes> attributes;
|
||||
// In some builds there may be 1 GPU, but there should never be 9.
|
||||
cp.instance.device_names.push_back(
|
||||
"/job:localhost/replica:0/task:0/device:GPU:9");
|
||||
std::vector<string> devices{"/job:localhost/replica:0/task:0/device:GPU:9"};
|
||||
Notification note;
|
||||
Status status;
|
||||
drl_->GetDeviceLocalitiesAsync(cp.instance, &localities,
|
||||
[¬e, &status](const Status& s) {
|
||||
status = s;
|
||||
note.Notify();
|
||||
});
|
||||
drl_->GetAllDeviceAttributesAsync(devices, /*tasks=*/{}, &attributes,
|
||||
[¬e, &status](const Status& s) {
|
||||
status = s;
|
||||
note.Notify();
|
||||
});
|
||||
note.WaitForNotification();
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_EQ(0, localities.size());
|
||||
EXPECT_EQ(0, attributes.size());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -619,6 +619,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -89,7 +89,7 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
|
||||
// State that needs to be threaded through a couple of async calls
|
||||
// in order to make this function completely non-blocking.
|
||||
struct State {
|
||||
DeviceLocality server_locality;
|
||||
DeviceAttributes server_attributes;
|
||||
std::unique_ptr<RecvBufCall> call;
|
||||
};
|
||||
State* state = new State;
|
||||
@ -168,14 +168,14 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
|
||||
} else {
|
||||
state->call.reset(new RecvBufCall(
|
||||
step_id_, peer_device, peer_task, key, to_device, to_device_ctx,
|
||||
to_alloc_attr, to_tensor, client_locality, state->server_locality,
|
||||
&cancel_mgr_, worker_cache_));
|
||||
to_alloc_attr, to_tensor, client_locality,
|
||||
state->server_attributes.locality(), &cancel_mgr_, worker_cache_));
|
||||
state->call->Start(recv_buf_callback);
|
||||
}
|
||||
};
|
||||
|
||||
dev_resolver_->GetLocalityAsync(
|
||||
peer_device, peer_task, &state->server_locality, dev_locality_callback);
|
||||
dev_resolver_->GetDeviceAttributesAsync(
|
||||
peer_device, peer_task, &state->server_attributes, dev_locality_callback);
|
||||
}
|
||||
|
||||
void CollectiveRemoteAccessDistributed::StartAbort(const Status& s) {
|
||||
|
@ -23,16 +23,15 @@ DeviceResolverDistributed::DeviceResolverDistributed(
|
||||
const string& task_name)
|
||||
: dev_mgr_(dev_mgr), worker_cache_(worker_cache), task_name_(task_name) {}
|
||||
|
||||
void DeviceResolverDistributed::GetLocalityAsync(const string& device,
|
||||
const string& task,
|
||||
DeviceLocality* locality,
|
||||
const StatusCallback& done) {
|
||||
void DeviceResolverDistributed::GetDeviceAttributesAsync(
|
||||
const string& device, const string& task, DeviceAttributes* attributes,
|
||||
const StatusCallback& done) {
|
||||
if (task.empty() || task == task_name_) {
|
||||
// Device is local to this task.
|
||||
Device* dev;
|
||||
Status s = dev_mgr_->LookupDevice(device, &dev);
|
||||
if (s.ok()) {
|
||||
*locality = dev->attributes().locality();
|
||||
*attributes = dev->attributes();
|
||||
}
|
||||
done(s);
|
||||
return;
|
||||
@ -43,7 +42,7 @@ void DeviceResolverDistributed::GetLocalityAsync(const string& device,
|
||||
mutex_lock l(mu_);
|
||||
auto it = attr_table_.find(device);
|
||||
if (it != attr_table_.end()) {
|
||||
*locality = it->second.locality();
|
||||
*attributes = it->second;
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
@ -55,39 +54,38 @@ void DeviceResolverDistributed::GetLocalityAsync(const string& device,
|
||||
// Device is remote and no cache entry was found. Refresh the cache
|
||||
// then retry the lookup.
|
||||
RefreshRemoteAttributes(
|
||||
device, task, [this, device, task, locality, done](const Status& s) {
|
||||
device, task, [this, device, task, attributes, done](const Status& s) {
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
} else {
|
||||
GetLocalityAsync(device, task, locality, done);
|
||||
GetDeviceAttributesAsync(device, task, attributes, done);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void DeviceResolverDistributed::GetDeviceLocalitiesAsync(
|
||||
const CollInstanceParams& inst_params,
|
||||
std::vector<DeviceLocality>* localities, const StatusCallback& done) {
|
||||
localities->clear();
|
||||
GetDeviceLocalitiesRecursive(inst_params, localities, done);
|
||||
void DeviceResolverDistributed::GetAllDeviceAttributesAsync(
|
||||
const std::vector<string>& devices, const std::vector<string>& tasks,
|
||||
std::vector<DeviceAttributes>* attributes, const StatusCallback& done) {
|
||||
attributes->clear();
|
||||
GetAllDeviceAttributesRecursive(devices, tasks, attributes, done);
|
||||
}
|
||||
|
||||
void DeviceResolverDistributed::GetDeviceLocalitiesRecursive(
|
||||
const CollInstanceParams& inst_params,
|
||||
std::vector<DeviceLocality>* localities, const StatusCallback& done) {
|
||||
size_t i = localities->size();
|
||||
if (i < inst_params.device_names.size()) {
|
||||
localities->push_back(DeviceLocality());
|
||||
GetLocalityAsync(inst_params.device_names[i], inst_params.task_names[i],
|
||||
&localities->back(),
|
||||
[this, &inst_params, localities, done](const Status& s) {
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
} else {
|
||||
GetDeviceLocalitiesRecursive(inst_params, localities,
|
||||
done);
|
||||
}
|
||||
});
|
||||
void DeviceResolverDistributed::GetAllDeviceAttributesRecursive(
|
||||
const std::vector<string>& devices, const std::vector<string>& tasks,
|
||||
std::vector<DeviceAttributes>* attributes, const StatusCallback& done) {
|
||||
size_t i = attributes->size();
|
||||
if (i < devices.size()) {
|
||||
attributes->push_back(DeviceAttributes());
|
||||
GetDeviceAttributesAsync(
|
||||
devices[i], tasks[i], &attributes->back(),
|
||||
[this, &devices, &tasks, attributes, done](const Status& s) {
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
} else {
|
||||
GetAllDeviceAttributesRecursive(devices, tasks, attributes, done);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
done(Status::OK());
|
||||
}
|
||||
@ -130,4 +128,9 @@ void DeviceResolverDistributed::ClearTask(const string& task) {
|
||||
}
|
||||
}
|
||||
|
||||
void DeviceResolverDistributed::ClearCache() {
|
||||
mutex_lock l(mu_);
|
||||
attr_table_.clear();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -18,9 +18,9 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/framework/collective.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class DeviceMgr;
|
||||
@ -34,33 +34,36 @@ class DeviceResolverDistributed : public DeviceResolverInterface {
|
||||
|
||||
virtual ~DeviceResolverDistributed() {}
|
||||
|
||||
void GetDeviceLocalitiesAsync(const CollInstanceParams& inst_params,
|
||||
std::vector<DeviceLocality>* localities,
|
||||
void GetAllDeviceAttributesAsync(const std::vector<string>& devices,
|
||||
const std::vector<string>& tasks,
|
||||
std::vector<DeviceAttributes>* attributes,
|
||||
const StatusCallback& done) override;
|
||||
|
||||
void GetDeviceAttributesAsync(const string& device, const string& task,
|
||||
DeviceAttributes* attributes,
|
||||
const StatusCallback& done) override;
|
||||
|
||||
void GetLocalityAsync(const string& device, const string& task,
|
||||
DeviceLocality* locality,
|
||||
const StatusCallback& done) override;
|
||||
|
||||
void ClearTask(const string& task) override;
|
||||
|
||||
void ClearCache() override;
|
||||
|
||||
protected:
|
||||
// Loads attr_table_ with device attributes retrieved from remote task.
|
||||
void RefreshRemoteAttributes(const string& device, const string& task,
|
||||
const StatusCallback& done) LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Subroutine used by GetDeviceLocalitiesAsync. Recursively extends
|
||||
// *localities with DeviceLocality of the corresponding device named
|
||||
// Subroutine used by GetAllDeviceAttributesAsync. Recursively extends
|
||||
// *attributes with DeviceAttributes of the corresponding device named
|
||||
// by inst_params.instance.device_names.
|
||||
void GetDeviceLocalitiesRecursive(const CollInstanceParams& inst_params,
|
||||
std::vector<DeviceLocality>* localities,
|
||||
const StatusCallback& done);
|
||||
void GetAllDeviceAttributesRecursive(
|
||||
const std::vector<string>& devices, const std::vector<string>& tasks,
|
||||
std::vector<DeviceAttributes>* attributes, const StatusCallback& done);
|
||||
|
||||
const DeviceMgr* dev_mgr_; // Not owned
|
||||
WorkerCacheInterface* worker_cache_; // Not owned
|
||||
const string task_name_;
|
||||
mutex mu_;
|
||||
gtl::FlatMap<string, DeviceAttributes> attr_table_ GUARDED_BY(mu_);
|
||||
absl::flat_hash_map<string, DeviceAttributes> attr_table_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -37,13 +37,15 @@ class TestableDeviceResolverDistributed : public DeviceResolverDistributed {
|
||||
const string& task)
|
||||
: DeviceResolverDistributed(dev_mgr, worker_cache, task) {}
|
||||
|
||||
gtl::FlatMap<string, DeviceAttributes>& attr_table() { return attr_table_; }
|
||||
absl::flat_hash_map<string, DeviceAttributes>& attr_table() {
|
||||
return attr_table_;
|
||||
}
|
||||
};
|
||||
|
||||
// Create a fake 'Device' whose only interesting attribute is a non-default
|
||||
// DeviceLocality.
|
||||
// DeviceLocality and incarnation.
|
||||
static std::unique_ptr<Device> NewDevice(const string& type, const string& name,
|
||||
int numa_node) {
|
||||
int numa_node, uint64 incarnation) {
|
||||
class FakeDevice : public Device {
|
||||
public:
|
||||
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
|
||||
@ -54,6 +56,7 @@ static std::unique_ptr<Device> NewDevice(const string& type, const string& name,
|
||||
attr.set_name(name);
|
||||
attr.set_device_type(type);
|
||||
attr.mutable_locality()->set_numa_node(numa_node);
|
||||
attr.set_incarnation(incarnation);
|
||||
return absl::make_unique<FakeDevice>(attr);
|
||||
}
|
||||
|
||||
@ -142,20 +145,23 @@ class DeviceResDistTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
void DefineWorkers(int num_workers, int num_devices,
|
||||
const string& device_type) {
|
||||
const string& device_type,
|
||||
uint64 device_incarnation_base) {
|
||||
for (int w = 0; w < num_workers; ++w) {
|
||||
string name = strings::StrCat("/job:worker/replica:0/task:", w);
|
||||
DefineWorker(name, device_type, num_devices);
|
||||
DefineWorker(name, device_type, num_devices,
|
||||
w * num_devices + device_incarnation_base);
|
||||
}
|
||||
}
|
||||
|
||||
void DefineWorker(const string& worker_name, const string& device_type,
|
||||
int num_devices) {
|
||||
int num_devices, uint64 device_incarnation_base) {
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
for (int i = 0; i < num_devices; ++i) {
|
||||
devices.push_back(NewDevice(
|
||||
device_type,
|
||||
strings::StrCat(worker_name, "/device:", device_type, ":", i), i));
|
||||
strings::StrCat(worker_name, "/device:", device_type, ":", i), i,
|
||||
device_incarnation_base + i));
|
||||
}
|
||||
DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
|
||||
TestableDeviceResolverDistributed* dev_res =
|
||||
@ -163,6 +169,7 @@ class DeviceResDistTest : public ::testing::Test {
|
||||
resolvers_[worker_name] = dev_res;
|
||||
device_mgrs_.push_back(dev_mgr);
|
||||
std::vector<string>* dv = &dev_by_task_[worker_name];
|
||||
dv->clear();
|
||||
for (auto* d : dev_mgr->ListDevices()) {
|
||||
dv->push_back(d->name());
|
||||
}
|
||||
@ -171,6 +178,55 @@ class DeviceResDistTest : public ::testing::Test {
|
||||
wc_.AddWorker(worker_name, fw);
|
||||
}
|
||||
|
||||
void RestartWorker(const string& worker_name, const string& device_type,
|
||||
int num_devices, uint64 device_incarnation_base) {
|
||||
for (auto it : resolvers_) {
|
||||
it.second->ClearCache();
|
||||
}
|
||||
// `DefineWorker` creates a device resolver and a worker and adds them to
|
||||
// resolvers_ and workers_. Recreating the worker would overwrite these map
|
||||
// entries. We destroy the old device resolver here; all other objects are
|
||||
// cleaned up in the destructor.
|
||||
delete resolvers_[worker_name];
|
||||
DefineWorker(worker_name, device_type, num_devices,
|
||||
device_incarnation_base);
|
||||
}
|
||||
|
||||
void ResolveIncarnationsAndValidate(
|
||||
const int num_workers, const int num_devices, const string& worker_prefix,
|
||||
const string& device_type,
|
||||
const std::vector<std::vector<uint64>>& expected_incarnations) {
|
||||
for (int w = 0; w < num_workers; ++w) {
|
||||
const string worker_name = absl::StrCat(worker_prefix, w);
|
||||
auto* device_resolver = resolvers_[worker_name];
|
||||
const string device_prefix =
|
||||
absl::StrCat(worker_name, "/device:", device_type, ":");
|
||||
for (int peer_w = 0; peer_w < num_workers; ++peer_w) {
|
||||
const string peer_worker_name = absl::StrCat(worker_prefix, peer_w);
|
||||
for (int d = 0; d < num_devices; ++d) {
|
||||
const string device_name =
|
||||
absl::StrCat(peer_worker_name, "/device:", device_type, ":", d);
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
ASSERT_TRUE(DeviceNameUtils::ParseFullName(device_name, &parsed));
|
||||
// NOLINT prevents linter from suggesting absl::Notification as a
|
||||
// replacement, which is not available in OSS.
|
||||
Notification note; // NOLINT
|
||||
Status status;
|
||||
DeviceAttributes attributes;
|
||||
device_resolver->GetDeviceAttributesAsync(
|
||||
device_name, peer_worker_name, &attributes,
|
||||
[¬e, &status](const Status& s) {
|
||||
status = s;
|
||||
note.Notify();
|
||||
});
|
||||
note.WaitForNotification();
|
||||
TF_EXPECT_OK(status);
|
||||
EXPECT_EQ(attributes.incarnation(), expected_incarnations[peer_w][d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FakeCache wc_;
|
||||
std::vector<DeviceMgr*> device_mgrs_;
|
||||
std::unordered_map<string, TestableDeviceResolverDistributed*> resolvers_;
|
||||
@ -179,7 +235,8 @@ class DeviceResDistTest : public ::testing::Test {
|
||||
};
|
||||
|
||||
TEST_F(DeviceResDistTest, Workers3Devices4) {
|
||||
DefineWorkers(3, 4, "CPU");
|
||||
DefineWorkers(/*num_workers=*/3, /*num_devices=*/4, /*device_type=*/"CPU",
|
||||
/*device_incarnation_base=*/1);
|
||||
// Check that every device is available from every task.
|
||||
for (auto it : resolvers_) {
|
||||
DeviceResolverDistributed* dres = it.second;
|
||||
@ -190,15 +247,15 @@ TEST_F(DeviceResDistTest, Workers3Devices4) {
|
||||
ASSERT_TRUE(DeviceNameUtils::ParseFullName(dev_name, &parsed));
|
||||
Notification note;
|
||||
Status status;
|
||||
DeviceLocality locality;
|
||||
dres->GetLocalityAsync(dev_name, task_name, &locality,
|
||||
[¬e, &status](const Status& s) {
|
||||
status = s;
|
||||
note.Notify();
|
||||
});
|
||||
DeviceAttributes attributes;
|
||||
dres->GetDeviceAttributesAsync(dev_name, task_name, &attributes,
|
||||
[¬e, &status](const Status& s) {
|
||||
status = s;
|
||||
note.Notify();
|
||||
});
|
||||
note.WaitForNotification();
|
||||
TF_EXPECT_OK(status);
|
||||
EXPECT_EQ(parsed.id, locality.numa_node());
|
||||
EXPECT_EQ(parsed.id, attributes.locality().numa_node());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -213,5 +270,42 @@ TEST_F(DeviceResDistTest, Workers3Devices4) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DeviceResDistTest, DeviceIncarnationChangesOnFailure) {
|
||||
constexpr int num_workers = 3;
|
||||
constexpr int num_devices = 4;
|
||||
constexpr int failing_worker_index = 1;
|
||||
const string device_type = "CPU";
|
||||
constexpr uint64 device_incarnation_base = 100;
|
||||
DefineWorkers(num_workers, num_devices, device_type, device_incarnation_base);
|
||||
const string worker_prefix = "/job:worker/replica:0/task:";
|
||||
const string failing_worker =
|
||||
absl::StrCat(worker_prefix, failing_worker_index);
|
||||
|
||||
// Check device incarnations match expected.
|
||||
std::vector<std::vector<uint64>> expected_incarnations(num_workers);
|
||||
for (int w = 0; w < num_workers; ++w) {
|
||||
expected_incarnations[w].resize(num_devices);
|
||||
for (int d = 0; d < num_devices; ++d) {
|
||||
expected_incarnations[w][d] =
|
||||
w * num_devices + d + device_incarnation_base;
|
||||
}
|
||||
}
|
||||
ResolveIncarnationsAndValidate(num_workers, num_devices, worker_prefix,
|
||||
device_type, expected_incarnations);
|
||||
|
||||
// Restart worker `failing_worker`.
|
||||
constexpr uint64 restart_incarnation_base = 200;
|
||||
RestartWorker(failing_worker, device_type, num_devices,
|
||||
restart_incarnation_base);
|
||||
for (int d = 0; d < num_devices; ++d) {
|
||||
expected_incarnations[failing_worker_index][d] =
|
||||
d + restart_incarnation_base;
|
||||
}
|
||||
|
||||
// Check incarnations have changed for `failing worker`.
|
||||
ResolveIncarnationsAndValidate(num_workers, num_devices, worker_prefix,
|
||||
device_type, expected_incarnations);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -138,21 +138,25 @@ class DeviceResolverInterface {
|
||||
public:
|
||||
virtual ~DeviceResolverInterface() {}
|
||||
|
||||
// Collects DeviceLocality protobufs from all of the devices identified
|
||||
// Collects DeviceAttributes protobufs from all of the devices identified
|
||||
// in 'col_params'.
|
||||
virtual void GetDeviceLocalitiesAsync(const CollInstanceParams& inst_params,
|
||||
std::vector<DeviceLocality>* localities,
|
||||
virtual void GetAllDeviceAttributesAsync(
|
||||
const std::vector<string>& devices, const std::vector<string>& tasks,
|
||||
std::vector<DeviceAttributes>* attributes,
|
||||
const StatusCallback& done) = 0;
|
||||
|
||||
// Populate *attributes with the DeviceAttributes of the specified
|
||||
// device.
|
||||
virtual void GetDeviceAttributesAsync(const string& device,
|
||||
const string& task,
|
||||
DeviceAttributes* attributes,
|
||||
const StatusCallback& done) = 0;
|
||||
|
||||
// Populate *locality with the DeviceLocality of the specified
|
||||
// device.
|
||||
virtual void GetLocalityAsync(const string& device, const string& task,
|
||||
DeviceLocality* locality,
|
||||
const StatusCallback& done) = 0;
|
||||
|
||||
// Clear the cache of device data belonging
|
||||
// to the specified task.
|
||||
// Clear the cache of device data belonging to the specified task.
|
||||
virtual void ClearTask(const string& task) = 0;
|
||||
|
||||
// Clear the cache of all device data.
|
||||
virtual void ClearCache() = 0;
|
||||
};
|
||||
|
||||
// Interface that provides resolution of shared CollectiveParams fields.
|
||||
|
Loading…
Reference in New Issue
Block a user