[XLA:GPU] Add experimental, lightly tested support for multi-host and multi-process NCCL AllReduce.

This change makes several API changes:
* we allow the client to provide a mapping from the local device ordinals on the machine to global device IDs. If provided, we interpret the device IDs in the DeviceAssignment provided by the client as global IDs, not as local device ordinals. This allows us to describe computations that cross a host boundary.
* we allow the client to provide a callback for manufacturing a ncclUniqueId for a particular subset of global devices. The idea is that the client should use some other distributed system of their own (e.g., MPI) to share ncclUniqueId values needed for a computation. NCCL allows for cross-host/process collectives iff the same ncclUniqueId value is used.

Refactors the common collective logic and the NCCL collective logic in particular to support a local/global distinction.

PiperOrigin-RevId: 296505571
Change-Id: I5ed42d65597b0960df78890745421f77e9789ba3
This commit is contained in:
Peter Hawkins 2020-02-21 13:56:28 -08:00 committed by TensorFlower Gardener
parent 371c29f627
commit 8a72c4466a
18 changed files with 450 additions and 139 deletions

View File

@ -100,6 +100,17 @@ const DeviceAssignment* ExecutableRunOptions::device_assignment() const {
return device_assignment_; return device_assignment_;
} }
ExecutableRunOptions& ExecutableRunOptions::set_gpu_executable_run_options(
const GpuExecutableRunOptions* gpu_executable_run_options) {
gpu_executable_run_options_ = gpu_executable_run_options;
return *this;
}
const GpuExecutableRunOptions*
ExecutableRunOptions::gpu_executable_run_options() const {
return gpu_executable_run_options_;
}
ExecutableRunOptions& ExecutableRunOptions::set_rng_seed(int rng_seed) { ExecutableRunOptions& ExecutableRunOptions::set_rng_seed(int rng_seed) {
rng_seed_ = rng_seed; rng_seed_ = rng_seed;
return *this; return *this;

View File

@ -38,6 +38,7 @@ namespace xla {
class DeviceAssignment; class DeviceAssignment;
class ExecutionProfile; class ExecutionProfile;
class GpuExecutableRunOptions;
// A unique identifier for a particular "logical execution" of an XLA model. // A unique identifier for a particular "logical execution" of an XLA model.
// //
@ -137,6 +138,12 @@ class ExecutableRunOptions {
return then_execute_function_; return then_execute_function_;
} }
// GPU-backend specific options. These are kept out-of-line to avoid bloating
// the size of this dependency for CPU-only AOT builds.
ExecutableRunOptions& set_gpu_executable_run_options(
const GpuExecutableRunOptions* gpu_executable_run_options);
const GpuExecutableRunOptions* gpu_executable_run_options() const;
private: private:
stream_executor::DeviceMemoryAllocator* allocator_ = nullptr; stream_executor::DeviceMemoryAllocator* allocator_ = nullptr;
int device_ordinal_ = -1; int device_ordinal_ = -1;
@ -148,6 +155,7 @@ class ExecutableRunOptions {
stream_executor::Stream* host_to_device_stream_ = nullptr; stream_executor::Stream* host_to_device_stream_ = nullptr;
ThenExecuteFunction* then_execute_function_ = nullptr; ThenExecuteFunction* then_execute_function_ = nullptr;
RunId run_id_; RunId run_id_;
const GpuExecutableRunOptions* gpu_executable_run_options_ = nullptr;
}; };
} // namespace xla } // namespace xla

View File

