[retry]Use cancellation manager to abort collectives

The previous change may cause a use-after-free since StartAbort() runs a separate thread but accesses resources owned by CollectiveExecutiveMgr. Once all cancellation callbacks finish, the CollectiveExecutorMgr may already be deallocated while StartAbort() is in progress. Fixing the ownership is not trivial so we now call StartAbort() in the cancellation callback instead to ensure all resources are valid. Note that with this we need to use TryDeregisterCallback in done() instead of DeregisterCallback(), because the latter blocks until all cancellation callback is done.

We used to always abort collective ops in executor when there're errors in graph execution. However there're some errors that are intended for the user to catch, and if we abort collective ops, the user program cannot continue. It's also not necessary
to abort collective ops if there's no active ones.

Ideally we should have a cancellation story for collectives. Before that, we can at least only abort collectives when it's necessary, i.e. when there're pending collective ops or failed collective ops.

To make the the catching EOF workflow work, we also need to make all collectives in gather depend on the input tensors, so there's better chance they fire after iterator GetNext. Without that the shape gathering may run in parallel with GetNext.

PiperOrigin-RevId: 337997169
Change-Id: I4a374f9ff00bdba38e012a96fb7f5837e049c85c
This commit is contained in:
Ran Chen 2020-10-19 22:05:50 -07:00 committed by TensorFlower Gardener
parent d9e09b0723
commit d345c40688
7 changed files with 193 additions and 89 deletions

View File

@ -1119,11 +1119,13 @@ bool ExecutorState<PropagatorStateType>::NodeDone(
if (rendezvous_) {
rendezvous_->StartAbort(s);
}
if (collective_executor_) {
collective_executor_->StartAbort(s);
}
if (cancellation_manager_) {
cancellation_manager_->StartCancel();
} else if (collective_executor_) {
// If there's cancellation_manager_, collective ops aborts
// collective_executor_ upon cancellation; otherwise we need to abort
// here.
collective_executor_->StartAbort(s);
}
}
@ -1267,11 +1269,13 @@ void ExecutorState<PropagatorStateType>::Finish() {
if (rendezvous_) {
rendezvous_->StartAbort(status);
}
if (collective_executor_) {
collective_executor_->StartAbort(status);
}
if (cancellation_manager_) {
cancellation_manager_->StartCancel();
} else if (collective_executor_) {
// If there's cancellation_manager_, collective ops aborts
// collective_executor_ upon cancellation; otherwise we need to abort
// here.
collective_executor_->StartAbort(status);
}
}
delete this;

View File

