[XLA:GPU] Add an AllReduceCombiner pass, that merges AllReduce operations.

On GPU, implement combined allreduces using NCCL groups.

PiperOrigin-RevId: 296130269
Change-Id: I763f0139c8ed9a59d7d691e3252e6b46244fefd6
This commit is contained in:
Peter Hawkins 2020-02-19 22:04:16 -08:00 committed by TensorFlower Gardener
parent 20c1ba21a9
commit db85f4c207
13 changed files with 1229 additions and 105 deletions

View File

@ -1947,6 +1947,51 @@ tf_cc_test(
],
)
cc_library(
name = "all_reduce_combiner",
srcs = ["all_reduce_combiner.cc"],
hdrs = ["all_reduce_combiner.h"],
deps = [
":hlo",
":hlo_domain_map",
":hlo_pass",
":hlo_query",
":hlo_reachability",
":shape_inference",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
tf_cc_test(
name = "all_reduce_combiner_test",
srcs = ["all_reduce_combiner_test.cc"],
deps = [
":all_reduce_combiner",
":hlo",
":hlo_matchers",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/memory",
],
)
cc_library(
name = "all_reduce_simplifier",
srcs = ["all_reduce_simplifier.cc"],

View File

@ -0,0 +1,452 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
#include <algorithm>
#include <list>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_domain_map.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace {
// Combines the elements of to_combine into a single AllReduce op. All
// entries in to_combine must be AllReduce ops with exactly one operand
// and the same reduction operation.
Status CombineAllReduces(absl::Span<HloInstruction* const> to_combine) {
if (to_combine.size() < 2) {
return Status::OK();
}
VLOG(1) << "Combined " << to_combine.size() << " CRS ops";
HloComputation& computation = *to_combine.back()->parent();
HloComputation* reduction = to_combine[0]->to_apply();
const HloOpcode type = reduction->root_instruction()->opcode();
// Create a single bigger AllReduce of the operands of the smaller
// AllReduces.
std::vector<HloInstruction*> operands;
std::vector<Shape> operand_shapes;
VLOG(1) << "Combining set";
for (HloInstruction* hlo : to_combine) {
VLOG(1) << "Set element: " << hlo->ToString();
TF_RET_CHECK(hlo->opcode() == HloOpcode::kAllReduce);
TF_RET_CHECK(hlo->operands().size() == 1);
TF_RET_CHECK(hlo->to_apply() == reduction ||
(hlo->to_apply()->instruction_count() == 3 &&
hlo->to_apply()->num_parameters() == 2 &&
hlo->to_apply()->root_instruction()->opcode() == type));
TF_RET_CHECK(hlo->shape().IsArray());
for (HloInstruction* operand : hlo->operands()) {
operands.push_back(operand);
operand_shapes.push_back(operand->shape());
}
}
HloInstruction* combined;
// AllReduce ops with more than one operand produce a tuple.
TF_RET_CHECK(operands.size() >= 2);
combined = computation.AddInstruction(HloInstruction::CreateAllReduce(
ShapeUtil::MakeTupleShape(operand_shapes), operands, reduction,
to_combine.front()->replica_groups(),
/*constrain_layout=*/false, to_combine.front()->channel_id()));
// We have to propagate the sharding manually because Domain instructions are
// not guaranteed to preserve it for side effecting instructions.
if (to_combine.front()->has_sharding()) {
combined->set_sharding(to_combine.front()->sharding());
}
VLOG(1) << "Replacing with : " << combined->ToString();
// Replace all the smaller AllReduces with elements of the tuple output
// of the single bigger AllReduce.
for (int64 i = 0; i < to_combine.size(); ++i) {
auto replace_with = HloInstruction::CreateGetTupleElement(
to_combine[i]->shape(), combined, i);
TF_RETURN_IF_ERROR(computation.ReplaceWithNewInstruction(
to_combine[i], std::move(replace_with)));
}
return Status::OK();
}
struct GroupKey {
GroupKey(const HloInstruction* hlo, const HloDomainMap& domain_map)
: opcode(hlo->to_apply()->root_instruction()->opcode()),
accum_type(hlo->to_apply()->root_instruction()->shape().element_type()),
domain_id(domain_map.GetDomainMetadataId(hlo)),
is_cross_shard(hlo->channel_id().has_value()),
replica_groups(hlo->replica_groups()) {}
bool operator<(const GroupKey& other) const {
if (opcode != other.opcode) {
return opcode < other.opcode;
}
if (accum_type != other.accum_type) {
return accum_type < other.accum_type;
}
if (domain_id != other.domain_id) {
return domain_id < other.domain_id;
}
if (is_cross_shard != other.is_cross_shard) {
return is_cross_shard < other.is_cross_shard;
}
if (replica_groups.size() != other.replica_groups.size()) {
return replica_groups.size() < other.replica_groups.size();
}
for (int64 i = 0; i < replica_groups.size(); ++i) {
const auto& rg = replica_groups[i];
const auto& org = other.replica_groups[i];
if (rg.replica_ids_size() != org.replica_ids_size()) {
return rg.replica_ids_size() < org.replica_ids_size();
}
for (int64 j = 0; j < rg.replica_ids_size(); ++j) {
if (rg.replica_ids(j) != org.replica_ids(j)) {
return rg.replica_ids(j) < org.replica_ids(j);
}
}
}
return false;
}
HloOpcode opcode;
PrimitiveType accum_type;
int64 domain_id;
bool is_cross_shard;
std::vector<ReplicaGroup> replica_groups;
};
// Group AllReduce instructions by the reduction types, e.g., add, min,
// max, replica groups and domain. For cross-module all reduce instructions
// we group them by the set of domains they are reducing across.
//
// Note that the shape of the reduction computation is not included in the
// reduction types, e.g.: "f32[] add" and "bf16[] add" will be the same type. We
// need to disallow combining CRS instructions with different domain metadata as
// well as that could end up short-cutting two or more different domains.
//
// In each group, the instructions should be in post order. We will then iterate
// each group and try to combine them, so to prevent non-determinism, we use
// std::map here.
//
// The return value is a list of groups where every group contains a list of
// all-reduce instruction sets in topological order and with a deterministic
// order within the set. Additionally due to the above constraints every all
// reduce set within a group will contain the same number of elements
// and every instruction within an all reduce set will have the same
// all-reduce-id (if specified) and thus shape (all reduce sets without an
// all-reduce-id will have a single instruction).
using InstructionGroups =
std::vector<std::vector<std::vector<HloInstruction*>>>;
StatusOr<InstructionGroups> CreateComputationGroups(
HloComputation* computation) {
TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, ""));
// Group instructions by opcode, domain id and replica group.
std::map<GroupKey, std::vector<HloInstruction*>> opcode_groups;
for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
if (instruction->opcode() != HloOpcode::kAllReduce) {
continue;
}
if (instruction->to_apply()->instruction_count() != 3 ||
instruction->to_apply()->num_parameters() != 2) {
VLOG(1) << "Skipping due to non-trivial reduction function.";
continue;
}
opcode_groups[GroupKey(instruction, *domain_map)].push_back(instruction);
}
// Generate a unique all-reduce-id for instructions without one by negating
// the unique id of the hlo. This way we can treat cross module and normal CRS
// instructions uniformly.
auto channel_id = [](const HloInstruction* all_reduce) {
return all_reduce->IsCrossModuleAllReduce()
? all_reduce->channel_id().value()
: -1 * all_reduce->unique_id();
};
// Group instructions by all-reduce id with instructions for an all-reduce id
// is listed along their group id and the (group id, instruction) pairs are
// sorted by group id in the vector.
std::map<int64, std::vector<std::pair<int64, HloInstruction*>>>
all_reduce_sets;
int64 group_id = 0;
for (auto& domain_groups : opcode_groups) {
for (HloInstruction* hlo : domain_groups.second) {
all_reduce_sets[channel_id(hlo)].emplace_back(group_id, hlo);
}
++group_id;
}
// Group instructions by participating group ids. Instructions within a group
// are sorted by topological order and instructions within an all reduce group
// is still sorted by group id.
std::map<std::vector<int64>, std::vector<std::vector<HloInstruction*>>>
all_reduce_group_map;
for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
if (instruction->opcode() != HloOpcode::kAllReduce) {
continue;
}
if (instruction->to_apply()->instruction_count() != 3 ||
instruction->to_apply()->num_parameters() != 2) {
VLOG(1) << "Skipping due to non-trivial reduction function.";
continue;
}
int64 arid = channel_id(instruction);
if (all_reduce_sets.count(arid) == 0) {
// Already processed.
continue;
}
std::vector<int64> group_ids;
std::vector<HloInstruction*> instructions;
for (const auto& hlo : all_reduce_sets[arid]) {
group_ids.push_back(hlo.first);
instructions.push_back(hlo.second);
}
all_reduce_group_map[group_ids].push_back(std::move(instructions));
all_reduce_sets.erase(arid);
}
CHECK(all_reduce_sets.empty());
InstructionGroups groups;
for (const auto& all_reduce_group : all_reduce_group_map) {
groups.push_back(all_reduce_group.second);
}
return std::move(groups);
}
} // namespace
AllReduceCombiner::AllReduceCombiner(int64 combine_threshold_in_bytes,
int64 combine_threshold_count)
: combine_threshold_in_bytes_(combine_threshold_in_bytes),
combine_threshold_count_(combine_threshold_count) {}
StatusOr<bool> AllReduceCombiner::Run(HloModule* module) {
VLOG(1) << "Running AllReduceCombiner with threshold of "
<< combine_threshold_in_bytes_ << " bytes";
if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) {
VLOG(1) << "Skip AllReduceCombiner because the module contains all-reduce "
"with constrained layouts";
return false;
}
bool changed = false;
for (HloComputation* computation : module->MakeNonfusionComputations()) {
TF_ASSIGN_OR_RETURN(auto groups, CreateComputationGroups(computation));
for (auto group : groups) {
// Recompute reachability after every combine group because we can't
// maintain a cross group topolgical order to be able to rely on the
// transitive dependencies to detect cycles.
auto reachability = HloReachabilityMap::Build(computation);
// Create a map to be able to find an instruction group based on the first
// instruction in the group. It will be used during the post order
// iteration to be able to process full groups at a time. Doing it only
// for one instruction in every group will be sufficient because all
// instruction have to schedule at the same time due to cross core
// dependencies.
absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>*>
group_map;
for (auto& instruction : group) {
group_map[instruction.front()] = &instruction;
}
// Collect sets of AllReduce instructions to combine.
std::vector<std::vector<std::vector<HloInstruction*>>> combine_sets(1);
int64 current_size_in_bytes = 0;
int64 current_operand_count = 0;
// Iterate all instructions in post order and skip the ones not in the
// current group. We have to create a new post order iteration for every
// group because merging instructions in the previous group can made the
// original post order no longer hold.
// This will make it likely that we won't increase memory pressure much
// above combine_threshold_in_bytes, since two AllReduces that are
// near in post order are most likely, but not for sure, also near in
// scheduled order.
//
// TODO(b/70235266): This should usually be fine, but it's probably
// possible to construct some case where the memory usage increases beyond
// the threshold due to reordering of the instructions in scheduling. If
// this ever comes up as a real problem, it would be nice to implement
// safeguards so that that cannot possibly happen.
for (const HloInstruction* inst :
computation->MakeInstructionPostOrder()) {
auto it = group_map.find(inst);
if (it == group_map.end()) {
// Instruction belongs to a different group.
continue;
}
const auto& instructions = *it->second;
VLOG(1) << "Considering HLO " << instructions.front()->ToString()
<< " with current set size of " << current_size_in_bytes
<< " and current operand count of " << current_operand_count;
// We do not handle AllReduce ops that do not have exactly 1
// operand since that is simpler and this pass is the only way to
// generate such ops and it should rarely be important to consider the
// same ops again.
if (instructions.front()->operands().size() != 1) {
VLOG(1) << "Skipping due to "
<< instructions.front()->operands().size() << " operands";
continue;
}
int64 size_in_bytes;
TF_RET_CHECK(instructions.front()->shape().IsArray());
size_in_bytes = ShapeUtil::ByteSizeOf(instructions.front()->shape());
if (size_in_bytes > combine_threshold_in_bytes_) {
VLOG(1) << "Skipping due to size " << size_in_bytes
<< " above threshold";
// If the instruction is greather than the threshold, then we can
// never combine it with anything.
continue;
}
// If the current set is dependent on the instruction, then create a new
// one to avoid the dependency. We move on from the current set instead
// of ignoring the instruction since otherwise a single AllReduce
// instruction that all the other ones depend on (such as one on the
// forward pass of a model) could disable this optimization entirely.
TF_RET_CHECK(!combine_sets.empty());
for (const auto& previous : combine_sets.back()) {
// The reachability information does not reflect the planned
// combination from combine_sets. We cannot just bring it up to date
// cheaply since HloReachabilityMap does not track reachability
// updates transitively and doing it directly is expensive. However,
// leaving it stale has no effect on the reachability queries that we
// are doing here because we are considering the ops in a topological
// order, so we can just leave it stale.
//
// Proof: Suppose A is the instruction we are looking to combine and B
// is an element of the current combine set that we are looking to
// combine A into.
//
// First of all, we check that all elements in each set do not depend
// on each other, so combining the *current* combine set cannot create
// new dependencies between A and B. It remains to prove that
// combining the prior combine sets also cannot create a dependency
// between A and B.
//
// Assume to get a contradiction that there are two AllReduce
// ops C and D in combine_sets that will be combined and that A and B
// are not connected now but that they will be after combining C and
// D. Then there exist paths in the dependency graph such that one of
// these cases is true:
//
// A -> ... -> C and D -> ... -> B
// A -> ... -> D and C -> ... -> B
// B -> ... -> C and D -> ... -> A
// B -> ... -> D and C -> ... -> A
//
// None of these cases are possible because we are visiting the nodes
// in a topological order, so C and D cannot be in-between A and B.
// That is a contradiction, so combining the prior combine sets also
// cannot create a dependency between A and B.
bool new_set = false;
for (int64 i = 0; i < instructions.size(); ++i) {
if (reachability->IsReachable(previous[i], instructions[i])) {
VLOG(1) << "Starting new set due to dependency between "
<< previous[i]->ToString() << " AND "
<< instructions[i]->ToString();
new_set = true;
break;
}
}
if (new_set) {
combine_sets.emplace_back();
current_size_in_bytes = 0;
current_operand_count = 0;
break;
}
}
if (current_size_in_bytes + size_in_bytes >
combine_threshold_in_bytes_ ||
current_operand_count + 1 > combine_threshold_count_) {
VLOG(1) << "The instruction cannot be entered into the set due "
"to the combined size being too large.";
// In this case we cannot include the instruction into the current set
// since then it would grow beyond the threshold. The set of
// instructions to carry forward will either be the current set or the
// instruction by itself, whichever is smaller, since that maximizes
// the chance of being able to combine with the next instruction.
if (size_in_bytes > current_size_in_bytes) {
VLOG(1) << "Skipping as the instruction is larger than the set.";
continue; // keep the current set
}
VLOG(1)
<< "Resetting the set as the set is larger than the instruction.";
combine_sets.emplace_back();
current_size_in_bytes = 0;
current_operand_count = 0;
}
VLOG(1) << "Adding instruction to set.";
combine_sets.back().push_back(instructions);
current_size_in_bytes += size_in_bytes;
current_operand_count += 1;
TF_RET_CHECK(current_size_in_bytes <= combine_threshold_in_bytes_);
TF_RET_CHECK(current_operand_count <= combine_threshold_count_);
}
VLOG(1) << "Done constructing sets. Final set size is "
<< current_size_in_bytes << " bytes and " << current_operand_count
<< " operands";
// Combine the collected sets of AllReduce instructions.
for (const auto& combine_set : combine_sets) {
if (combine_set.size() >= 2) {
changed = true;
for (int64 i = 0; i < combine_set.front().size(); ++i) {
std::vector<HloInstruction*> to_combine;
to_combine.reserve(combine_set.size());
for (const auto& c : combine_set) {
to_combine.push_back(c[i]);
}
TF_RETURN_IF_ERROR(CombineAllReduces(to_combine));
}
}
}
}
}
return changed;
}
} // namespace xla

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALL_REDUCE_COMBINER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_ALL_REDUCE_COMBINER_H_
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
// Combines small non-dependent AllReduce ops into larger combined
// AllReduce ops. A typical AllReduce implementation has a minimum
// latency-induced time for a AllReduce op so a single combined op can be
// more efficient than many small ones.
class AllReduceCombiner : public HloModulePass {
public:
AllReduceCombiner(int64 combine_threshold_in_bytes,
int64 combine_threshold_count);
absl::string_view name() const override { return "all-reduce-combiner"; }
StatusOr<bool> Run(HloModule* module) override;
private:
// Combine all reduce ops up to this threshold.
int64 combine_threshold_in_bytes_;
// Combine all reduce ops up to this threshold (number of operands).
int64 combine_threshold_count_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ALL_REDUCE_COMBINER_H_

View File

@ -0,0 +1,477 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
#include <memory>
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace {
using absl::nullopt;
using ::testing::AllOf;
namespace op = xla::testing::opcode_matchers;
int64 kMaxCombineCount = 256;
int64 AllReduceCount(const HloModule& module) {
int64 count = 0;
for (HloComputation* computation : module.computations()) {
if (computation->IsFusionComputation()) {
continue;
}
for (HloInstruction* hlo : computation->instructions()) {
if (hlo->opcode() == HloOpcode::kAllReduce) {
++count;
}
}
}
return count;
}
// inputs[i] will be some op producing a shape of size sizes_in_kib[i] which
// feeds into a a all reduce op in all_reduces[i]. Returns a tuple
// of the all_reduces.
HloInstruction* MakeCrossReplicaReductions(
std::vector<int64> sizes_in_kib, std::vector<HloComputation*> reductions,
std::vector<HloInstruction*>* inputs, HloComputation::Builder* b) {
CHECK_EQ(reductions.size(), sizes_in_kib.size());
std::vector<HloInstruction*> all_reduces;
for (int i = 0; i < sizes_in_kib.size(); i++) {
int64 size_in_kib = sizes_in_kib[i];
HloComputation* reduction = reductions[i];
auto constant = b->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.3)));
Shape shape = ShapeUtil::MakeShape(
F32, {static_cast<int32>(size_in_kib * 1024 / sizeof(float))});
auto input =
b->AddInstruction(HloInstruction::CreateBroadcast(shape, constant, {}));
inputs->push_back(input);
all_reduces.push_back(b->AddInstruction(HloInstruction::CreateAllReduce(
shape, {input}, reduction, /*replica_groups=*/{},
/*constrain_layout=*/false, /*channel_id=*/nullopt)));
}
return b->AddInstruction(HloInstruction::CreateTuple(all_reduces));
}
// Create and add a reduction computation in the given type to the module.
HloComputation* MakeReduction(const HloOpcode type, HloModule* module) {
HloComputation::Builder sum_builder(HloOpcodeString(type));
auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x"));
auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y"));
sum_builder.AddInstruction(
HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {}), type, x, y));
HloComputation* reduction =
module->AddEmbeddedComputation(sum_builder.Build());
return reduction;
}
// Creates replica groups for AllReduce. groups[i] represents replica ids
// for group 'i'.
std::vector<ReplicaGroup> CreateReplicaGroups(
absl::Span<const std::vector<int64>> groups) {
std::vector<ReplicaGroup> replica_groups(groups.size());
for (int64 i = 0; i < groups.size(); ++i) {
*replica_groups[i].mutable_replica_ids() = {groups[i].begin(),
groups[i].end()};
}
return replica_groups;
}
using AllReduceCombinerTest = HloTestBase;
// Tests combination of several AllReduce instructions.
TEST_F(AllReduceCombinerTest, CombineAllReduces) {
auto module = CreateNewVerifiedModule();
HloComputation* sum = MakeReduction(HloOpcode::kAdd, module.get());
HloComputation::Builder b(TestName());
std::vector<HloInstruction*> inputs;
auto root = MakeCrossReplicaReductions(
{1, 2, 10, 7, 6}, {sum, sum, sum, sum, sum}, &inputs, &b);
auto computation = module->AddEntryComputation(b.Build());
// Run the AllReduce combiner optimization pass.
AllReduceCombiner combine(10 * 1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), inputs.size());
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
ASSERT_EQ(AllReduceCount(*module), 1);
EXPECT_TRUE(changed);
ASSERT_EQ(root, computation->root_instruction());
ASSERT_EQ(inputs.size(), root->operands().size());
HloInstruction* combined = nullptr;
for (int64 i = 0; i < root->operands().size(); ++i) {
HloInstruction* hlo = root->mutable_operand(i);
ASSERT_TRUE(hlo->opcode() == HloOpcode::kGetTupleElement);
EXPECT_EQ(hlo->tuple_index(), i);
EXPECT_TRUE(ShapeUtil::Equal(inputs[i]->shape(), hlo->shape()));
if (combined == nullptr) {
// Verify the combined all reduce instruction.
combined = hlo->mutable_operand(0);
ASSERT_TRUE(combined->opcode() == HloOpcode::kAllReduce);
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), combined->shape()));
ASSERT_EQ(combined->operands().size(), inputs.size());
}
EXPECT_EQ(combined, hlo->operand(0));
EXPECT_TRUE(ShapeUtil::Equal(inputs[i]->shape(), hlo->shape()));
EXPECT_EQ(combined->operand(i), inputs[i]);
EXPECT_EQ(1, inputs[i]->users().size());
}
ASSERT_NE(combined, nullptr);
}
// Tests combination of several cross replica reduction instructions in
// different types.k
TEST_F(AllReduceCombinerTest, CombineCrossReplicaReductionsInGroups) {
auto module = CreateNewVerifiedModule();
HloComputation* sum = MakeReduction(HloOpcode::kAdd, module.get());
HloComputation* min = MakeReduction(HloOpcode::kMinimum, module.get());
HloComputation* max = MakeReduction(HloOpcode::kMaximum, module.get());
HloComputation* sum_2 = MakeReduction(HloOpcode::kAdd, module.get());
HloComputation::Builder b(TestName());
std::vector<HloInstruction*> inputs;
MakeCrossReplicaReductions(
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
{sum, sum_2, min, min, min, max, max, max, sum, sum_2}, &inputs, &b);
module->AddEntryComputation(b.Build());
// Run the AllReduce combiner optimization pass.
AllReduceCombiner combine(10 * 1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), inputs.size());
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
ASSERT_EQ(AllReduceCount(*module), 3)
<< "expects 3 groups for 3 reduction types.";
EXPECT_TRUE(changed);
}
// Tests that the combination threshold is respected.
TEST_F(AllReduceCombinerTest, RespectThreshold) {
auto module = CreateNewVerifiedModule();
HloComputation* sum = MakeReduction(HloOpcode::kAdd, module.get());
HloComputation::Builder b(TestName());
std::vector<HloInstruction*> inputs;
MakeCrossReplicaReductions({8, 4}, {sum, sum}, &inputs, &b);
module->AddEntryComputation(b.Build());
// Run the AllReduce combiner optimization pass with threshold less than
// the combined size of the all reduce ops so that the combination
// cannot occur.
{
AllReduceCombiner combine((8 + 4) * 1024 - 1, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), inputs.size());
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_EQ(AllReduceCount(*module), inputs.size());
EXPECT_FALSE(changed);
}
// Run the AllReduce combiner optimization pass again with a slightly
// higher threshold so that the combination can occur.
{
AllReduceCombiner combine((8 + 4) * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), inputs.size());
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_EQ(AllReduceCount(*module), 1);
EXPECT_TRUE(changed);
}
}
// Tests that dependent all reduces are not combined.
TEST_F(AllReduceCombinerTest, NoDependentCombination) {
auto module = CreateNewVerifiedModule();
HloComputation* reduction = MakeReduction(HloOpcode::kAdd, module.get());
HloComputation::Builder b(TestName());
auto constant = b.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.3)));
auto all_reduce = b.AddInstruction(HloInstruction::CreateAllReduce(
constant->shape(), {constant}, reduction, /*replica_groups=*/{},
/*constrain_layout=*/false, /*channel_id=*/nullopt));
b.AddInstruction(HloInstruction::CreateAllReduce(
constant->shape(), {all_reduce}, reduction,
/*replica_groups=*/{}, /*constrain_layout=*/false,
/*channel_id=*/nullopt));
module->AddEntryComputation(b.Build());
AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 2);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_EQ(AllReduceCount(*module), 2);
EXPECT_FALSE(changed);
}
// Tests that AllReduce ops with different groups are not combined.
TEST_F(AllReduceCombinerTest, GroupAllReduce) {
auto module = CreateNewVerifiedModule();
HloComputation::Builder b(TestName());
HloComputation* reduction = MakeReduction(HloOpcode::kAdd, module.get());
auto constant = b.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.3)));
auto crs0 = b.AddInstruction(
HloInstruction::CreateAllReduce(constant->shape(), {constant}, reduction,
CreateReplicaGroups({{0, 1}, {2, 3}}),
/*constrain_layout=*/false,
/*channel_id=*/nullopt));
auto crs1 = b.AddInstruction(
HloInstruction::CreateAllReduce(constant->shape(), {constant}, reduction,
CreateReplicaGroups({{0, 2}, {1, 3}}),
/*constrain_layout=*/false,
/*channel_id=*/nullopt));
b.AddInstruction(HloInstruction::CreateTuple({crs0, crs1}));
module->AddEntryComputation(b.Build());
AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 2);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_EQ(AllReduceCount(*module), 2);
EXPECT_FALSE(changed);
}
TEST_F(AllReduceCombinerTest, DomainPreventsCombining) {
const char* const hlo_string = R"(
HloModule Module
summit {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY entry {
param0 = f32[128] parameter(0), sharding={maximal device=0}
param1 = f32[128] parameter(1), sharding={maximal device=1}
crs0 = f32[128] all-reduce(param0),
replica_groups={}, to_apply=summit, sharding={maximal device=0}
crs1 = f32[128] all-reduce(param1),
replica_groups={}, to_apply=summit, sharding={maximal device=1}
domain0 = f32[128] domain(crs0),
domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}}, exit={maximal device=0}}
domain1 = f32[128] domain(crs1),
domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}}, exit={maximal device=1}}
ROOT tuple = (f32[128], f32[128]) tuple(domain0, domain1),
sharding={{maximal device=0}, {maximal device=1}}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 2);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_EQ(AllReduceCount(*module), 2);
EXPECT_FALSE(changed);
}
// This test checks that two CRS instructions that are in separate domains
// but with the same domain metadata can be combined.
TEST_F(AllReduceCombinerTest, CombineFromTwoDomainsWithSameMetadata) {
const char* const hlo_string = R"(
HloModule Module
summit {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY entry {
param0 = f32[128] parameter(0), sharding={maximal device=0}
param1 = f32[128] parameter(1), sharding={maximal device=1}
param2 = f32[128] parameter(2), sharding={maximal device=1}
crs0 = f32[128] all-reduce(param0),
replica_groups={}, to_apply=summit, sharding={maximal device=0}
crs1 = f32[128] all-reduce(param1),
replica_groups={}, to_apply=summit, sharding={maximal device=1}
crs2 = f32[128] all-reduce(param2),
replica_groups={}, to_apply=summit, sharding={maximal device=0}
domain0 = f32[128] domain(crs0),
domain={kind="sharding", entry={{maximal device=0}, {maximal device=1},
{maximal device=0}}, exit={maximal device=0}}
domain1 = f32[128] domain(crs1),
domain={kind="sharding", entry={{maximal device=0}, {maximal device=1},
{maximal device=0}}, exit={maximal device=1}}
domain2 = f32[128] domain(crs2),
domain={kind="sharding", entry={{maximal device=0}, {maximal device=1},
{maximal device=0}}, exit={maximal device=0}}
ROOT tuple = (f32[128], f32[128], f32[128]) tuple(domain0, domain1, domain2),
sharding={{maximal device=0}, {maximal device=1}, {maximal device=0}}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 3);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_EQ(AllReduceCount(*module), 2);
EXPECT_TRUE(changed);
}
TEST_F(AllReduceCombinerTest, DoNotCombineCrossShardAndCrosReplicaInSPMD) {
const char* const hlo_string = R"(
HloModule Module
summit {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY entry {
param0 = f32[128] parameter(0), sharding={maximal device=0}
param1 = f32[128] parameter(1), sharding={maximal device=1}
cross_shard_ar = f32[128] all-reduce(param0),
replica_groups={{0}}, to_apply=summit, channel_id=1
cross_replica_ar = f32[128] all-reduce(param1),
replica_groups={{0}}, to_apply=summit, sharding={maximal device=1}
ROOT tuple = (f32[128], f32[128]) tuple(cross_shard_ar, cross_replica_ar)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 2);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_EQ(AllReduceCount(*module), 2);
EXPECT_FALSE(changed);
}
TEST_F(AllReduceCombinerTest, CrossCoreAllReduce) {
const char* const hlo_string = R"(
HloModule Module
summit {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY entry {
param0 = f32[128] parameter(0), sharding={maximal device=0}
param1 = f32[128] parameter(1), sharding={maximal device=1}
crs00 = f32[128] all-reduce(param0),
replica_groups={{0}}, channel_id=1, to_apply=summit,
sharding={maximal device=0}
crs01 = f32[128] all-reduce(param1),
replica_groups={{0}}, channel_id=1, to_apply=summit,
sharding={maximal device=1}
crs10 = f32[128] all-reduce(param0),
replica_groups={{0}}, channel_id=2, to_apply=summit,
sharding={maximal device=0}
crs11 = f32[128] all-reduce(param1),
replica_groups={{0}}, channel_id=2, to_apply=summit,
sharding={maximal device=1}
domain0 = f32[128] domain(crs00),
domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
ROOT add = f32[128] add(domain0, crs11),
sharding={maximal device=1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 4);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_EQ(AllReduceCount(*module), 2);
EXPECT_TRUE(changed);
EXPECT_THAT(
module->entry_computation()->root_instruction(),
op::Add(op::Domain(op::GetTupleElement(
AllOf(op::AllReduce(op::Parameter(0), op::Parameter(0)),
op::Shape("(f32[128], f32[128])")),
1)),
op::GetTupleElement(
AllOf(op::AllReduce(op::Parameter(1), op::Parameter(1)),
op::Shape("(f32[128], f32[128])")),
0)));
}
TEST_F(AllReduceCombinerTest, CrossCombineGroupCycle) {
const char* const hlo_string = R"(
HloModule module
%add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
%max {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] maximum(lhs, rhs)
}
ENTRY %comp {
p0 = f32[128] parameter(0)
p1 = f32[128] parameter(1)
crs00 = f32[128] all-reduce(p0), to_apply=add
crs10 = f32[128] all-reduce(p1), to_apply=max
crs01 = f32[128] all-reduce(crs00), to_apply=max
crs11 = f32[128] all-reduce(crs10), to_apply=add
add0 = f32[128] add(crs01, crs11)
crs02 = f32[128] all-reduce(add0), to_apply=add
crs12 = f32[128] all-reduce(crs11), to_apply=add
ROOT tuple = (f32[128], f32[128]) tuple(crs02, crs12)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
ASSERT_EQ(AllReduceCount(*module), 6);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_EQ(AllReduceCount(*module), 4);
EXPECT_TRUE(changed);
auto crs0 = op::AllReduce(op::Parameter(0), op::AllReduce(op::Parameter(1)));
auto add = op::Add(op::AllReduce(op::GetTupleElement(crs0, 0)),
op::GetTupleElement(crs0, 1));
auto crs1 = op::AllReduce(add, op::GetTupleElement(crs0));
EXPECT_THAT(
module->entry_computation()->root_instruction(),
op::Tuple(op::GetTupleElement(crs1, 0), op::GetTupleElement(crs1, 1)));
}
} // namespace
} // namespace xla

View File

@ -149,7 +149,6 @@ struct AllReduceParticipantData {
explicit AllReduceParticipantData(RendezvousKey rendezvous_key)
: rendezvous_key(rendezvous_key) {}
int64 element_count;
int64 device_ordinal;
RendezvousKey rendezvous_key;
@ -157,20 +156,30 @@ struct AllReduceParticipantData {
// source_buffer == destination_buffer if that avoids a NCCL copy (will depend
// on how well the NCCL in-place implementation performs vs the out-of-place
// implementation).
se::DeviceMemoryBase source_data;
se::DeviceMemoryBase destination_data;
struct Buffer {
int64 element_count;
se::DeviceMemoryBase source_data;
se::DeviceMemoryBase destination_data;
PrimitiveType primitive_type;
};
std::vector<Buffer> buffers;
se::Stream* stream;
ReductionKind reduction_kind;
PrimitiveType primitive_type;
int num_participants() const { return rendezvous_key.num_participants(); }
string ToString() const {
std::vector<std::string> buffer_strs;
for (const Buffer& buffer : buffers) {
buffer_strs.push_back(
absl::StrFormat("{element_count=%d}", buffer.element_count));
}
return absl::StrFormat(
"AllReduceParticipantData{element_count=%d, rendezvous_key=%s, "
"AllReduceParticipantData{buffers=[%s], rendezvous_key=%s, "
"device_ordinal=%d, stream=%p}",
element_count, rendezvous_key.ToString(), device_ordinal, stream);
absl::StrJoin(buffer_strs, ","), rendezvous_key.ToString(),
device_ordinal, stream);
}
};
@ -245,7 +254,7 @@ class Rendezvous {
// Spot check for consistent replica counts among submitting threads.
if (!participants_.empty() &&
(participants_.back().element_count != participant.element_count ||
(participants_.back().buffers.size() != participant.buffers.size() ||
participants_.back().rendezvous_key != participant.rendezvous_key)) {
return InvalidArgument(
"Mismatch among all-reduce participants. Expected same "

View File

@ -262,7 +262,8 @@ class CpuAllReduceRendezvous : public xla::Rendezvous<std::nullptr_t> {
protected:
xla::StatusOr<std::pair<std::nullptr_t, bool>> SubmitParticipantImpl(
xla::AllReduceParticipantData participant) override {
xla::PrimitiveType datatype = participant.primitive_type;
TF_RET_CHECK(participant.buffers.size() == 1);
xla::PrimitiveType datatype = participant.buffers.front().primitive_type;
bool primary = [&] {
tensorflow::mutex_lock lock(mu_);
if (!initialized_) {
@ -316,10 +317,8 @@ class CpuAllReduceRendezvous : public xla::Rendezvous<std::nullptr_t> {
using T = typename xla::primitive_util::PrimitiveTypeToNative<PT>::type;
tensorflow::mutex_lock lock(mu_);
CHECK(!participants_.empty());
xla::int64 element_count = participant.element_count;
xla::ReductionKind reduction_kind = participant.reduction_kind;
for (const auto& p : participants_) {
CHECK_EQ(p.element_count, element_count);
CHECK(p.reduction_kind == reduction_kind);
}
@ -329,11 +328,19 @@ class CpuAllReduceRendezvous : public xla::Rendezvous<std::nullptr_t> {
output_buffers.reserve(participants_.size());
for (auto& p : participants_) {
input_buffers.emplace_back(static_cast<T*>(p.source_data.opaque()),
element_count);
output_buffers.emplace_back(static_cast<T*>(p.destination_data.opaque()),
element_count);
CHECK_EQ(p.buffers.size(), 1);
CHECK_EQ(p.buffers.front().element_count,
participants_.front().buffers.front().element_count);
xla::int64 element_count = participant.buffers.front().element_count;
input_buffers.emplace_back(
static_cast<T*>(p.buffers.front().source_data.opaque()),
element_count);
output_buffers.emplace_back(
static_cast<T*>(p.buffers.front().destination_data.opaque()),
element_count);
}
xla::int64 element_count =
participants_.front().buffers.front().element_count;
auto compute = [reduction_kind](T a, T b) -> T {
switch (reduction_kind) {
@ -416,7 +423,6 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
xla::RendezvousKey rendezvous_key(run_options->run_id(),
participating_replicas_vec, op_kind, op_id);
auto shape_str = ShapeString(shape_ptr, shape_length);
VLOG(2) << "All-reduce input/output shape : " << shape_str;
@ -426,14 +432,16 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
<< "All-reduce on CPU is implemented only for dense arrays";
xla::AllReduceParticipantData participant(rendezvous_key);
participant.element_count = xla::ShapeUtil::ElementsIn(shape);
participant.device_ordinal = device_ordinal;
participant.primitive_type = shape.element_type();
participant.stream = run_options->stream();
participant.source_data =
xla::AllReduceParticipantData::Buffer buffer;
buffer.element_count = xla::ShapeUtil::ElementsIn(shape);
buffer.primitive_type = shape.element_type();
buffer.source_data =
se::DeviceMemoryBase(input_buffer, xla::ShapeUtil::ByteSizeOf(shape));
participant.destination_data =
buffer.destination_data =
se::DeviceMemoryBase(output_buffer, xla::ShapeUtil::ByteSizeOf(shape));
participant.buffers = {buffer};
participant.reduction_kind = static_cast<xla::ReductionKind>(reduction_kind);
TF_CHECK_OK(

View File

@ -1131,6 +1131,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:algebraic_simplifier",
"//tensorflow/compiler/xla/service:all_reduce_combiner",
"//tensorflow/compiler/xla/service:batchnorm_expander",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:call_inliner",

View File

@ -42,15 +42,11 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
struct NcclAllReduceThunk::AuxData {};
NcclAllReduceThunk::NcclAllReduceThunk(
int64 replica_count, int64 element_count,
const BufferAllocation::Slice& source_buffer,
const BufferAllocation::Slice& destination_buffer,
int64 replica_count, std::vector<NcclAllReduceThunk::Buffer> buffers,
const HloInstruction* all_reduce)
: Thunk(Thunk::kNcclAllReduce, all_reduce),
replica_count_(replica_count),
element_count_(element_count),
source_buffer_(source_buffer),
destination_buffer_(destination_buffer) {}
buffers_(std::move(buffers)) {}
} // namespace gpu
} // namespace xla

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "llvm/IR/Verifier.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
@ -291,7 +292,13 @@ Status GpuCompiler::OptimizeHloModule(
horizontal_fusion.AddPass<HloDCE>();
TF_RETURN_IF_ERROR(horizontal_fusion.Run(hlo_module).status());
}
{
HloPassPipeline pipeline("all_reduce_combiner");
pipeline.AddPass<AllReduceCombiner>(
/*combine_threshold_in_bytes=*/30 * 1024 * 1024,
/*combine_threshold_count=*/256);
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
return Status::OK();
}

View File

@ -1210,10 +1210,7 @@ Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {
return Status::OK();
}
namespace {
} // namespace
namespace {} // namespace
Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
VLOG(2) << "AllReduce; replica count: " << hlo_module_config_.replica_count()
@ -1226,13 +1223,37 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
NcclAllReduceThunk::CanImplement(crs);
if (should_use_nccl_thunk) {
CHECK(crs->operand(0)->shape().IsArray())
<< "Operands to all-reduce must be arrays: " << crs->ToString();
AddThunkToThunkSequence(absl::make_unique<NcclAllReduceThunk>(
std::vector<NcclAllReduceThunk::Buffer> buffers;
std::vector<BufferAllocation::Slice> tuple_element_buffers;
buffers.resize(crs->operand_count());
tuple_element_buffers.reserve(crs->operand_count());
CHECK(crs->shape().IsArray() && crs->operand_count() == 1 ||
crs->shape().IsTuple() &&
crs->shape().tuple_shapes_size() == crs->operand_count());
for (int i = 0; i < crs->operand_count(); ++i) {
CHECK(crs->operand(i)->shape().IsArray())
<< "Operands to all-reduce must be arrays: " << crs->ToString();
buffers[i].element_count =
ShapeUtil::ElementsIn(crs->operand(i)->shape());
buffers[i].source_buffer = GetAllocationSlice(*crs->operand(i));
buffers[i].destination_buffer = GetAllocationSlice(
*crs, crs->shape().IsTuple() ? ShapeIndex({i}) : ShapeIndex({}));
tuple_element_buffers.push_back(buffers[i].destination_buffer);
}
auto all_reduce_thunk = absl::make_unique<NcclAllReduceThunk>(
/*replica_count=*/hlo_module_config_.replica_count(),
/*elements=*/ShapeUtil::ElementsIn(crs->operand(0)->shape()),
/*source_address=*/GetAllocationSlice(*crs->operand(0)),
/*destination_buffer=*/GetAllocationSlice(*crs), crs));
/*buffers=*/std::move(buffers), crs);
if (crs->shape().IsTuple()) {
std::vector<std::unique_ptr<Thunk>> thunks;
thunks.push_back(std::move(all_reduce_thunk));
thunks.push_back(absl::make_unique<TupleThunk>(
tuple_element_buffers, GetAllocationSlice(*crs), nullptr));
AddThunkToThunkSequence(
absl::make_unique<SequentialThunk>(std::move(thunks), crs));
} else {
AddThunkToThunkSequence(std::move(all_reduce_thunk));
}
return Status::OK();
}
@ -1957,32 +1978,32 @@ void IrEmitterUnnested::EmitTile(
//
// TODO(cheshire): Once ptxas is fixed and TF switches to it, remove the
// workaround.
ksl->For(
loop_name + "_y_in_tile",
/*start=*/constant(0),
/*end=*/
ceil_of_ratio(b_.CreateSub(tile_height, thread_id_info.thread_id_y),
num_threads_y),
/*step=*/constant(1), [&](llvm::Value* y_indvar) {
llvm::Value* y_loc = b_.CreateAdd(
thread_id_info.thread_id_y, b_.CreateMul(y_indvar, num_threads_y));
for (int64 j = 0; j < x_num_steps; j++) {
llvm::Value* x_loc =
b_.CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
IrArray::Index source_idx_x =
source_idx.AddOffsetToDim(y_loc, kDimY, &b_)
.AddOffsetToDim(constant(j * step_x), kDimX, &b_);
auto emit_element = [&] {
return emit_elem_function(source_idx_x, y_loc, x_loc, j);
};
if (!x_tile_fits) {
ksl->If(loop_name + "_x_in_tile",
b_.CreateICmpULT(x_loc, tile_width), emit_element);
} else {
emit_element();
}
}
});
ksl->For(loop_name + "_y_in_tile",
/*start=*/constant(0),
/*end=*/
ceil_of_ratio(b_.CreateSub(tile_height, thread_id_info.thread_id_y),
num_threads_y),
/*step=*/constant(1), [&](llvm::Value* y_indvar) {
llvm::Value* y_loc =
b_.CreateAdd(thread_id_info.thread_id_y,
b_.CreateMul(y_indvar, num_threads_y));
for (int64 j = 0; j < x_num_steps; j++) {
llvm::Value* x_loc =
b_.CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
IrArray::Index source_idx_x =
source_idx.AddOffsetToDim(y_loc, kDimY, &b_)
.AddOffsetToDim(constant(j * step_x), kDimX, &b_);
auto emit_element = [&] {
return emit_elem_function(source_idx_x, y_loc, x_loc, j);
};
if (!x_tile_fits) {
ksl->If(loop_name + "_x_in_tile",
b_.CreateICmpULT(x_loc, tile_width), emit_element);
} else {
emit_element();
}
}
});
}
// Emits code to process a tensor element in a tile for the given kCopy HLO that

View File

@ -154,10 +154,6 @@ ncclRedOp_t ReductionKindToNccl(ReductionKind kind) {
}
}
PrimitiveType AllReducePrimitiveType(const HloInstruction* instr) {
return instr->operand(0)->shape().element_type();
}
absl::optional<ncclDataType_t> DatatypeToNccl(PrimitiveType element_type) {
switch (element_type) {
case S8:
@ -402,9 +398,6 @@ RendezvousNcclAllReduce::SubmitParticipantImpl(
VLOG(3) << "Performing all reduce from device ordinal: "
<< participant.device_ordinal;
ncclRedOp_t computation = ReductionKindToNccl(participant.reduction_kind);
absl::optional<ncclDataType_t> allreduce_datatype =
DatatypeToNccl(participant.primitive_type);
CHECK(allreduce_datatype.has_value());
se::StreamExecutor* executor = participant.stream->parent();
se::cuda::ScopedActivateExecutorContext scoped_context(executor);
@ -412,19 +405,26 @@ RendezvousNcclAllReduce::SubmitParticipantImpl(
participant.stream->implementation()->GpuStreamMemberHack());
VLOG(3) << "Using stream pointer: " << cu_stream
<< " on device: " << participant.device_ordinal;
void* send_buffer = participant.source_data.opaque();
void* recv_buffer = participant.destination_data.opaque();
VLOG(3) << absl::StreamFormat(
"Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, "
"comm=%p, stream=%p)",
send_buffer, recv_buffer, participant.element_count,
static_cast<const void*>(comm), cu_stream);
XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer,
/*count=*/participant.element_count,
/*datatype=*/*allreduce_datatype,
/*op=*/computation,
/*comm=*/comm,
/*stream=*/*cu_stream));
XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
for (auto& buffer : participant.buffers) {
void* send_buffer = buffer.source_data.opaque();
void* recv_buffer = buffer.destination_data.opaque();
absl::optional<ncclDataType_t> allreduce_datatype =
DatatypeToNccl(buffer.primitive_type);
CHECK(allreduce_datatype.has_value());
VLOG(3) << absl::StreamFormat(
"Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, "
"comm=%p, stream=%p)",
send_buffer, recv_buffer, buffer.element_count,
static_cast<const void*>(comm), cu_stream);
XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer,
/*count=*/buffer.element_count,
/*datatype=*/*allreduce_datatype,
/*op=*/computation,
/*comm=*/comm,
/*stream=*/*cu_stream));
}
XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
VLOG(3) << "Done performing all reduce for ordinal: "
<< participant.device_ordinal;
@ -453,11 +453,14 @@ struct NcclAllReduceThunk::AuxData {
};
/*static*/ bool NcclAllReduceThunk::CanImplement(const HloInstruction* crs) {
auto operands_are_supported = [crs]() {
return absl::c_all_of(crs->operands(), [](HloInstruction* operand) {
return LayoutUtil::IsDenseArray(operand->shape()) &&
DatatypeToNccl(operand->shape().element_type()).has_value();
});
};
return MatchReductionComputation(crs->to_apply()).has_value() &&
DatatypeToNccl(AllReducePrimitiveType(crs)).has_value() &&
crs->IsCrossReplicaAllReduce() &&
crs->operand_count() == 1 && // One array to reduce.
LayoutUtil::IsDenseArray(crs->operand(0)->shape());
crs->IsCrossReplicaAllReduce() && operands_are_supported();
}
/*static*/ absl::flat_hash_set<int>
@ -471,16 +474,14 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
}
NcclAllReduceThunk::NcclAllReduceThunk(
int64 replica_count, int64 element_count,
const BufferAllocation::Slice& source_buffer,
const BufferAllocation::Slice& destination_buffer,
int64 replica_count, std::vector<NcclAllReduceThunk::Buffer> buffers,
const HloInstruction* all_reduce)
: Thunk(Thunk::kNcclAllReduce, all_reduce),
replica_count_(replica_count),
element_count_(element_count),
source_buffer_(source_buffer),
destination_buffer_(destination_buffer),
aux_data_(absl::make_unique<AuxData>()) {}
buffers_(std::move(buffers)),
aux_data_(absl::make_unique<AuxData>()) {
CHECK_EQ(hlo_instruction()->operand_count(), buffers_.size());
}
// Figures out which devices (named by their replica-ids) are participating in
// the all-reduce subgroup that contains device_ordinal.
@ -506,18 +507,24 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
<< absl::StrJoin(participating_replicas, ", ");
AllReduceParticipantData participant(rendezvous_key);
participant.element_count = element_count_;
participant.device_ordinal = device_ordinal;
participant.source_data =
params.buffer_allocations->GetDeviceAddress(source_buffer_);
participant.destination_data =
params.buffer_allocations->GetDeviceAddress(destination_buffer_);
for (size_t i = 0; i < buffers_.size(); ++i) {
const NcclAllReduceThunk::Buffer& buffer = buffers_[i];
AllReduceParticipantData::Buffer pbuffer;
pbuffer.element_count = buffer.element_count;
pbuffer.source_data =
params.buffer_allocations->GetDeviceAddress(buffer.source_buffer);
pbuffer.destination_data =
params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer);
pbuffer.primitive_type =
hlo_instruction()->operand(i)->shape().element_type();
participant.buffers.push_back(pbuffer);
}
participant.stream = params.stream;
auto reduction_kind =
MatchReductionComputation(hlo_instruction()->to_apply());
CHECK(reduction_kind.has_value());
participant.reduction_kind = *reduction_kind;
participant.primitive_type = AllReducePrimitiveType(hlo_instruction());
TF_ASSIGN_OR_RETURN(
std::shared_ptr<NcclClique> clique,

View File

@ -50,9 +50,12 @@ class NcclAllReduceThunk : public Thunk {
// TODO(b/125951860): Support all-reduces with replica groups, i.e.
// all-reduces that compute multiple sums across subsets of all replicas.
NcclAllReduceThunk(int64 replica_count, int64 element_count,
const BufferAllocation::Slice& source_buffer,
const BufferAllocation::Slice& destination_buffer,
struct Buffer {
int64 element_count;
BufferAllocation::Slice source_buffer;
BufferAllocation::Slice destination_buffer;
};
NcclAllReduceThunk(int64 replica_count, std::vector<Buffer> buffers,
const HloInstruction* all_reduce);
~NcclAllReduceThunk() override;
@ -70,9 +73,7 @@ class NcclAllReduceThunk : public Thunk {
struct AuxData;
const int64 replica_count_;
const int64 element_count_;
const BufferAllocation::Slice source_buffer_;
const BufferAllocation::Slice destination_buffer_;
const std::vector<Buffer> buffers_;
std::unique_ptr<AuxData> aux_data_;
};

View File

@ -368,6 +368,55 @@ XLA_TEST_F(CollectiveOpsTest, AllReduce_ManyConcurrentAllReduces) {
done.Wait();
}
// Runs the same executable many times concurrently. The all-reduces should not
// conflict with one another.
XLA_TEST_F(CollectiveOpsTest, AllReduce_CombinableAllReduces) {
std::string hlo_string = R"(
HloModule test
apply_op {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT apply_op = f32[] add(x, y)
}
ENTRY test_computation {
p0 = f32[5] parameter(0)
p1 = f32[5] parameter(1)
crs0 = f32[5] all-reduce(p0), replica_groups={}, to_apply=apply_op
crs1 = f32[5] all-reduce(p1), replica_groups={}, to_apply=apply_op
ROOT out = (f32[5], f32[5]) tuple(f32[5] crs0, f32[5] crs1)
}
)";
static constexpr int kNumReplicas = 2;
auto config = GetModuleConfigForTest();
config.set_replica_count(kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string, config));
std::vector<float> input0_vec = {1., 2., 3., 4., 5.};
auto input0_literal = LiteralUtil::CreateR1<float>(input0_vec);
std::vector<float> input1_vec = {7., 3., 4., 1., 2.};
auto input1_literal = LiteralUtil::CreateR1<float>(input1_vec);
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), {&input0_literal, &input1_literal},
/*num_replicas=*/kNumReplicas,
/*use_threads=*/true));
std::vector<float> expected0_vec = {2., 4., 6., 8., 10.};
auto expected0_literal = LiteralUtil::CreateR1<float>(expected0_vec);
std::vector<float> expected1_vec = {14., 6., 8., 2., 4.};
auto expected1_literal = LiteralUtil::CreateR1<float>(expected1_vec);
for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
auto rs = results[replica_idx].DecomposeTuple();
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected0_literal, rs[0],
ErrorSpec{1e-5, 1e-5}));
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected1_literal, rs[1],
ErrorSpec{1e-5, 1e-5}));
}
}
// Runs an all-reduce with three partitions:
// {0}, {1,2}, {3}
// meaning, the all-reduce is a nop for devices 0 and 3, and only devices 1 and