Allow cancellation of v2 collectives during param resolution

Even if this is only for the first collective that runs for a particular group, having a non-cancellable deadlock is pretty annoying.

PiperOrigin-RevId: 355871763
Change-Id: I76383e9f7eb0670fce943643d632c0eaf33b8c92
This commit is contained in:
Allen Lavoie 2021-02-05 09:57:42 -08:00 committed by TensorFlower Gardener
parent aeeafe8f66
commit 41667a5862
5 changed files with 124 additions and 24 deletions

View File

@ -86,10 +86,34 @@ string TaskNameFromDeviceName(const string& device_name) {
void CollectiveParamResolverLocal::CompleteGroupLocal(
const DeviceAttributes& device, CollectiveParams* cp,
const GroupRecCallback& done) {
const GroupRecCallback& done, CancellationManager* cancel_mgr) {
VLOG(1) << "CompleteGroupLocal device=" << device.name() << " cp: " << cp
<< ": " << cp->ToString();
std::vector<StatusCallback> to_be_called;
// Keep a reference to `cp` to avoid racing with deletion due to cancellation.
cp->Ref();
core::ScopedUnref cp_unref(cp);
std::function<void(const Status& s, GroupRec* gr)> done_with_cleanup;
if (cancel_mgr != nullptr) {
const CancellationToken token = cancel_mgr->get_cancellation_token();
const bool already_cancelled = !cancel_mgr->RegisterCallback(
token, [done]() { done(errors::Cancelled("op cancelled"), nullptr); });
if (already_cancelled) {
done(errors::Cancelled("op cancelled"), nullptr);
return;
}
done_with_cleanup = [cancel_mgr, done, token](const Status& s,
GroupRec* gr) {
if (cancel_mgr == nullptr || cancel_mgr->TryDeregisterCallback(token)) {
// The operation was never cancelled, so we'll return a normal status.
done(s, gr);
}
};
} else {
done_with_cleanup = done;
}
GroupRec* gr = nullptr;
Status status;
{
@ -121,7 +145,7 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
}
if (!status.ok()) {
done(status, gr);
done_with_cleanup(status, gr);
return;
}
@ -140,7 +164,7 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
status = status_;
}
if (!status.ok()) {
done(status, nullptr);
done_with_cleanup(status, nullptr);
return;
}
{
@ -211,7 +235,8 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
<< gr->devices.size() << " gr " << gr;
if (gr->devices.size() < gr->group.group_size) {
gr->waiting.push_back(std::bind(done, std::placeholders::_1, gr));
gr->waiting.push_back(
std::bind(done_with_cleanup, std::placeholders::_1, gr));
return;
}
CHECK_EQ(gr->devices.size(), gr->group.group_size);
@ -227,7 +252,7 @@ void CollectiveParamResolverLocal::CompleteGroupLocal(
}
status = gr->status;
}
done(status, gr);
done_with_cleanup(status, gr);
for (int i = 0; i < to_be_called.size(); ++i) {
to_be_called[i](status);
}
@ -609,7 +634,8 @@ void CollectiveParamResolverLocal::CompleteParamsAsync(
} else {
done(s);
}
});
},
cancel_mgr);
}
void CollectiveParamResolverLocal::CompleteInstanceAsync(

View File

@ -85,7 +85,8 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
typedef std::function<void(const Status& s, const GroupRec* gr)>
GroupRecCallback;
void CompleteGroupLocal(const DeviceAttributes& device, CollectiveParams* cp,
const GroupRecCallback& done)
const GroupRecCallback& done,
CancellationManager* cancel_mgr)
TF_LOCKS_EXCLUDED(group_mu_);
// Finishes the group parameters once all members of the group are there.

View File

