diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc index 9c46314af67..01b89494c0d 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc @@ -779,7 +779,7 @@ void CollectiveParamResolverLocal::StartAbort(const Status& s) { { mutex_lock l(status_mu_); if (!status_.ok()) { - VLOG(1) << "CollectiveParamResolverLocal already aborted. Ignoring " + VLOG(2) << "CollectiveParamResolverLocal already aborted. Ignoring " "subsequent abortion with status: " << s; return; diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc index 238d29065d2..9466c8ef96b 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc @@ -295,19 +295,30 @@ void CollectiveParamResolverDistributed::CompleteGroupDistributed( CompleteGroupCall* call = new CompleteGroupCall(cp->group, device, cp->instance.type, cancel_mgr, group_leader_, worker_cache_); - call->Start([this, device, cp, call, done](const Status& s) { - if (s.ok()) { - Status status = UpdateGroupCache(call->resp_); - if (status.ok()) { - CompleteGroupLocal(device, cp, done); - } else { - done(status, nullptr); - } - } else { - done(s, nullptr); - } + CancellationToken abortion_token = + abortion_cancel_mgr_.get_cancellation_token(); + bool already_aborted = !abortion_cancel_mgr_.RegisterCallback( + abortion_token, [call] { call->Cancel(); }); + if (already_aborted) { + done(errors::Cancelled("collective ops already aborted"), nullptr); delete call; - }); + return; + } + call->Start( + [this, device, cp, call, abortion_token, done](const Status& s) { + abortion_cancel_mgr_.DeregisterCallback(abortion_token); + if (s.ok()) { + Status status = UpdateGroupCache(call->resp_); + if (status.ok()) { + CompleteGroupLocal(device, cp, done); + } else { + done(status, nullptr); + } + } else { + done(s, nullptr); + } + delete call; + }); return; } else { return CompleteGroupLocal(device, cp, done); @@ -373,7 +384,17 @@ void CollectiveParamResolverDistributed::CompleteInstanceDistributed( CompleteInstanceCall* call = new CompleteInstanceCall( cp->group, cp->instance, cp->name, device, cp->is_source, cancel_mgr, group_leader_, worker_cache_); - call->Start([this, device, gr, cp, call, done](Status s) { + CancellationToken abortion_token = + abortion_cancel_mgr_.get_cancellation_token(); + bool already_aborted = !abortion_cancel_mgr_.RegisterCallback( + abortion_token, [call] { call->Cancel(); }); + if (already_aborted) { + done(errors::Cancelled("collective ops already aborted")); + delete call; + return; + } + call->Start([this, device, gr, cp, call, abortion_token, done](Status s) { + abortion_cancel_mgr_.DeregisterCallback(abortion_token); if (s.ok()) { s = UpdateInstanceCache(gr, cp, call->resp_); } @@ -388,4 +409,19 @@ void CollectiveParamResolverDistributed::CompleteInstanceDistributed( } } +void CollectiveParamResolverDistributed::StartAbort(const Status& s) { + { + mutex_lock l(status_mu_); + if (!status_.ok()) { + VLOG(2) << "CollectiveParamResolverDistributed already aborted. Ignoring " + "subsequent abortion with status: " + << s; + return; + } + status_ = s; + } + StartAbortLocal(s); + abortion_cancel_mgr_.StartCancel(); +} + } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h index 89f923a800b..97445fa6cfd 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_ #include "tensorflow/core/common_runtime/collective_param_resolver_local.h" +#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/platform/status.h" @@ -47,6 +48,8 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal { CancellationManager* cancel_mgr, const StatusCallback& done) override; + void StartAbort(const Status& s) override; + protected: // Returns the cached group iff there's an entry for this group_key in the // local group_table_; returns nullptr otherwise. @@ -87,6 +90,7 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal { WorkerCacheInterface* worker_cache_; // Not owned const string group_leader_; + CancellationManager abortion_cancel_mgr_; }; } // namespace tensorflow diff --git a/tensorflow/python/kernel_tests/collective_ops_multi_worker_test.py b/tensorflow/python/kernel_tests/collective_ops_multi_worker_test.py index cada2b8a99b..5c9a351e327 100644 --- a/tensorflow/python/kernel_tests/collective_ops_multi_worker_test.py +++ b/tensorflow/python/kernel_tests/collective_ops_multi_worker_test.py @@ -33,6 +33,7 @@ from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import collective_ops @@ -51,6 +52,12 @@ def enable_collective_ops(cluster_resolver): context.context().enable_collective_ops(server_def) +def enable_collective_ops_with_barrier(cluster_resolver): + multi_process_runner.get_barrier().wait() + enable_collective_ops(cluster_resolver) + multi_process_runner.get_barrier().wait() + + device_combination = ( combinations.combine(device="CPU", communication="RING", required_gpus=0) + combinations.combine( @@ -167,7 +174,7 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase): self.skipTest("b/171358086: cannot test multi worker NCCL") dev0 = "/device:%s:0" % device cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver() - enable_collective_ops(cluster_resolver) + enable_collective_ops_with_barrier(cluster_resolver) group_size = 2 group_key = 100 instance_key = 100 @@ -214,9 +221,7 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase): t.join() # Enable collective ops again in order to reset the collective executor. - multi_process_runner.get_barrier().wait() - enable_collective_ops(cluster_resolver) - multi_process_runner.get_barrier().wait() + enable_collective_ops_with_barrier(cluster_resolver) with ops.device(dev0): collective_ops.all_reduce( in_tensor, @@ -225,6 +230,102 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase): instance_key, communication_hint=communication) + def testAbortGroupParamsResolution(self, device, communication): + if communication == "NCCL": + self.skipTest("b/171358086: cannot test multi worker NCCL") + dev0 = "/device:%s:0" % device + cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver() + enable_collective_ops_with_barrier(cluster_resolver) + group_size = 2 + group_key = 100 + instance_key = 100 + in_tensor = constant_op.constant([1.]) + + if cluster_resolver.task_id == 1: + + def abort_fn(): + time.sleep(2) + context.context().abort_collective_ops(errors.UNAVAILABLE, "peer down") + + t = threading.Thread(target=abort_fn) + t.start() + + with self.assertRaisesRegex(errors.UnavailableError, "peer down"): + # This hangs on params resolution since we're only launching one + # collective for a group size of 2. + with ops.device(dev0): + collective_ops.all_reduce(in_tensor, group_size, group_key, + instance_key) + + # After abortion, subsequent collectives should fail immediately. + with self.assertRaisesRegex(errors.UnavailableError, "peer down"): + with ops.device(dev0): + collective_ops.all_reduce(in_tensor, group_size, group_key, + instance_key) + + t.join() + + # Enable collective ops again in order to reset the collective executor. + enable_collective_ops_with_barrier(cluster_resolver) + with ops.device(dev0): + collective_ops.all_reduce(in_tensor, group_size, group_key, instance_key) + + def testAbortInstanceParamsResolution(self, device, communication): + if communication == "NCCL": + self.skipTest("b/171358086: cannot test multi worker NCCL") + dev0 = "/device:%s:0" % device + cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver() + enable_collective_ops_with_barrier(cluster_resolver) + group_size = 2 + group_key = 100 + instance_key = 100 + in_tensor = constant_op.constant([1.]) + + # First perform a normal all-reduce to complete the group resolution. + with ops.device(dev0): + collective_ops.all_reduce(in_tensor, group_size, group_key, instance_key) + + # We use broadcast to test aborting instance resolution since only broadcast + # waits for the group. + + if cluster_resolver.task_id == 1: + + def abort_fn(): + time.sleep(2) + context.context().abort_collective_ops(errors.UNAVAILABLE, "peer down") + + t = threading.Thread(target=abort_fn) + t.start() + + # Use a different instance key to trigger another instance resolution. + instance_key = 101 + with self.assertRaisesRegex(errors.UnavailableError, "peer down"): + # This hangs on params resolution since we're only launching one + # collective for a group size of 2. + with ops.device(dev0): + collective_ops.broadcast_send(in_tensor, (1,), dtypes.float32, + group_size, group_key, instance_key) + + # After abortion, subsequent collectives should fail immediately. + with self.assertRaisesRegex(errors.UnavailableError, "peer down"): + with ops.device(dev0): + collective_ops.broadcast_send(in_tensor, (1,), dtypes.float32, + group_size, group_key, instance_key) + + t.join() + + # Enable collective ops again in order to reset the collective executor. + enable_collective_ops_with_barrier(cluster_resolver) + # Reassign instance_key so that it's the same on each worker. + instance_key = 100 + with ops.device(dev0): + if cluster_resolver.task_id == 0: + collective_ops.broadcast_send(in_tensor, (1,), dtypes.float32, + group_size, group_key, instance_key) + else: + collective_ops.broadcast_recv((1,), dtypes.float32, group_size, + group_key, instance_key) + if __name__ == "__main__": multi_process_runner.test_main()