Enable resetting NcclManager after a previous StartAbort.
After calling Reset, the NcclManager is again available to launch collectives. PiperOrigin-RevId: 338296784 Change-Id: I4ebc0fedbd84318ba08a1571e2a1c91e765f6b5a
This commit is contained in:
parent
259ffa9ea6
commit
1c7803fa56
tensorflow/core/nccl
@ -739,6 +739,7 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
|
||||
|
||||
VLOG(2) << "call NcclAllReduce collective_key "
|
||||
<< collective->collective_key << " participant " << p_idx
|
||||
<< " num_participants " << collective->participants.size()
|
||||
<< " sendbuff " << sendbuff << " recvbuff " << recvbuff
|
||||
<< " nccl_comm " << nccl_comm << " comm_stream " << comm_stream
|
||||
<< " cuda_stream " << cu_stream;
|
||||
@ -849,7 +850,6 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
|
||||
}
|
||||
|
||||
void NcclManager::StartAbort(const Status& s) {
|
||||
VLOG(1) << "NcclManager StartAbort";
|
||||
absl::flat_hash_map<string, Collective*> collectives;
|
||||
std::vector<std::unique_ptr<Communicator>> communicators;
|
||||
{
|
||||
@ -864,6 +864,9 @@ void NcclManager::StartAbort(const Status& s) {
|
||||
collectives.swap(collectives_);
|
||||
communicators.swap(communicators_);
|
||||
}
|
||||
VLOG(2) << "Aborted NcclManager " << this << " with " << collectives.size()
|
||||
<< " collectives and " << communicators.size()
|
||||
<< " comms with status " << s;
|
||||
// collectives_ contains pending launches that haven't been dispatched to
|
||||
// kernel launch threads, so we can simply invoke the done callbacks of them.
|
||||
for (const auto& item : collectives) {
|
||||
@ -895,6 +898,12 @@ void NcclManager::StartAbort(const Status& s) {
|
||||
pending.Wait();
|
||||
}
|
||||
|
||||
void NcclManager::Reset() {
|
||||
mutex_lock l(mu_);
|
||||
status_ = Status();
|
||||
VLOG(2) << "Reset NcclManager " << this;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -195,6 +195,10 @@ class NcclManager {
|
||||
// launched with this NcclManager.
|
||||
void StartAbort(const Status& s);
|
||||
|
||||
// Resets a previously aborted NcclManager, making it available for future
|
||||
// collectives.
|
||||
void Reset();
|
||||
|
||||
private:
|
||||
enum CollectiveType {
|
||||
kAllReduce = 1,
|
||||
|
@ -17,8 +17,6 @@ limitations under the License.
|
||||
#include "absl/strings/str_format.h"
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "tensorflow/core/nccl/nccl_manager.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
@ -27,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_device.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/nccl/nccl_manager.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/unbounded_work_queue.h"
|
||||
|
||||
@ -863,20 +862,19 @@ TYPED_TEST(NcclManagerTest, BroadcastInconsistentSource) {
|
||||
// environment, on a single node with multiple GPUS. So tests that rely
|
||||
// upon such simulation need to be skipped on the ROCm platform
|
||||
|
||||
TYPED_TEST(NcclManagerTest, Abort) {
|
||||
TYPED_TEST(NcclManagerTest, AbortThenReset) {
|
||||
using NodeState = typename TestFixture::NodeState;
|
||||
using TestCase = typename TestFixture::TestCase;
|
||||
int num_nodes = 2;
|
||||
const int num_nodes = 2;
|
||||
std::vector<NodeState> nodes(num_nodes);
|
||||
// First do a normal all-reduce to simulate the the case when there're
|
||||
// multiple communicators.
|
||||
this->RunMultiNodeAllReduceTest(nodes, /* num_ranks_per_node */ 1);
|
||||
|
||||
// Use a new communicator_key, which uses a new set of ncclComm underneath.
|
||||
string communicator_key = nodes[0].nccl_manager.GenerateCommunicatorKey();
|
||||
string collective_key = "allreduce";
|
||||
const string collective_key = "allreduce";
|
||||
ncclRedOp_t reduction_op = static_cast<ncclRedOp_t>(0);
|
||||
auto node_fn = [&](TestCase* test_case, int node) {
|
||||
auto node_fn = [&](TestCase* test_case, int node,
|
||||
const string& communicator_key) {
|
||||
auto* device = this->GetDevice(/* num_ranks_per_node */ 1, node,
|
||||
/* local_rank */ 0);
|
||||
auto* info = device->tensorflow_gpu_device_info();
|
||||
@ -894,6 +892,8 @@ TYPED_TEST(NcclManagerTest, Abort) {
|
||||
nodes[node].nccl_manager.SignalMultiNodeReady(collective_key);
|
||||
};
|
||||
|
||||
// Use a new communicator_key, which uses a new set of ncclComm underneath.
|
||||
string communicator_key = nodes[0].nccl_manager.GenerateCommunicatorKey();
|
||||
// Do a normal all-reduce with this communicator key to initialize ncclComm.
|
||||
// This is because ncclCommInitRank waits for all ranks and is blocking.
|
||||
{
|
||||
@ -903,7 +903,9 @@ TYPED_TEST(NcclManagerTest, Abort) {
|
||||
TensorShape({2, 3}), 0.0f));
|
||||
for (int i = 0; i < num_nodes; ++i) {
|
||||
this->work_queue_->Schedule(
|
||||
[&node_fn, &test_case, i]() { node_fn(test_case.get(), i); });
|
||||
[&node_fn, &test_case, i, communicator_key]() {
|
||||
node_fn(test_case.get(), i, communicator_key);
|
||||
});
|
||||
}
|
||||
this->VerifyResults(test_case.get());
|
||||
}
|
||||
@ -914,16 +916,41 @@ TYPED_TEST(NcclManagerTest, Abort) {
|
||||
this->MakeReductionTestCase(
|
||||
/* num_nodes */ num_nodes, /* num_ranks_per_node */ 1, reduction_op,
|
||||
TensorShape({2, 3}), 0.0f));
|
||||
node_fn(test_case.get(), 0);
|
||||
node_fn(test_case.get(), 0, communicator_key);
|
||||
Env::Default()->SleepForMicroseconds(1000000);
|
||||
nodes[0].nccl_manager.StartAbort(errors::Unavailable("peer down"));
|
||||
for (auto& node : nodes) {
|
||||
node.nccl_manager.StartAbort(errors::Unavailable("peer down"));
|
||||
}
|
||||
{
|
||||
mutex_lock l(test_case->mu);
|
||||
while (test_case->num_completed != 1) {
|
||||
test_case->done_cv.wait(l);
|
||||
}
|
||||
}
|
||||
|
||||
// Reset the aborted NcclManager and then run another all-reduce with the
|
||||
// resetted NcclManagers.
|
||||
for (auto& node : nodes) {
|
||||
node.nccl_manager.Reset();
|
||||
}
|
||||
// Regenerate the communicator_key, because this is needed to create new
|
||||
// communicators.
|
||||
communicator_key = nodes[0].nccl_manager.GenerateCommunicatorKey();
|
||||
{
|
||||
std::unique_ptr<typename TestFixture::TestCase> test_case(
|
||||
this->MakeReductionTestCase(
|
||||
/* num_nodes */ num_nodes, /* num_ranks_per_node */ 1, reduction_op,
|
||||
TensorShape({2, 3}), 0.0f));
|
||||
for (int i = 0; i < num_nodes; ++i) {
|
||||
this->work_queue_->Schedule(
|
||||
[&node_fn, &test_case, i, communicator_key]() {
|
||||
node_fn(test_case.get(), i, communicator_key);
|
||||
});
|
||||
}
|
||||
this->VerifyResults(test_case.get());
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user