Support aborting param resolution in multi worker collectives
PiperOrigin-RevId: 338374360 Change-Id: I36d70d33290831907b0ca970adfec802343a6f04
This commit is contained in:
parent
ec8ef1a4f2
commit
6879015a25
@ -779,7 +779,7 @@ void CollectiveParamResolverLocal::StartAbort(const Status& s) {
|
||||
{
|
||||
mutex_lock l(status_mu_);
|
||||
if (!status_.ok()) {
|
||||
VLOG(1) << "CollectiveParamResolverLocal already aborted. Ignoring "
|
||||
VLOG(2) << "CollectiveParamResolverLocal already aborted. Ignoring "
|
||||
"subsequent abortion with status: "
|
||||
<< s;
|
||||
return;
|
||||
|
@ -295,19 +295,30 @@ void CollectiveParamResolverDistributed::CompleteGroupDistributed(
|
||||
CompleteGroupCall* call =
|
||||
new CompleteGroupCall(cp->group, device, cp->instance.type, cancel_mgr,
|
||||
group_leader_, worker_cache_);
|
||||
call->Start([this, device, cp, call, done](const Status& s) {
|
||||
if (s.ok()) {
|
||||
Status status = UpdateGroupCache(call->resp_);
|
||||
if (status.ok()) {
|
||||
CompleteGroupLocal(device, cp, done);
|
||||
} else {
|
||||
done(status, nullptr);
|
||||
}
|
||||
} else {
|
||||
done(s, nullptr);
|
||||
}
|
||||
CancellationToken abortion_token =
|
||||
abortion_cancel_mgr_.get_cancellation_token();
|
||||
bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
|
||||
abortion_token, [call] { call->Cancel(); });
|
||||
if (already_aborted) {
|
||||
done(errors::Cancelled("collective ops already aborted"), nullptr);
|
||||
delete call;
|
||||
});
|
||||
return;
|
||||
}
|
||||
call->Start(
|
||||
[this, device, cp, call, abortion_token, done](const Status& s) {
|
||||
abortion_cancel_mgr_.DeregisterCallback(abortion_token);
|
||||
if (s.ok()) {
|
||||
Status status = UpdateGroupCache(call->resp_);
|
||||
if (status.ok()) {
|
||||
CompleteGroupLocal(device, cp, done);
|
||||
} else {
|
||||
done(status, nullptr);
|
||||
}
|
||||
} else {
|
||||
done(s, nullptr);
|
||||
}
|
||||
delete call;
|
||||
});
|
||||
return;
|
||||
} else {
|
||||
return CompleteGroupLocal(device, cp, done);
|
||||
@ -373,7 +384,17 @@ void CollectiveParamResolverDistributed::CompleteInstanceDistributed(
|
||||
CompleteInstanceCall* call = new CompleteInstanceCall(
|
||||
cp->group, cp->instance, cp->name, device, cp->is_source, cancel_mgr,
|
||||
group_leader_, worker_cache_);
|
||||
call->Start([this, device, gr, cp, call, done](Status s) {
|
||||
CancellationToken abortion_token =
|
||||
abortion_cancel_mgr_.get_cancellation_token();
|
||||
bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
|
||||
abortion_token, [call] { call->Cancel(); });
|
||||
if (already_aborted) {
|
||||
done(errors::Cancelled("collective ops already aborted"));
|
||||
delete call;
|
||||
return;
|
||||
}
|
||||
call->Start([this, device, gr, cp, call, abortion_token, done](Status s) {
|
||||
abortion_cancel_mgr_.DeregisterCallback(abortion_token);
|
||||
if (s.ok()) {
|
||||
s = UpdateInstanceCache(gr, cp, call->resp_);
|
||||
}
|
||||
@ -388,4 +409,19 @@ void CollectiveParamResolverDistributed::CompleteInstanceDistributed(
|
||||
}
|
||||
}
|
||||
|
||||
void CollectiveParamResolverDistributed::StartAbort(const Status& s) {
|
||||
{
|
||||
mutex_lock l(status_mu_);
|
||||
if (!status_.ok()) {
|
||||
VLOG(2) << "CollectiveParamResolverDistributed already aborted. Ignoring "
|
||||
"subsequent abortion with status: "
|
||||
<< s;
|
||||
return;
|
||||
}
|
||||
status_ = s;
|
||||
}
|
||||
StartAbortLocal(s);
|
||||
abortion_cancel_mgr_.StartCancel();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
@ -47,6 +48,8 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal {
|
||||
CancellationManager* cancel_mgr,
|
||||
const StatusCallback& done) override;
|
||||
|
||||
void StartAbort(const Status& s) override;
|
||||
|
||||
protected:
|
||||
// Returns the cached group iff there's an entry for this group_key in the
|
||||
// local group_table_; returns nullptr otherwise.
|
||||
@ -87,6 +90,7 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal {
|
||||
|
||||
WorkerCacheInterface* worker_cache_; // Not owned
|
||||
const string group_leader_;
|
||||
CancellationManager abortion_cancel_mgr_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import collective_ops
|
||||
@ -51,6 +52,12 @@ def enable_collective_ops(cluster_resolver):
|
||||
context.context().enable_collective_ops(server_def)
|
||||
|
||||
|
||||
def enable_collective_ops_with_barrier(cluster_resolver):
|
||||
multi_process_runner.get_barrier().wait()
|
||||
enable_collective_ops(cluster_resolver)
|
||||
multi_process_runner.get_barrier().wait()
|
||||
|
||||
|
||||
device_combination = (
|
||||
combinations.combine(device="CPU", communication="RING", required_gpus=0) +
|
||||
combinations.combine(
|
||||
@ -167,7 +174,7 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
self.skipTest("b/171358086: cannot test multi worker NCCL")
|
||||
dev0 = "/device:%s:0" % device
|
||||
cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
|
||||
enable_collective_ops(cluster_resolver)
|
||||
enable_collective_ops_with_barrier(cluster_resolver)
|
||||
group_size = 2
|
||||
group_key = 100
|
||||
instance_key = 100
|
||||
@ -214,9 +221,7 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
t.join()
|
||||
|
||||
# Enable collective ops again in order to reset the collective executor.
|
||||
multi_process_runner.get_barrier().wait()
|
||||
enable_collective_ops(cluster_resolver)
|
||||
multi_process_runner.get_barrier().wait()
|
||||
enable_collective_ops_with_barrier(cluster_resolver)
|
||||
with ops.device(dev0):
|
||||
collective_ops.all_reduce(
|
||||
in_tensor,
|
||||
@ -225,6 +230,102 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
instance_key,
|
||||
communication_hint=communication)
|
||||
|
||||
def testAbortGroupParamsResolution(self, device, communication):
|
||||
if communication == "NCCL":
|
||||
self.skipTest("b/171358086: cannot test multi worker NCCL")
|
||||
dev0 = "/device:%s:0" % device
|
||||
cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
|
||||
enable_collective_ops_with_barrier(cluster_resolver)
|
||||
group_size = 2
|
||||
group_key = 100
|
||||
instance_key = 100
|
||||
in_tensor = constant_op.constant([1.])
|
||||
|
||||
if cluster_resolver.task_id == 1:
|
||||
|
||||
def abort_fn():
|
||||
time.sleep(2)
|
||||
context.context().abort_collective_ops(errors.UNAVAILABLE, "peer down")
|
||||
|
||||
t = threading.Thread(target=abort_fn)
|
||||
t.start()
|
||||
|
||||
with self.assertRaisesRegex(errors.UnavailableError, "peer down"):
|
||||
# This hangs on params resolution since we're only launching one
|
||||
# collective for a group size of 2.
|
||||
with ops.device(dev0):
|
||||
collective_ops.all_reduce(in_tensor, group_size, group_key,
|
||||
instance_key)
|
||||
|
||||
# After abortion, subsequent collectives should fail immediately.
|
||||
with self.assertRaisesRegex(errors.UnavailableError, "peer down"):
|
||||
with ops.device(dev0):
|
||||
collective_ops.all_reduce(in_tensor, group_size, group_key,
|
||||
instance_key)
|
||||
|
||||
t.join()
|
||||
|
||||
# Enable collective ops again in order to reset the collective executor.
|
||||
enable_collective_ops_with_barrier(cluster_resolver)
|
||||
with ops.device(dev0):
|
||||
collective_ops.all_reduce(in_tensor, group_size, group_key, instance_key)
|
||||
|
||||
def testAbortInstanceParamsResolution(self, device, communication):
|
||||
if communication == "NCCL":
|
||||
self.skipTest("b/171358086: cannot test multi worker NCCL")
|
||||
dev0 = "/device:%s:0" % device
|
||||
cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
|
||||
enable_collective_ops_with_barrier(cluster_resolver)
|
||||
group_size = 2
|
||||
group_key = 100
|
||||
instance_key = 100
|
||||
in_tensor = constant_op.constant([1.])
|
||||
|
||||
# First perform a normal all-reduce to complete the group resolution.
|
||||
with ops.device(dev0):
|
||||
collective_ops.all_reduce(in_tensor, group_size, group_key, instance_key)
|
||||
|
||||
# We use broadcast to test aborting instance resolution since only broadcast
|
||||
# waits for the group.
|
||||
|
||||
if cluster_resolver.task_id == 1:
|
||||
|
||||
def abort_fn():
|
||||
time.sleep(2)
|
||||
context.context().abort_collective_ops(errors.UNAVAILABLE, "peer down")
|
||||
|
||||
t = threading.Thread(target=abort_fn)
|
||||
t.start()
|
||||
|
||||
# Use a different instance key to trigger another instance resolution.
|
||||
instance_key = 101
|
||||
with self.assertRaisesRegex(errors.UnavailableError, "peer down"):
|
||||
# This hangs on params resolution since we're only launching one
|
||||
# collective for a group size of 2.
|
||||
with ops.device(dev0):
|
||||
collective_ops.broadcast_send(in_tensor, (1,), dtypes.float32,
|
||||
group_size, group_key, instance_key)
|
||||
|
||||
# After abortion, subsequent collectives should fail immediately.
|
||||
with self.assertRaisesRegex(errors.UnavailableError, "peer down"):
|
||||
with ops.device(dev0):
|
||||
collective_ops.broadcast_send(in_tensor, (1,), dtypes.float32,
|
||||
group_size, group_key, instance_key)
|
||||
|
||||
t.join()
|
||||
|
||||
# Enable collective ops again in order to reset the collective executor.
|
||||
enable_collective_ops_with_barrier(cluster_resolver)
|
||||
# Reassign instance_key so that it's the same on each worker.
|
||||
instance_key = 100
|
||||
with ops.device(dev0):
|
||||
if cluster_resolver.task_id == 0:
|
||||
collective_ops.broadcast_send(in_tensor, (1,), dtypes.float32,
|
||||
group_size, group_key, instance_key)
|
||||
else:
|
||||
collective_ops.broadcast_recv((1,), dtypes.float32, group_size,
|
||||
group_key, instance_key)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
multi_process_runner.test_main()
|
||||
|
Loading…
Reference in New Issue
Block a user