[XLA/GPU] Simplify reduction implementation.

Specifically, before this change:
* output_instructions are created by calls,
* output_instructions are passed around,
* output_instructions are used to reverse-lookup ShapeIndex to find IrArrays.

After this change:
* there is no output_instructions. It's derived by the callee.
* instr_index_group gets passed around instead.
* output_instructions are derived from unnested_hlo and index.
* ShapeIndex is derived from index.

This also removes some footages of the type "HloInstruction", making later transitions to MLIR easier.

PiperOrigin-RevId: 344114759
Change-Id: I2fe7642e44d1c639453faa32fe5f128167ee6291
This commit is contained in:
Tim Shen 2020-11-24 12:52:15 -08:00 committed by TensorFlower Gardener
parent 22e335d781
commit faa6548fba
5 changed files with 132 additions and 124 deletions

View File

View File

@ -496,32 +496,23 @@ llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) {
return b->CreateAnd(is_thread0, is_block0);
}
bool AreFusedReductionOutputsConsistent(
absl::Span<const HloInstruction* const> output_instructions,
const HloInstruction* first_reduce) {
for (const HloInstruction* inst : output_instructions) {
if (IsReductionFromOrToContiguousDimensions(*inst)) {
// Shapes, layouts and dimensions must be the same for all reduces
// inside of this fusion.
// TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
if (!(ShapeUtil::Equal(first_reduce->shape(), inst->shape()) &&
ShapeUtil::Equal(first_reduce->operand(0)->shape(),
inst->operand(0)->shape()) &&
ShapeUtil::Equal(first_reduce->operand(1)->shape(),
inst->operand(1)->shape()) &&
first_reduce->dimensions() == inst->dimensions())) {
return false;
}
} else {
if (!(ShapeUtil::CompatibleIgnoringElementType(
first_reduce->operand(0)->shape(), inst->shape()) &&
LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
inst->shape().layout()))) {
return false;
}
}
bool IsFusedReductionOutputConsistent(const HloInstruction* inst,
const HloInstruction* first_reduce) {
if (IsReductionFromOrToContiguousDimensions(*inst)) {
// Shapes, layouts and dimensions must be the same for all reduces
// inside of this fusion.
// TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
return ShapeUtil::Equal(first_reduce->shape(), inst->shape()) &&
ShapeUtil::Equal(first_reduce->operand(0)->shape(),
inst->operand(0)->shape()) &&
ShapeUtil::Equal(first_reduce->operand(1)->shape(),
inst->operand(1)->shape()) &&
first_reduce->dimensions() == inst->dimensions();
}
return true;
return ShapeUtil::CompatibleIgnoringElementType(
first_reduce->operand(0)->shape(), inst->shape()) &&
LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
inst->shape().layout());
}
} // namespace gpu

View File

@ -217,10 +217,18 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
// block 0 of the kernel.
llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b);
// Returns whether the outputs of a fusion with reduction are consistent.
bool AreFusedReductionOutputsConsistent(
// Returns whether the output of a fusion with reduction are consistent with
// `first_reduce`.
bool IsFusedReductionOutputConsistent(const HloInstruction* inst,
const HloInstruction* first_reduce);
inline bool AreFusedReductionOutputsConsistent(
absl::Span<const HloInstruction* const> output_instructions,
const HloInstruction* first_reduce);
const HloInstruction* first_reduce) {
return absl::c_all_of(output_instructions, [=](const HloInstruction* inst) {
return IsFusedReductionOutputConsistent(inst, first_reduce);
});
}
} // namespace gpu
} // namespace xla

View File

