Support aborting RING communication in multi worker collectives
For the out of band abortion to work we need to have access the cancellation managers used in RPCs from StartAbort(), so we create a new cancellation manager for each call and track it in the object. PiperOrigin-RevId: 338361732 Change-Id: I83d6efac1cf4f68af29ed9c7dd85cd0c9f4c8547
This commit is contained in:
parent
c2239c1a2a
commit
ec37857782
@ -242,6 +242,7 @@ tf_cc_test(
|
||||
|
||||
cc_library(
|
||||
name = "cancellable_call",
|
||||
srcs = ["cancellable_call.cc"],
|
||||
hdrs = ["cancellable_call.h"],
|
||||
deps = [
|
||||
":call_options",
|
||||
|
48
tensorflow/core/distributed_runtime/cancellable_call.cc
Normal file
48
tensorflow/core/distributed_runtime/cancellable_call.cc
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/distributed_runtime/cancellable_call.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void CancellableCall::Start(const StatusCallback& done) {
|
||||
if (cancel_mgr_ == nullptr) {
|
||||
IssueCall(done);
|
||||
return;
|
||||
}
|
||||
CancellationToken token = cancel_mgr_->get_cancellation_token();
|
||||
const bool not_yet_cancelled =
|
||||
cancel_mgr_->RegisterCallback(token, [this]() { Cancel(); });
|
||||
if (not_yet_cancelled) {
|
||||
IssueCall([this, token, done](const Status& s) {
|
||||
cancel_mgr_->DeregisterCallback(token);
|
||||
done(s);
|
||||
});
|
||||
} else {
|
||||
done(errors::Cancelled("RPC Request was cancelled"));
|
||||
}
|
||||
}
|
||||
|
||||
void CancellableCall::Cancel() {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (is_cancelled_) {
|
||||
return;
|
||||
}
|
||||
is_cancelled_ = true;
|
||||
}
|
||||
opts_.StartCancel();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -29,7 +29,8 @@ class CancellableCall {
|
||||
public:
|
||||
CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker,
|
||||
WorkerCacheInterface* wc)
|
||||
: cancel_mgr_(cancel_mgr),
|
||||
: is_cancelled_(false),
|
||||
cancel_mgr_(cancel_mgr),
|
||||
remote_worker_(remote_worker),
|
||||
wc_(wc),
|
||||
wi_(wc_->GetOrCreateWorker(remote_worker_)) {}
|
||||
@ -38,22 +39,17 @@ class CancellableCall {
|
||||
|
||||
virtual void IssueCall(const StatusCallback& done) = 0;
|
||||
|
||||
void Start(const StatusCallback& done) {
|
||||
CancellationToken token = cancel_mgr_->get_cancellation_token();
|
||||
const bool not_yet_cancelled =
|
||||
cancel_mgr_->RegisterCallback(token, [this]() { opts_.StartCancel(); });
|
||||
if (not_yet_cancelled) {
|
||||
IssueCall([this, token, done](const Status& s) {
|
||||
cancel_mgr_->DeregisterCallback(token);
|
||||
done(s);
|
||||
});
|
||||
} else {
|
||||
done(errors::Cancelled("RPC Request was cancelled"));
|
||||
}
|
||||
}
|
||||
void Start(const StatusCallback& done);
|
||||
|
||||
// Cancels the RPC if it's not cancelled yet. This must be called after
|
||||
// Start(). This is normally used if there's a needed to cancel the RPC from a
|
||||
// sideband. If appliable, pass a cancellation manager to the constructor
|
||||
// instead of using this method.
|
||||
void Cancel() TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
protected:
|
||||
mutable mutex mu_;
|
||||
mutex mu_;
|
||||
bool is_cancelled_;
|
||||
CancellationManager* const cancel_mgr_; // Not owned
|
||||
const string remote_worker_;
|
||||
WorkerCacheInterface* const wc_; // Not owned
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/distributed_runtime/cancellable_call.h"
|
||||
#include "tensorflow/core/distributed_runtime/request_id.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/platform/protobuf_internal.h"
|
||||
#include "tensorflow/core/protobuf/transport_options.pb.h"
|
||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||
@ -167,16 +168,23 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
|
||||
recv_buf_callback(s);
|
||||
return;
|
||||
}
|
||||
// If a per-call `cancellation_manager` is passed to this function, prefer
|
||||
// using that over `abortion_cancellation_manager_`. This is because abortion
|
||||
// should also be accompanied by opkernel cancellation.
|
||||
state->call.reset(new RecvBufCall(
|
||||
step_id_, peer_device, peer_task, key, to_device, to_device_ctx,
|
||||
to_alloc_attr, to_tensor, client_locality, state->server_attributes,
|
||||
cancellation_manager == nullptr ? &abortion_cancellation_manager_
|
||||
: cancellation_manager,
|
||||
worker_cache_));
|
||||
state->call->Start(recv_buf_callback);
|
||||
cancellation_manager, worker_cache_));
|
||||
CancellationToken abortion_token =
|
||||
abortion_cancel_mgr_.get_cancellation_token();
|
||||
bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
|
||||
abortion_token, [state] { state->call->Cancel(); });
|
||||
if (already_aborted) {
|
||||
recv_buf_callback(errors::Cancelled("collective ops already aborted"));
|
||||
} else {
|
||||
state->call->Start([this, abortion_token,
|
||||
done = std::move(recv_buf_callback)](const Status& s) {
|
||||
abortion_cancel_mgr_.DeregisterCallback(abortion_token);
|
||||
done(s);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void CollectiveRemoteAccessDistributed::CheckPeerHealth(
|
||||
@ -241,7 +249,7 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth(
|
||||
|
||||
void CollectiveRemoteAccessDistributed::StartAbort(const Status& s) {
|
||||
CollectiveRemoteAccessLocal::StartAbort(s);
|
||||
abortion_cancellation_manager_.StartCancel();
|
||||
abortion_cancel_mgr_.StartCancel();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/collective_rma_local.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/unbounded_work_queue.h"
|
||||
|
||||
@ -55,7 +56,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
|
||||
// Ownership of `work_queue_` is shared between `this` and
|
||||
// `CollectiveExecutorMgr`.
|
||||
std::shared_ptr<UnboundedWorkQueue> work_queue_;
|
||||
CancellationManager abortion_cancellation_manager_;
|
||||
CancellationManager abortion_cancel_mgr_;
|
||||
string task_name_;
|
||||
};
|
||||
|
||||
|
@ -101,7 +101,7 @@ class ClusterParameters(combinations_lib.ParameterModifier):
|
||||
else:
|
||||
has_chief = kwargs.get("has_chief", False)
|
||||
num_workers = kwargs.get("num_workers", 1)
|
||||
runner = None
|
||||
runner = kwargs.get("runner", None)
|
||||
|
||||
# Always set cluster parameters if they're requested. So that generate()
|
||||
# works when there's no startegy in the combinations.
|
||||
|
@ -319,11 +319,13 @@ tf_py_test(
|
||||
"//tensorflow/python:collective_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:multi_process_runner",
|
||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:test",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -20,16 +20,21 @@ from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.core.protobuf import tensorflow_server_pb2
|
||||
from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import collective_ops
|
||||
|
||||
|
||||
@ -46,6 +51,12 @@ def enable_collective_ops(cluster_resolver):
|
||||
context.context().enable_collective_ops(server_def)
|
||||
|
||||
|
||||
device_combination = (
|
||||
combinations.combine(device="CPU", communication="RING", required_gpus=0) +
|
||||
combinations.combine(
|
||||
device="GPU", communication=["RING", "NCCL"], required_gpus=1))
|
||||
|
||||
|
||||
class CollectiveOpTest(test.TestCase):
|
||||
|
||||
def testCheckHealth(self):
|
||||
@ -138,5 +149,82 @@ class CollectiveOpTest(test.TestCase):
|
||||
mpr.join()
|
||||
|
||||
|
||||
two_worker_pool_runner = multi_process_runner.MultiProcessPoolRunner(
|
||||
multi_worker_test_base.create_cluster_spec(num_workers=2),
|
||||
initializer=lambda: enable_collective_ops(cluster_resolver_lib.
|
||||
TFConfigClusterResolver()))
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.combine(
|
||||
mode="eager", num_workers=2, runner=two_worker_pool_runner),
|
||||
device_combination))
|
||||
class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testAbortCommunication(self, device, communication):
|
||||
if communication == "NCCL":
|
||||
self.skipTest("b/171358086: cannot test multi worker NCCL")
|
||||
dev0 = "/device:%s:0" % device
|
||||
cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
|
||||
enable_collective_ops(cluster_resolver)
|
||||
group_size = 2
|
||||
group_key = 100
|
||||
instance_key = 100
|
||||
in_tensor = constant_op.constant([1.])
|
||||
|
||||
# First perform a normal all-reduce to complete the group and instance
|
||||
# resolution.
|
||||
with ops.device(dev0):
|
||||
collective_ops.all_reduce(
|
||||
in_tensor,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint=communication)
|
||||
|
||||
if cluster_resolver.task_id == 1:
|
||||
|
||||
def abort_fn():
|
||||
time.sleep(2)
|
||||
context.context().abort_collective_ops(errors.UNAVAILABLE, "peer down")
|
||||
|
||||
t = threading.Thread(target=abort_fn)
|
||||
t.start()
|
||||
|
||||
with self.assertRaisesRegex(errors.UnavailableError, "peer down"):
|
||||
with ops.device(dev0):
|
||||
collective_ops.all_reduce(
|
||||
in_tensor,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint=communication)
|
||||
|
||||
# After abortion, subsequent collectives should fail immediately.
|
||||
with self.assertRaisesRegex(errors.UnavailableError, "peer down"):
|
||||
with ops.device(dev0):
|
||||
collective_ops.all_reduce(
|
||||
in_tensor,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint=communication)
|
||||
|
||||
t.join()
|
||||
|
||||
# Enable collective ops again in order to reset the collective executor.
|
||||
multi_process_runner.get_barrier().wait()
|
||||
enable_collective_ops(cluster_resolver)
|
||||
multi_process_runner.get_barrier().wait()
|
||||
with ops.device(dev0):
|
||||
collective_ops.all_reduce(
|
||||
in_tensor,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint=communication)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
multi_process_runner.test_main()
|
||||
|
Loading…
Reference in New Issue
Block a user