Support aborting param resolution in multi worker collectives

PiperOrigin-RevId: 338374360
Change-Id: I36d70d33290831907b0ca970adfec802343a6f04
This commit is contained in:
Ran Chen 2020-10-21 17:33:40 -07:00 committed by TensorFlower Gardener
parent ec8ef1a4f2
commit 6879015a25
4 changed files with 159 additions and 18 deletions

View File

@ -779,7 +779,7 @@ void CollectiveParamResolverLocal::StartAbort(const Status& s) {
{
mutex_lock l(status_mu_);
if (!status_.ok()) {
VLOG(1) << "CollectiveParamResolverLocal already aborted. Ignoring "
VLOG(2) << "CollectiveParamResolverLocal already aborted. Ignoring "
"subsequent abortion with status: "
<< s;
return;

View File

@ -295,19 +295,30 @@ void CollectiveParamResolverDistributed::CompleteGroupDistributed(
CompleteGroupCall* call =
new CompleteGroupCall(cp->group, device, cp->instance.type, cancel_mgr,
group_leader_, worker_cache_);
call->Start([this, device, cp, call, done](const Status& s) {
if (s.ok()) {
Status status = UpdateGroupCache(call->resp_);
if (status.ok()) {
CompleteGroupLocal(device, cp, done);
} else {
done(status, nullptr);
}
} else {
done(s, nullptr);
}
CancellationToken abortion_token =
abortion_cancel_mgr_.get_cancellation_token();
bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
abortion_token, [call] { call->Cancel(); });
if (already_aborted) {
done(errors::Cancelled("collective ops already aborted"), nullptr);
delete call;
});
return;
}
call->Start(
[this, device, cp, call, abortion_token, done](const Status& s) {
abortion_cancel_mgr_.DeregisterCallback(abortion_token);
if (s.ok()) {
Status status = UpdateGroupCache(call->resp_);
if (status.ok()) {
CompleteGroupLocal(device, cp, done);
} else {
done(status, nullptr);
}
} else {
done(s, nullptr);
}
delete call;
});
return;
} else {
return CompleteGroupLocal(device, cp, done);
@ -373,7 +384,17 @@ void CollectiveParamResolverDistributed::CompleteInstanceDistributed(
CompleteInstanceCall* call = new CompleteInstanceCall(
cp->group, cp->instance, cp->name, device, cp->is_source, cancel_mgr,
group_leader_, worker_cache_);
call->Start([this, device, gr, cp, call, done](Status s) {
CancellationToken abortion_token =
abortion_cancel_mgr_.get_cancellation_token();
bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
abortion_token, [call] { call->Cancel(); });
if (already_aborted) {
done(errors::Cancelled("collective ops already aborted"));
delete call;
return;
}
call->Start([this, device, gr, cp, call, abortion_token, done](Status s) {
abortion_cancel_mgr_.DeregisterCallback(abortion_token);
if (s.ok()) {
s = UpdateInstanceCache(gr, cp, call->resp_);
}
@ -388,4 +409,19 @@ void CollectiveParamResolverDistributed::CompleteInstanceDistributed(
}
}
void CollectiveParamResolverDistributed::StartAbort(const Status& s) {
{
mutex_lock l(status_mu_);
if (!status_.ok()) {
VLOG(2) << "CollectiveParamResolverDistributed already aborted. Ignoring "
"subsequent abortion with status: "
<< s;
return;
}
status_ = s;
}
StartAbortLocal(s);
abortion_cancel_mgr_.StartCancel();
}
} // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/platform/status.h"
@ -47,6 +48,8 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal {
CancellationManager* cancel_mgr,
const StatusCallback& done) override;
void StartAbort(const Status& s) override;
protected:
// Returns the cached group iff there's an entry for this group_key in the
// local group_table_; returns nullptr otherwise.
@ -87,6 +90,7 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal {
WorkerCacheInterface* worker_cache_; // Not owned
const string group_leader_;
CancellationManager abortion_cancel_mgr_;
};
} // namespace tensorflow

View File

@ -33,6 +33,7 @@ from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import collective_ops
@ -51,6 +52,12 @@ def enable_collective_ops(cluster_resolver):
context.context().enable_collective_ops(server_def)
def enable_collective_ops_with_barrier(cluster_resolver):
multi_process_runner.get_barrier().wait()
enable_collective_ops(cluster_resolver)
multi_process_runner.get_barrier().wait()
device_combination = (
combinations.combine(device="CPU", communication="RING", required_gpus=0) +
combinations.combine(
@ -167,7 +174,7 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
self.skipTest("b/171358086: cannot test multi worker NCCL")
dev0 = "/device:%s:0" % device
cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
enable_collective_ops(cluster_resolver)
enable_collective_ops_with_barrier(cluster_resolver)
group_size = 2
group_key = 100
instance_key = 100
@ -214,9 +221,7 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
t.join()
# Enable collective ops again in order to reset the collective executor.
multi_process_runner.get_barrier().wait()
enable_collective_ops(cluster_resolver)
multi_process_runner.get_barrier().wait()
enable_collective_ops_with_barrier(cluster_resolver)
with ops.device(dev0):
collective_ops.all_reduce(
in_tensor,
@ -225,6 +230,102 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
instance_key,
communication_hint=communication)
def testAbortGroupParamsResolution(self, device, communication):
if communication == "NCCL":
self.skipTest("b/171358086: cannot test multi worker NCCL")
dev0 = "/device:%s:0" % device
cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
enable_collective_ops_with_barrier(cluster_resolver)
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
if cluster_resolver.task_id == 1:
def abort_fn():
time.sleep(2)
context.context().abort_collective_ops(errors.UNAVAILABLE, "peer down")
t = threading.Thread(target=abort_fn)
t.start()
with self.assertRaisesRegex(errors.UnavailableError, "peer down"):
# This hangs on params resolution since we're only launching one
# collective for a group size of 2.
with ops.device(dev0):
collective_ops.all_reduce(in_tensor, group_size, group_key,
instance_key)
# After abortion, subsequent collectives should fail immediately.
with self.assertRaisesRegex(errors.UnavailableError, "peer down"):
with ops.device(dev0):
collective_ops.all_reduce(in_tensor, group_size, group_key,
instance_key)
t.join()
# Enable collective ops again in order to reset the collective executor.
enable_collective_ops_with_barrier(cluster_resolver)
with ops.device(dev0):
collective_ops.all_reduce(in_tensor, group_size, group_key, instance_key)
def testAbortInstanceParamsResolution(self, device, communication):
if communication == "NCCL":
self.skipTest("b/171358086: cannot test multi worker NCCL")
dev0 = "/device:%s:0" % device
cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
enable_collective_ops_with_barrier(cluster_resolver)
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
# First perform a normal all-reduce to complete the group resolution.
with ops.device(dev0):
collective_ops.all_reduce(in_tensor, group_size, group_key, instance_key)
# We use broadcast to test aborting instance resolution since only broadcast
# waits for the group.
if cluster_resolver.task_id == 1:
def abort_fn():
time.sleep(2)
context.context().abort_collective_ops(errors.UNAVAILABLE, "peer down")
t = threading.Thread(target=abort_fn)
t.start()
# Use a different instance key to trigger another instance resolution.
instance_key = 101
with self.assertRaisesRegex(errors.UnavailableError, "peer down"):
# This hangs on params resolution since we're only launching one
# collective for a group size of 2.
with ops.device(dev0):
collective_ops.broadcast_send(in_tensor, (1,), dtypes.float32,
group_size, group_key, instance_key)
# After abortion, subsequent collectives should fail immediately.
with self.assertRaisesRegex(errors.UnavailableError, "peer down"):
with ops.device(dev0):
collective_ops.broadcast_send(in_tensor, (1,), dtypes.float32,
group_size, group_key, instance_key)
t.join()
# Enable collective ops again in order to reset the collective executor.
enable_collective_ops_with_barrier(cluster_resolver)
# Reassign instance_key so that it's the same on each worker.
instance_key = 100
with ops.device(dev0):
if cluster_resolver.task_id == 0:
collective_ops.broadcast_send(in_tensor, (1,), dtypes.float32,
group_size, group_key, instance_key)
else:
collective_ops.broadcast_recv((1,), dtypes.float32, group_size,
group_key, instance_key)
if __name__ == "__main__":
multi_process_runner.test_main()