@ -1247,8 +1247,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
}
CHECK_GE(root->operand_count(), 1);
return EmitReductionFromOrToContiguousDimensions(fusion,
root->operands());
return EmitReductionFromOrToContiguousDimensions(fusion);
}
case HloOpcode::kReduce: {
// HandleFusion specializes reduction from a multi-dimensional array to
@ -1259,7 +1258,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
return Unimplemented(
"Vectorized variadic reduce is not supported on GPU");
}
return EmitReductionFromOrToContiguousDimensions(fusion, {root});
return EmitReductionFromOrToContiguousDimensions(fusion);
}
case HloOpcode::kSlice: {
return EmitInputFusibleNonStridedSlices(fusion);
@ -1385,7 +1384,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
if (IsReductionFromOrToContiguousDimensions(*reduce) &&
reduce->shape().IsArray()) {
return EmitReductionFromOrToContiguousDimensions(reduce, {reduce});
return EmitReductionFromOrToContiguousDimensions(reduce);
}
return DefaultActionForMlir(input);
@ -3673,7 +3672,7 @@ void IrEmitterUnnested::EmitEpilogueForReduction(
llvm::Type* index_ty, HloInstruction* unnested_hlo,
const ReductionCodegenInfo& reduction_info,
absl::Span<const HloInstruction* const> reduce_instructions,
absl::Span<const ShapeIndex> reduction_output_shape_indices,
absl::Span<const int> reduction_output_indices,
absl::Span<HloComputation* const> reducers,
const TilingKernelInfo& tiling_kernel_info) {
const KernelMappingScheme& mapping_scheme =
@ -3728,8 +3727,13 @@ void IrEmitterUnnested::EmitEpilogueForReduction(
// At this point in the function we have a "partial sum" of input elements
// (stored in partial_result_addresses), and we need to accumulate it into
// the correct output element.
auto output_array = GetIrArray(*unnested_hlo, *unnested_hlo,
reduction_output_shape_indices[i]);
ShapeIndex index;
if (unnested_hlo->IsMultiOutputFusion()) {
index.push_back(reduction_output_indices[i]);
} else {
CHECK_EQ(0, reduction_output_indices[i]);
}
auto output_array = GetIrArray(*unnested_hlo, *unnested_hlo, index);
IrArray::Index element_index(
/*linear=*/untransposed_output_linear_address,
reduction_kept_element_shape, &b_);
@ -3872,31 +3876,28 @@ void IrEmitterUnnested::EmitPrintfWithThreadId(
});
}
namespace {
// Obtains the corresponding index of the out_instr in the outputs of the
// `unnested_hlo`.
ShapeIndex CreateShapeIndexForOutputInstruction(
const HloInstruction& unnested_hlo, const HloInstruction& out_instr) {
if (!unnested_hlo.IsMultiOutputFusion()) {
return ShapeIndex({});
static HloInstruction* GetReduceFromUnnested(HloInstruction* unnested_hlo,
int index) {
if (unnested_hlo->opcode() == HloOpcode::kReduce) {
CHECK_EQ(0, index);
return unnested_hlo;
}
const auto& all_outputs = unnested_hlo.fused_expression_root()->operands();
for (size_t i = 0; i < all_outputs.size(); ++i) {
if (all_outputs[i] == &out_instr) {
return ShapeIndex({static_cast<int64>(i)});
if (unnested_hlo->opcode() == HloOpcode::kFusion) {
auto root = unnested_hlo->fused_expression_root();
if (root->opcode() == HloOpcode::kReduce) {
CHECK_EQ(0, index);
return root;
}
if (root->opcode() == HloOpcode::kTuple) {
return root->mutable_operand(index);
}
}
LOG(FATAL) << " Fusion root does not contain output instruction; "
<< " fusion: " << unnested_hlo.ToString()
<< ", output instruction: " << out_instr.ToString();
return nullptr;
}
} // namespace
void IrEmitterUnnested::EmitTileElementForReduction(
HloInstruction* unnested_hlo, const Shape& reduction_operand_shape,
absl::Span<HloInstruction* const> output_instructions,
absl::Span<const int> instr_index_group,
const llvm_ir::IrArray::Index& index,
const ReductionCodegenInfo& reduction_info,
absl::Span<HloComputation* const> reducers, int64 x_iter_num) {
@ -3915,10 +3916,12 @@ void IrEmitterUnnested::EmitTileElementForReduction(
if (unnested_hlo->opcode() == HloOpcode::kFusion) {
BindFusionArguments(unnested_hlo, &fused_emitter);
for (int i = 0, e = output_instructions.size(); i != e; ++i) {
const HloInstruction* inst = output_instructions[i];
ShapeIndex idx =
CreateShapeIndexForOutputInstruction(*unnested_hlo, *inst);
for (int index : instr_index_group) {
const HloInstruction* inst = GetReduceFromUnnested(unnested_hlo, index);
ShapeIndex idx;
if (unnested_hlo->IsMultiOutputFusion()) {
idx.push_back(index);
}
if (IsReductionFromOrToContiguousDimensions(*inst)) {
input_gens.push_back(*fused_emitter.GetGenerator(inst->operand(0)));
} else {
@ -4689,22 +4692,20 @@ ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo(
}
void IrEmitterUnnested::EmitIRForReduction(
HloInstruction* unnested_hlo,
absl::Span<HloInstruction* const> output_instructions,
HloInstruction* unnested_hlo, absl::Span<const int> instr_index_group,
ReductionCodegenInfo* reduction_info, const Shape& input_shape) {
std::vector<HloInstruction*> reduce_instructions;
InlinedVector<ShapeIndex, 1> reduction_output_shape_indices;
InlinedVector<int, 1> reduction_output_indices;
InlinedVector<HloComputation*, 1> reducers;
for (size_t i = 0; i < output_instructions.size(); ++i) {
if (!IsReductionFromOrToContiguousDimensions(*output_instructions[i])) {
for (int index : instr_index_group) {
HloInstruction* output_instruction =
GetReduceFromUnnested(unnested_hlo, index);
if (!IsReductionFromOrToContiguousDimensions(*output_instruction)) {
continue;
}
HloInstruction* output_instruction = output_instructions[i];
reduce_instructions.push_back(output_instruction);
reduction_output_shape_indices.push_back(
CreateShapeIndexForOutputInstruction(*unnested_hlo,
*output_instruction));
reduction_output_indices.push_back(index);
reducers.push_back(output_instruction->to_apply());
}
CHECK(reduce_instructions.size() != 0)
@ -4722,7 +4723,7 @@ void IrEmitterUnnested::EmitIRForReduction(
[&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
llvm::Value* x_loc, int64 x_iter_num) {
EmitTileElementForReduction(unnested_hlo, input_shape,
output_instructions, index, *reduction_info,
instr_index_group, index, *reduction_info,
reducers, x_iter_num);
};
@ -4736,7 +4737,7 @@ void IrEmitterUnnested::EmitIRForReduction(
emit_reduction_tile);
});
EmitEpilogueForReduction(index_ty, unnested_hlo, *reduction_info,
reduce_instructions, reduction_output_shape_indices,
reduce_instructions, reduction_output_indices,
reducers, tiling_kernel_info);
}
@ -4751,33 +4752,33 @@ bool IsBroadcastedConstantOrScalar(const HloInstruction& instr) {
ShapeUtil::IsScalar(instr.operand(0)->shape())));
}
// Divides output_instructions into groups. Different groups will be executed
// Divides `num_reduces` reduces into groups. Different groups will be executed
// in parallel. Generally speaking, we'd like to run the reduce instructions
// in parallel without incurring too much recomputation overhead. The current
// heuristic is to place reduce instructions who share nothing or only
// (broadcasted) scalars/constants into different groups; otherwise, they are
// placed in the same group. Non-reduce instructions always go with the reduce
// instructions into the same group so long as they share any predecessors.
std::vector<std::vector<HloInstruction*>> DivideOutputInstructionsIntoGroups(
HloInstruction* unnested_hlo,
absl::Span<HloInstruction* const> output_instructions) {
CHECK(!output_instructions.empty());
if (output_instructions.size() == 1) {
return {{output_instructions[0]}};
std::vector<std::vector<int>> DivideOutputInstructionsIntoGroups(
HloInstruction* unnested_hlo, int num_reduces) {
CHECK_NE(0, num_reduces);
if (num_reduces == 1) {
return {{0}};
}
std::vector<tensorflow::UnionFind<HloInstruction*>> disjoint_sets(
output_instructions.size());
for (size_t i = 0; i < output_instructions.size(); ++i) {
disjoint_sets[i].Get() = output_instructions[i];
num_reduces);
for (size_t i = 0; i < num_reduces; ++i) {
disjoint_sets[i].Get() = GetReduceFromUnnested(unnested_hlo, i);
}
std::unique_ptr<HloReachabilityMap> reachability_map =
HloReachabilityMap::Build(unnested_hlo->fused_instructions_computation());
for (auto* instr : unnested_hlo->fused_instructions()) {
std::vector<int64> reached_output_ids;
for (size_t oid = 0; oid < output_instructions.size(); ++oid) {
if (HloOpcode::kReduce == output_instructions[oid]->opcode() &&
for (size_t oid = 0; oid < num_reduces; ++oid) {
auto reduce = GetReduceFromUnnested(unnested_hlo, oid);
if (HloOpcode::kReduce == reduce->opcode() &&
(IsBroadcastedConstantOrScalar(*instr))) {
// Do not group output reduce instructions through broadcasted
// constants or scalars, as the recomputation should be acceptable.
@ -4785,9 +4786,9 @@ std::vector<std::vector<HloInstruction*>> DivideOutputInstructionsIntoGroups(
continue;
}
// Now group output instructions if they have common predecessors.
if (reachability_map->IsReachable(instr, output_instructions[oid])) {
VLOG(3) << "Reaching " << output_instructions[oid]->ToString()
<< " from " << instr->ToString();
if (reachability_map->IsReachable(instr, reduce)) {
VLOG(3) << "Reaching " << reduce->ToString() << " from "
<< instr->ToString();
reached_output_ids.push_back(oid);
}
}
@ -4797,12 +4798,12 @@ std::vector<std::vector<HloInstruction*>> DivideOutputInstructionsIntoGroups(
}
}
// Place output instructions in the same set into the same group.
absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>> groups;
for (size_t oid = 0; oid < output_instructions.size(); ++oid) {
groups[disjoint_sets[oid].Get()].push_back(output_instructions.at(oid));
absl::flat_hash_map<HloInstruction*, std::vector<int>> groups;
for (size_t oid = 0; oid < num_reduces; ++oid) {
groups[disjoint_sets[oid].Get()].push_back(oid);
}
std::vector<std::vector<HloInstruction*>> ret;
std::vector<std::vector<int>> ret;
absl::c_for_each(
groups, [&](auto& iter) { ret.emplace_back(std::move(iter.second)); });
return ret;
@ -4811,15 +4812,20 @@ std::vector<std::vector<HloInstruction*>> DivideOutputInstructionsIntoGroups(
} // namespace
Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
HloInstruction* unnested_hlo,
absl::Span<HloInstruction* const> output_instructions) {
bool returns_tuple = output_instructions.size() > 1;
HloInstruction* unnested_hlo) {
int num_reduces = 1;
if (unnested_hlo->IsMultiOutputFusion()) {
num_reduces = unnested_hlo->fused_expression_root()->operand_count();
}
bool returns_tuple = num_reduces > 1;
VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString();
// Build an initializer thunk to initialize each reduction output.
std::vector<std::unique_ptr<Thunk>> thunks;
for (int i = 0; i < output_instructions.size(); ++i) {
if (!IsReductionFromOrToContiguousDimensions(*output_instructions[i])) {
for (int i = 0; i < num_reduces; ++i) {
if (!IsReductionFromOrToContiguousDimensions(
*GetReduceFromUnnested(unnested_hlo, i))) {
continue;
}
@ -4831,17 +4837,20 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
// Build a kernel thunk to compute all the outputs.
const HloInstruction* first_reduce = nullptr;
for (int i = 0; i < output_instructions.size(); ++i) {
if (IsReductionFromOrToContiguousDimensions(*output_instructions[i])) {
first_reduce = output_instructions[i];
for (int i = 0; i < num_reduces; ++i) {
if (IsReductionFromOrToContiguousDimensions(
*GetReduceFromUnnested(unnested_hlo, i))) {
first_reduce = GetReduceFromUnnested(unnested_hlo, i);
break;
}
}
CHECK(first_reduce);
if (output_instructions.size() > 1) {
if (!AreFusedReductionOutputsConsistent(output_instructions,
first_reduce)) {
return InternalError("Inconsistent reduction fusion outputs");
if (num_reduces > 1) {
for (int i = 1; i < num_reduces; i++) {
if (!IsFusedReductionOutputConsistent(
GetReduceFromUnnested(unnested_hlo, i), first_reduce)) {
return InternalError("Inconsistent reduction fusion outputs");
}
}
}
const Shape& input_shape = first_reduce->operand(0)->shape();
@ -4852,24 +4861,24 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
<< first_reduce->ToString();
// Group output instructions. Each group will be executed in parallel.
std::vector<std::vector<HloInstruction*>> instr_groups =
DivideOutputInstructionsIntoGroups(unnested_hlo, output_instructions);
VLOG(2) << StrCat("Generate in ", instr_groups.size(), " groups for ",
std::vector<std::vector<int>> instr_index_groups =
DivideOutputInstructionsIntoGroups(unnested_hlo, num_reduces);
VLOG(2) << StrCat("Generate in ", instr_index_groups.size(), " groups for ",
unnested_hlo->ToString());
std::unique_ptr<KernelThunk> kernel_thunk =
BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false);
KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
for (size_t i = 0; i < instr_groups.size(); ++i) {
for (size_t i = 0; i < instr_index_groups.size(); ++i) {
// Create a new ReductionCodegenInfo instance as it contains states for
// code generation per reduction group. For now, let's always use the very
// first reduce as representative to construct ReductionCodegenInfo, since
// all the reductions are required to have the same shape and layout as
// verified by `AreFusedReductionOutputsConsistent()`. We can loosen the
// verified by `IsFusedReductionOutputConsistent()`. We can loosen the
// constraint later when the needs arise.
ReductionCodegenInfo reduction_info =
ComputeReductionCodegenInfo(unnested_hlo, first_reduce);
auto emit_reduction_func = [&] {
EmitIRForReduction(unnested_hlo, instr_groups[i], &reduction_info,
EmitIRForReduction(unnested_hlo, instr_index_groups[i], &reduction_info,
input_shape);
};
// Use raw block_id_y to select the i-th parallel reduction to run. Using
@ -4878,7 +4887,7 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
// the indices used within the reductions.
llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic(
gpu::TargetIntrinsicID::kBlockIdy, {}, {}, &b_);
llvm_ir::AddRangeMetadata(0, instr_groups.size(),
llvm_ir::AddRangeMetadata(0, instr_index_groups.size(),
llvm::cast<llvm::Instruction>(raw_block_id_y));
llvm::Value* guarding_cond =
b_.CreateICmpEQ(raw_block_id_y, b_.getInt32(i));
@ -4888,11 +4897,11 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
ComputeReductionCodegenInfo(unnested_hlo, first_reduce);
const KernelMappingScheme& mapping_scheme =
reduction_info.GetKernelMappingScheme();
// block_y_count is set to instr_groups.size(), so that each reduction group
// can be run in parallel by a different BlockIdy.
// block_y_count is set to instr_index_groups.size(), so that each reduction
// group can be run in parallel by a different BlockIdy.
LaunchDimensions launch_dimensions(
{/*x=*/mapping_scheme.GetNumberOfBlocks(),
/*y=*/static_cast<int64>(instr_groups.size()),
/*y=*/static_cast<int64>(instr_index_groups.size()),
/*z=*/1},
{/*x=*/mapping_scheme.GetThreadsPerBlock(), /*y=*/1, /*z=*/1});
VLOG(3) << "Launch dimensions of " << unnested_hlo->name()

View File

@ -405,13 +405,8 @@ class IrEmitterUnnested : public IrEmitter,
// complicating the index calculation in the code generation of the reduce
// instructions. In other words, a block_id_y is assigned to a group and so
// different groups can be run in parallel.
//
// output_instructions: Output instructions in the computation: instruction
// itself if it's not a fusion, fusion root if fusion is not multi-output, and
// elements of the fusion multi-output tuple otherwise.
Status EmitReductionFromOrToContiguousDimensions(
HloInstruction* unnested_hlo,
absl::Span<HloInstruction* const> output_instructions);
HloInstruction* unnested_hlo);
// Computes the KernelMappingScheme for the reduce HLO and indicates whether
// the reduction is a row reduction. For an un-fused reduce op, unnested_hlo
@ -555,12 +550,17 @@ class IrEmitterUnnested : public IrEmitter,
//
// Calculates and stores the temporary reduction value in the corresponding
// alloca.
void EmitTileElementForReduction(
HloInstruction* unnested_hlo, const Shape& reduction_operand_shape,
absl::Span<HloInstruction* const> output_instructions,
const llvm_ir::IrArray::Index& index,
const ReductionCodegenInfo& reduction_info,
absl::Span<HloComputation* const> reducers, int64 x_iter_num);
//
// `instr_index_group` indicates a set of reductions this call needs to emit,
// each i points to the ith output of unnested_hlo. Notice that if
// unnested_hlo is not a multi-output fusion, instr_index_group is always {0}.
void EmitTileElementForReduction(HloInstruction* unnested_hlo,
const Shape& reduction_operand_shape,
absl::Span<const int> instr_index_group,
const llvm_ir::IrArray::Index& index,
const ReductionCodegenInfo& reduction_info,
absl::Span<HloComputation* const> reducers,
int64 x_iter_num);
// Prepares for the code generation for a tile block of a reduction kernel.
//
@ -577,13 +577,13 @@ class IrEmitterUnnested : public IrEmitter,
llvm::Type* index_ty, HloInstruction* unnested_hlo,
const ReductionCodegenInfo& reduction_info,
absl::Span<const HloInstruction* const> reduce_instructions,
absl::Span<const ShapeIndex> reduction_output_shape_indices,
absl::Span<const int> reduction_output_indices,
absl::Span<HloComputation* const> reducers,
const TilingKernelInfo& tiling_kernel_info);
// Emits code for reductions in the output_instructions.
// Emits code for reductions in the instr_index_group.
void EmitIRForReduction(HloInstruction* unnested_hlo,
absl::Span<HloInstruction* const> output_instructions,
absl::Span<const int> instr_index_group,
ReductionCodegenInfo* reduction_info,
const Shape& input_shape);