Use cancellation manager to abort collectives

We used to always abort collective ops in executor when there're errors in graph execution. However there're some errors that are intended for the user to catch, and if we abort collective ops, the user program cannot continue. It's also not necessary
to abort collective ops if there's no active ones.

Ideally we should have a cancellation story for collectives. Before that, we can at least only abort collectives when it's necessary, i.e. when there're pending collective ops or failed collective ops.

To make the the catching EOF workflow work, we also need to make all collectives in gather depend on the input tensors, so there's better chance they fire after iterator GetNext. Without that the shape gathering may run in parallel with GetNext.

PiperOrigin-RevId: 337440792
Change-Id: I7caea917c858bcf99f6eb471abf46d94d5c255b3
This commit is contained in:
Ran Chen 2020-10-15 21:28:33 -07:00 committed by TensorFlower Gardener
parent 15dd772865
commit f0844f4065
7 changed files with 195 additions and 89 deletions

View File

@ -1119,11 +1119,15 @@ bool ExecutorState<PropagatorStateType>::NodeDone(
if (rendezvous_) { if (rendezvous_) {
rendezvous_->StartAbort(s); rendezvous_->StartAbort(s);
} }
if (collective_executor_) {
collective_executor_->StartAbort(s);
}
if (cancellation_manager_) { if (cancellation_manager_) {
cancellation_manager_->StartCancel(); cancellation_manager_->StartCancel();
} else {
// If there's cancellation_manager_, collective ops aborts
// collective_executor_ upon cancellation; otherwise we need to abort
// here.
if (collective_executor_) {
collective_executor_->StartAbort(s);
}
} }
} }
@ -1267,11 +1271,15 @@ void ExecutorState<PropagatorStateType>::Finish() {
if (rendezvous_) { if (rendezvous_) {
rendezvous_->StartAbort(status); rendezvous_->StartAbort(status);
} }
if (collective_executor_) {
collective_executor_->StartAbort(status);
}
if (cancellation_manager_) { if (cancellation_manager_) {
cancellation_manager_->StartCancel(); cancellation_manager_->StartCancel();
} else {
// If there's cancellation_manager_, collective ops aborts
// collective_executor_ upon cancellation; otherwise we need to abort
// here.
if (collective_executor_) {
collective_executor_->StartAbort(status);
}
} }
} }
delete this; delete this;

View File

