Support aborting param resolution in multi worker collectives
PiperOrigin-RevId: 338374360 Change-Id: I36d70d33290831907b0ca970adfec802343a6f04
This commit is contained in:
parent
ec8ef1a4f2
commit
6879015a25
@ -779,7 +779,7 @@ void CollectiveParamResolverLocal::StartAbort(const Status& s) {
|
|||||||
{
|
{
|
||||||
mutex_lock l(status_mu_);
|
mutex_lock l(status_mu_);
|
||||||
if (!status_.ok()) {
|
if (!status_.ok()) {
|
||||||
VLOG(1) << "CollectiveParamResolverLocal already aborted. Ignoring "
|
VLOG(2) << "CollectiveParamResolverLocal already aborted. Ignoring "
|
||||||
"subsequent abortion with status: "
|
"subsequent abortion with status: "
|
||||||
<< s;
|
<< s;
|
||||||
return;
|
return;
|
||||||
|
@ -295,19 +295,30 @@ void CollectiveParamResolverDistributed::CompleteGroupDistributed(
|
|||||||
CompleteGroupCall* call =
|
CompleteGroupCall* call =
|
||||||
new CompleteGroupCall(cp->group, device, cp->instance.type, cancel_mgr,
|
new CompleteGroupCall(cp->group, device, cp->instance.type, cancel_mgr,
|
||||||
group_leader_, worker_cache_);
|
group_leader_, worker_cache_);
|
||||||
call->Start([this, device, cp, call, done](const Status& s) {
|
CancellationToken abortion_token =
|
||||||
if (s.ok()) {
|
abortion_cancel_mgr_.get_cancellation_token();
|
||||||
Status status = UpdateGroupCache(call->resp_);
|
bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
|
||||||
if (status.ok()) {
|
abortion_token, [call] { call->Cancel(); });
|
||||||
CompleteGroupLocal(device, cp, done);
|
if (already_aborted) {
|
||||||
} else {
|
done(errors::Cancelled("collective ops already aborted"), nullptr);
|
||||||
done(status, nullptr);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
done(s, nullptr);
|
|
||||||
}
|
|
||||||
delete call;
|
delete call;
|
||||||
});
|
return;
|
||||||
|
}
|
||||||
|
call->Start(
|
||||||
|
[this, device, cp, call, abortion_token, done](const Status& s) {
|
||||||
|
abortion_cancel_mgr_.DeregisterCallback(abortion_token);
|
||||||
|
if (s.ok()) {
|
||||||
|
Status status = UpdateGroupCache(call->resp_);
|
||||||
|
if (status.ok()) {
|
||||||
|
CompleteGroupLocal(device, cp, done);
|
||||||
|
} else {
|
||||||
|
done(status, nullptr);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
done(s, nullptr);
|
||||||
|
}
|
||||||
|
delete call;
|
||||||
|
});
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
return CompleteGroupLocal(device, cp, done);
|
return CompleteGroupLocal(device, cp, done);
|
||||||
@ -373,7 +384,17 @@ void CollectiveParamResolverDistributed::CompleteInstanceDistributed(
|
|||||||
CompleteInstanceCall* call = new CompleteInstanceCall(
|
CompleteInstanceCall* call = new CompleteInstanceCall(
|
||||||
cp->group, cp->instance, cp->name, device, cp->is_source, cancel_mgr,
|
cp->group, cp->instance, cp->name, device, cp->is_source, cancel_mgr,
|
||||||
group_leader_, worker_cache_);
|
group_leader_, worker_cache_);
|
||||||
call->Start([this, device, gr, cp, call, done](Status s) {
|
CancellationToken abortion_token =
|
||||||
|
abortion_cancel_mgr_.get_cancellation_token();
|
||||||
|
bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
|
||||||
|
abortion_token, [call] { call->Cancel(); });
|
||||||
|
if (already_aborted) {
|
||||||
|
done(errors::Cancelled("collective ops already aborted"));
|
||||||
|
delete call;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
call->Start([this, device, gr, cp, call, abortion_token, done](Status s) {
|
||||||
|
abortion_cancel_mgr_.DeregisterCallback(abortion_token);
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
s = UpdateInstanceCache(gr, cp, call->resp_);
|
s = UpdateInstanceCache(gr, cp, call->resp_);
|
||||||
}
|
}
|
||||||
@ -388,4 +409,19 @@ void CollectiveParamResolverDistributed::CompleteInstanceDistributed(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CollectiveParamResolverDistributed::StartAbort(const Status& s) {
|
||||||
|
{
|
||||||
|
mutex_lock l(status_mu_);
|
||||||
|
if (!status_.ok()) {
|
||||||
|
VLOG(2) << "CollectiveParamResolverDistributed already aborted. Ignoring "
|
||||||
|
"subsequent abortion with status: "
|
||||||
|
<< s;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
status_ = s;
|
||||||
|
}
|
||||||
|
StartAbortLocal(s);
|
||||||
|
abortion_cancel_mgr_.StartCancel();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_
|
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
|
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
|
||||||
|
#include "tensorflow/core/framework/cancellation.h"
|
||||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
|
||||||
@ -47,6 +48,8 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal {
|
|||||||
CancellationManager* cancel_mgr,
|
CancellationManager* cancel_mgr,
|
||||||
const StatusCallback& done) override;
|
const StatusCallback& done) override;
|
||||||
|
|
||||||
|
void StartAbort(const Status& s) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Returns the cached group iff there's an entry for this group_key in the
|
// Returns the cached group iff there's an entry for this group_key in the
|
||||||
// local group_table_; returns nullptr otherwise.
|
// local group_table_; returns nullptr otherwise.
|
||||||
@ -87,6 +90,7 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal {
|
|||||||
|
|
||||||
WorkerCacheInterface* worker_cache_; // Not owned
|
WorkerCacheInterface* worker_cache_; // Not owned
|
||||||
const string group_leader_;
|
const string group_leader_;
|
||||||
|
CancellationManager abortion_cancel_mgr_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -33,6 +33,7 @@ from tensorflow.python.distribute import multi_worker_test_base
|
|||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import collective_ops
|
from tensorflow.python.ops import collective_ops
|
||||||
@ -51,6 +52,12 @@ def enable_collective_ops(cluster_resolver):
|
|||||||
context.context().enable_collective_ops(server_def)
|
context.context().enable_collective_ops(server_def)
|
||||||
|
|
||||||
|
|
||||||
|
def enable_collective_ops_with_barrier(cluster_resolver):
|
||||||
|
multi_process_runner.get_barrier().wait()
|
||||||
|
enable_collective_ops(cluster_resolver)
|
||||||
|
multi_process_runner.get_barrier().wait()
|
||||||
|
|
||||||
|
|
||||||
device_combination = (
|
device_combination = (
|
||||||
combinations.combine(device="CPU", communication="RING", required_gpus=0) +
|
combinations.combine(device="CPU", communication="RING", required_gpus=0) +
|
||||||
combinations.combine(
|
combinations.combine(
|
||||||
@ -167,7 +174,7 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.skipTest("b/171358086: cannot test multi worker NCCL")
|
self.skipTest("b/171358086: cannot test multi worker NCCL")
|
||||||
dev0 = "/device:%s:0" % device
|
dev0 = "/device:%s:0" % device
|
||||||
cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
|
cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
|
||||||
enable_collective_ops(cluster_resolver)
|
enable_collective_ops_with_barrier(cluster_resolver)
|
||||||
group_size = 2
|
group_size = 2
|
||||||
group_key = 100
|
group_key = 100
|
||||||
instance_key = 100
|
instance_key = 100
|
||||||
@ -214,9 +221,7 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
t.join()
|
t.join()
|
||||||
|
|
||||||
# Enable collective ops again in order to reset the collective executor.
|
# Enable collective ops again in order to reset the collective executor.
|
||||||
multi_process_runner.get_barrier().wait()
|
enable_collective_ops_with_barrier(cluster_resolver)
|
||||||
enable_collective_ops(cluster_resolver)
|
|
||||||
multi_process_runner.get_barrier().wait()
|
|
||||||
with ops.device(dev0):
|
with ops.device(dev0):
|
||||||
collective_ops.all_reduce(
|
collective_ops.all_reduce(
|
||||||
in_tensor,
|
in_tensor,
|
||||||
@ -225,6 +230,102 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
instance_key,
|
instance_key,
|
||||||
communication_hint=communication)
|
communication_hint=communication)
|
||||||
|
|
||||||
|
def testAbortGroupParamsResolution(self, device, communication):
|
||||||
|
if communication == "NCCL":
|
||||||
|
self.skipTest("b/171358086: cannot test multi worker NCCL")
|
||||||
|
dev0 = "/device:%s:0" % device
|
||||||
|
cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
|
||||||
|
enable_collective_ops_with_barrier(cluster_resolver)
|
||||||
|
group_size = 2
|
||||||
|
group_key = 100
|
||||||
|
instance_key = 100
|
||||||
|
in_tensor = constant_op.constant([1.])
|
||||||
|
|
||||||
|
if cluster_resolver.task_id == 1:
|
||||||
|
|
||||||
|
def abort_fn():
|
||||||
|
time.sleep(2)
|
||||||
|
context.context().abort_collective_ops(errors.UNAVAILABLE, "peer down")
|
||||||
|
|
||||||
|
t = threading.Thread(target=abort_fn)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(errors.UnavailableError, "peer down"):
|
||||||
|
# This hangs on params resolution since we're only launching one
|
||||||
|
# collective for a group size of 2.
|
||||||
|
with ops.device(dev0):
|
||||||
|
collective_ops.all_reduce(in_tensor, group_size, group_key,
|
||||||
|
instance_key)
|
||||||
|
|
||||||
|
# After abortion, subsequent collectives should fail immediately.
|
||||||
|
with self.assertRaisesRegex(errors.UnavailableError, "peer down"):
|
||||||
|
with ops.device(dev0):
|
||||||
|
collective_ops.all_reduce(in_tensor, group_size, group_key,
|
||||||
|
instance_key)
|
||||||
|
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
# Enable collective ops again in order to reset the collective executor.
|
||||||
|
enable_collective_ops_with_barrier(cluster_resolver)
|
||||||
|
with ops.device(dev0):
|
||||||
|
collective_ops.all_reduce(in_tensor, group_size, group_key, instance_key)
|
||||||
|
|
||||||
|
def testAbortInstanceParamsResolution(self, device, communication):
|
||||||
|
if communication == "NCCL":
|
||||||
|
self.skipTest("b/171358086: cannot test multi worker NCCL")
|
||||||
|
dev0 = "/device:%s:0" % device
|
||||||
|
cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
|
||||||
|
enable_collective_ops_with_barrier(cluster_resolver)
|
||||||
|
group_size = 2
|
||||||
|
group_key = 100
|
||||||
|
instance_key = 100
|
||||||
|
in_tensor = constant_op.constant([1.])
|
||||||
|
|
||||||
|
# First perform a normal all-reduce to complete the group resolution.
|
||||||
|
with ops.device(dev0):
|
||||||
|
collective_ops.all_reduce(in_tensor, group_size, group_key, instance_key)
|
||||||
|
|
||||||
|
# We use broadcast to test aborting instance resolution since only broadcast
|
||||||
|
# waits for the group.
|
||||||
|
|
||||||
|
if cluster_resolver.task_id == 1:
|
||||||
|
|
||||||
|
def abort_fn():
|
||||||
|
time.sleep(2)
|
||||||
|
context.context().abort_collective_ops(errors.UNAVAILABLE, "peer down")
|
||||||
|
|
||||||
|
t = threading.Thread(target=abort_fn)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
# Use a different instance key to trigger another instance resolution.
|
||||||
|
instance_key = 101
|
||||||
|
with self.assertRaisesRegex(errors.UnavailableError, "peer down"):
|
||||||
|
# This hangs on params resolution since we're only launching one
|
||||||
|
# collective for a group size of 2.
|
||||||
|
with ops.device(dev0):
|
||||||
|
collective_ops.broadcast_send(in_tensor, (1,), dtypes.float32,
|
||||||
|
group_size, group_key, instance_key)
|
||||||
|
|
||||||
|
# After abortion, subsequent collectives should fail immediately.
|
||||||
|
with self.assertRaisesRegex(errors.UnavailableError, "peer down"):
|
||||||
|
with ops.device(dev0):
|
||||||
|
collective_ops.broadcast_send(in_tensor, (1,), dtypes.float32,
|
||||||
|
group_size, group_key, instance_key)
|
||||||
|
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
# Enable collective ops again in order to reset the collective executor.
|
||||||
|
enable_collective_ops_with_barrier(cluster_resolver)
|
||||||
|
# Reassign instance_key so that it's the same on each worker.
|
||||||
|
instance_key = 100
|
||||||
|
with ops.device(dev0):
|
||||||
|
if cluster_resolver.task_id == 0:
|
||||||
|
collective_ops.broadcast_send(in_tensor, (1,), dtypes.float32,
|
||||||
|
group_size, group_key, instance_key)
|
||||||
|
else:
|
||||||
|
collective_ops.broadcast_recv((1,), dtypes.float32, group_size,
|
||||||
|
group_key, instance_key)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
multi_process_runner.test_main()
|
multi_process_runner.test_main()
|
||||||
|
Loading…
Reference in New Issue
Block a user