Use cancellation manager to abort collectives

We used to always abort collective ops in executor when there're errors in graph execution. However there're some errors that are intended for the user to catch, and if we abort collective ops, the user program cannot continue. It's also not necessary
to abort collective ops if there's no active ones.

Ideally we should have a cancellation story for collectives. Before that, we can at least only abort collectives when it's necessary, i.e. when there're pending collective ops or failed collective ops.

To make the the catching EOF workflow work, we also need to make all collectives in gather depend on the input tensors, so there's better chance they fire after iterator GetNext. Without that the shape gathering may run in parallel with GetNext.

PiperOrigin-RevId: 337440792
Change-Id: I7caea917c858bcf99f6eb471abf46d94d5c255b3
This commit is contained in:
Ran Chen 2020-10-15 21:28:33 -07:00 committed by TensorFlower Gardener
parent 15dd772865
commit f0844f4065
7 changed files with 195 additions and 89 deletions

View File

@ -1119,11 +1119,15 @@ bool ExecutorState<PropagatorStateType>::NodeDone(
if (rendezvous_) {
rendezvous_->StartAbort(s);
}
if (collective_executor_) {
collective_executor_->StartAbort(s);
}
if (cancellation_manager_) {
cancellation_manager_->StartCancel();
} else {
// If there's cancellation_manager_, collective ops aborts
// collective_executor_ upon cancellation; otherwise we need to abort
// here.
if (collective_executor_) {
collective_executor_->StartAbort(s);
}
}
}
@ -1267,11 +1271,15 @@ void ExecutorState<PropagatorStateType>::Finish() {
if (rendezvous_) {
rendezvous_->StartAbort(status);
}
if (collective_executor_) {
collective_executor_->StartAbort(status);
}
if (cancellation_manager_) {
cancellation_manager_->StartCancel();
} else {
// If there's cancellation_manager_, collective ops aborts
// collective_executor_ upon cancellation; otherwise we need to abort
// here.
if (collective_executor_) {
collective_executor_->StartAbort(status);
}
}
}
delete this;

View File