@ -51,7 +51,54 @@ static std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c,
class CollectiveOpKernel : public AsyncOpKernel { class CollectiveOpKernel : public AsyncOpKernel {
public: public:
explicit CollectiveOpKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {} explicit CollectiveOpKernel(OpKernelConstruction* c)
: AsyncOpKernel(c), name_(name()) {}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
name_),
done);
CancellationToken token =
c->cancellation_manager()->get_cancellation_token();
bool cancel_registered =
c->cancellation_manager()->RegisterCallback(token, [col_exec]() {
// StartAbort invokes done callback which contains DeregisterCallback,
// so we cannot block on that.
col_exec->RunClosure([col_exec]() {
col_exec->StartAbort(errors::Cancelled("op cancelled"));
});
});
OP_REQUIRES_ASYNC(c, cancel_registered,
errors::Cancelled("op cancelled ", name_), done);
auto deregister_and_done = [c, col_exec, token, done = std::move(done)]() {
c->cancellation_manager()->DeregisterCallback(token);
// Abort CollectiveExecutor so that this error can propagate to other
// workers.
if (!c->status().ok()) {
col_exec->StartAbort(c->status());
}
done();
};
ComputeAsyncImpl(c, col_exec, std::move(deregister_and_done));
}
protected:
virtual void ComputeAsyncImpl(OpKernelContext* c,
CollectiveExecutor* col_exec,
DoneCallback done) = 0;
string name_;
};
class CollectiveOpV1Kernel : public CollectiveOpKernel {
public:
explicit CollectiveOpV1Kernel(OpKernelConstruction* c)
: CollectiveOpKernel(c) {}
// A string encoding instance, frame and iter to be handed off to // A string encoding instance, frame and iter to be handed off to
// the implementation for use in generating RecvBuf keys. // the implementation for use in generating RecvBuf keys.
@ -90,14 +137,15 @@ class CollectiveOpKernel : public AsyncOpKernel {
return true; return true;
} }
protected:
CollectiveParams col_params_; CollectiveParams col_params_;
std::vector<int32> dependencies_; std::vector<int32> dependencies_;
}; };
class CollectiveGatherOpKernel : public CollectiveOpKernel { class CollectiveGatherOpKernel : public CollectiveOpV1Kernel {
public: public:
explicit CollectiveGatherOpKernel(OpKernelConstruction* c) explicit CollectiveGatherOpKernel(OpKernelConstruction* c)
: CollectiveOpKernel(c) { : CollectiveOpV1Kernel(c) {
col_params_.instance.type = GATHER_COLLECTIVE; col_params_.instance.type = GATHER_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size)); OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
OP_REQUIRES( OP_REQUIRES(
@ -119,15 +167,9 @@ class CollectiveGatherOpKernel : public CollectiveOpKernel {
col_params_.group.device_type = c->device_type(); col_params_.group.device_type = c->device_type();
} }
void ComputeAsync(OpKernelContext* c, DoneCallback done) override { protected:
CollectiveExecutor* col_exec = c->collective_executor(); void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
OP_REQUIRES_ASYNC( DoneCallback done) override {
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
auto output_shape = c->input(0).shape(); auto output_shape = c->input(0).shape();
output_shape.set_dim( output_shape.set_dim(
0, output_shape.dim_size(0) * col_params_.group.group_size); 0, output_shape.dim_size(0) * col_params_.group.group_size);
@ -171,10 +213,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_GPU), REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_GPU),
CollectiveGatherOpKernel); CollectiveGatherOpKernel);
class CollectiveReduceOpKernel : public CollectiveOpKernel { class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
public: public:
explicit CollectiveReduceOpKernel(OpKernelConstruction* c) explicit CollectiveReduceOpKernel(OpKernelConstruction* c)
: CollectiveOpKernel(c) { : CollectiveOpV1Kernel(c) {
col_params_.instance.type = REDUCTION_COLLECTIVE; col_params_.instance.type = REDUCTION_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size)); OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
OP_REQUIRES( OP_REQUIRES(
@ -231,14 +273,9 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel {
col_params_.final_op = BuildOpKernel(c, final_op_name, &sub_node); col_params_.final_op = BuildOpKernel(c, final_op_name, &sub_node);
} }
void ComputeAsync(OpKernelContext* c, DoneCallback done) override { protected:
CollectiveExecutor* col_exec = c->collective_executor(); void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
OP_REQUIRES_ASYNC( DoneCallback done) override {
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
// Allocate output on the first pass through this function. This must be // Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise // done immediately, while we're still in the executor thread. Otherwise
// the memory is not guaranteed to be unused by any concurrently executing // the memory is not guaranteed to be unused by any concurrently executing
@ -280,10 +317,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_GPU), REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_GPU),
CollectiveReduceOpKernel); CollectiveReduceOpKernel);
class CollectiveBcastSendOpKernel : public CollectiveOpKernel { class CollectiveBcastSendOpKernel : public CollectiveOpV1Kernel {
public: public:
explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c) explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c)
: CollectiveOpKernel(c) { : CollectiveOpV1Kernel(c) {
col_params_.instance.type = BROADCAST_COLLECTIVE; col_params_.instance.type = BROADCAST_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size)); OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
OP_REQUIRES( OP_REQUIRES(
@ -309,14 +346,9 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
col_params_.group.device_type = c->device_type(); col_params_.group.device_type = c->device_type();
} }
void ComputeAsync(OpKernelContext* c, DoneCallback done) override { protected:
CollectiveExecutor* col_exec = c->collective_executor(); void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
OP_REQUIRES_ASYNC( DoneCallback done) override {
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
// Allocate output on the first pass through this function. This must be // Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise // done immediately, while we're still in the executor thread. Otherwise
// the memory is not guaranteed to be unused by any concurrently executing // the memory is not guaranteed to be unused by any concurrently executing
@ -362,10 +394,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_GPU), REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_GPU),
CollectiveBcastSendOpKernel); CollectiveBcastSendOpKernel);
class CollectiveBcastRecvOpKernel : public CollectiveOpKernel { class CollectiveBcastRecvOpKernel : public CollectiveOpV1Kernel {
public: public:
explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c) explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c)
: CollectiveOpKernel(c) { : CollectiveOpV1Kernel(c) {
col_params_.instance.type = BROADCAST_COLLECTIVE; col_params_.instance.type = BROADCAST_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size)); OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
OP_REQUIRES( OP_REQUIRES(
@ -391,14 +423,9 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
col_params_.group.device_type = c->device_type(); col_params_.group.device_type = c->device_type();
} }
void ComputeAsync(OpKernelContext* c, DoneCallback done) override { protected:
CollectiveExecutor* col_exec = c->collective_executor(); void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
OP_REQUIRES_ASYNC( DoneCallback done) override {
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
// Allocate output on the first pass through this function. This must be // Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise // done immediately, while we're still in the executor thread. Otherwise
// the memory is not guaranteed to be unused by any concurrently executing // the memory is not guaranteed to be unused by any concurrently executing
@ -437,10 +464,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_GPU), REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_GPU),
CollectiveBcastRecvOpKernel); CollectiveBcastRecvOpKernel);
class CollectiveReduceV2OpKernel : public AsyncOpKernel { class CollectiveReduceV2OpKernel : public CollectiveOpKernel {
public: public:
explicit CollectiveReduceV2OpKernel(OpKernelConstruction* c) explicit CollectiveReduceV2OpKernel(OpKernelConstruction* c)
: AsyncOpKernel(c) { : CollectiveOpKernel(c) {
col_params_ = std::make_shared<CollectiveParams>(); col_params_ = std::make_shared<CollectiveParams>();
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type)); OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
string merge_op_name; string merge_op_name;
@ -481,14 +508,9 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel {
<< col_params_->instance.impl_details.communication_hint; << col_params_->instance.impl_details.communication_hint;
} }
void ComputeAsync(OpKernelContext* c, DoneCallback done) override { protected:
CollectiveExecutor* col_exec = c->collective_executor(); void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
OP_REQUIRES_ASYNC( DoneCallback done) override {
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_->name),
done);
const Tensor& input = c->input(0); const Tensor& input = c->input(0);
const Tensor& group_size = c->input(1); const Tensor& group_size = c->input(1);
const Tensor& group_key = c->input(2); const Tensor& group_key = c->input(2);
@ -590,10 +612,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2")
.HostMemory("instance_key"), .HostMemory("instance_key"),
CollectiveReduceV2OpKernel); CollectiveReduceV2OpKernel);
class CollectiveGatherV2OpKernel : public AsyncOpKernel { class CollectiveGatherV2OpKernel : public CollectiveOpKernel {
public: public:
explicit CollectiveGatherV2OpKernel(OpKernelConstruction* c) explicit CollectiveGatherV2OpKernel(OpKernelConstruction* c)
: AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) { : CollectiveOpKernel(c), device_type_(DEVICE_DEFAULT) {
OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_)); OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_)); OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_)); OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
@ -603,14 +625,9 @@ class CollectiveGatherV2OpKernel : public AsyncOpKernel {
<< " communication_hint " << communication_hint_; << " communication_hint " << communication_hint_;
} }
void ComputeAsync(OpKernelContext* c, DoneCallback done) override { protected:
CollectiveExecutor* col_exec = c->collective_executor(); void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
OP_REQUIRES_ASYNC( DoneCallback done) override {
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
name_),
done);
const Tensor& input = c->input(0); const Tensor& input = c->input(0);
const Tensor& group_size = c->input(1); const Tensor& group_size = c->input(1);
const Tensor& group_key = c->input(2); const Tensor& group_key = c->input(2);
@ -712,7 +729,6 @@ class CollectiveGatherV2OpKernel : public AsyncOpKernel {
string communication_hint_; string communication_hint_;
float timeout_seconds_; float timeout_seconds_;
DeviceType device_type_; DeviceType device_type_;
string name_;
}; };
REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2").Device(DEVICE_CPU),