@ -42,13 +42,7 @@ template <typename K, typename V>
class RefcountingHashMap { class RefcountingHashMap {
public: public:
// Default-constructs new values. // Default-constructs new values.
RefcountingHashMap() RefcountingHashMap() = default;
: value_factory_([](const K&) { return absl::make_unique<V>(); }) {}
// Constructs new values according to the given factory function.
explicit RefcountingHashMap(
std::function<std::unique_ptr<V>(const K&)> value_factory)
: value_factory_(std::move(value_factory)) {}
// Not copyable or movable because this contains internal pointers (namely, // Not copyable or movable because this contains internal pointers (namely,
// instances of Deleter contain pointers to `this` and into `map_`). // instances of Deleter contain pointers to `this` and into `map_`).
@ -60,8 +54,10 @@ class RefcountingHashMap {
// Gets the value for the given key. // Gets the value for the given key.
// //
// If the map doesn't contain a live value for the key, constructs one // If the map doesn't contain a live value for the key, constructs one
// according to the factory passed to the map's constructor. // using `value_factory`.
std::shared_ptr<V> operator[](const K& key) { std::shared_ptr<V> GetOrCreateIfAbsent(
const K& key,
const std::function<std::unique_ptr<V>(const K&)>& value_factory) {
absl::MutexLock lock(&mu_); absl::MutexLock lock(&mu_);
auto it = map_.find(key); auto it = map_.find(key);
// We ensure that the entry has not expired in case deleter was running when // We ensure that the entry has not expired in case deleter was running when
@ -76,7 +72,7 @@ class RefcountingHashMap {
// Create entry in the map and then set its value, so the value can // Create entry in the map and then set its value, so the value can
// contain a pointer back into the map. // contain a pointer back into the map.
it = map_.emplace(key, std::weak_ptr<V>()).first; it = map_.emplace(key, std::weak_ptr<V>()).first;
std::shared_ptr<V> value(value_factory_(key).release(), std::shared_ptr<V> value(value_factory(key).release(),
Deleter{&it->first, this}); Deleter{&it->first, this});
it->second = value; // Set the weak ptr to the shared ptr. it->second = value; // Set the weak ptr to the shared ptr.
return value; return value;
@ -112,7 +108,6 @@ class RefcountingHashMap {
} }
}; };
std::function<std::unique_ptr<V>(const K&)> value_factory_;
absl::Mutex mu_; absl::Mutex mu_;
absl::node_hash_map<K, std::weak_ptr<V>> map_ ABSL_GUARDED_BY(mu_); absl::node_hash_map<K, std::weak_ptr<V>> map_ ABSL_GUARDED_BY(mu_);
}; };

View File

@ -47,22 +47,25 @@ struct DeleteNotifier {
TEST(RefcountingHashMapTest, PointerIdentity) { TEST(RefcountingHashMapTest, PointerIdentity) {
RefcountingHashMap<int, int> m; RefcountingHashMap<int, int> m;
std::shared_ptr<int> a = m[0]; auto factory = [](const int&) { return absl::make_unique<int>(); };
std::shared_ptr<int> b = m[0]; std::shared_ptr<int> a = m.GetOrCreateIfAbsent(0, factory);
std::shared_ptr<int> c = m[1]; std::shared_ptr<int> b = m.GetOrCreateIfAbsent(0, factory);
std::shared_ptr<int> c = m.GetOrCreateIfAbsent(1, factory);
EXPECT_EQ(a.get(), b.get()); EXPECT_EQ(a.get(), b.get());
EXPECT_NE(a.get(), c.get()); EXPECT_NE(a.get(), c.get());
} }
TEST(RefcountingHashMapTest, DefaultInitialized) { TEST(RefcountingHashMapTest, DefaultInitialized) {
RefcountingHashMap<int, int> m; RefcountingHashMap<int, int> m;
EXPECT_EQ(*m[42], 0); auto factory = [](const int&) { return absl::make_unique<int>(); };
EXPECT_EQ(*m.GetOrCreateIfAbsent(42, factory), 0);
} }
TEST(RefcountingHashMapTest, DeletesEagerly) { TEST(RefcountingHashMapTest, DeletesEagerly) {
RefcountingHashMap<int, DeleteNotifier> m; RefcountingHashMap<int, DeleteNotifier> m;
bool deleted = false; bool deleted = false;
auto handle = m[0]; auto factory = [](const int&) { return absl::make_unique<DeleteNotifier>(); };
auto handle = m.GetOrCreateIfAbsent(0, factory);
handle->fn = [&] { deleted = true; }; handle->fn = [&] { deleted = true; };
EXPECT_FALSE(deleted); EXPECT_FALSE(deleted);
handle = nullptr; handle = nullptr;
@ -70,10 +73,10 @@ TEST(RefcountingHashMapTest, DeletesEagerly) {
} }
TEST(RefcountingHashMapTest, CustomFactory) { TEST(RefcountingHashMapTest, CustomFactory) {
RefcountingHashMap<int, int> m( RefcountingHashMap<int, int> m;
[](const int& x) { return absl::make_unique<int>(x + 1); }); auto factory = [](const int& x) { return absl::make_unique<int>(x + 1); };
EXPECT_EQ(*m[0], 1); EXPECT_EQ(*m.GetOrCreateIfAbsent(0, factory), 1);
EXPECT_EQ(*m[100], 101); EXPECT_EQ(*m.GetOrCreateIfAbsent(100, factory), 101);
} }
TEST(RefcountingHashMapTest, ForEachEmpty) { TEST(RefcountingHashMapTest, ForEachEmpty) {
@ -85,8 +88,9 @@ TEST(RefcountingHashMapTest, ForEachEmpty) {
TEST(RefcountingHashMapTest, ForEachNonempty) { TEST(RefcountingHashMapTest, ForEachNonempty) {
RefcountingHashMap<int, int> m; RefcountingHashMap<int, int> m;
auto a = m[0]; auto factory = [](const int&) { return absl::make_unique<int>(); };
auto b = m[1]; auto a = m.GetOrCreateIfAbsent(0, factory);
auto b = m.GetOrCreateIfAbsent(1, factory);
std::vector<int> seen_keys; std::vector<int> seen_keys;
std::vector<int*> seen_values; std::vector<int*> seen_values;

View File

@ -4585,6 +4585,7 @@ cc_library(
"//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/service:pattern_matcher", "//tensorflow/compiler/xla/service:pattern_matcher",
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", # fixdeps: keep "//tensorflow/core:lib_internal", # fixdeps: keep
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",

View File

@ -44,7 +44,7 @@ absl::optional<ReductionKind> MatchReductionComputation(
} }
StatusOr<std::vector<int64>> GetParticipatingReplicas( StatusOr<std::vector<int64>> GetParticipatingReplicas(
int64 device_ordinal, absl::Span<const ReplicaGroup> replica_groups, GlobalDeviceId device_id, absl::Span<const ReplicaGroup> replica_groups,
int64 total_replica_count, const DeviceAssignment& device_assn) { int64 total_replica_count, const DeviceAssignment& device_assn) {
std::vector<int64> participating_replicas; std::vector<int64> participating_replicas;
@ -58,7 +58,7 @@ StatusOr<std::vector<int64>> GetParticipatingReplicas(
// Use the DeviceAssignment to figure out our replica-id. // Use the DeviceAssignment to figure out our replica-id.
TF_ASSIGN_OR_RETURN(int replica_id, TF_ASSIGN_OR_RETURN(int replica_id,
device_assn.ReplicaIdForDeviceOrdinal(device_ordinal)); device_assn.ReplicaIdForDeviceOrdinal(device_id.value()));
// Figure out the other replicas that go together with this one. // Figure out the other replicas that go together with this one.
absl::optional<ReplicaGroup> replica_group; absl::optional<ReplicaGroup> replica_group;

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h"
@ -37,9 +38,9 @@ absl::optional<ReductionKind> MatchReductionComputation(
const HloComputation* computation); const HloComputation* computation);
// Figures out which devices (named by their replica-ids) are participating in // Figures out which devices (named by their replica-ids) are participating in
// the all-reduce subgroup that contains device_ordinal. // the all-reduce subgroup that contains device_id.
StatusOr<std::vector<int64>> GetParticipatingReplicas( StatusOr<std::vector<int64>> GetParticipatingReplicas(
int64 device_ordinal, absl::Span<const ReplicaGroup> replica_groups, GlobalDeviceId device_id, absl::Span<const ReplicaGroup> replica_groups,
int64 total_replica_count, const DeviceAssignment& device_assn); int64 total_replica_count, const DeviceAssignment& device_assn);
// Key that identifies a particular Rendezvous object in our global hashtable. // Key that identifies a particular Rendezvous object in our global hashtable.
@ -72,16 +73,18 @@ struct RendezvousKey {
}; };
explicit RendezvousKey(const RunId& run_id, explicit RendezvousKey(const RunId& run_id,
std::vector<int64> participating_replicas, std::vector<GlobalDeviceId> global_devices,
int num_local_participants,
CollectiveOpKind collective_op_kind, int64 op_id) CollectiveOpKind collective_op_kind, int64 op_id)
: run_id(run_id), : run_id(run_id),
participating_replicas(participating_replicas), global_devices(std::move(global_devices)),
num_local_participants(num_local_participants),
collective_op_kind(collective_op_kind), collective_op_kind(collective_op_kind),
op_id(op_id) {} op_id(op_id) {}
static RendezvousKey FromInstruction( static RendezvousKey FromInstruction(
const RunId& run_id, std::vector<int64> participating_replicas, const RunId& run_id, std::vector<GlobalDeviceId> global_devices,
const HloInstruction* instr) { int num_local_participants, const HloInstruction* instr) {
CollectiveOpKind collective_op_kind; CollectiveOpKind collective_op_kind;
int64 op_id; int64 op_id;
@ -91,20 +94,19 @@ struct RendezvousKey {
: std::make_pair( : std::make_pair(
kCrossReplica, kCrossReplica,
static_cast<int64>(instr->GetModule()->unique_id())); static_cast<int64>(instr->GetModule()->unique_id()));
return RendezvousKey(run_id, participating_replicas, collective_op_kind, return RendezvousKey(run_id, std::move(global_devices),
op_id); num_local_participants, collective_op_kind, op_id);
} }
int num_participants() const { return participating_replicas.size(); }
template <typename H> template <typename H>
friend H AbslHashValue(H h, const RendezvousKey& k) { friend H AbslHashValue(H h, const RendezvousKey& k) {
return H::combine(std::move(h), k.run_id, k.participating_replicas, return H::combine(std::move(h), k.run_id, k.global_devices,
k.num_local_participants,
static_cast<int>(k.collective_op_kind), k.op_id); static_cast<int>(k.collective_op_kind), k.op_id);
} }
friend bool operator==(const RendezvousKey& a, const RendezvousKey& b) { friend bool operator==(const RendezvousKey& a, const RendezvousKey& b) {
return a.run_id == b.run_id && return a.run_id == b.run_id && a.global_devices == b.global_devices &&
a.participating_replicas == b.participating_replicas && a.num_local_participants == b.num_local_participants &&
a.collective_op_kind == b.collective_op_kind && // a.collective_op_kind == b.collective_op_kind && //
a.op_id == b.op_id; a.op_id == b.op_id;
} }
@ -114,14 +116,15 @@ struct RendezvousKey {
string ToString() const { string ToString() const {
return absl::StrFormat( return absl::StrFormat(
"RendezvousKey{run_id=%s, participating_replicas=[%s], " "RendezvousKey{run_id=%s, global_devices=[%s], "
"collective_op_kind=%d, op_id=%d}", "num_local_participants=%d, collective_op_kind=%d, op_id=%d}",
run_id.ToString(), absl::StrJoin(participating_replicas, ","), run_id.ToString(), GlobalDeviceIdsToString(global_devices),
static_cast<int>(collective_op_kind), op_id); num_local_participants, static_cast<int>(collective_op_kind), op_id);
} }
RunId run_id; RunId run_id;
std::vector<int64> participating_replicas; std::vector<GlobalDeviceId> global_devices;
int num_local_participants;
CollectiveOpKind collective_op_kind; CollectiveOpKind collective_op_kind;
int64 op_id; int64 op_id;
}; };
@ -164,10 +167,13 @@ struct AllReduceParticipantData {
}; };
std::vector<Buffer> buffers; std::vector<Buffer> buffers;
se::Stream* stream; se::Stream* stream;
const NcclUniqueIdCallback* nccl_unique_id_callback = nullptr;
ReductionKind reduction_kind; ReductionKind reduction_kind;
int num_participants() const { return rendezvous_key.num_participants(); } // For each local all-reduce participant a (global ID, local device ordinal)
// pair for the participant. Participants are in no particular order.
std::vector<std::pair<GlobalDeviceId, int64>> local_devices;
string ToString() const { string ToString() const {
std::vector<std::string> buffer_strs; std::vector<std::string> buffer_strs;
@ -303,12 +309,13 @@ class Rendezvous {
const RendezvousKey key_; const RendezvousKey key_;
tensorflow::BlockingCounter all_participants_present_{ tensorflow::BlockingCounter all_participants_present_{
key_.num_participants()}; key_.num_local_participants};
tensorflow::BlockingCounter done_{key_.num_participants()}; tensorflow::BlockingCounter done_{key_.num_local_participants};
// tensorflow::BlockingCounter returned by SubmitParticipant. // tensorflow::BlockingCounter returned by SubmitParticipant.
std::shared_ptr<tensorflow::BlockingCounter> returned_blocking_counter_{ std::shared_ptr<tensorflow::BlockingCounter> returned_blocking_counter_{
std::make_shared<tensorflow::BlockingCounter>(key_.num_participants())}; std::make_shared<tensorflow::BlockingCounter>(
key_.num_local_participants)};
}; };
} // end namespace xla } // end namespace xla

View File

@ -382,10 +382,7 @@ class CpuAllReduceRendezvous : public xla::Rendezvous<std::nullptr_t> {
xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>& xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>&
GlobalRendezvousMap() { GlobalRendezvousMap() {
static auto& m = static auto& m =
*new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>( *new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>;
[](const xla::RendezvousKey& k) {
return absl::make_unique<CpuAllReduceRendezvous>(k);
});
return m; return m;
} }
@ -411,18 +408,28 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
std::vector<xla::ReplicaGroup> group = std::vector<xla::ReplicaGroup> group =
xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie(); xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
xla::int32 replica_count = run_options->device_assignment()->replica_count(); const xla::DeviceAssignment& device_assignment =
std::vector<xla::int64> participating_replicas_vec = *run_options->device_assignment();
xla::GetParticipatingReplicas(device_ordinal, group, replica_count, xla::int32 replica_count = device_assignment.replica_count();
CHECK_EQ(device_assignment.computation_count(), 1);
std::vector<xla::int64> participating_replicas =
xla::GetParticipatingReplicas(xla::GlobalDeviceId(device_ordinal), group,
replica_count,
*run_options->device_assignment()) *run_options->device_assignment())
.ValueOrDie(); .ValueOrDie();
xla::RendezvousKey::CollectiveOpKind op_kind = xla::RendezvousKey::CollectiveOpKind op_kind =
channel_id_present ? xla::RendezvousKey::kCrossModule channel_id_present ? xla::RendezvousKey::kCrossModule
: xla::RendezvousKey::kCrossReplica; : xla::RendezvousKey::kCrossReplica;
xla::RendezvousKey rendezvous_key(run_options->run_id(), std::vector<xla::GlobalDeviceId> participating_devices;
participating_replicas_vec, op_kind, op_id); participating_devices.reserve(participating_replicas.size());
for (xla::int64 replica : participating_replicas) {
participating_devices.push_back(
xla::GlobalDeviceId(device_assignment(replica, 0)));
}
xla::RendezvousKey rendezvous_key(
run_options->run_id(), std::move(participating_devices),
participating_replicas.size(), op_kind, op_id);
auto shape_str = ShapeString(shape_ptr, shape_length); auto shape_str = ShapeString(shape_ptr, shape_length);
VLOG(2) << "All-reduce input/output shape : " << shape_str; VLOG(2) << "All-reduce input/output shape : " << shape_str;
@ -444,10 +451,17 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
participant.buffers = {buffer}; participant.buffers = {buffer};
participant.reduction_kind = static_cast<xla::ReductionKind>(reduction_kind); participant.reduction_kind = static_cast<xla::ReductionKind>(reduction_kind);
TF_CHECK_OK( auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
CpuAllReduceRendezvous::SubmitParticipant( return absl::make_unique<CpuAllReduceRendezvous>(k);
[&] { return GlobalRendezvousMap()[rendezvous_key]; }, participant) };
.status());
TF_CHECK_OK(CpuAllReduceRendezvous::SubmitParticipant(
[&] {
return GlobalRendezvousMap().GetOrCreateIfAbsent(
rendezvous_key, make_cpu_rendezvous);
},
participant)
.status());
} }
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId( TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId(

View File

@ -54,6 +54,22 @@ tf_proto_library_cc(
protodeps = ["//tensorflow/compiler/xla:xla_data_proto"], protodeps = ["//tensorflow/compiler/xla:xla_data_proto"],
) )
cc_library(
name = "gpu_executable_run_options",
srcs = ["gpu_executable_run_options.cc"],
hdrs = ["gpu_executable_run_options.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
cc_library( cc_library(
name = "gpu_constants", name = "gpu_constants",
srcs = ["gpu_constants.cc"], srcs = ["gpu_constants.cc"],
@ -385,6 +401,7 @@ cc_library(
hdrs = ["thunk.h"], hdrs = ["thunk.h"],
deps = [ deps = [
":buffer_allocations", ":buffer_allocations",
":gpu_executable_run_options",
":hlo_execution_profiler", ":hlo_execution_profiler",
"//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo",
@ -413,6 +430,7 @@ tf_cuda_library(
":buffer_allocations", ":buffer_allocations",
":hlo_execution_profiler", ":hlo_execution_profiler",
":thunk", ":thunk",
":gpu_executable_run_options",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"//tensorflow/compiler/xla/service:pattern_matcher", "//tensorflow/compiler/xla/service:pattern_matcher",
"//tensorflow/compiler/xla:refcounting_hash_map", "//tensorflow/compiler/xla:refcounting_hash_map",
@ -522,6 +540,7 @@ cc_library(
":cudnn_batchnorm_runner", ":cudnn_batchnorm_runner",
":gpu_conv_runner", ":gpu_conv_runner",
":gpu_debug_info_manager", ":gpu_debug_info_manager",
":gpu_executable_run_options",
":gpu_types", ":gpu_types",
":hlo_execution_profiler", ":hlo_execution_profiler",
":infeed_manager", ":infeed_manager",

View File

@ -211,10 +211,7 @@ StatusOr<std::shared_ptr<BlockingCounter>> Rendezvous::SubmitParticipant(
// Rendezvous objects are one-time use, so they're removed from this map once // Rendezvous objects are one-time use, so they're removed from this map once
// we're through with them. // we're through with them.
RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() { RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
static auto& m = *new RefcountingHashMap<RendezvousKey, Rendezvous>( static auto& m = *new RefcountingHashMap<RendezvousKey, Rendezvous>();
[](const RendezvousKey& key) {
return absl::make_unique<Rendezvous>(key);
});
return m; return m;
} }
@ -233,7 +230,11 @@ Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) {
// Rendezvous with the threads for all other devices that are participating in // Rendezvous with the threads for all other devices that are participating in
// this CollectivePermute. // this CollectivePermute.
RendezvousKey key{params.run_id, params.device_assn->replica_count()}; RendezvousKey key{params.run_id, params.device_assn->replica_count()};
std::shared_ptr<Rendezvous> rendezvous = GlobalRendezvousMap()[key]; auto rendezvous_factory = [](const RendezvousKey& key) {
return absl::make_unique<Rendezvous>(key);
};
std::shared_ptr<Rendezvous> rendezvous =
GlobalRendezvousMap().GetOrCreateIfAbsent(key, rendezvous_factory);
TF_ASSIGN_OR_RETURN(int64 replica_id, TF_ASSIGN_OR_RETURN(int64 replica_id,
params.device_assn->ReplicaIdForDeviceOrdinal( params.device_assn->ReplicaIdForDeviceOrdinal(

View File

@ -34,7 +34,7 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
NcclAllReduceThunk::~NcclAllReduceThunk() = default; NcclAllReduceThunk::~NcclAllReduceThunk() = default;
/*static*/ absl::flat_hash_set<int> /*static*/ absl::flat_hash_set<GlobalDeviceId>
NcclAllReduceThunk::DevicesWithOpenNcclChannels() { NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
return {}; return {};
} }

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h" #include "tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_types.h" #include "tensorflow/compiler/xla/service/gpu/gpu_types.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
@ -196,13 +197,21 @@ Status GpuExecutable::ExecuteThunks(
VLOG(2) << "Executing the thunk for " VLOG(2) << "Executing the thunk for "
<< thunk->hlo_instruction()->ToString() << " on stream " << thunk->hlo_instruction()->ToString() << " on stream "
<< stream_no; << stream_no;
const GpuExecutableRunOptions* gpu_options =
run_options->run_options().gpu_executable_run_options();
Thunk::ExecuteParams thunk_params{ Thunk::ExecuteParams thunk_params{
&buffer_allocations, &buffer_allocations,
stream, stream,
run_options->run_options().run_id(), run_options->run_options().run_id(),
&profiler, &profiler,
run_options->run_options().device_assignment(), run_options->run_options().device_assignment(),
&deferred_host_callbacks}; &deferred_host_callbacks,
gpu_options && gpu_options->gpu_global_device_ids()
? &*gpu_options->gpu_global_device_ids()
: nullptr,
gpu_options && gpu_options->nccl_unique_id_callback()
? &gpu_options->nccl_unique_id_callback()
: nullptr};
TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(thunk_params)); TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(thunk_params));
if (thunk_schedule_->Depended(thunk)) { if (thunk_schedule_->Depended(thunk)) {
auto finish_event = absl::make_unique<se::Event>(main_stream->parent()); auto finish_event = absl::make_unique<se::Event>(main_stream->parent());

View File

@ -0,0 +1,62 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "absl/algorithm/container.h"
#include "absl/strings/str_join.h"
namespace xla {
std::string GlobalDeviceIdsToString(absl::Span<GlobalDeviceId const> ids) {
std::vector<int64> values;
values.reserve(ids.size());
for (GlobalDeviceId id : ids) {
values.push_back(id.value());
}
return absl::StrJoin(values, ",");
}
NcclCliqueKey::NcclCliqueKey(std::vector<GlobalDeviceId> devices)
: devices_(std::move(devices)) {
absl::c_sort(devices_);
CHECK(absl::c_adjacent_find(devices_) == devices_.end())
<< "Duplicate devices are not allowed: "
<< GlobalDeviceIdsToString(devices_);
}
GpuExecutableRunOptions& GpuExecutableRunOptions::set_gpu_global_device_ids(
absl::optional<std::vector<GlobalDeviceId>> gpu_global_device_ids) {
gpu_global_device_ids_ = std::move(gpu_global_device_ids);
return *this;
}
const absl::optional<std::vector<GlobalDeviceId>>&
GpuExecutableRunOptions::gpu_global_device_ids() const {
return gpu_global_device_ids_;
}
GpuExecutableRunOptions& GpuExecutableRunOptions::set_nccl_unique_id_callback(
NcclUniqueIdCallback nccl_unique_id_callback) {
nccl_unique_id_callback_ = std::move(nccl_unique_id_callback);
return *this;
}
const NcclUniqueIdCallback& GpuExecutableRunOptions::nccl_unique_id_callback()
const {
return nccl_unique_id_callback_;
}
} // namespace xla

View File

@ -0,0 +1,90 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_
#include <functional>
#include <string>
#include <vector>
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/int_type.h"
namespace xla {
// Strongly-typed integer type for naming a device globally within a distributed
// system. XLA doesn't have a strong opinion about what global numbering scheme
// is applied to GPUs; the user must provide a local -> global mapping via
// GpuExecutableRunOptions for the local GPUs.
TF_LIB_GTL_DEFINE_INT_TYPE(GlobalDeviceId, int64);
// Returns a comma-separated string of global device IDs.
std::string GlobalDeviceIdsToString(absl::Span<GlobalDeviceId const> ids);
// Key for naming up a particular NCCL clique. This is just a set of unique
// device IDs (i.e. GPU IDs). The device IDs must be global within a cluster.
class NcclCliqueKey {
public:
explicit NcclCliqueKey(std::vector<GlobalDeviceId> devices);
template <typename H>
friend H AbslHashValue(H h, const NcclCliqueKey& k) {
return H::combine(std::move(h), k.devices_);
}
friend bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b) {
return a.devices_ == b.devices_;
}
const std::vector<GlobalDeviceId>& devices() const { return devices_; }
private:
std::vector<GlobalDeviceId> devices_;
};
using NcclUniqueIdCallback =
std::function<StatusOr<std::string>(const NcclCliqueKey&)>;
// GPU-specific executable options.
// We keep these separate from ExecutableRunOptions to avoid adding
// dependencies to ExecutableRunOptions.
class GpuExecutableRunOptions {
public:
// Sets a mapping from local device ordinals to global device IDs.
// Used only on NVidia GPUs for cross-host NCCL collectives. If set, the
// elements of `device_assignment` are interpreted as global device IDs, not
// local device ordinals.
GpuExecutableRunOptions& set_gpu_global_device_ids(
absl::optional<std::vector<GlobalDeviceId>> gpu_global_device_ids);
const absl::optional<std::vector<GlobalDeviceId>>& gpu_global_device_ids()
const;
// Callback that returns a ncclUniqueId encoded as a string for a group of
// communicating GPU devices. Used only on NVidia GPUs.
GpuExecutableRunOptions& set_nccl_unique_id_callback(
NcclUniqueIdCallback nccl_unique_id_callback);
const NcclUniqueIdCallback& nccl_unique_id_callback() const;
private:
absl::optional<std::vector<GlobalDeviceId>> gpu_global_device_ids_;
NcclUniqueIdCallback nccl_unique_id_callback_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/refcounting_hash_map.h" #include "tensorflow/compiler/xla/refcounting_hash_map.h"
#include "tensorflow/compiler/xla/service/collective_ops_utils.h" #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
@ -179,27 +180,18 @@ absl::optional<ncclDataType_t> DatatypeToNccl(PrimitiveType element_type) {
} }
} }
// Key for looking up a particular NCCL clique. This is just a set of unique Status StringToNcclUniqueId(const std::string& str_id, ncclUniqueId* nccl_id) {
// device ordinals (i.e. GPU IDs). if (str_id.size() != NCCL_UNIQUE_ID_BYTES) {
struct NcclCliqueKey { return InvalidArgument(
explicit NcclCliqueKey(absl::Span<const int64> devices) "ncclUniqueId string must have %d bytes, got %d bytes", str_id.size(),
: devices(devices.begin(), devices.end()) { NCCL_UNIQUE_ID_BYTES);
absl::c_sort(this->devices);
CHECK(absl::c_adjacent_find(devices) == devices.end())
<< "Duplicate devices are not allowed: "
<< absl::StrJoin(devices, ", ");
} }
// NcclUniqueId is internally just a char[].
template <typename H> static_assert(sizeof(ncclUniqueId) == NCCL_UNIQUE_ID_BYTES,
friend H AbslHashValue(H h, const NcclCliqueKey& k) { "NCCL_UNIQUE_ID_BYTES");
return H::combine(std::move(h), k.devices); std::memcpy(static_cast<void*>(nccl_id), str_id.data(), NCCL_UNIQUE_ID_BYTES);
} return Status::OK();
friend bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b) { }
return a.devices == b.devices;
}
std::vector<int64> devices;
};
// Owns a clique of NCCL comms which can be used for collective operations among // Owns a clique of NCCL comms which can be used for collective operations among
// a particular set of GPUs. // a particular set of GPUs.
@ -216,20 +208,29 @@ struct NcclCliqueKey {
// GPUs, you'll need a different clique. // GPUs, you'll need a different clique.
class NcclClique { class NcclClique {
public: public:
explicit NcclClique(absl::Span<const int64> devices) explicit NcclClique(
: devices_(devices.begin(), devices.end()) { int64 num_global_devices, std::vector<int64> local_device_ordinals,
absl::c_sort(devices_); std::vector<int64> local_device_ranks,
status_ = Init(); const StatusOr<absl::optional<std::string>>& nccl_unique_id)
: num_global_devices_(num_global_devices),
local_device_ordinals_(std::move(local_device_ordinals)),
local_device_ranks_(std::move(local_device_ranks)) {
CHECK_EQ(local_device_ordinals_.size(), local_device_ranks_.size());
// It's unusual to pass a StatusOr<> into a class, but since this class
// already has a erroneous state, it turns out to be a little easier to
// implement this way than to change RefcountingHashMap.
status_ = Init(nccl_unique_id);
} }
Status status() { return status_; } Status status() { return status_; }
absl::Span<const int64> devices() { // A NCCL communicator is the NCCL state associated with a participant (rank)
TF_CHECK_OK(status_); // in a reduction. This method returns the state associated with a particular
return devices_; // local device ordinal.
} ncclComm_t comm(int64 device_ordinal) {
ncclComm_t comm(int64 device) { int64 idx =
int64 idx = std::distance(devices_.begin(), absl::c_find(devices_, device)); std::distance(local_device_ordinals_.begin(),
absl::c_find(local_device_ordinals_, device_ordinal));
return comms_.at(idx).comm(); return comms_.at(idx).comm();
} }
@ -249,10 +250,12 @@ class NcclClique {
} }
private: private:
Status Init() { Status Init(
const StatusOr<absl::optional<std::string>>& maybe_nccl_unique_id) {
VLOG(3) << absl::StreamFormat( VLOG(3) << absl::StreamFormat(
"Initializing nccl comms for participant devices {%s}", "Initializing nccl comms for participant device ordinals %s ranks {%s}",
absl::StrJoin(devices_, ", ")); absl::StrJoin(local_device_ordinals_, ", "),
absl::StrJoin(local_device_ranks_, ", "));
// Restore CUDA device after running this. XLA shouldn't care, but maybe // Restore CUDA device after running this. XLA shouldn't care, but maybe
// another consumer does. // another consumer does.
@ -264,15 +267,23 @@ class NcclClique {
// When using ncclGroupStart/End it seems that the ncclComm_t's are not // When using ncclGroupStart/End it seems that the ncclComm_t's are not
// populated until the End() call. This unfortunately makes error handling // populated until the End() call. This unfortunately makes error handling
// tricky. // tricky.
std::vector<ncclComm_t> raw_comms(devices_.size(), nullptr); std::vector<ncclComm_t> raw_comms(local_device_ordinals_.size(), nullptr);
TF_ASSIGN_OR_RETURN(const absl::optional<std::string>& nccl_id_string,
maybe_nccl_unique_id);
ncclUniqueId nccl_id; ncclUniqueId nccl_id;
XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&nccl_id)); if (nccl_id_string) {
TF_RETURN_IF_ERROR(StringToNcclUniqueId(*nccl_id_string, &nccl_id));
} else {
XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&nccl_id));
}
XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart()); XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
Status status = [&] { Status status = [&] {
for (int i = 0; i < devices_.size(); ++i) { for (int i = 0; i < local_device_ordinals_.size(); ++i) {
XLA_CUDA_RETURN_IF_ERROR(cudaSetDevice(devices_[i])); XLA_CUDA_RETURN_IF_ERROR(cudaSetDevice(local_device_ordinals_[i]));
XLA_CUDA_RETURN_IF_ERROR( XLA_CUDA_RETURN_IF_ERROR(ncclCommInitRank(&raw_comms[i],
ncclCommInitRank(&raw_comms[i], devices_.size(), nccl_id, i)); num_global_devices_, nccl_id,
local_device_ranks_.at(i)));
} }
return Status::OK(); return Status::OK();
}(); }();
@ -282,9 +293,9 @@ class NcclClique {
// Populate comms_ from the raw comms we created above. If we encountered // Populate comms_ from the raw comms we created above. If we encountered
// an error above we'll later clear comms_ thus destroying any raw comms // an error above we'll later clear comms_ thus destroying any raw comms
// that were created before the error. // that were created before the error.
for (int i = 0; i < devices_.size(); ++i) { for (int i = 0; i < local_device_ordinals_.size(); ++i) {
VLOG(3) << absl::StreamFormat("Device %d assigned ncclComm %p", VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p",
devices_[i], raw_comms[i]); local_device_ordinals_[i], raw_comms[i]);
CHECK(raw_comms[i] != nullptr || !status.ok()); CHECK(raw_comms[i] != nullptr || !status.ok());
comms_.emplace_back(raw_comms[i]); comms_.emplace_back(raw_comms[i]);
} }
@ -296,7 +307,11 @@ class NcclClique {
} }
Status status_; Status status_;
std::vector<int64> devices_; int64 num_global_devices_;
std::vector<int64> local_device_ordinals_;
// NCCL communicator rank for each local device. The rank of a device is equal
// to the offset of the local device in the global device set.
std::vector<int64> local_device_ranks_;
std::vector<NcclComm> comms_; std::vector<NcclComm> comms_;
// This mutex is in a unique_ptr so NcclClique can be movable. // This mutex is in a unique_ptr so NcclClique can be movable.
@ -312,10 +327,7 @@ class NcclClique {
// have one clique alive for a given set of GPUs. This means that a process // have one clique alive for a given set of GPUs. This means that a process
// will never do two collective operations concurrently on the same set of GPUs. // will never do two collective operations concurrently on the same set of GPUs.
RefcountingHashMap<NcclCliqueKey, NcclClique>& GlobalNcclCliqueMap() { RefcountingHashMap<NcclCliqueKey, NcclClique>& GlobalNcclCliqueMap() {
static auto& m = *new RefcountingHashMap<NcclCliqueKey, NcclClique>( static auto& m = *new RefcountingHashMap<NcclCliqueKey, NcclClique>();
[](const NcclCliqueKey& key) {
return absl::make_unique<NcclClique>(key.devices);
});
return m; return m;
} }
@ -341,10 +353,7 @@ class RendezvousNcclAllReduce : public Rendezvous<std::shared_ptr<NcclClique>> {
RefcountingHashMap<RendezvousKey, RendezvousNcclAllReduce>& RefcountingHashMap<RendezvousKey, RendezvousNcclAllReduce>&
GlobalRendezvousMap() { GlobalRendezvousMap() {
static auto& m = static auto& m =
*new RefcountingHashMap<RendezvousKey, RendezvousNcclAllReduce>( *new RefcountingHashMap<RendezvousKey, RendezvousNcclAllReduce>();
[](const RendezvousKey& k) {
return absl::make_unique<RendezvousNcclAllReduce>(k);
});
return m; return m;
} }
@ -365,12 +374,46 @@ RendezvousNcclAllReduce::SubmitParticipantImpl(
// ensuring that there's a NCCL clique available for us to use. // ensuring that there's a NCCL clique available for us to use.
primary = !initialized_; primary = !initialized_;
TF_RET_CHECK(participant.local_devices.size() ==
participant.rendezvous_key.num_local_participants);
// Look up or create the NCCL clique for this set of devices. // Look up or create the NCCL clique for this set of devices.
std::vector<int64> devices; NcclCliqueKey clique_key(participant.rendezvous_key.global_devices);
for (const auto& p : participants_) {
devices.push_back(p.device_ordinal); auto clique_factory =
} [&](const NcclCliqueKey& key) -> std::unique_ptr<NcclClique> {
clique = GlobalNcclCliqueMap()[NcclCliqueKey(devices)]; std::vector<int64> local_device_ranks;
std::vector<int64> local_device_ordinals;
local_device_ranks.reserve(participant.local_devices.size());
local_device_ordinals.reserve(participant.local_devices.size());
for (const auto& l : participant.local_devices) {
auto it =
absl::c_find(participant.rendezvous_key.global_devices, l.first);
CHECK(it != participant.rendezvous_key.global_devices.end()) << l.first;
local_device_ranks.push_back(std::distance(
participant.rendezvous_key.global_devices.begin(), it));
local_device_ordinals.push_back(l.second);
}
StatusOr<absl::optional<std::string>> nccl_unique_id;
if (participant.nccl_unique_id_callback) {
nccl_unique_id = (*participant.nccl_unique_id_callback)(clique_key);
} else {
if (participant.rendezvous_key.global_devices.size() !=
participant.rendezvous_key.num_local_participants) {
nccl_unique_id = InvalidArgument(
"Multihost AllReduce on GPU requires a nccl_unique_id_callback "
"to be provided by the client.");
} else {
nccl_unique_id = absl::optional<std::string>();
}
}
return absl::make_unique<NcclClique>(
participant.rendezvous_key.global_devices.size(),
std::move(local_device_ordinals), std::move(local_device_ranks),
nccl_unique_id);
};
clique =
GlobalNcclCliqueMap().GetOrCreateIfAbsent(clique_key, clique_factory);
if (primary) { if (primary) {
VLOG(3) << "Primary initializing accounting data."; VLOG(3) << "Primary initializing accounting data.";
@ -463,12 +506,12 @@ struct NcclAllReduceThunk::AuxData {
crs->IsCrossReplicaAllReduce() && operands_are_supported(); crs->IsCrossReplicaAllReduce() && operands_are_supported();
} }
/*static*/ absl::flat_hash_set<int> /*static*/ absl::flat_hash_set<GlobalDeviceId>
NcclAllReduceThunk::DevicesWithOpenNcclChannels() { NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
absl::flat_hash_set<int> devices; absl::flat_hash_set<GlobalDeviceId> devices;
GlobalNcclCliqueMap().ForEach( GlobalNcclCliqueMap().ForEach(
[&](const NcclCliqueKey& k, const std::shared_ptr<NcclClique>&) { [&](const NcclCliqueKey& k, const std::shared_ptr<NcclClique>&) {
devices.insert(k.devices.begin(), k.devices.end()); devices.insert(k.devices().begin(), k.devices().end());
}); });
return devices; return devices;
} }
@ -491,23 +534,57 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
auto* instr = Cast<HloAllReduceInstruction>(hlo_instruction()); auto* instr = Cast<HloAllReduceInstruction>(hlo_instruction());
int64 device_ordinal = params.stream->parent()->device_ordinal(); int64 local_device_ordinal = params.stream->parent()->device_ordinal();
GlobalDeviceId global_device_id;
if (params.gpu_global_device_ids) {
TF_RET_CHECK(0 <= local_device_ordinal &&
local_device_ordinal < params.gpu_global_device_ids->size());
global_device_id = (*params.gpu_global_device_ids)[local_device_ordinal];
} else {
// No local -> global mapping was provided; assume the identity mapping.
global_device_id = GlobalDeviceId(local_device_ordinal);
}
// Determines the set of global and local devices that are participating in
// the same collective group as the caller.
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::vector<int64> participating_replicas, std::vector<int64> global_participating_replicas,
GetParticipatingReplicas(device_ordinal, instr->replica_groups(), GetParticipatingReplicas(global_device_id, instr->replica_groups(),
replica_count_, *params.device_assn)); replica_count_, *params.device_assn));
std::vector<GlobalDeviceId> global_devices;
std::vector<std::pair<GlobalDeviceId, int64>> local_devices;
local_devices.reserve(global_participating_replicas.size());
global_devices.reserve(global_participating_replicas.size());
TF_RET_CHECK(params.device_assn->computation_count() == 1)
<< params.device_assn->ToString();
for (int64 replica : global_participating_replicas) {
GlobalDeviceId global_device(
(*params.device_assn)(replica, /*computation=*/0));
global_devices.push_back(global_device);
if (!params.gpu_global_device_ids) {
local_devices.emplace_back(global_device, global_device.value());
} else {
auto it = absl::c_find(*params.gpu_global_device_ids, global_device);
if (it != params.gpu_global_device_ids->end()) {
local_devices.emplace_back(
*it, std::distance(params.gpu_global_device_ids->begin(), it));
}
}
}
absl::c_sort(global_devices);
// Find or create the rendezvous for this collective operation. // Find or create the rendezvous for this collective operation.
RendezvousKey rendezvous_key = RendezvousKey::FromInstruction( RendezvousKey rendezvous_key = RendezvousKey::FromInstruction(
params.run_id, participating_replicas, hlo_instruction()); params.run_id, global_devices, local_devices.size(), hlo_instruction());
VLOG(2) << "Rendezvous key: " << rendezvous_key.ToString() VLOG(2) << "Rendezvous key: " << rendezvous_key.ToString()
<< ", participating replicas: " << ", global participating replicas: "
<< absl::StrJoin(participating_replicas, ", "); << absl::StrJoin(global_participating_replicas, ", ")
<< ", global participating devices: "
<< GlobalDeviceIdsToString(global_devices);
AllReduceParticipantData participant(rendezvous_key); AllReduceParticipantData participant(rendezvous_key);
participant.device_ordinal = device_ordinal; participant.device_ordinal = local_device_ordinal;
for (size_t i = 0; i < buffers_.size(); ++i) { for (size_t i = 0; i < buffers_.size(); ++i) {
const NcclAllReduceThunk::Buffer& buffer = buffers_[i]; const NcclAllReduceThunk::Buffer& buffer = buffers_[i];
AllReduceParticipantData::Buffer pbuffer; AllReduceParticipantData::Buffer pbuffer;
@ -521,15 +598,24 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
participant.buffers.push_back(pbuffer); participant.buffers.push_back(pbuffer);
} }
participant.stream = params.stream; participant.stream = params.stream;
participant.local_devices = std::move(local_devices);
participant.nccl_unique_id_callback = params.nccl_unique_id_callback;
auto reduction_kind = auto reduction_kind =
MatchReductionComputation(hlo_instruction()->to_apply()); MatchReductionComputation(hlo_instruction()->to_apply());
CHECK(reduction_kind.has_value()); CHECK(reduction_kind.has_value());
participant.reduction_kind = *reduction_kind; participant.reduction_kind = *reduction_kind;
TF_ASSIGN_OR_RETURN( auto rendezvous_factory = [](const RendezvousKey& k) {
std::shared_ptr<NcclClique> clique, return absl::make_unique<RendezvousNcclAllReduce>(k);
RendezvousNcclAllReduce::SubmitParticipant( };
[&] { return GlobalRendezvousMap()[rendezvous_key]; }, participant));
TF_ASSIGN_OR_RETURN(std::shared_ptr<NcclClique> clique,
RendezvousNcclAllReduce::SubmitParticipant(
[&] {
return GlobalRendezvousMap().GetOrCreateIfAbsent(
rendezvous_key, rendezvous_factory);
},
participant));
// Keep the clique we used alive for as long as this Thunk lives. Creating // Keep the clique we used alive for as long as this Thunk lives. Creating
// new NCCL cliques is expensive, and this is how we avoid thrashing them. // new NCCL cliques is expensive, and this is how we avoid thrashing them.

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
@ -46,7 +47,7 @@ class NcclAllReduceThunk : public Thunk {
// (Indeed, because the NCCL channels are a global variable, in the real // (Indeed, because the NCCL channels are a global variable, in the real
// world, the value returned here is stale as soon as you read it, so it's not // world, the value returned here is stale as soon as you read it, so it's not
// clear how you *could* use it for anything other than tests.) // clear how you *could* use it for anything other than tests.)
static absl::flat_hash_set<int> DevicesWithOpenNcclChannels(); static absl::flat_hash_set<GlobalDeviceId> DevicesWithOpenNcclChannels();
// TODO(b/125951860): Support all-reduces with replica groups, i.e. // TODO(b/125951860): Support all-reduces with replica groups, i.e.
// all-reduces that compute multiple sums across subsets of all replicas. // all-reduces that compute multiple sums across subsets of all replicas.

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
@ -98,6 +99,8 @@ class Thunk {
HloExecutionProfiler* profiler; // never null HloExecutionProfiler* profiler; // never null
const DeviceAssignment* device_assn; // never null const DeviceAssignment* device_assn; // never null
std::vector<std::function<void()>>* deferred_host_callbacks; // never null std::vector<std::function<void()>>* deferred_host_callbacks; // never null
const std::vector<GlobalDeviceId>* gpu_global_device_ids; // may be null
const NcclUniqueIdCallback* nccl_unique_id_callback; // may be null
}; };
// Execute the kernel for the thunk on the given stream. This method must be // Execute the kernel for the thunk on the given stream. This method must be

View File

@ -159,7 +159,7 @@ DeviceAssignment MakeDeviceAssn(std::vector<int64> devices) {
} }
// Shorter alias for this function. // Shorter alias for this function.
absl::flat_hash_set<int> OpenNcclChannels() { absl::flat_hash_set<GlobalDeviceId> OpenNcclChannels() {
return gpu::NcclAllReduceThunk::DevicesWithOpenNcclChannels(); return gpu::NcclAllReduceThunk::DevicesWithOpenNcclChannels();
} }