@ -51,7 +51,56 @@ static std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c,
class CollectiveOpKernel : public AsyncOpKernel {
public:
explicit CollectiveOpKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {}
explicit CollectiveOpKernel(OpKernelConstruction* c)
: AsyncOpKernel(c), name_(name()) {}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
name_),
done);
const CancellationToken token =
c->cancellation_manager()->get_cancellation_token();
const bool already_cancelled =
!c->cancellation_manager()->RegisterCallback(token, [col_exec]() {
// We must call StartAbort() within the callback. StartAbort() relies
// on resources that may be deallocated if all execution of a graph is
// finished.
col_exec->StartAbort(errors::Cancelled("op cancelled"));
});
OP_REQUIRES_ASYNC(c, !already_cancelled,
errors::Cancelled("op cancelled ", name_), done);
auto deregister_and_done = [c, col_exec, token, done = std::move(done)]() {
// Once done() is called, StartAbort() won't have any effect, so we
// don't need to block on the deregistration. Also StartAbort() may call
// done() and DeregisterCallback may deadlock.
c->cancellation_manager()->TryDeregisterCallback(token);
// Abort CollectiveExecutor so that this error can propagate to other
// workers.
if (!c->status().ok()) {
col_exec->StartAbort(c->status());
}
done();
};
ComputeAsyncImpl(c, col_exec, std::move(deregister_and_done));
}
protected:
virtual void ComputeAsyncImpl(OpKernelContext* c,
CollectiveExecutor* col_exec,
DoneCallback done) = 0;
string name_;
};
class CollectiveOpV1Kernel : public CollectiveOpKernel {
public:
explicit CollectiveOpV1Kernel(OpKernelConstruction* c)
: CollectiveOpKernel(c) {}
// A string encoding instance, frame and iter to be handed off to
// the implementation for use in generating RecvBuf keys.
@ -90,14 +139,15 @@ class CollectiveOpKernel : public AsyncOpKernel {
return true;
}
protected:
CollectiveParams col_params_;
std::vector<int32> dependencies_;
};
class CollectiveGatherOpKernel : public CollectiveOpKernel {
class CollectiveGatherOpKernel : public CollectiveOpV1Kernel {
public:
explicit CollectiveGatherOpKernel(OpKernelConstruction* c)
: CollectiveOpKernel(c) {
: CollectiveOpV1Kernel(c) {
col_params_.instance.type = GATHER_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
OP_REQUIRES(
@ -119,15 +169,9 @@ class CollectiveGatherOpKernel : public CollectiveOpKernel {
col_params_.group.device_type = c->device_type();
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
protected:
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
DoneCallback done) override {
auto output_shape = c->input(0).shape();
output_shape.set_dim(
0, output_shape.dim_size(0) * col_params_.group.group_size);
@ -171,10 +215,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_GPU),
CollectiveGatherOpKernel);
class CollectiveReduceOpKernel : public CollectiveOpKernel {
class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
public:
explicit CollectiveReduceOpKernel(OpKernelConstruction* c)
: CollectiveOpKernel(c) {
: CollectiveOpV1Kernel(c) {
col_params_.instance.type = REDUCTION_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
OP_REQUIRES(
@ -231,14 +275,9 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel {
col_params_.final_op = BuildOpKernel(c, final_op_name, &sub_node);
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
protected:
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
DoneCallback done) override {
// Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise
// the memory is not guaranteed to be unused by any concurrently executing
@ -280,10 +319,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_GPU),
CollectiveReduceOpKernel);
class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
class CollectiveBcastSendOpKernel : public CollectiveOpV1Kernel {
public:
explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c)
: CollectiveOpKernel(c) {
: CollectiveOpV1Kernel(c) {
col_params_.instance.type = BROADCAST_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
OP_REQUIRES(
@ -309,14 +348,9 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
col_params_.group.device_type = c->device_type();
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
protected:
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
DoneCallback done) override {
// Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise
// the memory is not guaranteed to be unused by any concurrently executing
@ -362,10 +396,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_GPU),
CollectiveBcastSendOpKernel);
class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
class CollectiveBcastRecvOpKernel : public CollectiveOpV1Kernel {
public:
explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c)
: CollectiveOpKernel(c) {
: CollectiveOpV1Kernel(c) {
col_params_.instance.type = BROADCAST_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
OP_REQUIRES(
@ -391,14 +425,9 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
col_params_.group.device_type = c->device_type();
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
protected:
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
DoneCallback done) override {
// Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise
// the memory is not guaranteed to be unused by any concurrently executing
@ -437,10 +466,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_GPU),
CollectiveBcastRecvOpKernel);
class CollectiveReduceV2OpKernel : public AsyncOpKernel {
class CollectiveReduceV2OpKernel : public CollectiveOpKernel {
public:
explicit CollectiveReduceV2OpKernel(OpKernelConstruction* c)
: AsyncOpKernel(c) {
: CollectiveOpKernel(c) {
col_params_ = std::make_shared<CollectiveParams>();
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
string merge_op_name;
@ -481,14 +510,9 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel {
<< col_params_->instance.impl_details.communication_hint;
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_->name),
done);
protected:
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
DoneCallback done) override {
const Tensor& input = c->input(0);
const Tensor& group_size = c->input(1);
const Tensor& group_key = c->input(2);
@ -590,10 +614,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2")
.HostMemory("instance_key"),
CollectiveReduceV2OpKernel);
class CollectiveGatherV2OpKernel : public AsyncOpKernel {
class CollectiveGatherV2OpKernel : public CollectiveOpKernel {
public:
explicit CollectiveGatherV2OpKernel(OpKernelConstruction* c)
: AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
: CollectiveOpKernel(c), device_type_(DEVICE_DEFAULT) {
OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
@ -603,14 +627,9 @@ class CollectiveGatherV2OpKernel : public AsyncOpKernel {
<< " communication_hint " << communication_hint_;
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
name_),
done);
protected:
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
DoneCallback done) override {
const Tensor& input = c->input(0);
const Tensor& group_size = c->input(1);
const Tensor& group_key = c->input(2);
@ -712,7 +731,6 @@ class CollectiveGatherV2OpKernel : public AsyncOpKernel {
string communication_hint_;
float timeout_seconds_;
DeviceType device_type_;
string name_;
};
REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2").Device(DEVICE_CPU),

View File

@ -412,7 +412,8 @@ class CollectiveReplicaLauncher(object):
self._group_key, self._device)
instance_key_shape = self._collective_keys.get_instance_key(
self._group_key, self._device)
with ops.device(self._device):
with ops.device(self._device), \
ops.control_dependencies([array_ops.identity(input_tensor)]):
# 1. Transpose
# E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3,
# we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which

View File

@ -29,7 +29,6 @@ from tensorflow.python.data.experimental.ops.distribute_options import AutoShard
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.distribute import central_storage_strategy
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import mirrored_strategy
@ -1151,9 +1150,6 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
if mode == 'graph' and _is_tpu_strategy(distribution):
self.skipTest('partial batch not supported with TPU in graph mode.')
if isinstance(distribution,
collective_all_reduce_strategy.CollectiveAllReduceStrategy):
self.skipTest('EOF error causes subsequent collective ops fail.')
with self.cached_session():
with distribution.scope():
optimizer_fn = gradient_descent_keras.SGD
@ -1166,8 +1162,8 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
loss,
metrics=metrics)
inputs = np.zeros((1000, 3), dtype=np.float32)
targets = np.zeros((1000, 4), dtype=np.float32)
inputs = np.zeros((100, 3), dtype=np.float32)
targets = np.zeros((100, 4), dtype=np.float32)
# steps/steps_per_epoch are calculated when using numpy arrays as
# input data.
fit_with_numpy = model.fit(

View File

@ -2756,19 +2756,23 @@ def _collective_all_reduce_multi_worker(strategy):
def _multi_worker_concat(v, strategy):
"""Order PerReplica objects for CollectiveAllReduceStrategy and concat."""
replicas = strategy.gather(v, axis=0) # pylint: disable=protected-access
# v might not have the same shape on different replicas
if isinstance(v, ds_values.PerReplica):
shapes = array_ops.concat([
array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0)
for single_value in v.values
],
axis=0)
all_shapes = strategy.gather(shapes, axis=0) # pylint: disable=protected-access
else:
# v is a tensor. This may happen when, say, we have 2x1 multi-worker.
all_shapes = strategy.gather( # pylint: disable=protected-access
array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0),
axis=0)
# TODO(b/170435030): We now need to make sure these run after the iterator
# GetNext, so that we don't trigger aborting collective ops in the case of
# EOF. Remove after the issue is fixed.
with ops.control_dependencies([replicas]):
# v might not have the same shape on different replicas
if isinstance(v, ds_values.PerReplica):
shapes = array_ops.concat([
array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0)
for single_value in v.values
],
axis=0)
all_shapes = strategy.gather(shapes, axis=0)
else:
# v is a tensor. This may happen when, say, we have 2x1 multi-worker.
all_shapes = strategy.gather(
array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0),
axis=0)
replicas = array_ops.split(
replicas,

View File

@ -298,6 +298,8 @@ cuda_py_test(
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/data/experimental/ops:testing",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:test_util",
"//tensorflow/python/eager:context",

View File

@ -24,6 +24,8 @@ import time
from absl.testing import parameterized
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.experimental.ops import testing as dataset_testing
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import test_util
from tensorflow.python.eager import context
@ -469,6 +471,83 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
_setup_context()
def_function.function(collective_fn)()
def testOpErrorNotAbort(self, collective_op, device, communication):
# Do not abort if there's no active collective ops. There could be
# exceptions like EOF which we expect users to catch, aborting collective
# ops on all op errors intervenes with this workflow.
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
group_size = 2
group_key = 100
instance_key = 100
dataset = dataset_ops.Dataset.from_tensors([1.])
@def_function.function
def collective_fn(in_tensor):
for device in [dev0, dev1]:
with ops.device(device):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
@def_function.function
def f():
iterator = iter(dataset)
collective_fn(next(iterator))
# This next(iterator) should raise EOF.
collective_fn(next(iterator))
with self.assertRaises(errors.OutOfRangeError):
f()
collective_fn(constant_op.constant([1.]))
def testOpErrorAbort(self, collective_op, device, communication):
# Abort collective ops if there're active collective ops at the time of an
# op error. This is due to the inability to cancel collective ops, and op
# errors may cause running collective ops to hang.
dev0 = '/device:%s:0' % device
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
# Make the dataset sleep a while so that the collective is being executed
# when the EOF happens.
dataset = dataset_ops.Dataset.from_tensors([1.]).apply(
dataset_testing.sleep(sleep_microseconds=200))
@def_function.function
def f():
# Launch a collective op that won't be able to finish to test abortion
# when other ops error.
with ops.device(dev0):
ret = collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
iterator = iter(dataset)
next(iterator)
# This should raise EOF.
next(iterator)
return ret
with self.assertRaises(errors.OutOfRangeError):
f()
# Now collective ops is aborted, subsequent collective ops should fail with
# the previous error.
with self.assertRaises(errors.CancelledError):
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
@combinations.generate(
combinations.times(