Unblock dependencies in RingReducer
.
In graphs that contain both ring and nccl collectives, we impose an execution order on all collectives. However, before this change we only unblocked dependent collectives in the NCCL kernel, not in the ring kernel. This would cause a hang in graphs that have both. This change adds the unblocking code early in the ring kernel. It should be a no-op for ring-only graphs. This change also renames `BaseCollectiveExecutor::Launched` to `BaseCollectiveExecutor::UnblockDependencies`. PiperOrigin-RevId: 275974863 Change-Id: Ief78d867356ad977cd8133ca9365038637a92d69
This commit is contained in:
parent
8f86c271c9
commit
09606bba45
@ -341,7 +341,8 @@ void BaseCollectiveExecutor::WaitForDependencies(
|
||||
VLOG(1) << "Unblocking collective " << col_params.ToString();
|
||||
}
|
||||
|
||||
void BaseCollectiveExecutor::Launched(const CollectiveParams& col_params) {
|
||||
void BaseCollectiveExecutor::UnblockDependencies(
|
||||
const CollectiveParams& col_params) {
|
||||
mutex_lock l(launch_mu_);
|
||||
if (launched_.find(col_params.instance.instance_key) == launched_.end()) {
|
||||
const string& task_name =
|
||||
|
@ -154,7 +154,7 @@ class BaseCollectiveExecutor : public CollectiveExecutor {
|
||||
// Record that this collective has completed the portion of the implementation
|
||||
// that needs to be ordered wrt other collectives, to unblock any of its
|
||||
// dependent ops.
|
||||
void Launched(const CollectiveParams& col_params) override;
|
||||
void UnblockDependencies(const CollectiveParams& col_params) override;
|
||||
|
||||
protected:
|
||||
const int64 step_id_;
|
||||
|
@ -56,6 +56,10 @@ Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) {
|
||||
void RingReducer::Run(StatusCallback done) {
|
||||
CHECK(col_ctx_);
|
||||
CHECK(col_params_);
|
||||
// Since `RingReducer` doesn't require non-overlapping collectives, unblock
|
||||
// any collective that is blocked on this instance.
|
||||
col_ctx_->col_exec->UnblockDependencies(*col_params_);
|
||||
|
||||
done_ = std::move(done);
|
||||
group_size_ = col_params_->group.group_size;
|
||||
num_subdivs_ = static_cast<int>(
|
||||
|
@ -238,8 +238,9 @@ class RingReducerTest : public ::testing::Test {
|
||||
|
||||
// Set up all of the fake device contexts.
|
||||
for (int wi = 0; wi < num_workers; ++wi) {
|
||||
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
|
||||
col_params_.instance.num_devices_per_task[task_name] = num_devices;
|
||||
for (int di = 0; di < num_devices; ++di) {
|
||||
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
|
||||
string dev_name = strings::StrCat(task_name, "/cpu:", di);
|
||||
if (device_type == DEVICE_GPU) {
|
||||
dev_name =
|
||||
|
@ -302,10 +302,10 @@ class CollectiveExecutor : public PeerAccessInterface, public core::RefCounted {
|
||||
// execution, where safety is defined as: ordered with respect to the
|
||||
// collective instances defined in the callee's `wait_for` attribute.
|
||||
virtual void WaitForDependencies(const CollectiveParams& col_params) {}
|
||||
// `Launched` unblocks the dependent collective instances by recording that
|
||||
// this callee device has completed the critical portion of the collective
|
||||
// execution.
|
||||
virtual void Launched(const CollectiveParams& col_params) {}
|
||||
// `UnblockDependencies` unblocks the dependent collective instances by
|
||||
// recording that this caller's device has completed the critical portion of
|
||||
// the collective execution.
|
||||
virtual void UnblockDependencies(const CollectiveParams& col_params) {}
|
||||
|
||||
// Used to designate an invalid group or instance key.
|
||||
static int64 kInvalidId;
|
||||
|
@ -68,10 +68,10 @@ void NcclBroadcaster::Run(StatusCallback done) {
|
||||
{
|
||||
// When all devices at this worker have called `SignalMultiNodeReady`, the
|
||||
// `NcclManager` will enqueue the NCCL kernel on the NCCL stream. Thus the
|
||||
// implementation of `Launched` keeps track of the number of devices that
|
||||
// have launched.
|
||||
// implementation of `UnblockDepdendencies` keeps track of the number of
|
||||
// devices that have launched.
|
||||
profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
|
||||
col_ctx_->col_exec->Launched(*col_params_);
|
||||
col_ctx_->col_exec->UnblockDependencies(*col_params_);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -58,10 +58,10 @@ void NcclGatherer::Run(StatusCallback done) {
|
||||
{
|
||||
// When all devices at this worker have called `SignalMultiNodeReady`, the
|
||||
// `NcclManager` will enqueue the NCCL kernel on the NCCL stream. Thus the
|
||||
// implementation of `Launched` keeps track of the number of devices that
|
||||
// have launched.
|
||||
// implementation of `UnblockDependencies` keeps track of the number of
|
||||
// devices that have launched.
|
||||
profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
|
||||
col_ctx_->col_exec->Launched(*col_params_);
|
||||
col_ctx_->col_exec->UnblockDependencies(*col_params_);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -179,10 +179,10 @@ void NcclReducer::Run(StatusCallback done) {
|
||||
{
|
||||
// When all devices at this worker have called `SignalMultiNodeReady`, the
|
||||
// `NcclManager` will enqueue the NCCL kernel on the NCCL stream. Thus the
|
||||
// implementation of `Launched` keeps track of the number of devices that
|
||||
// have launched.
|
||||
// implementation of `UnblockDependencies` keeps track of the number of
|
||||
// devices that have launched.
|
||||
profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
|
||||
col_ctx_->col_exec->Launched(*col_params_);
|
||||
col_ctx_->col_exec->UnblockDependencies(*col_params_);
|
||||
}
|
||||
|
||||
// If no final_op, then this OpKernel is non-blocking.
|
||||
|
Loading…
Reference in New Issue
Block a user