Support aborting RING communication in multi worker collectives

For the out of band abortion to work we need to have access the cancellation managers used in RPCs from StartAbort(), so we create a new cancellation manager for each call and track it in the object.

PiperOrigin-RevId: 338361732
Change-Id: I83d6efac1cf4f68af29ed9c7dd85cd0c9f4c8547
This commit is contained in:
Ran Chen 2020-10-21 16:19:05 -07:00 committed by TensorFlower Gardener
parent c2239c1a2a
commit ec37857782
8 changed files with 169 additions and 25 deletions

View File

@ -242,6 +242,7 @@ tf_cc_test(
cc_library( cc_library(
name = "cancellable_call", name = "cancellable_call",
srcs = ["cancellable_call.cc"],
hdrs = ["cancellable_call.h"], hdrs = ["cancellable_call.h"],
deps = [ deps = [
":call_options", ":call_options",

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/cancellable_call.h"
namespace tensorflow {
void CancellableCall::Start(const StatusCallback& done) {
if (cancel_mgr_ == nullptr) {
IssueCall(done);
return;
}
CancellationToken token = cancel_mgr_->get_cancellation_token();
const bool not_yet_cancelled =
cancel_mgr_->RegisterCallback(token, [this]() { Cancel(); });
if (not_yet_cancelled) {
IssueCall([this, token, done](const Status& s) {
cancel_mgr_->DeregisterCallback(token);
done(s);
});
} else {
done(errors::Cancelled("RPC Request was cancelled"));
}
}
void CancellableCall::Cancel() {
{
mutex_lock l(mu_);
if (is_cancelled_) {
return;
}
is_cancelled_ = true;
}
opts_.StartCancel();
}
} // namespace tensorflow

View File

@ -29,7 +29,8 @@ class CancellableCall {
public: public:
CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker, CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker,
WorkerCacheInterface* wc) WorkerCacheInterface* wc)
: cancel_mgr_(cancel_mgr), : is_cancelled_(false),
cancel_mgr_(cancel_mgr),
remote_worker_(remote_worker), remote_worker_(remote_worker),
wc_(wc), wc_(wc),
wi_(wc_->GetOrCreateWorker(remote_worker_)) {} wi_(wc_->GetOrCreateWorker(remote_worker_)) {}
@ -38,22 +39,17 @@ class CancellableCall {
virtual void IssueCall(const StatusCallback& done) = 0; virtual void IssueCall(const StatusCallback& done) = 0;
void Start(const StatusCallback& done) { void Start(const StatusCallback& done);
CancellationToken token = cancel_mgr_->get_cancellation_token();
const bool not_yet_cancelled = // Cancels the RPC if it's not cancelled yet. This must be called after
cancel_mgr_->RegisterCallback(token, [this]() { opts_.StartCancel(); }); // Start(). This is normally used if there's a needed to cancel the RPC from a
if (not_yet_cancelled) { // sideband. If appliable, pass a cancellation manager to the constructor
IssueCall([this, token, done](const Status& s) { // instead of using this method.
cancel_mgr_->DeregisterCallback(token); void Cancel() TF_LOCKS_EXCLUDED(mu_);
done(s);
});
} else {
done(errors::Cancelled("RPC Request was cancelled"));
}
}
protected: protected:
mutable mutex mu_; mutex mu_;
bool is_cancelled_;
CancellationManager* const cancel_mgr_; // Not owned CancellationManager* const cancel_mgr_; // Not owned
const string remote_worker_; const string remote_worker_;
WorkerCacheInterface* const wc_; // Not owned WorkerCacheInterface* const wc_; // Not owned

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/cancellable_call.h" #include "tensorflow/core/distributed_runtime/cancellable_call.h"
#include "tensorflow/core/distributed_runtime/request_id.h" #include "tensorflow/core/distributed_runtime/request_id.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/platform/protobuf_internal.h" #include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/transport_options.pb.h" #include "tensorflow/core/protobuf/transport_options.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h" #include "tensorflow/core/protobuf/worker.pb.h"
@ -167,16 +168,23 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
recv_buf_callback(s); recv_buf_callback(s);
return; return;
} }
// If a per-call `cancellation_manager` is passed to this function, prefer
// using that over `abortion_cancellation_manager_`. This is because abortion
// should also be accompanied by opkernel cancellation.
state->call.reset(new RecvBufCall( state->call.reset(new RecvBufCall(
step_id_, peer_device, peer_task, key, to_device, to_device_ctx, step_id_, peer_device, peer_task, key, to_device, to_device_ctx,
to_alloc_attr, to_tensor, client_locality, state->server_attributes, to_alloc_attr, to_tensor, client_locality, state->server_attributes,
cancellation_manager == nullptr ? &abortion_cancellation_manager_ cancellation_manager, worker_cache_));
: cancellation_manager, CancellationToken abortion_token =
worker_cache_)); abortion_cancel_mgr_.get_cancellation_token();
state->call->Start(recv_buf_callback); bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
abortion_token, [state] { state->call->Cancel(); });
if (already_aborted) {
recv_buf_callback(errors::Cancelled("collective ops already aborted"));
} else {
state->call->Start([this, abortion_token,
done = std::move(recv_buf_callback)](const Status& s) {
abortion_cancel_mgr_.DeregisterCallback(abortion_token);
done(s);
});
}
} }
void CollectiveRemoteAccessDistributed::CheckPeerHealth( void CollectiveRemoteAccessDistributed::CheckPeerHealth(
@ -241,7 +249,7 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth(
void CollectiveRemoteAccessDistributed::StartAbort(const Status& s) { void CollectiveRemoteAccessDistributed::StartAbort(const Status& s) {
CollectiveRemoteAccessLocal::StartAbort(s); CollectiveRemoteAccessLocal::StartAbort(s);
abortion_cancellation_manager_.StartCancel(); abortion_cancel_mgr_.StartCancel();
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_
#include "tensorflow/core/common_runtime/collective_rma_local.h" #include "tensorflow/core/common_runtime/collective_rma_local.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/unbounded_work_queue.h" #include "tensorflow/core/platform/unbounded_work_queue.h"
@ -55,7 +56,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
// Ownership of `work_queue_` is shared between `this` and // Ownership of `work_queue_` is shared between `this` and
// `CollectiveExecutorMgr`. // `CollectiveExecutorMgr`.
std::shared_ptr<UnboundedWorkQueue> work_queue_; std::shared_ptr<UnboundedWorkQueue> work_queue_;
CancellationManager abortion_cancellation_manager_; CancellationManager abortion_cancel_mgr_;
string task_name_; string task_name_;
}; };

View File

@ -101,7 +101,7 @@ class ClusterParameters(combinations_lib.ParameterModifier):
else: else:
has_chief = kwargs.get("has_chief", False) has_chief = kwargs.get("has_chief", False)
num_workers = kwargs.get("num_workers", 1) num_workers = kwargs.get("num_workers", 1)
runner = None runner = kwargs.get("runner", None)
# Always set cluster parameters if they're requested. So that generate() # Always set cluster parameters if they're requested. So that generate()
# works when there's no startegy in the combinations. # works when there's no startegy in the combinations.

View File

@ -319,11 +319,13 @@ tf_py_test(
"//tensorflow/python:collective_ops", "//tensorflow/python:collective_ops",
"//tensorflow/python:constant_op", "//tensorflow/python:constant_op",
"//tensorflow/python:errors", "//tensorflow/python:errors",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:multi_process_runner", "//tensorflow/python/distribute:multi_process_runner",
"//tensorflow/python/distribute:multi_worker_test_base", "//tensorflow/python/distribute:multi_worker_test_base",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/eager:test", "//tensorflow/python/eager:test",
"@absl_py//absl/testing:parameterized",
], ],
) )

