diff --git a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc index 755cbdff31c..9e7513e60d7 100644 --- a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc +++ b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc @@ -93,7 +93,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal { // State that needs to be threaded through a couple of async calls // in order to make this function completely non-blocking. struct State { - DeviceLocality server_locality; + DeviceAttributes server_attributes; std::unique_ptr call; }; State* state = new State; @@ -108,6 +108,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal { } if (!s.ok() && errors::IsFailedPrecondition(s)) { dev_resolver_->ClearTask(peer_task); + done(s); } delete state; @@ -124,14 +125,15 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal { } else { state->call.reset(new RecvBufCall( step_id_, peer_device, peer_task, key, to_device, to_device_ctx, - to_alloc_attr, to_tensor, client_locality, state->server_locality, - &cancel_mgr_, worker_cache_)); + to_alloc_attr, to_tensor, client_locality, + state->server_attributes.locality(), &cancel_mgr_, worker_cache_)); state->call->Start(recv_buf_callback); } }; - dev_resolver_->GetLocalityAsync( - peer_device, peer_task, &state->server_locality, dev_locality_callback); + dev_resolver_->GetDeviceAttributesAsync(peer_device, peer_task, + &state->server_attributes, + dev_locality_callback); } void StartAbort(const Status& s) override { diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc index 27156cfe362..2be3f623359 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc @@ -209,10 +209,10 @@ typedef std::unordered_map GlobalDeviceMap; // Create a populated GlobalDeviceMap from CollInstanceParams and localities. GlobalDeviceMap BuildDevRecs(const CollInstanceParams& ip, - const std::vector& localities) { + const std::vector& attributes) { GlobalDeviceMap gdm; CHECK_EQ(ip.device_names.size(), ip.task_names.size()); - CHECK_EQ(ip.device_names.size(), localities.size()); + CHECK_EQ(ip.device_names.size(), attributes.size()); for (int i = 0; i < ip.device_names.size(); ++i) { TaskDeviceMap& tdm = gdm[ip.task_names[i]]; DevRec* dr = &tdm[ip.device_names[i]]; @@ -221,7 +221,7 @@ GlobalDeviceMap BuildDevRecs(const CollInstanceParams& ip, dr->original_rank = i; dr->local_rank = 0; // Will be populated later by OrderTaskDeviceMap. dr->global_rank = 0; // Will be populated later by EstablishGlobalRank. - dr->locality = &localities[i]; + dr->locality = &attributes[i].locality(); } return gdm; } @@ -342,9 +342,9 @@ void OrderTaskDeviceMap(const string& gpu_ring_order, TaskDeviceMap* tdm) { // sharing the same device group where there is more than one good // order. GlobalDeviceMap EstablishGlobalRank( - CollectiveParams* cp, const std::vector& localities) { + CollectiveParams* cp, const std::vector& attributes) { VLOG(1) << "EstablishGlobalRank"; - GlobalDeviceMap gdm = BuildDevRecs(cp->instance, localities); + GlobalDeviceMap gdm = BuildDevRecs(cp->instance, attributes); for (auto& iter : gdm) { TaskDeviceMap& tdm = iter.second; OrderTaskDeviceMap(cp->instance.gpu_ring_order, &tdm); @@ -472,20 +472,26 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams( // Get Locality data for all devices. // Set is_local and task_names in *shared prior to invoking - // GetDeviceLocalitiesAsync. In a distributed context this function can be + // GetDeviceAttributesAsync. In a distributed context this function can be // called by a derived class, some of the devices may be non-local and - // GetDeviceLocalitiesAsync will use those fields to launch RPCs. + // GetDeviceAttributesAsync will use those fields to launch RPCs. CompleteTaskIsLocal(task_name_, &ir->shared); // Because the callback may execute in a different thread, we release // ir->out_mu here. Before releasing, we mark it as unavailable for other // threads. ir->out_mu_available = false; + const auto device_names = ir->shared.instance.device_names; + const auto task_names = ir->shared.instance.task_names; ir->out_mu.unlock(); - std::vector* localities = new std::vector; - dev_resolver_->GetDeviceLocalitiesAsync( - ir->shared.instance, localities, - [this, gr, cp, ir, localities, done](const Status& s) + std::vector* attributes = new std::vector; + // Suppress linter warning about access to shared without mutex because in + // principle the members are locked due to out_mu_available=false. + dev_resolver_->GetAllDeviceAttributesAsync( + ir->shared.instance.device_names, // NOLINT + ir->shared.instance.task_names, // NOLINT + attributes, + [this, gr, cp, ir, attributes, done](const Status& s) EXCLUSIVE_LOCK_FUNCTION(ir->out_mu) { // Then we recover the lock in the callback thread that will hold it // through the rest of the call chain. Signal the cv now, any @@ -495,26 +501,26 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams( ir->out_mu_available = true; ir->out_cv.notify_all(); if (s.ok()) { - CompleteDefaultRanking(gr, cp, ir, *localities); + CompleteDefaultRanking(gr, cp, ir, *attributes); done(Status::OK()); } else { done(s); } - delete localities; + delete attributes; }); } -// NOTE(ayushd): The DeviceLocality objects in localities will have LocalLinks +// NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks // to all devices that they are physically connected to and visible to the // TensorFlow runtime. This set of devices may be a superset of the devices // participating in this instance of collectives. void CollectiveParamResolverLocal::CompleteDefaultRanking( const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir, - const std::vector& localities) { + const std::vector& attributes) { // Establish an instance-specific default rank order for devices // based on localities. This rank order should be a good ring // order, if possible. - GlobalDeviceMap gdm = EstablishGlobalRank(&ir->shared, localities); + GlobalDeviceMap gdm = EstablishGlobalRank(&ir->shared, attributes); // Reflect the new global ranking on shared size_t num_devices = ir->shared.group.group_size; std::vector new_device_names(num_devices, ""); @@ -599,7 +605,7 @@ void CollectiveParamResolverLocal::CallInitInstanceSharedParams( // before all the function stack frames pop. The static analysis will // not allow that. // - // *the lock is dropped just before calling GetDeviceLocalitiesAsync, because + // *the lock is dropped just before calling GetDeviceAttributesAsync, because // there is no guarantee that the thread that executes the callback is the // same as the one that locked ir->out_mu. To prevent other threads from // grabbing ir->out_mu, we mark ir->out_mu_available as false. Hence, in diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h index 5b9eae9b78c..a912ffd1b1a 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.h +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h @@ -182,7 +182,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { // ir->shared.instance.task_names by considering localities of all devices. void CompleteDefaultRanking(const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir, - const std::vector& localities) + const std::vector& attributes) EXCLUSIVE_LOCKS_REQUIRED(ir->out_mu); // Finish populating *cp. diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc index 70eb9f8081a..e7bae008a4b 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/common_runtime/collective_executor_mgr.h" - #include "tensorflow/core/common_runtime/collective_param_resolver_local.h" + +#include "tensorflow/core/common_runtime/collective_executor_mgr.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -47,7 +47,7 @@ class CollectiveParamResolverLocalTest : public ::testing::Test { void RunCompleteDefaultRanking( const CollectiveParams& shared_cp, - const std::vector& localities, + const std::vector& attributes, const std::vector& gpu_ring_order, const std::vector& expected_device_order) { CollectiveParams cp; @@ -69,7 +69,7 @@ class CollectiveParamResolverLocalTest : public ::testing::Test { ir.shared.instance.gpu_ring_order, gpu_ring_order.back()); } VLOG(2) << "gpu_ring_order " << ir.shared.instance.gpu_ring_order; - prl_->CompleteDefaultRanking(nullptr, &cp, &ir, localities); + prl_->CompleteDefaultRanking(nullptr, &cp, &ir, attributes); EXPECT_EQ(ir.shared.instance.device_names, expected_device_order); } } @@ -82,7 +82,7 @@ class CollectiveParamResolverLocalTest : public ::testing::Test { TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) { constexpr int kNumGpus = 8; CollectiveParams cp; - std::vector localities(kNumGpus); + std::vector attributes(kNumGpus); cp.name = "PRLTest"; cp.group.device_type = DeviceType("GPU"); cp.group.num_tasks = 1; @@ -95,7 +95,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) { cp.instance.task_names.push_back("/job:localhost/replica:0/task:0"); cp.instance.device_names.push_back(strings::StrCat( "/job:localhost/replica:0/task:0/device:GPU:", gpu_idx)); - DeviceLocality* locality = &localities[gpu_idx]; + DeviceLocality locality; // Build localities so that 0,1,6,7 and 2,3,4,5 form 2 strongly connected // components. Across components, connect 3 and 7. for (int link_idx = 0; link_idx < kNumGpus; ++link_idx) { @@ -104,20 +104,21 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) { bool link_in_clique1 = clique1.find(link_idx) != clique1.end(); if ((gpu_in_clique1 && link_in_clique1) || (!gpu_in_clique1 && !link_in_clique1)) { - LocalLinks* links = locality->mutable_links(); + LocalLinks* links = locality.mutable_links(); InterconnectLink* ilink = links->add_link(); ilink->set_device_id(link_idx); ilink->set_strength(2); } else if ((gpu_idx == 3 && link_idx == 7) || (gpu_idx == 7 && link_idx == 3)) { - LocalLinks* links = locality->mutable_links(); + LocalLinks* links = locality.mutable_links(); InterconnectLink* ilink = links->add_link(); ilink->set_device_id(link_idx); ilink->set_strength(1); } } + *attributes[gpu_idx].mutable_locality() = locality; } - RunCompleteDefaultRanking(cp, localities, {1, 3, 5, 7, 6, 4, 2, 0}, + RunCompleteDefaultRanking(cp, attributes, {1, 3, 5, 7, 6, 4, 2, 0}, { "/job:localhost/replica:0/task:0/device:GPU:1", "/job:localhost/replica:0/task:0/device:GPU:3", @@ -128,7 +129,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) { "/job:localhost/replica:0/task:0/device:GPU:2", "/job:localhost/replica:0/task:0/device:GPU:0", }); - RunCompleteDefaultRanking(cp, localities, {7, 6, 5, 4, 3, 2, 1, 0}, + RunCompleteDefaultRanking(cp, attributes, {7, 6, 5, 4, 3, 2, 1, 0}, { "/job:localhost/replica:0/task:0/device:GPU:7", "/job:localhost/replica:0/task:0/device:GPU:6", @@ -141,7 +142,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) { }); // With no gpu_ring_order passed, automatic link detection should kick in. // Starting at dev 0, the best order would be: 0,1,6,7,3,2,4,5 - RunCompleteDefaultRanking(cp, localities, {}, + RunCompleteDefaultRanking(cp, attributes, {}, { "/job:localhost/replica:0/task:0/device:GPU:0", "/job:localhost/replica:0/task:0/device:GPU:1", diff --git a/tensorflow/core/common_runtime/collective_rma_local.cc b/tensorflow/core/common_runtime/collective_rma_local.cc index d99565b49ab..bbf10244f1d 100644 --- a/tensorflow/core/common_runtime/collective_rma_local.cc +++ b/tensorflow/core/common_runtime/collective_rma_local.cc @@ -21,6 +21,9 @@ namespace tensorflow { void CollectiveRemoteAccessLocal::StartAbort(const Status& s) { buf_rendezvous_.StartAbort(s); + if (errors::IsFailedPrecondition(s)) { + dev_resolver_->ClearCache(); + } } void CollectiveRemoteAccessLocal::RecvFromPeer( diff --git a/tensorflow/core/common_runtime/collective_rma_local.h b/tensorflow/core/common_runtime/collective_rma_local.h index 2188087957e..f8a0f58472d 100644 --- a/tensorflow/core/common_runtime/collective_rma_local.h +++ b/tensorflow/core/common_runtime/collective_rma_local.h @@ -52,22 +52,26 @@ class CollectiveRemoteAccessLocal : public PerStepCollectiveRemoteAccess { const DeviceLocality& client_locality, const StatusCallback& done) override; - void GetDeviceLocalitiesAsync(const CollInstanceParams& ci_params, - std::vector* localities, - const StatusCallback& done) override { - dev_resolver_->GetDeviceLocalitiesAsync(ci_params, localities, done); + void GetAllDeviceAttributesAsync(const std::vector& devices, + const std::vector& tasks, + std::vector* attributes, + const StatusCallback& done) override { + dev_resolver_->GetAllDeviceAttributesAsync(devices, tasks, attributes, + done); } - void GetLocalityAsync(const string& device, const string& task, - DeviceLocality* locality, - const StatusCallback& done) override { - dev_resolver_->GetLocalityAsync(device, task, locality, done); + void GetDeviceAttributesAsync(const string& device, const string& task, + DeviceAttributes* attributes, + const StatusCallback& done) override { + dev_resolver_->GetDeviceAttributesAsync(device, task, attributes, done); } void ClearTask(const string& task) override { dev_resolver_->ClearTask(task); } + void ClearCache() override { dev_resolver_->ClearCache(); } + BufRendezvous* buf_rendezvous() override { return &buf_rendezvous_; } // Copy utility that always copies bytes from src to dst even if diff --git a/tensorflow/core/common_runtime/device_resolver_local.cc b/tensorflow/core/common_runtime/device_resolver_local.cc index 17ef4a22844..12e1e28296d 100644 --- a/tensorflow/core/common_runtime/device_resolver_local.cc +++ b/tensorflow/core/common_runtime/device_resolver_local.cc @@ -18,30 +18,30 @@ limitations under the License. namespace tensorflow { -void DeviceResolverLocal::GetDeviceLocalitiesAsync( - const CollInstanceParams& ci_params, - std::vector* localities, const StatusCallback& done) { - localities->clear(); - for (const string& device_name : ci_params.device_names) { +void DeviceResolverLocal::GetAllDeviceAttributesAsync( + const std::vector& devices, const std::vector& tasks, + std::vector* attributes, const StatusCallback& done) { + attributes->clear(); + for (const string& device_name : devices) { Device* dev; Status s = dev_mgr_->LookupDevice(device_name, &dev); if (!s.ok()) { done(s); return; } - localities->push_back(dev->attributes().locality()); + attributes->push_back(dev->attributes()); } done(Status::OK()); } -void DeviceResolverLocal::GetLocalityAsync(const string& device, - const string& task, - DeviceLocality* locality, - const StatusCallback& done) { +void DeviceResolverLocal::GetDeviceAttributesAsync(const string& device, + const string& task, + DeviceAttributes* attributes, + const StatusCallback& done) { Device* dev; Status s = dev_mgr_->LookupDevice(device, &dev); if (s.ok()) { - *locality = dev->attributes().locality(); + *attributes = dev->attributes(); } done(s); } diff --git a/tensorflow/core/common_runtime/device_resolver_local.h b/tensorflow/core/common_runtime/device_resolver_local.h index bb6ff2efa0c..53a3c87a158 100644 --- a/tensorflow/core/common_runtime/device_resolver_local.h +++ b/tensorflow/core/common_runtime/device_resolver_local.h @@ -30,16 +30,19 @@ class DeviceResolverLocal : public DeviceResolverInterface { virtual ~DeviceResolverLocal() {} - void GetDeviceLocalitiesAsync(const CollInstanceParams& ci_params, - std::vector* localities, + void GetAllDeviceAttributesAsync(const std::vector& devices, + const std::vector& tasks, + std::vector* attributes, + const StatusCallback& done) override; + + void GetDeviceAttributesAsync(const string& device, const string& task, + DeviceAttributes* attributes, const StatusCallback& done) override; - void GetLocalityAsync(const string& device, const string& task, - DeviceLocality* locality, - const StatusCallback& done) override; - void ClearTask(const string& task) override {} + void ClearCache() override {} + protected: const DeviceMgr* dev_mgr_; }; diff --git a/tensorflow/core/common_runtime/device_resolver_local_test.cc b/tensorflow/core/common_runtime/device_resolver_local_test.cc index b8dac8e0dd9..45c74184b28 100644 --- a/tensorflow/core/common_runtime/device_resolver_local_test.cc +++ b/tensorflow/core/common_runtime/device_resolver_local_test.cc @@ -45,41 +45,36 @@ class DeviceResolverLocalTest : public ::testing::Test { std::unique_ptr drl_; }; -TEST_F(DeviceResolverLocalTest, GetDeviceLocalitiesKnown) { - CollectiveParams cp; - std::vector localities; - cp.instance.device_names.push_back( - "/job:localhost/replica:0/task:0/device:CPU:1"); - cp.instance.device_names.push_back( - "/job:localhost/replica:0/task:0/device:CPU:2"); +TEST_F(DeviceResolverLocalTest, GetDeviceAttributesKnown) { + std::vector attributes; + std::vector devices{"/job:localhost/replica:0/task:0/device:CPU:1", + "/job:localhost/replica:0/task:0/device:CPU:2"}; Notification note; Status status; - drl_->GetDeviceLocalitiesAsync(cp.instance, &localities, - [¬e, &status](const Status& s) { - status = s; - note.Notify(); - }); + drl_->GetAllDeviceAttributesAsync(devices, /*tasks=*/{}, &attributes, + [¬e, &status](const Status& s) { + status = s; + note.Notify(); + }); note.WaitForNotification(); TF_EXPECT_OK(status); - EXPECT_EQ(2, localities.size()); + EXPECT_EQ(2, attributes.size()); } -TEST_F(DeviceResolverLocalTest, GetDeviceLocalitiesUnknown) { - CollectiveParams cp; - std::vector localities; +TEST_F(DeviceResolverLocalTest, GetDeviceAttributesUnknown) { + std::vector attributes; // In some builds there may be 1 GPU, but there should never be 9. - cp.instance.device_names.push_back( - "/job:localhost/replica:0/task:0/device:GPU:9"); + std::vector devices{"/job:localhost/replica:0/task:0/device:GPU:9"}; Notification note; Status status; - drl_->GetDeviceLocalitiesAsync(cp.instance, &localities, - [¬e, &status](const Status& s) { - status = s; - note.Notify(); - }); + drl_->GetAllDeviceAttributesAsync(devices, /*tasks=*/{}, &attributes, + [¬e, &status](const Status& s) { + status = s; + note.Notify(); + }); note.WaitForNotification(); EXPECT_FALSE(status.ok()); - EXPECT_EQ(0, localities.size()); + EXPECT_EQ(0, attributes.size()); } } // namespace diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 280facc0f3e..126316963e8 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -619,6 +619,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:worker_proto_cc", + "@com_google_absl//absl/container:flat_hash_map", ], ) diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc index 1faba31b92d..901cc8f182a 100644 --- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc @@ -89,7 +89,7 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer( // State that needs to be threaded through a couple of async calls // in order to make this function completely non-blocking. struct State { - DeviceLocality server_locality; + DeviceAttributes server_attributes; std::unique_ptr call; }; State* state = new State; @@ -168,14 +168,14 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer( } else { state->call.reset(new RecvBufCall( step_id_, peer_device, peer_task, key, to_device, to_device_ctx, - to_alloc_attr, to_tensor, client_locality, state->server_locality, - &cancel_mgr_, worker_cache_)); + to_alloc_attr, to_tensor, client_locality, + state->server_attributes.locality(), &cancel_mgr_, worker_cache_)); state->call->Start(recv_buf_callback); } }; - dev_resolver_->GetLocalityAsync( - peer_device, peer_task, &state->server_locality, dev_locality_callback); + dev_resolver_->GetDeviceAttributesAsync( + peer_device, peer_task, &state->server_attributes, dev_locality_callback); } void CollectiveRemoteAccessDistributed::StartAbort(const Status& s) { diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed.cc b/tensorflow/core/distributed_runtime/device_resolver_distributed.cc index 165a19641ef..d39e1cb47a4 100644 --- a/tensorflow/core/distributed_runtime/device_resolver_distributed.cc +++ b/tensorflow/core/distributed_runtime/device_resolver_distributed.cc @@ -23,16 +23,15 @@ DeviceResolverDistributed::DeviceResolverDistributed( const string& task_name) : dev_mgr_(dev_mgr), worker_cache_(worker_cache), task_name_(task_name) {} -void DeviceResolverDistributed::GetLocalityAsync(const string& device, - const string& task, - DeviceLocality* locality, - const StatusCallback& done) { +void DeviceResolverDistributed::GetDeviceAttributesAsync( + const string& device, const string& task, DeviceAttributes* attributes, + const StatusCallback& done) { if (task.empty() || task == task_name_) { // Device is local to this task. Device* dev; Status s = dev_mgr_->LookupDevice(device, &dev); if (s.ok()) { - *locality = dev->attributes().locality(); + *attributes = dev->attributes(); } done(s); return; @@ -43,7 +42,7 @@ void DeviceResolverDistributed::GetLocalityAsync(const string& device, mutex_lock l(mu_); auto it = attr_table_.find(device); if (it != attr_table_.end()) { - *locality = it->second.locality(); + *attributes = it->second; found = true; } } @@ -55,39 +54,38 @@ void DeviceResolverDistributed::GetLocalityAsync(const string& device, // Device is remote and no cache entry was found. Refresh the cache // then retry the lookup. RefreshRemoteAttributes( - device, task, [this, device, task, locality, done](const Status& s) { + device, task, [this, device, task, attributes, done](const Status& s) { if (!s.ok()) { done(s); } else { - GetLocalityAsync(device, task, locality, done); + GetDeviceAttributesAsync(device, task, attributes, done); } }); } -void DeviceResolverDistributed::GetDeviceLocalitiesAsync( - const CollInstanceParams& inst_params, - std::vector* localities, const StatusCallback& done) { - localities->clear(); - GetDeviceLocalitiesRecursive(inst_params, localities, done); +void DeviceResolverDistributed::GetAllDeviceAttributesAsync( + const std::vector& devices, const std::vector& tasks, + std::vector* attributes, const StatusCallback& done) { + attributes->clear(); + GetAllDeviceAttributesRecursive(devices, tasks, attributes, done); } -void DeviceResolverDistributed::GetDeviceLocalitiesRecursive( - const CollInstanceParams& inst_params, - std::vector* localities, const StatusCallback& done) { - size_t i = localities->size(); - if (i < inst_params.device_names.size()) { - localities->push_back(DeviceLocality()); - GetLocalityAsync(inst_params.device_names[i], inst_params.task_names[i], - &localities->back(), - [this, &inst_params, localities, done](const Status& s) { - if (!s.ok()) { - done(s); - return; - } else { - GetDeviceLocalitiesRecursive(inst_params, localities, - done); - } - }); +void DeviceResolverDistributed::GetAllDeviceAttributesRecursive( + const std::vector& devices, const std::vector& tasks, + std::vector* attributes, const StatusCallback& done) { + size_t i = attributes->size(); + if (i < devices.size()) { + attributes->push_back(DeviceAttributes()); + GetDeviceAttributesAsync( + devices[i], tasks[i], &attributes->back(), + [this, &devices, &tasks, attributes, done](const Status& s) { + if (!s.ok()) { + done(s); + return; + } else { + GetAllDeviceAttributesRecursive(devices, tasks, attributes, done); + } + }); } else { done(Status::OK()); } @@ -130,4 +128,9 @@ void DeviceResolverDistributed::ClearTask(const string& task) { } } +void DeviceResolverDistributed::ClearCache() { + mutex_lock l(mu_); + attr_table_.clear(); +} + } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed.h b/tensorflow/core/distributed_runtime/device_resolver_distributed.h index ac68ec68731..f4391f822f5 100644 --- a/tensorflow/core/distributed_runtime/device_resolver_distributed.h +++ b/tensorflow/core/distributed_runtime/device_resolver_distributed.h @@ -18,9 +18,9 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/device_attributes.pb.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace tensorflow { class DeviceMgr; @@ -34,33 +34,36 @@ class DeviceResolverDistributed : public DeviceResolverInterface { virtual ~DeviceResolverDistributed() {} - void GetDeviceLocalitiesAsync(const CollInstanceParams& inst_params, - std::vector* localities, + void GetAllDeviceAttributesAsync(const std::vector& devices, + const std::vector& tasks, + std::vector* attributes, + const StatusCallback& done) override; + + void GetDeviceAttributesAsync(const string& device, const string& task, + DeviceAttributes* attributes, const StatusCallback& done) override; - void GetLocalityAsync(const string& device, const string& task, - DeviceLocality* locality, - const StatusCallback& done) override; - void ClearTask(const string& task) override; + void ClearCache() override; + protected: // Loads attr_table_ with device attributes retrieved from remote task. void RefreshRemoteAttributes(const string& device, const string& task, const StatusCallback& done) LOCKS_EXCLUDED(mu_); - // Subroutine used by GetDeviceLocalitiesAsync. Recursively extends - // *localities with DeviceLocality of the corresponding device named + // Subroutine used by GetAllDeviceAttributesAsync. Recursively extends + // *attributes with DeviceAttributes of the corresponding device named // by inst_params.instance.device_names. - void GetDeviceLocalitiesRecursive(const CollInstanceParams& inst_params, - std::vector* localities, - const StatusCallback& done); + void GetAllDeviceAttributesRecursive( + const std::vector& devices, const std::vector& tasks, + std::vector* attributes, const StatusCallback& done); const DeviceMgr* dev_mgr_; // Not owned WorkerCacheInterface* worker_cache_; // Not owned const string task_name_; mutex mu_; - gtl::FlatMap attr_table_ GUARDED_BY(mu_); + absl::flat_hash_map attr_table_ GUARDED_BY(mu_); }; } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc index 521bc237957..ecd14db2b6f 100644 --- a/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc @@ -37,13 +37,15 @@ class TestableDeviceResolverDistributed : public DeviceResolverDistributed { const string& task) : DeviceResolverDistributed(dev_mgr, worker_cache, task) {} - gtl::FlatMap& attr_table() { return attr_table_; } + absl::flat_hash_map& attr_table() { + return attr_table_; + } }; // Create a fake 'Device' whose only interesting attribute is a non-default -// DeviceLocality. +// DeviceLocality and incarnation. static std::unique_ptr NewDevice(const string& type, const string& name, - int numa_node) { + int numa_node, uint64 incarnation) { class FakeDevice : public Device { public: explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} @@ -54,6 +56,7 @@ static std::unique_ptr NewDevice(const string& type, const string& name, attr.set_name(name); attr.set_device_type(type); attr.mutable_locality()->set_numa_node(numa_node); + attr.set_incarnation(incarnation); return absl::make_unique(attr); } @@ -142,20 +145,23 @@ class DeviceResDistTest : public ::testing::Test { } void DefineWorkers(int num_workers, int num_devices, - const string& device_type) { + const string& device_type, + uint64 device_incarnation_base) { for (int w = 0; w < num_workers; ++w) { string name = strings::StrCat("/job:worker/replica:0/task:", w); - DefineWorker(name, device_type, num_devices); + DefineWorker(name, device_type, num_devices, + w * num_devices + device_incarnation_base); } } void DefineWorker(const string& worker_name, const string& device_type, - int num_devices) { + int num_devices, uint64 device_incarnation_base) { std::vector> devices; for (int i = 0; i < num_devices; ++i) { devices.push_back(NewDevice( device_type, - strings::StrCat(worker_name, "/device:", device_type, ":", i), i)); + strings::StrCat(worker_name, "/device:", device_type, ":", i), i, + device_incarnation_base + i)); } DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices)); TestableDeviceResolverDistributed* dev_res = @@ -163,6 +169,7 @@ class DeviceResDistTest : public ::testing::Test { resolvers_[worker_name] = dev_res; device_mgrs_.push_back(dev_mgr); std::vector* dv = &dev_by_task_[worker_name]; + dv->clear(); for (auto* d : dev_mgr->ListDevices()) { dv->push_back(d->name()); } @@ -171,6 +178,55 @@ class DeviceResDistTest : public ::testing::Test { wc_.AddWorker(worker_name, fw); } + void RestartWorker(const string& worker_name, const string& device_type, + int num_devices, uint64 device_incarnation_base) { + for (auto it : resolvers_) { + it.second->ClearCache(); + } + // `DefineWorker` creates a device resolver and a worker and adds them to + // resolvers_ and workers_. Recreating the worker would overwrite these map + // entries. We destroy the old device resolver here; all other objects are + // cleaned up in the destructor. + delete resolvers_[worker_name]; + DefineWorker(worker_name, device_type, num_devices, + device_incarnation_base); + } + + void ResolveIncarnationsAndValidate( + const int num_workers, const int num_devices, const string& worker_prefix, + const string& device_type, + const std::vector>& expected_incarnations) { + for (int w = 0; w < num_workers; ++w) { + const string worker_name = absl::StrCat(worker_prefix, w); + auto* device_resolver = resolvers_[worker_name]; + const string device_prefix = + absl::StrCat(worker_name, "/device:", device_type, ":"); + for (int peer_w = 0; peer_w < num_workers; ++peer_w) { + const string peer_worker_name = absl::StrCat(worker_prefix, peer_w); + for (int d = 0; d < num_devices; ++d) { + const string device_name = + absl::StrCat(peer_worker_name, "/device:", device_type, ":", d); + DeviceNameUtils::ParsedName parsed; + ASSERT_TRUE(DeviceNameUtils::ParseFullName(device_name, &parsed)); + // NOLINT prevents linter from suggesting absl::Notification as a + // replacement, which is not available in OSS. + Notification note; // NOLINT + Status status; + DeviceAttributes attributes; + device_resolver->GetDeviceAttributesAsync( + device_name, peer_worker_name, &attributes, + [¬e, &status](const Status& s) { + status = s; + note.Notify(); + }); + note.WaitForNotification(); + TF_EXPECT_OK(status); + EXPECT_EQ(attributes.incarnation(), expected_incarnations[peer_w][d]); + } + } + } + } + FakeCache wc_; std::vector device_mgrs_; std::unordered_map resolvers_; @@ -179,7 +235,8 @@ class DeviceResDistTest : public ::testing::Test { }; TEST_F(DeviceResDistTest, Workers3Devices4) { - DefineWorkers(3, 4, "CPU"); + DefineWorkers(/*num_workers=*/3, /*num_devices=*/4, /*device_type=*/"CPU", + /*device_incarnation_base=*/1); // Check that every device is available from every task. for (auto it : resolvers_) { DeviceResolverDistributed* dres = it.second; @@ -190,15 +247,15 @@ TEST_F(DeviceResDistTest, Workers3Devices4) { ASSERT_TRUE(DeviceNameUtils::ParseFullName(dev_name, &parsed)); Notification note; Status status; - DeviceLocality locality; - dres->GetLocalityAsync(dev_name, task_name, &locality, - [¬e, &status](const Status& s) { - status = s; - note.Notify(); - }); + DeviceAttributes attributes; + dres->GetDeviceAttributesAsync(dev_name, task_name, &attributes, + [¬e, &status](const Status& s) { + status = s; + note.Notify(); + }); note.WaitForNotification(); TF_EXPECT_OK(status); - EXPECT_EQ(parsed.id, locality.numa_node()); + EXPECT_EQ(parsed.id, attributes.locality().numa_node()); } } } @@ -213,5 +270,42 @@ TEST_F(DeviceResDistTest, Workers3Devices4) { } } +TEST_F(DeviceResDistTest, DeviceIncarnationChangesOnFailure) { + constexpr int num_workers = 3; + constexpr int num_devices = 4; + constexpr int failing_worker_index = 1; + const string device_type = "CPU"; + constexpr uint64 device_incarnation_base = 100; + DefineWorkers(num_workers, num_devices, device_type, device_incarnation_base); + const string worker_prefix = "/job:worker/replica:0/task:"; + const string failing_worker = + absl::StrCat(worker_prefix, failing_worker_index); + + // Check device incarnations match expected. + std::vector> expected_incarnations(num_workers); + for (int w = 0; w < num_workers; ++w) { + expected_incarnations[w].resize(num_devices); + for (int d = 0; d < num_devices; ++d) { + expected_incarnations[w][d] = + w * num_devices + d + device_incarnation_base; + } + } + ResolveIncarnationsAndValidate(num_workers, num_devices, worker_prefix, + device_type, expected_incarnations); + + // Restart worker `failing_worker`. + constexpr uint64 restart_incarnation_base = 200; + RestartWorker(failing_worker, device_type, num_devices, + restart_incarnation_base); + for (int d = 0; d < num_devices; ++d) { + expected_incarnations[failing_worker_index][d] = + d + restart_incarnation_base; + } + + // Check incarnations have changed for `failing worker`. + ResolveIncarnationsAndValidate(num_workers, num_devices, worker_prefix, + device_type, expected_incarnations); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h index 3022e6156f0..f0511f0c164 100644 --- a/tensorflow/core/framework/collective.h +++ b/tensorflow/core/framework/collective.h @@ -138,21 +138,25 @@ class DeviceResolverInterface { public: virtual ~DeviceResolverInterface() {} - // Collects DeviceLocality protobufs from all of the devices identified + // Collects DeviceAttributes protobufs from all of the devices identified // in 'col_params'. - virtual void GetDeviceLocalitiesAsync(const CollInstanceParams& inst_params, - std::vector* localities, + virtual void GetAllDeviceAttributesAsync( + const std::vector& devices, const std::vector& tasks, + std::vector* attributes, + const StatusCallback& done) = 0; + + // Populate *attributes with the DeviceAttributes of the specified + // device. + virtual void GetDeviceAttributesAsync(const string& device, + const string& task, + DeviceAttributes* attributes, const StatusCallback& done) = 0; - // Populate *locality with the DeviceLocality of the specified - // device. - virtual void GetLocalityAsync(const string& device, const string& task, - DeviceLocality* locality, - const StatusCallback& done) = 0; - - // Clear the cache of device data belonging - // to the specified task. + // Clear the cache of device data belonging to the specified task. virtual void ClearTask(const string& task) = 0; + + // Clear the cache of all device data. + virtual void ClearCache() = 0; }; // Interface that provides resolution of shared CollectiveParams fields.