Unblock dependencies in RingReducer.

In graphs that contain both ring and nccl collectives, we impose an execution
order on all collectives.  However, before this change we only unblocked
dependent collectives in the NCCL kernel, not in the ring kernel.  This would
cause a hang in graphs that have both.

This change adds the unblocking code early in the ring kernel.  It should be a
no-op for ring-only graphs.

This change also renames `BaseCollectiveExecutor::Launched` to
`BaseCollectiveExecutor::UnblockDependencies`.

PiperOrigin-RevId: 275974863
Change-Id: Ief78d867356ad977cd8133ca9365038637a92d69
This commit is contained in:
Ayush Dubey 2019-10-21 19:24:23 -07:00 committed by TensorFlower Gardener
parent 8f86c271c9
commit 09606bba45
8 changed files with 22 additions and 16 deletions

View File

@ -341,7 +341,8 @@ void BaseCollectiveExecutor::WaitForDependencies(
VLOG(1) << "Unblocking collective " << col_params.ToString();
}
void BaseCollectiveExecutor::Launched(const CollectiveParams& col_params) {
void BaseCollectiveExecutor::UnblockDependencies(
const CollectiveParams& col_params) {
mutex_lock l(launch_mu_);
if (launched_.find(col_params.instance.instance_key) == launched_.end()) {
const string& task_name =

View File

@ -154,7 +154,7 @@ class BaseCollectiveExecutor : public CollectiveExecutor {
// Record that this collective has completed the portion of the implementation
// that needs to be ordered wrt other collectives, to unblock any of its
// dependent ops.
void Launched(const CollectiveParams& col_params) override;
void UnblockDependencies(const CollectiveParams& col_params) override;
protected:
const int64 step_id_;

View File

@ -56,6 +56,10 @@ Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) {
void RingReducer::Run(StatusCallback done) {
CHECK(col_ctx_);
CHECK(col_params_);
// Since `RingReducer` doesn't require non-overlapping collectives, unblock
// any collective that is blocked on this instance.
col_ctx_->col_exec->UnblockDependencies(*col_params_);
done_ = std::move(done);
group_size_ = col_params_->group.group_size;
num_subdivs_ = static_cast<int>(

View File

@ -238,8 +238,9 @@ class RingReducerTest : public ::testing::Test {
// Set up all of the fake device contexts.
for (int wi = 0; wi < num_workers; ++wi) {
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
col_params_.instance.num_devices_per_task[task_name] = num_devices;
for (int di = 0; di < num_devices; ++di) {
string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
string dev_name = strings::StrCat(task_name, "/cpu:", di);
if (device_type == DEVICE_GPU) {
dev_name =

View File

@ -302,10 +302,10 @@ class CollectiveExecutor : public PeerAccessInterface, public core::RefCounted {
// execution, where safety is defined as: ordered with respect to the
// collective instances defined in the callee's `wait_for` attribute.
virtual void WaitForDependencies(const CollectiveParams& col_params) {}
// `Launched` unblocks the dependent collective instances by recording that
// this callee device has completed the critical portion of the collective
// execution.
virtual void Launched(const CollectiveParams& col_params) {}
// `UnblockDependencies` unblocks the dependent collective instances by
// recording that this caller's device has completed the critical portion of
// the collective execution.
virtual void UnblockDependencies(const CollectiveParams& col_params) {}
// Used to designate an invalid group or instance key.
static int64 kInvalidId;

View File

@ -68,10 +68,10 @@ void NcclBroadcaster::Run(StatusCallback done) {
{
// When all devices at this worker have called `SignalMultiNodeReady`, the
// `NcclManager` will enqueue the NCCL kernel on the NCCL stream. Thus the
// implementation of `Launched` keeps track of the number of devices that
// have launched.
// implementation of `UnblockDepdendencies` keeps track of the number of
// devices that have launched.
profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
col_ctx_->col_exec->Launched(*col_params_);
col_ctx_->col_exec->UnblockDependencies(*col_params_);
}
}

View File

@ -58,10 +58,10 @@ void NcclGatherer::Run(StatusCallback done) {
{
// When all devices at this worker have called `SignalMultiNodeReady`, the
// `NcclManager` will enqueue the NCCL kernel on the NCCL stream. Thus the
// implementation of `Launched` keeps track of the number of devices that
// have launched.
// implementation of `UnblockDependencies` keeps track of the number of
// devices that have launched.
profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
col_ctx_->col_exec->Launched(*col_params_);
col_ctx_->col_exec->UnblockDependencies(*col_params_);
}
}

View File

@ -179,10 +179,10 @@ void NcclReducer::Run(StatusCallback done) {
{
// When all devices at this worker have called `SignalMultiNodeReady`, the
// `NcclManager` will enqueue the NCCL kernel on the NCCL stream. Thus the
// implementation of `Launched` keeps track of the number of devices that
// have launched.
// implementation of `UnblockDependencies` keeps track of the number of
// devices that have launched.
profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
col_ctx_->col_exec->Launched(*col_params_);
col_ctx_->col_exec->UnblockDependencies(*col_params_);
}
// If no final_op, then this OpKernel is non-blocking.