@ -51,7 +51,54 @@ static std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c,
class CollectiveOpKernel : public AsyncOpKernel {
public:
explicit CollectiveOpKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {}
explicit CollectiveOpKernel(OpKernelConstruction* c)
: AsyncOpKernel(c), name_(name()) {}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
name_),
done);
CancellationToken token =
c->cancellation_manager()->get_cancellation_token();
bool cancel_registered =
c->cancellation_manager()->RegisterCallback(token, [col_exec]() {
// StartAbort invokes done callback which contains DeregisterCallback,
// so we cannot block on that.
col_exec->RunClosure([col_exec]() {
col_exec->StartAbort(errors::Cancelled("op cancelled"));
});
});
OP_REQUIRES_ASYNC(c, cancel_registered,
errors::Cancelled("op cancelled ", name_), done);
auto deregister_and_done = [c, col_exec, token, done = std::move(done)]() {
c->cancellation_manager()->DeregisterCallback(token);
// Abort CollectiveExecutor so that this error can propagate to other
// workers.
if (!c->status().ok()) {
col_exec->StartAbort(c->status());
}
done();
};
ComputeAsyncImpl(c, col_exec, std::move(deregister_and_done));
}
protected:
virtual void ComputeAsyncImpl(OpKernelContext* c,
CollectiveExecutor* col_exec,
DoneCallback done) = 0;
string name_;
};
class CollectiveOpV1Kernel : public CollectiveOpKernel {
public:
explicit CollectiveOpV1Kernel(OpKernelConstruction* c)
: CollectiveOpKernel(c) {}
// A string encoding instance, frame and iter to be handed off to
// the implementation for use in generating RecvBuf keys.
@ -90,14 +137,15 @@ class CollectiveOpKernel : public AsyncOpKernel {
return true;
}
protected:
CollectiveParams col_params_;
std::vector<int32> dependencies_;
};
class CollectiveGatherOpKernel : public CollectiveOpKernel {
class CollectiveGatherOpKernel : public CollectiveOpV1Kernel {
public:
explicit CollectiveGatherOpKernel(OpKernelConstruction* c)
: CollectiveOpKernel(c) {
: CollectiveOpV1Kernel(c) {
col_params_.instance.type = GATHER_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
OP_REQUIRES(
@ -119,15 +167,9 @@ class CollectiveGatherOpKernel : public CollectiveOpKernel {
col_params_.group.device_type = c->device_type();
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
protected:
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
DoneCallback done) override {
auto output_shape = c->input(0).shape();
output_shape.set_dim(
0, output_shape.dim_size(0) * col_params_.group.group_size);
@ -171,10 +213,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_GPU),
CollectiveGatherOpKernel);
class CollectiveReduceOpKernel : public CollectiveOpKernel {
class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
public:
explicit CollectiveReduceOpKernel(OpKernelConstruction* c)
: CollectiveOpKernel(c) {
: CollectiveOpV1Kernel(c) {
col_params_.instance.type = REDUCTION_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
OP_REQUIRES(
@ -231,14 +273,9 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel {
col_params_.final_op = BuildOpKernel(c, final_op_name, &sub_node);
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
protected:
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
DoneCallback done) override {
// Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise
// the memory is not guaranteed to be unused by any concurrently executing
@ -280,10 +317,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_GPU),
CollectiveReduceOpKernel);
class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
class CollectiveBcastSendOpKernel : public CollectiveOpV1Kernel {
public:
explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c)
: CollectiveOpKernel(c) {
: CollectiveOpV1Kernel(c) {
col_params_.instance.type = BROADCAST_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
OP_REQUIRES(
@ -309,14 +346,9 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
col_params_.group.device_type = c->device_type();
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
protected:
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
DoneCallback done) override {
// Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise
// the memory is not guaranteed to be unused by any concurrently executing
@ -362,10 +394,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_GPU),
CollectiveBcastSendOpKernel);
class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
class CollectiveBcastRecvOpKernel : public CollectiveOpV1Kernel {
public:
explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c)
: CollectiveOpKernel(c) {
: CollectiveOpV1Kernel(c) {
col_params_.instance.type = BROADCAST_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
OP_REQUIRES(
@ -391,14 +423,9 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
col_params_.group.device_type = c->device_type();
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
protected:
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
DoneCallback done) override {
// Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise
// the memory is not guaranteed to be unused by any concurrently executing
@ -437,10 +464,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_GPU),
CollectiveBcastRecvOpKernel);
class CollectiveReduceV2OpKernel : public AsyncOpKernel {
class CollectiveReduceV2OpKernel : public CollectiveOpKernel {
public:
explicit CollectiveReduceV2OpKernel(OpKernelConstruction* c)
: AsyncOpKernel(c) {
: CollectiveOpKernel(c) {
col_params_ = std::make_shared<CollectiveParams>();
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
string merge_op_name;
@ -481,14 +508,9 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel {
<< col_params_->instance.impl_details.communication_hint;
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_->name),
done);
protected:
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
DoneCallback done) override {
const Tensor& input = c->input(0);
const Tensor& group_size = c->input(1);
const Tensor& group_key = c->input(2);
@ -590,10 +612,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2")
.HostMemory("instance_key"),
CollectiveReduceV2OpKernel);
class CollectiveGatherV2OpKernel : public AsyncOpKernel {
class CollectiveGatherV2OpKernel : public CollectiveOpKernel {
public:
explicit CollectiveGatherV2OpKernel(OpKernelConstruction* c)
: AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
: CollectiveOpKernel(c), device_type_(DEVICE_DEFAULT) {
OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
@ -603,14 +625,9 @@ class CollectiveGatherV2OpKernel : public AsyncOpKernel {
<< " communication_hint " << communication_hint_;
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
name_),
done);
protected:
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
DoneCallback done) override {
const Tensor& input = c->input(0);
const Tensor& group_size = c->input(1);
const Tensor& group_key = c->input(2);
@ -712,7 +729,6 @@ class CollectiveGatherV2OpKernel : public AsyncOpKernel {
string communication_hint_;
float timeout_seconds_;
DeviceType device_type_;
string name_;
};
REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2").Device(DEVICE_CPU),

View File

@ -412,7 +412,8 @@ class CollectiveReplicaLauncher(object):
self._group_key, self._device)
instance_key_shape = self._collective_keys.get_instance_key(
self._group_key, self._device)
with ops.device(self._device):
with ops.device(self._device), \
ops.control_dependencies([array_ops.identity(input_tensor)]):
# 1. Transpose
# E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3,
# we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which

View File

@ -29,7 +29,6 @@ from tensorflow.python.data.experimental.ops.distribute_options import AutoShard
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.distribute import central_storage_strategy
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import mirrored_strategy
@ -1151,9 +1150,6 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
if mode == 'graph' and _is_tpu_strategy(distribution):
self.skipTest('partial batch not supported with TPU in graph mode.')
if isinstance(distribution,
collective_all_reduce_strategy.CollectiveAllReduceStrategy):
self.skipTest('EOF error causes subsequent collective ops fail.')
with self.cached_session():
with distribution.scope():
optimizer_fn = gradient_descent_keras.SGD
@ -1166,8 +1162,8 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
loss,
metrics=metrics)
inputs = np.zeros((1000, 3), dtype=np.float32)
targets = np.zeros((1000, 4), dtype=np.float32)
inputs = np.zeros((10, 3), dtype=np.float32)
targets = np.zeros((10, 4), dtype=np.float32)
# steps/steps_per_epoch are calculated when using numpy arrays as
# input data.
fit_with_numpy = model.fit(

View File

@ -2733,19 +2733,23 @@ def _collective_all_reduce_multi_worker(strategy):
def _multi_worker_concat(v, strategy):
"""Order PerReplica objects for CollectiveAllReduceStrategy and concat."""
replicas = strategy._gather(v, axis=0) # pylint: disable=protected-access
# v might not have the same shape on different replicas
if isinstance(v, ds_values.PerReplica):
shapes = array_ops.concat([
array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0)
for single_value in v.values
],
axis=0)
all_shapes = strategy._gather(shapes, axis=0) # pylint: disable=protected-access
else:
# v is a tensor. This may happen when, say, we have 2x1 multi-worker.
all_shapes = strategy._gather( # pylint: disable=protected-access
array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0),
axis=0)
# TODO(b/170435030): We now need to make sure these run after the iterator
# GetNext, so that we don't trigger aborting collective ops in the case of
# EOF. Remove after the issue is fixed.
with ops.control_dependencies([replicas]):
# v might not have the same shape on different replicas
if isinstance(v, ds_values.PerReplica):
shapes = array_ops.concat([
array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0)
for single_value in v.values
],
axis=0)
all_shapes = strategy._gather(shapes, axis=0) # pylint: disable=protected-access
else:
# v is a tensor. This may happen when, say, we have 2x1 multi-worker.
all_shapes = strategy._gather( # pylint: disable=protected-access
array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0),
axis=0)
replicas = array_ops.split(
replicas,

View File

@ -298,6 +298,8 @@ cuda_py_test(
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/data/experimental/ops:testing",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:test_util",
"//tensorflow/python/eager:context",

View File

@ -24,6 +24,8 @@ import time
from absl.testing import parameterized
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.experimental.ops import testing as dataset_testing
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import test_util
from tensorflow.python.eager import context
@ -469,6 +471,83 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
_setup_context()
def_function.function(collective_fn)()
def testOpErrorNotAbort(self, collective_op, device, communication):
# Do not abort if there's no active collective ops. There could be
# exceptions like EOF which we expect users to catch, aborting collective
# ops on all op errors intervenes with this workflow.
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
group_size = 2
group_key = 100
instance_key = 100
dataset = dataset_ops.Dataset.from_tensors([1.])
@def_function.function
def collective_fn(in_tensor):
for device in [dev0, dev1]:
with ops.device(device):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
@def_function.function
def f():
iterator = iter(dataset)
collective_fn(next(iterator))
# This next(iterator) should raise EOF.
collective_fn(next(iterator))
with self.assertRaises(errors.OutOfRangeError):
f()
collective_fn(constant_op.constant([1.]))
def testOpErrorAbort(self, collective_op, device, communication):
# Abort collective ops if there're active collective ops at the time of an
# op error. This is due to the inability to cancel collective ops, and op
# errors may cause running collective ops to hang.
dev0 = '/device:%s:0' % device
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
# Make the dataset sleep a while so that the collective is being executed
# when the EOF happens.
dataset = dataset_ops.Dataset.from_tensors([1.]).apply(
dataset_testing.sleep(sleep_microseconds=200))
@def_function.function
def f():
# Launch a collective op that won't be able to finish to test abortion
# when other ops error.
with ops.device(dev0):
ret = collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
iterator = iter(dataset)
next(iterator)
# This should raise EOF.
next(iterator)
return ret
with self.assertRaises(errors.OutOfRangeError):
f()
# Now collective ops is aborted, subsequent collective ops should fail with
# the previous error.
with self.assertRaises(errors.CancelledError):
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
@combinations.generate(
combinations.times(