Support aborting RING communication in multi worker collectives

For the out of band abortion to work we need to have access the cancellation managers used in RPCs from StartAbort(), so we create a new cancellation manager for each call and track it in the object.

PiperOrigin-RevId: 338361732
Change-Id: I83d6efac1cf4f68af29ed9c7dd85cd0c9f4c8547
This commit is contained in:
Ran Chen 2020-10-21 16:19:05 -07:00 committed by TensorFlower Gardener
parent c2239c1a2a
commit ec37857782
8 changed files with 169 additions and 25 deletions

View File

@ -242,6 +242,7 @@ tf_cc_test(
cc_library(
name = "cancellable_call",
srcs = ["cancellable_call.cc"],
hdrs = ["cancellable_call.h"],
deps = [
":call_options",

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/cancellable_call.h"
namespace tensorflow {
void CancellableCall::Start(const StatusCallback& done) {
if (cancel_mgr_ == nullptr) {
IssueCall(done);
return;
}
CancellationToken token = cancel_mgr_->get_cancellation_token();
const bool not_yet_cancelled =
cancel_mgr_->RegisterCallback(token, [this]() { Cancel(); });
if (not_yet_cancelled) {
IssueCall([this, token, done](const Status& s) {
cancel_mgr_->DeregisterCallback(token);
done(s);
});
} else {
done(errors::Cancelled("RPC Request was cancelled"));
}
}
void CancellableCall::Cancel() {
{
mutex_lock l(mu_);
if (is_cancelled_) {
return;
}
is_cancelled_ = true;
}
opts_.StartCancel();
}
} // namespace tensorflow

View File

@ -29,7 +29,8 @@ class CancellableCall {
public:
CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker,
WorkerCacheInterface* wc)
: cancel_mgr_(cancel_mgr),
: is_cancelled_(false),
cancel_mgr_(cancel_mgr),
remote_worker_(remote_worker),
wc_(wc),
wi_(wc_->GetOrCreateWorker(remote_worker_)) {}
@ -38,22 +39,17 @@ class CancellableCall {
virtual void IssueCall(const StatusCallback& done) = 0;
void Start(const StatusCallback& done) {
CancellationToken token = cancel_mgr_->get_cancellation_token();
const bool not_yet_cancelled =
cancel_mgr_->RegisterCallback(token, [this]() { opts_.StartCancel(); });
if (not_yet_cancelled) {
IssueCall([this, token, done](const Status& s) {
cancel_mgr_->DeregisterCallback(token);
done(s);
});
} else {
done(errors::Cancelled("RPC Request was cancelled"));
}
}
void Start(const StatusCallback& done);
// Cancels the RPC if it's not cancelled yet. This must be called after
// Start(). This is normally used if there's a needed to cancel the RPC from a
// sideband. If appliable, pass a cancellation manager to the constructor
// instead of using this method.
void Cancel() TF_LOCKS_EXCLUDED(mu_);
protected:
mutable mutex mu_;
mutex mu_;
bool is_cancelled_;
CancellationManager* const cancel_mgr_; // Not owned
const string remote_worker_;
WorkerCacheInterface* const wc_; // Not owned

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/cancellable_call.h"
#include "tensorflow/core/distributed_runtime/request_id.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/transport_options.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
@ -167,16 +168,23 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
recv_buf_callback(s);
return;
}
// If a per-call `cancellation_manager` is passed to this function, prefer
// using that over `abortion_cancellation_manager_`. This is because abortion
// should also be accompanied by opkernel cancellation.
state->call.reset(new RecvBufCall(
step_id_, peer_device, peer_task, key, to_device, to_device_ctx,
to_alloc_attr, to_tensor, client_locality, state->server_attributes,
cancellation_manager == nullptr ? &abortion_cancellation_manager_
: cancellation_manager,
worker_cache_));
state->call->Start(recv_buf_callback);
cancellation_manager, worker_cache_));
CancellationToken abortion_token =
abortion_cancel_mgr_.get_cancellation_token();
bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
abortion_token, [state] { state->call->Cancel(); });
if (already_aborted) {
recv_buf_callback(errors::Cancelled("collective ops already aborted"));
} else {
state->call->Start([this, abortion_token,
done = std::move(recv_buf_callback)](const Status& s) {
abortion_cancel_mgr_.DeregisterCallback(abortion_token);
done(s);
});
}
}
void CollectiveRemoteAccessDistributed::CheckPeerHealth(
@ -241,7 +249,7 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth(
void CollectiveRemoteAccessDistributed::StartAbort(const Status& s) {
CollectiveRemoteAccessLocal::StartAbort(s);
abortion_cancellation_manager_.StartCancel();
abortion_cancel_mgr_.StartCancel();
}
} // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_
#include "tensorflow/core/common_runtime/collective_rma_local.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/unbounded_work_queue.h"
@ -55,7 +56,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
// Ownership of `work_queue_` is shared between `this` and
// `CollectiveExecutorMgr`.
std::shared_ptr<UnboundedWorkQueue> work_queue_;
CancellationManager abortion_cancellation_manager_;
CancellationManager abortion_cancel_mgr_;
string task_name_;
};

