diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc index 8e19c9587fa..fe48b3f6079 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.cc +++ b/tensorflow/core/common_runtime/base_collective_executor.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -42,6 +43,14 @@ limitations under the License. #define VALUE_IN_DEBUG_STRING false namespace tensorflow { + +namespace { +bool IsCancelled(CancellationManager* cancel_mgr) { + return cancel_mgr != nullptr && + (cancel_mgr->IsCancelled() || cancel_mgr->IsCancelling()); +} +} // namespace + /*static*/ int64 CollectiveAdapter::AlignedChunkElts(int64 elt_bytes, int64 total_elts, int64 num_chunks) { @@ -215,14 +224,12 @@ CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks, BaseCollectiveExecutor::~BaseCollectiveExecutor() {} void BaseCollectiveExecutor::StartAbort(const Status& s) { - VLOG(1) << "BaseCollectiveExecutor::StartAbort " << s; Status status; { mutex_lock l(status_mu_); if (!status_.ok()) { - LOG(WARNING) - << "BaseCollectiveExecutor already aborted, ignoring StartAbort: " - << s; + VLOG(2) << "BaseCollectiveExecutor already aborted, ignoring StartAbort: " + << s; return; } status_ = StatusGroup::MakeDerived(Status( @@ -233,6 +240,7 @@ void BaseCollectiveExecutor::StartAbort(const Status& s) { "program to reset."))); status = status_; } + LOG(ERROR) << "BaseCollectiveExecutor::StartAbort " << s; cem_->GetParamResolver()->StartAbort(status); remote_access_->StartAbort(status); if (cem_->GetNcclCommunicator() != nullptr) { @@ -261,9 +269,14 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, StatusCallback done) { // See CompleteParamsAsync() how done() and the timeout callback interacts. const auto is_callback_called = std::make_shared<std::atomic<bool>>(false); - auto done_safe = [this, done, is_callback_called](const Status& s) { + auto done_safe = [this, done, ctx, is_callback_called](const Status& s) { bool called = is_callback_called->exchange(true); if (!called) { + if (!s.ok() && !IsCancelled(ctx->cancellation_manager())) { + // This is a collective error. Abort CollectiveExecutor so that this + // error can propagate to other workers. + StartAbort(s); + } done(GetStatus(s)); } }; @@ -341,9 +354,15 @@ void BaseCollectiveExecutor::CompleteParamsAsync( // timeout callback executes, done_safe will become a no-op and the timeout // callback is responsible for invoking done() at the end. const auto is_callback_called = std::make_shared<std::atomic<bool>>(false); - auto done_safe = [this, is_callback_called, done](const Status& s) { + auto done_safe = [this, is_callback_called, cancel_mgr, + done](const Status& s) { bool called = is_callback_called->exchange(true); if (!called) { + if (!s.ok() && !IsCancelled(cancel_mgr)) { + // This is a collective error. Abort CollectiveExecutor so that this + // error can propagate to other workers. + StartAbort(s); + } done(GetStatus(s)); } }; diff --git a/tensorflow/core/common_runtime/ring_alg.cc b/tensorflow/core/common_runtime/ring_alg.cc index 91af06cf352..e664eb90865 100644 --- a/tensorflow/core/common_runtime/ring_alg.cc +++ b/tensorflow/core/common_runtime/ring_alg.cc @@ -278,12 +278,17 @@ void RingAlg::StartAbort(const Status& s) { status_.Update(s); } } - // If this is the initial entry to abort mode then invoke StartAbort - // on the CollectiveExecutor that invoked us. That should start - // cancellation on all of the outstanding CollectiveRemoteAccess - // actions. + // If this is the initial entry to abort mode and it's not a cancellation, + // then invoke StartAbort on the CollectiveExecutor that invoked us. That + // should start cancellation on all of the outstanding CollectiveRemoteAccess + // actions. If it's cancellation all pending send/recv should be cancelled as + // well and there's then no need to abort. if (abort_started) { - col_ctx_->col_exec->StartAbort(s); + if (col_ctx_->op_ctx->cancellation_manager() == nullptr || + (!col_ctx_->op_ctx->cancellation_manager()->IsCancelled() && + !col_ctx_->op_ctx->cancellation_manager()->IsCancelling())) { + col_ctx_->col_exec->StartAbort(s); + } } } diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc index 9cccae51b15..357ae158ea1 100644 --- a/tensorflow/core/kernels/collective_ops.cc +++ b/tensorflow/core/kernels/collective_ops.cc @@ -49,9 +49,9 @@ static std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c, return k; } -class CollectiveOpKernel : public AsyncOpKernel { +class CollectiveOpV1Kernel : public AsyncOpKernel { public: - explicit CollectiveOpKernel(OpKernelConstruction* c) + explicit CollectiveOpV1Kernel(OpKernelConstruction* c) : AsyncOpKernel(c), name_(name()) {} void ComputeAsync(OpKernelContext* c, DoneCallback done) override { @@ -79,29 +79,11 @@ class CollectiveOpKernel : public AsyncOpKernel { // don't need to block on the deregistration. Also StartAbort() may call // done() and DeregisterCallback may deadlock. c->cancellation_manager()->TryDeregisterCallback(token); - // Abort CollectiveExecutor so that this error can propagate to other - // workers. - if (!c->status().ok()) { - col_exec->StartAbort(c->status()); - } done(); }; ComputeAsyncImpl(c, col_exec, std::move(deregister_and_done)); } - protected: - virtual void ComputeAsyncImpl(OpKernelContext* c, - CollectiveExecutor* col_exec, - DoneCallback done) = 0; - - string name_; -}; - -class CollectiveOpV1Kernel : public CollectiveOpKernel { - public: - explicit CollectiveOpV1Kernel(OpKernelConstruction* c) - : CollectiveOpKernel(c) {} - // A string encoding instance, frame and iter to be handed off to // the implementation for use in generating RecvBuf keys. string GetCollectiveKey(OpKernelContext* c) { @@ -140,6 +122,11 @@ class CollectiveOpV1Kernel : public CollectiveOpKernel { } protected: + virtual void ComputeAsyncImpl(OpKernelContext* c, + CollectiveExecutor* col_exec, + DoneCallback done) = 0; + + string name_; CollectiveParams col_params_; std::vector<int32> dependencies_; }; @@ -470,10 +457,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_GPU), CollectiveBcastRecvOpKernel); -class CollectiveReduceV2OpKernel : public CollectiveOpKernel { +class CollectiveReduceV2OpKernel : public AsyncOpKernel { public: explicit CollectiveReduceV2OpKernel(OpKernelConstruction* c) - : CollectiveOpKernel(c), device_type_(DEVICE_DEFAULT) { + : AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) { OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_)); string merge_op_name; OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name)); @@ -504,9 +491,14 @@ class CollectiveReduceV2OpKernel : public CollectiveOpKernel { << " communication_hint " << communication_hint_; } - protected: - void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec, - DoneCallback done) override { + void ComputeAsync(OpKernelContext* c, DoneCallback done) override { + CollectiveExecutor* col_exec = c->collective_executor(); + OP_REQUIRES_ASYNC( + c, col_exec, + errors::Internal( + "Failed to get CollectiveExecutor from OpKernelContext for Op ", + name_), + done); const Tensor& input = c->input(0); const Tensor& group_size = c->input(1); const Tensor& group_key = c->input(2); @@ -597,6 +589,7 @@ class CollectiveReduceV2OpKernel : public CollectiveOpKernel { } private: + string name_; DataType data_type_ = DT_INVALID; string communication_hint_; float timeout_seconds_ = 0; @@ -614,10 +607,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2") .HostMemory("instance_key"), CollectiveReduceV2OpKernel); -class CollectiveGatherV2OpKernel : public CollectiveOpKernel { +class CollectiveGatherV2OpKernel : public AsyncOpKernel { public: explicit CollectiveGatherV2OpKernel(OpKernelConstruction* c) - : CollectiveOpKernel(c), device_type_(DEVICE_DEFAULT) { + : AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) { OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_)); OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_)); OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_)); @@ -627,9 +620,14 @@ class CollectiveGatherV2OpKernel : public CollectiveOpKernel { << " communication_hint " << communication_hint_; } - protected: - void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec, - DoneCallback done) override { + void ComputeAsync(OpKernelContext* c, DoneCallback done) override { + CollectiveExecutor* col_exec = c->collective_executor(); + OP_REQUIRES_ASYNC( + c, col_exec, + errors::Internal( + "Failed to get CollectiveExecutor from OpKernelContext for Op ", + name_), + done); const Tensor& input = c->input(0); const Tensor& group_size = c->input(1); const Tensor& group_key = c->input(2); @@ -728,6 +726,7 @@ class CollectiveGatherV2OpKernel : public CollectiveOpKernel { } private: + string name_; DataType data_type_ = DT_INVALID; string communication_hint_; float timeout_seconds_ = 0; diff --git a/tensorflow/core/nccl/collective_communicator.cc b/tensorflow/core/nccl/collective_communicator.cc index 56e2255ae99..bcdee71be18 100644 --- a/tensorflow/core/nccl/collective_communicator.cc +++ b/tensorflow/core/nccl/collective_communicator.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/nccl/collective_communicator.h" +#include "tensorflow/core/framework/cancellation.h" + #if TENSORFLOW_USE_NCCL && (GOOGLE_CUDA || TENSORFLOW_USE_ROCM) #include "absl/memory/memory.h" @@ -77,7 +79,25 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx, auto* gpu_info = col_ctx->op_ctx->device()->tensorflow_gpu_device_info(); auto participant = absl::make_unique<NcclManager::Participant>( compute_stream->parent(), compute_stream, gpu_info, col_ctx->input, - col_ctx->output, col_ctx->col_params.default_rank, std::move(done)); + col_ctx->output, col_ctx->col_params.default_rank, + /*done_callback=*/nullptr); + CancellationManager* cancel_mgr = col_ctx->op_ctx->cancellation_manager(); + if (cancel_mgr == nullptr) { + participant->done_callback = std::move(done); + } else { + CancellationToken cancel_token = cancel_mgr->get_cancellation_token(); + cancel_mgr->RegisterCallback(cancel_token, [this]() { + nccl_manager_.StartAbort(errors::Cancelled("op cancelled")); + nccl_manager_.Reset(); + }); + participant->done_callback = [cancel_mgr, cancel_token, + done = std::move(done)](const Status& s) { + // Do not block on deregistration since this can be invoked by + // NcclManager::StartAbort() in the cancellation callback. + cancel_mgr->TryDeregisterCallback(cancel_token); + done(s); + }; + } NcclManager::Context context( nccl_collective_key, num_local_devices, num_global_devices, col_params.group.runtime_details.communicator_key, diff --git a/tensorflow/core/nccl/nccl_manager.cc b/tensorflow/core/nccl/nccl_manager.cc index a31aafcdab1..eaa34d042ce 100644 --- a/tensorflow/core/nccl/nccl_manager.cc +++ b/tensorflow/core/nccl/nccl_manager.cc @@ -875,11 +875,12 @@ void NcclManager::StartAbort(const Status& s) { } item.second->Unref(); } - // Abort ncclComm. Note that there could be multiple ncclComm per device, and - // ncclCommAbort contains cuda calls that requires device synchronization. - // That is a collective on nccl_comm_0 can block ncclCommAbort(nccl_comm_1), - // so we need to abort all ncclComm in a concurrent fashion. This assumes that - // there's only one active NcclManager at a time. + // Abort ncclComm. Note that there could be multiple ncclComm per device, + // and ncclCommAbort contains cuda calls that requires device + // synchronization. That is a collective on nccl_comm_0 can block + // ncclCommAbort(nccl_comm_1), so we need to abort all ncclComm in a + // concurrent fashion. This assumes that there's only one active NcclManager + // at a time. UnboundedWorkQueue queue(Env::Default(), "nccl_abort"); int num_comms = 0; for (std::unique_ptr<Communicator>& communicator : communicators) { diff --git a/tensorflow/python/kernel_tests/collective_ops_test.py b/tensorflow/python/kernel_tests/collective_ops_test.py index 669aae49b41..fe558bcae64 100644 --- a/tensorflow/python/kernel_tests/collective_ops_test.py +++ b/tensorflow/python/kernel_tests/collective_ops_test.py @@ -471,7 +471,29 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase): _setup_context() def_function.function(collective_fn)() - def testOpErrorNotAbort(self, collective_op, device, communication): + +class OpCancellationTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + _setup_context() + super().setUp() + + @combinations.generate( + combinations.times( + combinations.combine( + collective_op=[ + combinations.NamedObject('all_reduce', + CollectiveOpsV1.all_reduce), + combinations.NamedObject('all_reduce_v2', + CollectiveOpsV2.all_reduce), + combinations.NamedObject('all_gather', + CollectiveOpsV1.all_gather), + combinations.NamedObject('all_gather_v2', + CollectiveOpsV2.all_gather), + ], + mode='eager'), device_combination)) + def testOpErrorNotAbortIfNoCollective(self, collective_op, device, + communication): # Do not abort if there's no active collective ops. There could be # exceptions like EOF which we expect users to catch, aborting collective # ops on all op errors intervenes with this workflow. @@ -504,9 +526,20 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase): f() collective_fn(constant_op.constant([1.])) - def testOpErrorAbort(self, collective_op, device, communication): - # Abort collective ops if there're active collective ops at the time of an - # op error. This is due to the inability to cancel collective ops, and op + @combinations.generate( + combinations.times( + combinations.combine( + collective_op=[ + combinations.NamedObject('all_reduce', + CollectiveOpsV1.all_reduce), + combinations.NamedObject('all_gather', + CollectiveOpsV1.all_gather), + ], + mode='eager'), device_combination)) + def testOpErrorAbortWithCollective(self, collective_op, device, + communication): + # Abort v1 collective ops if there're active collective ops at the time of + # an op error. This is due to the inability to cancel collective ops, and op # errors may cause running collective ops to hang. dev0 = '/device:%s:0' % device group_size = 2 @@ -548,6 +581,71 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase): instance_key, communication_hint=communication) + @combinations.generate( + combinations.times( + combinations.combine( + collective_op=[ + combinations.NamedObject('all_reduce_v2', + CollectiveOpsV2.all_reduce), + combinations.NamedObject('all_gather_v2', + CollectiveOpsV2.all_gather), + ], + mode='eager'), device_combination)) + def testOpErrorNotAbortWithCollective(self, collective_op, device, + communication): + # Do not abort v2 collective ops even if there're active collective ops at + # the time of an op error. We rely cancellation to terminate active + # collective ops. + dev0 = '/device:%s:0' % device + dev1 = '/device:%s:1' % device + group_size = 2 + group_key = 100 + instance_key = 100 + in_tensor = constant_op.constant([1.]) + + @def_function.function + def collective_fn(): + for device in [dev0, dev1]: + with ops.device(device): + collective_op( + in_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication) + + # Local params resolution cannot be cancelled yet, so we perform a normal + # collective so that the group is resolved. + collective_fn() + + # Make the dataset sleep a while so that the collective is being executed + # when the EOF happens. + dataset = dataset_ops.Dataset.from_tensors([1.]).apply( + dataset_testing.sleep(sleep_microseconds=200)) + + @def_function.function + def f(): + # Launch a collective op that won't be able to finish to test cancellation + # when other ops error. + with ops.device(dev0): + ret = collective_op( + in_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication) + iterator = iter(dataset) + next(iterator) + # This should raise EOF. + next(iterator) + return ret + + with self.assertRaises(errors.OutOfRangeError): + f() + # Collective ops shouldn't be aborted and new collectives should be able to + # proceed. + collective_fn() + @combinations.generate( combinations.times(