Rely on cancellation in collective V2 kernels
For collective v2 kernels we stop aborting collective ops if they're cancelled. Most componenets except param resolution is able to respond to cancellation. Param resolution is not large concern in practice, since group resolution is likely not needed, and most instance resolution do not block. Technically we can do this to v1 kernels as well, but it doesn't seem safe since we reuse instance keys in v1 collectives. PiperOrigin-RevId: 338508607 Change-Id: Iab2f4e1061d7b384b83bc2712b849e42ba3677fc
This commit is contained in:
parent
f3e4258863
commit
861764c406
tensorflow
core
common_runtime
kernels
nccl
python/kernel_tests
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
@ -42,6 +43,14 @@ limitations under the License.
|
||||
#define VALUE_IN_DEBUG_STRING false
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
bool IsCancelled(CancellationManager* cancel_mgr) {
|
||||
return cancel_mgr != nullptr &&
|
||||
(cancel_mgr->IsCancelled() || cancel_mgr->IsCancelling());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/*static*/
|
||||
int64 CollectiveAdapter::AlignedChunkElts(int64 elt_bytes, int64 total_elts,
|
||||
int64 num_chunks) {
|
||||
@ -215,14 +224,12 @@ CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks,
|
||||
BaseCollectiveExecutor::~BaseCollectiveExecutor() {}
|
||||
|
||||
void BaseCollectiveExecutor::StartAbort(const Status& s) {
|
||||
VLOG(1) << "BaseCollectiveExecutor::StartAbort " << s;
|
||||
Status status;
|
||||
{
|
||||
mutex_lock l(status_mu_);
|
||||
if (!status_.ok()) {
|
||||
LOG(WARNING)
|
||||
<< "BaseCollectiveExecutor already aborted, ignoring StartAbort: "
|
||||
<< s;
|
||||
VLOG(2) << "BaseCollectiveExecutor already aborted, ignoring StartAbort: "
|
||||
<< s;
|
||||
return;
|
||||
}
|
||||
status_ = StatusGroup::MakeDerived(Status(
|
||||
@ -233,6 +240,7 @@ void BaseCollectiveExecutor::StartAbort(const Status& s) {
|
||||
"program to reset.")));
|
||||
status = status_;
|
||||
}
|
||||
LOG(ERROR) << "BaseCollectiveExecutor::StartAbort " << s;
|
||||
cem_->GetParamResolver()->StartAbort(status);
|
||||
remote_access_->StartAbort(status);
|
||||
if (cem_->GetNcclCommunicator() != nullptr) {
|
||||
@ -261,9 +269,14 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
|
||||
StatusCallback done) {
|
||||
// See CompleteParamsAsync() how done() and the timeout callback interacts.
|
||||
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
|
||||
auto done_safe = [this, done, is_callback_called](const Status& s) {
|
||||
auto done_safe = [this, done, ctx, is_callback_called](const Status& s) {
|
||||
bool called = is_callback_called->exchange(true);
|
||||
if (!called) {
|
||||
if (!s.ok() && !IsCancelled(ctx->cancellation_manager())) {
|
||||
// This is a collective error. Abort CollectiveExecutor so that this
|
||||
// error can propagate to other workers.
|
||||
StartAbort(s);
|
||||
}
|
||||
done(GetStatus(s));
|
||||
}
|
||||
};
|
||||
@ -341,9 +354,15 @@ void BaseCollectiveExecutor::CompleteParamsAsync(
|
||||
// timeout callback executes, done_safe will become a no-op and the timeout
|
||||
// callback is responsible for invoking done() at the end.
|
||||
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
|
||||
auto done_safe = [this, is_callback_called, done](const Status& s) {
|
||||
auto done_safe = [this, is_callback_called, cancel_mgr,
|
||||
done](const Status& s) {
|
||||
bool called = is_callback_called->exchange(true);
|
||||
if (!called) {
|
||||
if (!s.ok() && !IsCancelled(cancel_mgr)) {
|
||||
// This is a collective error. Abort CollectiveExecutor so that this
|
||||
// error can propagate to other workers.
|
||||
StartAbort(s);
|
||||
}
|
||||
done(GetStatus(s));
|
||||
}
|
||||
};
|
||||
|
@ -278,12 +278,17 @@ void RingAlg::StartAbort(const Status& s) {
|
||||
status_.Update(s);
|
||||
}
|
||||
}
|
||||
// If this is the initial entry to abort mode then invoke StartAbort
|
||||
// on the CollectiveExecutor that invoked us. That should start
|
||||
// cancellation on all of the outstanding CollectiveRemoteAccess
|
||||
// actions.
|
||||
// If this is the initial entry to abort mode and it's not a cancellation,
|
||||
// then invoke StartAbort on the CollectiveExecutor that invoked us. That
|
||||
// should start cancellation on all of the outstanding CollectiveRemoteAccess
|
||||
// actions. If it's cancellation all pending send/recv should be cancelled as
|
||||
// well and there's then no need to abort.
|
||||
if (abort_started) {
|
||||
col_ctx_->col_exec->StartAbort(s);
|
||||
if (col_ctx_->op_ctx->cancellation_manager() == nullptr ||
|
||||
(!col_ctx_->op_ctx->cancellation_manager()->IsCancelled() &&
|
||||
!col_ctx_->op_ctx->cancellation_manager()->IsCancelling())) {
|
||||
col_ctx_->col_exec->StartAbort(s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -49,9 +49,9 @@ static std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c,
|
||||
return k;
|
||||
}
|
||||
|
||||
class CollectiveOpKernel : public AsyncOpKernel {
|
||||
class CollectiveOpV1Kernel : public AsyncOpKernel {
|
||||
public:
|
||||
explicit CollectiveOpKernel(OpKernelConstruction* c)
|
||||
explicit CollectiveOpV1Kernel(OpKernelConstruction* c)
|
||||
: AsyncOpKernel(c), name_(name()) {}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
@ -79,29 +79,11 @@ class CollectiveOpKernel : public AsyncOpKernel {
|
||||
// don't need to block on the deregistration. Also StartAbort() may call
|
||||
// done() and DeregisterCallback may deadlock.
|
||||
c->cancellation_manager()->TryDeregisterCallback(token);
|
||||
// Abort CollectiveExecutor so that this error can propagate to other
|
||||
// workers.
|
||||
if (!c->status().ok()) {
|
||||
col_exec->StartAbort(c->status());
|
||||
}
|
||||
done();
|
||||
};
|
||||
ComputeAsyncImpl(c, col_exec, std::move(deregister_and_done));
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual void ComputeAsyncImpl(OpKernelContext* c,
|
||||
CollectiveExecutor* col_exec,
|
||||
DoneCallback done) = 0;
|
||||
|
||||
string name_;
|
||||
};
|
||||
|
||||
class CollectiveOpV1Kernel : public CollectiveOpKernel {
|
||||
public:
|
||||
explicit CollectiveOpV1Kernel(OpKernelConstruction* c)
|
||||
: CollectiveOpKernel(c) {}
|
||||
|
||||
// A string encoding instance, frame and iter to be handed off to
|
||||
// the implementation for use in generating RecvBuf keys.
|
||||
string GetCollectiveKey(OpKernelContext* c) {
|
||||
@ -140,6 +122,11 @@ class CollectiveOpV1Kernel : public CollectiveOpKernel {
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual void ComputeAsyncImpl(OpKernelContext* c,
|
||||
CollectiveExecutor* col_exec,
|
||||
DoneCallback done) = 0;
|
||||
|
||||
string name_;
|
||||
CollectiveParams col_params_;
|
||||
std::vector<int32> dependencies_;
|
||||
};
|
||||
@ -470,10 +457,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_CPU),
|
||||
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_GPU),
|
||||
CollectiveBcastRecvOpKernel);
|
||||
|
||||
class CollectiveReduceV2OpKernel : public CollectiveOpKernel {
|
||||
class CollectiveReduceV2OpKernel : public AsyncOpKernel {
|
||||
public:
|
||||
explicit CollectiveReduceV2OpKernel(OpKernelConstruction* c)
|
||||
: CollectiveOpKernel(c), device_type_(DEVICE_DEFAULT) {
|
||||
: AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
|
||||
OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
|
||||
string merge_op_name;
|
||||
OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
|
||||
@ -504,9 +491,14 @@ class CollectiveReduceV2OpKernel : public CollectiveOpKernel {
|
||||
<< " communication_hint " << communication_hint_;
|
||||
}
|
||||
|
||||
protected:
|
||||
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
|
||||
DoneCallback done) override {
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
CollectiveExecutor* col_exec = c->collective_executor();
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, col_exec,
|
||||
errors::Internal(
|
||||
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
|
||||
name_),
|
||||
done);
|
||||
const Tensor& input = c->input(0);
|
||||
const Tensor& group_size = c->input(1);
|
||||
const Tensor& group_key = c->input(2);
|
||||
@ -597,6 +589,7 @@ class CollectiveReduceV2OpKernel : public CollectiveOpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
string name_;
|
||||
DataType data_type_ = DT_INVALID;
|
||||
string communication_hint_;
|
||||
float timeout_seconds_ = 0;
|
||||
@ -614,10 +607,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2")
|
||||
.HostMemory("instance_key"),
|
||||
CollectiveReduceV2OpKernel);
|
||||
|
||||
class CollectiveGatherV2OpKernel : public CollectiveOpKernel {
|
||||
class CollectiveGatherV2OpKernel : public AsyncOpKernel {
|
||||
public:
|
||||
explicit CollectiveGatherV2OpKernel(OpKernelConstruction* c)
|
||||
: CollectiveOpKernel(c), device_type_(DEVICE_DEFAULT) {
|
||||
: AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
|
||||
OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
|
||||
@ -627,9 +620,14 @@ class CollectiveGatherV2OpKernel : public CollectiveOpKernel {
|
||||
<< " communication_hint " << communication_hint_;
|
||||
}
|
||||
|
||||
protected:
|
||||
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
|
||||
DoneCallback done) override {
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
CollectiveExecutor* col_exec = c->collective_executor();
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, col_exec,
|
||||
errors::Internal(
|
||||
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
|
||||
name_),
|
||||
done);
|
||||
const Tensor& input = c->input(0);
|
||||
const Tensor& group_size = c->input(1);
|
||||
const Tensor& group_key = c->input(2);
|
||||
@ -728,6 +726,7 @@ class CollectiveGatherV2OpKernel : public CollectiveOpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
string name_;
|
||||
DataType data_type_ = DT_INVALID;
|
||||
string communication_hint_;
|
||||
float timeout_seconds_ = 0;
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/nccl/collective_communicator.h"
|
||||
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
|
||||
#if TENSORFLOW_USE_NCCL && (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
@ -77,7 +79,25 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
|
||||
auto* gpu_info = col_ctx->op_ctx->device()->tensorflow_gpu_device_info();
|
||||
auto participant = absl::make_unique<NcclManager::Participant>(
|
||||
compute_stream->parent(), compute_stream, gpu_info, col_ctx->input,
|
||||
col_ctx->output, col_ctx->col_params.default_rank, std::move(done));
|
||||
col_ctx->output, col_ctx->col_params.default_rank,
|
||||
/*done_callback=*/nullptr);
|
||||
CancellationManager* cancel_mgr = col_ctx->op_ctx->cancellation_manager();
|
||||
if (cancel_mgr == nullptr) {
|
||||
participant->done_callback = std::move(done);
|
||||
} else {
|
||||
CancellationToken cancel_token = cancel_mgr->get_cancellation_token();
|
||||
cancel_mgr->RegisterCallback(cancel_token, [this]() {
|
||||
nccl_manager_.StartAbort(errors::Cancelled("op cancelled"));
|
||||
nccl_manager_.Reset();
|
||||
});
|
||||
participant->done_callback = [cancel_mgr, cancel_token,
|
||||
done = std::move(done)](const Status& s) {
|
||||
// Do not block on deregistration since this can be invoked by
|
||||
// NcclManager::StartAbort() in the cancellation callback.
|
||||
cancel_mgr->TryDeregisterCallback(cancel_token);
|
||||
done(s);
|
||||
};
|
||||
}
|
||||
NcclManager::Context context(
|
||||
nccl_collective_key, num_local_devices, num_global_devices,
|
||||
col_params.group.runtime_details.communicator_key,
|
||||
|
@ -875,11 +875,12 @@ void NcclManager::StartAbort(const Status& s) {
|
||||
}
|
||||
item.second->Unref();
|
||||
}
|
||||
// Abort ncclComm. Note that there could be multiple ncclComm per device, and
|
||||
// ncclCommAbort contains cuda calls that requires device synchronization.
|
||||
// That is a collective on nccl_comm_0 can block ncclCommAbort(nccl_comm_1),
|
||||
// so we need to abort all ncclComm in a concurrent fashion. This assumes that
|
||||
// there's only one active NcclManager at a time.
|
||||
// Abort ncclComm. Note that there could be multiple ncclComm per device,
|
||||
// and ncclCommAbort contains cuda calls that requires device
|
||||
// synchronization. That is a collective on nccl_comm_0 can block
|
||||
// ncclCommAbort(nccl_comm_1), so we need to abort all ncclComm in a
|
||||
// concurrent fashion. This assumes that there's only one active NcclManager
|
||||
// at a time.
|
||||
UnboundedWorkQueue queue(Env::Default(), "nccl_abort");
|
||||
int num_comms = 0;
|
||||
for (std::unique_ptr<Communicator>& communicator : communicators) {
|
||||
|
@ -471,7 +471,29 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
_setup_context()
|
||||
def_function.function(collective_fn)()
|
||||
|
||||
def testOpErrorNotAbort(self, collective_op, device, communication):
|
||||
|
||||
class OpCancellationTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
_setup_context()
|
||||
super().setUp()
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.combine(
|
||||
collective_op=[
|
||||
combinations.NamedObject('all_reduce',
|
||||
CollectiveOpsV1.all_reduce),
|
||||
combinations.NamedObject('all_reduce_v2',
|
||||
CollectiveOpsV2.all_reduce),
|
||||
combinations.NamedObject('all_gather',
|
||||
CollectiveOpsV1.all_gather),
|
||||
combinations.NamedObject('all_gather_v2',
|
||||
CollectiveOpsV2.all_gather),
|
||||
],
|
||||
mode='eager'), device_combination))
|
||||
def testOpErrorNotAbortIfNoCollective(self, collective_op, device,
|
||||
communication):
|
||||
# Do not abort if there's no active collective ops. There could be
|
||||
# exceptions like EOF which we expect users to catch, aborting collective
|
||||
# ops on all op errors intervenes with this workflow.
|
||||
@ -504,9 +526,20 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
f()
|
||||
collective_fn(constant_op.constant([1.]))
|
||||
|
||||
def testOpErrorAbort(self, collective_op, device, communication):
|
||||
# Abort collective ops if there're active collective ops at the time of an
|
||||
# op error. This is due to the inability to cancel collective ops, and op
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.combine(
|
||||
collective_op=[
|
||||
combinations.NamedObject('all_reduce',
|
||||
CollectiveOpsV1.all_reduce),
|
||||
combinations.NamedObject('all_gather',
|
||||
CollectiveOpsV1.all_gather),
|
||||
],
|
||||
mode='eager'), device_combination))
|
||||
def testOpErrorAbortWithCollective(self, collective_op, device,
|
||||
communication):
|
||||
# Abort v1 collective ops if there're active collective ops at the time of
|
||||
# an op error. This is due to the inability to cancel collective ops, and op
|
||||
# errors may cause running collective ops to hang.
|
||||
dev0 = '/device:%s:0' % device
|
||||
group_size = 2
|
||||
@ -548,6 +581,71 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
instance_key,
|
||||
communication_hint=communication)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.combine(
|
||||
collective_op=[
|
||||
combinations.NamedObject('all_reduce_v2',
|
||||
CollectiveOpsV2.all_reduce),
|
||||
combinations.NamedObject('all_gather_v2',
|
||||
CollectiveOpsV2.all_gather),
|
||||
],
|
||||
mode='eager'), device_combination))
|
||||
def testOpErrorNotAbortWithCollective(self, collective_op, device,
|
||||
communication):
|
||||
# Do not abort v2 collective ops even if there're active collective ops at
|
||||
# the time of an op error. We rely cancellation to terminate active
|
||||
# collective ops.
|
||||
dev0 = '/device:%s:0' % device
|
||||
dev1 = '/device:%s:1' % device
|
||||
group_size = 2
|
||||
group_key = 100
|
||||
instance_key = 100
|
||||
in_tensor = constant_op.constant([1.])
|
||||
|
||||
@def_function.function
|
||||
def collective_fn():
|
||||
for device in [dev0, dev1]:
|
||||
with ops.device(device):
|
||||
collective_op(
|
||||
in_tensor,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint=communication)
|
||||
|
||||
# Local params resolution cannot be cancelled yet, so we perform a normal
|
||||
# collective so that the group is resolved.
|
||||
collective_fn()
|
||||
|
||||
# Make the dataset sleep a while so that the collective is being executed
|
||||
# when the EOF happens.
|
||||
dataset = dataset_ops.Dataset.from_tensors([1.]).apply(
|
||||
dataset_testing.sleep(sleep_microseconds=200))
|
||||
|
||||
@def_function.function
|
||||
def f():
|
||||
# Launch a collective op that won't be able to finish to test cancellation
|
||||
# when other ops error.
|
||||
with ops.device(dev0):
|
||||
ret = collective_op(
|
||||
in_tensor,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint=communication)
|
||||
iterator = iter(dataset)
|
||||
next(iterator)
|
||||
# This should raise EOF.
|
||||
next(iterator)
|
||||
return ret
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
f()
|
||||
# Collective ops shouldn't be aborted and new collectives should be able to
|
||||
# proceed.
|
||||
collective_fn()
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
|
Loading…
Reference in New Issue
Block a user