Use cancellation manager to abort collectives
We used to always abort collective ops in executor when there're errors in graph execution. However there're some errors that are intended for the user to catch, and if we abort collective ops, the user program cannot continue. It's also not necessary to abort collective ops if there's no active ones. Ideally we should have a cancellation story for collectives. Before that, we can at least only abort collectives when it's necessary, i.e. when there're pending collective ops or failed collective ops. To make the the catching EOF workflow work, we also need to make all collectives in gather depend on the input tensors, so there's better chance they fire after iterator GetNext. Without that the shape gathering may run in parallel with GetNext. PiperOrigin-RevId: 337440792 Change-Id: I7caea917c858bcf99f6eb471abf46d94d5c255b3
This commit is contained in:
parent
15dd772865
commit
f0844f4065
@ -1119,11 +1119,15 @@ bool ExecutorState<PropagatorStateType>::NodeDone(
|
||||
if (rendezvous_) {
|
||||
rendezvous_->StartAbort(s);
|
||||
}
|
||||
if (collective_executor_) {
|
||||
collective_executor_->StartAbort(s);
|
||||
}
|
||||
if (cancellation_manager_) {
|
||||
cancellation_manager_->StartCancel();
|
||||
} else {
|
||||
// If there's cancellation_manager_, collective ops aborts
|
||||
// collective_executor_ upon cancellation; otherwise we need to abort
|
||||
// here.
|
||||
if (collective_executor_) {
|
||||
collective_executor_->StartAbort(s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1267,11 +1271,15 @@ void ExecutorState<PropagatorStateType>::Finish() {
|
||||
if (rendezvous_) {
|
||||
rendezvous_->StartAbort(status);
|
||||
}
|
||||
if (collective_executor_) {
|
||||
collective_executor_->StartAbort(status);
|
||||
}
|
||||
if (cancellation_manager_) {
|
||||
cancellation_manager_->StartCancel();
|
||||
} else {
|
||||
// If there's cancellation_manager_, collective ops aborts
|
||||
// collective_executor_ upon cancellation; otherwise we need to abort
|
||||
// here.
|
||||
if (collective_executor_) {
|
||||
collective_executor_->StartAbort(status);
|
||||
}
|
||||
}
|
||||
}
|
||||
delete this;
|
||||
|
@ -51,7 +51,54 @@ static std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c,
|
||||
|
||||
class CollectiveOpKernel : public AsyncOpKernel {
|
||||
public:
|
||||
explicit CollectiveOpKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {}
|
||||
explicit CollectiveOpKernel(OpKernelConstruction* c)
|
||||
: AsyncOpKernel(c), name_(name()) {}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
CollectiveExecutor* col_exec = c->collective_executor();
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, col_exec,
|
||||
errors::Internal(
|
||||
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
|
||||
name_),
|
||||
done);
|
||||
CancellationToken token =
|
||||
c->cancellation_manager()->get_cancellation_token();
|
||||
bool cancel_registered =
|
||||
c->cancellation_manager()->RegisterCallback(token, [col_exec]() {
|
||||
// StartAbort invokes done callback which contains DeregisterCallback,
|
||||
// so we cannot block on that.
|
||||
col_exec->RunClosure([col_exec]() {
|
||||
col_exec->StartAbort(errors::Cancelled("op cancelled"));
|
||||
});
|
||||
});
|
||||
OP_REQUIRES_ASYNC(c, cancel_registered,
|
||||
errors::Cancelled("op cancelled ", name_), done);
|
||||
|
||||
auto deregister_and_done = [c, col_exec, token, done = std::move(done)]() {
|
||||
c->cancellation_manager()->DeregisterCallback(token);
|
||||
// Abort CollectiveExecutor so that this error can propagate to other
|
||||
// workers.
|
||||
if (!c->status().ok()) {
|
||||
col_exec->StartAbort(c->status());
|
||||
}
|
||||
done();
|
||||
};
|
||||
ComputeAsyncImpl(c, col_exec, std::move(deregister_and_done));
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual void ComputeAsyncImpl(OpKernelContext* c,
|
||||
CollectiveExecutor* col_exec,
|
||||
DoneCallback done) = 0;
|
||||
|
||||
string name_;
|
||||
};
|
||||
|
||||
class CollectiveOpV1Kernel : public CollectiveOpKernel {
|
||||
public:
|
||||
explicit CollectiveOpV1Kernel(OpKernelConstruction* c)
|
||||
: CollectiveOpKernel(c) {}
|
||||
|
||||
// A string encoding instance, frame and iter to be handed off to
|
||||
// the implementation for use in generating RecvBuf keys.
|
||||
@ -90,14 +137,15 @@ class CollectiveOpKernel : public AsyncOpKernel {
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
CollectiveParams col_params_;
|
||||
std::vector<int32> dependencies_;
|
||||
};
|
||||
|
||||
class CollectiveGatherOpKernel : public CollectiveOpKernel {
|
||||
class CollectiveGatherOpKernel : public CollectiveOpV1Kernel {
|
||||
public:
|
||||
explicit CollectiveGatherOpKernel(OpKernelConstruction* c)
|
||||
: CollectiveOpKernel(c) {
|
||||
: CollectiveOpV1Kernel(c) {
|
||||
col_params_.instance.type = GATHER_COLLECTIVE;
|
||||
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
|
||||
OP_REQUIRES(
|
||||
@ -119,15 +167,9 @@ class CollectiveGatherOpKernel : public CollectiveOpKernel {
|
||||
col_params_.group.device_type = c->device_type();
|
||||
}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
CollectiveExecutor* col_exec = c->collective_executor();
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, col_exec,
|
||||
errors::Internal(
|
||||
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
|
||||
col_params_.name),
|
||||
done);
|
||||
|
||||
protected:
|
||||
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
|
||||
DoneCallback done) override {
|
||||
auto output_shape = c->input(0).shape();
|
||||
output_shape.set_dim(
|
||||
0, output_shape.dim_size(0) * col_params_.group.group_size);
|
||||
@ -171,10 +213,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_CPU),
|
||||
REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_GPU),
|
||||
CollectiveGatherOpKernel);
|
||||
|
||||
class CollectiveReduceOpKernel : public CollectiveOpKernel {
|
||||
class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
|
||||
public:
|
||||
explicit CollectiveReduceOpKernel(OpKernelConstruction* c)
|
||||
: CollectiveOpKernel(c) {
|
||||
: CollectiveOpV1Kernel(c) {
|
||||
col_params_.instance.type = REDUCTION_COLLECTIVE;
|
||||
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
|
||||
OP_REQUIRES(
|
||||
@ -231,14 +273,9 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel {
|
||||
col_params_.final_op = BuildOpKernel(c, final_op_name, &sub_node);
|
||||
}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
CollectiveExecutor* col_exec = c->collective_executor();
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, col_exec,
|
||||
errors::Internal(
|
||||
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
|
||||
col_params_.name),
|
||||
done);
|
||||
protected:
|
||||
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
|
||||
DoneCallback done) override {
|
||||
// Allocate output on the first pass through this function. This must be
|
||||
// done immediately, while we're still in the executor thread. Otherwise
|
||||
// the memory is not guaranteed to be unused by any concurrently executing
|
||||
@ -280,10 +317,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_CPU),
|
||||
REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_GPU),
|
||||
CollectiveReduceOpKernel);
|
||||
|
||||
class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
|
||||
class CollectiveBcastSendOpKernel : public CollectiveOpV1Kernel {
|
||||
public:
|
||||
explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c)
|
||||
: CollectiveOpKernel(c) {
|
||||
: CollectiveOpV1Kernel(c) {
|
||||
col_params_.instance.type = BROADCAST_COLLECTIVE;
|
||||
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
|
||||
OP_REQUIRES(
|
||||
@ -309,14 +346,9 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
|
||||
col_params_.group.device_type = c->device_type();
|
||||
}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
CollectiveExecutor* col_exec = c->collective_executor();
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, col_exec,
|
||||
errors::Internal(
|
||||
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
|
||||
col_params_.name),
|
||||
done);
|
||||
protected:
|
||||
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
|
||||
DoneCallback done) override {
|
||||
// Allocate output on the first pass through this function. This must be
|
||||
// done immediately, while we're still in the executor thread. Otherwise
|
||||
// the memory is not guaranteed to be unused by any concurrently executing
|
||||
@ -362,10 +394,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_CPU),
|
||||
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_GPU),
|
||||
CollectiveBcastSendOpKernel);
|
||||
|
||||
class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
|
||||
class CollectiveBcastRecvOpKernel : public CollectiveOpV1Kernel {
|
||||
public:
|
||||
explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c)
|
||||
: CollectiveOpKernel(c) {
|
||||
: CollectiveOpV1Kernel(c) {
|
||||
col_params_.instance.type = BROADCAST_COLLECTIVE;
|
||||
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
|
||||
OP_REQUIRES(
|
||||
@ -391,14 +423,9 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
|
||||
col_params_.group.device_type = c->device_type();
|
||||
}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
CollectiveExecutor* col_exec = c->collective_executor();
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, col_exec,
|
||||
errors::Internal(
|
||||
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
|
||||
col_params_.name),
|
||||
done);
|
||||
protected:
|
||||
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
|
||||
DoneCallback done) override {
|
||||
// Allocate output on the first pass through this function. This must be
|
||||
// done immediately, while we're still in the executor thread. Otherwise
|
||||
// the memory is not guaranteed to be unused by any concurrently executing
|
||||
@ -437,10 +464,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_CPU),
|
||||
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_GPU),
|
||||
CollectiveBcastRecvOpKernel);
|
||||
|
||||
class CollectiveReduceV2OpKernel : public AsyncOpKernel {
|
||||
class CollectiveReduceV2OpKernel : public CollectiveOpKernel {
|
||||
public:
|
||||
explicit CollectiveReduceV2OpKernel(OpKernelConstruction* c)
|
||||
: AsyncOpKernel(c) {
|
||||
: CollectiveOpKernel(c) {
|
||||
col_params_ = std::make_shared<CollectiveParams>();
|
||||
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
|
||||
string merge_op_name;
|
||||
@ -481,14 +508,9 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel {
|
||||
<< col_params_->instance.impl_details.communication_hint;
|
||||
}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
CollectiveExecutor* col_exec = c->collective_executor();
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, col_exec,
|
||||
errors::Internal(
|
||||
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
|
||||
col_params_->name),
|
||||
done);
|
||||
protected:
|
||||
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
|
||||
DoneCallback done) override {
|
||||
const Tensor& input = c->input(0);
|
||||
const Tensor& group_size = c->input(1);
|
||||
const Tensor& group_key = c->input(2);
|
||||
@ -590,10 +612,10 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2")
|
||||
.HostMemory("instance_key"),
|
||||
CollectiveReduceV2OpKernel);
|
||||
|
||||
class CollectiveGatherV2OpKernel : public AsyncOpKernel {
|
||||
class CollectiveGatherV2OpKernel : public CollectiveOpKernel {
|
||||
public:
|
||||
explicit CollectiveGatherV2OpKernel(OpKernelConstruction* c)
|
||||
: AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
|
||||
: CollectiveOpKernel(c), device_type_(DEVICE_DEFAULT) {
|
||||
OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
|
||||
@ -603,14 +625,9 @@ class CollectiveGatherV2OpKernel : public AsyncOpKernel {
|
||||
<< " communication_hint " << communication_hint_;
|
||||
}
|
||||
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
CollectiveExecutor* col_exec = c->collective_executor();
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, col_exec,
|
||||
errors::Internal(
|
||||
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
|
||||
name_),
|
||||
done);
|
||||
protected:
|
||||
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
|
||||
DoneCallback done) override {
|
||||
const Tensor& input = c->input(0);
|
||||
const Tensor& group_size = c->input(1);
|
||||
const Tensor& group_key = c->input(2);
|
||||
@ -712,7 +729,6 @@ class CollectiveGatherV2OpKernel : public AsyncOpKernel {
|
||||
string communication_hint_;
|
||||
float timeout_seconds_;
|
||||
DeviceType device_type_;
|
||||
string name_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2").Device(DEVICE_CPU),
|
||||
|
@ -412,7 +412,8 @@ class CollectiveReplicaLauncher(object):
|
||||
self._group_key, self._device)
|
||||
instance_key_shape = self._collective_keys.get_instance_key(
|
||||
self._group_key, self._device)
|
||||
with ops.device(self._device):
|
||||
with ops.device(self._device), \
|
||||
ops.control_dependencies([array_ops.identity(input_tensor)]):
|
||||
# 1. Transpose
|
||||
# E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3,
|
||||
# we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which
|
||||
|
@ -29,7 +29,6 @@ from tensorflow.python.data.experimental.ops.distribute_options import AutoShard
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.distribute import central_storage_strategy
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import combinations as ds_combinations
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
@ -1151,9 +1150,6 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
||||
if mode == 'graph' and _is_tpu_strategy(distribution):
|
||||
self.skipTest('partial batch not supported with TPU in graph mode.')
|
||||
|
||||
if isinstance(distribution,
|
||||
collective_all_reduce_strategy.CollectiveAllReduceStrategy):
|
||||
self.skipTest('EOF error causes subsequent collective ops fail.')
|
||||
with self.cached_session():
|
||||
with distribution.scope():
|
||||
optimizer_fn = gradient_descent_keras.SGD
|
||||
@ -1166,8 +1162,8 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
|
||||
loss,
|
||||
metrics=metrics)
|
||||
|
||||
inputs = np.zeros((1000, 3), dtype=np.float32)
|
||||
targets = np.zeros((1000, 4), dtype=np.float32)
|
||||
inputs = np.zeros((10, 3), dtype=np.float32)
|
||||
targets = np.zeros((10, 4), dtype=np.float32)
|
||||
# steps/steps_per_epoch are calculated when using numpy arrays as
|
||||
# input data.
|
||||
fit_with_numpy = model.fit(
|
||||
|
@ -2733,19 +2733,23 @@ def _collective_all_reduce_multi_worker(strategy):
|
||||
def _multi_worker_concat(v, strategy):
|
||||
"""Order PerReplica objects for CollectiveAllReduceStrategy and concat."""
|
||||
replicas = strategy._gather(v, axis=0) # pylint: disable=protected-access
|
||||
# v might not have the same shape on different replicas
|
||||
if isinstance(v, ds_values.PerReplica):
|
||||
shapes = array_ops.concat([
|
||||
array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0)
|
||||
for single_value in v.values
|
||||
],
|
||||
axis=0)
|
||||
all_shapes = strategy._gather(shapes, axis=0) # pylint: disable=protected-access
|
||||
else:
|
||||
# v is a tensor. This may happen when, say, we have 2x1 multi-worker.
|
||||
all_shapes = strategy._gather( # pylint: disable=protected-access
|
||||
array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0),
|
||||
axis=0)
|
||||
# TODO(b/170435030): We now need to make sure these run after the iterator
|
||||
# GetNext, so that we don't trigger aborting collective ops in the case of
|
||||
# EOF. Remove after the issue is fixed.
|
||||
with ops.control_dependencies([replicas]):
|
||||
# v might not have the same shape on different replicas
|
||||
if isinstance(v, ds_values.PerReplica):
|
||||
shapes = array_ops.concat([
|
||||
array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0)
|
||||
for single_value in v.values
|
||||
],
|
||||
axis=0)
|
||||
all_shapes = strategy._gather(shapes, axis=0) # pylint: disable=protected-access
|
||||
else:
|
||||
# v is a tensor. This may happen when, say, we have 2x1 multi-worker.
|
||||
all_shapes = strategy._gather( # pylint: disable=protected-access
|
||||
array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0),
|
||||
axis=0)
|
||||
|
||||
replicas = array_ops.split(
|
||||
replicas,
|
||||
|
@ -298,6 +298,8 @@ cuda_py_test(
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/data/experimental/ops:testing",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:test_util",
|
||||
"//tensorflow/python/eager:context",
|
||||
|
@ -24,6 +24,8 @@ import time
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.data.experimental.ops import testing as dataset_testing
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.eager import context
|
||||
@ -469,6 +471,83 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
_setup_context()
|
||||
def_function.function(collective_fn)()
|
||||
|
||||
def testOpErrorNotAbort(self, collective_op, device, communication):
|
||||
# Do not abort if there's no active collective ops. There could be
|
||||
# exceptions like EOF which we expect users to catch, aborting collective
|
||||
# ops on all op errors intervenes with this workflow.
|
||||
dev0 = '/device:%s:0' % device
|
||||
dev1 = '/device:%s:1' % device
|
||||
group_size = 2
|
||||
group_key = 100
|
||||
instance_key = 100
|
||||
dataset = dataset_ops.Dataset.from_tensors([1.])
|
||||
|
||||
@def_function.function
|
||||
def collective_fn(in_tensor):
|
||||
for device in [dev0, dev1]:
|
||||
with ops.device(device):
|
||||
collective_op(
|
||||
in_tensor,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint=communication)
|
||||
|
||||
@def_function.function
|
||||
def f():
|
||||
iterator = iter(dataset)
|
||||
collective_fn(next(iterator))
|
||||
# This next(iterator) should raise EOF.
|
||||
collective_fn(next(iterator))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
f()
|
||||
collective_fn(constant_op.constant([1.]))
|
||||
|
||||
def testOpErrorAbort(self, collective_op, device, communication):
|
||||
# Abort collective ops if there're active collective ops at the time of an
|
||||
# op error. This is due to the inability to cancel collective ops, and op
|
||||
# errors may cause running collective ops to hang.
|
||||
dev0 = '/device:%s:0' % device
|
||||
group_size = 2
|
||||
group_key = 100
|
||||
instance_key = 100
|
||||
in_tensor = constant_op.constant([1.])
|
||||
# Make the dataset sleep a while so that the collective is being executed
|
||||
# when the EOF happens.
|
||||
dataset = dataset_ops.Dataset.from_tensors([1.]).apply(
|
||||
dataset_testing.sleep(sleep_microseconds=200))
|
||||
|
||||
@def_function.function
|
||||
def f():
|
||||
# Launch a collective op that won't be able to finish to test abortion
|
||||
# when other ops error.
|
||||
with ops.device(dev0):
|
||||
ret = collective_op(
|
||||
in_tensor,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint=communication)
|
||||
iterator = iter(dataset)
|
||||
next(iterator)
|
||||
# This should raise EOF.
|
||||
next(iterator)
|
||||
return ret
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
f()
|
||||
# Now collective ops is aborted, subsequent collective ops should fail with
|
||||
# the previous error.
|
||||
with self.assertRaises(errors.CancelledError):
|
||||
with ops.device(dev0):
|
||||
collective_op(
|
||||
in_tensor,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint=communication)
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
|
Loading…
x
Reference in New Issue
Block a user