diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 1cfb449ebd0..452c87b23b7 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -100,6 +100,17 @@ const DeviceAssignment* ExecutableRunOptions::device_assignment() const { return device_assignment_; } +ExecutableRunOptions& ExecutableRunOptions::set_gpu_executable_run_options( + const GpuExecutableRunOptions* gpu_executable_run_options) { + gpu_executable_run_options_ = gpu_executable_run_options; + return *this; +} + +const GpuExecutableRunOptions* +ExecutableRunOptions::gpu_executable_run_options() const { + return gpu_executable_run_options_; +} + ExecutableRunOptions& ExecutableRunOptions::set_rng_seed(int rng_seed) { rng_seed_ = rng_seed; return *this; diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index ed67bfbeb0d..b44d5f13b68 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -38,6 +38,7 @@ namespace xla { class DeviceAssignment; class ExecutionProfile; +class GpuExecutableRunOptions; // A unique identifier for a particular "logical execution" of an XLA model. // @@ -137,6 +138,12 @@ class ExecutableRunOptions { return then_execute_function_; } + // GPU-backend specific options. These are kept out-of-line to avoid bloating + // the size of this dependency for CPU-only AOT builds. + ExecutableRunOptions& set_gpu_executable_run_options( + const GpuExecutableRunOptions* gpu_executable_run_options); + const GpuExecutableRunOptions* gpu_executable_run_options() const; + private: stream_executor::DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; @@ -148,6 +155,7 @@ class ExecutableRunOptions { stream_executor::Stream* host_to_device_stream_ = nullptr; ThenExecuteFunction* then_execute_function_ = nullptr; RunId run_id_; + const GpuExecutableRunOptions* gpu_executable_run_options_ = nullptr; }; } // namespace xla diff --git a/tensorflow/compiler/xla/refcounting_hash_map.h b/tensorflow/compiler/xla/refcounting_hash_map.h index 3ff6a50d85f..efa1b9e3a50 100644 --- a/tensorflow/compiler/xla/refcounting_hash_map.h +++ b/tensorflow/compiler/xla/refcounting_hash_map.h @@ -42,13 +42,7 @@ template class RefcountingHashMap { public: // Default-constructs new values. - RefcountingHashMap() - : value_factory_([](const K&) { return absl::make_unique(); }) {} - - // Constructs new values according to the given factory function. - explicit RefcountingHashMap( - std::function(const K&)> value_factory) - : value_factory_(std::move(value_factory)) {} + RefcountingHashMap() = default; // Not copyable or movable because this contains internal pointers (namely, // instances of Deleter contain pointers to `this` and into `map_`). @@ -60,8 +54,10 @@ class RefcountingHashMap { // Gets the value for the given key. // // If the map doesn't contain a live value for the key, constructs one - // according to the factory passed to the map's constructor. - std::shared_ptr operator[](const K& key) { + // using `value_factory`. + std::shared_ptr GetOrCreateIfAbsent( + const K& key, + const std::function(const K&)>& value_factory) { absl::MutexLock lock(&mu_); auto it = map_.find(key); // We ensure that the entry has not expired in case deleter was running when @@ -76,7 +72,7 @@ class RefcountingHashMap { // Create entry in the map and then set its value, so the value can // contain a pointer back into the map. it = map_.emplace(key, std::weak_ptr()).first; - std::shared_ptr value(value_factory_(key).release(), + std::shared_ptr value(value_factory(key).release(), Deleter{&it->first, this}); it->second = value; // Set the weak ptr to the shared ptr. return value; @@ -112,7 +108,6 @@ class RefcountingHashMap { } }; - std::function(const K&)> value_factory_; absl::Mutex mu_; absl::node_hash_map> map_ ABSL_GUARDED_BY(mu_); }; diff --git a/tensorflow/compiler/xla/refcounting_hash_map_test.cc b/tensorflow/compiler/xla/refcounting_hash_map_test.cc index 753c30dafbe..acb7d7afb46 100644 --- a/tensorflow/compiler/xla/refcounting_hash_map_test.cc +++ b/tensorflow/compiler/xla/refcounting_hash_map_test.cc @@ -47,22 +47,25 @@ struct DeleteNotifier { TEST(RefcountingHashMapTest, PointerIdentity) { RefcountingHashMap m; - std::shared_ptr a = m[0]; - std::shared_ptr b = m[0]; - std::shared_ptr c = m[1]; + auto factory = [](const int&) { return absl::make_unique(); }; + std::shared_ptr a = m.GetOrCreateIfAbsent(0, factory); + std::shared_ptr b = m.GetOrCreateIfAbsent(0, factory); + std::shared_ptr c = m.GetOrCreateIfAbsent(1, factory); EXPECT_EQ(a.get(), b.get()); EXPECT_NE(a.get(), c.get()); } TEST(RefcountingHashMapTest, DefaultInitialized) { RefcountingHashMap m; - EXPECT_EQ(*m[42], 0); + auto factory = [](const int&) { return absl::make_unique(); }; + EXPECT_EQ(*m.GetOrCreateIfAbsent(42, factory), 0); } TEST(RefcountingHashMapTest, DeletesEagerly) { RefcountingHashMap m; bool deleted = false; - auto handle = m[0]; + auto factory = [](const int&) { return absl::make_unique(); }; + auto handle = m.GetOrCreateIfAbsent(0, factory); handle->fn = [&] { deleted = true; }; EXPECT_FALSE(deleted); handle = nullptr; @@ -70,10 +73,10 @@ TEST(RefcountingHashMapTest, DeletesEagerly) { } TEST(RefcountingHashMapTest, CustomFactory) { - RefcountingHashMap m( - [](const int& x) { return absl::make_unique(x + 1); }); - EXPECT_EQ(*m[0], 1); - EXPECT_EQ(*m[100], 101); + RefcountingHashMap m; + auto factory = [](const int& x) { return absl::make_unique(x + 1); }; + EXPECT_EQ(*m.GetOrCreateIfAbsent(0, factory), 1); + EXPECT_EQ(*m.GetOrCreateIfAbsent(100, factory), 101); } TEST(RefcountingHashMapTest, ForEachEmpty) { @@ -85,8 +88,9 @@ TEST(RefcountingHashMapTest, ForEachEmpty) { TEST(RefcountingHashMapTest, ForEachNonempty) { RefcountingHashMap m; - auto a = m[0]; - auto b = m[1]; + auto factory = [](const int&) { return absl::make_unique(); }; + auto a = m.GetOrCreateIfAbsent(0, factory); + auto b = m.GetOrCreateIfAbsent(1, factory); std::vector seen_keys; std::vector seen_values; diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index da50e92de32..12f7cd4654c 100755 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -4585,6 +4585,7 @@ cc_library( "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", # fixdeps: keep "//tensorflow/stream_executor/lib", diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.cc b/tensorflow/compiler/xla/service/collective_ops_utils.cc index cfe586c6c0b..a4eba334f31 100644 --- a/tensorflow/compiler/xla/service/collective_ops_utils.cc +++ b/tensorflow/compiler/xla/service/collective_ops_utils.cc @@ -44,7 +44,7 @@ absl::optional MatchReductionComputation( } StatusOr> GetParticipatingReplicas( - int64 device_ordinal, absl::Span replica_groups, + GlobalDeviceId device_id, absl::Span replica_groups, int64 total_replica_count, const DeviceAssignment& device_assn) { std::vector participating_replicas; @@ -58,7 +58,7 @@ StatusOr> GetParticipatingReplicas( // Use the DeviceAssignment to figure out our replica-id. TF_ASSIGN_OR_RETURN(int replica_id, - device_assn.ReplicaIdForDeviceOrdinal(device_ordinal)); + device_assn.ReplicaIdForDeviceOrdinal(device_id.value())); // Figure out the other replicas that go together with this one. absl::optional replica_group; diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.h b/tensorflow/compiler/xla/service/collective_ops_utils.h index 2524b4190e9..d9b6c48685b 100644 --- a/tensorflow/compiler/xla/service/collective_ops_utils.h +++ b/tensorflow/compiler/xla/service/collective_ops_utils.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" @@ -37,9 +38,9 @@ absl::optional MatchReductionComputation( const HloComputation* computation); // Figures out which devices (named by their replica-ids) are participating in -// the all-reduce subgroup that contains device_ordinal. +// the all-reduce subgroup that contains device_id. StatusOr> GetParticipatingReplicas( - int64 device_ordinal, absl::Span replica_groups, + GlobalDeviceId device_id, absl::Span replica_groups, int64 total_replica_count, const DeviceAssignment& device_assn); // Key that identifies a particular Rendezvous object in our global hashtable. @@ -72,16 +73,18 @@ struct RendezvousKey { }; explicit RendezvousKey(const RunId& run_id, - std::vector participating_replicas, + std::vector global_devices, + int num_local_participants, CollectiveOpKind collective_op_kind, int64 op_id) : run_id(run_id), - participating_replicas(participating_replicas), + global_devices(std::move(global_devices)), + num_local_participants(num_local_participants), collective_op_kind(collective_op_kind), op_id(op_id) {} static RendezvousKey FromInstruction( - const RunId& run_id, std::vector participating_replicas, - const HloInstruction* instr) { + const RunId& run_id, std::vector global_devices, + int num_local_participants, const HloInstruction* instr) { CollectiveOpKind collective_op_kind; int64 op_id; @@ -91,20 +94,19 @@ struct RendezvousKey { : std::make_pair( kCrossReplica, static_cast(instr->GetModule()->unique_id())); - return RendezvousKey(run_id, participating_replicas, collective_op_kind, - op_id); + return RendezvousKey(run_id, std::move(global_devices), + num_local_participants, collective_op_kind, op_id); } - int num_participants() const { return participating_replicas.size(); } - template friend H AbslHashValue(H h, const RendezvousKey& k) { - return H::combine(std::move(h), k.run_id, k.participating_replicas, + return H::combine(std::move(h), k.run_id, k.global_devices, + k.num_local_participants, static_cast(k.collective_op_kind), k.op_id); } friend bool operator==(const RendezvousKey& a, const RendezvousKey& b) { - return a.run_id == b.run_id && - a.participating_replicas == b.participating_replicas && + return a.run_id == b.run_id && a.global_devices == b.global_devices && + a.num_local_participants == b.num_local_participants && a.collective_op_kind == b.collective_op_kind && // a.op_id == b.op_id; } @@ -114,14 +116,15 @@ struct RendezvousKey { string ToString() const { return absl::StrFormat( - "RendezvousKey{run_id=%s, participating_replicas=[%s], " - "collective_op_kind=%d, op_id=%d}", - run_id.ToString(), absl::StrJoin(participating_replicas, ","), - static_cast(collective_op_kind), op_id); + "RendezvousKey{run_id=%s, global_devices=[%s], " + "num_local_participants=%d, collective_op_kind=%d, op_id=%d}", + run_id.ToString(), GlobalDeviceIdsToString(global_devices), + num_local_participants, static_cast(collective_op_kind), op_id); } RunId run_id; - std::vector participating_replicas; + std::vector global_devices; + int num_local_participants; CollectiveOpKind collective_op_kind; int64 op_id; }; @@ -164,10 +167,13 @@ struct AllReduceParticipantData { }; std::vector buffers; se::Stream* stream; + const NcclUniqueIdCallback* nccl_unique_id_callback = nullptr; ReductionKind reduction_kind; - int num_participants() const { return rendezvous_key.num_participants(); } + // For each local all-reduce participant a (global ID, local device ordinal) + // pair for the participant. Participants are in no particular order. + std::vector> local_devices; string ToString() const { std::vector buffer_strs; @@ -303,12 +309,13 @@ class Rendezvous { const RendezvousKey key_; tensorflow::BlockingCounter all_participants_present_{ - key_.num_participants()}; - tensorflow::BlockingCounter done_{key_.num_participants()}; + key_.num_local_participants}; + tensorflow::BlockingCounter done_{key_.num_local_participants}; // tensorflow::BlockingCounter returned by SubmitParticipant. std::shared_ptr returned_blocking_counter_{ - std::make_shared(key_.num_participants())}; + std::make_shared( + key_.num_local_participants)}; }; } // end namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 98c23b679fa..60e184411e9 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -382,10 +382,7 @@ class CpuAllReduceRendezvous : public xla::Rendezvous { xla::RefcountingHashMap& GlobalRendezvousMap() { static auto& m = - *new xla::RefcountingHashMap( - [](const xla::RendezvousKey& k) { - return absl::make_unique(k); - }); + *new xla::RefcountingHashMap; return m; } @@ -411,18 +408,28 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce( std::vector group = xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie(); - xla::int32 replica_count = run_options->device_assignment()->replica_count(); - std::vector participating_replicas_vec = - xla::GetParticipatingReplicas(device_ordinal, group, replica_count, + const xla::DeviceAssignment& device_assignment = + *run_options->device_assignment(); + xla::int32 replica_count = device_assignment.replica_count(); + CHECK_EQ(device_assignment.computation_count(), 1); + std::vector participating_replicas = + xla::GetParticipatingReplicas(xla::GlobalDeviceId(device_ordinal), group, + replica_count, *run_options->device_assignment()) .ValueOrDie(); xla::RendezvousKey::CollectiveOpKind op_kind = channel_id_present ? xla::RendezvousKey::kCrossModule : xla::RendezvousKey::kCrossReplica; - xla::RendezvousKey rendezvous_key(run_options->run_id(), - participating_replicas_vec, op_kind, op_id); - + std::vector participating_devices; + participating_devices.reserve(participating_replicas.size()); + for (xla::int64 replica : participating_replicas) { + participating_devices.push_back( + xla::GlobalDeviceId(device_assignment(replica, 0))); + } + xla::RendezvousKey rendezvous_key( + run_options->run_id(), std::move(participating_devices), + participating_replicas.size(), op_kind, op_id); auto shape_str = ShapeString(shape_ptr, shape_length); VLOG(2) << "All-reduce input/output shape : " << shape_str; @@ -444,10 +451,17 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce( participant.buffers = {buffer}; participant.reduction_kind = static_cast(reduction_kind); - TF_CHECK_OK( - CpuAllReduceRendezvous::SubmitParticipant( - [&] { return GlobalRendezvousMap()[rendezvous_key]; }, participant) - .status()); + auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) { + return absl::make_unique(k); + }; + + TF_CHECK_OK(CpuAllReduceRendezvous::SubmitParticipant( + [&] { + return GlobalRendezvousMap().GetOrCreateIfAbsent( + rendezvous_key, make_cpu_rendezvous); + }, + participant) + .status()); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId( diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index d13eca30cdc..86da500b1dd 100755 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -54,6 +54,22 @@ tf_proto_library_cc( protodeps = ["//tensorflow/compiler/xla:xla_data_proto"], ) +cc_library( + name = "gpu_executable_run_options", + srcs = ["gpu_executable_run_options.cc"], + hdrs = ["gpu_executable_run_options.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib_internal", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "gpu_constants", srcs = ["gpu_constants.cc"], @@ -385,6 +401,7 @@ cc_library( hdrs = ["thunk.h"], deps = [ ":buffer_allocations", + ":gpu_executable_run_options", ":hlo_execution_profiler", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla/service:hlo", @@ -413,6 +430,7 @@ tf_cuda_library( ":buffer_allocations", ":hlo_execution_profiler", ":thunk", + ":gpu_executable_run_options", "@com_google_absl//absl/base:core_headers", "//tensorflow/compiler/xla/service:pattern_matcher", "//tensorflow/compiler/xla:refcounting_hash_map", @@ -522,6 +540,7 @@ cc_library( ":cudnn_batchnorm_runner", ":gpu_conv_runner", ":gpu_debug_info_manager", + ":gpu_executable_run_options", ":gpu_types", ":hlo_execution_profiler", ":infeed_manager", diff --git a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc index 2fe359861f8..2a071cd658d 100644 --- a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc @@ -211,10 +211,7 @@ StatusOr> Rendezvous::SubmitParticipant( // Rendezvous objects are one-time use, so they're removed from this map once // we're through with them. RefcountingHashMap& GlobalRendezvousMap() { - static auto& m = *new RefcountingHashMap( - [](const RendezvousKey& key) { - return absl::make_unique(key); - }); + static auto& m = *new RefcountingHashMap(); return m; } @@ -233,7 +230,11 @@ Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) { // Rendezvous with the threads for all other devices that are participating in // this CollectivePermute. RendezvousKey key{params.run_id, params.device_assn->replica_count()}; - std::shared_ptr rendezvous = GlobalRendezvousMap()[key]; + auto rendezvous_factory = [](const RendezvousKey& key) { + return absl::make_unique(key); + }; + std::shared_ptr rendezvous = + GlobalRendezvousMap().GetOrCreateIfAbsent(key, rendezvous_factory); TF_ASSIGN_OR_RETURN(int64 replica_id, params.device_assn->ReplicaIdForDeviceOrdinal( diff --git a/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc index 7c3d76c1c92..998a3ccb4ee 100644 --- a/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc @@ -34,7 +34,7 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { NcclAllReduceThunk::~NcclAllReduceThunk() = default; -/*static*/ absl::flat_hash_set +/*static*/ absl::flat_hash_set NcclAllReduceThunk::DevicesWithOpenNcclChannels() { return {}; } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index df44f379b99..d4797e094fd 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/gpu/gpu_types.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -196,13 +197,21 @@ Status GpuExecutable::ExecuteThunks( VLOG(2) << "Executing the thunk for " << thunk->hlo_instruction()->ToString() << " on stream " << stream_no; + const GpuExecutableRunOptions* gpu_options = + run_options->run_options().gpu_executable_run_options(); Thunk::ExecuteParams thunk_params{ &buffer_allocations, stream, run_options->run_options().run_id(), &profiler, run_options->run_options().device_assignment(), - &deferred_host_callbacks}; + &deferred_host_callbacks, + gpu_options && gpu_options->gpu_global_device_ids() + ? &*gpu_options->gpu_global_device_ids() + : nullptr, + gpu_options && gpu_options->nccl_unique_id_callback() + ? &gpu_options->nccl_unique_id_callback() + : nullptr}; TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(thunk_params)); if (thunk_schedule_->Depended(thunk)) { auto finish_event = absl::make_unique(main_stream->parent()); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.cc new file mode 100644 index 00000000000..b152962eb99 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.cc @@ -0,0 +1,62 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" + +namespace xla { + +std::string GlobalDeviceIdsToString(absl::Span ids) { + std::vector values; + values.reserve(ids.size()); + for (GlobalDeviceId id : ids) { + values.push_back(id.value()); + } + return absl::StrJoin(values, ","); +} + +NcclCliqueKey::NcclCliqueKey(std::vector devices) + : devices_(std::move(devices)) { + absl::c_sort(devices_); + CHECK(absl::c_adjacent_find(devices_) == devices_.end()) + << "Duplicate devices are not allowed: " + << GlobalDeviceIdsToString(devices_); +} + +GpuExecutableRunOptions& GpuExecutableRunOptions::set_gpu_global_device_ids( + absl::optional> gpu_global_device_ids) { + gpu_global_device_ids_ = std::move(gpu_global_device_ids); + return *this; +} + +const absl::optional>& +GpuExecutableRunOptions::gpu_global_device_ids() const { + return gpu_global_device_ids_; +} + +GpuExecutableRunOptions& GpuExecutableRunOptions::set_nccl_unique_id_callback( + NcclUniqueIdCallback nccl_unique_id_callback) { + nccl_unique_id_callback_ = std::move(nccl_unique_id_callback); + return *this; +} + +const NcclUniqueIdCallback& GpuExecutableRunOptions::nccl_unique_id_callback() + const { + return nccl_unique_id_callback_; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h b/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h new file mode 100644 index 00000000000..7a43c80121b --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h @@ -0,0 +1,90 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_ + +#include +#include +#include + +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/int_type.h" + +namespace xla { + +// Strongly-typed integer type for naming a device globally within a distributed +// system. XLA doesn't have a strong opinion about what global numbering scheme +// is applied to GPUs; the user must provide a local -> global mapping via +// GpuExecutableRunOptions for the local GPUs. +TF_LIB_GTL_DEFINE_INT_TYPE(GlobalDeviceId, int64); + +// Returns a comma-separated string of global device IDs. +std::string GlobalDeviceIdsToString(absl::Span ids); + +// Key for naming up a particular NCCL clique. This is just a set of unique +// device IDs (i.e. GPU IDs). The device IDs must be global within a cluster. +class NcclCliqueKey { + public: + explicit NcclCliqueKey(std::vector devices); + + template + friend H AbslHashValue(H h, const NcclCliqueKey& k) { + return H::combine(std::move(h), k.devices_); + } + friend bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b) { + return a.devices_ == b.devices_; + } + + const std::vector& devices() const { return devices_; } + + private: + std::vector devices_; +}; + +using NcclUniqueIdCallback = + std::function(const NcclCliqueKey&)>; + +// GPU-specific executable options. +// We keep these separate from ExecutableRunOptions to avoid adding +// dependencies to ExecutableRunOptions. +class GpuExecutableRunOptions { + public: + // Sets a mapping from local device ordinals to global device IDs. + // Used only on NVidia GPUs for cross-host NCCL collectives. If set, the + // elements of `device_assignment` are interpreted as global device IDs, not + // local device ordinals. + GpuExecutableRunOptions& set_gpu_global_device_ids( + absl::optional> gpu_global_device_ids); + const absl::optional>& gpu_global_device_ids() + const; + + // Callback that returns a ncclUniqueId encoded as a string for a group of + // communicating GPU devices. Used only on NVidia GPUs. + GpuExecutableRunOptions& set_nccl_unique_id_callback( + NcclUniqueIdCallback nccl_unique_id_callback); + const NcclUniqueIdCallback& nccl_unique_id_callback() const; + + private: + absl::optional> gpu_global_device_ids_; + NcclUniqueIdCallback nccl_unique_id_callback_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index 4498793113a..719429771af 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/refcounting_hash_map.h" #include "tensorflow/compiler/xla/service/collective_ops_utils.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/util.h" @@ -179,27 +180,18 @@ absl::optional DatatypeToNccl(PrimitiveType element_type) { } } -// Key for looking up a particular NCCL clique. This is just a set of unique -// device ordinals (i.e. GPU IDs). -struct NcclCliqueKey { - explicit NcclCliqueKey(absl::Span devices) - : devices(devices.begin(), devices.end()) { - absl::c_sort(this->devices); - CHECK(absl::c_adjacent_find(devices) == devices.end()) - << "Duplicate devices are not allowed: " - << absl::StrJoin(devices, ", "); +Status StringToNcclUniqueId(const std::string& str_id, ncclUniqueId* nccl_id) { + if (str_id.size() != NCCL_UNIQUE_ID_BYTES) { + return InvalidArgument( + "ncclUniqueId string must have %d bytes, got %d bytes", str_id.size(), + NCCL_UNIQUE_ID_BYTES); } - - template - friend H AbslHashValue(H h, const NcclCliqueKey& k) { - return H::combine(std::move(h), k.devices); - } - friend bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b) { - return a.devices == b.devices; - } - - std::vector devices; -}; + // NcclUniqueId is internally just a char[]. + static_assert(sizeof(ncclUniqueId) == NCCL_UNIQUE_ID_BYTES, + "NCCL_UNIQUE_ID_BYTES"); + std::memcpy(static_cast(nccl_id), str_id.data(), NCCL_UNIQUE_ID_BYTES); + return Status::OK(); +} // Owns a clique of NCCL comms which can be used for collective operations among // a particular set of GPUs. @@ -216,20 +208,29 @@ struct NcclCliqueKey { // GPUs, you'll need a different clique. class NcclClique { public: - explicit NcclClique(absl::Span devices) - : devices_(devices.begin(), devices.end()) { - absl::c_sort(devices_); - status_ = Init(); + explicit NcclClique( + int64 num_global_devices, std::vector local_device_ordinals, + std::vector local_device_ranks, + const StatusOr>& nccl_unique_id) + : num_global_devices_(num_global_devices), + local_device_ordinals_(std::move(local_device_ordinals)), + local_device_ranks_(std::move(local_device_ranks)) { + CHECK_EQ(local_device_ordinals_.size(), local_device_ranks_.size()); + // It's unusual to pass a StatusOr<> into a class, but since this class + // already has a erroneous state, it turns out to be a little easier to + // implement this way than to change RefcountingHashMap. + status_ = Init(nccl_unique_id); } Status status() { return status_; } - absl::Span devices() { - TF_CHECK_OK(status_); - return devices_; - } - ncclComm_t comm(int64 device) { - int64 idx = std::distance(devices_.begin(), absl::c_find(devices_, device)); + // A NCCL communicator is the NCCL state associated with a participant (rank) + // in a reduction. This method returns the state associated with a particular + // local device ordinal. + ncclComm_t comm(int64 device_ordinal) { + int64 idx = + std::distance(local_device_ordinals_.begin(), + absl::c_find(local_device_ordinals_, device_ordinal)); return comms_.at(idx).comm(); } @@ -249,10 +250,12 @@ class NcclClique { } private: - Status Init() { + Status Init( + const StatusOr>& maybe_nccl_unique_id) { VLOG(3) << absl::StreamFormat( - "Initializing nccl comms for participant devices {%s}", - absl::StrJoin(devices_, ", ")); + "Initializing nccl comms for participant device ordinals %s ranks {%s}", + absl::StrJoin(local_device_ordinals_, ", "), + absl::StrJoin(local_device_ranks_, ", ")); // Restore CUDA device after running this. XLA shouldn't care, but maybe // another consumer does. @@ -264,15 +267,23 @@ class NcclClique { // When using ncclGroupStart/End it seems that the ncclComm_t's are not // populated until the End() call. This unfortunately makes error handling // tricky. - std::vector raw_comms(devices_.size(), nullptr); + std::vector raw_comms(local_device_ordinals_.size(), nullptr); + TF_ASSIGN_OR_RETURN(const absl::optional& nccl_id_string, + maybe_nccl_unique_id); + ncclUniqueId nccl_id; - XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&nccl_id)); + if (nccl_id_string) { + TF_RETURN_IF_ERROR(StringToNcclUniqueId(*nccl_id_string, &nccl_id)); + } else { + XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&nccl_id)); + } XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart()); Status status = [&] { - for (int i = 0; i < devices_.size(); ++i) { - XLA_CUDA_RETURN_IF_ERROR(cudaSetDevice(devices_[i])); - XLA_CUDA_RETURN_IF_ERROR( - ncclCommInitRank(&raw_comms[i], devices_.size(), nccl_id, i)); + for (int i = 0; i < local_device_ordinals_.size(); ++i) { + XLA_CUDA_RETURN_IF_ERROR(cudaSetDevice(local_device_ordinals_[i])); + XLA_CUDA_RETURN_IF_ERROR(ncclCommInitRank(&raw_comms[i], + num_global_devices_, nccl_id, + local_device_ranks_.at(i))); } return Status::OK(); }(); @@ -282,9 +293,9 @@ class NcclClique { // Populate comms_ from the raw comms we created above. If we encountered // an error above we'll later clear comms_ thus destroying any raw comms // that were created before the error. - for (int i = 0; i < devices_.size(); ++i) { - VLOG(3) << absl::StreamFormat("Device %d assigned ncclComm %p", - devices_[i], raw_comms[i]); + for (int i = 0; i < local_device_ordinals_.size(); ++i) { + VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p", + local_device_ordinals_[i], raw_comms[i]); CHECK(raw_comms[i] != nullptr || !status.ok()); comms_.emplace_back(raw_comms[i]); } @@ -296,7 +307,11 @@ class NcclClique { } Status status_; - std::vector devices_; + int64 num_global_devices_; + std::vector local_device_ordinals_; + // NCCL communicator rank for each local device. The rank of a device is equal + // to the offset of the local device in the global device set. + std::vector local_device_ranks_; std::vector comms_; // This mutex is in a unique_ptr so NcclClique can be movable. @@ -312,10 +327,7 @@ class NcclClique { // have one clique alive for a given set of GPUs. This means that a process // will never do two collective operations concurrently on the same set of GPUs. RefcountingHashMap& GlobalNcclCliqueMap() { - static auto& m = *new RefcountingHashMap( - [](const NcclCliqueKey& key) { - return absl::make_unique(key.devices); - }); + static auto& m = *new RefcountingHashMap(); return m; } @@ -341,10 +353,7 @@ class RendezvousNcclAllReduce : public Rendezvous> { RefcountingHashMap& GlobalRendezvousMap() { static auto& m = - *new RefcountingHashMap( - [](const RendezvousKey& k) { - return absl::make_unique(k); - }); + *new RefcountingHashMap(); return m; } @@ -365,12 +374,46 @@ RendezvousNcclAllReduce::SubmitParticipantImpl( // ensuring that there's a NCCL clique available for us to use. primary = !initialized_; + TF_RET_CHECK(participant.local_devices.size() == + participant.rendezvous_key.num_local_participants); + // Look up or create the NCCL clique for this set of devices. - std::vector devices; - for (const auto& p : participants_) { - devices.push_back(p.device_ordinal); - } - clique = GlobalNcclCliqueMap()[NcclCliqueKey(devices)]; + NcclCliqueKey clique_key(participant.rendezvous_key.global_devices); + + auto clique_factory = + [&](const NcclCliqueKey& key) -> std::unique_ptr { + std::vector local_device_ranks; + std::vector local_device_ordinals; + local_device_ranks.reserve(participant.local_devices.size()); + local_device_ordinals.reserve(participant.local_devices.size()); + for (const auto& l : participant.local_devices) { + auto it = + absl::c_find(participant.rendezvous_key.global_devices, l.first); + CHECK(it != participant.rendezvous_key.global_devices.end()) << l.first; + local_device_ranks.push_back(std::distance( + participant.rendezvous_key.global_devices.begin(), it)); + local_device_ordinals.push_back(l.second); + } + StatusOr> nccl_unique_id; + if (participant.nccl_unique_id_callback) { + nccl_unique_id = (*participant.nccl_unique_id_callback)(clique_key); + } else { + if (participant.rendezvous_key.global_devices.size() != + participant.rendezvous_key.num_local_participants) { + nccl_unique_id = InvalidArgument( + "Multihost AllReduce on GPU requires a nccl_unique_id_callback " + "to be provided by the client."); + } else { + nccl_unique_id = absl::optional(); + } + } + return absl::make_unique( + participant.rendezvous_key.global_devices.size(), + std::move(local_device_ordinals), std::move(local_device_ranks), + nccl_unique_id); + }; + clique = + GlobalNcclCliqueMap().GetOrCreateIfAbsent(clique_key, clique_factory); if (primary) { VLOG(3) << "Primary initializing accounting data."; @@ -463,12 +506,12 @@ struct NcclAllReduceThunk::AuxData { crs->IsCrossReplicaAllReduce() && operands_are_supported(); } -/*static*/ absl::flat_hash_set +/*static*/ absl::flat_hash_set NcclAllReduceThunk::DevicesWithOpenNcclChannels() { - absl::flat_hash_set devices; + absl::flat_hash_set devices; GlobalNcclCliqueMap().ForEach( [&](const NcclCliqueKey& k, const std::shared_ptr&) { - devices.insert(k.devices.begin(), k.devices.end()); + devices.insert(k.devices().begin(), k.devices().end()); }); return devices; } @@ -491,23 +534,57 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); auto* instr = Cast(hlo_instruction()); - int64 device_ordinal = params.stream->parent()->device_ordinal(); + int64 local_device_ordinal = params.stream->parent()->device_ordinal(); + GlobalDeviceId global_device_id; + if (params.gpu_global_device_ids) { + TF_RET_CHECK(0 <= local_device_ordinal && + local_device_ordinal < params.gpu_global_device_ids->size()); + global_device_id = (*params.gpu_global_device_ids)[local_device_ordinal]; + } else { + // No local -> global mapping was provided; assume the identity mapping. + global_device_id = GlobalDeviceId(local_device_ordinal); + } + // Determines the set of global and local devices that are participating in + // the same collective group as the caller. TF_ASSIGN_OR_RETURN( - std::vector participating_replicas, - GetParticipatingReplicas(device_ordinal, instr->replica_groups(), + std::vector global_participating_replicas, + GetParticipatingReplicas(global_device_id, instr->replica_groups(), replica_count_, *params.device_assn)); + std::vector global_devices; + std::vector> local_devices; + local_devices.reserve(global_participating_replicas.size()); + global_devices.reserve(global_participating_replicas.size()); + TF_RET_CHECK(params.device_assn->computation_count() == 1) + << params.device_assn->ToString(); + for (int64 replica : global_participating_replicas) { + GlobalDeviceId global_device( + (*params.device_assn)(replica, /*computation=*/0)); + global_devices.push_back(global_device); + if (!params.gpu_global_device_ids) { + local_devices.emplace_back(global_device, global_device.value()); + } else { + auto it = absl::c_find(*params.gpu_global_device_ids, global_device); + if (it != params.gpu_global_device_ids->end()) { + local_devices.emplace_back( + *it, std::distance(params.gpu_global_device_ids->begin(), it)); + } + } + } + absl::c_sort(global_devices); // Find or create the rendezvous for this collective operation. RendezvousKey rendezvous_key = RendezvousKey::FromInstruction( - params.run_id, participating_replicas, hlo_instruction()); + params.run_id, global_devices, local_devices.size(), hlo_instruction()); VLOG(2) << "Rendezvous key: " << rendezvous_key.ToString() - << ", participating replicas: " - << absl::StrJoin(participating_replicas, ", "); + << ", global participating replicas: " + << absl::StrJoin(global_participating_replicas, ", ") + << ", global participating devices: " + << GlobalDeviceIdsToString(global_devices); AllReduceParticipantData participant(rendezvous_key); - participant.device_ordinal = device_ordinal; + participant.device_ordinal = local_device_ordinal; for (size_t i = 0; i < buffers_.size(); ++i) { const NcclAllReduceThunk::Buffer& buffer = buffers_[i]; AllReduceParticipantData::Buffer pbuffer; @@ -521,15 +598,24 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { participant.buffers.push_back(pbuffer); } participant.stream = params.stream; + participant.local_devices = std::move(local_devices); + participant.nccl_unique_id_callback = params.nccl_unique_id_callback; auto reduction_kind = MatchReductionComputation(hlo_instruction()->to_apply()); CHECK(reduction_kind.has_value()); participant.reduction_kind = *reduction_kind; - TF_ASSIGN_OR_RETURN( - std::shared_ptr clique, - RendezvousNcclAllReduce::SubmitParticipant( - [&] { return GlobalRendezvousMap()[rendezvous_key]; }, participant)); + auto rendezvous_factory = [](const RendezvousKey& k) { + return absl::make_unique(k); + }; + + TF_ASSIGN_OR_RETURN(std::shared_ptr clique, + RendezvousNcclAllReduce::SubmitParticipant( + [&] { + return GlobalRendezvousMap().GetOrCreateIfAbsent( + rendezvous_key, rendezvous_factory); + }, + participant)); // Keep the clique we used alive for as long as this Thunk lives. Creating // new NCCL cliques is expensive, and this is how we avoid thrashing them. diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h index 7633a99794f..90091ed2c7b 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -46,7 +47,7 @@ class NcclAllReduceThunk : public Thunk { // (Indeed, because the NCCL channels are a global variable, in the real // world, the value returned here is stale as soon as you read it, so it's not // clear how you *could* use it for anything other than tests.) - static absl::flat_hash_set DevicesWithOpenNcclChannels(); + static absl::flat_hash_set DevicesWithOpenNcclChannels(); // TODO(b/125951860): Support all-reduces with replica groups, i.e. // all-reduces that compute multiple sums across subsets of all replicas. diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index abf829cee00..326c5a20716 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/core/lib/core/status.h" @@ -98,6 +99,8 @@ class Thunk { HloExecutionProfiler* profiler; // never null const DeviceAssignment* device_assn; // never null std::vector>* deferred_host_callbacks; // never null + const std::vector* gpu_global_device_ids; // may be null + const NcclUniqueIdCallback* nccl_unique_id_callback; // may be null }; // Execute the kernel for the thunk on the given stream. This method must be diff --git a/tensorflow/compiler/xla/tests/collective_ops_test.cc b/tensorflow/compiler/xla/tests/collective_ops_test.cc index 5cdf9633ca4..464865506f7 100644 --- a/tensorflow/compiler/xla/tests/collective_ops_test.cc +++ b/tensorflow/compiler/xla/tests/collective_ops_test.cc @@ -159,7 +159,7 @@ DeviceAssignment MakeDeviceAssn(std::vector devices) { } // Shorter alias for this function. -absl::flat_hash_set OpenNcclChannels() { +absl::flat_hash_set OpenNcclChannels() { return gpu::NcclAllReduceThunk::DevicesWithOpenNcclChannels(); }