View File

@ -20,16 +20,21 @@ from __future__ import print_function
import copy import copy
import os import os
import threading
import time import time
from absl.testing import parameterized
from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import test from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import collective_ops from tensorflow.python.ops import collective_ops
@ -46,6 +51,12 @@ def enable_collective_ops(cluster_resolver):
context.context().enable_collective_ops(server_def) context.context().enable_collective_ops(server_def)
device_combination = (
combinations.combine(device="CPU", communication="RING", required_gpus=0) +
combinations.combine(
device="GPU", communication=["RING", "NCCL"], required_gpus=1))
class CollectiveOpTest(test.TestCase): class CollectiveOpTest(test.TestCase):
def testCheckHealth(self): def testCheckHealth(self):
@ -138,5 +149,82 @@ class CollectiveOpTest(test.TestCase):
mpr.join() mpr.join()
two_worker_pool_runner = multi_process_runner.MultiProcessPoolRunner(
multi_worker_test_base.create_cluster_spec(num_workers=2),
initializer=lambda: enable_collective_ops(cluster_resolver_lib.
TFConfigClusterResolver()))
@combinations.generate(
combinations.times(
combinations.combine(
mode="eager", num_workers=2, runner=two_worker_pool_runner),
device_combination))
class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
def testAbortCommunication(self, device, communication):
if communication == "NCCL":
self.skipTest("b/171358086: cannot test multi worker NCCL")
dev0 = "/device:%s:0" % device
cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
enable_collective_ops(cluster_resolver)
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
# First perform a normal all-reduce to complete the group and instance
# resolution.
with ops.device(dev0):
collective_ops.all_reduce(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
if cluster_resolver.task_id == 1:
def abort_fn():
time.sleep(2)
context.context().abort_collective_ops(errors.UNAVAILABLE, "peer down")
t = threading.Thread(target=abort_fn)
t.start()
with self.assertRaisesRegex(errors.UnavailableError, "peer down"):
with ops.device(dev0):
collective_ops.all_reduce(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
# After abortion, subsequent collectives should fail immediately.
with self.assertRaisesRegex(errors.UnavailableError, "peer down"):
with ops.device(dev0):
collective_ops.all_reduce(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
t.join()
# Enable collective ops again in order to reset the collective executor.
multi_process_runner.get_barrier().wait()
enable_collective_ops(cluster_resolver)
multi_process_runner.get_barrier().wait()
with ops.device(dev0):
collective_ops.all_reduce(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
if __name__ == "__main__": if __name__ == "__main__":
multi_process_runner.test_main() multi_process_runner.test_main()