[retry]Use cancellation manager to abort collectives
The previous change may cause a use-after-free since StartAbort() runs a separate thread but accesses resources owned by CollectiveExecutiveMgr. Once all cancellation callbacks finish, the CollectiveExecutorMgr may already be deallocated while StartAbort() is in progress. Fixing the ownership is not trivial so we now call StartAbort() in the cancellation callback instead to ensure all resources are valid. Note that with this we need to use TryDeregisterCallback in done() instead of DeregisterCallback(), because the latter blocks until all cancellation callback is done. We used to always abort collective ops in executor when there're errors in graph execution. However there're some errors that are intended for the user to catch, and if we abort collective ops, the user program cannot continue. It's also not necessary to abort collective ops if there's no active ones. Ideally we should have a cancellation story for collectives. Before that, we can at least only abort collectives when it's necessary, i.e. when there're pending collective ops or failed collective ops. To make the the catching EOF workflow work, we also need to make all collectives in gather depend on the input tensors, so there's better chance they fire after iterator GetNext. Without that the shape gathering may run in parallel with GetNext. PiperOrigin-RevId: 337997169 Change-Id: I4a374f9ff00bdba38e012a96fb7f5837e049c85c
This commit is contained in:
parent
d9e09b0723
commit
d345c40688
@ -1119,11 +1119,13 @@ bool ExecutorState<PropagatorStateType>::NodeDone(
|
||||
if (rendezvous_) {
|
||||
rendezvous_->StartAbort(s);
|
||||
}
|
||||
if (collective_executor_) {
|
||||
collective_executor_->StartAbort(s);
|
||||
}
|
||||
if (cancellation_manager_) {
|
||||
cancellation_manager_->StartCancel();
|
||||
} else if (collective_executor_) {
|
||||
// If there's cancellation_manager_, collective ops aborts
|
||||
// collective_executor_ upon cancellation; otherwise we need to abort
|
||||
// here.
|
||||
collective_executor_->StartAbort(s);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1267,11 +1269,13 @@ void ExecutorState<PropagatorStateType>::Finish() {
|
||||
if (rendezvous_) {
|
||||
rendezvous_->StartAbort(status);
|
||||
}
|
||||
if (collective_executor_) {
|
||||
collective_executor_->StartAbort(status);
|
||||
}
|
||||
if (cancellation_manager_) {
|
||||
cancellation_manager_->StartCancel();
|
||||
} else if (collective_executor_) {
|
||||
// If there's cancellation_manager_, collective ops aborts
|
||||
// collective_executor_ upon cancellation; otherwise we need to abort
|
||||
// here.
|
||||
collective_executor_->StartAbort(status);
|
||||
}
|
||||
}
|
||||
delete this;
|
||||
|
@ -51,7 +51,56 @@ static std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c,
|
||||
|
||||
class CollectiveOpKernel : public AsyncOpKernel {
|
||||
public:
|
||||
explicit CollectiveOpKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {}
|
||||
explicit CollectiveOpKernel(OpKernelConstruction* c)
|
||||
: AsyncOpKernel(c), name_(name()) {}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
CollectiveExecutor* col_exec = c->collective_executor();
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, col_exec,
|
||||
errors::Internal(
|
||||
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
|
||||
name_),
|
||||
done);
|
||||
const CancellationToken token =
|
||||
c->cancellation_manager()->get_cancellation_token();
|
||||
const bool already_cancelled =
|
||||
!c->cancellation_manager()->RegisterCallback(token, [col_exec]() {
|
||||
// We must call StartAbort() within the callback. StartAbort() relies
|
||||
// on resources that may be deallocated if all execution of a graph is
|
||||
// finished.
|
||||
col_exec->StartAbort(errors::Cancelled("op cancelled"));
|
||||
});
|
||||
OP_REQUIRES_ASYNC(c, !already_cancelled,
|
||||
errors::Cancelled("op cancelled ", name_), done);
|
||||
|
||||
auto deregister_and_done = [c, col_exec, token, done = std::move(done)]() {
|
||||
// Once done() is called, StartAbort() won't have any effect, so we
|
||||
// don't need to block on the deregistration. Also StartAbort() may call
|
||||
// done() and DeregisterCallback may deadlock.
|
||||
c->cancellation_manager()->TryDeregisterCallback(token);
|
||||
// Abort CollectiveExecutor so that this error can propagate to other
|
||||
// workers.
|
||||
if (!c->status().ok()) {
|
||||
col_exec->StartAbort(c->status());
|
||||
}
|
||||
done();
|
||||
};
|
||||
ComputeAsyncImpl(c, col_exec, std::move(deregister_and_done));
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual void ComputeAsyncImpl(OpKernelContext* c,
|
||||
CollectiveExecutor* col_exec,
|
||||
DoneCallback done) = 0;
|
||||
|
||||
string name_;
|
||||
};
|
||||
|
||||
class CollectiveOpV1Kernel : public CollectiveOpKernel {
|
||||
public:
|
||||
explicit CollectiveOpV1Kernel(OpKernelConstruction* c)
|
||||
: CollectiveOpKernel(c) {}
|
||||
|
||||
// A string encoding instance, frame and iter to be handed off to
|
||||
// the implementation for use in generating RecvBuf keys.
|
||||
@ -90,14 +139,15 @@ class CollectiveOpKernel : public AsyncOpKernel {
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
CollectiveParams col_params_;
|
||||
std::vector<int32> dependencies_;
|
||||
};
|
||||
|
||||
class CollectiveGatherOpKernel : public CollectiveOpKernel {
|
||||
class CollectiveGatherOpKernel : public CollectiveOpV1Kernel {
|
||||
public:
|
||||
explicit CollectiveGatherOpKernel(OpKernelConstruction* c)
|
||||
: CollectiveOpKernel(c) {
|
||||
: CollectiveOpV1Kernel(c) {
|
||||
col_params_.instance.type = GATHER_COLLECTIVE;
|
||||
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
|
||||
OP_REQUIRES(
|
||||
@ -119,15 +169,9 @@ class CollectiveGatherOpKernel : public CollectiveOpKernel {
|
||||
col_params_.group.device_type = c->device_type();
|
||||
}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
CollectiveExecutor* col_exec = c->collective_executor();
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, col_exec,
|
||||
errors::Internal(
|
||||
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
|
||||
col_params_.name),
|
||||
done);
|
||||
|
||||
protected:
|
||||
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
|
||||
DoneCallback done) override {
|
||||
auto output_shape = c->input(0).shape();
|
||||
output_shape.set_dim(
|
||||
0, output_shape.dim_size(0) * col_params_.group.group_size);
|
||||
@ -171,10 +215,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_CPU),
|
||||
REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_GPU),
|
||||
CollectiveGatherOpKernel);
|
||||
|
||||
class CollectiveReduceOpKernel : public CollectiveOpKernel {
|
||||
class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
|
||||
public:
|
||||
explicit CollectiveReduceOpKernel(OpKernelConstruction* c)
|
||||
: CollectiveOpKernel(c) {
|
||||
: CollectiveOpV1Kernel(c) {
|
||||
col_params_.instance.type = REDUCTION_COLLECTIVE;
|
||||
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
|
||||
OP_REQUIRES(
|
||||
@ -231,14 +275,9 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel {
|
||||
col_params_.final_op = BuildOpKernel(c, final_op_name, &sub_node);
|
||||
}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
CollectiveExecutor* col_exec = c->collective_executor();
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, col_exec,
|
||||
errors::Internal(
|
||||
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
|
||||
col_params_.name),
|
||||
done);
|
||||
protected:
|
||||
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
|
||||
DoneCallback done) override {
|
||||
// Allocate output on the first pass through this function. This must be
|
||||
// done immediately, while we're still in the executor thread. Otherwise
|
||||
// the memory is not guaranteed to be unused by any concurrently executing
|
||||
@ -280,10 +319,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_CPU),
|
||||
REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_GPU),
|
||||
CollectiveReduceOpKernel);
|
||||
|
||||
class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
|
||||
class CollectiveBcastSendOpKernel : public CollectiveOpV1Kernel {
|
||||
public:
|
||||
explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c)
|
||||
: CollectiveOpKernel(c) {
|
||||
: CollectiveOpV1Kernel(c) {
|
||||
col_params_.instance.type = BROADCAST_COLLECTIVE;
|
||||
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
|
||||
OP_REQUIRES(
|
||||
@ -309,14 +348,9 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
|
||||
col_params_.group.device_type = c->device_type();
|
||||
}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
CollectiveExecutor* col_exec = c->collective_executor();
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, col_exec,
|
||||
errors::Internal(
|
||||
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
|
||||
col_params_.name),
|
||||
done);
|
||||
protected:
|
||||
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
|
||||
DoneCallback done) override {
|
||||
// Allocate output on the first pass through this function. This must be
|
||||
// done immediately, while we're still in the executor thread. Otherwise
|
||||
// the memory is not guaranteed to be unused by any concurrently executing
|
||||
@ -362,10 +396,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_CPU),
|
||||
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_GPU),
|
||||
CollectiveBcastSendOpKernel);
|
||||
|
||||
class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
|
||||
class CollectiveBcastRecvOpKernel : public CollectiveOpV1Kernel {
|
||||
public:
|
||||
explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c)
|
||||
: CollectiveOpKernel(c) {
|
||||
: CollectiveOpV1Kernel(c) {
|
||||
col_params_.instance.type = BROADCAST_COLLECTIVE;
|
||||
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
|
||||
OP_REQUIRES(
|
||||
@ -391,14 +425,9 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
|
||||
col_params_.group.device_type = c->device_type();
|
||||
}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
CollectiveExecutor* col_exec = c->collective_executor();
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, col_exec,
|
||||
errors::Internal(
|
||||
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
|
||||
col_params_.name),
|
||||
done);
|
||||
protected:
|
||||
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
|
||||
DoneCallback done) override {
|
||||
// Allocate output on the first pass through this function. This must be
|
||||
// done immediately, while we're still in the executor thread. Otherwise
|
||||
// the memory is not guaranteed to be unused by any concurrently executing
|
||||
@ -437,10 +466,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_CPU),
|
||||
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_GPU),
|
||||
CollectiveBcastRecvOpKernel);
|
||||
|
||||
class CollectiveReduceV2OpKernel : public AsyncOpKernel {
|
||||
class CollectiveReduceV2OpKernel : public CollectiveOpKernel {
|
||||
public:
|
||||
explicit CollectiveReduceV2OpKernel(OpKernelConstruction* c)
|
||||
: AsyncOpKernel(c) {
|
||||
: CollectiveOpKernel(c) {
|
||||
col_params_ = std::make_shared<CollectiveParams>();
|
||||
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
|
||||
string merge_op_name;
|
||||
@ -481,14 +510,9 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel {
|
||||
<< col_params_->instance.impl_details.communication_hint;
|
||||
}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
CollectiveExecutor* col_exec = c->collective_executor();
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, col_exec,
|
||||
errors::Internal(
|
||||
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
|
||||
col_params_->name),
|
||||
done);
|
||||
protected:
|
||||
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
|
||||
DoneCallback done) override {
|
||||
const Tensor& input = c->input(0);
|
||||
const Tensor& group_size = c->input(1);
|
||||
const Tensor& group_key = c->input(2);
|
||||
@ -590,10 +614,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2")
|
||||
.HostMemory("instance_key"),
|
||||
CollectiveReduceV2OpKernel);
|
||||
|
||||
class CollectiveGatherV2OpKernel : public AsyncOpKernel {
|
||||
class CollectiveGatherV2OpKernel : public CollectiveOpKernel {
|
||||
public:
|
||||
explicit CollectiveGatherV2OpKernel(OpKernelConstruction* c)
|
||||
: AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
|
||||
: CollectiveOpKernel(c), device_type_(DEVICE_DEFAULT) {
|
||||
OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
|
||||
@ -603,14 +627,9 @@ class CollectiveGatherV2OpKernel : public AsyncOpKernel {
|
||||
<< " communication_hint " << communication_hint_;
|
||||
}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
CollectiveExecutor* col_exec = c->collective_executor();
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, col_exec,
|
||||
errors::Internal(
|
||||
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
|
||||
name_),
|
||||
done);
|
||||
protected:
|
||||
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
|
||||
DoneCallback done) override {
|
||||
const Tensor& input = c->input(0);
|
||||
const Tensor& group_size = c->input(1);
|
||||
const Tensor& group_key = c->input(2);
|
||||
@ -712,7 +731,6 @@ class CollectiveGatherV2OpKernel : public AsyncOpKernel {
|
||||
string communication_hint_;
|
||||
float timeout_seconds_;
|
||||
DeviceType device_type_;
|
||||
string name_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2").Device(DEVICE_CPU),
|
||||
|
@ -412,7 +412,8 @@ class CollectiveReplicaLauncher(object):
|
||||
self._group_key, self._device)
|
||||
instance_key_shape = self._collective_keys.get_instance_key(
|
||||
self._group_key, self._device)
|
||||
with ops.device(self._device):
|
||||
with ops.device(self._device), \
|
||||
ops.control_dependencies([array_ops.identity(input_tensor)]):
|
||||
# 1. Transpose
|
||||
# E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3,
|
||||
# we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which
|
||||
|
@ -29,7 +29,6 @@ from tensorflow.python.data.experimental.ops.distribute_options import AutoShard
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.distribute import central_storage_strategy
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import combinations as ds_combinations
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
@ -1151,9 +1150,6 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
||||
if mode == 'graph' and _is_tpu_strategy(distribution):
|
||||
self.skipTest('partial batch not supported with TPU in graph mode.')
|
||||
|
||||
if isinstance(distribution,
|
||||
collective_all_reduce_strategy.CollectiveAllReduceStrategy):
|
||||
self.skipTest('EOF error causes subsequent collective ops fail.')
|
||||
with self.cached_session():
|
||||
with distribution.scope():
|
||||
optimizer_fn = gradient_descent_keras.SGD
|
||||
@ -1166,8 +1162,8 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
||||
loss,
|
||||
metrics=metrics)
|
||||
|
||||
inputs = np.zeros((1000, 3), dtype=np.float32)
|
||||
targets = np.zeros((1000, 4), dtype=np.float32)
|
||||
inputs = np.zeros((100, 3), dtype=np.float32)
|
||||
targets = np.zeros((100, 4), dtype=np.float32)
|
||||
# steps/steps_per_epoch are calculated when using numpy arrays as
|
||||
# input data.
|
||||
fit_with_numpy = model.fit(
|
||||
|
@ -2756,19 +2756,23 @@ def _collective_all_reduce_multi_worker(strategy):
|
||||
def _multi_worker_concat(v, strategy):
|
||||
"""Order PerReplica objects for CollectiveAllReduceStrategy and concat."""
|
||||
replicas = strategy.gather(v, axis=0) # pylint: disable=protected-access
|
||||
# v might not have the same shape on different replicas
|
||||
if isinstance(v, ds_values.PerReplica):
|
||||
shapes = array_ops.concat([
|
||||
array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0)
|
||||
for single_value in v.values
|
||||
],
|
||||
axis=0)
|
||||
all_shapes = strategy.gather(shapes, axis=0) # pylint: disable=protected-access
|
||||
else:
|
||||
# v is a tensor. This may happen when, say, we have 2x1 multi-worker.
|
||||
all_shapes = strategy.gather( # pylint: disable=protected-access
|
||||
array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0),
|
||||
axis=0)
|
||||
# TODO(b/170435030): We now need to make sure these run after the iterator
|
||||
# GetNext, so that we don't trigger aborting collective ops in the case of
|
||||
# EOF. Remove after the issue is fixed.
|
||||
with ops.control_dependencies([replicas]):
|
||||
# v might not have the same shape on different replicas
|
||||
if isinstance(v, ds_values.PerReplica):
|
||||
shapes = array_ops.concat([
|
||||
array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0)
|
||||
for single_value in v.values
|
||||
],
|
||||
axis=0)
|
||||
all_shapes = strategy.gather(shapes, axis=0)
|
||||
else:
|
||||
# v is a tensor. This may happen when, say, we have 2x1 multi-worker.
|
||||
all_shapes = strategy.gather(
|
||||
array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0),
|
||||
axis=0)
|
||||
|
||||
replicas = array_ops.split(
|
||||
replicas,
|
||||
|
@ -298,6 +298,8 @@ cuda_py_test(
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/data/experimental/ops:testing",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:test_util",
|
||||
"//tensorflow/python/eager:context",
|
||||
|
@ -24,6 +24,8 @@ import time
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.data.experimental.ops import testing as dataset_testing
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.eager import context
|
||||
@ -469,6 +471,83 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
_setup_context()
|
||||
def_function.function(collective_fn)()
|
||||
|
||||
def testOpErrorNotAbort(self, collective_op, device, communication):
|
||||
# Do not abort if there's no active collective ops. There could be
|
||||
# exceptions like EOF which we expect users to catch, aborting collective
|
||||
# ops on all op errors intervenes with this workflow.
|
||||
dev0 = '/device:%s:0' % device
|
||||
dev1 = '/device:%s:1' % device
|
||||
group_size = 2
|
||||
group_key = 100
|
||||
instance_key = 100
|
||||
dataset = dataset_ops.Dataset.from_tensors([1.])
|
||||
|
||||
@def_function.function
|
||||
def collective_fn(in_tensor):
|
||||
for device in [dev0, dev1]:
|
||||
with ops.device(device):
|
||||
collective_op(
|
||||
in_tensor,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint=communication)
|
||||
|
||||
@def_function.function
|
||||
def f():
|
||||
iterator = iter(dataset)
|
||||
collective_fn(next(iterator))
|
||||
# This next(iterator) should raise EOF.
|
||||
collective_fn(next(iterator))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
f()
|
||||
collective_fn(constant_op.constant([1.]))
|
||||
|
||||
def testOpErrorAbort(self, collective_op, device, communication):
|
||||
# Abort collective ops if there're active collective ops at the time of an
|
||||
# op error. This is due to the inability to cancel collective ops, and op
|
||||
# errors may cause running collective ops to hang.
|
||||
dev0 = '/device:%s:0' % device
|
||||
group_size = 2
|
||||
group_key = 100
|
||||
instance_key = 100
|
||||
in_tensor = constant_op.constant([1.])
|
||||
# Make the dataset sleep a while so that the collective is being executed
|
||||
# when the EOF happens.
|
||||
dataset = dataset_ops.Dataset.from_tensors([1.]).apply(
|
||||
dataset_testing.sleep(sleep_microseconds=200))
|
||||
|
||||
@def_function.function
|
||||
def f():
|
||||
# Launch a collective op that won't be able to finish to test abortion
|
||||
# when other ops error.
|
||||
with ops.device(dev0):
|
||||
ret = collective_op(
|
||||
in_tensor,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint=communication)
|
||||
iterator = iter(dataset)
|
||||
next(iterator)
|
||||
# This should raise EOF.
|
||||
next(iterator)
|
||||
return ret
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
f()
|
||||
# Now collective ops is aborted, subsequent collective ops should fail with
|
||||
# the previous error.
|
||||
with self.assertRaises(errors.CancelledError):
|
||||
with ops.device(dev0):
|
||||
collective_op(
|
||||
in_tensor,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint=communication)
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
|
Loading…
x
Reference in New Issue
Block a user