[XLA:GPU] Add experimental, lightly tested support for multi-host and multi-process NCCL AllReduce.
This change makes several API changes: * we allow the client to provide a mapping from the local device ordinals on the machine to global device IDs. If provided, we interpret the device IDs in the DeviceAssignment provided by the client as global IDs, not as local device ordinals. This allows us to describe computations that cross a host boundary. * we allow the client to provide a callback for manufacturing a ncclUniqueId for a particular subset of global devices. The idea is that the client should use some other distributed system of their own (e.g., MPI) to share ncclUniqueId values needed for a computation. NCCL allows for cross-host/process collectives iff the same ncclUniqueId value is used. Refactors the common collective logic and the NCCL collective logic in particular to support a local/global distinction. PiperOrigin-RevId: 296505571 Change-Id: I5ed42d65597b0960df78890745421f77e9789ba3
This commit is contained in:
parent
371c29f627
commit
8a72c4466a
@ -100,6 +100,17 @@ const DeviceAssignment* ExecutableRunOptions::device_assignment() const {
|
||||
return device_assignment_;
|
||||
}
|
||||
|
||||
ExecutableRunOptions& ExecutableRunOptions::set_gpu_executable_run_options(
|
||||
const GpuExecutableRunOptions* gpu_executable_run_options) {
|
||||
gpu_executable_run_options_ = gpu_executable_run_options;
|
||||
return *this;
|
||||
}
|
||||
|
||||
const GpuExecutableRunOptions*
|
||||
ExecutableRunOptions::gpu_executable_run_options() const {
|
||||
return gpu_executable_run_options_;
|
||||
}
|
||||
|
||||
ExecutableRunOptions& ExecutableRunOptions::set_rng_seed(int rng_seed) {
|
||||
rng_seed_ = rng_seed;
|
||||
return *this;
|
||||
|
@ -38,6 +38,7 @@ namespace xla {
|
||||
|
||||
class DeviceAssignment;
|
||||
class ExecutionProfile;
|
||||
class GpuExecutableRunOptions;
|
||||
|
||||
// A unique identifier for a particular "logical execution" of an XLA model.
|
||||
//
|
||||
@ -137,6 +138,12 @@ class ExecutableRunOptions {
|
||||
return then_execute_function_;
|
||||
}
|
||||
|
||||
// GPU-backend specific options. These are kept out-of-line to avoid bloating
|
||||
// the size of this dependency for CPU-only AOT builds.
|
||||
ExecutableRunOptions& set_gpu_executable_run_options(
|
||||
const GpuExecutableRunOptions* gpu_executable_run_options);
|
||||
const GpuExecutableRunOptions* gpu_executable_run_options() const;
|
||||
|
||||
private:
|
||||
stream_executor::DeviceMemoryAllocator* allocator_ = nullptr;
|
||||
int device_ordinal_ = -1;
|
||||
@ -148,6 +155,7 @@ class ExecutableRunOptions {
|
||||
stream_executor::Stream* host_to_device_stream_ = nullptr;
|
||||
ThenExecuteFunction* then_execute_function_ = nullptr;
|
||||
RunId run_id_;
|
||||
const GpuExecutableRunOptions* gpu_executable_run_options_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -42,13 +42,7 @@ template <typename K, typename V>
|
||||
class RefcountingHashMap {
|
||||
public:
|
||||
// Default-constructs new values.
|
||||
RefcountingHashMap()
|
||||
: value_factory_([](const K&) { return absl::make_unique<V>(); }) {}
|
||||
|
||||
// Constructs new values according to the given factory function.
|
||||
explicit RefcountingHashMap(
|
||||
std::function<std::unique_ptr<V>(const K&)> value_factory)
|
||||
: value_factory_(std::move(value_factory)) {}
|
||||
RefcountingHashMap() = default;
|
||||
|
||||
// Not copyable or movable because this contains internal pointers (namely,
|
||||
// instances of Deleter contain pointers to `this` and into `map_`).
|
||||
@ -60,8 +54,10 @@ class RefcountingHashMap {
|
||||
// Gets the value for the given key.
|
||||
//
|
||||
// If the map doesn't contain a live value for the key, constructs one
|
||||
// according to the factory passed to the map's constructor.
|
||||
std::shared_ptr<V> operator[](const K& key) {
|
||||
// using `value_factory`.
|
||||
std::shared_ptr<V> GetOrCreateIfAbsent(
|
||||
const K& key,
|
||||
const std::function<std::unique_ptr<V>(const K&)>& value_factory) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto it = map_.find(key);
|
||||
// We ensure that the entry has not expired in case deleter was running when
|
||||
@ -76,7 +72,7 @@ class RefcountingHashMap {
|
||||
// Create entry in the map and then set its value, so the value can
|
||||
// contain a pointer back into the map.
|
||||
it = map_.emplace(key, std::weak_ptr<V>()).first;
|
||||
std::shared_ptr<V> value(value_factory_(key).release(),
|
||||
std::shared_ptr<V> value(value_factory(key).release(),
|
||||
Deleter{&it->first, this});
|
||||
it->second = value; // Set the weak ptr to the shared ptr.
|
||||
return value;
|
||||
@ -112,7 +108,6 @@ class RefcountingHashMap {
|
||||
}
|
||||
};
|
||||
|
||||
std::function<std::unique_ptr<V>(const K&)> value_factory_;
|
||||
absl::Mutex mu_;
|
||||
absl::node_hash_map<K, std::weak_ptr<V>> map_ ABSL_GUARDED_BY(mu_);
|
||||
};
|
||||
|
@ -47,22 +47,25 @@ struct DeleteNotifier {
|
||||
|
||||
TEST(RefcountingHashMapTest, PointerIdentity) {
|
||||
RefcountingHashMap<int, int> m;
|
||||
std::shared_ptr<int> a = m[0];
|
||||
std::shared_ptr<int> b = m[0];
|
||||
std::shared_ptr<int> c = m[1];
|
||||
auto factory = [](const int&) { return absl::make_unique<int>(); };
|
||||
std::shared_ptr<int> a = m.GetOrCreateIfAbsent(0, factory);
|
||||
std::shared_ptr<int> b = m.GetOrCreateIfAbsent(0, factory);
|
||||
std::shared_ptr<int> c = m.GetOrCreateIfAbsent(1, factory);
|
||||
EXPECT_EQ(a.get(), b.get());
|
||||
EXPECT_NE(a.get(), c.get());
|
||||
}
|
||||
|
||||
TEST(RefcountingHashMapTest, DefaultInitialized) {
|
||||
RefcountingHashMap<int, int> m;
|
||||
EXPECT_EQ(*m[42], 0);
|
||||
auto factory = [](const int&) { return absl::make_unique<int>(); };
|
||||
EXPECT_EQ(*m.GetOrCreateIfAbsent(42, factory), 0);
|
||||
}
|
||||
|
||||
TEST(RefcountingHashMapTest, DeletesEagerly) {
|
||||
RefcountingHashMap<int, DeleteNotifier> m;
|
||||
bool deleted = false;
|
||||
auto handle = m[0];
|
||||
auto factory = [](const int&) { return absl::make_unique<DeleteNotifier>(); };
|
||||
auto handle = m.GetOrCreateIfAbsent(0, factory);
|
||||
handle->fn = [&] { deleted = true; };
|
||||
EXPECT_FALSE(deleted);
|
||||
handle = nullptr;
|
||||
@ -70,10 +73,10 @@ TEST(RefcountingHashMapTest, DeletesEagerly) {
|
||||
}
|
||||
|
||||
TEST(RefcountingHashMapTest, CustomFactory) {
|
||||
RefcountingHashMap<int, int> m(
|
||||
[](const int& x) { return absl::make_unique<int>(x + 1); });
|
||||
EXPECT_EQ(*m[0], 1);
|
||||
EXPECT_EQ(*m[100], 101);
|
||||
RefcountingHashMap<int, int> m;
|
||||
auto factory = [](const int& x) { return absl::make_unique<int>(x + 1); };
|
||||
EXPECT_EQ(*m.GetOrCreateIfAbsent(0, factory), 1);
|
||||
EXPECT_EQ(*m.GetOrCreateIfAbsent(100, factory), 101);
|
||||
}
|
||||
|
||||
TEST(RefcountingHashMapTest, ForEachEmpty) {
|
||||
@ -85,8 +88,9 @@ TEST(RefcountingHashMapTest, ForEachEmpty) {
|
||||
|
||||
TEST(RefcountingHashMapTest, ForEachNonempty) {
|
||||
RefcountingHashMap<int, int> m;
|
||||
auto a = m[0];
|
||||
auto b = m[1];
|
||||
auto factory = [](const int&) { return absl::make_unique<int>(); };
|
||||
auto a = m.GetOrCreateIfAbsent(0, factory);
|
||||
auto b = m.GetOrCreateIfAbsent(1, factory);
|
||||
|
||||
std::vector<int> seen_keys;
|
||||
std::vector<int*> seen_values;
|
||||
|
@ -4585,6 +4585,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service:pattern_matcher",
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal", # fixdeps: keep
|
||||
"//tensorflow/stream_executor/lib",
|
||||
|
@ -44,7 +44,7 @@ absl::optional<ReductionKind> MatchReductionComputation(
|
||||
}
|
||||
|
||||
StatusOr<std::vector<int64>> GetParticipatingReplicas(
|
||||
int64 device_ordinal, absl::Span<const ReplicaGroup> replica_groups,
|
||||
GlobalDeviceId device_id, absl::Span<const ReplicaGroup> replica_groups,
|
||||
int64 total_replica_count, const DeviceAssignment& device_assn) {
|
||||
std::vector<int64> participating_replicas;
|
||||
|
||||
@ -58,7 +58,7 @@ StatusOr<std::vector<int64>> GetParticipatingReplicas(
|
||||
|
||||
// Use the DeviceAssignment to figure out our replica-id.
|
||||
TF_ASSIGN_OR_RETURN(int replica_id,
|
||||
device_assn.ReplicaIdForDeviceOrdinal(device_ordinal));
|
||||
device_assn.ReplicaIdForDeviceOrdinal(device_id.value()));
|
||||
|
||||
// Figure out the other replicas that go together with this one.
|
||||
absl::optional<ReplicaGroup> replica_group;
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
@ -37,9 +38,9 @@ absl::optional<ReductionKind> MatchReductionComputation(
|
||||
const HloComputation* computation);
|
||||
|
||||
// Figures out which devices (named by their replica-ids) are participating in
|
||||
// the all-reduce subgroup that contains device_ordinal.
|
||||
// the all-reduce subgroup that contains device_id.
|
||||
StatusOr<std::vector<int64>> GetParticipatingReplicas(
|
||||
int64 device_ordinal, absl::Span<const ReplicaGroup> replica_groups,
|
||||
GlobalDeviceId device_id, absl::Span<const ReplicaGroup> replica_groups,
|
||||
int64 total_replica_count, const DeviceAssignment& device_assn);
|
||||
|
||||
// Key that identifies a particular Rendezvous object in our global hashtable.
|
||||
@ -72,16 +73,18 @@ struct RendezvousKey {
|
||||
};
|
||||
|
||||
explicit RendezvousKey(const RunId& run_id,
|
||||
std::vector<int64> participating_replicas,
|
||||
std::vector<GlobalDeviceId> global_devices,
|
||||
int num_local_participants,
|
||||
CollectiveOpKind collective_op_kind, int64 op_id)
|
||||
: run_id(run_id),
|
||||
participating_replicas(participating_replicas),
|
||||
global_devices(std::move(global_devices)),
|
||||
num_local_participants(num_local_participants),
|
||||
collective_op_kind(collective_op_kind),
|
||||
op_id(op_id) {}
|
||||
|
||||
static RendezvousKey FromInstruction(
|
||||
const RunId& run_id, std::vector<int64> participating_replicas,
|
||||
const HloInstruction* instr) {
|
||||
const RunId& run_id, std::vector<GlobalDeviceId> global_devices,
|
||||
int num_local_participants, const HloInstruction* instr) {
|
||||
CollectiveOpKind collective_op_kind;
|
||||
int64 op_id;
|
||||
|
||||
@ -91,20 +94,19 @@ struct RendezvousKey {
|
||||
: std::make_pair(
|
||||
kCrossReplica,
|
||||
static_cast<int64>(instr->GetModule()->unique_id()));
|
||||
return RendezvousKey(run_id, participating_replicas, collective_op_kind,
|
||||
op_id);
|
||||
return RendezvousKey(run_id, std::move(global_devices),
|
||||
num_local_participants, collective_op_kind, op_id);
|
||||
}
|
||||
|
||||
int num_participants() const { return participating_replicas.size(); }
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const RendezvousKey& k) {
|
||||
return H::combine(std::move(h), k.run_id, k.participating_replicas,
|
||||
return H::combine(std::move(h), k.run_id, k.global_devices,
|
||||
k.num_local_participants,
|
||||
static_cast<int>(k.collective_op_kind), k.op_id);
|
||||
}
|
||||
friend bool operator==(const RendezvousKey& a, const RendezvousKey& b) {
|
||||
return a.run_id == b.run_id &&
|
||||
a.participating_replicas == b.participating_replicas &&
|
||||
return a.run_id == b.run_id && a.global_devices == b.global_devices &&
|
||||
a.num_local_participants == b.num_local_participants &&
|
||||
a.collective_op_kind == b.collective_op_kind && //
|
||||
a.op_id == b.op_id;
|
||||
}
|
||||
@ -114,14 +116,15 @@ struct RendezvousKey {
|
||||
|
||||
string ToString() const {
|
||||
return absl::StrFormat(
|
||||
"RendezvousKey{run_id=%s, participating_replicas=[%s], "
|
||||
"collective_op_kind=%d, op_id=%d}",
|
||||
run_id.ToString(), absl::StrJoin(participating_replicas, ","),
|
||||
static_cast<int>(collective_op_kind), op_id);
|
||||
"RendezvousKey{run_id=%s, global_devices=[%s], "
|
||||
"num_local_participants=%d, collective_op_kind=%d, op_id=%d}",
|
||||
run_id.ToString(), GlobalDeviceIdsToString(global_devices),
|
||||
num_local_participants, static_cast<int>(collective_op_kind), op_id);
|
||||
}
|
||||
|
||||
RunId run_id;
|
||||
std::vector<int64> participating_replicas;
|
||||
std::vector<GlobalDeviceId> global_devices;
|
||||
int num_local_participants;
|
||||
CollectiveOpKind collective_op_kind;
|
||||
int64 op_id;
|
||||
};
|
||||
@ -164,10 +167,13 @@ struct AllReduceParticipantData {
|
||||
};
|
||||
std::vector<Buffer> buffers;
|
||||
se::Stream* stream;
|
||||
const NcclUniqueIdCallback* nccl_unique_id_callback = nullptr;
|
||||
|
||||
ReductionKind reduction_kind;
|
||||
|
||||
int num_participants() const { return rendezvous_key.num_participants(); }
|
||||
// For each local all-reduce participant a (global ID, local device ordinal)
|
||||
// pair for the participant. Participants are in no particular order.
|
||||
std::vector<std::pair<GlobalDeviceId, int64>> local_devices;
|
||||
|
||||
string ToString() const {
|
||||
std::vector<std::string> buffer_strs;
|
||||
@ -303,12 +309,13 @@ class Rendezvous {
|
||||
const RendezvousKey key_;
|
||||
|
||||
tensorflow::BlockingCounter all_participants_present_{
|
||||
key_.num_participants()};
|
||||
tensorflow::BlockingCounter done_{key_.num_participants()};
|
||||
key_.num_local_participants};
|
||||
tensorflow::BlockingCounter done_{key_.num_local_participants};
|
||||
|
||||
// tensorflow::BlockingCounter returned by SubmitParticipant.
|
||||
std::shared_ptr<tensorflow::BlockingCounter> returned_blocking_counter_{
|
||||
std::make_shared<tensorflow::BlockingCounter>(key_.num_participants())};
|
||||
std::make_shared<tensorflow::BlockingCounter>(
|
||||
key_.num_local_participants)};
|
||||
};
|
||||
|
||||
} // end namespace xla
|
||||
|
@ -382,10 +382,7 @@ class CpuAllReduceRendezvous : public xla::Rendezvous<std::nullptr_t> {
|
||||
xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>&
|
||||
GlobalRendezvousMap() {
|
||||
static auto& m =
|
||||
*new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>(
|
||||
[](const xla::RendezvousKey& k) {
|
||||
return absl::make_unique<CpuAllReduceRendezvous>(k);
|
||||
});
|
||||
*new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>;
|
||||
return m;
|
||||
}
|
||||
|
||||
@ -411,18 +408,28 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
|
||||
|
||||
std::vector<xla::ReplicaGroup> group =
|
||||
xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
|
||||
xla::int32 replica_count = run_options->device_assignment()->replica_count();
|
||||
std::vector<xla::int64> participating_replicas_vec =
|
||||
xla::GetParticipatingReplicas(device_ordinal, group, replica_count,
|
||||
const xla::DeviceAssignment& device_assignment =
|
||||
*run_options->device_assignment();
|
||||
xla::int32 replica_count = device_assignment.replica_count();
|
||||
CHECK_EQ(device_assignment.computation_count(), 1);
|
||||
std::vector<xla::int64> participating_replicas =
|
||||
xla::GetParticipatingReplicas(xla::GlobalDeviceId(device_ordinal), group,
|
||||
replica_count,
|
||||
*run_options->device_assignment())
|
||||
.ValueOrDie();
|
||||
|
||||
xla::RendezvousKey::CollectiveOpKind op_kind =
|
||||
channel_id_present ? xla::RendezvousKey::kCrossModule
|
||||
: xla::RendezvousKey::kCrossReplica;
|
||||
xla::RendezvousKey rendezvous_key(run_options->run_id(),
|
||||
participating_replicas_vec, op_kind, op_id);
|
||||
|
||||
std::vector<xla::GlobalDeviceId> participating_devices;
|
||||
participating_devices.reserve(participating_replicas.size());
|
||||
for (xla::int64 replica : participating_replicas) {
|
||||
participating_devices.push_back(
|
||||
xla::GlobalDeviceId(device_assignment(replica, 0)));
|
||||
}
|
||||
xla::RendezvousKey rendezvous_key(
|
||||
run_options->run_id(), std::move(participating_devices),
|
||||
participating_replicas.size(), op_kind, op_id);
|
||||
auto shape_str = ShapeString(shape_ptr, shape_length);
|
||||
VLOG(2) << "All-reduce input/output shape : " << shape_str;
|
||||
|
||||
@ -444,10 +451,17 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
|
||||
participant.buffers = {buffer};
|
||||
participant.reduction_kind = static_cast<xla::ReductionKind>(reduction_kind);
|
||||
|
||||
TF_CHECK_OK(
|
||||
CpuAllReduceRendezvous::SubmitParticipant(
|
||||
[&] { return GlobalRendezvousMap()[rendezvous_key]; }, participant)
|
||||
.status());
|
||||
auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
|
||||
return absl::make_unique<CpuAllReduceRendezvous>(k);
|
||||
};
|
||||
|
||||
TF_CHECK_OK(CpuAllReduceRendezvous::SubmitParticipant(
|
||||
[&] {
|
||||
return GlobalRendezvousMap().GetOrCreateIfAbsent(
|
||||
rendezvous_key, make_cpu_rendezvous);
|
||||
},
|
||||
participant)
|
||||
.status());
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId(
|
||||
|
@ -54,6 +54,22 @@ tf_proto_library_cc(
|
||||
protodeps = ["//tensorflow/compiler/xla:xla_data_proto"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_executable_run_options",
|
||||
srcs = ["gpu_executable_run_options.cc"],
|
||||
hdrs = ["gpu_executable_run_options.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_constants",
|
||||
srcs = ["gpu_constants.cc"],
|
||||
@ -385,6 +401,7 @@ cc_library(
|
||||
hdrs = ["thunk.h"],
|
||||
deps = [
|
||||
":buffer_allocations",
|
||||
":gpu_executable_run_options",
|
||||
":hlo_execution_profiler",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
@ -413,6 +430,7 @@ tf_cuda_library(
|
||||
":buffer_allocations",
|
||||
":hlo_execution_profiler",
|
||||
":thunk",
|
||||
":gpu_executable_run_options",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"//tensorflow/compiler/xla/service:pattern_matcher",
|
||||
"//tensorflow/compiler/xla:refcounting_hash_map",
|
||||
@ -522,6 +540,7 @@ cc_library(
|
||||
":cudnn_batchnorm_runner",
|
||||
":gpu_conv_runner",
|
||||
":gpu_debug_info_manager",
|
||||
":gpu_executable_run_options",
|
||||
":gpu_types",
|
||||
":hlo_execution_profiler",
|
||||
":infeed_manager",
|
||||
|
@ -211,10 +211,7 @@ StatusOr<std::shared_ptr<BlockingCounter>> Rendezvous::SubmitParticipant(
|
||||
// Rendezvous objects are one-time use, so they're removed from this map once
|
||||
// we're through with them.
|
||||
RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
|
||||
static auto& m = *new RefcountingHashMap<RendezvousKey, Rendezvous>(
|
||||
[](const RendezvousKey& key) {
|
||||
return absl::make_unique<Rendezvous>(key);
|
||||
});
|
||||
static auto& m = *new RefcountingHashMap<RendezvousKey, Rendezvous>();
|
||||
return m;
|
||||
}
|
||||
|
||||
@ -233,7 +230,11 @@ Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
// Rendezvous with the threads for all other devices that are participating in
|
||||
// this CollectivePermute.
|
||||
RendezvousKey key{params.run_id, params.device_assn->replica_count()};
|
||||
std::shared_ptr<Rendezvous> rendezvous = GlobalRendezvousMap()[key];
|
||||
auto rendezvous_factory = [](const RendezvousKey& key) {
|
||||
return absl::make_unique<Rendezvous>(key);
|
||||
};
|
||||
std::shared_ptr<Rendezvous> rendezvous =
|
||||
GlobalRendezvousMap().GetOrCreateIfAbsent(key, rendezvous_factory);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(int64 replica_id,
|
||||
params.device_assn->ReplicaIdForDeviceOrdinal(
|
||||
|
@ -34,7 +34,7 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
|
||||
NcclAllReduceThunk::~NcclAllReduceThunk() = default;
|
||||
|
||||
/*static*/ absl::flat_hash_set<int>
|
||||
/*static*/ absl::flat_hash_set<GlobalDeviceId>
|
||||
NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
||||
return {};
|
||||
}
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_types.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
@ -196,13 +197,21 @@ Status GpuExecutable::ExecuteThunks(
|
||||
VLOG(2) << "Executing the thunk for "
|
||||
<< thunk->hlo_instruction()->ToString() << " on stream "
|
||||
<< stream_no;
|
||||
const GpuExecutableRunOptions* gpu_options =
|
||||
run_options->run_options().gpu_executable_run_options();
|
||||
Thunk::ExecuteParams thunk_params{
|
||||
&buffer_allocations,
|
||||
stream,
|
||||
run_options->run_options().run_id(),
|
||||
&profiler,
|
||||
run_options->run_options().device_assignment(),
|
||||
&deferred_host_callbacks};
|
||||
&deferred_host_callbacks,
|
||||
gpu_options && gpu_options->gpu_global_device_ids()
|
||||
? &*gpu_options->gpu_global_device_ids()
|
||||
: nullptr,
|
||||
gpu_options && gpu_options->nccl_unique_id_callback()
|
||||
? &gpu_options->nccl_unique_id_callback()
|
||||
: nullptr};
|
||||
TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(thunk_params));
|
||||
if (thunk_schedule_->Depended(thunk)) {
|
||||
auto finish_event = absl::make_unique<se::Event>(main_stream->parent());
|
||||
|
@ -0,0 +1,62 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
std::string GlobalDeviceIdsToString(absl::Span<GlobalDeviceId const> ids) {
|
||||
std::vector<int64> values;
|
||||
values.reserve(ids.size());
|
||||
for (GlobalDeviceId id : ids) {
|
||||
values.push_back(id.value());
|
||||
}
|
||||
return absl::StrJoin(values, ",");
|
||||
}
|
||||
|
||||
NcclCliqueKey::NcclCliqueKey(std::vector<GlobalDeviceId> devices)
|
||||
: devices_(std::move(devices)) {
|
||||
absl::c_sort(devices_);
|
||||
CHECK(absl::c_adjacent_find(devices_) == devices_.end())
|
||||
<< "Duplicate devices are not allowed: "
|
||||
<< GlobalDeviceIdsToString(devices_);
|
||||
}
|
||||
|
||||
GpuExecutableRunOptions& GpuExecutableRunOptions::set_gpu_global_device_ids(
|
||||
absl::optional<std::vector<GlobalDeviceId>> gpu_global_device_ids) {
|
||||
gpu_global_device_ids_ = std::move(gpu_global_device_ids);
|
||||
return *this;
|
||||
}
|
||||
|
||||
const absl::optional<std::vector<GlobalDeviceId>>&
|
||||
GpuExecutableRunOptions::gpu_global_device_ids() const {
|
||||
return gpu_global_device_ids_;
|
||||
}
|
||||
|
||||
GpuExecutableRunOptions& GpuExecutableRunOptions::set_nccl_unique_id_callback(
|
||||
NcclUniqueIdCallback nccl_unique_id_callback) {
|
||||
nccl_unique_id_callback_ = std::move(nccl_unique_id_callback);
|
||||
return *this;
|
||||
}
|
||||
|
||||
const NcclUniqueIdCallback& GpuExecutableRunOptions::nccl_unique_id_callback()
|
||||
const {
|
||||
return nccl_unique_id_callback_;
|
||||
}
|
||||
|
||||
} // namespace xla
|
@ -0,0 +1,90 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/gtl/int_type.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Strongly-typed integer type for naming a device globally within a distributed
|
||||
// system. XLA doesn't have a strong opinion about what global numbering scheme
|
||||
// is applied to GPUs; the user must provide a local -> global mapping via
|
||||
// GpuExecutableRunOptions for the local GPUs.
|
||||
TF_LIB_GTL_DEFINE_INT_TYPE(GlobalDeviceId, int64);
|
||||
|
||||
// Returns a comma-separated string of global device IDs.
|
||||
std::string GlobalDeviceIdsToString(absl::Span<GlobalDeviceId const> ids);
|
||||
|
||||
// Key for naming up a particular NCCL clique. This is just a set of unique
|
||||
// device IDs (i.e. GPU IDs). The device IDs must be global within a cluster.
|
||||
class NcclCliqueKey {
|
||||
public:
|
||||
explicit NcclCliqueKey(std::vector<GlobalDeviceId> devices);
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const NcclCliqueKey& k) {
|
||||
return H::combine(std::move(h), k.devices_);
|
||||
}
|
||||
friend bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b) {
|
||||
return a.devices_ == b.devices_;
|
||||
}
|
||||
|
||||
const std::vector<GlobalDeviceId>& devices() const { return devices_; }
|
||||
|
||||
private:
|
||||
std::vector<GlobalDeviceId> devices_;
|
||||
};
|
||||
|
||||
using NcclUniqueIdCallback =
|
||||
std::function<StatusOr<std::string>(const NcclCliqueKey&)>;
|
||||
|
||||
// GPU-specific executable options.
|
||||
// We keep these separate from ExecutableRunOptions to avoid adding
|
||||
// dependencies to ExecutableRunOptions.
|
||||
class GpuExecutableRunOptions {
|
||||
public:
|
||||
// Sets a mapping from local device ordinals to global device IDs.
|
||||
// Used only on NVidia GPUs for cross-host NCCL collectives. If set, the
|
||||
// elements of `device_assignment` are interpreted as global device IDs, not
|
||||
// local device ordinals.
|
||||
GpuExecutableRunOptions& set_gpu_global_device_ids(
|
||||
absl::optional<std::vector<GlobalDeviceId>> gpu_global_device_ids);
|
||||
const absl::optional<std::vector<GlobalDeviceId>>& gpu_global_device_ids()
|
||||
const;
|
||||
|
||||
// Callback that returns a ncclUniqueId encoded as a string for a group of
|
||||
// communicating GPU devices. Used only on NVidia GPUs.
|
||||
GpuExecutableRunOptions& set_nccl_unique_id_callback(
|
||||
NcclUniqueIdCallback nccl_unique_id_callback);
|
||||
const NcclUniqueIdCallback& nccl_unique_id_callback() const;
|
||||
|
||||
private:
|
||||
absl::optional<std::vector<GlobalDeviceId>> gpu_global_device_ids_;
|
||||
NcclUniqueIdCallback nccl_unique_id_callback_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/refcounting_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -179,27 +180,18 @@ absl::optional<ncclDataType_t> DatatypeToNccl(PrimitiveType element_type) {
|
||||
}
|
||||
}
|
||||
|
||||
// Key for looking up a particular NCCL clique. This is just a set of unique
|
||||
// device ordinals (i.e. GPU IDs).
|
||||
struct NcclCliqueKey {
|
||||
explicit NcclCliqueKey(absl::Span<const int64> devices)
|
||||
: devices(devices.begin(), devices.end()) {
|
||||
absl::c_sort(this->devices);
|
||||
CHECK(absl::c_adjacent_find(devices) == devices.end())
|
||||
<< "Duplicate devices are not allowed: "
|
||||
<< absl::StrJoin(devices, ", ");
|
||||
Status StringToNcclUniqueId(const std::string& str_id, ncclUniqueId* nccl_id) {
|
||||
if (str_id.size() != NCCL_UNIQUE_ID_BYTES) {
|
||||
return InvalidArgument(
|
||||
"ncclUniqueId string must have %d bytes, got %d bytes", str_id.size(),
|
||||
NCCL_UNIQUE_ID_BYTES);
|
||||
}
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const NcclCliqueKey& k) {
|
||||
return H::combine(std::move(h), k.devices);
|
||||
}
|
||||
friend bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b) {
|
||||
return a.devices == b.devices;
|
||||
}
|
||||
|
||||
std::vector<int64> devices;
|
||||
};
|
||||
// NcclUniqueId is internally just a char[].
|
||||
static_assert(sizeof(ncclUniqueId) == NCCL_UNIQUE_ID_BYTES,
|
||||
"NCCL_UNIQUE_ID_BYTES");
|
||||
std::memcpy(static_cast<void*>(nccl_id), str_id.data(), NCCL_UNIQUE_ID_BYTES);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Owns a clique of NCCL comms which can be used for collective operations among
|
||||
// a particular set of GPUs.
|
||||
@ -216,20 +208,29 @@ struct NcclCliqueKey {
|
||||
// GPUs, you'll need a different clique.
|
||||
class NcclClique {
|
||||
public:
|
||||
explicit NcclClique(absl::Span<const int64> devices)
|
||||
: devices_(devices.begin(), devices.end()) {
|
||||
absl::c_sort(devices_);
|
||||
status_ = Init();
|
||||
explicit NcclClique(
|
||||
int64 num_global_devices, std::vector<int64> local_device_ordinals,
|
||||
std::vector<int64> local_device_ranks,
|
||||
const StatusOr<absl::optional<std::string>>& nccl_unique_id)
|
||||
: num_global_devices_(num_global_devices),
|
||||
local_device_ordinals_(std::move(local_device_ordinals)),
|
||||
local_device_ranks_(std::move(local_device_ranks)) {
|
||||
CHECK_EQ(local_device_ordinals_.size(), local_device_ranks_.size());
|
||||
// It's unusual to pass a StatusOr<> into a class, but since this class
|
||||
// already has a erroneous state, it turns out to be a little easier to
|
||||
// implement this way than to change RefcountingHashMap.
|
||||
status_ = Init(nccl_unique_id);
|
||||
}
|
||||
|
||||
Status status() { return status_; }
|
||||
|
||||
absl::Span<const int64> devices() {
|
||||
TF_CHECK_OK(status_);
|
||||
return devices_;
|
||||
}
|
||||
ncclComm_t comm(int64 device) {
|
||||
int64 idx = std::distance(devices_.begin(), absl::c_find(devices_, device));
|
||||
// A NCCL communicator is the NCCL state associated with a participant (rank)
|
||||
// in a reduction. This method returns the state associated with a particular
|
||||
// local device ordinal.
|
||||
ncclComm_t comm(int64 device_ordinal) {
|
||||
int64 idx =
|
||||
std::distance(local_device_ordinals_.begin(),
|
||||
absl::c_find(local_device_ordinals_, device_ordinal));
|
||||
return comms_.at(idx).comm();
|
||||
}
|
||||
|
||||
@ -249,10 +250,12 @@ class NcclClique {
|
||||
}
|
||||
|
||||
private:
|
||||
Status Init() {
|
||||
Status Init(
|
||||
const StatusOr<absl::optional<std::string>>& maybe_nccl_unique_id) {
|
||||
VLOG(3) << absl::StreamFormat(
|
||||
"Initializing nccl comms for participant devices {%s}",
|
||||
absl::StrJoin(devices_, ", "));
|
||||
"Initializing nccl comms for participant device ordinals %s ranks {%s}",
|
||||
absl::StrJoin(local_device_ordinals_, ", "),
|
||||
absl::StrJoin(local_device_ranks_, ", "));
|
||||
|
||||
// Restore CUDA device after running this. XLA shouldn't care, but maybe
|
||||
// another consumer does.
|
||||
@ -264,15 +267,23 @@ class NcclClique {
|
||||
// When using ncclGroupStart/End it seems that the ncclComm_t's are not
|
||||
// populated until the End() call. This unfortunately makes error handling
|
||||
// tricky.
|
||||
std::vector<ncclComm_t> raw_comms(devices_.size(), nullptr);
|
||||
std::vector<ncclComm_t> raw_comms(local_device_ordinals_.size(), nullptr);
|
||||
TF_ASSIGN_OR_RETURN(const absl::optional<std::string>& nccl_id_string,
|
||||
maybe_nccl_unique_id);
|
||||
|
||||
ncclUniqueId nccl_id;
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&nccl_id));
|
||||
if (nccl_id_string) {
|
||||
TF_RETURN_IF_ERROR(StringToNcclUniqueId(*nccl_id_string, &nccl_id));
|
||||
} else {
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&nccl_id));
|
||||
}
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
|
||||
Status status = [&] {
|
||||
for (int i = 0; i < devices_.size(); ++i) {
|
||||
XLA_CUDA_RETURN_IF_ERROR(cudaSetDevice(devices_[i]));
|
||||
XLA_CUDA_RETURN_IF_ERROR(
|
||||
ncclCommInitRank(&raw_comms[i], devices_.size(), nccl_id, i));
|
||||
for (int i = 0; i < local_device_ordinals_.size(); ++i) {
|
||||
XLA_CUDA_RETURN_IF_ERROR(cudaSetDevice(local_device_ordinals_[i]));
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclCommInitRank(&raw_comms[i],
|
||||
num_global_devices_, nccl_id,
|
||||
local_device_ranks_.at(i)));
|
||||
}
|
||||
return Status::OK();
|
||||
}();
|
||||
@ -282,9 +293,9 @@ class NcclClique {
|
||||
// Populate comms_ from the raw comms we created above. If we encountered
|
||||
// an error above we'll later clear comms_ thus destroying any raw comms
|
||||
// that were created before the error.
|
||||
for (int i = 0; i < devices_.size(); ++i) {
|
||||
VLOG(3) << absl::StreamFormat("Device %d assigned ncclComm %p",
|
||||
devices_[i], raw_comms[i]);
|
||||
for (int i = 0; i < local_device_ordinals_.size(); ++i) {
|
||||
VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p",
|
||||
local_device_ordinals_[i], raw_comms[i]);
|
||||
CHECK(raw_comms[i] != nullptr || !status.ok());
|
||||
comms_.emplace_back(raw_comms[i]);
|
||||
}
|
||||
@ -296,7 +307,11 @@ class NcclClique {
|
||||
}
|
||||
|
||||
Status status_;
|
||||
std::vector<int64> devices_;
|
||||
int64 num_global_devices_;
|
||||
std::vector<int64> local_device_ordinals_;
|
||||
// NCCL communicator rank for each local device. The rank of a device is equal
|
||||
// to the offset of the local device in the global device set.
|
||||
std::vector<int64> local_device_ranks_;
|
||||
std::vector<NcclComm> comms_;
|
||||
|
||||
// This mutex is in a unique_ptr so NcclClique can be movable.
|
||||
@ -312,10 +327,7 @@ class NcclClique {
|
||||
// have one clique alive for a given set of GPUs. This means that a process
|
||||
// will never do two collective operations concurrently on the same set of GPUs.
|
||||
RefcountingHashMap<NcclCliqueKey, NcclClique>& GlobalNcclCliqueMap() {
|
||||
static auto& m = *new RefcountingHashMap<NcclCliqueKey, NcclClique>(
|
||||
[](const NcclCliqueKey& key) {
|
||||
return absl::make_unique<NcclClique>(key.devices);
|
||||
});
|
||||
static auto& m = *new RefcountingHashMap<NcclCliqueKey, NcclClique>();
|
||||
return m;
|
||||
}
|
||||
|
||||
@ -341,10 +353,7 @@ class RendezvousNcclAllReduce : public Rendezvous<std::shared_ptr<NcclClique>> {
|
||||
RefcountingHashMap<RendezvousKey, RendezvousNcclAllReduce>&
|
||||
GlobalRendezvousMap() {
|
||||
static auto& m =
|
||||
*new RefcountingHashMap<RendezvousKey, RendezvousNcclAllReduce>(
|
||||
[](const RendezvousKey& k) {
|
||||
return absl::make_unique<RendezvousNcclAllReduce>(k);
|
||||
});
|
||||
*new RefcountingHashMap<RendezvousKey, RendezvousNcclAllReduce>();
|
||||
return m;
|
||||
}
|
||||
|
||||
@ -365,12 +374,46 @@ RendezvousNcclAllReduce::SubmitParticipantImpl(
|
||||
// ensuring that there's a NCCL clique available for us to use.
|
||||
primary = !initialized_;
|
||||
|
||||
TF_RET_CHECK(participant.local_devices.size() ==
|
||||
participant.rendezvous_key.num_local_participants);
|
||||
|
||||
// Look up or create the NCCL clique for this set of devices.
|
||||
std::vector<int64> devices;
|
||||
for (const auto& p : participants_) {
|
||||
devices.push_back(p.device_ordinal);
|
||||
}
|
||||
clique = GlobalNcclCliqueMap()[NcclCliqueKey(devices)];
|
||||
NcclCliqueKey clique_key(participant.rendezvous_key.global_devices);
|
||||
|
||||
auto clique_factory =
|
||||
[&](const NcclCliqueKey& key) -> std::unique_ptr<NcclClique> {
|
||||
std::vector<int64> local_device_ranks;
|
||||
std::vector<int64> local_device_ordinals;
|
||||
local_device_ranks.reserve(participant.local_devices.size());
|
||||
local_device_ordinals.reserve(participant.local_devices.size());
|
||||
for (const auto& l : participant.local_devices) {
|
||||
auto it =
|
||||
absl::c_find(participant.rendezvous_key.global_devices, l.first);
|
||||
CHECK(it != participant.rendezvous_key.global_devices.end()) << l.first;
|
||||
local_device_ranks.push_back(std::distance(
|
||||
participant.rendezvous_key.global_devices.begin(), it));
|
||||
local_device_ordinals.push_back(l.second);
|
||||
}
|
||||
StatusOr<absl::optional<std::string>> nccl_unique_id;
|
||||
if (participant.nccl_unique_id_callback) {
|
||||
nccl_unique_id = (*participant.nccl_unique_id_callback)(clique_key);
|
||||
} else {
|
||||
if (participant.rendezvous_key.global_devices.size() !=
|
||||
participant.rendezvous_key.num_local_participants) {
|
||||
nccl_unique_id = InvalidArgument(
|
||||
"Multihost AllReduce on GPU requires a nccl_unique_id_callback "
|
||||
"to be provided by the client.");
|
||||
} else {
|
||||
nccl_unique_id = absl::optional<std::string>();
|
||||
}
|
||||
}
|
||||
return absl::make_unique<NcclClique>(
|
||||
participant.rendezvous_key.global_devices.size(),
|
||||
std::move(local_device_ordinals), std::move(local_device_ranks),
|
||||
nccl_unique_id);
|
||||
};
|
||||
clique =
|
||||
GlobalNcclCliqueMap().GetOrCreateIfAbsent(clique_key, clique_factory);
|
||||
|
||||
if (primary) {
|
||||
VLOG(3) << "Primary initializing accounting data.";
|
||||
@ -463,12 +506,12 @@ struct NcclAllReduceThunk::AuxData {
|
||||
crs->IsCrossReplicaAllReduce() && operands_are_supported();
|
||||
}
|
||||
|
||||
/*static*/ absl::flat_hash_set<int>
|
||||
/*static*/ absl::flat_hash_set<GlobalDeviceId>
|
||||
NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
||||
absl::flat_hash_set<int> devices;
|
||||
absl::flat_hash_set<GlobalDeviceId> devices;
|
||||
GlobalNcclCliqueMap().ForEach(
|
||||
[&](const NcclCliqueKey& k, const std::shared_ptr<NcclClique>&) {
|
||||
devices.insert(k.devices.begin(), k.devices.end());
|
||||
devices.insert(k.devices().begin(), k.devices().end());
|
||||
});
|
||||
return devices;
|
||||
}
|
||||
@ -491,23 +534,57 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
params.profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
|
||||
auto* instr = Cast<HloAllReduceInstruction>(hlo_instruction());
|
||||
int64 device_ordinal = params.stream->parent()->device_ordinal();
|
||||
int64 local_device_ordinal = params.stream->parent()->device_ordinal();
|
||||
GlobalDeviceId global_device_id;
|
||||
if (params.gpu_global_device_ids) {
|
||||
TF_RET_CHECK(0 <= local_device_ordinal &&
|
||||
local_device_ordinal < params.gpu_global_device_ids->size());
|
||||
global_device_id = (*params.gpu_global_device_ids)[local_device_ordinal];
|
||||
} else {
|
||||
// No local -> global mapping was provided; assume the identity mapping.
|
||||
global_device_id = GlobalDeviceId(local_device_ordinal);
|
||||
}
|
||||
|
||||
// Determines the set of global and local devices that are participating in
|
||||
// the same collective group as the caller.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<int64> participating_replicas,
|
||||
GetParticipatingReplicas(device_ordinal, instr->replica_groups(),
|
||||
std::vector<int64> global_participating_replicas,
|
||||
GetParticipatingReplicas(global_device_id, instr->replica_groups(),
|
||||
replica_count_, *params.device_assn));
|
||||
std::vector<GlobalDeviceId> global_devices;
|
||||
std::vector<std::pair<GlobalDeviceId, int64>> local_devices;
|
||||
local_devices.reserve(global_participating_replicas.size());
|
||||
global_devices.reserve(global_participating_replicas.size());
|
||||
TF_RET_CHECK(params.device_assn->computation_count() == 1)
|
||||
<< params.device_assn->ToString();
|
||||
for (int64 replica : global_participating_replicas) {
|
||||
GlobalDeviceId global_device(
|
||||
(*params.device_assn)(replica, /*computation=*/0));
|
||||
global_devices.push_back(global_device);
|
||||
if (!params.gpu_global_device_ids) {
|
||||
local_devices.emplace_back(global_device, global_device.value());
|
||||
} else {
|
||||
auto it = absl::c_find(*params.gpu_global_device_ids, global_device);
|
||||
if (it != params.gpu_global_device_ids->end()) {
|
||||
local_devices.emplace_back(
|
||||
*it, std::distance(params.gpu_global_device_ids->begin(), it));
|
||||
}
|
||||
}
|
||||
}
|
||||
absl::c_sort(global_devices);
|
||||
|
||||
// Find or create the rendezvous for this collective operation.
|
||||
RendezvousKey rendezvous_key = RendezvousKey::FromInstruction(
|
||||
params.run_id, participating_replicas, hlo_instruction());
|
||||
params.run_id, global_devices, local_devices.size(), hlo_instruction());
|
||||
|
||||
VLOG(2) << "Rendezvous key: " << rendezvous_key.ToString()
|
||||
<< ", participating replicas: "
|
||||
<< absl::StrJoin(participating_replicas, ", ");
|
||||
<< ", global participating replicas: "
|
||||
<< absl::StrJoin(global_participating_replicas, ", ")
|
||||
<< ", global participating devices: "
|
||||
<< GlobalDeviceIdsToString(global_devices);
|
||||
|
||||
AllReduceParticipantData participant(rendezvous_key);
|
||||
participant.device_ordinal = device_ordinal;
|
||||
participant.device_ordinal = local_device_ordinal;
|
||||
for (size_t i = 0; i < buffers_.size(); ++i) {
|
||||
const NcclAllReduceThunk::Buffer& buffer = buffers_[i];
|
||||
AllReduceParticipantData::Buffer pbuffer;
|
||||
@ -521,15 +598,24 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
participant.buffers.push_back(pbuffer);
|
||||
}
|
||||
participant.stream = params.stream;
|
||||
participant.local_devices = std::move(local_devices);
|
||||
participant.nccl_unique_id_callback = params.nccl_unique_id_callback;
|
||||
auto reduction_kind =
|
||||
MatchReductionComputation(hlo_instruction()->to_apply());
|
||||
CHECK(reduction_kind.has_value());
|
||||
participant.reduction_kind = *reduction_kind;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::shared_ptr<NcclClique> clique,
|
||||
RendezvousNcclAllReduce::SubmitParticipant(
|
||||
[&] { return GlobalRendezvousMap()[rendezvous_key]; }, participant));
|
||||
auto rendezvous_factory = [](const RendezvousKey& k) {
|
||||
return absl::make_unique<RendezvousNcclAllReduce>(k);
|
||||
};
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<NcclClique> clique,
|
||||
RendezvousNcclAllReduce::SubmitParticipant(
|
||||
[&] {
|
||||
return GlobalRendezvousMap().GetOrCreateIfAbsent(
|
||||
rendezvous_key, rendezvous_factory);
|
||||
},
|
||||
participant));
|
||||
|
||||
// Keep the clique we used alive for as long as this Thunk lives. Creating
|
||||
// new NCCL cliques is expensive, and this is how we avoid thrashing them.
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
@ -46,7 +47,7 @@ class NcclAllReduceThunk : public Thunk {
|
||||
// (Indeed, because the NCCL channels are a global variable, in the real
|
||||
// world, the value returned here is stale as soon as you read it, so it's not
|
||||
// clear how you *could* use it for anything other than tests.)
|
||||
static absl::flat_hash_set<int> DevicesWithOpenNcclChannels();
|
||||
static absl::flat_hash_set<GlobalDeviceId> DevicesWithOpenNcclChannels();
|
||||
|
||||
// TODO(b/125951860): Support all-reduces with replica groups, i.e.
|
||||
// all-reduces that compute multiple sums across subsets of all replicas.
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -98,6 +99,8 @@ class Thunk {
|
||||
HloExecutionProfiler* profiler; // never null
|
||||
const DeviceAssignment* device_assn; // never null
|
||||
std::vector<std::function<void()>>* deferred_host_callbacks; // never null
|
||||
const std::vector<GlobalDeviceId>* gpu_global_device_ids; // may be null
|
||||
const NcclUniqueIdCallback* nccl_unique_id_callback; // may be null
|
||||
};
|
||||
|
||||
// Execute the kernel for the thunk on the given stream. This method must be
|
||||
|
@ -159,7 +159,7 @@ DeviceAssignment MakeDeviceAssn(std::vector<int64> devices) {
|
||||
}
|
||||
|
||||
// Shorter alias for this function.
|
||||
absl::flat_hash_set<int> OpenNcclChannels() {
|
||||
absl::flat_hash_set<GlobalDeviceId> OpenNcclChannels() {
|
||||
return gpu::NcclAllReduceThunk::DevicesWithOpenNcclChannels();
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user