@ -291,7 +291,7 @@ void CollectiveParamResolverDistributed::CompleteGroupDistributed(
<< " is_leader=" << (group_leader_.empty());
if (group_leader_.empty()) {
// This is the group leader, so resolution is local.
return CompleteGroupLocal(device, cp, done);
return CompleteGroupLocal(device, cp, done, cancel_mgr);
} else if (GetCachedGroup(cp->group.group_key) == nullptr) {
// Need to update Group cache from the leader.
CompleteGroupCall* call =
@ -306,24 +306,24 @@ void CollectiveParamResolverDistributed::CompleteGroupDistributed(
delete call;
return;
}
call->Start(
[this, device, cp, call, abortion_token, done](const Status& s) {
abortion_cancel_mgr_.DeregisterCallback(abortion_token);
if (s.ok()) {
Status status = UpdateGroupCache(call->resp_);
if (status.ok()) {
CompleteGroupLocal(device, cp, done);
} else {
done(status, nullptr);
}
} else {
done(s, nullptr);
}
delete call;
});
call->Start([this, device, cp, call, cancel_mgr, abortion_token,
done](const Status& s) {
abortion_cancel_mgr_.DeregisterCallback(abortion_token);
if (s.ok()) {
Status status = UpdateGroupCache(call->resp_);
if (status.ok()) {
CompleteGroupLocal(device, cp, done, cancel_mgr);
} else {
done(status, nullptr);
}
} else {
done(s, nullptr);
}
delete call;
});
return;
} else {
return CompleteGroupLocal(device, cp, done);
return CompleteGroupLocal(device, cp, done, cancel_mgr);
}
}

View File

@ -289,6 +289,7 @@ cuda_py_test(
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:test_util",
"//tensorflow/python/eager:cancellation",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"@absl_py//absl/testing:parameterized",

View File

@ -29,6 +29,7 @@ from tensorflow.python.data.experimental.ops import testing as dataset_testing
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import test_util
from tensorflow.python.eager import cancellation
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
@ -36,6 +37,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import collective_ops as _collective_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
@ -732,6 +734,76 @@ class OpCancellationTest(test.TestCase, parameterized.TestCase):
# proceed.
collective_fn()
@combinations.generate(
combinations.times(
combinations.combine(
collective_op=[
combinations.NamedObject('all_reduce_v2',
CollectiveOpsV2.all_reduce),
combinations.NamedObject('all_gather_v2',
CollectiveOpsV2.all_gather),
],
mode='eager'), device_combination))
def testCancelDuringParamResolution(self, collective_op, device,
communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
t1_cancellation_manager = cancellation.CancellationManager()
t2_cancellation_manager = cancellation.CancellationManager()
@def_function.function
def _collective_fn(x):
# Run an assertion to crash one of the two function executions running
# collectives. We explicitly cancel the other in response.
assert_op = check_ops.assert_equal(x, in_tensor)
with ops.control_dependencies([assert_op]):
return collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
collective_concrete = _collective_fn.get_concrete_function(in_tensor)
finish_mu = threading.Lock()
finishes = 0
def _placement_wrapper(device, x, my_cancellation, other_cancellation):
try:
with ops.device(device):
cancelable_collective = my_cancellation.get_cancelable_function(
collective_concrete)
return cancelable_collective(x)
except errors.InvalidArgumentError:
# `assert_equal` failed for this execution of the function. The other
# function would deadlock without cancellation.
other_cancellation.start_cancel()
except errors.CancelledError:
pass
nonlocal finishes
with finish_mu:
finishes += 1
t1 = threading.Thread(
target=_placement_wrapper,
args=(dev0, constant_op.constant([1.]), t1_cancellation_manager,
t2_cancellation_manager))
t2 = threading.Thread(
target=_placement_wrapper,
# Will cause the assertion to fail
args=(dev1, constant_op.constant([2.]), t2_cancellation_manager,
t1_cancellation_manager))
t1.start()
t2.start()
t1.join()
t2.join()
self.assertEqual(finishes, 2)
@combinations.generate(collective_op_combinations)
class TimeoutTest(test.TestCase, parameterized.TestCase):