[XLA:GPU] Add an AllReduceCombiner pass, that merges AllReduce operations.
On GPU, implement combined allreduces using NCCL groups. PiperOrigin-RevId: 296130269 Change-Id: I763f0139c8ed9a59d7d691e3252e6b46244fefd6
This commit is contained in:
parent
20c1ba21a9
commit
db85f4c207
@ -1947,6 +1947,51 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "all_reduce_combiner",
|
||||
srcs = ["all_reduce_combiner.cc"],
|
||||
hdrs = ["all_reduce_combiner.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_domain_map",
|
||||
":hlo_pass",
|
||||
":hlo_query",
|
||||
":hlo_reachability",
|
||||
":shape_inference",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "all_reduce_combiner_test",
|
||||
srcs = ["all_reduce_combiner_test.cc"],
|
||||
deps = [
|
||||
":all_reduce_combiner",
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "all_reduce_simplifier",
|
||||
srcs = ["all_reduce_simplifier.cc"],
|
||||
|
452
tensorflow/compiler/xla/service/all_reduce_combiner.cc
Normal file
452
tensorflow/compiler/xla/service/all_reduce_combiner.cc
Normal file
@ -0,0 +1,452 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_domain_map.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_query.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
// Combines the elements of to_combine into a single AllReduce op. All
|
||||
// entries in to_combine must be AllReduce ops with exactly one operand
|
||||
// and the same reduction operation.
|
||||
Status CombineAllReduces(absl::Span<HloInstruction* const> to_combine) {
|
||||
if (to_combine.size() < 2) {
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(1) << "Combined " << to_combine.size() << " CRS ops";
|
||||
|
||||
HloComputation& computation = *to_combine.back()->parent();
|
||||
HloComputation* reduction = to_combine[0]->to_apply();
|
||||
const HloOpcode type = reduction->root_instruction()->opcode();
|
||||
|
||||
// Create a single bigger AllReduce of the operands of the smaller
|
||||
// AllReduces.
|
||||
std::vector<HloInstruction*> operands;
|
||||
std::vector<Shape> operand_shapes;
|
||||
VLOG(1) << "Combining set";
|
||||
for (HloInstruction* hlo : to_combine) {
|
||||
VLOG(1) << "Set element: " << hlo->ToString();
|
||||
TF_RET_CHECK(hlo->opcode() == HloOpcode::kAllReduce);
|
||||
TF_RET_CHECK(hlo->operands().size() == 1);
|
||||
TF_RET_CHECK(hlo->to_apply() == reduction ||
|
||||
(hlo->to_apply()->instruction_count() == 3 &&
|
||||
hlo->to_apply()->num_parameters() == 2 &&
|
||||
hlo->to_apply()->root_instruction()->opcode() == type));
|
||||
TF_RET_CHECK(hlo->shape().IsArray());
|
||||
for (HloInstruction* operand : hlo->operands()) {
|
||||
operands.push_back(operand);
|
||||
operand_shapes.push_back(operand->shape());
|
||||
}
|
||||
}
|
||||
|
||||
HloInstruction* combined;
|
||||
// AllReduce ops with more than one operand produce a tuple.
|
||||
TF_RET_CHECK(operands.size() >= 2);
|
||||
combined = computation.AddInstruction(HloInstruction::CreateAllReduce(
|
||||
ShapeUtil::MakeTupleShape(operand_shapes), operands, reduction,
|
||||
to_combine.front()->replica_groups(),
|
||||
/*constrain_layout=*/false, to_combine.front()->channel_id()));
|
||||
|
||||
// We have to propagate the sharding manually because Domain instructions are
|
||||
// not guaranteed to preserve it for side effecting instructions.
|
||||
if (to_combine.front()->has_sharding()) {
|
||||
combined->set_sharding(to_combine.front()->sharding());
|
||||
}
|
||||
VLOG(1) << "Replacing with : " << combined->ToString();
|
||||
|
||||
// Replace all the smaller AllReduces with elements of the tuple output
|
||||
// of the single bigger AllReduce.
|
||||
for (int64 i = 0; i < to_combine.size(); ++i) {
|
||||
auto replace_with = HloInstruction::CreateGetTupleElement(
|
||||
to_combine[i]->shape(), combined, i);
|
||||
TF_RETURN_IF_ERROR(computation.ReplaceWithNewInstruction(
|
||||
to_combine[i], std::move(replace_with)));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
struct GroupKey {
|
||||
GroupKey(const HloInstruction* hlo, const HloDomainMap& domain_map)
|
||||
: opcode(hlo->to_apply()->root_instruction()->opcode()),
|
||||
accum_type(hlo->to_apply()->root_instruction()->shape().element_type()),
|
||||
domain_id(domain_map.GetDomainMetadataId(hlo)),
|
||||
is_cross_shard(hlo->channel_id().has_value()),
|
||||
replica_groups(hlo->replica_groups()) {}
|
||||
|
||||
bool operator<(const GroupKey& other) const {
|
||||
if (opcode != other.opcode) {
|
||||
return opcode < other.opcode;
|
||||
}
|
||||
if (accum_type != other.accum_type) {
|
||||
return accum_type < other.accum_type;
|
||||
}
|
||||
if (domain_id != other.domain_id) {
|
||||
return domain_id < other.domain_id;
|
||||
}
|
||||
if (is_cross_shard != other.is_cross_shard) {
|
||||
return is_cross_shard < other.is_cross_shard;
|
||||
}
|
||||
if (replica_groups.size() != other.replica_groups.size()) {
|
||||
return replica_groups.size() < other.replica_groups.size();
|
||||
}
|
||||
for (int64 i = 0; i < replica_groups.size(); ++i) {
|
||||
const auto& rg = replica_groups[i];
|
||||
const auto& org = other.replica_groups[i];
|
||||
if (rg.replica_ids_size() != org.replica_ids_size()) {
|
||||
return rg.replica_ids_size() < org.replica_ids_size();
|
||||
}
|
||||
for (int64 j = 0; j < rg.replica_ids_size(); ++j) {
|
||||
if (rg.replica_ids(j) != org.replica_ids(j)) {
|
||||
return rg.replica_ids(j) < org.replica_ids(j);
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
HloOpcode opcode;
|
||||
PrimitiveType accum_type;
|
||||
int64 domain_id;
|
||||
bool is_cross_shard;
|
||||
std::vector<ReplicaGroup> replica_groups;
|
||||
};
|
||||
|
||||
// Group AllReduce instructions by the reduction types, e.g., add, min,
|
||||
// max, replica groups and domain. For cross-module all reduce instructions
|
||||
// we group them by the set of domains they are reducing across.
|
||||
//
|
||||
// Note that the shape of the reduction computation is not included in the
|
||||
// reduction types, e.g.: "f32[] add" and "bf16[] add" will be the same type. We
|
||||
// need to disallow combining CRS instructions with different domain metadata as
|
||||
// well as that could end up short-cutting two or more different domains.
|
||||
//
|
||||
// In each group, the instructions should be in post order. We will then iterate
|
||||
// each group and try to combine them, so to prevent non-determinism, we use
|
||||
// std::map here.
|
||||
//
|
||||
// The return value is a list of groups where every group contains a list of
|
||||
// all-reduce instruction sets in topological order and with a deterministic
|
||||
// order within the set. Additionally due to the above constraints every all
|
||||
// reduce set within a group will contain the same number of elements
|
||||
// and every instruction within an all reduce set will have the same
|
||||
// all-reduce-id (if specified) and thus shape (all reduce sets without an
|
||||
// all-reduce-id will have a single instruction).
|
||||
using InstructionGroups =
|
||||
std::vector<std::vector<std::vector<HloInstruction*>>>;
|
||||
StatusOr<InstructionGroups> CreateComputationGroups(
|
||||
HloComputation* computation) {
|
||||
TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, ""));
|
||||
|
||||
// Group instructions by opcode, domain id and replica group.
|
||||
std::map<GroupKey, std::vector<HloInstruction*>> opcode_groups;
|
||||
for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
|
||||
if (instruction->opcode() != HloOpcode::kAllReduce) {
|
||||
continue;
|
||||
}
|
||||
if (instruction->to_apply()->instruction_count() != 3 ||
|
||||
instruction->to_apply()->num_parameters() != 2) {
|
||||
VLOG(1) << "Skipping due to non-trivial reduction function.";
|
||||
continue;
|
||||
}
|
||||
opcode_groups[GroupKey(instruction, *domain_map)].push_back(instruction);
|
||||
}
|
||||
|
||||
// Generate a unique all-reduce-id for instructions without one by negating
|
||||
// the unique id of the hlo. This way we can treat cross module and normal CRS
|
||||
// instructions uniformly.
|
||||
auto channel_id = [](const HloInstruction* all_reduce) {
|
||||
return all_reduce->IsCrossModuleAllReduce()
|
||||
? all_reduce->channel_id().value()
|
||||
: -1 * all_reduce->unique_id();
|
||||
};
|
||||
|
||||
// Group instructions by all-reduce id with instructions for an all-reduce id
|
||||
// is listed along their group id and the (group id, instruction) pairs are
|
||||
// sorted by group id in the vector.
|
||||
std::map<int64, std::vector<std::pair<int64, HloInstruction*>>>
|
||||
all_reduce_sets;
|
||||
int64 group_id = 0;
|
||||
for (auto& domain_groups : opcode_groups) {
|
||||
for (HloInstruction* hlo : domain_groups.second) {
|
||||
all_reduce_sets[channel_id(hlo)].emplace_back(group_id, hlo);
|
||||
}
|
||||
++group_id;
|
||||
}
|
||||
|
||||
// Group instructions by participating group ids. Instructions within a group
|
||||
// are sorted by topological order and instructions within an all reduce group
|
||||
// is still sorted by group id.
|
||||
std::map<std::vector<int64>, std::vector<std::vector<HloInstruction*>>>
|
||||
all_reduce_group_map;
|
||||
for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
|
||||
if (instruction->opcode() != HloOpcode::kAllReduce) {
|
||||
continue;
|
||||
}
|
||||
if (instruction->to_apply()->instruction_count() != 3 ||
|
||||
instruction->to_apply()->num_parameters() != 2) {
|
||||
VLOG(1) << "Skipping due to non-trivial reduction function.";
|
||||
continue;
|
||||
}
|
||||
|
||||
int64 arid = channel_id(instruction);
|
||||
if (all_reduce_sets.count(arid) == 0) {
|
||||
// Already processed.
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<int64> group_ids;
|
||||
std::vector<HloInstruction*> instructions;
|
||||
for (const auto& hlo : all_reduce_sets[arid]) {
|
||||
group_ids.push_back(hlo.first);
|
||||
instructions.push_back(hlo.second);
|
||||
}
|
||||
all_reduce_group_map[group_ids].push_back(std::move(instructions));
|
||||
all_reduce_sets.erase(arid);
|
||||
}
|
||||
CHECK(all_reduce_sets.empty());
|
||||
|
||||
InstructionGroups groups;
|
||||
for (const auto& all_reduce_group : all_reduce_group_map) {
|
||||
groups.push_back(all_reduce_group.second);
|
||||
}
|
||||
return std::move(groups);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
AllReduceCombiner::AllReduceCombiner(int64 combine_threshold_in_bytes,
|
||||
int64 combine_threshold_count)
|
||||
: combine_threshold_in_bytes_(combine_threshold_in_bytes),
|
||||
combine_threshold_count_(combine_threshold_count) {}
|
||||
|
||||
StatusOr<bool> AllReduceCombiner::Run(HloModule* module) {
|
||||
VLOG(1) << "Running AllReduceCombiner with threshold of "
|
||||
<< combine_threshold_in_bytes_ << " bytes";
|
||||
|
||||
if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) {
|
||||
VLOG(1) << "Skip AllReduceCombiner because the module contains all-reduce "
|
||||
"with constrained layouts";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||
TF_ASSIGN_OR_RETURN(auto groups, CreateComputationGroups(computation));
|
||||
for (auto group : groups) {
|
||||
// Recompute reachability after every combine group because we can't
|
||||
// maintain a cross group topolgical order to be able to rely on the
|
||||
// transitive dependencies to detect cycles.
|
||||
auto reachability = HloReachabilityMap::Build(computation);
|
||||
|
||||
// Create a map to be able to find an instruction group based on the first
|
||||
// instruction in the group. It will be used during the post order
|
||||
// iteration to be able to process full groups at a time. Doing it only
|
||||
// for one instruction in every group will be sufficient because all
|
||||
// instruction have to schedule at the same time due to cross core
|
||||
// dependencies.
|
||||
absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>*>
|
||||
group_map;
|
||||
for (auto& instruction : group) {
|
||||
group_map[instruction.front()] = &instruction;
|
||||
}
|
||||
|
||||
// Collect sets of AllReduce instructions to combine.
|
||||
std::vector<std::vector<std::vector<HloInstruction*>>> combine_sets(1);
|
||||
int64 current_size_in_bytes = 0;
|
||||
int64 current_operand_count = 0;
|
||||
|
||||
// Iterate all instructions in post order and skip the ones not in the
|
||||
// current group. We have to create a new post order iteration for every
|
||||
// group because merging instructions in the previous group can made the
|
||||
// original post order no longer hold.
|
||||
// This will make it likely that we won't increase memory pressure much
|
||||
// above combine_threshold_in_bytes, since two AllReduces that are
|
||||
// near in post order are most likely, but not for sure, also near in
|
||||
// scheduled order.
|
||||
//
|
||||
// TODO(b/70235266): This should usually be fine, but it's probably
|
||||
// possible to construct some case where the memory usage increases beyond
|
||||
// the threshold due to reordering of the instructions in scheduling. If
|
||||
// this ever comes up as a real problem, it would be nice to implement
|
||||
// safeguards so that that cannot possibly happen.
|
||||
for (const HloInstruction* inst :
|
||||
computation->MakeInstructionPostOrder()) {
|
||||
auto it = group_map.find(inst);
|
||||
if (it == group_map.end()) {
|
||||
// Instruction belongs to a different group.
|
||||
continue;
|
||||
}
|
||||
const auto& instructions = *it->second;
|
||||
|
||||
VLOG(1) << "Considering HLO " << instructions.front()->ToString()
|
||||
<< " with current set size of " << current_size_in_bytes
|
||||
<< " and current operand count of " << current_operand_count;
|
||||
|
||||
// We do not handle AllReduce ops that do not have exactly 1
|
||||
// operand since that is simpler and this pass is the only way to
|
||||
// generate such ops and it should rarely be important to consider the
|
||||
// same ops again.
|
||||
if (instructions.front()->operands().size() != 1) {
|
||||
VLOG(1) << "Skipping due to "
|
||||
<< instructions.front()->operands().size() << " operands";
|
||||
continue;
|
||||
}
|
||||
|
||||
int64 size_in_bytes;
|
||||
TF_RET_CHECK(instructions.front()->shape().IsArray());
|
||||
size_in_bytes = ShapeUtil::ByteSizeOf(instructions.front()->shape());
|
||||
|
||||
if (size_in_bytes > combine_threshold_in_bytes_) {
|
||||
VLOG(1) << "Skipping due to size " << size_in_bytes
|
||||
<< " above threshold";
|
||||
// If the instruction is greather than the threshold, then we can
|
||||
// never combine it with anything.
|
||||
continue;
|
||||
}
|
||||
|
||||
// If the current set is dependent on the instruction, then create a new
|
||||
// one to avoid the dependency. We move on from the current set instead
|
||||
// of ignoring the instruction since otherwise a single AllReduce
|
||||
// instruction that all the other ones depend on (such as one on the
|
||||
// forward pass of a model) could disable this optimization entirely.
|
||||
TF_RET_CHECK(!combine_sets.empty());
|
||||
for (const auto& previous : combine_sets.back()) {
|
||||
// The reachability information does not reflect the planned
|
||||
// combination from combine_sets. We cannot just bring it up to date
|
||||
// cheaply since HloReachabilityMap does not track reachability
|
||||
// updates transitively and doing it directly is expensive. However,
|
||||
// leaving it stale has no effect on the reachability queries that we
|
||||
// are doing here because we are considering the ops in a topological
|
||||
// order, so we can just leave it stale.
|
||||
//
|
||||
// Proof: Suppose A is the instruction we are looking to combine and B
|
||||
// is an element of the current combine set that we are looking to
|
||||
// combine A into.
|
||||
//
|
||||
// First of all, we check that all elements in each set do not depend
|
||||
// on each other, so combining the *current* combine set cannot create
|
||||
// new dependencies between A and B. It remains to prove that
|
||||
// combining the prior combine sets also cannot create a dependency
|
||||
// between A and B.
|
||||
//
|
||||
// Assume to get a contradiction that there are two AllReduce
|
||||
// ops C and D in combine_sets that will be combined and that A and B
|
||||
// are not connected now but that they will be after combining C and
|
||||
// D. Then there exist paths in the dependency graph such that one of
|
||||
// these cases is true:
|
||||
//
|
||||
// A -> ... -> C and D -> ... -> B
|
||||
// A -> ... -> D and C -> ... -> B
|
||||
// B -> ... -> C and D -> ... -> A
|
||||
// B -> ... -> D and C -> ... -> A
|
||||
//
|
||||
// None of these cases are possible because we are visiting the nodes
|
||||
// in a topological order, so C and D cannot be in-between A and B.
|
||||
// That is a contradiction, so combining the prior combine sets also
|
||||
// cannot create a dependency between A and B.
|
||||
bool new_set = false;
|
||||
for (int64 i = 0; i < instructions.size(); ++i) {
|
||||
if (reachability->IsReachable(previous[i], instructions[i])) {
|
||||
VLOG(1) << "Starting new set due to dependency between "
|
||||
<< previous[i]->ToString() << " AND "
|
||||
<< instructions[i]->ToString();
|
||||
new_set = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (new_set) {
|
||||
combine_sets.emplace_back();
|
||||
current_size_in_bytes = 0;
|
||||
current_operand_count = 0;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (current_size_in_bytes + size_in_bytes >
|
||||
combine_threshold_in_bytes_ ||
|
||||
current_operand_count + 1 > combine_threshold_count_) {
|
||||
VLOG(1) << "The instruction cannot be entered into the set due "
|
||||
"to the combined size being too large.";
|
||||
// In this case we cannot include the instruction into the current set
|
||||
// since then it would grow beyond the threshold. The set of
|
||||
// instructions to carry forward will either be the current set or the
|
||||
// instruction by itself, whichever is smaller, since that maximizes
|
||||
// the chance of being able to combine with the next instruction.
|
||||
if (size_in_bytes > current_size_in_bytes) {
|
||||
VLOG(1) << "Skipping as the instruction is larger than the set.";
|
||||
continue; // keep the current set
|
||||
}
|
||||
VLOG(1)
|
||||
<< "Resetting the set as the set is larger than the instruction.";
|
||||
combine_sets.emplace_back();
|
||||
current_size_in_bytes = 0;
|
||||
current_operand_count = 0;
|
||||
}
|
||||
|
||||
VLOG(1) << "Adding instruction to set.";
|
||||
combine_sets.back().push_back(instructions);
|
||||
current_size_in_bytes += size_in_bytes;
|
||||
current_operand_count += 1;
|
||||
TF_RET_CHECK(current_size_in_bytes <= combine_threshold_in_bytes_);
|
||||
TF_RET_CHECK(current_operand_count <= combine_threshold_count_);
|
||||
}
|
||||
VLOG(1) << "Done constructing sets. Final set size is "
|
||||
<< current_size_in_bytes << " bytes and " << current_operand_count
|
||||
<< " operands";
|
||||
|
||||
// Combine the collected sets of AllReduce instructions.
|
||||
for (const auto& combine_set : combine_sets) {
|
||||
if (combine_set.size() >= 2) {
|
||||
changed = true;
|
||||
for (int64 i = 0; i < combine_set.front().size(); ++i) {
|
||||
std::vector<HloInstruction*> to_combine;
|
||||
to_combine.reserve(combine_set.size());
|
||||
for (const auto& c : combine_set) {
|
||||
to_combine.push_back(c[i]);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(CombineAllReduces(to_combine));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace xla
|
51
tensorflow/compiler/xla/service/all_reduce_combiner.h
Normal file
51
tensorflow/compiler/xla/service/all_reduce_combiner.h
Normal file
@ -0,0 +1,51 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALL_REDUCE_COMBINER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_ALL_REDUCE_COMBINER_H_
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Combines small non-dependent AllReduce ops into larger combined
|
||||
// AllReduce ops. A typical AllReduce implementation has a minimum
|
||||
// latency-induced time for a AllReduce op so a single combined op can be
|
||||
// more efficient than many small ones.
|
||||
class AllReduceCombiner : public HloModulePass {
|
||||
public:
|
||||
AllReduceCombiner(int64 combine_threshold_in_bytes,
|
||||
int64 combine_threshold_count);
|
||||
|
||||
absl::string_view name() const override { return "all-reduce-combiner"; }
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
private:
|
||||
// Combine all reduce ops up to this threshold.
|
||||
int64 combine_threshold_in_bytes_;
|
||||
|
||||
// Combine all reduce ops up to this threshold (number of operands).
|
||||
int64 combine_threshold_count_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ALL_REDUCE_COMBINER_H_
|
477
tensorflow/compiler/xla/service/all_reduce_combiner_test.cc
Normal file
477
tensorflow/compiler/xla/service/all_reduce_combiner_test.cc
Normal file
@ -0,0 +1,477 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using absl::nullopt;
|
||||
using ::testing::AllOf;
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
int64 kMaxCombineCount = 256;
|
||||
|
||||
int64 AllReduceCount(const HloModule& module) {
|
||||
int64 count = 0;
|
||||
for (HloComputation* computation : module.computations()) {
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
for (HloInstruction* hlo : computation->instructions()) {
|
||||
if (hlo->opcode() == HloOpcode::kAllReduce) {
|
||||
++count;
|
||||
}
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
// inputs[i] will be some op producing a shape of size sizes_in_kib[i] which
|
||||
// feeds into a a all reduce op in all_reduces[i]. Returns a tuple
|
||||
// of the all_reduces.
|
||||
HloInstruction* MakeCrossReplicaReductions(
|
||||
std::vector<int64> sizes_in_kib, std::vector<HloComputation*> reductions,
|
||||
std::vector<HloInstruction*>* inputs, HloComputation::Builder* b) {
|
||||
CHECK_EQ(reductions.size(), sizes_in_kib.size());
|
||||
std::vector<HloInstruction*> all_reduces;
|
||||
for (int i = 0; i < sizes_in_kib.size(); i++) {
|
||||
int64 size_in_kib = sizes_in_kib[i];
|
||||
HloComputation* reduction = reductions[i];
|
||||
auto constant = b->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.3)));
|
||||
Shape shape = ShapeUtil::MakeShape(
|
||||
F32, {static_cast<int32>(size_in_kib * 1024 / sizeof(float))});
|
||||
auto input =
|
||||
b->AddInstruction(HloInstruction::CreateBroadcast(shape, constant, {}));
|
||||
inputs->push_back(input);
|
||||
all_reduces.push_back(b->AddInstruction(HloInstruction::CreateAllReduce(
|
||||
shape, {input}, reduction, /*replica_groups=*/{},
|
||||
/*constrain_layout=*/false, /*channel_id=*/nullopt)));
|
||||
}
|
||||
return b->AddInstruction(HloInstruction::CreateTuple(all_reduces));
|
||||
}
|
||||
|
||||
// Create and add a reduction computation in the given type to the module.
|
||||
HloComputation* MakeReduction(const HloOpcode type, HloModule* module) {
|
||||
HloComputation::Builder sum_builder(HloOpcodeString(type));
|
||||
auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x"));
|
||||
auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
/*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y"));
|
||||
sum_builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {}), type, x, y));
|
||||
HloComputation* reduction =
|
||||
module->AddEmbeddedComputation(sum_builder.Build());
|
||||
return reduction;
|
||||
}
|
||||
|
||||
// Creates replica groups for AllReduce. groups[i] represents replica ids
|
||||
// for group 'i'.
|
||||
std::vector<ReplicaGroup> CreateReplicaGroups(
|
||||
absl::Span<const std::vector<int64>> groups) {
|
||||
std::vector<ReplicaGroup> replica_groups(groups.size());
|
||||
for (int64 i = 0; i < groups.size(); ++i) {
|
||||
*replica_groups[i].mutable_replica_ids() = {groups[i].begin(),
|
||||
groups[i].end()};
|
||||
}
|
||||
return replica_groups;
|
||||
}
|
||||
|
||||
using AllReduceCombinerTest = HloTestBase;
|
||||
|
||||
// Tests combination of several AllReduce instructions.
|
||||
TEST_F(AllReduceCombinerTest, CombineAllReduces) {
|
||||
auto module = CreateNewVerifiedModule();
|
||||
HloComputation* sum = MakeReduction(HloOpcode::kAdd, module.get());
|
||||
|
||||
HloComputation::Builder b(TestName());
|
||||
std::vector<HloInstruction*> inputs;
|
||||
auto root = MakeCrossReplicaReductions(
|
||||
{1, 2, 10, 7, 6}, {sum, sum, sum, sum, sum}, &inputs, &b);
|
||||
auto computation = module->AddEntryComputation(b.Build());
|
||||
|
||||
// Run the AllReduce combiner optimization pass.
|
||||
AllReduceCombiner combine(10 * 1024 * 1024, kMaxCombineCount);
|
||||
ASSERT_EQ(AllReduceCount(*module), inputs.size());
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
|
||||
ASSERT_EQ(AllReduceCount(*module), 1);
|
||||
EXPECT_TRUE(changed);
|
||||
|
||||
ASSERT_EQ(root, computation->root_instruction());
|
||||
ASSERT_EQ(inputs.size(), root->operands().size());
|
||||
|
||||
HloInstruction* combined = nullptr;
|
||||
for (int64 i = 0; i < root->operands().size(); ++i) {
|
||||
HloInstruction* hlo = root->mutable_operand(i);
|
||||
ASSERT_TRUE(hlo->opcode() == HloOpcode::kGetTupleElement);
|
||||
EXPECT_EQ(hlo->tuple_index(), i);
|
||||
EXPECT_TRUE(ShapeUtil::Equal(inputs[i]->shape(), hlo->shape()));
|
||||
|
||||
if (combined == nullptr) {
|
||||
// Verify the combined all reduce instruction.
|
||||
combined = hlo->mutable_operand(0);
|
||||
ASSERT_TRUE(combined->opcode() == HloOpcode::kAllReduce);
|
||||
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), combined->shape()));
|
||||
ASSERT_EQ(combined->operands().size(), inputs.size());
|
||||
}
|
||||
EXPECT_EQ(combined, hlo->operand(0));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(inputs[i]->shape(), hlo->shape()));
|
||||
EXPECT_EQ(combined->operand(i), inputs[i]);
|
||||
EXPECT_EQ(1, inputs[i]->users().size());
|
||||
}
|
||||
ASSERT_NE(combined, nullptr);
|
||||
}
|
||||
|
||||
// Tests combination of several cross replica reduction instructions in
|
||||
// different types.k
|
||||
TEST_F(AllReduceCombinerTest, CombineCrossReplicaReductionsInGroups) {
|
||||
auto module = CreateNewVerifiedModule();
|
||||
HloComputation* sum = MakeReduction(HloOpcode::kAdd, module.get());
|
||||
HloComputation* min = MakeReduction(HloOpcode::kMinimum, module.get());
|
||||
HloComputation* max = MakeReduction(HloOpcode::kMaximum, module.get());
|
||||
HloComputation* sum_2 = MakeReduction(HloOpcode::kAdd, module.get());
|
||||
|
||||
HloComputation::Builder b(TestName());
|
||||
std::vector<HloInstruction*> inputs;
|
||||
MakeCrossReplicaReductions(
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
{sum, sum_2, min, min, min, max, max, max, sum, sum_2}, &inputs, &b);
|
||||
module->AddEntryComputation(b.Build());
|
||||
|
||||
// Run the AllReduce combiner optimization pass.
|
||||
AllReduceCombiner combine(10 * 1024 * 1024, kMaxCombineCount);
|
||||
ASSERT_EQ(AllReduceCount(*module), inputs.size());
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
|
||||
ASSERT_EQ(AllReduceCount(*module), 3)
|
||||
<< "expects 3 groups for 3 reduction types.";
|
||||
EXPECT_TRUE(changed);
|
||||
}
|
||||
|
||||
// Tests that the combination threshold is respected.
|
||||
TEST_F(AllReduceCombinerTest, RespectThreshold) {
|
||||
auto module = CreateNewVerifiedModule();
|
||||
HloComputation* sum = MakeReduction(HloOpcode::kAdd, module.get());
|
||||
|
||||
HloComputation::Builder b(TestName());
|
||||
std::vector<HloInstruction*> inputs;
|
||||
MakeCrossReplicaReductions({8, 4}, {sum, sum}, &inputs, &b);
|
||||
module->AddEntryComputation(b.Build());
|
||||
|
||||
// Run the AllReduce combiner optimization pass with threshold less than
|
||||
// the combined size of the all reduce ops so that the combination
|
||||
// cannot occur.
|
||||
{
|
||||
AllReduceCombiner combine((8 + 4) * 1024 - 1, kMaxCombineCount);
|
||||
ASSERT_EQ(AllReduceCount(*module), inputs.size());
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
|
||||
EXPECT_EQ(AllReduceCount(*module), inputs.size());
|
||||
EXPECT_FALSE(changed);
|
||||
}
|
||||
|
||||
// Run the AllReduce combiner optimization pass again with a slightly
|
||||
// higher threshold so that the combination can occur.
|
||||
{
|
||||
AllReduceCombiner combine((8 + 4) * 1024, kMaxCombineCount);
|
||||
ASSERT_EQ(AllReduceCount(*module), inputs.size());
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
|
||||
EXPECT_EQ(AllReduceCount(*module), 1);
|
||||
EXPECT_TRUE(changed);
|
||||
}
|
||||
}
|
||||
|
||||
// Tests that dependent all reduces are not combined.
|
||||
TEST_F(AllReduceCombinerTest, NoDependentCombination) {
|
||||
auto module = CreateNewVerifiedModule();
|
||||
HloComputation* reduction = MakeReduction(HloOpcode::kAdd, module.get());
|
||||
|
||||
HloComputation::Builder b(TestName());
|
||||
auto constant = b.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.3)));
|
||||
auto all_reduce = b.AddInstruction(HloInstruction::CreateAllReduce(
|
||||
constant->shape(), {constant}, reduction, /*replica_groups=*/{},
|
||||
/*constrain_layout=*/false, /*channel_id=*/nullopt));
|
||||
b.AddInstruction(HloInstruction::CreateAllReduce(
|
||||
constant->shape(), {all_reduce}, reduction,
|
||||
/*replica_groups=*/{}, /*constrain_layout=*/false,
|
||||
/*channel_id=*/nullopt));
|
||||
|
||||
module->AddEntryComputation(b.Build());
|
||||
|
||||
AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
|
||||
ASSERT_EQ(AllReduceCount(*module), 2);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
|
||||
EXPECT_EQ(AllReduceCount(*module), 2);
|
||||
EXPECT_FALSE(changed);
|
||||
}
|
||||
|
||||
// Tests that AllReduce ops with different groups are not combined.
|
||||
TEST_F(AllReduceCombinerTest, GroupAllReduce) {
|
||||
auto module = CreateNewVerifiedModule();
|
||||
HloComputation::Builder b(TestName());
|
||||
HloComputation* reduction = MakeReduction(HloOpcode::kAdd, module.get());
|
||||
|
||||
auto constant = b.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.3)));
|
||||
auto crs0 = b.AddInstruction(
|
||||
HloInstruction::CreateAllReduce(constant->shape(), {constant}, reduction,
|
||||
CreateReplicaGroups({{0, 1}, {2, 3}}),
|
||||
/*constrain_layout=*/false,
|
||||
/*channel_id=*/nullopt));
|
||||
auto crs1 = b.AddInstruction(
|
||||
HloInstruction::CreateAllReduce(constant->shape(), {constant}, reduction,
|
||||
CreateReplicaGroups({{0, 2}, {1, 3}}),
|
||||
/*constrain_layout=*/false,
|
||||
/*channel_id=*/nullopt));
|
||||
b.AddInstruction(HloInstruction::CreateTuple({crs0, crs1}));
|
||||
|
||||
module->AddEntryComputation(b.Build());
|
||||
|
||||
AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
|
||||
ASSERT_EQ(AllReduceCount(*module), 2);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
|
||||
EXPECT_EQ(AllReduceCount(*module), 2);
|
||||
EXPECT_FALSE(changed);
|
||||
}
|
||||
|
||||
TEST_F(AllReduceCombinerTest, DomainPreventsCombining) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
summit {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT add = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
param0 = f32[128] parameter(0), sharding={maximal device=0}
|
||||
param1 = f32[128] parameter(1), sharding={maximal device=1}
|
||||
crs0 = f32[128] all-reduce(param0),
|
||||
replica_groups={}, to_apply=summit, sharding={maximal device=0}
|
||||
crs1 = f32[128] all-reduce(param1),
|
||||
replica_groups={}, to_apply=summit, sharding={maximal device=1}
|
||||
domain0 = f32[128] domain(crs0),
|
||||
domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}}, exit={maximal device=0}}
|
||||
domain1 = f32[128] domain(crs1),
|
||||
domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}}, exit={maximal device=1}}
|
||||
ROOT tuple = (f32[128], f32[128]) tuple(domain0, domain1),
|
||||
sharding={{maximal device=0}, {maximal device=1}}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
LOG(INFO) << "Original module:\n" << module->ToString();
|
||||
|
||||
AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
|
||||
ASSERT_EQ(AllReduceCount(*module), 2);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
|
||||
EXPECT_EQ(AllReduceCount(*module), 2);
|
||||
EXPECT_FALSE(changed);
|
||||
}
|
||||
|
||||
// This test checks that two CRS instructions that are in separate domains
|
||||
// but with the same domain metadata can be combined.
|
||||
TEST_F(AllReduceCombinerTest, CombineFromTwoDomainsWithSameMetadata) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
summit {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT add = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
param0 = f32[128] parameter(0), sharding={maximal device=0}
|
||||
param1 = f32[128] parameter(1), sharding={maximal device=1}
|
||||
param2 = f32[128] parameter(2), sharding={maximal device=1}
|
||||
crs0 = f32[128] all-reduce(param0),
|
||||
replica_groups={}, to_apply=summit, sharding={maximal device=0}
|
||||
crs1 = f32[128] all-reduce(param1),
|
||||
replica_groups={}, to_apply=summit, sharding={maximal device=1}
|
||||
crs2 = f32[128] all-reduce(param2),
|
||||
replica_groups={}, to_apply=summit, sharding={maximal device=0}
|
||||
domain0 = f32[128] domain(crs0),
|
||||
domain={kind="sharding", entry={{maximal device=0}, {maximal device=1},
|
||||
{maximal device=0}}, exit={maximal device=0}}
|
||||
domain1 = f32[128] domain(crs1),
|
||||
domain={kind="sharding", entry={{maximal device=0}, {maximal device=1},
|
||||
{maximal device=0}}, exit={maximal device=1}}
|
||||
domain2 = f32[128] domain(crs2),
|
||||
domain={kind="sharding", entry={{maximal device=0}, {maximal device=1},
|
||||
{maximal device=0}}, exit={maximal device=0}}
|
||||
ROOT tuple = (f32[128], f32[128], f32[128]) tuple(domain0, domain1, domain2),
|
||||
sharding={{maximal device=0}, {maximal device=1}, {maximal device=0}}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
|
||||
ASSERT_EQ(AllReduceCount(*module), 3);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
|
||||
EXPECT_EQ(AllReduceCount(*module), 2);
|
||||
EXPECT_TRUE(changed);
|
||||
}
|
||||
|
||||
TEST_F(AllReduceCombinerTest, DoNotCombineCrossShardAndCrosReplicaInSPMD) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
summit {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT add = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
param0 = f32[128] parameter(0), sharding={maximal device=0}
|
||||
param1 = f32[128] parameter(1), sharding={maximal device=1}
|
||||
cross_shard_ar = f32[128] all-reduce(param0),
|
||||
replica_groups={{0}}, to_apply=summit, channel_id=1
|
||||
cross_replica_ar = f32[128] all-reduce(param1),
|
||||
replica_groups={{0}}, to_apply=summit, sharding={maximal device=1}
|
||||
ROOT tuple = (f32[128], f32[128]) tuple(cross_shard_ar, cross_replica_ar)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
|
||||
ASSERT_EQ(AllReduceCount(*module), 2);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
|
||||
EXPECT_EQ(AllReduceCount(*module), 2);
|
||||
EXPECT_FALSE(changed);
|
||||
}
|
||||
|
||||
TEST_F(AllReduceCombinerTest, CrossCoreAllReduce) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
summit {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT add = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
param0 = f32[128] parameter(0), sharding={maximal device=0}
|
||||
param1 = f32[128] parameter(1), sharding={maximal device=1}
|
||||
crs00 = f32[128] all-reduce(param0),
|
||||
replica_groups={{0}}, channel_id=1, to_apply=summit,
|
||||
sharding={maximal device=0}
|
||||
crs01 = f32[128] all-reduce(param1),
|
||||
replica_groups={{0}}, channel_id=1, to_apply=summit,
|
||||
sharding={maximal device=1}
|
||||
crs10 = f32[128] all-reduce(param0),
|
||||
replica_groups={{0}}, channel_id=2, to_apply=summit,
|
||||
sharding={maximal device=0}
|
||||
crs11 = f32[128] all-reduce(param1),
|
||||
replica_groups={{0}}, channel_id=2, to_apply=summit,
|
||||
sharding={maximal device=1}
|
||||
domain0 = f32[128] domain(crs00),
|
||||
domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
|
||||
ROOT add = f32[128] add(domain0, crs11),
|
||||
sharding={maximal device=1}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
|
||||
ASSERT_EQ(AllReduceCount(*module), 4);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
|
||||
EXPECT_EQ(AllReduceCount(*module), 2);
|
||||
EXPECT_TRUE(changed);
|
||||
|
||||
EXPECT_THAT(
|
||||
module->entry_computation()->root_instruction(),
|
||||
op::Add(op::Domain(op::GetTupleElement(
|
||||
AllOf(op::AllReduce(op::Parameter(0), op::Parameter(0)),
|
||||
op::Shape("(f32[128], f32[128])")),
|
||||
1)),
|
||||
op::GetTupleElement(
|
||||
AllOf(op::AllReduce(op::Parameter(1), op::Parameter(1)),
|
||||
op::Shape("(f32[128], f32[128])")),
|
||||
0)));
|
||||
}
|
||||
|
||||
TEST_F(AllReduceCombinerTest, CrossCombineGroupCycle) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
%add {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT add = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
%max {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT add = f32[] maximum(lhs, rhs)
|
||||
}
|
||||
ENTRY %comp {
|
||||
p0 = f32[128] parameter(0)
|
||||
p1 = f32[128] parameter(1)
|
||||
|
||||
crs00 = f32[128] all-reduce(p0), to_apply=add
|
||||
crs10 = f32[128] all-reduce(p1), to_apply=max
|
||||
|
||||
crs01 = f32[128] all-reduce(crs00), to_apply=max
|
||||
crs11 = f32[128] all-reduce(crs10), to_apply=add
|
||||
add0 = f32[128] add(crs01, crs11)
|
||||
|
||||
crs02 = f32[128] all-reduce(add0), to_apply=add
|
||||
crs12 = f32[128] all-reduce(crs11), to_apply=add
|
||||
ROOT tuple = (f32[128], f32[128]) tuple(crs02, crs12)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
|
||||
ASSERT_EQ(AllReduceCount(*module), 6);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
|
||||
EXPECT_EQ(AllReduceCount(*module), 4);
|
||||
EXPECT_TRUE(changed);
|
||||
|
||||
auto crs0 = op::AllReduce(op::Parameter(0), op::AllReduce(op::Parameter(1)));
|
||||
auto add = op::Add(op::AllReduce(op::GetTupleElement(crs0, 0)),
|
||||
op::GetTupleElement(crs0, 1));
|
||||
auto crs1 = op::AllReduce(add, op::GetTupleElement(crs0));
|
||||
EXPECT_THAT(
|
||||
module->entry_computation()->root_instruction(),
|
||||
op::Tuple(op::GetTupleElement(crs1, 0), op::GetTupleElement(crs1, 1)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -149,7 +149,6 @@ struct AllReduceParticipantData {
|
||||
explicit AllReduceParticipantData(RendezvousKey rendezvous_key)
|
||||
: rendezvous_key(rendezvous_key) {}
|
||||
|
||||
int64 element_count;
|
||||
int64 device_ordinal;
|
||||
RendezvousKey rendezvous_key;
|
||||
|
||||
@ -157,20 +156,30 @@ struct AllReduceParticipantData {
|
||||
// source_buffer == destination_buffer if that avoids a NCCL copy (will depend
|
||||
// on how well the NCCL in-place implementation performs vs the out-of-place
|
||||
// implementation).
|
||||
se::DeviceMemoryBase source_data;
|
||||
se::DeviceMemoryBase destination_data;
|
||||
struct Buffer {
|
||||
int64 element_count;
|
||||
se::DeviceMemoryBase source_data;
|
||||
se::DeviceMemoryBase destination_data;
|
||||
PrimitiveType primitive_type;
|
||||
};
|
||||
std::vector<Buffer> buffers;
|
||||
se::Stream* stream;
|
||||
|
||||
ReductionKind reduction_kind;
|
||||
PrimitiveType primitive_type;
|
||||
|
||||
int num_participants() const { return rendezvous_key.num_participants(); }
|
||||
|
||||
string ToString() const {
|
||||
std::vector<std::string> buffer_strs;
|
||||
for (const Buffer& buffer : buffers) {
|
||||
buffer_strs.push_back(
|
||||
absl::StrFormat("{element_count=%d}", buffer.element_count));
|
||||
}
|
||||
return absl::StrFormat(
|
||||
"AllReduceParticipantData{element_count=%d, rendezvous_key=%s, "
|
||||
"AllReduceParticipantData{buffers=[%s], rendezvous_key=%s, "
|
||||
"device_ordinal=%d, stream=%p}",
|
||||
element_count, rendezvous_key.ToString(), device_ordinal, stream);
|
||||
absl::StrJoin(buffer_strs, ","), rendezvous_key.ToString(),
|
||||
device_ordinal, stream);
|
||||
}
|
||||
};
|
||||
|
||||
@ -245,7 +254,7 @@ class Rendezvous {
|
||||
|
||||
// Spot check for consistent replica counts among submitting threads.
|
||||
if (!participants_.empty() &&
|
||||
(participants_.back().element_count != participant.element_count ||
|
||||
(participants_.back().buffers.size() != participant.buffers.size() ||
|
||||
participants_.back().rendezvous_key != participant.rendezvous_key)) {
|
||||
return InvalidArgument(
|
||||
"Mismatch among all-reduce participants. Expected same "
|
||||
|
@ -262,7 +262,8 @@ class CpuAllReduceRendezvous : public xla::Rendezvous<std::nullptr_t> {
|
||||
protected:
|
||||
xla::StatusOr<std::pair<std::nullptr_t, bool>> SubmitParticipantImpl(
|
||||
xla::AllReduceParticipantData participant) override {
|
||||
xla::PrimitiveType datatype = participant.primitive_type;
|
||||
TF_RET_CHECK(participant.buffers.size() == 1);
|
||||
xla::PrimitiveType datatype = participant.buffers.front().primitive_type;
|
||||
bool primary = [&] {
|
||||
tensorflow::mutex_lock lock(mu_);
|
||||
if (!initialized_) {
|
||||
@ -316,10 +317,8 @@ class CpuAllReduceRendezvous : public xla::Rendezvous<std::nullptr_t> {
|
||||
using T = typename xla::primitive_util::PrimitiveTypeToNative<PT>::type;
|
||||
tensorflow::mutex_lock lock(mu_);
|
||||
CHECK(!participants_.empty());
|
||||
xla::int64 element_count = participant.element_count;
|
||||
xla::ReductionKind reduction_kind = participant.reduction_kind;
|
||||
for (const auto& p : participants_) {
|
||||
CHECK_EQ(p.element_count, element_count);
|
||||
CHECK(p.reduction_kind == reduction_kind);
|
||||
}
|
||||
|
||||
@ -329,11 +328,19 @@ class CpuAllReduceRendezvous : public xla::Rendezvous<std::nullptr_t> {
|
||||
output_buffers.reserve(participants_.size());
|
||||
|
||||
for (auto& p : participants_) {
|
||||
input_buffers.emplace_back(static_cast<T*>(p.source_data.opaque()),
|
||||
element_count);
|
||||
output_buffers.emplace_back(static_cast<T*>(p.destination_data.opaque()),
|
||||
element_count);
|
||||
CHECK_EQ(p.buffers.size(), 1);
|
||||
CHECK_EQ(p.buffers.front().element_count,
|
||||
participants_.front().buffers.front().element_count);
|
||||
xla::int64 element_count = participant.buffers.front().element_count;
|
||||
input_buffers.emplace_back(
|
||||
static_cast<T*>(p.buffers.front().source_data.opaque()),
|
||||
element_count);
|
||||
output_buffers.emplace_back(
|
||||
static_cast<T*>(p.buffers.front().destination_data.opaque()),
|
||||
element_count);
|
||||
}
|
||||
xla::int64 element_count =
|
||||
participants_.front().buffers.front().element_count;
|
||||
|
||||
auto compute = [reduction_kind](T a, T b) -> T {
|
||||
switch (reduction_kind) {
|
||||
@ -416,7 +423,6 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
|
||||
xla::RendezvousKey rendezvous_key(run_options->run_id(),
|
||||
participating_replicas_vec, op_kind, op_id);
|
||||
|
||||
|
||||
auto shape_str = ShapeString(shape_ptr, shape_length);
|
||||
VLOG(2) << "All-reduce input/output shape : " << shape_str;
|
||||
|
||||
@ -426,14 +432,16 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
|
||||
<< "All-reduce on CPU is implemented only for dense arrays";
|
||||
|
||||
xla::AllReduceParticipantData participant(rendezvous_key);
|
||||
participant.element_count = xla::ShapeUtil::ElementsIn(shape);
|
||||
participant.device_ordinal = device_ordinal;
|
||||
participant.primitive_type = shape.element_type();
|
||||
participant.stream = run_options->stream();
|
||||
participant.source_data =
|
||||
xla::AllReduceParticipantData::Buffer buffer;
|
||||
buffer.element_count = xla::ShapeUtil::ElementsIn(shape);
|
||||
buffer.primitive_type = shape.element_type();
|
||||
buffer.source_data =
|
||||
se::DeviceMemoryBase(input_buffer, xla::ShapeUtil::ByteSizeOf(shape));
|
||||
participant.destination_data =
|
||||
buffer.destination_data =
|
||||
se::DeviceMemoryBase(output_buffer, xla::ShapeUtil::ByteSizeOf(shape));
|
||||
participant.buffers = {buffer};
|
||||
participant.reduction_kind = static_cast<xla::ReductionKind>(reduction_kind);
|
||||
|
||||
TF_CHECK_OK(
|
||||
|
@ -1131,6 +1131,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:algebraic_simplifier",
|
||||
"//tensorflow/compiler/xla/service:all_reduce_combiner",
|
||||
"//tensorflow/compiler/xla/service:batchnorm_expander",
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:call_inliner",
|
||||
|
@ -42,15 +42,11 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
||||
struct NcclAllReduceThunk::AuxData {};
|
||||
|
||||
NcclAllReduceThunk::NcclAllReduceThunk(
|
||||
int64 replica_count, int64 element_count,
|
||||
const BufferAllocation::Slice& source_buffer,
|
||||
const BufferAllocation::Slice& destination_buffer,
|
||||
int64 replica_count, std::vector<NcclAllReduceThunk::Buffer> buffers,
|
||||
const HloInstruction* all_reduce)
|
||||
: Thunk(Thunk::kNcclAllReduce, all_reduce),
|
||||
replica_count_(replica_count),
|
||||
element_count_(element_count),
|
||||
source_buffer_(source_buffer),
|
||||
destination_buffer_(destination_buffer) {}
|
||||
buffers_(std::move(buffers)) {}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
|
||||
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||
@ -291,7 +292,13 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
horizontal_fusion.AddPass<HloDCE>();
|
||||
TF_RETURN_IF_ERROR(horizontal_fusion.Run(hlo_module).status());
|
||||
}
|
||||
|
||||
{
|
||||
HloPassPipeline pipeline("all_reduce_combiner");
|
||||
pipeline.AddPass<AllReduceCombiner>(
|
||||
/*combine_threshold_in_bytes=*/30 * 1024 * 1024,
|
||||
/*combine_threshold_count=*/256);
|
||||
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1210,10 +1210,7 @@ Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
} // namespace
|
||||
namespace {} // namespace
|
||||
|
||||
Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
|
||||
VLOG(2) << "AllReduce; replica count: " << hlo_module_config_.replica_count()
|
||||
@ -1226,13 +1223,37 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
|
||||
NcclAllReduceThunk::CanImplement(crs);
|
||||
|
||||
if (should_use_nccl_thunk) {
|
||||
CHECK(crs->operand(0)->shape().IsArray())
|
||||
<< "Operands to all-reduce must be arrays: " << crs->ToString();
|
||||
AddThunkToThunkSequence(absl::make_unique<NcclAllReduceThunk>(
|
||||
std::vector<NcclAllReduceThunk::Buffer> buffers;
|
||||
std::vector<BufferAllocation::Slice> tuple_element_buffers;
|
||||
buffers.resize(crs->operand_count());
|
||||
tuple_element_buffers.reserve(crs->operand_count());
|
||||
CHECK(crs->shape().IsArray() && crs->operand_count() == 1 ||
|
||||
crs->shape().IsTuple() &&
|
||||
crs->shape().tuple_shapes_size() == crs->operand_count());
|
||||
for (int i = 0; i < crs->operand_count(); ++i) {
|
||||
CHECK(crs->operand(i)->shape().IsArray())
|
||||
<< "Operands to all-reduce must be arrays: " << crs->ToString();
|
||||
buffers[i].element_count =
|
||||
ShapeUtil::ElementsIn(crs->operand(i)->shape());
|
||||
buffers[i].source_buffer = GetAllocationSlice(*crs->operand(i));
|
||||
buffers[i].destination_buffer = GetAllocationSlice(
|
||||
*crs, crs->shape().IsTuple() ? ShapeIndex({i}) : ShapeIndex({}));
|
||||
tuple_element_buffers.push_back(buffers[i].destination_buffer);
|
||||
}
|
||||
auto all_reduce_thunk = absl::make_unique<NcclAllReduceThunk>(
|
||||
/*replica_count=*/hlo_module_config_.replica_count(),
|
||||
/*elements=*/ShapeUtil::ElementsIn(crs->operand(0)->shape()),
|
||||
/*source_address=*/GetAllocationSlice(*crs->operand(0)),
|
||||
/*destination_buffer=*/GetAllocationSlice(*crs), crs));
|
||||
/*buffers=*/std::move(buffers), crs);
|
||||
if (crs->shape().IsTuple()) {
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
thunks.push_back(std::move(all_reduce_thunk));
|
||||
thunks.push_back(absl::make_unique<TupleThunk>(
|
||||
tuple_element_buffers, GetAllocationSlice(*crs), nullptr));
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), crs));
|
||||
} else {
|
||||
AddThunkToThunkSequence(std::move(all_reduce_thunk));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1957,32 +1978,32 @@ void IrEmitterUnnested::EmitTile(
|
||||
//
|
||||
// TODO(cheshire): Once ptxas is fixed and TF switches to it, remove the
|
||||
// workaround.
|
||||
ksl->For(
|
||||
loop_name + "_y_in_tile",
|
||||
/*start=*/constant(0),
|
||||
/*end=*/
|
||||
ceil_of_ratio(b_.CreateSub(tile_height, thread_id_info.thread_id_y),
|
||||
num_threads_y),
|
||||
/*step=*/constant(1), [&](llvm::Value* y_indvar) {
|
||||
llvm::Value* y_loc = b_.CreateAdd(
|
||||
thread_id_info.thread_id_y, b_.CreateMul(y_indvar, num_threads_y));
|
||||
for (int64 j = 0; j < x_num_steps; j++) {
|
||||
llvm::Value* x_loc =
|
||||
b_.CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
|
||||
IrArray::Index source_idx_x =
|
||||
source_idx.AddOffsetToDim(y_loc, kDimY, &b_)
|
||||
.AddOffsetToDim(constant(j * step_x), kDimX, &b_);
|
||||
auto emit_element = [&] {
|
||||
return emit_elem_function(source_idx_x, y_loc, x_loc, j);
|
||||
};
|
||||
if (!x_tile_fits) {
|
||||
ksl->If(loop_name + "_x_in_tile",
|
||||
b_.CreateICmpULT(x_loc, tile_width), emit_element);
|
||||
} else {
|
||||
emit_element();
|
||||
}
|
||||
}
|
||||
});
|
||||
ksl->For(loop_name + "_y_in_tile",
|
||||
/*start=*/constant(0),
|
||||
/*end=*/
|
||||
ceil_of_ratio(b_.CreateSub(tile_height, thread_id_info.thread_id_y),
|
||||
num_threads_y),
|
||||
/*step=*/constant(1), [&](llvm::Value* y_indvar) {
|
||||
llvm::Value* y_loc =
|
||||
b_.CreateAdd(thread_id_info.thread_id_y,
|
||||
b_.CreateMul(y_indvar, num_threads_y));
|
||||
for (int64 j = 0; j < x_num_steps; j++) {
|
||||
llvm::Value* x_loc =
|
||||
b_.CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
|
||||
IrArray::Index source_idx_x =
|
||||
source_idx.AddOffsetToDim(y_loc, kDimY, &b_)
|
||||
.AddOffsetToDim(constant(j * step_x), kDimX, &b_);
|
||||
auto emit_element = [&] {
|
||||
return emit_elem_function(source_idx_x, y_loc, x_loc, j);
|
||||
};
|
||||
if (!x_tile_fits) {
|
||||
ksl->If(loop_name + "_x_in_tile",
|
||||
b_.CreateICmpULT(x_loc, tile_width), emit_element);
|
||||
} else {
|
||||
emit_element();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Emits code to process a tensor element in a tile for the given kCopy HLO that
|
||||
|
@ -154,10 +154,6 @@ ncclRedOp_t ReductionKindToNccl(ReductionKind kind) {
|
||||
}
|
||||
}
|
||||
|
||||
PrimitiveType AllReducePrimitiveType(const HloInstruction* instr) {
|
||||
return instr->operand(0)->shape().element_type();
|
||||
}
|
||||
|
||||
absl::optional<ncclDataType_t> DatatypeToNccl(PrimitiveType element_type) {
|
||||
switch (element_type) {
|
||||
case S8:
|
||||
@ -402,9 +398,6 @@ RendezvousNcclAllReduce::SubmitParticipantImpl(
|
||||
VLOG(3) << "Performing all reduce from device ordinal: "
|
||||
<< participant.device_ordinal;
|
||||
ncclRedOp_t computation = ReductionKindToNccl(participant.reduction_kind);
|
||||
absl::optional<ncclDataType_t> allreduce_datatype =
|
||||
DatatypeToNccl(participant.primitive_type);
|
||||
CHECK(allreduce_datatype.has_value());
|
||||
|
||||
se::StreamExecutor* executor = participant.stream->parent();
|
||||
se::cuda::ScopedActivateExecutorContext scoped_context(executor);
|
||||
@ -412,19 +405,26 @@ RendezvousNcclAllReduce::SubmitParticipantImpl(
|
||||
participant.stream->implementation()->GpuStreamMemberHack());
|
||||
VLOG(3) << "Using stream pointer: " << cu_stream
|
||||
<< " on device: " << participant.device_ordinal;
|
||||
void* send_buffer = participant.source_data.opaque();
|
||||
void* recv_buffer = participant.destination_data.opaque();
|
||||
VLOG(3) << absl::StreamFormat(
|
||||
"Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, "
|
||||
"comm=%p, stream=%p)",
|
||||
send_buffer, recv_buffer, participant.element_count,
|
||||
static_cast<const void*>(comm), cu_stream);
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer,
|
||||
/*count=*/participant.element_count,
|
||||
/*datatype=*/*allreduce_datatype,
|
||||
/*op=*/computation,
|
||||
/*comm=*/comm,
|
||||
/*stream=*/*cu_stream));
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
|
||||
for (auto& buffer : participant.buffers) {
|
||||
void* send_buffer = buffer.source_data.opaque();
|
||||
void* recv_buffer = buffer.destination_data.opaque();
|
||||
absl::optional<ncclDataType_t> allreduce_datatype =
|
||||
DatatypeToNccl(buffer.primitive_type);
|
||||
CHECK(allreduce_datatype.has_value());
|
||||
VLOG(3) << absl::StreamFormat(
|
||||
"Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, "
|
||||
"comm=%p, stream=%p)",
|
||||
send_buffer, recv_buffer, buffer.element_count,
|
||||
static_cast<const void*>(comm), cu_stream);
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer,
|
||||
/*count=*/buffer.element_count,
|
||||
/*datatype=*/*allreduce_datatype,
|
||||
/*op=*/computation,
|
||||
/*comm=*/comm,
|
||||
/*stream=*/*cu_stream));
|
||||
}
|
||||
XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
|
||||
|
||||
VLOG(3) << "Done performing all reduce for ordinal: "
|
||||
<< participant.device_ordinal;
|
||||
@ -453,11 +453,14 @@ struct NcclAllReduceThunk::AuxData {
|
||||
};
|
||||
|
||||
/*static*/ bool NcclAllReduceThunk::CanImplement(const HloInstruction* crs) {
|
||||
auto operands_are_supported = [crs]() {
|
||||
return absl::c_all_of(crs->operands(), [](HloInstruction* operand) {
|
||||
return LayoutUtil::IsDenseArray(operand->shape()) &&
|
||||
DatatypeToNccl(operand->shape().element_type()).has_value();
|
||||
});
|
||||
};
|
||||
return MatchReductionComputation(crs->to_apply()).has_value() &&
|
||||
DatatypeToNccl(AllReducePrimitiveType(crs)).has_value() &&
|
||||
crs->IsCrossReplicaAllReduce() &&
|
||||
crs->operand_count() == 1 && // One array to reduce.
|
||||
LayoutUtil::IsDenseArray(crs->operand(0)->shape());
|
||||
crs->IsCrossReplicaAllReduce() && operands_are_supported();
|
||||
}
|
||||
|
||||
/*static*/ absl::flat_hash_set<int>
|
||||
@ -471,16 +474,14 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
||||
}
|
||||
|
||||
NcclAllReduceThunk::NcclAllReduceThunk(
|
||||
int64 replica_count, int64 element_count,
|
||||
const BufferAllocation::Slice& source_buffer,
|
||||
const BufferAllocation::Slice& destination_buffer,
|
||||
int64 replica_count, std::vector<NcclAllReduceThunk::Buffer> buffers,
|
||||
const HloInstruction* all_reduce)
|
||||
: Thunk(Thunk::kNcclAllReduce, all_reduce),
|
||||
replica_count_(replica_count),
|
||||
element_count_(element_count),
|
||||
source_buffer_(source_buffer),
|
||||
destination_buffer_(destination_buffer),
|
||||
aux_data_(absl::make_unique<AuxData>()) {}
|
||||
buffers_(std::move(buffers)),
|
||||
aux_data_(absl::make_unique<AuxData>()) {
|
||||
CHECK_EQ(hlo_instruction()->operand_count(), buffers_.size());
|
||||
}
|
||||
|
||||
// Figures out which devices (named by their replica-ids) are participating in
|
||||
// the all-reduce subgroup that contains device_ordinal.
|
||||
@ -506,18 +507,24 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
<< absl::StrJoin(participating_replicas, ", ");
|
||||
|
||||
AllReduceParticipantData participant(rendezvous_key);
|
||||
participant.element_count = element_count_;
|
||||
participant.device_ordinal = device_ordinal;
|
||||
participant.source_data =
|
||||
params.buffer_allocations->GetDeviceAddress(source_buffer_);
|
||||
participant.destination_data =
|
||||
params.buffer_allocations->GetDeviceAddress(destination_buffer_);
|
||||
for (size_t i = 0; i < buffers_.size(); ++i) {
|
||||
const NcclAllReduceThunk::Buffer& buffer = buffers_[i];
|
||||
AllReduceParticipantData::Buffer pbuffer;
|
||||
pbuffer.element_count = buffer.element_count;
|
||||
pbuffer.source_data =
|
||||
params.buffer_allocations->GetDeviceAddress(buffer.source_buffer);
|
||||
pbuffer.destination_data =
|
||||
params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer);
|
||||
pbuffer.primitive_type =
|
||||
hlo_instruction()->operand(i)->shape().element_type();
|
||||
participant.buffers.push_back(pbuffer);
|
||||
}
|
||||
participant.stream = params.stream;
|
||||
auto reduction_kind =
|
||||
MatchReductionComputation(hlo_instruction()->to_apply());
|
||||
CHECK(reduction_kind.has_value());
|
||||
participant.reduction_kind = *reduction_kind;
|
||||
participant.primitive_type = AllReducePrimitiveType(hlo_instruction());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::shared_ptr<NcclClique> clique,
|
||||
|
@ -50,9 +50,12 @@ class NcclAllReduceThunk : public Thunk {
|
||||
|
||||
// TODO(b/125951860): Support all-reduces with replica groups, i.e.
|
||||
// all-reduces that compute multiple sums across subsets of all replicas.
|
||||
NcclAllReduceThunk(int64 replica_count, int64 element_count,
|
||||
const BufferAllocation::Slice& source_buffer,
|
||||
const BufferAllocation::Slice& destination_buffer,
|
||||
struct Buffer {
|
||||
int64 element_count;
|
||||
BufferAllocation::Slice source_buffer;
|
||||
BufferAllocation::Slice destination_buffer;
|
||||
};
|
||||
NcclAllReduceThunk(int64 replica_count, std::vector<Buffer> buffers,
|
||||
const HloInstruction* all_reduce);
|
||||
~NcclAllReduceThunk() override;
|
||||
|
||||
@ -70,9 +73,7 @@ class NcclAllReduceThunk : public Thunk {
|
||||
struct AuxData;
|
||||
|
||||
const int64 replica_count_;
|
||||
const int64 element_count_;
|
||||
const BufferAllocation::Slice source_buffer_;
|
||||
const BufferAllocation::Slice destination_buffer_;
|
||||
const std::vector<Buffer> buffers_;
|
||||
std::unique_ptr<AuxData> aux_data_;
|
||||
};
|
||||
|
||||
|
@ -368,6 +368,55 @@ XLA_TEST_F(CollectiveOpsTest, AllReduce_ManyConcurrentAllReduces) {
|
||||
done.Wait();
|
||||
}
|
||||
|
||||
// Runs the same executable many times concurrently. The all-reduces should not
|
||||
// conflict with one another.
|
||||
XLA_TEST_F(CollectiveOpsTest, AllReduce_CombinableAllReduces) {
|
||||
std::string hlo_string = R"(
|
||||
HloModule test
|
||||
|
||||
apply_op {
|
||||
x = f32[] parameter(0)
|
||||
y = f32[] parameter(1)
|
||||
ROOT apply_op = f32[] add(x, y)
|
||||
}
|
||||
|
||||
ENTRY test_computation {
|
||||
p0 = f32[5] parameter(0)
|
||||
p1 = f32[5] parameter(1)
|
||||
crs0 = f32[5] all-reduce(p0), replica_groups={}, to_apply=apply_op
|
||||
crs1 = f32[5] all-reduce(p1), replica_groups={}, to_apply=apply_op
|
||||
ROOT out = (f32[5], f32[5]) tuple(f32[5] crs0, f32[5] crs1)
|
||||
}
|
||||
)";
|
||||
static constexpr int kNumReplicas = 2;
|
||||
auto config = GetModuleConfigForTest();
|
||||
config.set_replica_count(kNumReplicas);
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string, config));
|
||||
|
||||
std::vector<float> input0_vec = {1., 2., 3., 4., 5.};
|
||||
auto input0_literal = LiteralUtil::CreateR1<float>(input0_vec);
|
||||
std::vector<float> input1_vec = {7., 3., 4., 1., 2.};
|
||||
auto input1_literal = LiteralUtil::CreateR1<float>(input1_vec);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::vector<Literal> results,
|
||||
ExecuteReplicated(std::move(module), {&input0_literal, &input1_literal},
|
||||
/*num_replicas=*/kNumReplicas,
|
||||
/*use_threads=*/true));
|
||||
std::vector<float> expected0_vec = {2., 4., 6., 8., 10.};
|
||||
auto expected0_literal = LiteralUtil::CreateR1<float>(expected0_vec);
|
||||
std::vector<float> expected1_vec = {14., 6., 8., 2., 4.};
|
||||
auto expected1_literal = LiteralUtil::CreateR1<float>(expected1_vec);
|
||||
for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
|
||||
auto rs = results[replica_idx].DecomposeTuple();
|
||||
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected0_literal, rs[0],
|
||||
ErrorSpec{1e-5, 1e-5}));
|
||||
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected1_literal, rs[1],
|
||||
ErrorSpec{1e-5, 1e-5}));
|
||||
}
|
||||
}
|
||||
|
||||
// Runs an all-reduce with three partitions:
|
||||
// {0}, {1,2}, {3}
|
||||
// meaning, the all-reduce is a nop for devices 0 and 3, and only devices 1 and
|
||||
|
Loading…
x
Reference in New Issue
Block a user