From 41667a5862fe6c70bee9e392747651285e7c9c69 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Fri, 5 Feb 2021 09:57:42 -0800 Subject: [PATCH] Allow cancellation of v2 collectives during param resolution Even if this is only for the first collective that runs for a particular group, having a non-cancellable deadlock is pretty annoying. PiperOrigin-RevId: 355871763 Change-Id: I76383e9f7eb0670fce943643d632c0eaf33b8c92 --- .../collective_param_resolver_local.cc | 38 ++++++++-- .../collective_param_resolver_local.h | 3 +- .../collective_param_resolver_distributed.cc | 34 ++++----- tensorflow/python/kernel_tests/BUILD | 1 + .../kernel_tests/collective_ops_test.py | 72 +++++++++++++++++++ 5 files changed, 124 insertions(+), 24 deletions(-) diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc index db7ea4dcd3b..28ceec265c4 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc @@ -86,10 +86,34 @@ string TaskNameFromDeviceName(const string& device_name) { void CollectiveParamResolverLocal::CompleteGroupLocal( const DeviceAttributes& device, CollectiveParams* cp, - const GroupRecCallback& done) { + const GroupRecCallback& done, CancellationManager* cancel_mgr) { VLOG(1) << "CompleteGroupLocal device=" << device.name() << " cp: " << cp << ": " << cp->ToString(); std::vector to_be_called; + // Keep a reference to `cp` to avoid racing with deletion due to cancellation. + cp->Ref(); + core::ScopedUnref cp_unref(cp); + + std::function done_with_cleanup; + if (cancel_mgr != nullptr) { + const CancellationToken token = cancel_mgr->get_cancellation_token(); + const bool already_cancelled = !cancel_mgr->RegisterCallback( + token, [done]() { done(errors::Cancelled("op cancelled"), nullptr); }); + if (already_cancelled) { + done(errors::Cancelled("op cancelled"), nullptr); + return; + } + done_with_cleanup = [cancel_mgr, done, token](const Status& s, + GroupRec* gr) { + if (cancel_mgr == nullptr || cancel_mgr->TryDeregisterCallback(token)) { + // The operation was never cancelled, so we'll return a normal status. + done(s, gr); + } + }; + } else { + done_with_cleanup = done; + } + GroupRec* gr = nullptr; Status status; { @@ -121,7 +145,7 @@ void CollectiveParamResolverLocal::CompleteGroupLocal( } if (!status.ok()) { - done(status, gr); + done_with_cleanup(status, gr); return; } @@ -140,7 +164,7 @@ void CollectiveParamResolverLocal::CompleteGroupLocal( status = status_; } if (!status.ok()) { - done(status, nullptr); + done_with_cleanup(status, nullptr); return; } { @@ -211,7 +235,8 @@ void CollectiveParamResolverLocal::CompleteGroupLocal( << gr->devices.size() << " gr " << gr; if (gr->devices.size() < gr->group.group_size) { - gr->waiting.push_back(std::bind(done, std::placeholders::_1, gr)); + gr->waiting.push_back( + std::bind(done_with_cleanup, std::placeholders::_1, gr)); return; } CHECK_EQ(gr->devices.size(), gr->group.group_size); @@ -227,7 +252,7 @@ void CollectiveParamResolverLocal::CompleteGroupLocal( } status = gr->status; } - done(status, gr); + done_with_cleanup(status, gr); for (int i = 0; i < to_be_called.size(); ++i) { to_be_called[i](status); } @@ -609,7 +634,8 @@ void CollectiveParamResolverLocal::CompleteParamsAsync( } else { done(s); } - }); + }, + cancel_mgr); } void CollectiveParamResolverLocal::CompleteInstanceAsync( diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h index 5a7cd54a0de..73ed8f2ae7b 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.h +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h @@ -85,7 +85,8 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { typedef std::function GroupRecCallback; void CompleteGroupLocal(const DeviceAttributes& device, CollectiveParams* cp, - const GroupRecCallback& done) + const GroupRecCallback& done, + CancellationManager* cancel_mgr) TF_LOCKS_EXCLUDED(group_mu_); // Finishes the group parameters once all members of the group are there. diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc index c5d846e1b57..33874ef2c5c 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc @@ -291,7 +291,7 @@ void CollectiveParamResolverDistributed::CompleteGroupDistributed( << " is_leader=" << (group_leader_.empty()); if (group_leader_.empty()) { // This is the group leader, so resolution is local. - return CompleteGroupLocal(device, cp, done); + return CompleteGroupLocal(device, cp, done, cancel_mgr); } else if (GetCachedGroup(cp->group.group_key) == nullptr) { // Need to update Group cache from the leader. CompleteGroupCall* call = @@ -306,24 +306,24 @@ void CollectiveParamResolverDistributed::CompleteGroupDistributed( delete call; return; } - call->Start( - [this, device, cp, call, abortion_token, done](const Status& s) { - abortion_cancel_mgr_.DeregisterCallback(abortion_token); - if (s.ok()) { - Status status = UpdateGroupCache(call->resp_); - if (status.ok()) { - CompleteGroupLocal(device, cp, done); - } else { - done(status, nullptr); - } - } else { - done(s, nullptr); - } - delete call; - }); + call->Start([this, device, cp, call, cancel_mgr, abortion_token, + done](const Status& s) { + abortion_cancel_mgr_.DeregisterCallback(abortion_token); + if (s.ok()) { + Status status = UpdateGroupCache(call->resp_); + if (status.ok()) { + CompleteGroupLocal(device, cp, done, cancel_mgr); + } else { + done(status, nullptr); + } + } else { + done(s, nullptr); + } + delete call; + }); return; } else { - return CompleteGroupLocal(device, cp, done); + return CompleteGroupLocal(device, cp, done, cancel_mgr); } } diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 82a8acb0040..2f778286c21 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -289,6 +289,7 @@ cuda_py_test( "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:test_util", + "//tensorflow/python/eager:cancellation", "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "@absl_py//absl/testing:parameterized", diff --git a/tensorflow/python/kernel_tests/collective_ops_test.py b/tensorflow/python/kernel_tests/collective_ops_test.py index 945ab1c5f3d..3fb1ed3ac50 100644 --- a/tensorflow/python/kernel_tests/collective_ops_test.py +++ b/tensorflow/python/kernel_tests/collective_ops_test.py @@ -29,6 +29,7 @@ from tensorflow.python.data.experimental.ops import testing as dataset_testing from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import combinations from tensorflow.python.distribute import test_util +from tensorflow.python.eager import cancellation from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op @@ -36,6 +37,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import collective_ops as _collective_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test @@ -732,6 +734,76 @@ class OpCancellationTest(test.TestCase, parameterized.TestCase): # proceed. collective_fn() + @combinations.generate( + combinations.times( + combinations.combine( + collective_op=[ + combinations.NamedObject('all_reduce_v2', + CollectiveOpsV2.all_reduce), + combinations.NamedObject('all_gather_v2', + CollectiveOpsV2.all_gather), + ], + mode='eager'), device_combination)) + def testCancelDuringParamResolution(self, collective_op, device, + communication): + dev0 = '/device:%s:0' % device + dev1 = '/device:%s:1' % device + group_size = 2 + group_key = 100 + instance_key = 100 + in_tensor = constant_op.constant([1.]) + t1_cancellation_manager = cancellation.CancellationManager() + t2_cancellation_manager = cancellation.CancellationManager() + + @def_function.function + def _collective_fn(x): + # Run an assertion to crash one of the two function executions running + # collectives. We explicitly cancel the other in response. + assert_op = check_ops.assert_equal(x, in_tensor) + with ops.control_dependencies([assert_op]): + return collective_op( + in_tensor, + group_size, + group_key, + instance_key, + communication_hint=communication) + + collective_concrete = _collective_fn.get_concrete_function(in_tensor) + + finish_mu = threading.Lock() + finishes = 0 + + def _placement_wrapper(device, x, my_cancellation, other_cancellation): + try: + with ops.device(device): + cancelable_collective = my_cancellation.get_cancelable_function( + collective_concrete) + return cancelable_collective(x) + except errors.InvalidArgumentError: + # `assert_equal` failed for this execution of the function. The other + # function would deadlock without cancellation. + other_cancellation.start_cancel() + except errors.CancelledError: + pass + nonlocal finishes + with finish_mu: + finishes += 1 + + t1 = threading.Thread( + target=_placement_wrapper, + args=(dev0, constant_op.constant([1.]), t1_cancellation_manager, + t2_cancellation_manager)) + t2 = threading.Thread( + target=_placement_wrapper, + # Will cause the assertion to fail + args=(dev1, constant_op.constant([2.]), t2_cancellation_manager, + t1_cancellation_manager)) + t1.start() + t2.start() + t1.join() + t2.join() + self.assertEqual(finishes, 2) + @combinations.generate(collective_op_combinations) class TimeoutTest(test.TestCase, parameterized.TestCase):