Ensure that CollectiveParams outlives all references to it.
				
					
				
			Before this change, it was possible to access a `const CollectiveParams&` after it was destroyed. For example, the call to `UnblockDependencies` in `NcclCommunicator::Enqueue` raced with the done_callback of the collective participant. This change makes `CollectiveParams` a refcounted object, and holds references everywhere it may be accessed. PiperOrigin-RevId: 355646163 Change-Id: I7fd164afe8c1c9aa1c3b77a988930a0624977c7c
This commit is contained in:
		
							parent
							
								
									e743fcee33
								
							
						
					
					
						commit
						18eaf4e8f1
					
				@ -264,7 +264,7 @@ Status BaseCollectiveExecutor::GetStatus(const Status& s) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
 | 
			
		||||
                                          const CollectiveParams& col_params,
 | 
			
		||||
                                          const CollectiveParams* col_params,
 | 
			
		||||
                                          const string& exec_key,
 | 
			
		||||
                                          StatusCallback done) {
 | 
			
		||||
  // See CompleteParamsAsync() how done() and the timeout callback interacts.
 | 
			
		||||
@ -281,7 +281,7 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
  auto timeout_microseconds = static_cast<int64>(
 | 
			
		||||
      col_params.instance.impl_details.timeout_seconds * 1'000'000);
 | 
			
		||||
      col_params->instance.impl_details.timeout_seconds * 1'000'000);
 | 
			
		||||
  if (timeout_microseconds > 0) {
 | 
			
		||||
    // TODO(xldrx): Share the timeout watchdog thread among collectives.
 | 
			
		||||
    SchedNonBlockingClosureAfter(
 | 
			
		||||
@ -297,15 +297,15 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Tensor* output = ctx->mutable_output(0);
 | 
			
		||||
  const Tensor* input = (col_params.instance.type == REDUCTION_COLLECTIVE ||
 | 
			
		||||
                         col_params.instance.type == GATHER_COLLECTIVE ||
 | 
			
		||||
                         col_params.instance.type == PERMUTE_COLLECTIVE ||
 | 
			
		||||
                         (col_params.instance.type == BROADCAST_COLLECTIVE &&
 | 
			
		||||
                          col_params.is_source))
 | 
			
		||||
  const Tensor* input = (col_params->instance.type == REDUCTION_COLLECTIVE ||
 | 
			
		||||
                         col_params->instance.type == GATHER_COLLECTIVE ||
 | 
			
		||||
                         col_params->instance.type == PERMUTE_COLLECTIVE ||
 | 
			
		||||
                         (col_params->instance.type == BROADCAST_COLLECTIVE &&
 | 
			
		||||
                          col_params->is_source))
 | 
			
		||||
                            ? &ctx->input(0)
 | 
			
		||||
                            : nullptr;
 | 
			
		||||
  CollectiveImplementationInterface* col_impl = nullptr;
 | 
			
		||||
  Status status = CreateCollective(col_params, &col_impl);
 | 
			
		||||
  Status status = CreateCollective(*col_params, &col_impl);
 | 
			
		||||
  if (!status.ok()) {
 | 
			
		||||
    done_safe(status);
 | 
			
		||||
    DCHECK_EQ(nullptr, col_impl);
 | 
			
		||||
 | 
			
		||||
@ -110,7 +110,7 @@ class BaseCollectiveExecutor : public CollectiveExecutor {
 | 
			
		||||
 | 
			
		||||
  void StartAbort(const Status& s) override TF_LOCKS_EXCLUDED(status_mu_);
 | 
			
		||||
 | 
			
		||||
  void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams& col_params,
 | 
			
		||||
  void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams* col_params,
 | 
			
		||||
                    const string& exec_key, StatusCallback done) override;
 | 
			
		||||
 | 
			
		||||
  void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
 | 
			
		||||
 | 
			
		||||
@ -513,14 +513,14 @@ void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
 | 
			
		||||
 | 
			
		||||
void CollectiveParamResolverLocal::InitInstanceSharedParams(
 | 
			
		||||
    const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir) {
 | 
			
		||||
  ir->shared.instance = cp->instance;
 | 
			
		||||
  ir->shared.default_rank = -1;
 | 
			
		||||
  ir->shared->instance = cp->instance;
 | 
			
		||||
  ir->shared->default_rank = -1;
 | 
			
		||||
 | 
			
		||||
  // Set is_local and task_names in *shared prior to invoking
 | 
			
		||||
  // GetDeviceAttributesAsync.  In a distributed context this function can be
 | 
			
		||||
  // called by a derived class, some of the devices may be non-local and
 | 
			
		||||
  // GetDeviceAttributesAsync will use those fields to launch RPCs.
 | 
			
		||||
  CompleteTaskIsLocal(task_name_, &ir->shared);
 | 
			
		||||
  CompleteTaskIsLocal(task_name_, ir->shared);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks
 | 
			
		||||
@ -662,11 +662,11 @@ void CollectiveParamResolverLocal::CompleteInstanceLocal(
 | 
			
		||||
  if (!created_irec) {
 | 
			
		||||
    // Check that the preexisting IRec is consistent with the params passed into
 | 
			
		||||
    // this invocation.
 | 
			
		||||
    if (ir->shared.instance.type != cp->instance.type ||
 | 
			
		||||
        ir->shared.instance.data_type != cp->instance.data_type) {
 | 
			
		||||
    if (ir->shared->instance.type != cp->instance.type ||
 | 
			
		||||
        ir->shared->instance.data_type != cp->instance.data_type) {
 | 
			
		||||
      done(errors::Internal("Collective instance ", cp->instance.instance_key,
 | 
			
		||||
                            " expected type ", ir->shared.instance.type,
 | 
			
		||||
                            " and data_type ", ir->shared.instance.data_type,
 | 
			
		||||
                            " expected type ", ir->shared->instance.type,
 | 
			
		||||
                            " and data_type ", ir->shared->instance.data_type,
 | 
			
		||||
                            " but got type ", cp->instance.type,
 | 
			
		||||
                            " and data_type ", cp->instance.data_type));
 | 
			
		||||
      return;
 | 
			
		||||
@ -686,7 +686,7 @@ void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
 | 
			
		||||
    status = ir->status;
 | 
			
		||||
    if (status.ok()) {
 | 
			
		||||
      // custom operator= does a deep copy.
 | 
			
		||||
      cp->instance = ir->shared.instance;
 | 
			
		||||
      cp->instance = ir->shared->instance;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  if (!status.ok()) {
 | 
			
		||||
 | 
			
		||||
@ -98,7 +98,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
 | 
			
		||||
  struct InstanceRec {
 | 
			
		||||
    mutex mu;
 | 
			
		||||
    // Values to be shared by all instances, constant after initialization.
 | 
			
		||||
    CollectiveParams shared;
 | 
			
		||||
    CollectiveParams* shared;
 | 
			
		||||
    // If an error occurs during initialization this structure stays in the
 | 
			
		||||
    // table with a non-OK status. Purging the table and restarting needs to be
 | 
			
		||||
    // done at a higher level.
 | 
			
		||||
@ -113,7 +113,9 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
 | 
			
		||||
    std::vector<bool> known TF_GUARDED_BY(mu);
 | 
			
		||||
    std::vector<IRConsumer> known_waiters TF_GUARDED_BY(mu);
 | 
			
		||||
 | 
			
		||||
    InstanceRec() : source_rank(-1), known_count(0) {}
 | 
			
		||||
    InstanceRec()
 | 
			
		||||
        : shared(new CollectiveParams()), source_rank(-1), known_count(0) {}
 | 
			
		||||
    ~InstanceRec() { shared->Unref(); }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  // Find the InstanceRec with the same instance_key as cp.  If it doesn't
 | 
			
		||||
 | 
			
		||||
@ -161,11 +161,12 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) {
 | 
			
		||||
  CollectiveParams cps[NUM_DEVS];
 | 
			
		||||
  CollectiveParams* cps[NUM_DEVS];
 | 
			
		||||
  Status statuses[NUM_DEVS];
 | 
			
		||||
  Notification note[NUM_DEVS];
 | 
			
		||||
  for (int i = 0; i < NUM_DEVS; ++i) {
 | 
			
		||||
    CollectiveParams* cp = &cps[i];
 | 
			
		||||
    cps[i] = new CollectiveParams();
 | 
			
		||||
    CollectiveParams* cp = cps[i];
 | 
			
		||||
    cp->group.group_key = 1;
 | 
			
		||||
    cp->group.group_size = 3;
 | 
			
		||||
    cp->group.device_type = DeviceType("CPU");
 | 
			
		||||
@ -192,17 +193,18 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) {
 | 
			
		||||
  }
 | 
			
		||||
  for (int i = 0; i < NUM_DEVS; ++i) {
 | 
			
		||||
    TF_ASSERT_OK(statuses[i]);
 | 
			
		||||
    ASSERT_EQ(cps[i].group.device_names.size(), 3);
 | 
			
		||||
    ASSERT_EQ(cps[i]->group.device_names.size(), 3);
 | 
			
		||||
    for (int j = 0; j < NUM_DEVS; ++j) {
 | 
			
		||||
      EXPECT_EQ(
 | 
			
		||||
          strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", j),
 | 
			
		||||
          cps[i].group.device_names[j]);
 | 
			
		||||
      EXPECT_TRUE(cps[i].task.is_local[j]);
 | 
			
		||||
          cps[i]->group.device_names[j]);
 | 
			
		||||
      EXPECT_TRUE(cps[i]->task.is_local[j]);
 | 
			
		||||
    }
 | 
			
		||||
    EXPECT_EQ(cps[i].instance.impl_details.subdiv_source_rank.size(), 0);
 | 
			
		||||
    EXPECT_FALSE(cps[i].is_source);
 | 
			
		||||
    EXPECT_EQ(cps[i].default_rank, i);
 | 
			
		||||
    EXPECT_TRUE(cps[i].group.same_num_devices_per_task);
 | 
			
		||||
    EXPECT_EQ(cps[i]->instance.impl_details.subdiv_source_rank.size(), 0);
 | 
			
		||||
    EXPECT_FALSE(cps[i]->is_source);
 | 
			
		||||
    EXPECT_EQ(cps[i]->default_rank, i);
 | 
			
		||||
    EXPECT_TRUE(cps[i]->group.same_num_devices_per_task);
 | 
			
		||||
    cps[i]->Unref();
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -223,11 +225,12 @@ void InitializeCollectiveParamsForBroadcast(int instance_key, int device_idx,
 | 
			
		||||
 | 
			
		||||
TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
 | 
			
		||||
  constexpr int kInstanceKey = 5;
 | 
			
		||||
  CollectiveParams cps[NUM_DEVS];
 | 
			
		||||
  CollectiveParams* cps[NUM_DEVS];
 | 
			
		||||
  Status statuses[NUM_DEVS];
 | 
			
		||||
  Notification note[NUM_DEVS];
 | 
			
		||||
  for (int i = 0; i < NUM_DEVS; ++i) {
 | 
			
		||||
    CollectiveParams* cp = &cps[i];
 | 
			
		||||
    cps[i] = new CollectiveParams();
 | 
			
		||||
    CollectiveParams* cp = cps[i];
 | 
			
		||||
    InitializeCollectiveParamsForBroadcast(kInstanceKey, i, i == 1, cp);
 | 
			
		||||
    Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() {
 | 
			
		||||
      string device =
 | 
			
		||||
@ -245,16 +248,17 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
 | 
			
		||||
  }
 | 
			
		||||
  for (int i = 0; i < NUM_DEVS; ++i) {
 | 
			
		||||
    TF_ASSERT_OK(statuses[i]);
 | 
			
		||||
    ASSERT_EQ(cps[i].group.device_names.size(), 3);
 | 
			
		||||
    ASSERT_EQ(cps[i]->group.device_names.size(), 3);
 | 
			
		||||
    for (int j = 0; j < NUM_DEVS; ++j) {
 | 
			
		||||
      EXPECT_EQ(
 | 
			
		||||
          strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", j),
 | 
			
		||||
          cps[i].group.device_names[j]);
 | 
			
		||||
      EXPECT_TRUE(cps[i].task.is_local[j]);
 | 
			
		||||
          cps[i]->group.device_names[j]);
 | 
			
		||||
      EXPECT_TRUE(cps[i]->task.is_local[j]);
 | 
			
		||||
    }
 | 
			
		||||
    EXPECT_EQ(cps[i].is_source, (i == 1));
 | 
			
		||||
    EXPECT_EQ(cps[i].default_rank, i);
 | 
			
		||||
    EXPECT_TRUE(cps[i].group.same_num_devices_per_task);
 | 
			
		||||
    EXPECT_EQ(cps[i]->is_source, (i == 1));
 | 
			
		||||
    EXPECT_EQ(cps[i]->default_rank, i);
 | 
			
		||||
    EXPECT_TRUE(cps[i]->group.same_num_devices_per_task);
 | 
			
		||||
    cps[i]->Unref();
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -263,11 +267,12 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
 | 
			
		||||
// get an internal error from param resolution.
 | 
			
		||||
TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcastForgotSender) {
 | 
			
		||||
  constexpr int kInstanceKey = 8;
 | 
			
		||||
  CollectiveParams cps[NUM_DEVS];
 | 
			
		||||
  CollectiveParams* cps[NUM_DEVS];
 | 
			
		||||
  Status statuses[NUM_DEVS];
 | 
			
		||||
  Notification note[NUM_DEVS];
 | 
			
		||||
  for (int i = 0; i < NUM_DEVS; ++i) {
 | 
			
		||||
    CollectiveParams* cp = &cps[i];
 | 
			
		||||
    cps[i] = new CollectiveParams();
 | 
			
		||||
    CollectiveParams* cp = cps[i];
 | 
			
		||||
    InitializeCollectiveParamsForBroadcast(kInstanceKey, i, false, cp);
 | 
			
		||||
    Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() {
 | 
			
		||||
      string device =
 | 
			
		||||
@ -291,27 +296,28 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcastForgotSender) {
 | 
			
		||||
                  " found no source for broadcast.  This could mean that there"
 | 
			
		||||
                  " were group_size=",
 | 
			
		||||
                  NUM_DEVS, " BcastRecvs but no BcastSend."));
 | 
			
		||||
    cps[i]->Unref();
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
CollectiveParams MakeCollectiveParams(int group_key, int instance_key,
 | 
			
		||||
                                      bool is_source) {
 | 
			
		||||
  CollectiveParams cp;
 | 
			
		||||
  cp.group.group_key = group_key;
 | 
			
		||||
  cp.group.group_size = NUM_DEVS;
 | 
			
		||||
  cp.group.device_type = DeviceType("CPU");
 | 
			
		||||
  cp.group.num_tasks = 1;
 | 
			
		||||
  cp.instance.instance_key = instance_key;
 | 
			
		||||
CollectiveParams* MakeCollectiveParams(int group_key, int instance_key,
 | 
			
		||||
                                       bool is_source) {
 | 
			
		||||
  auto* cp = new CollectiveParams();
 | 
			
		||||
  cp->group.group_key = group_key;
 | 
			
		||||
  cp->group.group_size = NUM_DEVS;
 | 
			
		||||
  cp->group.device_type = DeviceType("CPU");
 | 
			
		||||
  cp->group.num_tasks = 1;
 | 
			
		||||
  cp->instance.instance_key = instance_key;
 | 
			
		||||
  // CompleteInstanceLocal only waits for the group for broadcasts.
 | 
			
		||||
  // Testing with broadcasts yields better coverage.
 | 
			
		||||
  cp.instance.type = BROADCAST_COLLECTIVE;
 | 
			
		||||
  cp.is_source = is_source;
 | 
			
		||||
  cp->instance.type = BROADCAST_COLLECTIVE;
 | 
			
		||||
  cp->is_source = is_source;
 | 
			
		||||
  return cp;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(CollectiveParamResolverLocalTest, AbortPendingGroup) {
 | 
			
		||||
  CancellationManager cancel_mgr;
 | 
			
		||||
  std::vector<CollectiveParams> cp(NUM_DEVS - 1);
 | 
			
		||||
  std::vector<CollectiveParams*> cp(NUM_DEVS - 1);
 | 
			
		||||
  BlockingCounter start(NUM_DEVS - 1);
 | 
			
		||||
  BlockingCounter done(NUM_DEVS - 1);
 | 
			
		||||
  for (int i = 0; i < NUM_DEVS - 1; ++i) {
 | 
			
		||||
@ -320,11 +326,12 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingGroup) {
 | 
			
		||||
          strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
 | 
			
		||||
      cp[i] = MakeCollectiveParams(/*group_key*/ 100, /*instance_key*/ 100,
 | 
			
		||||
                                   /*is_source*/ i == 0);
 | 
			
		||||
      prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
 | 
			
		||||
                                &cancel_mgr, [&done](const Status& s) {
 | 
			
		||||
      prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp[i], &cancel_mgr,
 | 
			
		||||
                                [&done, cp = cp[i]](const Status& s) {
 | 
			
		||||
                                  EXPECT_EQ(s.code(), error::ABORTED);
 | 
			
		||||
                                  EXPECT_EQ(s.error_message(), "__aborted__");
 | 
			
		||||
                                  done.DecrementCount();
 | 
			
		||||
                                  cp->Unref();
 | 
			
		||||
                                });
 | 
			
		||||
      start.DecrementCount();
 | 
			
		||||
    });
 | 
			
		||||
@ -336,7 +343,7 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingGroup) {
 | 
			
		||||
 | 
			
		||||
TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) {
 | 
			
		||||
  CancellationManager cancel_mgr;
 | 
			
		||||
  std::vector<CollectiveParams> cp(NUM_DEVS);
 | 
			
		||||
  std::vector<CollectiveParams*> cp(NUM_DEVS);
 | 
			
		||||
  int group_key = 100;
 | 
			
		||||
  int instance_key = 100;
 | 
			
		||||
  // First do a normal CompleteParamsAsync to complete the group;
 | 
			
		||||
@ -349,10 +356,12 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) {
 | 
			
		||||
            strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
 | 
			
		||||
        cp[i] = MakeCollectiveParams(group_key, instance_key,
 | 
			
		||||
                                     /*is_source*/ i == 0);
 | 
			
		||||
        prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
 | 
			
		||||
                                  &cancel_mgr, [&done](const Status& s) {
 | 
			
		||||
        prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp[i],
 | 
			
		||||
                                  &cancel_mgr,
 | 
			
		||||
                                  [&done, cp = cp[i]](const Status& s) {
 | 
			
		||||
                                    EXPECT_EQ(s.code(), error::OK);
 | 
			
		||||
                                    done.DecrementCount();
 | 
			
		||||
                                    cp->Unref();
 | 
			
		||||
                                  });
 | 
			
		||||
      });
 | 
			
		||||
    }
 | 
			
		||||
@ -361,21 +370,21 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) {
 | 
			
		||||
  BlockingCounter start(NUM_DEVS - 1);
 | 
			
		||||
  BlockingCounter done(NUM_DEVS - 1);
 | 
			
		||||
  for (int i = 0; i < NUM_DEVS - 1; ++i) {
 | 
			
		||||
    Env::Default()->SchedClosure(
 | 
			
		||||
        [this, group_key, instance_key, i, &cancel_mgr, &cp, &start, &done] {
 | 
			
		||||
          string device =
 | 
			
		||||
              strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
 | 
			
		||||
          cp[i] = MakeCollectiveParams(group_key, instance_key + 1,
 | 
			
		||||
                                       /*is_source*/ i == 0);
 | 
			
		||||
          prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
 | 
			
		||||
                                    &cancel_mgr, [&done](const Status& s) {
 | 
			
		||||
                                      EXPECT_EQ(s.code(), error::ABORTED);
 | 
			
		||||
                                      EXPECT_EQ(s.error_message(),
 | 
			
		||||
                                                "__aborted__");
 | 
			
		||||
                                      done.DecrementCount();
 | 
			
		||||
                                    });
 | 
			
		||||
          start.DecrementCount();
 | 
			
		||||
        });
 | 
			
		||||
    Env::Default()->SchedClosure([this, group_key, instance_key, i, &cancel_mgr,
 | 
			
		||||
                                  &cp, &start, &done] {
 | 
			
		||||
      string device =
 | 
			
		||||
          strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
 | 
			
		||||
      cp[i] = MakeCollectiveParams(group_key, instance_key + 1,
 | 
			
		||||
                                   /*is_source*/ i == 0);
 | 
			
		||||
      prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp[i], &cancel_mgr,
 | 
			
		||||
                                [&done, cp = cp[i]](const Status& s) {
 | 
			
		||||
                                  EXPECT_EQ(s.code(), error::ABORTED);
 | 
			
		||||
                                  EXPECT_EQ(s.error_message(), "__aborted__");
 | 
			
		||||
                                  done.DecrementCount();
 | 
			
		||||
                                  cp->Unref();
 | 
			
		||||
                                });
 | 
			
		||||
      start.DecrementCount();
 | 
			
		||||
    });
 | 
			
		||||
  }
 | 
			
		||||
  start.Wait();
 | 
			
		||||
  prl_->StartAbort(Status(error::ABORTED, "__aborted__"));
 | 
			
		||||
@ -388,7 +397,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) {
 | 
			
		||||
  int instance_key = 100;
 | 
			
		||||
  // First do a normal CompleteParamsAsync to complete the group;
 | 
			
		||||
  {
 | 
			
		||||
    std::vector<CollectiveParams> cp(NUM_DEVS);
 | 
			
		||||
    std::vector<CollectiveParams*> cp(NUM_DEVS);
 | 
			
		||||
    BlockingCounter done(NUM_DEVS);
 | 
			
		||||
    for (int i = 0; i < NUM_DEVS; ++i) {
 | 
			
		||||
      Env::Default()->SchedClosure([this, group_key, instance_key, i,
 | 
			
		||||
@ -397,10 +406,12 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) {
 | 
			
		||||
            strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i);
 | 
			
		||||
        cp[i] = MakeCollectiveParams(group_key, instance_key,
 | 
			
		||||
                                     /*is_source*/ i == 0);
 | 
			
		||||
        prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp[i],
 | 
			
		||||
                                  &cancel_mgr, [&done](const Status& s) {
 | 
			
		||||
        prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp[i],
 | 
			
		||||
                                  &cancel_mgr,
 | 
			
		||||
                                  [&done, cp = cp[i]](const Status& s) {
 | 
			
		||||
                                    EXPECT_EQ(s.code(), error::OK);
 | 
			
		||||
                                    done.DecrementCount();
 | 
			
		||||
                                    cp->Unref();
 | 
			
		||||
                                  });
 | 
			
		||||
      });
 | 
			
		||||
    }
 | 
			
		||||
@ -411,9 +422,10 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) {
 | 
			
		||||
  auto complete_params = [this, &cancel_mgr](int group_key, int instance_key) {
 | 
			
		||||
    string device = "/job:localhost/replica:0/task:0/device:CPU:0";
 | 
			
		||||
    Notification done;
 | 
			
		||||
    auto cp = MakeCollectiveParams(group_key, instance_key,
 | 
			
		||||
                                   /*is_source*/ true);
 | 
			
		||||
    prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp, &cancel_mgr,
 | 
			
		||||
    auto* cp = MakeCollectiveParams(group_key, instance_key,
 | 
			
		||||
                                    /*is_source*/ true);
 | 
			
		||||
    core::ScopedUnref unref(cp);
 | 
			
		||||
    prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp, &cancel_mgr,
 | 
			
		||||
                              [&done](const Status& s) {
 | 
			
		||||
                                EXPECT_EQ(s.code(), error::ABORTED);
 | 
			
		||||
                                EXPECT_EQ(s.error_message(), "__aborted__");
 | 
			
		||||
@ -449,16 +461,17 @@ TEST_F(CollectiveParamResolverLocalTest, AbortNormalCompleteParamsAsync) {
 | 
			
		||||
            while (true) {
 | 
			
		||||
              Status status;
 | 
			
		||||
              Notification n;
 | 
			
		||||
              auto cp =
 | 
			
		||||
              auto* cp =
 | 
			
		||||
                  MakeCollectiveParams(/* group_key*/ key, /*instance_key*/ key,
 | 
			
		||||
                                       /*is_source*/ i == 0);
 | 
			
		||||
              prl_->CompleteParamsAsync(GetDeviceAttributes(device), &cp,
 | 
			
		||||
              prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp,
 | 
			
		||||
                                        &cancel_mgr,
 | 
			
		||||
                                        [&status, &n](const Status& s) {
 | 
			
		||||
                                          status = s;
 | 
			
		||||
                                          n.Notify();
 | 
			
		||||
                                        });
 | 
			
		||||
              n.WaitForNotification();
 | 
			
		||||
              cp->Unref();
 | 
			
		||||
              // The status should be either OK or the aborted status.
 | 
			
		||||
              if (!status.ok()) {
 | 
			
		||||
                EXPECT_EQ(status.code(), error::ABORTED);
 | 
			
		||||
 | 
			
		||||
@ -188,7 +188,7 @@ Status HierarchicalTreeBroadcaster::InitializeCollectiveContext(
 | 
			
		||||
    std::shared_ptr<CollectiveContext> col_ctx) {
 | 
			
		||||
  CHECK(col_ctx->dev_mgr);
 | 
			
		||||
  col_ctx_ = col_ctx;
 | 
			
		||||
  col_params_ = &col_ctx->col_params;
 | 
			
		||||
  col_params_ = col_ctx->col_params;
 | 
			
		||||
  return collective_util::InitializeDeviceAndLocality(
 | 
			
		||||
      col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
 | 
			
		||||
      &col_ctx->device_locality);
 | 
			
		||||
 | 
			
		||||
@ -56,23 +56,24 @@ class TrivialTest : public ::testing::Test {
 | 
			
		||||
// R = tested rank
 | 
			
		||||
// RF = receive-from rank
 | 
			
		||||
// ST = send_to rank vector
 | 
			
		||||
#define DEF_TL_TEST(D, S, R, RF, ST)                                 \
 | 
			
		||||
  TEST_F(TrivialTest, TreeLinks_##D##Devs_##S##Source_##R##Rank) {   \
 | 
			
		||||
    CollectiveParams cp;                                             \
 | 
			
		||||
    cp.group.group_size = D;                                         \
 | 
			
		||||
    cp.instance.impl_details.subdiv_source_rank = {S};               \
 | 
			
		||||
    cp.instance.impl_details.subdiv_permutations.push_back(          \
 | 
			
		||||
        std::vector<int>(D, 0));                                     \
 | 
			
		||||
    cp.subdiv_rank = {R};                                            \
 | 
			
		||||
    cp.is_source = (S == R);                                         \
 | 
			
		||||
    EXPECT_EQ(RF, HierarchicalTreeBroadcaster::TreeRecvFrom(cp, 0)); \
 | 
			
		||||
    std::vector<int> expected = ST;                                  \
 | 
			
		||||
    std::vector<int> send_to;                                        \
 | 
			
		||||
    HierarchicalTreeBroadcaster::TreeSendTo(cp, 0, &send_to);        \
 | 
			
		||||
    ASSERT_EQ(expected.size(), send_to.size());                      \
 | 
			
		||||
    for (int i = 0; i < expected.size(); ++i) {                      \
 | 
			
		||||
      EXPECT_EQ(expected[i], send_to[i]);                            \
 | 
			
		||||
    }                                                                \
 | 
			
		||||
#define DEF_TL_TEST(D, S, R, RF, ST)                                  \
 | 
			
		||||
  TEST_F(TrivialTest, TreeLinks_##D##Devs_##S##Source_##R##Rank) {    \
 | 
			
		||||
    auto* cp = new CollectiveParams();                                \
 | 
			
		||||
    core::ScopedUnref unref(cp);                                      \
 | 
			
		||||
    cp->group.group_size = D;                                         \
 | 
			
		||||
    cp->instance.impl_details.subdiv_source_rank = {S};               \
 | 
			
		||||
    cp->instance.impl_details.subdiv_permutations.push_back(          \
 | 
			
		||||
        std::vector<int>(D, 0));                                      \
 | 
			
		||||
    cp->subdiv_rank = {R};                                            \
 | 
			
		||||
    cp->is_source = (S == R);                                         \
 | 
			
		||||
    EXPECT_EQ(RF, HierarchicalTreeBroadcaster::TreeRecvFrom(*cp, 0)); \
 | 
			
		||||
    std::vector<int> expected = ST;                                   \
 | 
			
		||||
    std::vector<int> send_to;                                         \
 | 
			
		||||
    HierarchicalTreeBroadcaster::TreeSendTo(*cp, 0, &send_to);        \
 | 
			
		||||
    ASSERT_EQ(expected.size(), send_to.size());                       \
 | 
			
		||||
    for (int i = 0; i < expected.size(); ++i) {                       \
 | 
			
		||||
      EXPECT_EQ(expected[i], send_to[i]);                             \
 | 
			
		||||
    }                                                                 \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
#define V(...) std::vector<int>({__VA_ARGS__})
 | 
			
		||||
@ -196,12 +197,14 @@ class FailTestRMA : public CollectiveRemoteAccessLocal {
 | 
			
		||||
 | 
			
		||||
class HierarchicalTreeBroadcasterTest : public ::testing::Test {
 | 
			
		||||
 protected:
 | 
			
		||||
  HierarchicalTreeBroadcasterTest() : device_type_(DEVICE_CPU) {}
 | 
			
		||||
  HierarchicalTreeBroadcasterTest()
 | 
			
		||||
      : device_type_(DEVICE_CPU), col_exec_(nullptr), col_params_(nullptr) {}
 | 
			
		||||
 | 
			
		||||
  ~HierarchicalTreeBroadcasterTest() override {
 | 
			
		||||
    stop_ = true;
 | 
			
		||||
    for (auto i : instances_) delete i;
 | 
			
		||||
    if (col_exec_) col_exec_->Unref();
 | 
			
		||||
    if (col_params_) col_params_->Unref();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | 
			
		||||
@ -262,30 +265,31 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
 | 
			
		||||
    col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId,
 | 
			
		||||
                                           dev_mgr_.get(),
 | 
			
		||||
                                           gpu_ring_order_.get(), work_queue_);
 | 
			
		||||
    col_params_.name = "test_collective";
 | 
			
		||||
    col_params_.instance.data_type = dtype;
 | 
			
		||||
    col_params_ = new CollectiveParams();
 | 
			
		||||
    col_params_->name = "test_collective";
 | 
			
		||||
    col_params_->instance.data_type = dtype;
 | 
			
		||||
    static const int kGroupKey = 6;
 | 
			
		||||
    col_params_.group.group_key = kGroupKey;
 | 
			
		||||
    col_params_->group.group_key = kGroupKey;
 | 
			
		||||
    static const int kInstanceKey = 18;
 | 
			
		||||
    col_params_.instance.instance_key = kInstanceKey;
 | 
			
		||||
    col_params_.group.device_type = device_type;
 | 
			
		||||
    col_params_.group.group_size = num_workers * num_devices_per_worker;
 | 
			
		||||
    col_params_.instance.impl_details.subdiv_offsets.clear();
 | 
			
		||||
    col_params_.instance.type = BROADCAST_COLLECTIVE;
 | 
			
		||||
    col_params_->instance.instance_key = kInstanceKey;
 | 
			
		||||
    col_params_->group.device_type = device_type;
 | 
			
		||||
    col_params_->group.group_size = num_workers * num_devices_per_worker;
 | 
			
		||||
    col_params_->instance.impl_details.subdiv_offsets.clear();
 | 
			
		||||
    col_params_->instance.type = BROADCAST_COLLECTIVE;
 | 
			
		||||
 | 
			
		||||
    int num_subdivs = num_workers + (num_workers > 1 ? 1 : 0);
 | 
			
		||||
    VLOG(2) << "#subdiv=" << num_subdivs;
 | 
			
		||||
    col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs);
 | 
			
		||||
    col_params_.subdiv_rank.resize(num_subdivs);
 | 
			
		||||
    col_params_->instance.impl_details.subdiv_permutations.resize(num_subdivs);
 | 
			
		||||
    col_params_->subdiv_rank.resize(num_subdivs);
 | 
			
		||||
 | 
			
		||||
    // Inter-machine broadcast.
 | 
			
		||||
    int subdiv_i = 0;
 | 
			
		||||
    if (num_workers > 1) {
 | 
			
		||||
      col_params_.instance.impl_details.subdiv_permutations[subdiv_i].resize(
 | 
			
		||||
      col_params_->instance.impl_details.subdiv_permutations[subdiv_i].resize(
 | 
			
		||||
          total_num_devices, -1);
 | 
			
		||||
      for (int i = 0, rank = 0; i < total_num_devices; i++) {
 | 
			
		||||
        if (i % num_devices_per_worker == 0) {
 | 
			
		||||
          col_params_.instance.impl_details
 | 
			
		||||
          col_params_->instance.impl_details
 | 
			
		||||
              .subdiv_permutations[subdiv_i][rank] = i;
 | 
			
		||||
          rank++;
 | 
			
		||||
        }
 | 
			
		||||
@ -293,7 +297,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
 | 
			
		||||
      if (VLOG_IS_ON(2)) {
 | 
			
		||||
        string sp_buf;
 | 
			
		||||
        for (int p :
 | 
			
		||||
             col_params_.instance.impl_details.subdiv_permutations[subdiv_i])
 | 
			
		||||
             col_params_->instance.impl_details.subdiv_permutations[subdiv_i])
 | 
			
		||||
          strings::StrAppend(&sp_buf, p, ", ");
 | 
			
		||||
        VLOG(2) << "subdiv_i=" << subdiv_i << " perm=" << sp_buf;
 | 
			
		||||
      }
 | 
			
		||||
@ -301,22 +305,22 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
 | 
			
		||||
    }
 | 
			
		||||
    // Intra-machine broadcast.
 | 
			
		||||
    for (int i = 0; subdiv_i < num_subdivs; i++, subdiv_i++) {
 | 
			
		||||
      col_params_.instance.impl_details.subdiv_permutations[subdiv_i].resize(
 | 
			
		||||
      col_params_->instance.impl_details.subdiv_permutations[subdiv_i].resize(
 | 
			
		||||
          total_num_devices, -1);
 | 
			
		||||
      int perm_i_base = i * num_devices_per_worker;
 | 
			
		||||
      VLOG(2) << "subdiv_i=" << subdiv_i << " i=" << i
 | 
			
		||||
              << " perm_i_base=" << perm_i_base << " subdiv_perms.size="
 | 
			
		||||
              << col_params_.instance.impl_details.subdiv_permutations.size();
 | 
			
		||||
              << col_params_->instance.impl_details.subdiv_permutations.size();
 | 
			
		||||
      // subdiv for worker i.
 | 
			
		||||
      for (int j = perm_i_base, rank = 0;
 | 
			
		||||
           j < perm_i_base + num_devices_per_worker; j++, rank++) {
 | 
			
		||||
        col_params_.instance.impl_details.subdiv_permutations[subdiv_i][rank] =
 | 
			
		||||
        col_params_->instance.impl_details.subdiv_permutations[subdiv_i][rank] =
 | 
			
		||||
            j;
 | 
			
		||||
      }
 | 
			
		||||
      if (VLOG_IS_ON(2)) {
 | 
			
		||||
        string sp_buf;
 | 
			
		||||
        for (int p :
 | 
			
		||||
             col_params_.instance.impl_details.subdiv_permutations[subdiv_i])
 | 
			
		||||
             col_params_->instance.impl_details.subdiv_permutations[subdiv_i])
 | 
			
		||||
          strings::StrAppend(&sp_buf, p, ", ");
 | 
			
		||||
        VLOG(2) << "subdiv_i=" << subdiv_i << " perm=" << sp_buf;
 | 
			
		||||
      }
 | 
			
		||||
@ -333,16 +337,16 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
 | 
			
		||||
          dev_name = strings::StrCat(task_name, "/device:CPU:", di);
 | 
			
		||||
        }
 | 
			
		||||
        VLOG(2) << "dev=" << dev_name;
 | 
			
		||||
        col_params_.group.device_names.push_back(dev_name);
 | 
			
		||||
        col_params_.group.task_names.push_back(task_name);
 | 
			
		||||
        col_params_.task.is_local.push_back(true);
 | 
			
		||||
        col_params_->group.device_names.push_back(dev_name);
 | 
			
		||||
        col_params_->group.task_names.push_back(task_name);
 | 
			
		||||
        col_params_->task.is_local.push_back(true);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    for (int wi = 0; wi < num_workers; wi++) {
 | 
			
		||||
      for (int di = 0; di < num_devices_per_worker; di++) {
 | 
			
		||||
        int default_rank = wi * num_devices_per_worker + di;
 | 
			
		||||
        instances_.push_back(new DeviceInstance(
 | 
			
		||||
            default_rank, col_params_.group.device_names[default_rank],
 | 
			
		||||
            default_rank, col_params_->group.device_names[default_rank],
 | 
			
		||||
            device_type, this));
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
@ -435,7 +439,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
 | 
			
		||||
 | 
			
		||||
    // Copy the expected value from the broadcast source tensor
 | 
			
		||||
    std::vector<T> expected(tensor_len, 0.0);
 | 
			
		||||
    const CollectiveParams& cp = instances_[0]->col_params_;
 | 
			
		||||
    const CollectiveParams& cp = *instances_[0]->col_params_;
 | 
			
		||||
    int broadcast_dev_id =
 | 
			
		||||
        cp.instance.impl_details.subdiv_permutations
 | 
			
		||||
            [0][cp.instance.impl_details.subdiv_source_rank[0]];
 | 
			
		||||
@ -558,27 +562,29 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
 | 
			
		||||
        : parent_(parent),
 | 
			
		||||
          dev_name_(dev_name),
 | 
			
		||||
          device_type_(device_type),
 | 
			
		||||
          rank_(rank) {
 | 
			
		||||
          rank_(rank),
 | 
			
		||||
          col_params_(new CollectiveParams()) {
 | 
			
		||||
      TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_));
 | 
			
		||||
      col_params_.name = parent_->col_params_.name;
 | 
			
		||||
      col_params_.instance.data_type = parent_->col_params_.instance.data_type;
 | 
			
		||||
      col_params_.group = parent_->col_params_.group;
 | 
			
		||||
      col_params_.instance.instance_key =
 | 
			
		||||
          parent_->col_params_.instance.instance_key;
 | 
			
		||||
      col_params_.task.is_local = parent_->col_params_.task.is_local;
 | 
			
		||||
      col_params_.instance.impl_details.subdiv_permutations =
 | 
			
		||||
          parent_->col_params_.instance.impl_details.subdiv_permutations;
 | 
			
		||||
      col_params_.subdiv_rank = parent_->col_params_.subdiv_rank;
 | 
			
		||||
      col_params_->name = parent_->col_params_->name;
 | 
			
		||||
      col_params_->instance.data_type =
 | 
			
		||||
          parent_->col_params_->instance.data_type;
 | 
			
		||||
      col_params_->group = parent_->col_params_->group;
 | 
			
		||||
      col_params_->instance.instance_key =
 | 
			
		||||
          parent_->col_params_->instance.instance_key;
 | 
			
		||||
      col_params_->task.is_local = parent_->col_params_->task.is_local;
 | 
			
		||||
      col_params_->instance.impl_details.subdiv_permutations =
 | 
			
		||||
          parent_->col_params_->instance.impl_details.subdiv_permutations;
 | 
			
		||||
      col_params_->subdiv_rank = parent_->col_params_->subdiv_rank;
 | 
			
		||||
 | 
			
		||||
      int group_size = col_params_.group.group_size;
 | 
			
		||||
      CHECK_EQ(group_size, col_params_.group.device_names.size());
 | 
			
		||||
      int group_size = col_params_->group.group_size;
 | 
			
		||||
      CHECK_EQ(group_size, col_params_->group.device_names.size());
 | 
			
		||||
      // Default rank is order in device_names.
 | 
			
		||||
      col_params_.default_rank = rank;
 | 
			
		||||
      col_params_->default_rank = rank;
 | 
			
		||||
 | 
			
		||||
      auto& impl = col_params_.instance.impl_details;
 | 
			
		||||
      auto& impl = col_params_->instance.impl_details;
 | 
			
		||||
      size_t num_subdivs = impl.subdiv_permutations.size();
 | 
			
		||||
      impl.subdiv_source_rank.resize(num_subdivs, 0);
 | 
			
		||||
      col_params_.subdiv_rank.resize(num_subdivs);
 | 
			
		||||
      col_params_->subdiv_rank.resize(num_subdivs);
 | 
			
		||||
      for (size_t si = 0; si < num_subdivs; si++) {
 | 
			
		||||
        int perm_rank = -1;
 | 
			
		||||
        for (int i = 0; i < group_size; i++) {
 | 
			
		||||
@ -587,18 +593,20 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
 | 
			
		||||
            break;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
        col_params_.subdiv_rank[si] = perm_rank;
 | 
			
		||||
        col_params_->subdiv_rank[si] = perm_rank;
 | 
			
		||||
      }
 | 
			
		||||
      string rank_buf;
 | 
			
		||||
      for (int r : col_params_.subdiv_rank) {
 | 
			
		||||
      for (int r : col_params_->subdiv_rank) {
 | 
			
		||||
        strings::StrAppend(&rank_buf, r, ", ");
 | 
			
		||||
      }
 | 
			
		||||
      VLOG(1) << "default=" << rank << " subdiv_ranks=" << rank_buf;
 | 
			
		||||
 | 
			
		||||
      col_params_.is_source =
 | 
			
		||||
          col_params_.subdiv_rank[0] == impl.subdiv_source_rank[0];
 | 
			
		||||
      col_params_->is_source =
 | 
			
		||||
          col_params_->subdiv_rank[0] == impl.subdiv_source_rank[0];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ~DeviceInstance() { col_params_->Unref(); }
 | 
			
		||||
 | 
			
		||||
    void InitTensor(DataType dtype, const TensorShape& shape,
 | 
			
		||||
                    const InitFunc& f) {
 | 
			
		||||
      tensor_ =
 | 
			
		||||
@ -641,22 +649,22 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
 | 
			
		||||
      op_params.op_device_context = dev_ctx;
 | 
			
		||||
      int forward_from[] = {OpKernelContext::Params::kNeverForward};
 | 
			
		||||
      if (forward_input) forward_from[0] = 0;
 | 
			
		||||
      if (col_params_.is_source) {
 | 
			
		||||
      if (col_params_->is_source) {
 | 
			
		||||
        op_params.forward_from_array = &forward_from[0];
 | 
			
		||||
      }
 | 
			
		||||
      AllocatorAttributes generic_alloc_attr;
 | 
			
		||||
      op_params.output_attr_array = &generic_alloc_attr;
 | 
			
		||||
      std::unique_ptr<OpKernel> op =
 | 
			
		||||
          col_params_.is_source
 | 
			
		||||
              ? parent_->GetCollectiveBcastSend(col_params_, &tensor_,
 | 
			
		||||
          col_params_->is_source
 | 
			
		||||
              ? parent_->GetCollectiveBcastSend(*col_params_, &tensor_,
 | 
			
		||||
                                                DEVICE_CPU, device_)
 | 
			
		||||
              : parent_->GetCollectiveBcastRecv(col_params_, tensor_.shape(),
 | 
			
		||||
              : parent_->GetCollectiveBcastRecv(*col_params_, tensor_.shape(),
 | 
			
		||||
                                                DEVICE_CPU, device_);
 | 
			
		||||
      op_params.op_kernel = op.get();
 | 
			
		||||
      OpKernelContext ctx(&op_params, 1);
 | 
			
		||||
 | 
			
		||||
      Tensor* output_tensor_ptr = nullptr;
 | 
			
		||||
      if (col_params_.is_source) {
 | 
			
		||||
      if (col_params_->is_source) {
 | 
			
		||||
        TF_CHECK_OK(ctx.forward_input_or_allocate_output(
 | 
			
		||||
            {0}, 0, tensor_.shape(), &output_tensor_ptr));
 | 
			
		||||
      } else {
 | 
			
		||||
@ -665,11 +673,11 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
 | 
			
		||||
      }
 | 
			
		||||
      CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0));
 | 
			
		||||
      const Tensor* input_tensor_ptr =
 | 
			
		||||
          col_params_.is_source ? &tensor_ : nullptr;
 | 
			
		||||
          col_params_->is_source ? &tensor_ : nullptr;
 | 
			
		||||
 | 
			
		||||
      // Prepare a Broadcaster instance.
 | 
			
		||||
      string exec_key =
 | 
			
		||||
          strings::StrCat(col_params_.instance.instance_key, ":0:0");
 | 
			
		||||
          strings::StrCat(col_params_->instance.instance_key, ":0:0");
 | 
			
		||||
      HierarchicalTreeBroadcaster* broadcaster =
 | 
			
		||||
          new HierarchicalTreeBroadcaster;
 | 
			
		||||
      core::ScopedUnref unref(broadcaster);
 | 
			
		||||
@ -694,7 +702,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
 | 
			
		||||
    int rank_;
 | 
			
		||||
    Tensor tensor_;
 | 
			
		||||
    Device* device_;
 | 
			
		||||
    CollectiveParams col_params_;
 | 
			
		||||
    CollectiveParams* col_params_;
 | 
			
		||||
    Status status_;
 | 
			
		||||
  };  // class DeviceInstance
 | 
			
		||||
 | 
			
		||||
@ -708,7 +716,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
 | 
			
		||||
  std::unique_ptr<DeviceResolverLocal> dev_resolver_;
 | 
			
		||||
  std::shared_ptr<UnboundedWorkQueue> work_queue_;
 | 
			
		||||
  std::vector<DeviceInstance*> instances_;
 | 
			
		||||
  CollectiveParams col_params_;
 | 
			
		||||
  CollectiveParams* col_params_;
 | 
			
		||||
  std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
 | 
			
		||||
  std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
 | 
			
		||||
  std::unique_ptr<string> gpu_ring_order_;
 | 
			
		||||
@ -720,33 +728,35 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams1Task8GPU) {
 | 
			
		||||
  CollectiveParams cp;
 | 
			
		||||
  PrepColParamsForSubdivPermsTest(&cp, 1, 8);
 | 
			
		||||
  auto* cp = new CollectiveParams();
 | 
			
		||||
  core::ScopedUnref unref(cp);
 | 
			
		||||
  PrepColParamsForSubdivPermsTest(cp, 1, 8);
 | 
			
		||||
 | 
			
		||||
  // source 0 device 0
 | 
			
		||||
  cp.source_rank = 0;
 | 
			
		||||
  cp.default_rank = 0;
 | 
			
		||||
  RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {0});
 | 
			
		||||
  cp->source_rank = 0;
 | 
			
		||||
  cp->default_rank = 0;
 | 
			
		||||
  RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {0});
 | 
			
		||||
 | 
			
		||||
  // source 2 device 2
 | 
			
		||||
  cp.source_rank = 2;
 | 
			
		||||
  cp.default_rank = 2;
 | 
			
		||||
  RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {2}, {2});
 | 
			
		||||
  cp->source_rank = 2;
 | 
			
		||||
  cp->default_rank = 2;
 | 
			
		||||
  RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {2}, {2});
 | 
			
		||||
 | 
			
		||||
  // source 2 device 0
 | 
			
		||||
  cp.source_rank = 2;
 | 
			
		||||
  cp.default_rank = 0;
 | 
			
		||||
  RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {2});
 | 
			
		||||
  cp->source_rank = 2;
 | 
			
		||||
  cp->default_rank = 0;
 | 
			
		||||
  RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {2});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4Tasks8GPU) {
 | 
			
		||||
  CollectiveParams cp;
 | 
			
		||||
  PrepColParamsForSubdivPermsTest(&cp, 4, 8);
 | 
			
		||||
  auto* cp = new CollectiveParams();
 | 
			
		||||
  core::ScopedUnref unref(cp);
 | 
			
		||||
  PrepColParamsForSubdivPermsTest(cp, 4, 8);
 | 
			
		||||
 | 
			
		||||
  // source 0 device 0
 | 
			
		||||
  cp.source_rank = 0;
 | 
			
		||||
  cp.default_rank = 0;
 | 
			
		||||
  RunSubdivPermsTest(&cp,
 | 
			
		||||
  cp->source_rank = 0;
 | 
			
		||||
  cp->default_rank = 0;
 | 
			
		||||
  RunSubdivPermsTest(cp,
 | 
			
		||||
                     {{0, 8, 16, 24},
 | 
			
		||||
                      {0, 1, 2, 3, 4, 5, 6, 7},
 | 
			
		||||
                      {8, 9, 10, 11, 12, 13, 14, 15},
 | 
			
		||||
@ -755,9 +765,9 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4Tasks8GPU) {
 | 
			
		||||
                     {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0});
 | 
			
		||||
 | 
			
		||||
  // source 2 device 0
 | 
			
		||||
  cp.source_rank = 2;
 | 
			
		||||
  cp.default_rank = 0;
 | 
			
		||||
  RunSubdivPermsTest(&cp,
 | 
			
		||||
  cp->source_rank = 2;
 | 
			
		||||
  cp->default_rank = 0;
 | 
			
		||||
  RunSubdivPermsTest(cp,
 | 
			
		||||
                     {{2, 8, 16, 24},
 | 
			
		||||
                      {0, 1, 2, 3, 4, 5, 6, 7},
 | 
			
		||||
                      {8, 9, 10, 11, 12, 13, 14, 15},
 | 
			
		||||
@ -766,9 +776,9 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4Tasks8GPU) {
 | 
			
		||||
                     {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0});
 | 
			
		||||
 | 
			
		||||
  // source 9 device 9
 | 
			
		||||
  cp.source_rank = 9;
 | 
			
		||||
  cp.default_rank = 9;
 | 
			
		||||
  RunSubdivPermsTest(&cp,
 | 
			
		||||
  cp->source_rank = 9;
 | 
			
		||||
  cp->default_rank = 9;
 | 
			
		||||
  RunSubdivPermsTest(cp,
 | 
			
		||||
                     {{0, 9, 16, 24},
 | 
			
		||||
                      {0, 1, 2, 3, 4, 5, 6, 7},
 | 
			
		||||
                      {8, 9, 10, 11, 12, 13, 14, 15},
 | 
			
		||||
@ -778,28 +788,29 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4Tasks8GPU) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4TasksVariableGPU) {
 | 
			
		||||
  CollectiveParams cp;
 | 
			
		||||
  auto* cp = new CollectiveParams();
 | 
			
		||||
  core::ScopedUnref unref(cp);
 | 
			
		||||
  int num_tasks = 4;
 | 
			
		||||
  cp.group.device_type = DeviceType("GPU");
 | 
			
		||||
  cp.group.num_tasks = num_tasks;
 | 
			
		||||
  cp.group.group_size = 0;
 | 
			
		||||
  cp.instance.type = BROADCAST_COLLECTIVE;
 | 
			
		||||
  cp.instance.impl_details.collective_name = "HierarchicalTreeBroadcast";
 | 
			
		||||
  cp->group.device_type = DeviceType("GPU");
 | 
			
		||||
  cp->group.num_tasks = num_tasks;
 | 
			
		||||
  cp->group.group_size = 0;
 | 
			
		||||
  cp->instance.type = BROADCAST_COLLECTIVE;
 | 
			
		||||
  cp->instance.impl_details.collective_name = "HierarchicalTreeBroadcast";
 | 
			
		||||
  std::vector<int> dev_per_task = {4, 4, 6, 8};
 | 
			
		||||
  for (int ti = 0; ti < cp.group.num_tasks; ti++) {
 | 
			
		||||
  for (int ti = 0; ti < cp->group.num_tasks; ti++) {
 | 
			
		||||
    string task_name = strings::StrCat("/job:worker/replica:0/task:", ti);
 | 
			
		||||
    for (int di = 0; di < dev_per_task[ti]; di++) {
 | 
			
		||||
      string dev_name = strings::StrCat(task_name, "/device:GPU:", di);
 | 
			
		||||
      cp.group.task_names.push_back(task_name);
 | 
			
		||||
      cp.group.device_names.push_back(dev_name);
 | 
			
		||||
      cp.group.group_size++;
 | 
			
		||||
      cp->group.task_names.push_back(task_name);
 | 
			
		||||
      cp->group.device_names.push_back(dev_name);
 | 
			
		||||
      cp->group.group_size++;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // source 0 device 0
 | 
			
		||||
  cp.source_rank = 0;
 | 
			
		||||
  cp.default_rank = 0;
 | 
			
		||||
  RunSubdivPermsTest(&cp,
 | 
			
		||||
  cp->source_rank = 0;
 | 
			
		||||
  cp->default_rank = 0;
 | 
			
		||||
  RunSubdivPermsTest(cp,
 | 
			
		||||
                     {{0, 4, 8, 14},
 | 
			
		||||
                      {0, 1, 2, 3},
 | 
			
		||||
                      {4, 5, 6, 7},
 | 
			
		||||
@ -808,9 +819,9 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4TasksVariableGPU) {
 | 
			
		||||
                     {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0});
 | 
			
		||||
 | 
			
		||||
  // source 2 device 0
 | 
			
		||||
  cp.source_rank = 2;
 | 
			
		||||
  cp.default_rank = 0;
 | 
			
		||||
  RunSubdivPermsTest(&cp,
 | 
			
		||||
  cp->source_rank = 2;
 | 
			
		||||
  cp->default_rank = 0;
 | 
			
		||||
  RunSubdivPermsTest(cp,
 | 
			
		||||
                     {{2, 4, 8, 14},
 | 
			
		||||
                      {0, 1, 2, 3},
 | 
			
		||||
                      {4, 5, 6, 7},
 | 
			
		||||
@ -819,9 +830,9 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4TasksVariableGPU) {
 | 
			
		||||
                     {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0});
 | 
			
		||||
 | 
			
		||||
  // source 9 device 5
 | 
			
		||||
  cp.source_rank = 9;
 | 
			
		||||
  cp.default_rank = 5;
 | 
			
		||||
  RunSubdivPermsTest(&cp,
 | 
			
		||||
  cp->source_rank = 9;
 | 
			
		||||
  cp->default_rank = 5;
 | 
			
		||||
  RunSubdivPermsTest(cp,
 | 
			
		||||
                     {{0, 4, 9, 14},
 | 
			
		||||
                      {0, 1, 2, 3},
 | 
			
		||||
                      {4, 5, 6, 7},
 | 
			
		||||
 | 
			
		||||
@ -54,7 +54,7 @@ Status Permuter::InitializeCollectiveContext(
 | 
			
		||||
    std::shared_ptr<CollectiveContext> col_ctx) {
 | 
			
		||||
  DCHECK(col_ctx->dev_mgr);
 | 
			
		||||
  col_ctx_ = col_ctx;
 | 
			
		||||
  col_params_ = &col_ctx->col_params;
 | 
			
		||||
  col_params_ = col_ctx->col_params;
 | 
			
		||||
  return collective_util::InitializeDeviceAndLocality(
 | 
			
		||||
      col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
 | 
			
		||||
      &col_ctx->device_locality);
 | 
			
		||||
 | 
			
		||||
@ -107,12 +107,14 @@ class FailTestRMA : public CollectiveRemoteAccessLocal {
 | 
			
		||||
 | 
			
		||||
class PermuterTest : public ::testing::Test {
 | 
			
		||||
 protected:
 | 
			
		||||
  PermuterTest() : device_type_(DEVICE_CPU) {}
 | 
			
		||||
  PermuterTest()
 | 
			
		||||
      : device_type_(DEVICE_CPU), col_exec_(nullptr), col_params_(nullptr) {}
 | 
			
		||||
 | 
			
		||||
  ~PermuterTest() override {
 | 
			
		||||
    stop_ = true;
 | 
			
		||||
    for (auto i : instances_) delete i;
 | 
			
		||||
    if (col_exec_) col_exec_->Unref();
 | 
			
		||||
    if (col_params_) col_params_->Unref();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | 
			
		||||
@ -170,12 +172,13 @@ class PermuterTest : public ::testing::Test {
 | 
			
		||||
    col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId,
 | 
			
		||||
                                           dev_mgr_.get(),
 | 
			
		||||
                                           gpu_ring_order_.get(), work_queue_);
 | 
			
		||||
    col_params_.name = "test_collective";
 | 
			
		||||
    col_params_.instance.data_type = dtype;
 | 
			
		||||
    col_params_ = new CollectiveParams();
 | 
			
		||||
    col_params_->name = "test_collective";
 | 
			
		||||
    col_params_->instance.data_type = dtype;
 | 
			
		||||
    static const int kInstanceKey = 18;
 | 
			
		||||
    col_params_.instance.instance_key = kInstanceKey;
 | 
			
		||||
    col_params_.group.device_type = device_type;
 | 
			
		||||
    col_params_.instance.type = PERMUTE_COLLECTIVE;
 | 
			
		||||
    col_params_->instance.instance_key = kInstanceKey;
 | 
			
		||||
    col_params_->group.device_type = device_type;
 | 
			
		||||
    col_params_->instance.type = PERMUTE_COLLECTIVE;
 | 
			
		||||
 | 
			
		||||
    // Set up all the fake device contexts.
 | 
			
		||||
    for (int wi = 0; wi < num_workers; wi++) {
 | 
			
		||||
@ -187,12 +190,12 @@ class PermuterTest : public ::testing::Test {
 | 
			
		||||
        } else {
 | 
			
		||||
          dev_name = strings::StrCat(task_name, "/device:CPU:", di);
 | 
			
		||||
        }
 | 
			
		||||
        col_params_.group.device_names.push_back(dev_name);
 | 
			
		||||
        col_params_.instance.devices.push_back(dev_name);
 | 
			
		||||
        col_params_->group.device_names.push_back(dev_name);
 | 
			
		||||
        col_params_->instance.devices.push_back(dev_name);
 | 
			
		||||
        int default_rank = wi * num_devices_per_worker + di;
 | 
			
		||||
        permutation_.push_back(default_rank);
 | 
			
		||||
        col_params_.group.task_names.push_back(task_name);
 | 
			
		||||
        col_params_.task.is_local.push_back(true);
 | 
			
		||||
        col_params_->group.task_names.push_back(task_name);
 | 
			
		||||
        col_params_->task.is_local.push_back(true);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -210,13 +213,13 @@ class PermuterTest : public ::testing::Test {
 | 
			
		||||
      std::next_permutation(permutation_.begin() + i,
 | 
			
		||||
                            permutation_.begin() + i + 2);
 | 
			
		||||
    }
 | 
			
		||||
    col_params_.instance.permutation = permutation_;
 | 
			
		||||
    col_params_->instance.permutation = permutation_;
 | 
			
		||||
 | 
			
		||||
    for (int wi = 0; wi < num_workers; wi++) {
 | 
			
		||||
      for (int di = 0; di < num_devices_per_worker; di++) {
 | 
			
		||||
        int default_rank = wi * num_devices_per_worker + di;
 | 
			
		||||
        instances_.push_back(new DeviceInstance(
 | 
			
		||||
            default_rank, col_params_.group.device_names[default_rank],
 | 
			
		||||
            default_rank, col_params_->group.device_names[default_rank],
 | 
			
		||||
            device_type, this));
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
@ -320,25 +323,30 @@ class PermuterTest : public ::testing::Test {
 | 
			
		||||
        : parent_(parent),
 | 
			
		||||
          dev_name_(dev_name),
 | 
			
		||||
          device_type_(device_type),
 | 
			
		||||
          rank_(rank) {
 | 
			
		||||
          rank_(rank),
 | 
			
		||||
          col_params_(new CollectiveParams()) {
 | 
			
		||||
      TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_));
 | 
			
		||||
      col_params_.name = parent_->col_params_.name;
 | 
			
		||||
      col_params_.instance.data_type = parent_->col_params_.instance.data_type;
 | 
			
		||||
      col_params_.instance.instance_key =
 | 
			
		||||
          parent_->col_params_.instance.instance_key;
 | 
			
		||||
      col_params_.group.device_type = parent_->col_params_.group.device_type;
 | 
			
		||||
      col_params_.group.device_names = parent_->col_params_.group.device_names;
 | 
			
		||||
      col_params_.instance.devices = parent_->col_params_.instance.devices;
 | 
			
		||||
      col_params_.instance.permutation =
 | 
			
		||||
          parent->col_params_.instance.permutation;
 | 
			
		||||
      col_params_.group.task_names = parent_->col_params_.group.task_names;
 | 
			
		||||
      col_params_.task.is_local = parent_->col_params_.task.is_local;
 | 
			
		||||
      CHECK_EQ(col_params_.instance.devices.size(),
 | 
			
		||||
               col_params_.group.device_names.size());
 | 
			
		||||
      col_params_->name = parent_->col_params_->name;
 | 
			
		||||
      col_params_->instance.data_type =
 | 
			
		||||
          parent_->col_params_->instance.data_type;
 | 
			
		||||
      col_params_->instance.instance_key =
 | 
			
		||||
          parent_->col_params_->instance.instance_key;
 | 
			
		||||
      col_params_->group.device_type = parent_->col_params_->group.device_type;
 | 
			
		||||
      col_params_->group.device_names =
 | 
			
		||||
          parent_->col_params_->group.device_names;
 | 
			
		||||
      col_params_->instance.devices = parent_->col_params_->instance.devices;
 | 
			
		||||
      col_params_->instance.permutation =
 | 
			
		||||
          parent->col_params_->instance.permutation;
 | 
			
		||||
      col_params_->group.task_names = parent_->col_params_->group.task_names;
 | 
			
		||||
      col_params_->task.is_local = parent_->col_params_->task.is_local;
 | 
			
		||||
      CHECK_EQ(col_params_->instance.devices.size(),
 | 
			
		||||
               col_params_->group.device_names.size());
 | 
			
		||||
      // Default rank is order in device_names.
 | 
			
		||||
      col_params_.default_rank = rank;
 | 
			
		||||
      col_params_->default_rank = rank;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ~DeviceInstance() { col_params_->Unref(); }
 | 
			
		||||
 | 
			
		||||
    void InitTensor(DataType dtype, const TensorShape& shape,
 | 
			
		||||
                    const InitFunc& f) {
 | 
			
		||||
      tensor_input_ =
 | 
			
		||||
@ -387,7 +395,7 @@ class PermuterTest : public ::testing::Test {
 | 
			
		||||
 | 
			
		||||
      // Prepare a Permuter instance.
 | 
			
		||||
      string exec_key =
 | 
			
		||||
          strings::StrCat(col_params_.instance.instance_key, ":0:0");
 | 
			
		||||
          strings::StrCat(col_params_->instance.instance_key, ":0:0");
 | 
			
		||||
      Permuter* permuter = new Permuter;
 | 
			
		||||
      core::ScopedUnref unref(permuter);
 | 
			
		||||
      auto col_ctx = std::make_shared<CollectiveContext>(
 | 
			
		||||
@ -412,7 +420,7 @@ class PermuterTest : public ::testing::Test {
 | 
			
		||||
    Tensor tensor_input_;
 | 
			
		||||
    Tensor tensor_output_;
 | 
			
		||||
    Device* device_;
 | 
			
		||||
    CollectiveParams col_params_;
 | 
			
		||||
    CollectiveParams* col_params_;
 | 
			
		||||
    Status status_;
 | 
			
		||||
  };  // class DeviceInstance
 | 
			
		||||
 | 
			
		||||
@ -425,7 +433,7 @@ class PermuterTest : public ::testing::Test {
 | 
			
		||||
  std::unique_ptr<DeviceResolverLocal> dev_resolver_;
 | 
			
		||||
  std::shared_ptr<UnboundedWorkQueue> work_queue_;
 | 
			
		||||
  std::vector<DeviceInstance*> instances_;
 | 
			
		||||
  CollectiveParams col_params_;
 | 
			
		||||
  CollectiveParams* col_params_;
 | 
			
		||||
  std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
 | 
			
		||||
  std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
 | 
			
		||||
  std::unique_ptr<string> gpu_ring_order_;
 | 
			
		||||
 | 
			
		||||
@ -245,7 +245,7 @@ Status RingAlg::InitializeCollectiveContext(
 | 
			
		||||
    std::shared_ptr<CollectiveContext> col_ctx) {
 | 
			
		||||
  DCHECK(col_ctx->dev_mgr);
 | 
			
		||||
  col_ctx_ = col_ctx;
 | 
			
		||||
  col_params_ = &col_ctx->col_params;
 | 
			
		||||
  col_params_ = col_ctx->col_params;
 | 
			
		||||
  return collective_util::InitializeDeviceAndLocality(
 | 
			
		||||
      col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
 | 
			
		||||
      &col_ctx->device_locality);
 | 
			
		||||
 | 
			
		||||
@ -115,7 +115,8 @@ static int64 kStepId = 123;
 | 
			
		||||
 | 
			
		||||
class RingGathererTest : public ::testing::Test {
 | 
			
		||||
 protected:
 | 
			
		||||
  RingGathererTest() : device_type_(DEVICE_CPU) {}
 | 
			
		||||
  RingGathererTest()
 | 
			
		||||
      : device_type_(DEVICE_CPU), col_exec_(nullptr), col_params_(nullptr) {}
 | 
			
		||||
 | 
			
		||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | 
			
		||||
  void InitGPUDevices() {
 | 
			
		||||
@ -132,6 +133,7 @@ class RingGathererTest : public ::testing::Test {
 | 
			
		||||
    stop_ = true;
 | 
			
		||||
    for (auto i : instances_) delete i;
 | 
			
		||||
    if (col_exec_) col_exec_->Unref();
 | 
			
		||||
    if (col_params_) col_params_->Unref();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void Init(int num_workers, int num_devices, DataType dtype,
 | 
			
		||||
@ -180,24 +182,25 @@ class RingGathererTest : public ::testing::Test {
 | 
			
		||||
    col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId,
 | 
			
		||||
                                           dev_mgr_.get(),
 | 
			
		||||
                                           gpu_ring_order_.get(), work_queue_);
 | 
			
		||||
    col_params_.name = "test_collective";
 | 
			
		||||
    col_params_ = new CollectiveParams();
 | 
			
		||||
    col_params_->name = "test_collective";
 | 
			
		||||
    static const int kGroupKey = 5;
 | 
			
		||||
    col_params_.group.group_key = kGroupKey;
 | 
			
		||||
    col_params_.group.device_type = device_type;
 | 
			
		||||
    col_params_.group.group_size = num_workers * num_devices;
 | 
			
		||||
    col_params_->group.group_key = kGroupKey;
 | 
			
		||||
    col_params_->group.device_type = device_type;
 | 
			
		||||
    col_params_->group.group_size = num_workers * num_devices;
 | 
			
		||||
    static const int kInstanceKey = 17;
 | 
			
		||||
    col_params_.instance.instance_key = kInstanceKey;
 | 
			
		||||
    col_params_.instance.impl_details.subdiv_offsets.clear();
 | 
			
		||||
    col_params_.instance.type = GATHER_COLLECTIVE;
 | 
			
		||||
    col_params_.instance.impl_details.collective_name = "RingGather";
 | 
			
		||||
    col_params_.instance.data_type = dtype;
 | 
			
		||||
    col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs);
 | 
			
		||||
    col_params_.subdiv_rank.resize(num_subdivs);
 | 
			
		||||
    col_params_->instance.instance_key = kInstanceKey;
 | 
			
		||||
    col_params_->instance.impl_details.subdiv_offsets.clear();
 | 
			
		||||
    col_params_->instance.type = GATHER_COLLECTIVE;
 | 
			
		||||
    col_params_->instance.impl_details.collective_name = "RingGather";
 | 
			
		||||
    col_params_->instance.data_type = dtype;
 | 
			
		||||
    col_params_->instance.impl_details.subdiv_permutations.resize(num_subdivs);
 | 
			
		||||
    col_params_->subdiv_rank.resize(num_subdivs);
 | 
			
		||||
    int subdiv_stride = num_devices / num_subdivs;
 | 
			
		||||
    for (int sdi = 0; sdi < num_subdivs; ++sdi) {
 | 
			
		||||
      col_params_.instance.impl_details.subdiv_offsets.push_back(sdi *
 | 
			
		||||
                                                                 subdiv_stride);
 | 
			
		||||
      col_params_.subdiv_rank[sdi] = sdi * subdiv_stride;
 | 
			
		||||
      col_params_->instance.impl_details.subdiv_offsets.push_back(
 | 
			
		||||
          sdi * subdiv_stride);
 | 
			
		||||
      col_params_->subdiv_rank[sdi] = sdi * subdiv_stride;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Set up a local device ring order that's not just 0,1,2...
 | 
			
		||||
@ -225,16 +228,16 @@ class RingGathererTest : public ::testing::Test {
 | 
			
		||||
          dev_name =
 | 
			
		||||
              strings::StrCat(task_name, "/gpu:", di % gpu_devices_.size());
 | 
			
		||||
        }
 | 
			
		||||
        col_params_.group.device_names.push_back(dev_name);
 | 
			
		||||
        col_params_.group.task_names.push_back(task_name);
 | 
			
		||||
        col_params_->group.device_names.push_back(dev_name);
 | 
			
		||||
        col_params_->group.task_names.push_back(task_name);
 | 
			
		||||
        // Normally each device would set is_local to its own perspective but
 | 
			
		||||
        // this test runs in a single process so is_local is always true.
 | 
			
		||||
        col_params_.task.is_local.push_back(true);
 | 
			
		||||
        col_params_->task.is_local.push_back(true);
 | 
			
		||||
        for (int sdi = 0; sdi < num_subdivs; ++sdi) {
 | 
			
		||||
          int rotated_di =
 | 
			
		||||
              (di + col_params_.instance.impl_details.subdiv_offsets[sdi]) %
 | 
			
		||||
              (di + col_params_->instance.impl_details.subdiv_offsets[sdi]) %
 | 
			
		||||
              num_devices;
 | 
			
		||||
          col_params_.instance.impl_details.subdiv_permutations[sdi].push_back(
 | 
			
		||||
          col_params_->instance.impl_details.subdiv_permutations[sdi].push_back(
 | 
			
		||||
              wi * num_devices + local_ring_order[rotated_di]);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
@ -243,7 +246,7 @@ class RingGathererTest : public ::testing::Test {
 | 
			
		||||
      for (int di = 0; di < num_devices; ++di) {
 | 
			
		||||
        int rank = wi * num_devices + di;
 | 
			
		||||
        instances_.push_back(new DeviceInstance(
 | 
			
		||||
            rank, col_params_.group.device_names[rank], device_type_, this));
 | 
			
		||||
            rank, col_params_->group.device_names[rank], device_type_, this));
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
@ -387,39 +390,42 @@ class RingGathererTest : public ::testing::Test {
 | 
			
		||||
        : parent_(parent),
 | 
			
		||||
          dev_name_(dev_name),
 | 
			
		||||
          device_type_(device_type),
 | 
			
		||||
          rank_(rank) {
 | 
			
		||||
          rank_(rank),
 | 
			
		||||
          col_params_(new CollectiveParams()) {
 | 
			
		||||
      TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_))
 | 
			
		||||
          << "Couldn't find device " << dev_name
 | 
			
		||||
          << " existing devices: " << parent_->dev_mgr_->DebugString();
 | 
			
		||||
      col_params_.name = parent_->col_params_.name;
 | 
			
		||||
      col_params_.group = parent_->col_params_.group;
 | 
			
		||||
      col_params_.instance = parent->col_params_.instance;
 | 
			
		||||
      col_params_.task.is_local = parent_->col_params_.task.is_local;
 | 
			
		||||
      col_params_.subdiv_rank = parent_->col_params_.subdiv_rank;
 | 
			
		||||
      col_params_->name = parent_->col_params_->name;
 | 
			
		||||
      col_params_->group = parent_->col_params_->group;
 | 
			
		||||
      col_params_->instance = parent->col_params_->instance;
 | 
			
		||||
      col_params_->task.is_local = parent_->col_params_->task.is_local;
 | 
			
		||||
      col_params_->subdiv_rank = parent_->col_params_->subdiv_rank;
 | 
			
		||||
 | 
			
		||||
      int num_subdivs = static_cast<int>(col_params_.subdiv_rank.size());
 | 
			
		||||
      int group_size = col_params_.group.group_size;
 | 
			
		||||
      int num_subdivs = static_cast<int>(col_params_->subdiv_rank.size());
 | 
			
		||||
      int group_size = col_params_->group.group_size;
 | 
			
		||||
      CHECK_EQ(group_size,
 | 
			
		||||
               static_cast<int>(col_params_.group.device_names.size()));
 | 
			
		||||
               static_cast<int>(col_params_->group.device_names.size()));
 | 
			
		||||
      // Id of this device is at rank position in first subdiv perm.
 | 
			
		||||
      int my_device_id =
 | 
			
		||||
          col_params_.instance.impl_details.subdiv_permutations[0][rank];
 | 
			
		||||
      col_params_.default_rank = my_device_id;
 | 
			
		||||
          col_params_->instance.impl_details.subdiv_permutations[0][rank];
 | 
			
		||||
      col_params_->default_rank = my_device_id;
 | 
			
		||||
      // Set rank for all other subdivs by finding that device_id.
 | 
			
		||||
      for (int sdi = 0; sdi < num_subdivs; ++sdi) {
 | 
			
		||||
        for (int r = 0; r < static_cast<int>(col_params_.instance.impl_details
 | 
			
		||||
        for (int r = 0; r < static_cast<int>(col_params_->instance.impl_details
 | 
			
		||||
                                                 .subdiv_permutations[sdi]
 | 
			
		||||
                                                 .size());
 | 
			
		||||
             ++r) {
 | 
			
		||||
          if (my_device_id ==
 | 
			
		||||
              col_params_.instance.impl_details.subdiv_permutations[sdi][r]) {
 | 
			
		||||
            col_params_.subdiv_rank[sdi] = r;
 | 
			
		||||
              col_params_->instance.impl_details.subdiv_permutations[sdi][r]) {
 | 
			
		||||
            col_params_->subdiv_rank[sdi] = r;
 | 
			
		||||
            break;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ~DeviceInstance() { col_params_->Unref(); }
 | 
			
		||||
 | 
			
		||||
    void InitTensor(DataType dtype, const TensorShape& shape,
 | 
			
		||||
                    const std::function<void(Tensor*)>& init_f) {
 | 
			
		||||
      input_tensor_ =
 | 
			
		||||
@ -464,7 +470,7 @@ class RingGathererTest : public ::testing::Test {
 | 
			
		||||
      AllocatorAttributes generic_alloc_attr;
 | 
			
		||||
      op_params.output_attr_array = &generic_alloc_attr;
 | 
			
		||||
      std::unique_ptr<OpKernel> op = parent_->GetCollectiveGather(
 | 
			
		||||
          col_params_, &input_tensor_, DEVICE_CPU, device_);
 | 
			
		||||
          *col_params_, &input_tensor_, DEVICE_CPU, device_);
 | 
			
		||||
      op_params.op_kernel = op.get();
 | 
			
		||||
      OpKernelContext ctx(&op_params, 1);
 | 
			
		||||
 | 
			
		||||
@ -478,7 +484,7 @@ class RingGathererTest : public ::testing::Test {
 | 
			
		||||
      CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0));
 | 
			
		||||
      // Prepare a RingGatherer instance.
 | 
			
		||||
      string exec_key =
 | 
			
		||||
          strings::StrCat(col_params_.instance.instance_key, ":0:0");
 | 
			
		||||
          strings::StrCat(col_params_->instance.instance_key, ":0:0");
 | 
			
		||||
      RingGatherer* gatherer = new RingGatherer;
 | 
			
		||||
      core::ScopedUnref unref(gatherer);
 | 
			
		||||
      auto col_ctx = std::make_shared<CollectiveContext>(
 | 
			
		||||
@ -507,7 +513,7 @@ class RingGathererTest : public ::testing::Test {
 | 
			
		||||
    Tensor input_tensor_;
 | 
			
		||||
    Tensor output_tensor_;
 | 
			
		||||
    Device* device_;
 | 
			
		||||
    CollectiveParams col_params_;
 | 
			
		||||
    CollectiveParams* col_params_;
 | 
			
		||||
    std::unique_ptr<CollectiveAdapter> ca_;
 | 
			
		||||
    std::unique_ptr<OpKernelContext> ctx_;
 | 
			
		||||
    Status status_;
 | 
			
		||||
@ -521,7 +527,7 @@ class RingGathererTest : public ::testing::Test {
 | 
			
		||||
  std::unique_ptr<DeviceResolverLocal> dev_resolver_;
 | 
			
		||||
  std::shared_ptr<UnboundedWorkQueue> work_queue_;
 | 
			
		||||
  std::vector<DeviceInstance*> instances_;
 | 
			
		||||
  CollectiveParams col_params_;
 | 
			
		||||
  CollectiveParams* col_params_;
 | 
			
		||||
  std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
 | 
			
		||||
  std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
 | 
			
		||||
  std::unique_ptr<string> gpu_ring_order_;
 | 
			
		||||
@ -530,28 +536,28 @@ class RingGathererTest : public ::testing::Test {
 | 
			
		||||
  CancellationManager cancellation_manager_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
CollectiveParams SetUpCollectiveParams(const int num_devs_per_task,
 | 
			
		||||
                                       const int num_tasks) {
 | 
			
		||||
  CollectiveParams cp;
 | 
			
		||||
CollectiveParams* SetUpCollectiveParams(const int num_devs_per_task,
 | 
			
		||||
                                        const int num_tasks) {
 | 
			
		||||
  auto* cp = new CollectiveParams();
 | 
			
		||||
  const int kNumDevs = num_devs_per_task * num_tasks;
 | 
			
		||||
  cp.group.group_key = 1;
 | 
			
		||||
  cp.group.group_size = kNumDevs;
 | 
			
		||||
  cp.group.device_type = DeviceType("GPU");
 | 
			
		||||
  cp.group.num_tasks = num_tasks;
 | 
			
		||||
  cp.instance.instance_key = 3;
 | 
			
		||||
  cp.instance.type = GATHER_COLLECTIVE;
 | 
			
		||||
  cp.instance.data_type = DataType(DT_FLOAT);
 | 
			
		||||
  cp.instance.shape = TensorShape({kNumDevs * kNumDevs});
 | 
			
		||||
  cp.instance.impl_details.collective_name = "RingGather";
 | 
			
		||||
  cp.instance.impl_details.subdiv_offsets.push_back(0);
 | 
			
		||||
  cp.is_source = false;
 | 
			
		||||
  cp->group.group_key = 1;
 | 
			
		||||
  cp->group.group_size = kNumDevs;
 | 
			
		||||
  cp->group.device_type = DeviceType("GPU");
 | 
			
		||||
  cp->group.num_tasks = num_tasks;
 | 
			
		||||
  cp->instance.instance_key = 3;
 | 
			
		||||
  cp->instance.type = GATHER_COLLECTIVE;
 | 
			
		||||
  cp->instance.data_type = DataType(DT_FLOAT);
 | 
			
		||||
  cp->instance.shape = TensorShape({kNumDevs * kNumDevs});
 | 
			
		||||
  cp->instance.impl_details.collective_name = "RingGather";
 | 
			
		||||
  cp->instance.impl_details.subdiv_offsets.push_back(0);
 | 
			
		||||
  cp->is_source = false;
 | 
			
		||||
  for (int i = 0; i < kNumDevs; ++i) {
 | 
			
		||||
    int task_id = i / num_devs_per_task;
 | 
			
		||||
    int dev_id = i % num_devs_per_task;
 | 
			
		||||
    string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id);
 | 
			
		||||
    string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id);
 | 
			
		||||
    cp.group.task_names.push_back(task_name);
 | 
			
		||||
    cp.group.device_names.push_back(device_name);
 | 
			
		||||
    cp->group.task_names.push_back(task_name);
 | 
			
		||||
    cp->group.device_names.push_back(device_name);
 | 
			
		||||
  }
 | 
			
		||||
  return cp;
 | 
			
		||||
}
 | 
			
		||||
@ -559,23 +565,24 @@ CollectiveParams SetUpCollectiveParams(const int num_devs_per_task,
 | 
			
		||||
TEST_F(RingGathererTest, InitializeParams) {
 | 
			
		||||
  const int kNumDevsPerTask = 8;
 | 
			
		||||
  const int kNumTasks = 3;
 | 
			
		||||
  CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
 | 
			
		||||
  CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
 | 
			
		||||
  core::ScopedUnref unref(cp);
 | 
			
		||||
 | 
			
		||||
  cp.default_rank = 0;
 | 
			
		||||
  cp.instance.impl_details.subdiv_offsets = {};
 | 
			
		||||
  RunSubdivPermsTest(&cp, {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
 | 
			
		||||
                            12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
 | 
			
		||||
  cp->default_rank = 0;
 | 
			
		||||
  cp->instance.impl_details.subdiv_offsets = {};
 | 
			
		||||
  RunSubdivPermsTest(cp, {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
 | 
			
		||||
                           12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
 | 
			
		||||
                     {0});
 | 
			
		||||
 | 
			
		||||
  cp.instance.impl_details.subdiv_offsets = {0};
 | 
			
		||||
  RunSubdivPermsTest(&cp, {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
 | 
			
		||||
                            12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
 | 
			
		||||
  cp->instance.impl_details.subdiv_offsets = {0};
 | 
			
		||||
  RunSubdivPermsTest(cp, {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
 | 
			
		||||
                           12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
 | 
			
		||||
                     {0});
 | 
			
		||||
 | 
			
		||||
  cp.default_rank = 3;
 | 
			
		||||
  cp.instance.impl_details.subdiv_offsets = {};
 | 
			
		||||
  RunSubdivPermsTest(&cp, {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
 | 
			
		||||
                            12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
 | 
			
		||||
  cp->default_rank = 3;
 | 
			
		||||
  cp->instance.impl_details.subdiv_offsets = {};
 | 
			
		||||
  RunSubdivPermsTest(cp, {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
 | 
			
		||||
                           12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
 | 
			
		||||
                     {3});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -138,7 +138,8 @@ static int64 kStepId = 123;
 | 
			
		||||
 | 
			
		||||
class RingReducerTest : public ::testing::Test {
 | 
			
		||||
 protected:
 | 
			
		||||
  RingReducerTest() : device_type_(DEVICE_CPU) {}
 | 
			
		||||
  RingReducerTest()
 | 
			
		||||
      : device_type_(DEVICE_CPU), col_exec_(nullptr), col_params_(nullptr) {}
 | 
			
		||||
 | 
			
		||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 | 
			
		||||
  void InitGPUDevices() {
 | 
			
		||||
@ -155,6 +156,7 @@ class RingReducerTest : public ::testing::Test {
 | 
			
		||||
    stop_ = true;
 | 
			
		||||
    for (auto i : instances_) delete i;
 | 
			
		||||
    if (col_exec_) col_exec_->Unref();
 | 
			
		||||
    if (col_params_) col_params_->Unref();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void Init(int num_workers, int num_devices, DataType dtype,
 | 
			
		||||
@ -203,24 +205,25 @@ class RingReducerTest : public ::testing::Test {
 | 
			
		||||
    col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId,
 | 
			
		||||
                                           dev_mgr_.get(),
 | 
			
		||||
                                           gpu_ring_order_.get(), work_queue_);
 | 
			
		||||
    col_params_.name = "test_collective";
 | 
			
		||||
    col_params_ = new CollectiveParams();
 | 
			
		||||
    col_params_->name = "test_collective";
 | 
			
		||||
    static const int kGroupKey = 5;
 | 
			
		||||
    col_params_.group.group_key = kGroupKey;
 | 
			
		||||
    col_params_.group.device_type = device_type;
 | 
			
		||||
    col_params_.group.group_size = num_workers * num_devices;
 | 
			
		||||
    col_params_->group.group_key = kGroupKey;
 | 
			
		||||
    col_params_->group.device_type = device_type;
 | 
			
		||||
    col_params_->group.group_size = num_workers * num_devices;
 | 
			
		||||
    static const int kInstanceKey = 17;
 | 
			
		||||
    col_params_.instance.instance_key = kInstanceKey;
 | 
			
		||||
    col_params_.instance.impl_details.subdiv_offsets.clear();
 | 
			
		||||
    col_params_.instance.type = REDUCTION_COLLECTIVE;
 | 
			
		||||
    col_params_.instance.impl_details.collective_name = "RingReduce";
 | 
			
		||||
    col_params_.instance.data_type = dtype;
 | 
			
		||||
    col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs);
 | 
			
		||||
    col_params_.subdiv_rank.resize(num_subdivs);
 | 
			
		||||
    col_params_->instance.instance_key = kInstanceKey;
 | 
			
		||||
    col_params_->instance.impl_details.subdiv_offsets.clear();
 | 
			
		||||
    col_params_->instance.type = REDUCTION_COLLECTIVE;
 | 
			
		||||
    col_params_->instance.impl_details.collective_name = "RingReduce";
 | 
			
		||||
    col_params_->instance.data_type = dtype;
 | 
			
		||||
    col_params_->instance.impl_details.subdiv_permutations.resize(num_subdivs);
 | 
			
		||||
    col_params_->subdiv_rank.resize(num_subdivs);
 | 
			
		||||
    int subdiv_stride = num_devices / num_subdivs;
 | 
			
		||||
    for (int sdi = 0; sdi < num_subdivs; ++sdi) {
 | 
			
		||||
      col_params_.instance.impl_details.subdiv_offsets.push_back(sdi *
 | 
			
		||||
                                                                 subdiv_stride);
 | 
			
		||||
      col_params_.subdiv_rank[sdi] = sdi * subdiv_stride;
 | 
			
		||||
      col_params_->instance.impl_details.subdiv_offsets.push_back(
 | 
			
		||||
          sdi * subdiv_stride);
 | 
			
		||||
      col_params_->subdiv_rank[sdi] = sdi * subdiv_stride;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Set up a local device ring order that's not just 0,1,2...
 | 
			
		||||
@ -242,23 +245,23 @@ class RingReducerTest : public ::testing::Test {
 | 
			
		||||
    // Set up all of the fake device contexts.
 | 
			
		||||
    for (int wi = 0; wi < num_workers; ++wi) {
 | 
			
		||||
      string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
 | 
			
		||||
      col_params_.group.num_devices_per_task[task_name] = num_devices;
 | 
			
		||||
      col_params_->group.num_devices_per_task[task_name] = num_devices;
 | 
			
		||||
      for (int di = 0; di < num_devices; ++di) {
 | 
			
		||||
        string dev_name = strings::StrCat(task_name, "/cpu:", di);
 | 
			
		||||
        if (device_type == DEVICE_GPU) {
 | 
			
		||||
          dev_name =
 | 
			
		||||
              strings::StrCat(task_name, "/gpu:", di % gpu_devices_.size());
 | 
			
		||||
        }
 | 
			
		||||
        col_params_.group.device_names.push_back(dev_name);
 | 
			
		||||
        col_params_.group.task_names.push_back(task_name);
 | 
			
		||||
        col_params_->group.device_names.push_back(dev_name);
 | 
			
		||||
        col_params_->group.task_names.push_back(task_name);
 | 
			
		||||
        // Normally each device would set is_local to its own perspective but
 | 
			
		||||
        // this test runs in a single process so is_local is always true.
 | 
			
		||||
        col_params_.task.is_local.push_back(true);
 | 
			
		||||
        col_params_->task.is_local.push_back(true);
 | 
			
		||||
        for (int sdi = 0; sdi < num_subdivs; ++sdi) {
 | 
			
		||||
          int rotated_di =
 | 
			
		||||
              (di + col_params_.instance.impl_details.subdiv_offsets[sdi]) %
 | 
			
		||||
              (di + col_params_->instance.impl_details.subdiv_offsets[sdi]) %
 | 
			
		||||
              num_devices;
 | 
			
		||||
          col_params_.instance.impl_details.subdiv_permutations[sdi].push_back(
 | 
			
		||||
          col_params_->instance.impl_details.subdiv_permutations[sdi].push_back(
 | 
			
		||||
              wi * num_devices + local_ring_order[rotated_di]);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
@ -267,7 +270,7 @@ class RingReducerTest : public ::testing::Test {
 | 
			
		||||
      for (int di = 0; di < num_devices; ++di) {
 | 
			
		||||
        int rank = wi * num_devices + di;
 | 
			
		||||
        instances_.push_back(new DeviceInstance(
 | 
			
		||||
            rank, col_params_.group.device_names[rank], device_type_, this));
 | 
			
		||||
            rank, col_params_->group.device_names[rank], device_type_, this));
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
@ -413,39 +416,42 @@ class RingReducerTest : public ::testing::Test {
 | 
			
		||||
        : parent_(parent),
 | 
			
		||||
          dev_name_(dev_name),
 | 
			
		||||
          device_type_(device_type),
 | 
			
		||||
          rank_(rank) {
 | 
			
		||||
          rank_(rank),
 | 
			
		||||
          col_params_(new CollectiveParams()) {
 | 
			
		||||
      TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_))
 | 
			
		||||
          << "Couldn't find device " << dev_name
 | 
			
		||||
          << " existing devices: " << parent_->dev_mgr_->DebugString();
 | 
			
		||||
      col_params_.name = parent_->col_params_.name;
 | 
			
		||||
      col_params_.group = parent_->col_params_.group;
 | 
			
		||||
      col_params_.instance = parent->col_params_.instance;
 | 
			
		||||
      col_params_.task.is_local = parent_->col_params_.task.is_local;
 | 
			
		||||
      col_params_.subdiv_rank = parent_->col_params_.subdiv_rank;
 | 
			
		||||
      col_params_->name = parent_->col_params_->name;
 | 
			
		||||
      col_params_->group = parent_->col_params_->group;
 | 
			
		||||
      col_params_->instance = parent->col_params_->instance;
 | 
			
		||||
      col_params_->task.is_local = parent_->col_params_->task.is_local;
 | 
			
		||||
      col_params_->subdiv_rank = parent_->col_params_->subdiv_rank;
 | 
			
		||||
 | 
			
		||||
      int num_subdivs = static_cast<int>(col_params_.subdiv_rank.size());
 | 
			
		||||
      int group_size = col_params_.group.group_size;
 | 
			
		||||
      int num_subdivs = static_cast<int>(col_params_->subdiv_rank.size());
 | 
			
		||||
      int group_size = col_params_->group.group_size;
 | 
			
		||||
      CHECK_EQ(group_size,
 | 
			
		||||
               static_cast<int>(col_params_.group.device_names.size()));
 | 
			
		||||
               static_cast<int>(col_params_->group.device_names.size()));
 | 
			
		||||
      // Id of this device is at rank position in first subdiv perm.
 | 
			
		||||
      int my_device_id =
 | 
			
		||||
          col_params_.instance.impl_details.subdiv_permutations[0][rank];
 | 
			
		||||
      col_params_.default_rank = my_device_id;
 | 
			
		||||
          col_params_->instance.impl_details.subdiv_permutations[0][rank];
 | 
			
		||||
      col_params_->default_rank = my_device_id;
 | 
			
		||||
      // Set rank for all other subdivs by finding that device_id.
 | 
			
		||||
      for (int sdi = 0; sdi < num_subdivs; ++sdi) {
 | 
			
		||||
        for (int r = 0; r < static_cast<int>(col_params_.instance.impl_details
 | 
			
		||||
        for (int r = 0; r < static_cast<int>(col_params_->instance.impl_details
 | 
			
		||||
                                                 .subdiv_permutations[sdi]
 | 
			
		||||
                                                 .size());
 | 
			
		||||
             ++r) {
 | 
			
		||||
          if (my_device_id ==
 | 
			
		||||
              col_params_.instance.impl_details.subdiv_permutations[sdi][r]) {
 | 
			
		||||
            col_params_.subdiv_rank[sdi] = r;
 | 
			
		||||
              col_params_->instance.impl_details.subdiv_permutations[sdi][r]) {
 | 
			
		||||
            col_params_->subdiv_rank[sdi] = r;
 | 
			
		||||
            break;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ~DeviceInstance() { col_params_->Unref(); }
 | 
			
		||||
 | 
			
		||||
    void InitTensor(DataType dtype, const TensorShape& shape,
 | 
			
		||||
                    const std::function<void(Tensor*)>& init_f) {
 | 
			
		||||
      tensor_ =
 | 
			
		||||
@ -466,10 +472,12 @@ class RingReducerTest : public ::testing::Test {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void DoReduce() {
 | 
			
		||||
      merge_op_ = GetAdd(col_params_.instance.data_type, device_type_, device_);
 | 
			
		||||
      final_op_ = GetDiv(col_params_.instance.data_type, device_type_, device_);
 | 
			
		||||
      col_params_.merge_op = merge_op_.get();
 | 
			
		||||
      col_params_.final_op = final_op_.get();
 | 
			
		||||
      merge_op_ =
 | 
			
		||||
          GetAdd(col_params_->instance.data_type, device_type_, device_);
 | 
			
		||||
      final_op_ =
 | 
			
		||||
          GetDiv(col_params_->instance.data_type, device_type_, device_);
 | 
			
		||||
      col_params_->merge_op = merge_op_.get();
 | 
			
		||||
      col_params_->final_op = final_op_.get();
 | 
			
		||||
 | 
			
		||||
      // Prepare an OpKernelContext.
 | 
			
		||||
      OpKernelContext::Params op_params;
 | 
			
		||||
@ -496,7 +504,7 @@ class RingReducerTest : public ::testing::Test {
 | 
			
		||||
      AllocatorAttributes generic_alloc_attr;
 | 
			
		||||
      op_params.output_attr_array = &generic_alloc_attr;
 | 
			
		||||
      std::unique_ptr<OpKernel> op = parent_->GetCollectiveReduce(
 | 
			
		||||
          col_params_, &tensor_, DEVICE_CPU, device_);
 | 
			
		||||
          *col_params_, &tensor_, DEVICE_CPU, device_);
 | 
			
		||||
      op_params.op_kernel = op.get();
 | 
			
		||||
      OpKernelContext ctx(&op_params, 1);
 | 
			
		||||
 | 
			
		||||
@ -509,7 +517,7 @@ class RingReducerTest : public ::testing::Test {
 | 
			
		||||
 | 
			
		||||
      // Prepare a RingReducer instance.
 | 
			
		||||
      string exec_key =
 | 
			
		||||
          strings::StrCat(col_params_.instance.instance_key, ":0:0");
 | 
			
		||||
          strings::StrCat(col_params_->instance.instance_key, ":0:0");
 | 
			
		||||
      RingReducer* reducer = new RingReducer;
 | 
			
		||||
      core::ScopedUnref unref(reducer);
 | 
			
		||||
      auto col_ctx = std::make_shared<CollectiveContext>(
 | 
			
		||||
@ -535,7 +543,7 @@ class RingReducerTest : public ::testing::Test {
 | 
			
		||||
    int rank_;
 | 
			
		||||
    Tensor tensor_;
 | 
			
		||||
    Device* device_;
 | 
			
		||||
    CollectiveParams col_params_;
 | 
			
		||||
    CollectiveParams* col_params_;
 | 
			
		||||
    std::unique_ptr<OpKernel> merge_op_;
 | 
			
		||||
    std::unique_ptr<OpKernel> final_op_;
 | 
			
		||||
    std::unique_ptr<CollectiveAdapter> ca_;
 | 
			
		||||
@ -551,7 +559,7 @@ class RingReducerTest : public ::testing::Test {
 | 
			
		||||
  std::unique_ptr<DeviceResolverLocal> dev_resolver_;
 | 
			
		||||
  std::shared_ptr<UnboundedWorkQueue> work_queue_;
 | 
			
		||||
  std::vector<DeviceInstance*> instances_;
 | 
			
		||||
  CollectiveParams col_params_;
 | 
			
		||||
  CollectiveParams* col_params_;
 | 
			
		||||
  std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
 | 
			
		||||
  std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
 | 
			
		||||
  std::unique_ptr<string> gpu_ring_order_;
 | 
			
		||||
@ -560,28 +568,28 @@ class RingReducerTest : public ::testing::Test {
 | 
			
		||||
  CancellationManager cancellation_manager_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
CollectiveParams SetUpCollectiveParams(const int num_devs_per_task,
 | 
			
		||||
                                       const int num_tasks) {
 | 
			
		||||
  CollectiveParams cp;
 | 
			
		||||
CollectiveParams* SetUpCollectiveParams(const int num_devs_per_task,
 | 
			
		||||
                                        const int num_tasks) {
 | 
			
		||||
  auto cp = new CollectiveParams();
 | 
			
		||||
  const int kNumDevs = num_devs_per_task * num_tasks;
 | 
			
		||||
  cp.group.group_key = 1;
 | 
			
		||||
  cp.group.group_size = kNumDevs;
 | 
			
		||||
  cp.group.device_type = DeviceType("GPU");
 | 
			
		||||
  cp.group.num_tasks = num_tasks;
 | 
			
		||||
  cp.instance.instance_key = 3;
 | 
			
		||||
  cp.instance.type = REDUCTION_COLLECTIVE;
 | 
			
		||||
  cp.instance.data_type = DataType(DT_FLOAT);
 | 
			
		||||
  cp.instance.shape = TensorShape({kNumDevs});
 | 
			
		||||
  cp.instance.impl_details.collective_name = "RingReduce";
 | 
			
		||||
  cp.instance.impl_details.subdiv_offsets.push_back(0);
 | 
			
		||||
  cp.is_source = false;
 | 
			
		||||
  cp->group.group_key = 1;
 | 
			
		||||
  cp->group.group_size = kNumDevs;
 | 
			
		||||
  cp->group.device_type = DeviceType("GPU");
 | 
			
		||||
  cp->group.num_tasks = num_tasks;
 | 
			
		||||
  cp->instance.instance_key = 3;
 | 
			
		||||
  cp->instance.type = REDUCTION_COLLECTIVE;
 | 
			
		||||
  cp->instance.data_type = DataType(DT_FLOAT);
 | 
			
		||||
  cp->instance.shape = TensorShape({kNumDevs});
 | 
			
		||||
  cp->instance.impl_details.collective_name = "RingReduce";
 | 
			
		||||
  cp->instance.impl_details.subdiv_offsets.push_back(0);
 | 
			
		||||
  cp->is_source = false;
 | 
			
		||||
  for (int i = 0; i < kNumDevs; ++i) {
 | 
			
		||||
    int task_id = i / num_devs_per_task;
 | 
			
		||||
    int dev_id = i % num_devs_per_task;
 | 
			
		||||
    string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id);
 | 
			
		||||
    string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id);
 | 
			
		||||
    cp.group.task_names.push_back(task_name);
 | 
			
		||||
    cp.group.device_names.push_back(device_name);
 | 
			
		||||
    cp->group.task_names.push_back(task_name);
 | 
			
		||||
    cp->group.device_names.push_back(device_name);
 | 
			
		||||
  }
 | 
			
		||||
  return cp;
 | 
			
		||||
}
 | 
			
		||||
@ -589,28 +597,29 @@ CollectiveParams SetUpCollectiveParams(const int num_devs_per_task,
 | 
			
		||||
TEST_F(RingReducerTest, InitializeParams) {
 | 
			
		||||
  const int kNumDevsPerTask = 8;
 | 
			
		||||
  const int kNumTasks = 3;
 | 
			
		||||
  CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
 | 
			
		||||
  CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
 | 
			
		||||
  core::ScopedUnref unref(cp);
 | 
			
		||||
 | 
			
		||||
  cp.default_rank = 0;
 | 
			
		||||
  cp.instance.impl_details.subdiv_offsets = {0, 4};
 | 
			
		||||
  RunSubdivPermsTest(&cp,
 | 
			
		||||
  cp->default_rank = 0;
 | 
			
		||||
  cp->instance.impl_details.subdiv_offsets = {0, 4};
 | 
			
		||||
  RunSubdivPermsTest(cp,
 | 
			
		||||
                     {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
 | 
			
		||||
                       12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
 | 
			
		||||
                      {4, 5, 6,  7,  0,  1,  2,  3,  12, 13, 14, 15,
 | 
			
		||||
                       8, 9, 10, 11, 20, 21, 22, 23, 16, 17, 18, 19}},
 | 
			
		||||
                     {0, 4});
 | 
			
		||||
 | 
			
		||||
  cp.instance.impl_details.subdiv_offsets = {0, -4};
 | 
			
		||||
  RunSubdivPermsTest(&cp,
 | 
			
		||||
  cp->instance.impl_details.subdiv_offsets = {0, -4};
 | 
			
		||||
  RunSubdivPermsTest(cp,
 | 
			
		||||
                     {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
 | 
			
		||||
                       12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
 | 
			
		||||
                      {3,  2,  1,  0,  7,  6,  5,  4,  11, 10, 9,  8,
 | 
			
		||||
                       15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20}},
 | 
			
		||||
                     {0, 3});
 | 
			
		||||
 | 
			
		||||
  cp.default_rank = 3;
 | 
			
		||||
  cp.instance.impl_details.subdiv_offsets = {3, -3};
 | 
			
		||||
  RunSubdivPermsTest(&cp,
 | 
			
		||||
  cp->default_rank = 3;
 | 
			
		||||
  cp->instance.impl_details.subdiv_offsets = {3, -3};
 | 
			
		||||
  RunSubdivPermsTest(cp,
 | 
			
		||||
                     {{3,  4, 5, 6,  7,  0,  1,  2,  11, 12, 13, 14,
 | 
			
		||||
                       15, 8, 9, 10, 19, 20, 21, 22, 23, 16, 17, 18},
 | 
			
		||||
                      {4, 3,  2,  1,  0,  7,  6,  5,  12, 11, 10, 9,
 | 
			
		||||
@ -622,13 +631,14 @@ TEST_F(RingReducerTest, AutomaticSubdivs) {
 | 
			
		||||
  const int kNumDevsPerTask = 8;
 | 
			
		||||
  const int kNumTasks = 3;
 | 
			
		||||
  const int kNumDevs = kNumDevsPerTask * kNumTasks;
 | 
			
		||||
  CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
 | 
			
		||||
  CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
 | 
			
		||||
  core::ScopedUnref unref(cp);
 | 
			
		||||
 | 
			
		||||
  // Test automatic generation of subdiv offsets.
 | 
			
		||||
  cp.default_rank = 0;
 | 
			
		||||
  cp.instance.impl_details.subdiv_offsets.clear();
 | 
			
		||||
  RunSubdivPermsTest(&cp, {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
 | 
			
		||||
                            12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
 | 
			
		||||
  cp->default_rank = 0;
 | 
			
		||||
  cp->instance.impl_details.subdiv_offsets.clear();
 | 
			
		||||
  RunSubdivPermsTest(cp, {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
 | 
			
		||||
                           12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
 | 
			
		||||
                     {0});
 | 
			
		||||
 | 
			
		||||
  // Set shape so that with 2 subdivs chunk_size is 3 MiB.  This should cause 2
 | 
			
		||||
@ -638,11 +648,11 @@ TEST_F(RingReducerTest, AutomaticSubdivs) {
 | 
			
		||||
    int num_chunks = kNumDevs * num_subdivs;
 | 
			
		||||
    size_t chunk_size = 3 * 1048576;  // 3 MB
 | 
			
		||||
    size_t tensor_size = chunk_size * num_chunks;
 | 
			
		||||
    cp.instance.shape =
 | 
			
		||||
    cp->instance.shape =
 | 
			
		||||
        TensorShape({static_cast<int64>(tensor_size / DataTypeSize(DT_FLOAT))});
 | 
			
		||||
  }
 | 
			
		||||
  cp.instance.impl_details.subdiv_offsets.clear();
 | 
			
		||||
  RunSubdivPermsTest(&cp,
 | 
			
		||||
  cp->instance.impl_details.subdiv_offsets.clear();
 | 
			
		||||
  RunSubdivPermsTest(cp,
 | 
			
		||||
                     {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
 | 
			
		||||
                       12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
 | 
			
		||||
                      {3,  2,  1,  0,  7,  6,  5,  4,  11, 10, 9,  8,
 | 
			
		||||
@ -653,12 +663,13 @@ TEST_F(RingReducerTest, AutomaticSubdivs) {
 | 
			
		||||
TEST_F(RingReducerTest, AutomaticSubdivUpperBound) {
 | 
			
		||||
  const int kNumDevsPerTask = 1;
 | 
			
		||||
  const int kNumTasks = 4;
 | 
			
		||||
  CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
 | 
			
		||||
  CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
 | 
			
		||||
  core::ScopedUnref unref(cp);
 | 
			
		||||
 | 
			
		||||
  cp.default_rank = 0;
 | 
			
		||||
  cp.instance.impl_details.subdiv_offsets.clear();
 | 
			
		||||
  cp.instance.shape = TensorShape({104857600 / DataTypeSize(DT_FLOAT)});
 | 
			
		||||
  RunSubdivPermsTest(&cp, {{0, 1, 2, 3}, {0, 1, 2, 3}}, {0, 0});
 | 
			
		||||
  cp->default_rank = 0;
 | 
			
		||||
  cp->instance.impl_details.subdiv_offsets.clear();
 | 
			
		||||
  cp->instance.shape = TensorShape({104857600 / DataTypeSize(DT_FLOAT)});
 | 
			
		||||
  RunSubdivPermsTest(cp, {{0, 1, 2, 3}, {0, 1, 2, 3}}, {0, 0});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO(b/113171733): change to use TEST_P.
 | 
			
		||||
 | 
			
		||||
@ -137,13 +137,14 @@ void CollectiveParamResolverDistributed::CompleteGroupAsync(
 | 
			
		||||
        "running the same version of Tensorflow on all workers."));
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
  CollectiveParams cp;
 | 
			
		||||
  cp.group.group_key = request->group_key();
 | 
			
		||||
  cp.group.group_size = request->group_size();
 | 
			
		||||
  cp.group.device_type = DeviceType(request->device_type());
 | 
			
		||||
  cp.instance.type = CollectiveType(request->collective_type());
 | 
			
		||||
  auto* cp = new CollectiveParams();
 | 
			
		||||
  core::ScopedUnref unref(cp);
 | 
			
		||||
  cp->group.group_key = request->group_key();
 | 
			
		||||
  cp->group.group_size = request->group_size();
 | 
			
		||||
  cp->group.device_type = DeviceType(request->device_type());
 | 
			
		||||
  cp->instance.type = CollectiveType(request->collective_type());
 | 
			
		||||
  CompleteGroupDistributed(
 | 
			
		||||
      request->device_attributes(), &cp, cancel_mgr,
 | 
			
		||||
      request->device_attributes(), cp, cancel_mgr,
 | 
			
		||||
      [response, done](const Status& s, const GroupRec* gr) {
 | 
			
		||||
        if (s.ok()) {
 | 
			
		||||
          mutex_lock l(gr->mu);
 | 
			
		||||
@ -196,7 +197,7 @@ void CollectiveParamResolverDistributed::CompleteInstanceAsync(
 | 
			
		||||
  }
 | 
			
		||||
  StatusCallback done_and_cleanup = [cp, done](const Status& s) {
 | 
			
		||||
    done(s);
 | 
			
		||||
    delete cp;
 | 
			
		||||
    cp->Unref();
 | 
			
		||||
  };
 | 
			
		||||
  CompleteInstanceDistributed(
 | 
			
		||||
      request->device(), gr, cp, cancel_mgr,
 | 
			
		||||
 | 
			
		||||
@ -127,6 +127,13 @@ class FakeCache : public TestWorkerCache {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class DeviceResDistTest : public ::testing::Test {
 | 
			
		||||
 public:
 | 
			
		||||
  ~DeviceResDistTest() override {
 | 
			
		||||
    for (auto& name_param : cp_) {
 | 
			
		||||
      name_param.second->Unref();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  void DefineWorkers(int num_workers, int num_devices,
 | 
			
		||||
                     const string& device_type, bool nccl) {
 | 
			
		||||
@ -181,20 +188,20 @@ class DeviceResDistTest : public ::testing::Test {
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  CollectiveParams CreateCollectiveParams(int num_workers, int num_devices,
 | 
			
		||||
                                          const string& device_type) {
 | 
			
		||||
  CollectiveParams* CreateCollectiveParams(int num_workers, int num_devices,
 | 
			
		||||
                                           const string& device_type) {
 | 
			
		||||
    const int kGroupKey = 5;
 | 
			
		||||
    const int kInstanceKey = 3;
 | 
			
		||||
    CollectiveParams cp;
 | 
			
		||||
    cp.group.group_key = kGroupKey;
 | 
			
		||||
    cp.group.group_size = num_workers * num_devices;
 | 
			
		||||
    cp.group.device_type = DeviceType(device_type);
 | 
			
		||||
    cp.group.num_tasks = num_workers;
 | 
			
		||||
    cp.instance.instance_key = kInstanceKey;
 | 
			
		||||
    cp.instance.type = REDUCTION_COLLECTIVE;
 | 
			
		||||
    cp.instance.data_type = DT_FLOAT;
 | 
			
		||||
    cp.instance.shape = TensorShape({64});
 | 
			
		||||
    cp.instance.impl_details.subdiv_offsets.push_back(0);
 | 
			
		||||
    auto* cp = new CollectiveParams();
 | 
			
		||||
    cp->group.group_key = kGroupKey;
 | 
			
		||||
    cp->group.group_size = num_workers * num_devices;
 | 
			
		||||
    cp->group.device_type = DeviceType(device_type);
 | 
			
		||||
    cp->group.num_tasks = num_workers;
 | 
			
		||||
    cp->instance.instance_key = kInstanceKey;
 | 
			
		||||
    cp->instance.type = REDUCTION_COLLECTIVE;
 | 
			
		||||
    cp->instance.data_type = DT_FLOAT;
 | 
			
		||||
    cp->instance.shape = TensorShape({64});
 | 
			
		||||
    cp->instance.impl_details.subdiv_offsets.push_back(0);
 | 
			
		||||
    return cp;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -217,7 +224,7 @@ class DeviceResDistTest : public ::testing::Test {
 | 
			
		||||
                    int group_size) {
 | 
			
		||||
    Device* device = nullptr;
 | 
			
		||||
    TF_CHECK_OK(device_mgrs_[task_name]->LookupDevice(device_name, &device));
 | 
			
		||||
    CollectiveParams* cp = &cp_[device_name];
 | 
			
		||||
    CollectiveParams* cp = cp_[device_name];
 | 
			
		||||
    CollectiveParamResolverDistributed* cp_res = cp_resolvers_[task_name].get();
 | 
			
		||||
    CHECK(cp_res);
 | 
			
		||||
    cp_res->CompleteParamsAsync(
 | 
			
		||||
@ -252,19 +259,19 @@ class DeviceResDistTest : public ::testing::Test {
 | 
			
		||||
        string device_name = strings::StrCat(task_name, "/device:CPU:", di);
 | 
			
		||||
        int idx = wi * num_devices + di;
 | 
			
		||||
        TF_ASSERT_OK(status_[device_name]);
 | 
			
		||||
        EXPECT_EQ(cp_[device_name].default_rank, idx);
 | 
			
		||||
        EXPECT_EQ(cp_[device_name].group.device_names.size(), dev_count);
 | 
			
		||||
        EXPECT_EQ(cp_[device_name].group.device_names[idx], device_name);
 | 
			
		||||
        EXPECT_EQ(cp_[device_name].group.task_names[idx], task_name);
 | 
			
		||||
        ValidateDeviceResolver(cp_[device_name], task_name);
 | 
			
		||||
        EXPECT_EQ(cp_[device_name]->default_rank, idx);
 | 
			
		||||
        EXPECT_EQ(cp_[device_name]->group.device_names.size(), dev_count);
 | 
			
		||||
        EXPECT_EQ(cp_[device_name]->group.device_names[idx], device_name);
 | 
			
		||||
        EXPECT_EQ(cp_[device_name]->group.task_names[idx], task_name);
 | 
			
		||||
        ValidateDeviceResolver(*cp_[device_name], task_name);
 | 
			
		||||
        if (idx > 0) {
 | 
			
		||||
          EXPECT_EQ(cp_[dev0].group.runtime_details.communicator_key,
 | 
			
		||||
                    cp_[device_name].group.runtime_details.communicator_key);
 | 
			
		||||
          EXPECT_EQ(cp_[dev0]->group.runtime_details.communicator_key,
 | 
			
		||||
                    cp_[device_name]->group.runtime_details.communicator_key);
 | 
			
		||||
          for (int i = 0; i < dev_count; ++i) {
 | 
			
		||||
            EXPECT_EQ(cp_[dev0].group.device_names[i],
 | 
			
		||||
                      cp_[device_name].group.device_names[i]);
 | 
			
		||||
            EXPECT_EQ(cp_[dev0].group.task_names[i],
 | 
			
		||||
                      cp_[device_name].group.task_names[i]);
 | 
			
		||||
            EXPECT_EQ(cp_[dev0]->group.device_names[i],
 | 
			
		||||
                      cp_[device_name]->group.device_names[i]);
 | 
			
		||||
            EXPECT_EQ(cp_[dev0]->group.task_names[i],
 | 
			
		||||
                      cp_[device_name]->group.task_names[i]);
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
@ -287,6 +294,9 @@ class DeviceResDistTest : public ::testing::Test {
 | 
			
		||||
    for (int i = 0; i < num_devices; ++i) {
 | 
			
		||||
      string device_name =
 | 
			
		||||
          strings::StrCat(worker_name, "/device:", device_type, ":", i);
 | 
			
		||||
      if (cp_.find(device_name) != cp_.end()) {
 | 
			
		||||
        cp_[device_name]->Unref();
 | 
			
		||||
      }
 | 
			
		||||
      cp_[device_name] =
 | 
			
		||||
          CreateCollectiveParams(num_workers, num_devices, device_type);
 | 
			
		||||
      status_.erase(device_name);
 | 
			
		||||
@ -305,7 +315,7 @@ class DeviceResDistTest : public ::testing::Test {
 | 
			
		||||
  absl::flat_hash_map<string, std::vector<string>> dev_by_task_;
 | 
			
		||||
  absl::flat_hash_map<string, std::unique_ptr<FakeWorker>> workers_;
 | 
			
		||||
  // Below are keyed by device names;
 | 
			
		||||
  absl::flat_hash_map<string, CollectiveParams> cp_;
 | 
			
		||||
  absl::flat_hash_map<string, CollectiveParams*> cp_;
 | 
			
		||||
  absl::flat_hash_map<string, Status> status_;
 | 
			
		||||
  mutex mu_;
 | 
			
		||||
  int num_done_ TF_GUARDED_BY(mu_);
 | 
			
		||||
 | 
			
		||||
@ -169,7 +169,7 @@ string CollectiveParams::ToString() const {
 | 
			
		||||
CollectiveContext::CollectiveContext(
 | 
			
		||||
    CollectiveExecutor* col_exec, NcclCommunicatorInterface* nccl_communicator,
 | 
			
		||||
    const DeviceMgr* dev_mgr, OpKernelContext* ctx,
 | 
			
		||||
    OpKernelContext::Params* op_params, const CollectiveParams& col_params,
 | 
			
		||||
    OpKernelContext::Params* op_params, const CollectiveParams* col_params,
 | 
			
		||||
    const string& exec_key, int64 step_id, const Tensor* input, Tensor* output)
 | 
			
		||||
    : col_exec(col_exec),
 | 
			
		||||
      nccl_communicator(nccl_communicator),
 | 
			
		||||
@ -182,7 +182,7 @@ CollectiveContext::CollectiveContext(
 | 
			
		||||
      input(input),
 | 
			
		||||
      output(output),
 | 
			
		||||
      device(nullptr),
 | 
			
		||||
      device_name(col_params.group.device_names[col_params.default_rank]) {}
 | 
			
		||||
      device_name(col_params->group.device_names[col_params->default_rank]) {}
 | 
			
		||||
 | 
			
		||||
/*static*/
 | 
			
		||||
int64 CollectiveExecutor::kInvalidId = -1;
 | 
			
		||||
 | 
			
		||||
@ -132,7 +132,7 @@ struct CollTaskParams {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Unique to a single CollectiveOp node.
 | 
			
		||||
struct CollectiveParams {
 | 
			
		||||
struct CollectiveParams : public core::RefCounted {
 | 
			
		||||
  CollGroupParams group;
 | 
			
		||||
  CollInstanceParams instance;
 | 
			
		||||
  CollTaskParams task;
 | 
			
		||||
@ -298,7 +298,7 @@ class CollectiveExecutor : public core::RefCounted {
 | 
			
		||||
  virtual void StartAbort(const Status& s) {}
 | 
			
		||||
 | 
			
		||||
  virtual void ExecuteAsync(OpKernelContext* ctx,
 | 
			
		||||
                            const CollectiveParams& col_params,
 | 
			
		||||
                            const CollectiveParams* col_params,
 | 
			
		||||
                            const string& exec_key, StatusCallback done) {
 | 
			
		||||
    done(errors::Internal(
 | 
			
		||||
        "A collective Op has been called in a context in which "
 | 
			
		||||
@ -367,7 +367,7 @@ struct CollectiveContext {
 | 
			
		||||
  const DeviceMgr* dev_mgr;                      // Not owned
 | 
			
		||||
  OpKernelContext* op_ctx;                       // Not owned
 | 
			
		||||
  OpKernelContext::Params* op_params;            // Not owned
 | 
			
		||||
  const CollectiveParams& col_params;
 | 
			
		||||
  const CollectiveParams* col_params;            // Not owned
 | 
			
		||||
  const string exec_key;
 | 
			
		||||
  const int64 step_id;
 | 
			
		||||
  const Tensor* input;  // Not owned
 | 
			
		||||
@ -380,7 +380,7 @@ struct CollectiveContext {
 | 
			
		||||
                    NcclCommunicatorInterface* nccl_communicator,
 | 
			
		||||
                    const DeviceMgr* dev_mgr, OpKernelContext* ctx,
 | 
			
		||||
                    OpKernelContext::Params* op_params,
 | 
			
		||||
                    const CollectiveParams& col_params, const string& exec_key,
 | 
			
		||||
                    const CollectiveParams* col_params, const string& exec_key,
 | 
			
		||||
                    int64 step_id, const Tensor* input, Tensor* output);
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -61,7 +61,7 @@ Status NcclBase::InitializeCollectiveParams(CollectiveParams* col_params) {
 | 
			
		||||
Status NcclBase::InitializeCollectiveContext(
 | 
			
		||||
    std::shared_ptr<CollectiveContext> col_ctx) {
 | 
			
		||||
  col_ctx_ = col_ctx;
 | 
			
		||||
  col_params_ = &col_ctx->col_params;
 | 
			
		||||
  col_params_ = col_ctx->col_params;
 | 
			
		||||
  return collective_util::InitializeDeviceAndLocality(
 | 
			
		||||
      col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
 | 
			
		||||
      &col_ctx->device_locality);
 | 
			
		||||
 | 
			
		||||
@ -88,6 +88,9 @@ void NcclReducer::Run(StatusCallback done) {
 | 
			
		||||
  } else {
 | 
			
		||||
    done_callback = std::move(done);
 | 
			
		||||
  }
 | 
			
		||||
  // Hold a ref to col_params for the rest of this function.
 | 
			
		||||
  col_params_->Ref();
 | 
			
		||||
  core::ScopedUnref unref(col_params_);
 | 
			
		||||
  col_ctx_->nccl_communicator->Enqueue(col_ctx_, std::move(done_callback));
 | 
			
		||||
 | 
			
		||||
  // If no final_op, then this OpKernel is non-blocking.
 | 
			
		||||
 | 
			
		||||
@ -53,7 +53,9 @@ static std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c,
 | 
			
		||||
class CollectiveOpV1Kernel : public AsyncOpKernel {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit CollectiveOpV1Kernel(OpKernelConstruction* c)
 | 
			
		||||
      : AsyncOpKernel(c), name_(name()) {}
 | 
			
		||||
      : AsyncOpKernel(c), name_(name()), col_params_(new CollectiveParams()) {}
 | 
			
		||||
 | 
			
		||||
  ~CollectiveOpV1Kernel() override { col_params_->Unref(); }
 | 
			
		||||
 | 
			
		||||
  void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
 | 
			
		||||
    CollectiveExecutor* col_exec = c->collective_executor();
 | 
			
		||||
@ -88,28 +90,31 @@ class CollectiveOpV1Kernel : public AsyncOpKernel {
 | 
			
		||||
  // A string encoding instance, frame and iter to be handed off to
 | 
			
		||||
  // the implementation for use in generating RecvBuf keys.
 | 
			
		||||
  string GetCollectiveKey(OpKernelContext* c) {
 | 
			
		||||
    return CollectiveKey(c, col_params_.group.group_key,
 | 
			
		||||
                         col_params_.instance.instance_key);
 | 
			
		||||
    return CollectiveKey(c, col_params_->group.group_key,
 | 
			
		||||
                         col_params_->instance.instance_key);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Returns false if calling invocation of ComputeAsync should return
 | 
			
		||||
  // immediately.
 | 
			
		||||
  bool CanProceedWithCompute(OpKernelContext* c, CollectiveExecutor* col_exec,
 | 
			
		||||
                             const DoneCallback& done) {
 | 
			
		||||
    if (col_params_.group.group_size > col_params_.group.device_names.size()) {
 | 
			
		||||
    if (col_params_->group.group_size >
 | 
			
		||||
        col_params_->group.device_names.size()) {
 | 
			
		||||
      // This is the first invocation: Finish initializing col_params_.
 | 
			
		||||
      // Schedule the `CompleteParamsAsync` call on a work queue that can handle
 | 
			
		||||
      // blocking work because it's not guaranteed that this call cannot block.
 | 
			
		||||
      c->collective_executor()->RunClosure([this, c, done, col_exec]() {
 | 
			
		||||
      col_params_->Ref();
 | 
			
		||||
      c->collective_executor()->RunClosure([this, c, col_exec, done]() {
 | 
			
		||||
        VLOG(1) << "CollectiveOpKernel CompleteParams for collective "
 | 
			
		||||
                << col_params_.name << " device " << c->device()->name()
 | 
			
		||||
                << " group " << col_params_.group.group_key << " instance "
 | 
			
		||||
                << col_params_.instance.instance_key;
 | 
			
		||||
                << col_params_->name << " device " << c->device()->name()
 | 
			
		||||
                << " group " << col_params_->group.group_key << " instance "
 | 
			
		||||
                << col_params_->instance.instance_key;
 | 
			
		||||
        col_exec->CompleteParamsAsync(
 | 
			
		||||
            c->device()->attributes(), &col_params_, c->cancellation_manager(),
 | 
			
		||||
            c->device()->attributes(), col_params_, c->cancellation_manager(),
 | 
			
		||||
            [this, c, done](const Status& s) {
 | 
			
		||||
              core::ScopedUnref unref(col_params_);
 | 
			
		||||
              if (s.ok()) {
 | 
			
		||||
                col_params_.instance.impl_details.dependencies = dependencies_;
 | 
			
		||||
                col_params_->instance.impl_details.dependencies = dependencies_;
 | 
			
		||||
                ComputeAsync(c, done);
 | 
			
		||||
              } else {
 | 
			
		||||
                c->SetStatus(s);
 | 
			
		||||
@ -128,7 +133,7 @@ class CollectiveOpV1Kernel : public AsyncOpKernel {
 | 
			
		||||
                                DoneCallback done) = 0;
 | 
			
		||||
 | 
			
		||||
  string name_;
 | 
			
		||||
  CollectiveParams col_params_;
 | 
			
		||||
  CollectiveParams* col_params_;
 | 
			
		||||
  std::vector<int32> dependencies_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@ -136,25 +141,25 @@ class CollectiveGatherOpKernel : public CollectiveOpV1Kernel {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit CollectiveGatherOpKernel(OpKernelConstruction* c)
 | 
			
		||||
      : CollectiveOpV1Kernel(c) {
 | 
			
		||||
    col_params_.instance.type = GATHER_COLLECTIVE;
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
 | 
			
		||||
    col_params_->instance.type = GATHER_COLLECTIVE;
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
 | 
			
		||||
    OP_REQUIRES(
 | 
			
		||||
        c, col_params_.group.group_size > 0,
 | 
			
		||||
        c, col_params_->group.group_size > 0,
 | 
			
		||||
        errors::InvalidArgument("group_size must be positive integer but got ",
 | 
			
		||||
                                col_params_.group.group_size));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
 | 
			
		||||
                                col_params_->group.group_size));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
 | 
			
		||||
    OP_REQUIRES_OK(
 | 
			
		||||
        c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
 | 
			
		||||
        c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
 | 
			
		||||
    OP_REQUIRES_OK(
 | 
			
		||||
        c, c->GetAttr("communication_hint",
 | 
			
		||||
                      &col_params_.instance.impl_details.communication_hint));
 | 
			
		||||
                      &col_params_->instance.impl_details.communication_hint));
 | 
			
		||||
    OP_REQUIRES_OK(
 | 
			
		||||
        c, c->GetAttr("timeout_seconds",
 | 
			
		||||
                      &col_params_.instance.impl_details.timeout_seconds));
 | 
			
		||||
                      &col_params_->instance.impl_details.timeout_seconds));
 | 
			
		||||
    const NodeDef& real_node = c->def();
 | 
			
		||||
    col_params_.name = strings::StrCat(real_node.name(), ": Gather");
 | 
			
		||||
    col_params_.group.device_type = c->device_type();
 | 
			
		||||
    col_params_->name = strings::StrCat(real_node.name(), ": Gather");
 | 
			
		||||
    col_params_->group.device_type = c->device_type();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
@ -162,8 +167,8 @@ class CollectiveGatherOpKernel : public CollectiveOpV1Kernel {
 | 
			
		||||
                        DoneCallback done) override {
 | 
			
		||||
    auto output_shape = c->input(0).shape();
 | 
			
		||||
    output_shape.set_dim(
 | 
			
		||||
        0, output_shape.dim_size(0) * col_params_.group.group_size);
 | 
			
		||||
    col_params_.instance.shape = output_shape;
 | 
			
		||||
        0, output_shape.dim_size(0) * col_params_->group.group_size);
 | 
			
		||||
    col_params_->instance.shape = output_shape;
 | 
			
		||||
 | 
			
		||||
    // Allocate output on the first pass through this function.  This must be
 | 
			
		||||
    // done immediately, while we're still in the executor thread.  Otherwise
 | 
			
		||||
@ -173,24 +178,24 @@ class CollectiveGatherOpKernel : public CollectiveOpV1Kernel {
 | 
			
		||||
      // Allocate the output tensor.
 | 
			
		||||
      Tensor* output = nullptr;
 | 
			
		||||
      OP_REQUIRES_OK_ASYNC(
 | 
			
		||||
          c, c->allocate_output(0, col_params_.instance.shape, &output), done);
 | 
			
		||||
          c, c->allocate_output(0, col_params_->instance.shape, &output), done);
 | 
			
		||||
    }
 | 
			
		||||
    if (!CanProceedWithCompute(c, col_exec, done)) return;
 | 
			
		||||
 | 
			
		||||
    auto actual_done = [c, group_key = col_params_.group.group_key,
 | 
			
		||||
                        instance_key = col_params_.instance.instance_key,
 | 
			
		||||
                        done](const Status& s) {
 | 
			
		||||
    auto actual_done = [c, col_params = col_params_, done](const Status& s) {
 | 
			
		||||
      VLOG(1) << "CollectiveGatherOpKernel ExecuteAsync done for collective "
 | 
			
		||||
              << c->op_kernel().name() << " device " << c->device()->name()
 | 
			
		||||
              << " group " << group_key << " instance " << instance_key
 | 
			
		||||
              << " status " << s;
 | 
			
		||||
              << " group " << col_params->group.group_key << " instance "
 | 
			
		||||
              << col_params->instance.instance_key << " status " << s;
 | 
			
		||||
      OP_REQUIRES_OK_ASYNC(c, s, done);
 | 
			
		||||
      done();
 | 
			
		||||
      col_params->Unref();
 | 
			
		||||
    };
 | 
			
		||||
    VLOG(1) << "CollectiveGatherOpKernel ExecuteAsync start for collective "
 | 
			
		||||
            << col_params_.name << " device " << c->device()->name()
 | 
			
		||||
            << " group " << col_params_.group.group_key << " instance "
 | 
			
		||||
            << col_params_.instance.instance_key;
 | 
			
		||||
            << col_params_->name << " device " << c->device()->name()
 | 
			
		||||
            << " group " << col_params_->group.group_key << " instance "
 | 
			
		||||
            << col_params_->instance.instance_key;
 | 
			
		||||
    col_params_->Ref();
 | 
			
		||||
    col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -207,18 +212,18 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit CollectiveReduceOpKernel(OpKernelConstruction* c)
 | 
			
		||||
      : CollectiveOpV1Kernel(c) {
 | 
			
		||||
    col_params_.instance.type = REDUCTION_COLLECTIVE;
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
 | 
			
		||||
    col_params_->instance.type = REDUCTION_COLLECTIVE;
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
 | 
			
		||||
    OP_REQUIRES(
 | 
			
		||||
        c, col_params_.group.group_size > 0,
 | 
			
		||||
        c, col_params_->group.group_size > 0,
 | 
			
		||||
        errors::InvalidArgument("group_size must be positive integer but got ",
 | 
			
		||||
                                col_params_.group.group_size));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
 | 
			
		||||
                                col_params_->group.group_size));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
 | 
			
		||||
    OP_REQUIRES_OK(
 | 
			
		||||
        c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
 | 
			
		||||
        c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
 | 
			
		||||
    OP_REQUIRES_OK(
 | 
			
		||||
        c, c->GetAttr("subdiv_offsets",
 | 
			
		||||
                      &col_params_.instance.impl_details.subdiv_offsets));
 | 
			
		||||
                      &col_params_->instance.impl_details.subdiv_offsets));
 | 
			
		||||
    string merge_op_name;
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
 | 
			
		||||
    if (merge_op_name == "Max") {
 | 
			
		||||
@ -232,24 +237,26 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
 | 
			
		||||
                errors::InvalidArgument(
 | 
			
		||||
                    "final_op must be one of {\"Id\", \"Div\"} but got ",
 | 
			
		||||
                    final_op_name));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("wait_for", &dependencies_));
 | 
			
		||||
    OP_REQUIRES_OK(
 | 
			
		||||
        c, c->GetAttr("communication_hint",
 | 
			
		||||
                      &col_params_.instance.impl_details.communication_hint));
 | 
			
		||||
                      &col_params_->instance.impl_details.communication_hint));
 | 
			
		||||
    OP_REQUIRES_OK(
 | 
			
		||||
        c, c->GetAttr("timeout_seconds",
 | 
			
		||||
                      &col_params_.instance.impl_details.timeout_seconds));
 | 
			
		||||
    VLOG(2) << "CollectiveReduce instance " << col_params_.instance.instance_key
 | 
			
		||||
            << " merge_op " << merge_op_name << " final_op " << final_op_name
 | 
			
		||||
                      &col_params_->instance.impl_details.timeout_seconds));
 | 
			
		||||
    VLOG(2) << "CollectiveReduce instance "
 | 
			
		||||
            << col_params_->instance.instance_key << " merge_op "
 | 
			
		||||
            << merge_op_name << " final_op " << final_op_name
 | 
			
		||||
            << " communication_hint "
 | 
			
		||||
            << col_params_.instance.impl_details.communication_hint
 | 
			
		||||
            << " timeout " << col_params_.instance.impl_details.timeout_seconds;
 | 
			
		||||
            << col_params_->instance.impl_details.communication_hint
 | 
			
		||||
            << " timeout "
 | 
			
		||||
            << col_params_->instance.impl_details.timeout_seconds;
 | 
			
		||||
 | 
			
		||||
    const NodeDef& real_node = c->def();
 | 
			
		||||
    col_params_.name = strings::StrCat(real_node.name(), ": Reduce(",
 | 
			
		||||
                                       merge_op_name, ",", final_op_name, ")");
 | 
			
		||||
    col_params_.group.device_type = c->device_type();
 | 
			
		||||
    col_params_->name = strings::StrCat(real_node.name(), ": Reduce(",
 | 
			
		||||
                                        merge_op_name, ",", final_op_name, ")");
 | 
			
		||||
    col_params_->group.device_type = c->device_type();
 | 
			
		||||
 | 
			
		||||
    // Find the OpKernels by name, type and device type.
 | 
			
		||||
    NodeDef sub_node;
 | 
			
		||||
@ -257,12 +264,12 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
 | 
			
		||||
    sub_node.add_input(real_node.input(0));
 | 
			
		||||
    sub_node.add_input(real_node.input(0));
 | 
			
		||||
    sub_node.set_device(real_node.device());
 | 
			
		||||
    SetAttrValue(col_params_.instance.data_type,
 | 
			
		||||
    SetAttrValue(col_params_->instance.data_type,
 | 
			
		||||
                 &(*sub_node.mutable_attr())["T"]);
 | 
			
		||||
    merge_op_ = BuildOpKernel(c, merge_op_name, &sub_node);
 | 
			
		||||
    final_op_ = BuildOpKernel(c, final_op_name, &sub_node);
 | 
			
		||||
    col_params_.merge_op = merge_op_.get();
 | 
			
		||||
    col_params_.final_op = final_op_.get();
 | 
			
		||||
    col_params_->merge_op = merge_op_.get();
 | 
			
		||||
    col_params_->final_op = final_op_.get();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
@ -279,24 +286,24 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
 | 
			
		||||
                           c->forward_input_or_allocate_output(
 | 
			
		||||
                               {0}, 0, c->input(0).shape(), &output),
 | 
			
		||||
                           done);
 | 
			
		||||
      col_params_.instance.shape = c->input(0).shape();
 | 
			
		||||
      col_params_->instance.shape = c->input(0).shape();
 | 
			
		||||
    }
 | 
			
		||||
    if (!CanProceedWithCompute(c, col_exec, done)) return;
 | 
			
		||||
 | 
			
		||||
    auto actual_done = [c, group_key = col_params_.group.group_key,
 | 
			
		||||
                        instance_key = col_params_.instance.instance_key,
 | 
			
		||||
                        done](const Status& s) {
 | 
			
		||||
    auto actual_done = [c, col_params = col_params_, done](const Status& s) {
 | 
			
		||||
      VLOG(1) << "CollectiveReduceOpKernel ExecuteAsync done for collective "
 | 
			
		||||
              << c->op_kernel().name() << " device " << c->device()->name()
 | 
			
		||||
              << " group " << group_key << " instance " << instance_key
 | 
			
		||||
              << " status " << s;
 | 
			
		||||
              << " group " << col_params->group.group_key << " instance "
 | 
			
		||||
              << col_params->instance.instance_key << " status " << s;
 | 
			
		||||
      OP_REQUIRES_OK_ASYNC(c, s, done);
 | 
			
		||||
      done();
 | 
			
		||||
      col_params->Unref();
 | 
			
		||||
    };
 | 
			
		||||
    VLOG(1) << "CollectiveReduceOpKernel ExecuteAsync start for collective "
 | 
			
		||||
            << col_params_.name << " device " << c->device()->name()
 | 
			
		||||
            << " group " << col_params_.group.group_key << " instance "
 | 
			
		||||
            << col_params_.instance.instance_key;
 | 
			
		||||
            << col_params_->name << " device " << c->device()->name()
 | 
			
		||||
            << " group " << col_params_->group.group_key << " instance "
 | 
			
		||||
            << col_params_->instance.instance_key;
 | 
			
		||||
    col_params_->Ref();
 | 
			
		||||
    col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -315,29 +322,29 @@ class CollectiveBcastSendOpKernel : public CollectiveOpV1Kernel {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c)
 | 
			
		||||
      : CollectiveOpV1Kernel(c) {
 | 
			
		||||
    col_params_.instance.type = BROADCAST_COLLECTIVE;
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
 | 
			
		||||
    col_params_->instance.type = BROADCAST_COLLECTIVE;
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
 | 
			
		||||
    OP_REQUIRES(
 | 
			
		||||
        c, col_params_.group.group_size > 0,
 | 
			
		||||
        c, col_params_->group.group_size > 0,
 | 
			
		||||
        errors::InvalidArgument("group_size must be positive integer but got ",
 | 
			
		||||
                                col_params_.group.group_size));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
 | 
			
		||||
                                col_params_->group.group_size));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
 | 
			
		||||
    OP_REQUIRES_OK(
 | 
			
		||||
        c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_.instance.shape));
 | 
			
		||||
        c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_->instance.shape));
 | 
			
		||||
    OP_REQUIRES_OK(
 | 
			
		||||
        c, c->GetAttr("communication_hint",
 | 
			
		||||
                      &col_params_.instance.impl_details.communication_hint));
 | 
			
		||||
                      &col_params_->instance.impl_details.communication_hint));
 | 
			
		||||
    OP_REQUIRES_OK(
 | 
			
		||||
        c, c->GetAttr("timeout_seconds",
 | 
			
		||||
                      &col_params_.instance.impl_details.timeout_seconds));
 | 
			
		||||
    col_params_.is_source = true;
 | 
			
		||||
    col_params_.instance.impl_details.subdiv_offsets = {0};
 | 
			
		||||
                      &col_params_->instance.impl_details.timeout_seconds));
 | 
			
		||||
    col_params_->is_source = true;
 | 
			
		||||
    col_params_->instance.impl_details.subdiv_offsets = {0};
 | 
			
		||||
 | 
			
		||||
    col_params_.name =
 | 
			
		||||
        strings::StrCat(name(), ": Broadcast(", col_params_.is_source, ")");
 | 
			
		||||
    col_params_.group.device_type = c->device_type();
 | 
			
		||||
    col_params_->name =
 | 
			
		||||
        strings::StrCat(name(), ": Broadcast(", col_params_->is_source, ")");
 | 
			
		||||
    col_params_->group.device_type = c->device_type();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
@ -352,30 +359,30 @@ class CollectiveBcastSendOpKernel : public CollectiveOpV1Kernel {
 | 
			
		||||
      Tensor* output = nullptr;
 | 
			
		||||
      OP_REQUIRES_OK_ASYNC(c,
 | 
			
		||||
                           c->forward_input_or_allocate_output(
 | 
			
		||||
                               {0}, 0, col_params_.instance.shape, &output),
 | 
			
		||||
                               {0}, 0, col_params_->instance.shape, &output),
 | 
			
		||||
                           done);
 | 
			
		||||
    }
 | 
			
		||||
    if (!CanProceedWithCompute(c, col_exec, done)) return;
 | 
			
		||||
    OP_REQUIRES_ASYNC(
 | 
			
		||||
        c, col_params_.instance.shape.IsSameSize(c->input(0).shape()),
 | 
			
		||||
        errors::Internal("Declared shape of op ", col_params_.name,
 | 
			
		||||
        c, col_params_->instance.shape.IsSameSize(c->input(0).shape()),
 | 
			
		||||
        errors::Internal("Declared shape of op ", col_params_->name,
 | 
			
		||||
                         " does not match shape of input"),
 | 
			
		||||
        done);
 | 
			
		||||
 | 
			
		||||
    auto actual_done = [c, group_key = col_params_.group.group_key,
 | 
			
		||||
                        instance_key = col_params_.instance.instance_key,
 | 
			
		||||
                        done](const Status& s) {
 | 
			
		||||
    auto actual_done = [c, col_params = col_params_, done](const Status& s) {
 | 
			
		||||
      VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync done for collective "
 | 
			
		||||
              << c->op_kernel().name() << " device " << c->device()->name()
 | 
			
		||||
              << " group " << group_key << " instance " << instance_key
 | 
			
		||||
              << " status " << s;
 | 
			
		||||
              << " group " << col_params->group.group_key << " instance "
 | 
			
		||||
              << col_params->instance.instance_key << " status " << s;
 | 
			
		||||
      OP_REQUIRES_OK_ASYNC(c, s, done);
 | 
			
		||||
      done();
 | 
			
		||||
      col_params->Unref();
 | 
			
		||||
    };
 | 
			
		||||
    VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync start for collective "
 | 
			
		||||
            << col_params_.name << " device " << c->device()->name()
 | 
			
		||||
            << " group " << col_params_.group.group_key << " instance "
 | 
			
		||||
            << col_params_.instance.instance_key;
 | 
			
		||||
            << col_params_->name << " device " << c->device()->name()
 | 
			
		||||
            << " group " << col_params_->group.group_key << " instance "
 | 
			
		||||
            << col_params_->instance.instance_key;
 | 
			
		||||
    col_params_->Ref();
 | 
			
		||||
    col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -392,29 +399,29 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpV1Kernel {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c)
 | 
			
		||||
      : CollectiveOpV1Kernel(c) {
 | 
			
		||||
    col_params_.instance.type = BROADCAST_COLLECTIVE;
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
 | 
			
		||||
    col_params_->instance.type = BROADCAST_COLLECTIVE;
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
 | 
			
		||||
    OP_REQUIRES(
 | 
			
		||||
        c, col_params_.group.group_size > 0,
 | 
			
		||||
        c, col_params_->group.group_size > 0,
 | 
			
		||||
        errors::InvalidArgument("group_size must be positive integer but got ",
 | 
			
		||||
                                col_params_.group.group_size));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
 | 
			
		||||
                                col_params_->group.group_size));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
 | 
			
		||||
    OP_REQUIRES_OK(
 | 
			
		||||
        c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_.instance.shape));
 | 
			
		||||
        c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
 | 
			
		||||
    OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_->instance.shape));
 | 
			
		||||
    OP_REQUIRES_OK(
 | 
			
		||||
        c, c->GetAttr("communication_hint",
 | 
			
		||||
                      &col_params_.instance.impl_details.communication_hint));
 | 
			
		||||
                      &col_params_->instance.impl_details.communication_hint));
 | 
			
		||||
    OP_REQUIRES_OK(
 | 
			
		||||
        c, c->GetAttr("timeout_seconds",
 | 
			
		||||
                      &col_params_.instance.impl_details.timeout_seconds));
 | 
			
		||||
    col_params_.is_source = false;
 | 
			
		||||
    col_params_.instance.impl_details.subdiv_offsets = {0};
 | 
			
		||||
                      &col_params_->instance.impl_details.timeout_seconds));
 | 
			
		||||
    col_params_->is_source = false;
 | 
			
		||||
    col_params_->instance.impl_details.subdiv_offsets = {0};
 | 
			
		||||
 | 
			
		||||
    col_params_.name =
 | 
			
		||||
        strings::StrCat(name(), ": Broadcast(", col_params_.is_source, ")");
 | 
			
		||||
    col_params_.group.device_type = c->device_type();
 | 
			
		||||
    col_params_->name =
 | 
			
		||||
        strings::StrCat(name(), ": Broadcast(", col_params_->is_source, ")");
 | 
			
		||||
    col_params_->group.device_type = c->device_type();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
@ -428,24 +435,24 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpV1Kernel {
 | 
			
		||||
      // No input, so must allocate output.
 | 
			
		||||
      Tensor* output = nullptr;
 | 
			
		||||
      OP_REQUIRES_OK_ASYNC(
 | 
			
		||||
          c, c->allocate_output(0, col_params_.instance.shape, &output), done);
 | 
			
		||||
          c, c->allocate_output(0, col_params_->instance.shape, &output), done);
 | 
			
		||||
    }
 | 
			
		||||
    if (!CanProceedWithCompute(c, col_exec, done)) return;
 | 
			
		||||
 | 
			
		||||
    auto actual_done = [c, group_key = col_params_.group.group_key,
 | 
			
		||||
                        instance_key = col_params_.instance.instance_key,
 | 
			
		||||
                        done](const Status& s) {
 | 
			
		||||
    auto actual_done = [c, col_params = col_params_, done](const Status& s) {
 | 
			
		||||
      VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync done for collective "
 | 
			
		||||
              << c->op_kernel().name() << " device " << c->device()->name()
 | 
			
		||||
              << " group " << group_key << " instance_key " << instance_key
 | 
			
		||||
              << " status  " << s;
 | 
			
		||||
              << " group " << col_params->group.group_key << " instance_key "
 | 
			
		||||
              << col_params->instance.instance_key << " status  " << s;
 | 
			
		||||
      OP_REQUIRES_OK_ASYNC(c, s, done);
 | 
			
		||||
      done();
 | 
			
		||||
      col_params->Unref();
 | 
			
		||||
    };
 | 
			
		||||
    VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync start for collective "
 | 
			
		||||
            << col_params_.name << " device " << c->device()->name()
 | 
			
		||||
            << " group " << col_params_.group.group_key << " instance "
 | 
			
		||||
            << col_params_.instance.instance_key;
 | 
			
		||||
            << col_params_->name << " device " << c->device()->name()
 | 
			
		||||
            << " group " << col_params_->group.group_key << " instance "
 | 
			
		||||
            << col_params_->instance.instance_key;
 | 
			
		||||
    col_params_->Ref();
 | 
			
		||||
    col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -534,8 +541,8 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel {
 | 
			
		||||
            << col_params->instance.instance_key;
 | 
			
		||||
 | 
			
		||||
    auto done_with_cleanup = [col_params, done = std::move(done)]() {
 | 
			
		||||
      delete col_params;
 | 
			
		||||
      done();
 | 
			
		||||
      col_params->Unref();
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    // Allocate the output tensor, trying to reuse the input.
 | 
			
		||||
@ -577,7 +584,7 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel {
 | 
			
		||||
                      << " group " << col_params->group.group_key
 | 
			
		||||
                      << " instance " << col_params->instance.instance_key;
 | 
			
		||||
              col_exec->ExecuteAsync(
 | 
			
		||||
                  c, *col_params,
 | 
			
		||||
                  c, col_params,
 | 
			
		||||
                  CollectiveKey(c, col_params->group.group_key,
 | 
			
		||||
                                col_params->instance.instance_key),
 | 
			
		||||
                  actual_done);
 | 
			
		||||
@ -673,8 +680,8 @@ class CollectiveGatherV2OpKernel : public AsyncOpKernel {
 | 
			
		||||
    col_params->instance.shape = output_shape;
 | 
			
		||||
 | 
			
		||||
    auto done_with_cleanup = [col_params, done = std::move(done)]() {
 | 
			
		||||
      delete col_params;
 | 
			
		||||
      done();
 | 
			
		||||
      col_params->Unref();
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    Tensor* output = nullptr;
 | 
			
		||||
@ -714,7 +721,7 @@ class CollectiveGatherV2OpKernel : public AsyncOpKernel {
 | 
			
		||||
                      << " group " << col_params->group.group_key
 | 
			
		||||
                      << " instance " << col_params->instance.instance_key;
 | 
			
		||||
              col_exec->ExecuteAsync(
 | 
			
		||||
                  c, *col_params,
 | 
			
		||||
                  c, col_params,
 | 
			
		||||
                  CollectiveKey(c, col_params->group.group_key,
 | 
			
		||||
                                col_params->instance.instance_key),
 | 
			
		||||
                  actual_done);
 | 
			
		||||
@ -797,8 +804,8 @@ class CollectiveBcastSendV2OpKernel : public AsyncOpKernel {
 | 
			
		||||
            << col_params->instance.instance_key;
 | 
			
		||||
 | 
			
		||||
    auto done_with_cleanup = [col_params, done = std::move(done)]() {
 | 
			
		||||
      delete col_params;
 | 
			
		||||
      done();
 | 
			
		||||
      col_params->Unref();
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    // Allocate the output tensor, trying to reuse the input.
 | 
			
		||||
@ -840,7 +847,7 @@ class CollectiveBcastSendV2OpKernel : public AsyncOpKernel {
 | 
			
		||||
                      << " group " << col_params->group.group_key
 | 
			
		||||
                      << " instance " << col_params->instance.instance_key;
 | 
			
		||||
              col_exec->ExecuteAsync(
 | 
			
		||||
                  c, *col_params,
 | 
			
		||||
                  c, col_params,
 | 
			
		||||
                  CollectiveKey(c, col_params->group.group_key,
 | 
			
		||||
                                col_params->instance.instance_key),
 | 
			
		||||
                  actual_done);
 | 
			
		||||
@ -905,8 +912,8 @@ class CollectiveBcastRecvV2OpKernel : public AsyncOpKernel {
 | 
			
		||||
 | 
			
		||||
    auto col_params = new CollectiveParams();
 | 
			
		||||
    auto done_with_cleanup = [col_params, done = std::move(done)]() {
 | 
			
		||||
      delete col_params;
 | 
			
		||||
      done();
 | 
			
		||||
      col_params->Unref();
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    OP_REQUIRES_OK_ASYNC(
 | 
			
		||||
@ -969,7 +976,7 @@ class CollectiveBcastRecvV2OpKernel : public AsyncOpKernel {
 | 
			
		||||
                      << " group " << col_params->group.group_key
 | 
			
		||||
                      << " instance " << col_params->instance.instance_key;
 | 
			
		||||
              col_exec->ExecuteAsync(
 | 
			
		||||
                  c, *col_params,
 | 
			
		||||
                  c, col_params,
 | 
			
		||||
                  CollectiveKey(c, col_params->group.group_key,
 | 
			
		||||
                                col_params->instance.instance_key),
 | 
			
		||||
                  actual_done);
 | 
			
		||||
 | 
			
		||||
@ -69,17 +69,17 @@ std::unique_ptr<NcclCommunicatorInterface> MaybeCreateNcclCommunicator() {
 | 
			
		||||
 | 
			
		||||
void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
 | 
			
		||||
                               StatusCallback done) {
 | 
			
		||||
  const CollectiveParams& col_params = col_ctx->col_params;
 | 
			
		||||
  const int num_global_devices = col_params.group.group_size;
 | 
			
		||||
  const int num_local_devices = col_params.group.num_devices_per_task.at(
 | 
			
		||||
      col_params.group.task_names[col_params.default_rank]);
 | 
			
		||||
  const CollectiveParams* col_params = col_ctx->col_params;
 | 
			
		||||
  const int num_global_devices = col_params->group.group_size;
 | 
			
		||||
  const int num_local_devices = col_params->group.num_devices_per_task.at(
 | 
			
		||||
      col_params->group.task_names[col_params->default_rank]);
 | 
			
		||||
  const string nccl_collective_key =
 | 
			
		||||
      NcclCollectiveKey(col_ctx->exec_key, col_ctx->step_id);
 | 
			
		||||
  auto* compute_stream = col_ctx->op_ctx->op_device_context()->stream();
 | 
			
		||||
  auto* gpu_info = col_ctx->op_ctx->device()->tensorflow_gpu_device_info();
 | 
			
		||||
  auto participant = absl::make_unique<NcclManager::Participant>(
 | 
			
		||||
      compute_stream->parent(), compute_stream, gpu_info, col_ctx->input,
 | 
			
		||||
      col_ctx->output, col_ctx->col_params.default_rank,
 | 
			
		||||
      col_ctx->output, col_ctx->col_params->default_rank,
 | 
			
		||||
      /*done_callback=*/nullptr);
 | 
			
		||||
  CancellationManager* cancel_mgr = col_ctx->op_ctx->cancellation_manager();
 | 
			
		||||
  if (cancel_mgr == nullptr) {
 | 
			
		||||
@ -105,15 +105,24 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
 | 
			
		||||
  }
 | 
			
		||||
  NcclManager::Context context(
 | 
			
		||||
      nccl_collective_key, num_local_devices, num_global_devices,
 | 
			
		||||
      col_params.group.runtime_details.communicator_key,
 | 
			
		||||
      col_params.source_rank);
 | 
			
		||||
  VLOG(1) << "NcclCommunicator::Enqueue type " << col_params.instance.type
 | 
			
		||||
          << " num_tasks " << col_params.group.num_tasks << " current task "
 | 
			
		||||
          << col_params.group.task_names[col_params.default_rank]
 | 
			
		||||
      col_params->group.runtime_details.communicator_key,
 | 
			
		||||
      col_params->source_rank);
 | 
			
		||||
  VLOG(1) << "NcclCommunicator::Enqueue type " << col_params->instance.type
 | 
			
		||||
          << " num_tasks " << col_params->group.num_tasks << " current task "
 | 
			
		||||
          << col_params->group.task_names[col_params->default_rank]
 | 
			
		||||
          << " num local devices " << num_local_devices
 | 
			
		||||
          << " num global devices " << num_global_devices << " device "
 | 
			
		||||
          << col_ctx->device_name << " instance "
 | 
			
		||||
          << col_params.instance.instance_key;
 | 
			
		||||
          << col_params->instance.instance_key;
 | 
			
		||||
  // Hold a ref to col_params for the rest of this function.
 | 
			
		||||
  // NOTE: an alternate design can be one in which CollectiveParams is not
 | 
			
		||||
  // refcounted.  In such a design, we would need to ensure that the
 | 
			
		||||
  // done_callback of each participant is called only after this function is
 | 
			
		||||
  // done with accessing the params.  This would likely require some
 | 
			
		||||
  // coordination mechanism, and may even require the participant thread to
 | 
			
		||||
  // block until after UnblockDependencies is called below.
 | 
			
		||||
  col_params->Ref();
 | 
			
		||||
  core::ScopedUnref unref(col_params);
 | 
			
		||||
  // `AddTo*` performs consistency checks for the NCCL call and enqueues the
 | 
			
		||||
  // `Participant` struct locally.  When all local participants with this
 | 
			
		||||
  // `nccl_collective_key` have called `AddToAllReduce` and
 | 
			
		||||
@ -123,10 +132,11 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
 | 
			
		||||
  // The `NcclManager` uses a dedicated CUDA stream for NCCL kernels.  At this
 | 
			
		||||
  // point, it synchronizes the NCCL stream with the compute stream, and then
 | 
			
		||||
  // enqueues the NCCL kernel on the NCCL stream.
 | 
			
		||||
  switch (col_params.instance.type) {
 | 
			
		||||
  switch (col_params->instance.type) {
 | 
			
		||||
    case REDUCTION_COLLECTIVE: {
 | 
			
		||||
      ncclRedOp_t reduction_op;
 | 
			
		||||
      Status s = ReductionOp(col_params.merge_op->type_string(), &reduction_op);
 | 
			
		||||
      Status s =
 | 
			
		||||
          ReductionOp(col_params->merge_op->type_string(), &reduction_op);
 | 
			
		||||
      if (!s.ok()) {
 | 
			
		||||
        participant->done_callback(s);
 | 
			
		||||
        return;
 | 
			
		||||
@ -140,7 +150,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    case BROADCAST_COLLECTIVE: {
 | 
			
		||||
      if (col_params.is_source) {
 | 
			
		||||
      if (col_params->is_source) {
 | 
			
		||||
        nccl_manager_.AddBroadcastSend(std::move(participant), context);
 | 
			
		||||
      } else {
 | 
			
		||||
        nccl_manager_.AddBroadcastRecv(std::move(participant), context);
 | 
			
		||||
@ -149,7 +159,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
 | 
			
		||||
    }
 | 
			
		||||
    default: {
 | 
			
		||||
      participant->done_callback(errors::Internal("Unexpected CollectiveType ",
 | 
			
		||||
                                                  col_params.instance.type));
 | 
			
		||||
                                                  col_params->instance.type));
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
@ -175,7 +185,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
 | 
			
		||||
    // ready to go.
 | 
			
		||||
    profiler::TraceMe activity("WaitForDependencies",
 | 
			
		||||
                               profiler::TraceMeLevel::kInfo);
 | 
			
		||||
    col_ctx->col_exec->WaitForDependencies(col_params);
 | 
			
		||||
    col_ctx->col_exec->WaitForDependencies(*col_params);
 | 
			
		||||
    nccl_manager_.SignalMultiNodeReady(nccl_collective_key);
 | 
			
		||||
  }
 | 
			
		||||
  {
 | 
			
		||||
@ -184,7 +194,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr<CollectiveContext> col_ctx,
 | 
			
		||||
    // implementation of `UnblockDependencies` keeps track of the number of
 | 
			
		||||
    // devices that have launched.
 | 
			
		||||
    profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
 | 
			
		||||
    col_ctx->col_exec->UnblockDependencies(col_params);
 | 
			
		||||
    col_ctx->col_exec->UnblockDependencies(*col_params);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user