Allow cancellation of v2 collectives during param resolution
Even if this is only for the first collective that runs for a particular group, having a non-cancellable deadlock is pretty annoying. PiperOrigin-RevId: 355871763 Change-Id: I76383e9f7eb0670fce943643d632c0eaf33b8c92
This commit is contained in:
parent
aeeafe8f66
commit
41667a5862
@ -86,10 +86,34 @@ string TaskNameFromDeviceName(const string& device_name) {
|
||||
|
||||
void CollectiveParamResolverLocal::CompleteGroupLocal(
|
||||
const DeviceAttributes& device, CollectiveParams* cp,
|
||||
const GroupRecCallback& done) {
|
||||
const GroupRecCallback& done, CancellationManager* cancel_mgr) {
|
||||
VLOG(1) << "CompleteGroupLocal device=" << device.name() << " cp: " << cp
|
||||
<< ": " << cp->ToString();
|
||||
std::vector<StatusCallback> to_be_called;
|
||||
// Keep a reference to `cp` to avoid racing with deletion due to cancellation.
|
||||
cp->Ref();
|
||||
core::ScopedUnref cp_unref(cp);
|
||||
|
||||
std::function<void(const Status& s, GroupRec* gr)> done_with_cleanup;
|
||||
if (cancel_mgr != nullptr) {
|
||||
const CancellationToken token = cancel_mgr->get_cancellation_token();
|
||||
const bool already_cancelled = !cancel_mgr->RegisterCallback(
|
||||
token, [done]() { done(errors::Cancelled("op cancelled"), nullptr); });
|
||||
if (already_cancelled) {
|
||||
done(errors::Cancelled("op cancelled"), nullptr);
|
||||
return;
|
||||
}
|
||||
done_with_cleanup = [cancel_mgr, done, token](const Status& s,
|
||||
GroupRec* gr) {
|
||||
if (cancel_mgr == nullptr || cancel_mgr->TryDeregisterCallback(token)) {
|
||||
// The operation was never cancelled, so we'll return a normal status.
|
||||
done(s, gr);
|
||||
}
|
||||
};
|
||||
} else {
|
||||
done_with_cleanup = done;
|
||||
}
|
||||
|
||||
GroupRec* gr = nullptr;
|
||||
Status status;
|
||||
{
|
||||
@ -121,7 +145,7 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
|
||||
}
|
||||
|
||||
if (!status.ok()) {
|
||||
done(status, gr);
|
||||
done_with_cleanup(status, gr);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -140,7 +164,7 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
|
||||
status = status_;
|
||||
}
|
||||
if (!status.ok()) {
|
||||
done(status, nullptr);
|
||||
done_with_cleanup(status, nullptr);
|
||||
return;
|
||||
}
|
||||
{
|
||||
@ -211,7 +235,8 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
|
||||
<< gr->devices.size() << " gr " << gr;
|
||||
|
||||
if (gr->devices.size() < gr->group.group_size) {
|
||||
gr->waiting.push_back(std::bind(done, std::placeholders::_1, gr));
|
||||
gr->waiting.push_back(
|
||||
std::bind(done_with_cleanup, std::placeholders::_1, gr));
|
||||
return;
|
||||
}
|
||||
CHECK_EQ(gr->devices.size(), gr->group.group_size);
|
||||
@ -227,7 +252,7 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
|
||||
}
|
||||
status = gr->status;
|
||||
}
|
||||
done(status, gr);
|
||||
done_with_cleanup(status, gr);
|
||||
for (int i = 0; i < to_be_called.size(); ++i) {
|
||||
to_be_called[i](status);
|
||||
}
|
||||
@ -609,7 +634,8 @@ void CollectiveParamResolverLocal::CompleteParamsAsync(
|
||||
} else {
|
||||
done(s);
|
||||
}
|
||||
});
|
||||
},
|
||||
cancel_mgr);
|
||||
}
|
||||
|
||||
void CollectiveParamResolverLocal::CompleteInstanceAsync(
|
||||
|
@ -85,7 +85,8 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
|
||||
typedef std::function<void(const Status& s, const GroupRec* gr)>
|
||||
GroupRecCallback;
|
||||
void CompleteGroupLocal(const DeviceAttributes& device, CollectiveParams* cp,
|
||||
const GroupRecCallback& done)
|
||||
const GroupRecCallback& done,
|
||||
CancellationManager* cancel_mgr)
|
||||
TF_LOCKS_EXCLUDED(group_mu_);
|
||||
|
||||
// Finishes the group parameters once all members of the group are there.
|
||||
|
@ -291,7 +291,7 @@ void CollectiveParamResolverDistributed::CompleteGroupDistributed(
|
||||
<< " is_leader=" << (group_leader_.empty());
|
||||
if (group_leader_.empty()) {
|
||||
// This is the group leader, so resolution is local.
|
||||
return CompleteGroupLocal(device, cp, done);
|
||||
return CompleteGroupLocal(device, cp, done, cancel_mgr);
|
||||
} else if (GetCachedGroup(cp->group.group_key) == nullptr) {
|
||||
// Need to update Group cache from the leader.
|
||||
CompleteGroupCall* call =
|
||||
@ -306,24 +306,24 @@ void CollectiveParamResolverDistributed::CompleteGroupDistributed(
|
||||
delete call;
|
||||
return;
|
||||
}
|
||||
call->Start(
|
||||
[this, device, cp, call, abortion_token, done](const Status& s) {
|
||||
abortion_cancel_mgr_.DeregisterCallback(abortion_token);
|
||||
if (s.ok()) {
|
||||
Status status = UpdateGroupCache(call->resp_);
|
||||
if (status.ok()) {
|
||||
CompleteGroupLocal(device, cp, done);
|
||||
} else {
|
||||
done(status, nullptr);
|
||||
}
|
||||
} else {
|
||||
done(s, nullptr);
|
||||
}
|
||||
delete call;
|
||||
});
|
||||
call->Start([this, device, cp, call, cancel_mgr, abortion_token,
|
||||
done](const Status& s) {
|
||||
abortion_cancel_mgr_.DeregisterCallback(abortion_token);
|
||||
if (s.ok()) {
|
||||
Status status = UpdateGroupCache(call->resp_);
|
||||
if (status.ok()) {
|
||||
CompleteGroupLocal(device, cp, done, cancel_mgr);
|
||||
} else {
|
||||
done(status, nullptr);
|
||||
}
|
||||
} else {
|
||||
done(s, nullptr);
|
||||
}
|
||||
delete call;
|
||||
});
|
||||
return;
|
||||
} else {
|
||||
return CompleteGroupLocal(device, cp, done);
|
||||
return CompleteGroupLocal(device, cp, done, cancel_mgr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -289,6 +289,7 @@ cuda_py_test(
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:test_util",
|
||||
"//tensorflow/python/eager:cancellation",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.data.experimental.ops import testing as dataset_testing
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.eager import cancellation
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -36,6 +37,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import collective_ops as _collective_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -732,6 +734,76 @@ class OpCancellationTest(test.TestCase, parameterized.TestCase):
|
||||
# proceed.
|
||||
collective_fn()
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.combine(
|
||||
collective_op=[
|
||||
combinations.NamedObject('all_reduce_v2',
|
||||
CollectiveOpsV2.all_reduce),
|
||||
combinations.NamedObject('all_gather_v2',
|
||||
CollectiveOpsV2.all_gather),
|
||||
],
|
||||
mode='eager'), device_combination))
|
||||
def testCancelDuringParamResolution(self, collective_op, device,
|
||||
communication):
|
||||
dev0 = '/device:%s:0' % device
|
||||
dev1 = '/device:%s:1' % device
|
||||
group_size = 2
|
||||
group_key = 100
|
||||
instance_key = 100
|
||||
in_tensor = constant_op.constant([1.])
|
||||
t1_cancellation_manager = cancellation.CancellationManager()
|
||||
t2_cancellation_manager = cancellation.CancellationManager()
|
||||
|
||||
@def_function.function
|
||||
def _collective_fn(x):
|
||||
# Run an assertion to crash one of the two function executions running
|
||||
# collectives. We explicitly cancel the other in response.
|
||||
assert_op = check_ops.assert_equal(x, in_tensor)
|
||||
with ops.control_dependencies([assert_op]):
|
||||
return collective_op(
|
||||
in_tensor,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint=communication)
|
||||
|
||||
collective_concrete = _collective_fn.get_concrete_function(in_tensor)
|
||||
|
||||
finish_mu = threading.Lock()
|
||||
finishes = 0
|
||||
|
||||
def _placement_wrapper(device, x, my_cancellation, other_cancellation):
|
||||
try:
|
||||
with ops.device(device):
|
||||
cancelable_collective = my_cancellation.get_cancelable_function(
|
||||
collective_concrete)
|
||||
return cancelable_collective(x)
|
||||
except errors.InvalidArgumentError:
|
||||
# `assert_equal` failed for this execution of the function. The other
|
||||
# function would deadlock without cancellation.
|
||||
other_cancellation.start_cancel()
|
||||
except errors.CancelledError:
|
||||
pass
|
||||
nonlocal finishes
|
||||
with finish_mu:
|
||||
finishes += 1
|
||||
|
||||
t1 = threading.Thread(
|
||||
target=_placement_wrapper,
|
||||
args=(dev0, constant_op.constant([1.]), t1_cancellation_manager,
|
||||
t2_cancellation_manager))
|
||||
t2 = threading.Thread(
|
||||
target=_placement_wrapper,
|
||||
# Will cause the assertion to fail
|
||||
args=(dev1, constant_op.constant([2.]), t2_cancellation_manager,
|
||||
t1_cancellation_manager))
|
||||
t1.start()
|
||||
t2.start()
|
||||
t1.join()
|
||||
t2.join()
|
||||
self.assertEqual(finishes, 2)
|
||||
|
||||
|
||||
@combinations.generate(collective_op_combinations)
|
||||
class TimeoutTest(test.TestCase, parameterized.TestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user