View File

@ -412,7 +412,8 @@ class CollectiveReplicaLauncher(object):
self._group_key, self._device) self._group_key, self._device)
instance_key_shape = self._collective_keys.get_instance_key( instance_key_shape = self._collective_keys.get_instance_key(
self._group_key, self._device) self._group_key, self._device)
with ops.device(self._device): with ops.device(self._device), \
ops.control_dependencies([array_ops.identity(input_tensor)]):
# 1. Transpose # 1. Transpose
# E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3, # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3,
# we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which

View File

@ -29,7 +29,6 @@ from tensorflow.python.data.experimental.ops.distribute_options import AutoShard
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers from tensorflow.python.data.ops import readers
from tensorflow.python.distribute import central_storage_strategy from tensorflow.python.distribute import central_storage_strategy
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations as ds_combinations from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import mirrored_strategy
@ -1151,9 +1150,6 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
if mode == 'graph' and _is_tpu_strategy(distribution): if mode == 'graph' and _is_tpu_strategy(distribution):
self.skipTest('partial batch not supported with TPU in graph mode.') self.skipTest('partial batch not supported with TPU in graph mode.')
if isinstance(distribution,
collective_all_reduce_strategy.CollectiveAllReduceStrategy):
self.skipTest('EOF error causes subsequent collective ops fail.')
with self.cached_session(): with self.cached_session():
with distribution.scope(): with distribution.scope():
optimizer_fn = gradient_descent_keras.SGD optimizer_fn = gradient_descent_keras.SGD
@ -1166,8 +1162,8 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
loss, loss,
metrics=metrics) metrics=metrics)
inputs = np.zeros((1000, 3), dtype=np.float32) inputs = np.zeros((10, 3), dtype=np.float32)
targets = np.zeros((1000, 4), dtype=np.float32) targets = np.zeros((10, 4), dtype=np.float32)
# steps/steps_per_epoch are calculated when using numpy arrays as # steps/steps_per_epoch are calculated when using numpy arrays as
# input data. # input data.
fit_with_numpy = model.fit( fit_with_numpy = model.fit(

View File

@ -2733,19 +2733,23 @@ def _collective_all_reduce_multi_worker(strategy):
def _multi_worker_concat(v, strategy): def _multi_worker_concat(v, strategy):
"""Order PerReplica objects for CollectiveAllReduceStrategy and concat.""" """Order PerReplica objects for CollectiveAllReduceStrategy and concat."""
replicas = strategy._gather(v, axis=0) # pylint: disable=protected-access replicas = strategy._gather(v, axis=0) # pylint: disable=protected-access
# v might not have the same shape on different replicas # TODO(b/170435030): We now need to make sure these run after the iterator
if isinstance(v, ds_values.PerReplica): # GetNext, so that we don't trigger aborting collective ops in the case of
shapes = array_ops.concat([ # EOF. Remove after the issue is fixed.
array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0) with ops.control_dependencies([replicas]):
for single_value in v.values # v might not have the same shape on different replicas
], if isinstance(v, ds_values.PerReplica):
axis=0) shapes = array_ops.concat([
all_shapes = strategy._gather(shapes, axis=0) # pylint: disable=protected-access array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0)
else: for single_value in v.values
# v is a tensor. This may happen when, say, we have 2x1 multi-worker. ],
all_shapes = strategy._gather( # pylint: disable=protected-access axis=0)
array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0), all_shapes = strategy._gather(shapes, axis=0) # pylint: disable=protected-access
axis=0) else:
# v is a tensor. This may happen when, say, we have 2x1 multi-worker.
all_shapes = strategy._gather( # pylint: disable=protected-access
array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0),
axis=0)
replicas = array_ops.split( replicas = array_ops.split(
replicas, replicas,

View File

@ -298,6 +298,8 @@ cuda_py_test(
"//tensorflow/python:errors", "//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python/compat:v2_compat", "//tensorflow/python/compat:v2_compat",
"//tensorflow/python/data/experimental/ops:testing",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:test_util", "//tensorflow/python/distribute:test_util",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",

View File

@ -24,6 +24,8 @@ import time
from absl.testing import parameterized from absl.testing import parameterized
from tensorflow.python.compat import v2_compat from tensorflow.python.compat import v2_compat
from tensorflow.python.data.experimental.ops import testing as dataset_testing
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import test_util from tensorflow.python.distribute import test_util
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -469,6 +471,83 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
_setup_context() _setup_context()
def_function.function(collective_fn)() def_function.function(collective_fn)()
def testOpErrorNotAbort(self, collective_op, device, communication):
# Do not abort if there's no active collective ops. There could be
# exceptions like EOF which we expect users to catch, aborting collective
# ops on all op errors intervenes with this workflow.
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
group_size = 2
group_key = 100
instance_key = 100
dataset = dataset_ops.Dataset.from_tensors([1.])
@def_function.function
def collective_fn(in_tensor):
for device in [dev0, dev1]:
with ops.device(device):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
@def_function.function
def f():
iterator = iter(dataset)
collective_fn(next(iterator))
# This next(iterator) should raise EOF.
collective_fn(next(iterator))
with self.assertRaises(errors.OutOfRangeError):
f()
collective_fn(constant_op.constant([1.]))
def testOpErrorAbort(self, collective_op, device, communication):
# Abort collective ops if there're active collective ops at the time of an
# op error. This is due to the inability to cancel collective ops, and op
# errors may cause running collective ops to hang.
dev0 = '/device:%s:0' % device
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
# Make the dataset sleep a while so that the collective is being executed
# when the EOF happens.
dataset = dataset_ops.Dataset.from_tensors([1.]).apply(
dataset_testing.sleep(sleep_microseconds=200))
@def_function.function
def f():
# Launch a collective op that won't be able to finish to test abortion
# when other ops error.
with ops.device(dev0):
ret = collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
iterator = iter(dataset)
next(iterator)
# This should raise EOF.
next(iterator)
return ret
with self.assertRaises(errors.OutOfRangeError):
f()
# Now collective ops is aborted, subsequent collective ops should fail with
# the previous error.
with self.assertRaises(errors.CancelledError):
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
@combinations.generate( @combinations.generate(
combinations.times( combinations.times(