Add GetParticipatingDevices()
function.
This new function should deduplicate some common code, be slightly faster in the most common case, and remove a potential source of error / make it more future proof (supporting multiple computations). In the simple case (no replica groups and only one computation), we can avoid looking up the replica ID or materializing the intermediate replica IDs list, and just return the devices directly. Otherwise, if we're looking up the replica ID, we get the computation ID for free, so we might as well use it, avoiding a potential source of bugs where multiple computations are used. PiperOrigin-RevId: 345642270 Change-Id: If4da9fe00b104dd754dc1f6d18ba3c56d397f4b5
This commit is contained in:
parent
d4047ea0c0
commit
0fd465c057
@ -5181,14 +5181,26 @@ cc_library(
|
|||||||
hdrs = ["collective_ops_utils.h"],
|
hdrs = ["collective_ops_utils.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":computation_placer",
|
":computation_placer",
|
||||||
|
":global_device_id",
|
||||||
":hlo",
|
":hlo",
|
||||||
|
":pattern_matcher",
|
||||||
"//tensorflow/compiler/xla:executable_run_options",
|
"//tensorflow/compiler/xla:executable_run_options",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla/service:pattern_matcher",
|
|
||||||
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
|
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:lib_internal", # fixdeps: keep
|
"//tensorflow/core:lib_internal", # fixdeps: keep
|
||||||
"//tensorflow/stream_executor/lib",
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "collective_ops_utils_test",
|
||||||
|
srcs = ["collective_ops_utils_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":collective_ops_utils",
|
||||||
|
":computation_placer",
|
||||||
|
":global_device_id",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
|
"//tensorflow/core:test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/global_device_id.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
absl::optional<ReductionKind> MatchReductionComputation(
|
absl::optional<ReductionKind> MatchReductionComputation(
|
||||||
@ -48,39 +50,59 @@ absl::optional<ReductionKind> MatchReductionComputation(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<int64>> GetParticipatingReplicas(
|
StatusOr<std::vector<int>> GetParticipatingReplicas(
|
||||||
GlobalDeviceId device_id, absl::Span<const ReplicaGroup> replica_groups,
|
int replica_id, int total_replica_count,
|
||||||
int64 total_replica_count, const DeviceAssignment& device_assn) {
|
absl::Span<const ReplicaGroup> replica_groups) {
|
||||||
std::vector<int64> participating_replicas;
|
// Empty replica_groups() means that all replicas participate.
|
||||||
|
|
||||||
// Empty replica_groups() means that all replicas participate in one big
|
|
||||||
// group.
|
|
||||||
if (replica_groups.empty()) {
|
if (replica_groups.empty()) {
|
||||||
participating_replicas.resize(total_replica_count);
|
std::vector<int> all_replicas(total_replica_count);
|
||||||
absl::c_iota(participating_replicas, 0);
|
absl::c_iota(all_replicas, 0);
|
||||||
return participating_replicas;
|
return all_replicas;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use the DeviceAssignment to figure out our replica-id.
|
|
||||||
TF_ASSIGN_OR_RETURN(int replica_id,
|
|
||||||
device_assn.ReplicaIdForDevice(device_id));
|
|
||||||
|
|
||||||
// Figure out the other replicas that go together with this one.
|
// Figure out the other replicas that go together with this one.
|
||||||
absl::optional<ReplicaGroup> replica_group;
|
absl::optional<ReplicaGroup> replica_group;
|
||||||
for (const ReplicaGroup& g : replica_groups) {
|
for (const ReplicaGroup& g : replica_groups) {
|
||||||
if (absl::c_linear_search(g.replica_ids(), replica_id)) {
|
if (absl::c_linear_search(g.replica_ids(), replica_id)) {
|
||||||
CHECK(!replica_group.has_value())
|
TF_RET_CHECK(!replica_group.has_value())
|
||||||
<< "Replica " << replica_id << " appears twice in replica groups";
|
<< "Replica " << replica_id << " appears twice in replica groups";
|
||||||
replica_group = g;
|
replica_group = g;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
CHECK(replica_group.has_value())
|
TF_RET_CHECK(replica_group.has_value())
|
||||||
<< "Replica " << replica_id << " doesn't appear in replica groups? ";
|
<< "Replica " << replica_id << " doesn't appear in replica groups";
|
||||||
|
return std::vector<int>(replica_group->replica_ids().begin(),
|
||||||
|
replica_group->replica_ids().end());
|
||||||
|
}
|
||||||
|
|
||||||
participating_replicas.insert(participating_replicas.begin(),
|
StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices(
|
||||||
replica_group->replica_ids().begin(),
|
GlobalDeviceId device_id, const DeviceAssignment& device_assignment,
|
||||||
replica_group->replica_ids().end());
|
int total_replica_count, absl::Span<const ReplicaGroup> replica_groups) {
|
||||||
return participating_replicas;
|
std::vector<GlobalDeviceId> participants;
|
||||||
|
// Fast path for common case, avoiding logical IDs lookup.
|
||||||
|
if (replica_groups.empty() && device_assignment.computation_count() == 1) {
|
||||||
|
participants.reserve(total_replica_count);
|
||||||
|
for (int replica_id = 0; replica_id < total_replica_count; ++replica_id) {
|
||||||
|
participants.emplace_back(
|
||||||
|
device_assignment(replica_id, /*computation_id=*/0));
|
||||||
|
}
|
||||||
|
return participants;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<int, int> logical_ids;
|
||||||
|
TF_ASSIGN_OR_RETURN(logical_ids,
|
||||||
|
device_assignment.LogicalIdsForDevice(device_id));
|
||||||
|
int replica_id = logical_ids.first;
|
||||||
|
int computation_id = logical_ids.second;
|
||||||
|
TF_ASSIGN_OR_RETURN(std::vector<int> participating_replicas,
|
||||||
|
GetParticipatingReplicas(replica_id, total_replica_count,
|
||||||
|
replica_groups));
|
||||||
|
|
||||||
|
participants.reserve(participating_replicas.size());
|
||||||
|
for (int replica_id : participating_replicas) {
|
||||||
|
participants.emplace_back(device_assignment(replica_id, computation_id));
|
||||||
|
}
|
||||||
|
return participants;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end namespace xla
|
} // end namespace xla
|
||||||
|
@ -21,13 +21,13 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/global_device_id.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -37,11 +37,17 @@ enum class ReductionKind { SUM, PRODUCT, MIN, MAX };
|
|||||||
absl::optional<ReductionKind> MatchReductionComputation(
|
absl::optional<ReductionKind> MatchReductionComputation(
|
||||||
const HloComputation* computation);
|
const HloComputation* computation);
|
||||||
|
|
||||||
// Figures out which devices (named by their replica-ids) are participating in
|
// Figures out which replicas are participating in the collective subgroup.
|
||||||
// the collective subgroup that contains device_id.
|
// An empty `replica_groups` indicates that all replicas are participating.
|
||||||
StatusOr<std::vector<int64>> GetParticipatingReplicas(
|
StatusOr<std::vector<int>> GetParticipatingReplicas(
|
||||||
GlobalDeviceId device_id, absl::Span<const ReplicaGroup> replica_groups,
|
int replica_id, int total_replica_count,
|
||||||
int64 total_replica_count, const DeviceAssignment& device_assn);
|
absl::Span<const ReplicaGroup> replica_groups);
|
||||||
|
|
||||||
|
// Figures out which devices are participating in the collective subgroup.
|
||||||
|
// An empty `replica_groups` indicates that all replicas are participating.
|
||||||
|
StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices(
|
||||||
|
GlobalDeviceId device_id, const DeviceAssignment& device_assignment,
|
||||||
|
int total_replica_count, absl::Span<const ReplicaGroup> replica_groups);
|
||||||
|
|
||||||
// Key that identifies a particular Rendezvous object in our global hashtable.
|
// Key that identifies a particular Rendezvous object in our global hashtable.
|
||||||
// This determines which calls to ExecuteOnStream communicate with each other.
|
// This determines which calls to ExecuteOnStream communicate with each other.
|
||||||
|
109
tensorflow/compiler/xla/service/collective_ops_utils_test.cc
Normal file
109
tensorflow/compiler/xla/service/collective_ops_utils_test.cc
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/global_device_id.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(CollectiveOpsUtilsTest, GetParticipatingReplicas_NoReplicaGroups) {
|
||||||
|
std::vector<int> actual =
|
||||||
|
GetParticipatingReplicas(
|
||||||
|
/*replica_id=*/0, /*total_replica_count=*/3, /*replica_groups=*/{})
|
||||||
|
.ConsumeValueOrDie();
|
||||||
|
std::vector<int> expected = {0, 1, 2};
|
||||||
|
EXPECT_EQ(actual, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CollectiveOpsUtilsTest, GetParticipatingReplicas_ReplicaGroups) {
|
||||||
|
std::vector<ReplicaGroup> replica_groups(3);
|
||||||
|
replica_groups[0].add_replica_ids(0);
|
||||||
|
replica_groups[0].add_replica_ids(4);
|
||||||
|
replica_groups[1].add_replica_ids(1);
|
||||||
|
replica_groups[1].add_replica_ids(5);
|
||||||
|
replica_groups[2].add_replica_ids(2);
|
||||||
|
replica_groups[2].add_replica_ids(3);
|
||||||
|
|
||||||
|
std::vector<int> actual =
|
||||||
|
GetParticipatingReplicas(
|
||||||
|
/*replica_id=*/1, /*total_replica_count=*/6, replica_groups)
|
||||||
|
.ConsumeValueOrDie();
|
||||||
|
std::vector<int> expected = {1, 5};
|
||||||
|
EXPECT_EQ(actual, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CollectiveOpsUtilsTest, GetParticipatingDevices_NoReplicaGroups) {
|
||||||
|
DeviceAssignment device_assignment(/*replica_count=*/3,
|
||||||
|
/*computation_count=*/1);
|
||||||
|
device_assignment(0, 0) = 42;
|
||||||
|
device_assignment(1, 0) = 43;
|
||||||
|
device_assignment(2, 0) = 44;
|
||||||
|
|
||||||
|
std::vector<GlobalDeviceId> actual =
|
||||||
|
GetParticipatingDevices(GlobalDeviceId(42), device_assignment,
|
||||||
|
/*total_replica_count=*/3, /*replica_groups=*/{})
|
||||||
|
.ConsumeValueOrDie();
|
||||||
|
std::vector<GlobalDeviceId> expected = {
|
||||||
|
GlobalDeviceId(42), GlobalDeviceId(43), GlobalDeviceId(44)};
|
||||||
|
EXPECT_EQ(actual, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CollectiveOpsUtilsTest, GetParticipatingDevices_ReplicaGroups) {
|
||||||
|
DeviceAssignment device_assignment(/*replica_count=*/4,
|
||||||
|
/*computation_count=*/1);
|
||||||
|
device_assignment(0, 0) = 42;
|
||||||
|
device_assignment(1, 0) = 43;
|
||||||
|
device_assignment(2, 0) = 44;
|
||||||
|
device_assignment(3, 0) = 45;
|
||||||
|
|
||||||
|
std::vector<ReplicaGroup> replica_groups(2);
|
||||||
|
replica_groups[0].add_replica_ids(0);
|
||||||
|
replica_groups[0].add_replica_ids(3);
|
||||||
|
replica_groups[1].add_replica_ids(1);
|
||||||
|
replica_groups[1].add_replica_ids(2);
|
||||||
|
|
||||||
|
std::vector<GlobalDeviceId> actual =
|
||||||
|
GetParticipatingDevices(GlobalDeviceId(42), device_assignment,
|
||||||
|
/*total_replica_count=*/4, replica_groups)
|
||||||
|
.ConsumeValueOrDie();
|
||||||
|
std::vector<GlobalDeviceId> expected = {GlobalDeviceId(42),
|
||||||
|
GlobalDeviceId(45)};
|
||||||
|
EXPECT_EQ(actual, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CollectiveOpsUtilsTest, GetParticipatingDevices_MultipleComputations) {
|
||||||
|
DeviceAssignment device_assignment(/*replica_count=*/2,
|
||||||
|
/*computation_count=*/2);
|
||||||
|
device_assignment(0, 0) = 42;
|
||||||
|
device_assignment(1, 0) = 43;
|
||||||
|
device_assignment(0, 1) = 44;
|
||||||
|
device_assignment(1, 1) = 45;
|
||||||
|
|
||||||
|
std::vector<GlobalDeviceId> actual =
|
||||||
|
GetParticipatingDevices(GlobalDeviceId(44), device_assignment,
|
||||||
|
/*total_replica_count=*/2, /*replica_groups=*/{})
|
||||||
|
.ConsumeValueOrDie();
|
||||||
|
std::vector<GlobalDeviceId> expected = {GlobalDeviceId(44),
|
||||||
|
GlobalDeviceId(45)};
|
||||||
|
EXPECT_EQ(actual, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace xla
|
@ -167,7 +167,7 @@ struct AllToAllParticipantData : xla::ParticipantData {
|
|||||||
|
|
||||||
// Replica ids participating in AllToAll, concatenation happens in the order
|
// Replica ids participating in AllToAll, concatenation happens in the order
|
||||||
// of appearence.
|
// of appearence.
|
||||||
std::vector<xla::int64> replica_ids_to_copy_to;
|
std::vector<int> replica_ids_to_copy_to;
|
||||||
|
|
||||||
std::string ToString() const override {
|
std::string ToString() const override {
|
||||||
auto addr_formatter = [](std::string* out,
|
auto addr_formatter = [](std::string* out,
|
||||||
@ -348,7 +348,7 @@ class CpuAllToAllRendezvous
|
|||||||
replica_id_map[p.replica_id] = pos;
|
replica_id_map[p.replica_id] = pos;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<xla::int64>& replica_ids_to_copy_to =
|
const std::vector<int>& replica_ids_to_copy_to =
|
||||||
participants_[0].replica_ids_to_copy_to;
|
participants_[0].replica_ids_to_copy_to;
|
||||||
|
|
||||||
// Replica id -> rank
|
// Replica id -> rank
|
||||||
@ -598,26 +598,19 @@ xla::RendezvousKey GetRendezvousKey(
|
|||||||
xla::int64 op_id) {
|
xla::int64 op_id) {
|
||||||
const xla::DeviceAssignment& device_assignment =
|
const xla::DeviceAssignment& device_assignment =
|
||||||
*run_options->device_assignment();
|
*run_options->device_assignment();
|
||||||
xla::int32 replica_count = device_assignment.replica_count();
|
|
||||||
int device_ordinal = GetDeviceOrdinal(run_options);
|
int device_ordinal = GetDeviceOrdinal(run_options);
|
||||||
CHECK_EQ(device_assignment.computation_count(), 1);
|
|
||||||
std::vector<xla::int64> participating_replicas =
|
|
||||||
xla::GetParticipatingReplicas(xla::GlobalDeviceId(device_ordinal), group,
|
|
||||||
replica_count,
|
|
||||||
*run_options->device_assignment())
|
|
||||||
.ValueOrDie();
|
|
||||||
xla::RendezvousKey::CollectiveOpKind op_kind =
|
xla::RendezvousKey::CollectiveOpKind op_kind =
|
||||||
channel_id_present ? xla::RendezvousKey::kCrossModule
|
channel_id_present ? xla::RendezvousKey::kCrossModule
|
||||||
: xla::RendezvousKey::kCrossReplica;
|
: xla::RendezvousKey::kCrossReplica;
|
||||||
std::vector<xla::GlobalDeviceId> participating_devices;
|
std::vector<xla::GlobalDeviceId> participating_devices =
|
||||||
participating_devices.reserve(participating_replicas.size());
|
xla::GetParticipatingDevices(xla::GlobalDeviceId(device_ordinal),
|
||||||
for (xla::int64 replica : participating_replicas) {
|
device_assignment,
|
||||||
participating_devices.push_back(
|
device_assignment.replica_count(), group)
|
||||||
xla::GlobalDeviceId(device_assignment(replica, 0)));
|
.ValueOrDie();
|
||||||
}
|
int num_local_participants = participating_devices.size();
|
||||||
return xla::RendezvousKey{
|
return xla::RendezvousKey{run_options->run_id(),
|
||||||
run_options->run_id(), std::move(participating_devices),
|
std::move(participating_devices),
|
||||||
static_cast<int>(participating_replicas.size()), op_kind, op_id};
|
num_local_participants, op_kind, op_id};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -644,9 +637,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllToAll(
|
|||||||
participant.replica_id = replica_id;
|
participant.replica_id = replica_id;
|
||||||
participant.replica_ids_to_copy_to =
|
participant.replica_ids_to_copy_to =
|
||||||
xla::GetParticipatingReplicas(
|
xla::GetParticipatingReplicas(
|
||||||
xla::GlobalDeviceId(device_ordinal), group,
|
replica_id, run_options->device_assignment()->replica_count(), group)
|
||||||
run_options->device_assignment()->replica_count(),
|
|
||||||
*run_options->device_assignment())
|
|
||||||
.ValueOrDie();
|
.ValueOrDie();
|
||||||
for (int i = 0; i < num_buffers; i++) {
|
for (int i = 0; i < num_buffers; i++) {
|
||||||
participant.source_buffers.emplace_back(source_buffers[i], buffer_size);
|
participant.source_buffers.emplace_back(source_buffers[i], buffer_size);
|
||||||
|
@ -24,6 +24,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/global_device_id.h"
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "third_party/nccl/nccl.h"
|
#include "third_party/nccl/nccl.h"
|
||||||
#elif TENSORFLOW_USE_ROCM
|
#elif TENSORFLOW_USE_ROCM
|
||||||
@ -132,28 +134,18 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
|||||||
int device_ordinal = executor->device_ordinal();
|
int device_ordinal = executor->device_ordinal();
|
||||||
TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
|
TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
|
||||||
params.GetGlobalDeviceId());
|
params.GetGlobalDeviceId());
|
||||||
// Determines the set of global and local devices that are participating in
|
|
||||||
// the same collective group as the caller.
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::vector<int64> participating_replicas,
|
std::vector<GlobalDeviceId> participants,
|
||||||
GetParticipatingReplicas(global_device_id, config_.replica_groups,
|
GetParticipatingDevices(global_device_id, *params.device_assn,
|
||||||
config_.replica_count, *params.device_assn));
|
config_.replica_count, config_.replica_groups));
|
||||||
if (IsGlobalNcclConfig() &&
|
|
||||||
participating_replicas.size() != config_.replica_count) {
|
if (IsGlobalNcclConfig() && (participants.size() != config_.replica_count)) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Partial replica groups are not allowed when using NCCL_COMM_ID "
|
"Partial replica groups are not allowed when using NCCL_COMM_ID "
|
||||||
"environment configuration.");
|
"environment configuration.");
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_RET_CHECK(params.device_assn->computation_count() == 1)
|
|
||||||
<< params.device_assn->ToString();
|
|
||||||
std::vector<GlobalDeviceId> participants;
|
|
||||||
participants.reserve(participating_replicas.size());
|
|
||||||
for (int64 replica : participating_replicas) {
|
|
||||||
participants.emplace_back(
|
|
||||||
(*params.device_assn)(replica, /*computation=*/0));
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::vector<LocalParticipant> local_participants,
|
std::vector<LocalParticipant> local_participants,
|
||||||
GetLocalParticipants(participants, params.gpu_global_device_ids));
|
GetLocalParticipants(participants, params.gpu_global_device_ids));
|
||||||
|
Loading…
x
Reference in New Issue
Block a user