diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 93c81052cd8..482b5f37579 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -5181,14 +5181,26 @@ cc_library( hdrs = ["collective_ops_utils.h"], deps = [ ":computation_placer", + ":global_device_id", ":hlo", + ":pattern_matcher", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/service:pattern_matcher", "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", - "//tensorflow/core:lib", "//tensorflow/core:lib_internal", # fixdeps: keep - "//tensorflow/stream_executor/lib", + ], +) + +tf_cc_test( + name = "collective_ops_utils_test", + srcs = ["collective_ops_utils_test.cc"], + deps = [ + ":collective_ops_utils", + ":computation_placer", + ":global_device_id", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.cc b/tensorflow/compiler/xla/service/collective_ops_utils.cc index 331f26ca9c3..c2da4834943 100644 --- a/tensorflow/compiler/xla/service/collective_ops_utils.cc +++ b/tensorflow/compiler/xla/service/collective_ops_utils.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/collective_ops_utils.h" +#include "tensorflow/compiler/xla/service/global_device_id.h" + namespace xla { absl::optional<ReductionKind> MatchReductionComputation( @@ -48,39 +50,59 @@ absl::optional<ReductionKind> MatchReductionComputation( } } -StatusOr<std::vector<int64>> GetParticipatingReplicas( - GlobalDeviceId device_id, absl::Span<const ReplicaGroup> replica_groups, - int64 total_replica_count, const DeviceAssignment& device_assn) { - std::vector<int64> participating_replicas; - - // Empty replica_groups() means that all replicas participate in one big - // group. +StatusOr<std::vector<int>> GetParticipatingReplicas( + int replica_id, int total_replica_count, + absl::Span<const ReplicaGroup> replica_groups) { + // Empty replica_groups() means that all replicas participate. if (replica_groups.empty()) { - participating_replicas.resize(total_replica_count); - absl::c_iota(participating_replicas, 0); - return participating_replicas; + std::vector<int> all_replicas(total_replica_count); + absl::c_iota(all_replicas, 0); + return all_replicas; } - // Use the DeviceAssignment to figure out our replica-id. - TF_ASSIGN_OR_RETURN(int replica_id, - device_assn.ReplicaIdForDevice(device_id)); - // Figure out the other replicas that go together with this one. absl::optional<ReplicaGroup> replica_group; for (const ReplicaGroup& g : replica_groups) { if (absl::c_linear_search(g.replica_ids(), replica_id)) { - CHECK(!replica_group.has_value()) + TF_RET_CHECK(!replica_group.has_value()) << "Replica " << replica_id << " appears twice in replica groups"; replica_group = g; } } - CHECK(replica_group.has_value()) - << "Replica " << replica_id << " doesn't appear in replica groups? "; + TF_RET_CHECK(replica_group.has_value()) + << "Replica " << replica_id << " doesn't appear in replica groups"; + return std::vector<int>(replica_group->replica_ids().begin(), + replica_group->replica_ids().end()); +} - participating_replicas.insert(participating_replicas.begin(), - replica_group->replica_ids().begin(), - replica_group->replica_ids().end()); - return participating_replicas; +StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices( + GlobalDeviceId device_id, const DeviceAssignment& device_assignment, + int total_replica_count, absl::Span<const ReplicaGroup> replica_groups) { + std::vector<GlobalDeviceId> participants; + // Fast path for common case, avoiding logical IDs lookup. + if (replica_groups.empty() && device_assignment.computation_count() == 1) { + participants.reserve(total_replica_count); + for (int replica_id = 0; replica_id < total_replica_count; ++replica_id) { + participants.emplace_back( + device_assignment(replica_id, /*computation_id=*/0)); + } + return participants; + } + + std::pair<int, int> logical_ids; + TF_ASSIGN_OR_RETURN(logical_ids, + device_assignment.LogicalIdsForDevice(device_id)); + int replica_id = logical_ids.first; + int computation_id = logical_ids.second; + TF_ASSIGN_OR_RETURN(std::vector<int> participating_replicas, + GetParticipatingReplicas(replica_id, total_replica_count, + replica_groups)); + + participants.reserve(participating_replicas.size()); + for (int replica_id : participating_replicas) { + participants.emplace_back(device_assignment(replica_id, computation_id)); + } + return participants; } } // end namespace xla diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.h b/tensorflow/compiler/xla/service/collective_ops_utils.h index c67d0739f5c..d7f63cb2663 100644 --- a/tensorflow/compiler/xla/service/collective_ops_utils.h +++ b/tensorflow/compiler/xla/service/collective_ops_utils.h @@ -21,13 +21,13 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/global_device_id.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/blocking_counter.h" -#include "tensorflow/stream_executor/lib/statusor.h" namespace xla { @@ -37,11 +37,17 @@ enum class ReductionKind { SUM, PRODUCT, MIN, MAX }; absl::optional<ReductionKind> MatchReductionComputation( const HloComputation* computation); -// Figures out which devices (named by their replica-ids) are participating in -// the collective subgroup that contains device_id. -StatusOr<std::vector<int64>> GetParticipatingReplicas( - GlobalDeviceId device_id, absl::Span<const ReplicaGroup> replica_groups, - int64 total_replica_count, const DeviceAssignment& device_assn); +// Figures out which replicas are participating in the collective subgroup. +// An empty `replica_groups` indicates that all replicas are participating. +StatusOr<std::vector<int>> GetParticipatingReplicas( + int replica_id, int total_replica_count, + absl::Span<const ReplicaGroup> replica_groups); + +// Figures out which devices are participating in the collective subgroup. +// An empty `replica_groups` indicates that all replicas are participating. +StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices( + GlobalDeviceId device_id, const DeviceAssignment& device_assignment, + int total_replica_count, absl::Span<const ReplicaGroup> replica_groups); // Key that identifies a particular Rendezvous object in our global hashtable. // This determines which calls to ExecuteOnStream communicate with each other. diff --git a/tensorflow/compiler/xla/service/collective_ops_utils_test.cc b/tensorflow/compiler/xla/service/collective_ops_utils_test.cc new file mode 100644 index 00000000000..c08113fb453 --- /dev/null +++ b/tensorflow/compiler/xla/service/collective_ops_utils_test.cc @@ -0,0 +1,109 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/collective_ops_utils.h" + +#include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/global_device_id.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +TEST(CollectiveOpsUtilsTest, GetParticipatingReplicas_NoReplicaGroups) { + std::vector<int> actual = + GetParticipatingReplicas( + /*replica_id=*/0, /*total_replica_count=*/3, /*replica_groups=*/{}) + .ConsumeValueOrDie(); + std::vector<int> expected = {0, 1, 2}; + EXPECT_EQ(actual, expected); +} + +TEST(CollectiveOpsUtilsTest, GetParticipatingReplicas_ReplicaGroups) { + std::vector<ReplicaGroup> replica_groups(3); + replica_groups[0].add_replica_ids(0); + replica_groups[0].add_replica_ids(4); + replica_groups[1].add_replica_ids(1); + replica_groups[1].add_replica_ids(5); + replica_groups[2].add_replica_ids(2); + replica_groups[2].add_replica_ids(3); + + std::vector<int> actual = + GetParticipatingReplicas( + /*replica_id=*/1, /*total_replica_count=*/6, replica_groups) + .ConsumeValueOrDie(); + std::vector<int> expected = {1, 5}; + EXPECT_EQ(actual, expected); +} + +TEST(CollectiveOpsUtilsTest, GetParticipatingDevices_NoReplicaGroups) { + DeviceAssignment device_assignment(/*replica_count=*/3, + /*computation_count=*/1); + device_assignment(0, 0) = 42; + device_assignment(1, 0) = 43; + device_assignment(2, 0) = 44; + + std::vector<GlobalDeviceId> actual = + GetParticipatingDevices(GlobalDeviceId(42), device_assignment, + /*total_replica_count=*/3, /*replica_groups=*/{}) + .ConsumeValueOrDie(); + std::vector<GlobalDeviceId> expected = { + GlobalDeviceId(42), GlobalDeviceId(43), GlobalDeviceId(44)}; + EXPECT_EQ(actual, expected); +} + +TEST(CollectiveOpsUtilsTest, GetParticipatingDevices_ReplicaGroups) { + DeviceAssignment device_assignment(/*replica_count=*/4, + /*computation_count=*/1); + device_assignment(0, 0) = 42; + device_assignment(1, 0) = 43; + device_assignment(2, 0) = 44; + device_assignment(3, 0) = 45; + + std::vector<ReplicaGroup> replica_groups(2); + replica_groups[0].add_replica_ids(0); + replica_groups[0].add_replica_ids(3); + replica_groups[1].add_replica_ids(1); + replica_groups[1].add_replica_ids(2); + + std::vector<GlobalDeviceId> actual = + GetParticipatingDevices(GlobalDeviceId(42), device_assignment, + /*total_replica_count=*/4, replica_groups) + .ConsumeValueOrDie(); + std::vector<GlobalDeviceId> expected = {GlobalDeviceId(42), + GlobalDeviceId(45)}; + EXPECT_EQ(actual, expected); +} + +TEST(CollectiveOpsUtilsTest, GetParticipatingDevices_MultipleComputations) { + DeviceAssignment device_assignment(/*replica_count=*/2, + /*computation_count=*/2); + device_assignment(0, 0) = 42; + device_assignment(1, 0) = 43; + device_assignment(0, 1) = 44; + device_assignment(1, 1) = 45; + + std::vector<GlobalDeviceId> actual = + GetParticipatingDevices(GlobalDeviceId(44), device_assignment, + /*total_replica_count=*/2, /*replica_groups=*/{}) + .ConsumeValueOrDie(); + std::vector<GlobalDeviceId> expected = {GlobalDeviceId(44), + GlobalDeviceId(45)}; + EXPECT_EQ(actual, expected); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index a8cc531e5c6..d4d78f5ac12 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -167,7 +167,7 @@ struct AllToAllParticipantData : xla::ParticipantData { // Replica ids participating in AllToAll, concatenation happens in the order // of appearence. - std::vector<xla::int64> replica_ids_to_copy_to; + std::vector<int> replica_ids_to_copy_to; std::string ToString() const override { auto addr_formatter = [](std::string* out, @@ -348,7 +348,7 @@ class CpuAllToAllRendezvous replica_id_map[p.replica_id] = pos; } - const std::vector<xla::int64>& replica_ids_to_copy_to = + const std::vector<int>& replica_ids_to_copy_to = participants_[0].replica_ids_to_copy_to; // Replica id -> rank @@ -598,26 +598,19 @@ xla::RendezvousKey GetRendezvousKey( xla::int64 op_id) { const xla::DeviceAssignment& device_assignment = *run_options->device_assignment(); - xla::int32 replica_count = device_assignment.replica_count(); int device_ordinal = GetDeviceOrdinal(run_options); - CHECK_EQ(device_assignment.computation_count(), 1); - std::vector<xla::int64> participating_replicas = - xla::GetParticipatingReplicas(xla::GlobalDeviceId(device_ordinal), group, - replica_count, - *run_options->device_assignment()) - .ValueOrDie(); xla::RendezvousKey::CollectiveOpKind op_kind = channel_id_present ? xla::RendezvousKey::kCrossModule : xla::RendezvousKey::kCrossReplica; - std::vector<xla::GlobalDeviceId> participating_devices; - participating_devices.reserve(participating_replicas.size()); - for (xla::int64 replica : participating_replicas) { - participating_devices.push_back( - xla::GlobalDeviceId(device_assignment(replica, 0))); - } - return xla::RendezvousKey{ - run_options->run_id(), std::move(participating_devices), - static_cast<int>(participating_replicas.size()), op_kind, op_id}; + std::vector<xla::GlobalDeviceId> participating_devices = + xla::GetParticipatingDevices(xla::GlobalDeviceId(device_ordinal), + device_assignment, + device_assignment.replica_count(), group) + .ValueOrDie(); + int num_local_participants = participating_devices.size(); + return xla::RendezvousKey{run_options->run_id(), + std::move(participating_devices), + num_local_participants, op_kind, op_id}; } } // namespace @@ -644,9 +637,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllToAll( participant.replica_id = replica_id; participant.replica_ids_to_copy_to = xla::GetParticipatingReplicas( - xla::GlobalDeviceId(device_ordinal), group, - run_options->device_assignment()->replica_count(), - *run_options->device_assignment()) + replica_id, run_options->device_assignment()->replica_count(), group) .ValueOrDie(); for (int i = 0; i < num_buffers; i++) { participant.source_buffers.emplace_back(source_buffers[i], buffer_size); diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index 71be037856e..9b797f7e9e6 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -24,6 +24,8 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/service/collective_ops_utils.h" +#include "tensorflow/compiler/xla/service/global_device_id.h" #if GOOGLE_CUDA #include "third_party/nccl/nccl.h" #elif TENSORFLOW_USE_ROCM @@ -132,28 +134,18 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { int device_ordinal = executor->device_ordinal(); TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id, params.GetGlobalDeviceId()); - // Determines the set of global and local devices that are participating in - // the same collective group as the caller. + TF_ASSIGN_OR_RETURN( - std::vector<int64> participating_replicas, - GetParticipatingReplicas(global_device_id, config_.replica_groups, - config_.replica_count, *params.device_assn)); - if (IsGlobalNcclConfig() && - participating_replicas.size() != config_.replica_count) { + std::vector<GlobalDeviceId> participants, + GetParticipatingDevices(global_device_id, *params.device_assn, + config_.replica_count, config_.replica_groups)); + + if (IsGlobalNcclConfig() && (participants.size() != config_.replica_count)) { return InvalidArgument( "Partial replica groups are not allowed when using NCCL_COMM_ID " "environment configuration."); } - TF_RET_CHECK(params.device_assn->computation_count() == 1) - << params.device_assn->ToString(); - std::vector<GlobalDeviceId> participants; - participants.reserve(participating_replicas.size()); - for (int64 replica : participating_replicas) { - participants.emplace_back( - (*params.device_assn)(replica, /*computation=*/0)); - } - TF_ASSIGN_OR_RETURN( std::vector<LocalParticipant> local_participants, GetLocalParticipants(participants, params.gpu_global_device_ids));