View File

@ -101,7 +101,7 @@ class ClusterParameters(combinations_lib.ParameterModifier):
else:
has_chief = kwargs.get("has_chief", False)
num_workers = kwargs.get("num_workers", 1)
runner = None
runner = kwargs.get("runner", None)
# Always set cluster parameters if they're requested. So that generate()
# works when there's no startegy in the combinations.

View File

@ -319,11 +319,13 @@ tf_py_test(
"//tensorflow/python:collective_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:errors",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:multi_process_runner",
"//tensorflow/python/distribute:multi_worker_test_base",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -20,16 +20,21 @@ from __future__ import print_function
import copy
import os
import threading
import time
from absl.testing import parameterized
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import collective_ops
@ -46,6 +51,12 @@ def enable_collective_ops(cluster_resolver):
context.context().enable_collective_ops(server_def)
device_combination = (
combinations.combine(device="CPU", communication="RING", required_gpus=0) +
combinations.combine(
device="GPU", communication=["RING", "NCCL"], required_gpus=1))
class CollectiveOpTest(test.TestCase):
def testCheckHealth(self):
@ -138,5 +149,82 @@ class CollectiveOpTest(test.TestCase):
mpr.join()
two_worker_pool_runner = multi_process_runner.MultiProcessPoolRunner(
multi_worker_test_base.create_cluster_spec(num_workers=2),
initializer=lambda: enable_collective_ops(cluster_resolver_lib.
TFConfigClusterResolver()))
@combinations.generate(
combinations.times(
combinations.combine(
mode="eager", num_workers=2, runner=two_worker_pool_runner),
device_combination))
class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
def testAbortCommunication(self, device, communication):
if communication == "NCCL":
self.skipTest("b/171358086: cannot test multi worker NCCL")
dev0 = "/device:%s:0" % device
cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
enable_collective_ops(cluster_resolver)
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
# First perform a normal all-reduce to complete the group and instance
# resolution.
with ops.device(dev0):
collective_ops.all_reduce(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
if cluster_resolver.task_id == 1:
def abort_fn():
time.sleep(2)
context.context().abort_collective_ops(errors.UNAVAILABLE, "peer down")
t = threading.Thread(target=abort_fn)
t.start()
with self.assertRaisesRegex(errors.UnavailableError, "peer down"):
with ops.device(dev0):
collective_ops.all_reduce(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
# After abortion, subsequent collectives should fail immediately.
with self.assertRaisesRegex(errors.UnavailableError, "peer down"):
with ops.device(dev0):
collective_ops.all_reduce(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
t.join()
# Enable collective ops again in order to reset the collective executor.
multi_process_runner.get_barrier().wait()
enable_collective_ops(cluster_resolver)
multi_process_runner.get_barrier().wait()
with ops.device(dev0):
collective_ops.all_reduce(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
if __name__ == "__main__":
multi_process_runner.test_main()