Add GetParticipatingDevices()
function.
This new function should deduplicate some common code, be slightly faster in the most common case, and remove a potential source of error / make it more future proof (supporting multiple computations). In the simple case (no replica groups and only one computation), we can avoid looking up the replica ID or materializing the intermediate replica IDs list, and just return the devices directly. Otherwise, if we're looking up the replica ID, we get the computation ID for free, so we might as well use it, avoiding a potential source of bugs where multiple computations are used. PiperOrigin-RevId: 345642270 Change-Id: If4da9fe00b104dd754dc1f6d18ba3c56d397f4b5
This commit is contained in:
parent
d4047ea0c0
commit
0fd465c057
tensorflow/compiler/xla/service
@ -5181,14 +5181,26 @@ cc_library(
|
||||
hdrs = ["collective_ops_utils.h"],
|
||||
deps = [
|
||||
":computation_placer",
|
||||
":global_device_id",
|
||||
":hlo",
|
||||
":pattern_matcher",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service:pattern_matcher",
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal", # fixdeps: keep
|
||||
"//tensorflow/stream_executor/lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "collective_ops_utils_test",
|
||||
srcs = ["collective_ops_utils_test.cc"],
|
||||
deps = [
|
||||
":collective_ops_utils",
|
||||
":computation_placer",
|
||||
":global_device_id",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/global_device_id.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
absl::optional<ReductionKind> MatchReductionComputation(
|
||||
@ -48,39 +50,59 @@ absl::optional<ReductionKind> MatchReductionComputation(
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<std::vector<int64>> GetParticipatingReplicas(
|
||||
GlobalDeviceId device_id, absl::Span<const ReplicaGroup> replica_groups,
|
||||
int64 total_replica_count, const DeviceAssignment& device_assn) {
|
||||
std::vector<int64> participating_replicas;
|
||||
|
||||
// Empty replica_groups() means that all replicas participate in one big
|
||||
// group.
|
||||
StatusOr<std::vector<int>> GetParticipatingReplicas(
|
||||
int replica_id, int total_replica_count,
|
||||
absl::Span<const ReplicaGroup> replica_groups) {
|
||||
// Empty replica_groups() means that all replicas participate.
|
||||
if (replica_groups.empty()) {
|
||||
participating_replicas.resize(total_replica_count);
|
||||
absl::c_iota(participating_replicas, 0);
|
||||
return participating_replicas;
|
||||
std::vector<int> all_replicas(total_replica_count);
|
||||
absl::c_iota(all_replicas, 0);
|
||||
return all_replicas;
|
||||
}
|
||||
|
||||
// Use the DeviceAssignment to figure out our replica-id.
|
||||
TF_ASSIGN_OR_RETURN(int replica_id,
|
||||
device_assn.ReplicaIdForDevice(device_id));
|
||||
|
||||
// Figure out the other replicas that go together with this one.
|
||||
absl::optional<ReplicaGroup> replica_group;
|
||||
for (const ReplicaGroup& g : replica_groups) {
|
||||
if (absl::c_linear_search(g.replica_ids(), replica_id)) {
|
||||
CHECK(!replica_group.has_value())
|
||||
TF_RET_CHECK(!replica_group.has_value())
|
||||
<< "Replica " << replica_id << " appears twice in replica groups";
|
||||
replica_group = g;
|
||||
}
|
||||
}
|
||||
CHECK(replica_group.has_value())
|
||||
<< "Replica " << replica_id << " doesn't appear in replica groups? ";
|
||||
TF_RET_CHECK(replica_group.has_value())
|
||||
<< "Replica " << replica_id << " doesn't appear in replica groups";
|
||||
return std::vector<int>(replica_group->replica_ids().begin(),
|
||||
replica_group->replica_ids().end());
|
||||
}
|
||||
|
||||
participating_replicas.insert(participating_replicas.begin(),
|
||||
replica_group->replica_ids().begin(),
|
||||
replica_group->replica_ids().end());
|
||||
return participating_replicas;
|
||||
StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices(
|
||||
GlobalDeviceId device_id, const DeviceAssignment& device_assignment,
|
||||
int total_replica_count, absl::Span<const ReplicaGroup> replica_groups) {
|
||||
std::vector<GlobalDeviceId> participants;
|
||||
// Fast path for common case, avoiding logical IDs lookup.
|
||||
if (replica_groups.empty() && device_assignment.computation_count() == 1) {
|
||||
participants.reserve(total_replica_count);
|
||||
for (int replica_id = 0; replica_id < total_replica_count; ++replica_id) {
|
||||
participants.emplace_back(
|
||||
device_assignment(replica_id, /*computation_id=*/0));
|
||||
}
|
||||
return participants;
|
||||
}
|
||||
|
||||
std::pair<int, int> logical_ids;
|
||||
TF_ASSIGN_OR_RETURN(logical_ids,
|
||||
device_assignment.LogicalIdsForDevice(device_id));
|
||||
int replica_id = logical_ids.first;
|
||||
int computation_id = logical_ids.second;
|
||||
TF_ASSIGN_OR_RETURN(std::vector<int> participating_replicas,
|
||||
GetParticipatingReplicas(replica_id, total_replica_count,
|
||||
replica_groups));
|
||||
|
||||
participants.reserve(participating_replicas.size());
|
||||
for (int replica_id : participating_replicas) {
|
||||
participants.emplace_back(device_assignment(replica_id, computation_id));
|
||||
}
|
||||
return participants;
|
||||
}
|
||||
|
||||
} // end namespace xla
|
||||
|
@ -21,13 +21,13 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/global_device_id.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -37,11 +37,17 @@ enum class ReductionKind { SUM, PRODUCT, MIN, MAX };
|
||||
absl::optional<ReductionKind> MatchReductionComputation(
|
||||
const HloComputation* computation);
|
||||
|
||||
// Figures out which devices (named by their replica-ids) are participating in
|
||||
// the collective subgroup that contains device_id.
|
||||
StatusOr<std::vector<int64>> GetParticipatingReplicas(
|
||||
GlobalDeviceId device_id, absl::Span<const ReplicaGroup> replica_groups,
|
||||
int64 total_replica_count, const DeviceAssignment& device_assn);
|
||||
// Figures out which replicas are participating in the collective subgroup.
|
||||
// An empty `replica_groups` indicates that all replicas are participating.
|
||||
StatusOr<std::vector<int>> GetParticipatingReplicas(
|
||||
int replica_id, int total_replica_count,
|
||||
absl::Span<const ReplicaGroup> replica_groups);
|
||||
|
||||
// Figures out which devices are participating in the collective subgroup.
|
||||
// An empty `replica_groups` indicates that all replicas are participating.
|
||||
StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices(
|
||||
GlobalDeviceId device_id, const DeviceAssignment& device_assignment,
|
||||
int total_replica_count, absl::Span<const ReplicaGroup> replica_groups);
|
||||
|
||||
// Key that identifies a particular Rendezvous object in our global hashtable.
|
||||
// This determines which calls to ExecuteOnStream communicate with each other.
|
||||
|
109
tensorflow/compiler/xla/service/collective_ops_utils_test.cc
Normal file
109
tensorflow/compiler/xla/service/collective_ops_utils_test.cc
Normal file
@ -0,0 +1,109 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/global_device_id.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
TEST(CollectiveOpsUtilsTest, GetParticipatingReplicas_NoReplicaGroups) {
|
||||
std::vector<int> actual =
|
||||
GetParticipatingReplicas(
|
||||
/*replica_id=*/0, /*total_replica_count=*/3, /*replica_groups=*/{})
|
||||
.ConsumeValueOrDie();
|
||||
std::vector<int> expected = {0, 1, 2};
|
||||
EXPECT_EQ(actual, expected);
|
||||
}
|
||||
|
||||
TEST(CollectiveOpsUtilsTest, GetParticipatingReplicas_ReplicaGroups) {
|
||||
std::vector<ReplicaGroup> replica_groups(3);
|
||||
replica_groups[0].add_replica_ids(0);
|
||||
replica_groups[0].add_replica_ids(4);
|
||||
replica_groups[1].add_replica_ids(1);
|
||||
replica_groups[1].add_replica_ids(5);
|
||||
replica_groups[2].add_replica_ids(2);
|
||||
replica_groups[2].add_replica_ids(3);
|
||||
|
||||
std::vector<int> actual =
|
||||
GetParticipatingReplicas(
|
||||
/*replica_id=*/1, /*total_replica_count=*/6, replica_groups)
|
||||
.ConsumeValueOrDie();
|
||||
std::vector<int> expected = {1, 5};
|
||||
EXPECT_EQ(actual, expected);
|
||||
}
|
||||
|
||||
TEST(CollectiveOpsUtilsTest, GetParticipatingDevices_NoReplicaGroups) {
|
||||
DeviceAssignment device_assignment(/*replica_count=*/3,
|
||||
/*computation_count=*/1);
|
||||
device_assignment(0, 0) = 42;
|
||||
device_assignment(1, 0) = 43;
|
||||
device_assignment(2, 0) = 44;
|
||||
|
||||
std::vector<GlobalDeviceId> actual =
|
||||
GetParticipatingDevices(GlobalDeviceId(42), device_assignment,
|
||||
/*total_replica_count=*/3, /*replica_groups=*/{})
|
||||
.ConsumeValueOrDie();
|
||||
std::vector<GlobalDeviceId> expected = {
|
||||
GlobalDeviceId(42), GlobalDeviceId(43), GlobalDeviceId(44)};
|
||||
EXPECT_EQ(actual, expected);
|
||||
}
|
||||
|
||||
TEST(CollectiveOpsUtilsTest, GetParticipatingDevices_ReplicaGroups) {
|
||||
DeviceAssignment device_assignment(/*replica_count=*/4,
|
||||
/*computation_count=*/1);
|
||||
device_assignment(0, 0) = 42;
|
||||
device_assignment(1, 0) = 43;
|
||||
device_assignment(2, 0) = 44;
|
||||
device_assignment(3, 0) = 45;
|
||||
|
||||
std::vector<ReplicaGroup> replica_groups(2);
|
||||
replica_groups[0].add_replica_ids(0);
|
||||
replica_groups[0].add_replica_ids(3);
|
||||
replica_groups[1].add_replica_ids(1);
|
||||
replica_groups[1].add_replica_ids(2);
|
||||
|
||||
std::vector<GlobalDeviceId> actual =
|
||||
GetParticipatingDevices(GlobalDeviceId(42), device_assignment,
|
||||
/*total_replica_count=*/4, replica_groups)
|
||||
.ConsumeValueOrDie();
|
||||
std::vector<GlobalDeviceId> expected = {GlobalDeviceId(42),
|
||||
GlobalDeviceId(45)};
|
||||
EXPECT_EQ(actual, expected);
|
||||
}
|
||||
|
||||
TEST(CollectiveOpsUtilsTest, GetParticipatingDevices_MultipleComputations) {
|
||||
DeviceAssignment device_assignment(/*replica_count=*/2,
|
||||
/*computation_count=*/2);
|
||||
device_assignment(0, 0) = 42;
|
||||
device_assignment(1, 0) = 43;
|
||||
device_assignment(0, 1) = 44;
|
||||
device_assignment(1, 1) = 45;
|
||||
|
||||
std::vector<GlobalDeviceId> actual =
|
||||
GetParticipatingDevices(GlobalDeviceId(44), device_assignment,
|
||||
/*total_replica_count=*/2, /*replica_groups=*/{})
|
||||
.ConsumeValueOrDie();
|
||||
std::vector<GlobalDeviceId> expected = {GlobalDeviceId(44),
|
||||
GlobalDeviceId(45)};
|
||||
EXPECT_EQ(actual, expected);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -167,7 +167,7 @@ struct AllToAllParticipantData : xla::ParticipantData {
|
||||
|
||||
// Replica ids participating in AllToAll, concatenation happens in the order
|
||||
// of appearence.
|
||||
std::vector<xla::int64> replica_ids_to_copy_to;
|
||||
std::vector<int> replica_ids_to_copy_to;
|
||||
|
||||
std::string ToString() const override {
|
||||
auto addr_formatter = [](std::string* out,
|
||||
@ -348,7 +348,7 @@ class CpuAllToAllRendezvous
|
||||
replica_id_map[p.replica_id] = pos;
|
||||
}
|
||||
|
||||
const std::vector<xla::int64>& replica_ids_to_copy_to =
|
||||
const std::vector<int>& replica_ids_to_copy_to =
|
||||
participants_[0].replica_ids_to_copy_to;
|
||||
|
||||
// Replica id -> rank
|
||||
@ -598,26 +598,19 @@ xla::RendezvousKey GetRendezvousKey(
|
||||
xla::int64 op_id) {
|
||||
const xla::DeviceAssignment& device_assignment =
|
||||
*run_options->device_assignment();
|
||||
xla::int32 replica_count = device_assignment.replica_count();
|
||||
int device_ordinal = GetDeviceOrdinal(run_options);
|
||||
CHECK_EQ(device_assignment.computation_count(), 1);
|
||||
std::vector<xla::int64> participating_replicas =
|
||||
xla::GetParticipatingReplicas(xla::GlobalDeviceId(device_ordinal), group,
|
||||
replica_count,
|
||||
*run_options->device_assignment())
|
||||
.ValueOrDie();
|
||||
xla::RendezvousKey::CollectiveOpKind op_kind =
|
||||
channel_id_present ? xla::RendezvousKey::kCrossModule
|
||||
: xla::RendezvousKey::kCrossReplica;
|
||||
std::vector<xla::GlobalDeviceId> participating_devices;
|
||||
participating_devices.reserve(participating_replicas.size());
|
||||
for (xla::int64 replica : participating_replicas) {
|
||||
participating_devices.push_back(
|
||||
xla::GlobalDeviceId(device_assignment(replica, 0)));
|
||||
}
|
||||
return xla::RendezvousKey{
|
||||
run_options->run_id(), std::move(participating_devices),
|
||||
static_cast<int>(participating_replicas.size()), op_kind, op_id};
|
||||
std::vector<xla::GlobalDeviceId> participating_devices =
|
||||
xla::GetParticipatingDevices(xla::GlobalDeviceId(device_ordinal),
|
||||
device_assignment,
|
||||
device_assignment.replica_count(), group)
|
||||
.ValueOrDie();
|
||||
int num_local_participants = participating_devices.size();
|
||||
return xla::RendezvousKey{run_options->run_id(),
|
||||
std::move(participating_devices),
|
||||
num_local_participants, op_kind, op_id};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -644,9 +637,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllToAll(
|
||||
participant.replica_id = replica_id;
|
||||
participant.replica_ids_to_copy_to =
|
||||
xla::GetParticipatingReplicas(
|
||||
xla::GlobalDeviceId(device_ordinal), group,
|
||||
run_options->device_assignment()->replica_count(),
|
||||
*run_options->device_assignment())
|
||||
replica_id, run_options->device_assignment()->replica_count(), group)
|
||||
.ValueOrDie();
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
participant.source_buffers.emplace_back(source_buffers[i], buffer_size);
|
||||
|
@ -24,6 +24,8 @@ limitations under the License.
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/global_device_id.h"
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/nccl/nccl.h"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
@ -132,28 +134,18 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
int device_ordinal = executor->device_ordinal();
|
||||
TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
|
||||
params.GetGlobalDeviceId());
|
||||
// Determines the set of global and local devices that are participating in
|
||||
// the same collective group as the caller.
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<int64> participating_replicas,
|
||||
GetParticipatingReplicas(global_device_id, config_.replica_groups,
|
||||
config_.replica_count, *params.device_assn));
|
||||
if (IsGlobalNcclConfig() &&
|
||||
participating_replicas.size() != config_.replica_count) {
|
||||
std::vector<GlobalDeviceId> participants,
|
||||
GetParticipatingDevices(global_device_id, *params.device_assn,
|
||||
config_.replica_count, config_.replica_groups));
|
||||
|
||||
if (IsGlobalNcclConfig() && (participants.size() != config_.replica_count)) {
|
||||
return InvalidArgument(
|
||||
"Partial replica groups are not allowed when using NCCL_COMM_ID "
|
||||
"environment configuration.");
|
||||
}
|
||||
|
||||
TF_RET_CHECK(params.device_assn->computation_count() == 1)
|
||||
<< params.device_assn->ToString();
|
||||
std::vector<GlobalDeviceId> participants;
|
||||
participants.reserve(participating_replicas.size());
|
||||
for (int64 replica : participating_replicas) {
|
||||
participants.emplace_back(
|
||||
(*params.device_assn)(replica, /*computation=*/0));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<LocalParticipant> local_participants,
|
||||
GetLocalParticipants(participants, params.gpu_global_device_ids));
|
||||
|
Loading…
Reference in New Issue
Block a user