From faa6548fba44e4fbb8ac596f8bea7e4a6247ea4d Mon Sep 17 00:00:00 2001 From: Tim Shen Date: Tue, 24 Nov 2020 12:52:15 -0800 Subject: [PATCH] [XLA/GPU] Simplify reduction implementation. Specifically, before this change: * output_instructions are created by calls, * output_instructions are passed around, * output_instructions are used to reverse-lookup ShapeIndex to find IrArrays. After this change: * there is no output_instructions. It's derived by the callee. * instr_index_group gets passed around instead. * output_instructions are derived from unnested_hlo and index. * ShapeIndex is derived from index. This also removes some footages of the type "HloInstruction", making later transitions to MLIR easier. PiperOrigin-RevId: 344114759 Change-Id: I2fe7642e44d1c639453faa32fe5f128167ee6291 --- .../xla/service/gpu/cudnn_batchnorm_runner.h | 0 .../xla/service/gpu/ir_emission_utils.cc | 41 ++--- .../xla/service/gpu/ir_emission_utils.h | 14 +- .../xla/service/gpu/ir_emitter_unnested.cc | 171 +++++++++--------- .../xla/service/gpu/ir_emitter_unnested.h | 30 +-- 5 files changed, 132 insertions(+), 124 deletions(-) mode change 100755 => 100644 tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h old mode 100755 new mode 100644 diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 53474dcdc66..d00ca4ca4e8 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -496,32 +496,23 @@ llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) { return b->CreateAnd(is_thread0, is_block0); } -bool AreFusedReductionOutputsConsistent( - absl::Span output_instructions, - const HloInstruction* first_reduce) { - for (const HloInstruction* inst : output_instructions) { - if (IsReductionFromOrToContiguousDimensions(*inst)) { - // Shapes, layouts and dimensions must be the same for all reduces - // inside of this fusion. - // TODO(tjoerg): Relax the shape constraint. The datatype does not matter. - if (!(ShapeUtil::Equal(first_reduce->shape(), inst->shape()) && - ShapeUtil::Equal(first_reduce->operand(0)->shape(), - inst->operand(0)->shape()) && - ShapeUtil::Equal(first_reduce->operand(1)->shape(), - inst->operand(1)->shape()) && - first_reduce->dimensions() == inst->dimensions())) { - return false; - } - } else { - if (!(ShapeUtil::CompatibleIgnoringElementType( - first_reduce->operand(0)->shape(), inst->shape()) && - LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(), - inst->shape().layout()))) { - return false; - } - } +bool IsFusedReductionOutputConsistent(const HloInstruction* inst, + const HloInstruction* first_reduce) { + if (IsReductionFromOrToContiguousDimensions(*inst)) { + // Shapes, layouts and dimensions must be the same for all reduces + // inside of this fusion. + // TODO(tjoerg): Relax the shape constraint. The datatype does not matter. + return ShapeUtil::Equal(first_reduce->shape(), inst->shape()) && + ShapeUtil::Equal(first_reduce->operand(0)->shape(), + inst->operand(0)->shape()) && + ShapeUtil::Equal(first_reduce->operand(1)->shape(), + inst->operand(1)->shape()) && + first_reduce->dimensions() == inst->dimensions(); } - return true; + return ShapeUtil::CompatibleIgnoringElementType( + first_reduce->operand(0)->shape(), inst->shape()) && + LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(), + inst->shape().layout()); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index a782eb3f507..66bcb409f28 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -217,10 +217,18 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, // block 0 of the kernel. llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b); -// Returns whether the outputs of a fusion with reduction are consistent. -bool AreFusedReductionOutputsConsistent( +// Returns whether the output of a fusion with reduction are consistent with +// `first_reduce`. +bool IsFusedReductionOutputConsistent(const HloInstruction* inst, + const HloInstruction* first_reduce); + +inline bool AreFusedReductionOutputsConsistent( absl::Span output_instructions, - const HloInstruction* first_reduce); + const HloInstruction* first_reduce) { + return absl::c_all_of(output_instructions, [=](const HloInstruction* inst) { + return IsFusedReductionOutputConsistent(inst, first_reduce); + }); +} } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 0a4bd6d3820..db9a6ba8390 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1247,8 +1247,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { } CHECK_GE(root->operand_count(), 1); - return EmitReductionFromOrToContiguousDimensions(fusion, - root->operands()); + return EmitReductionFromOrToContiguousDimensions(fusion); } case HloOpcode::kReduce: { // HandleFusion specializes reduction from a multi-dimensional array to @@ -1259,7 +1258,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { return Unimplemented( "Vectorized variadic reduce is not supported on GPU"); } - return EmitReductionFromOrToContiguousDimensions(fusion, {root}); + return EmitReductionFromOrToContiguousDimensions(fusion); } case HloOpcode::kSlice: { return EmitInputFusibleNonStridedSlices(fusion); @@ -1385,7 +1384,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { if (IsReductionFromOrToContiguousDimensions(*reduce) && reduce->shape().IsArray()) { - return EmitReductionFromOrToContiguousDimensions(reduce, {reduce}); + return EmitReductionFromOrToContiguousDimensions(reduce); } return DefaultActionForMlir(input); @@ -3673,7 +3672,7 @@ void IrEmitterUnnested::EmitEpilogueForReduction( llvm::Type* index_ty, HloInstruction* unnested_hlo, const ReductionCodegenInfo& reduction_info, absl::Span reduce_instructions, - absl::Span reduction_output_shape_indices, + absl::Span reduction_output_indices, absl::Span reducers, const TilingKernelInfo& tiling_kernel_info) { const KernelMappingScheme& mapping_scheme = @@ -3728,8 +3727,13 @@ void IrEmitterUnnested::EmitEpilogueForReduction( // At this point in the function we have a "partial sum" of input elements // (stored in partial_result_addresses), and we need to accumulate it into // the correct output element. - auto output_array = GetIrArray(*unnested_hlo, *unnested_hlo, - reduction_output_shape_indices[i]); + ShapeIndex index; + if (unnested_hlo->IsMultiOutputFusion()) { + index.push_back(reduction_output_indices[i]); + } else { + CHECK_EQ(0, reduction_output_indices[i]); + } + auto output_array = GetIrArray(*unnested_hlo, *unnested_hlo, index); IrArray::Index element_index( /*linear=*/untransposed_output_linear_address, reduction_kept_element_shape, &b_); @@ -3872,31 +3876,28 @@ void IrEmitterUnnested::EmitPrintfWithThreadId( }); } -namespace { - -// Obtains the corresponding index of the out_instr in the outputs of the -// `unnested_hlo`. -ShapeIndex CreateShapeIndexForOutputInstruction( - const HloInstruction& unnested_hlo, const HloInstruction& out_instr) { - if (!unnested_hlo.IsMultiOutputFusion()) { - return ShapeIndex({}); +static HloInstruction* GetReduceFromUnnested(HloInstruction* unnested_hlo, + int index) { + if (unnested_hlo->opcode() == HloOpcode::kReduce) { + CHECK_EQ(0, index); + return unnested_hlo; } - const auto& all_outputs = unnested_hlo.fused_expression_root()->operands(); - for (size_t i = 0; i < all_outputs.size(); ++i) { - if (all_outputs[i] == &out_instr) { - return ShapeIndex({static_cast(i)}); + if (unnested_hlo->opcode() == HloOpcode::kFusion) { + auto root = unnested_hlo->fused_expression_root(); + if (root->opcode() == HloOpcode::kReduce) { + CHECK_EQ(0, index); + return root; + } + if (root->opcode() == HloOpcode::kTuple) { + return root->mutable_operand(index); } } - LOG(FATAL) << " Fusion root does not contain output instruction; " - << " fusion: " << unnested_hlo.ToString() - << ", output instruction: " << out_instr.ToString(); + return nullptr; } -} // namespace - void IrEmitterUnnested::EmitTileElementForReduction( HloInstruction* unnested_hlo, const Shape& reduction_operand_shape, - absl::Span output_instructions, + absl::Span instr_index_group, const llvm_ir::IrArray::Index& index, const ReductionCodegenInfo& reduction_info, absl::Span reducers, int64 x_iter_num) { @@ -3915,10 +3916,12 @@ void IrEmitterUnnested::EmitTileElementForReduction( if (unnested_hlo->opcode() == HloOpcode::kFusion) { BindFusionArguments(unnested_hlo, &fused_emitter); - for (int i = 0, e = output_instructions.size(); i != e; ++i) { - const HloInstruction* inst = output_instructions[i]; - ShapeIndex idx = - CreateShapeIndexForOutputInstruction(*unnested_hlo, *inst); + for (int index : instr_index_group) { + const HloInstruction* inst = GetReduceFromUnnested(unnested_hlo, index); + ShapeIndex idx; + if (unnested_hlo->IsMultiOutputFusion()) { + idx.push_back(index); + } if (IsReductionFromOrToContiguousDimensions(*inst)) { input_gens.push_back(*fused_emitter.GetGenerator(inst->operand(0))); } else { @@ -4689,22 +4692,20 @@ ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo( } void IrEmitterUnnested::EmitIRForReduction( - HloInstruction* unnested_hlo, - absl::Span output_instructions, + HloInstruction* unnested_hlo, absl::Span instr_index_group, ReductionCodegenInfo* reduction_info, const Shape& input_shape) { std::vector reduce_instructions; - InlinedVector reduction_output_shape_indices; + InlinedVector reduction_output_indices; InlinedVector reducers; - for (size_t i = 0; i < output_instructions.size(); ++i) { - if (!IsReductionFromOrToContiguousDimensions(*output_instructions[i])) { + for (int index : instr_index_group) { + HloInstruction* output_instruction = + GetReduceFromUnnested(unnested_hlo, index); + if (!IsReductionFromOrToContiguousDimensions(*output_instruction)) { continue; } - HloInstruction* output_instruction = output_instructions[i]; reduce_instructions.push_back(output_instruction); - reduction_output_shape_indices.push_back( - CreateShapeIndexForOutputInstruction(*unnested_hlo, - *output_instruction)); + reduction_output_indices.push_back(index); reducers.push_back(output_instruction->to_apply()); } CHECK(reduce_instructions.size() != 0) @@ -4722,7 +4723,7 @@ void IrEmitterUnnested::EmitIRForReduction( [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, llvm::Value* x_loc, int64 x_iter_num) { EmitTileElementForReduction(unnested_hlo, input_shape, - output_instructions, index, *reduction_info, + instr_index_group, index, *reduction_info, reducers, x_iter_num); }; @@ -4736,7 +4737,7 @@ void IrEmitterUnnested::EmitIRForReduction( emit_reduction_tile); }); EmitEpilogueForReduction(index_ty, unnested_hlo, *reduction_info, - reduce_instructions, reduction_output_shape_indices, + reduce_instructions, reduction_output_indices, reducers, tiling_kernel_info); } @@ -4751,33 +4752,33 @@ bool IsBroadcastedConstantOrScalar(const HloInstruction& instr) { ShapeUtil::IsScalar(instr.operand(0)->shape()))); } -// Divides output_instructions into groups. Different groups will be executed +// Divides `num_reduces` reduces into groups. Different groups will be executed // in parallel. Generally speaking, we'd like to run the reduce instructions // in parallel without incurring too much recomputation overhead. The current // heuristic is to place reduce instructions who share nothing or only // (broadcasted) scalars/constants into different groups; otherwise, they are // placed in the same group. Non-reduce instructions always go with the reduce // instructions into the same group so long as they share any predecessors. -std::vector> DivideOutputInstructionsIntoGroups( - HloInstruction* unnested_hlo, - absl::Span output_instructions) { - CHECK(!output_instructions.empty()); - if (output_instructions.size() == 1) { - return {{output_instructions[0]}}; +std::vector> DivideOutputInstructionsIntoGroups( + HloInstruction* unnested_hlo, int num_reduces) { + CHECK_NE(0, num_reduces); + if (num_reduces == 1) { + return {{0}}; } std::vector> disjoint_sets( - output_instructions.size()); - for (size_t i = 0; i < output_instructions.size(); ++i) { - disjoint_sets[i].Get() = output_instructions[i]; + num_reduces); + for (size_t i = 0; i < num_reduces; ++i) { + disjoint_sets[i].Get() = GetReduceFromUnnested(unnested_hlo, i); } std::unique_ptr reachability_map = HloReachabilityMap::Build(unnested_hlo->fused_instructions_computation()); for (auto* instr : unnested_hlo->fused_instructions()) { std::vector reached_output_ids; - for (size_t oid = 0; oid < output_instructions.size(); ++oid) { - if (HloOpcode::kReduce == output_instructions[oid]->opcode() && + for (size_t oid = 0; oid < num_reduces; ++oid) { + auto reduce = GetReduceFromUnnested(unnested_hlo, oid); + if (HloOpcode::kReduce == reduce->opcode() && (IsBroadcastedConstantOrScalar(*instr))) { // Do not group output reduce instructions through broadcasted // constants or scalars, as the recomputation should be acceptable. @@ -4785,9 +4786,9 @@ std::vector> DivideOutputInstructionsIntoGroups( continue; } // Now group output instructions if they have common predecessors. - if (reachability_map->IsReachable(instr, output_instructions[oid])) { - VLOG(3) << "Reaching " << output_instructions[oid]->ToString() - << " from " << instr->ToString(); + if (reachability_map->IsReachable(instr, reduce)) { + VLOG(3) << "Reaching " << reduce->ToString() << " from " + << instr->ToString(); reached_output_ids.push_back(oid); } } @@ -4797,12 +4798,12 @@ std::vector> DivideOutputInstructionsIntoGroups( } } // Place output instructions in the same set into the same group. - absl::flat_hash_map> groups; - for (size_t oid = 0; oid < output_instructions.size(); ++oid) { - groups[disjoint_sets[oid].Get()].push_back(output_instructions.at(oid)); + absl::flat_hash_map> groups; + for (size_t oid = 0; oid < num_reduces; ++oid) { + groups[disjoint_sets[oid].Get()].push_back(oid); } - std::vector> ret; + std::vector> ret; absl::c_for_each( groups, [&](auto& iter) { ret.emplace_back(std::move(iter.second)); }); return ret; @@ -4811,15 +4812,20 @@ std::vector> DivideOutputInstructionsIntoGroups( } // namespace Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( - HloInstruction* unnested_hlo, - absl::Span output_instructions) { - bool returns_tuple = output_instructions.size() > 1; + HloInstruction* unnested_hlo) { + int num_reduces = 1; + if (unnested_hlo->IsMultiOutputFusion()) { + num_reduces = unnested_hlo->fused_expression_root()->operand_count(); + } + + bool returns_tuple = num_reduces > 1; VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString(); // Build an initializer thunk to initialize each reduction output. std::vector> thunks; - for (int i = 0; i < output_instructions.size(); ++i) { - if (!IsReductionFromOrToContiguousDimensions(*output_instructions[i])) { + for (int i = 0; i < num_reduces; ++i) { + if (!IsReductionFromOrToContiguousDimensions( + *GetReduceFromUnnested(unnested_hlo, i))) { continue; } @@ -4831,17 +4837,20 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( // Build a kernel thunk to compute all the outputs. const HloInstruction* first_reduce = nullptr; - for (int i = 0; i < output_instructions.size(); ++i) { - if (IsReductionFromOrToContiguousDimensions(*output_instructions[i])) { - first_reduce = output_instructions[i]; + for (int i = 0; i < num_reduces; ++i) { + if (IsReductionFromOrToContiguousDimensions( + *GetReduceFromUnnested(unnested_hlo, i))) { + first_reduce = GetReduceFromUnnested(unnested_hlo, i); break; } } CHECK(first_reduce); - if (output_instructions.size() > 1) { - if (!AreFusedReductionOutputsConsistent(output_instructions, - first_reduce)) { - return InternalError("Inconsistent reduction fusion outputs"); + if (num_reduces > 1) { + for (int i = 1; i < num_reduces; i++) { + if (!IsFusedReductionOutputConsistent( + GetReduceFromUnnested(unnested_hlo, i), first_reduce)) { + return InternalError("Inconsistent reduction fusion outputs"); + } } } const Shape& input_shape = first_reduce->operand(0)->shape(); @@ -4852,24 +4861,24 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( << first_reduce->ToString(); // Group output instructions. Each group will be executed in parallel. - std::vector> instr_groups = - DivideOutputInstructionsIntoGroups(unnested_hlo, output_instructions); - VLOG(2) << StrCat("Generate in ", instr_groups.size(), " groups for ", + std::vector> instr_index_groups = + DivideOutputInstructionsIntoGroups(unnested_hlo, num_reduces); + VLOG(2) << StrCat("Generate in ", instr_index_groups.size(), " groups for ", unnested_hlo->ToString()); std::unique_ptr kernel_thunk = BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false); KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); - for (size_t i = 0; i < instr_groups.size(); ++i) { + for (size_t i = 0; i < instr_index_groups.size(); ++i) { // Create a new ReductionCodegenInfo instance as it contains states for // code generation per reduction group. For now, let's always use the very // first reduce as representative to construct ReductionCodegenInfo, since // all the reductions are required to have the same shape and layout as - // verified by `AreFusedReductionOutputsConsistent()`. We can loosen the + // verified by `IsFusedReductionOutputConsistent()`. We can loosen the // constraint later when the needs arise. ReductionCodegenInfo reduction_info = ComputeReductionCodegenInfo(unnested_hlo, first_reduce); auto emit_reduction_func = [&] { - EmitIRForReduction(unnested_hlo, instr_groups[i], &reduction_info, + EmitIRForReduction(unnested_hlo, instr_index_groups[i], &reduction_info, input_shape); }; // Use raw block_id_y to select the i-th parallel reduction to run. Using @@ -4878,7 +4887,7 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( // the indices used within the reductions. llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic( gpu::TargetIntrinsicID::kBlockIdy, {}, {}, &b_); - llvm_ir::AddRangeMetadata(0, instr_groups.size(), + llvm_ir::AddRangeMetadata(0, instr_index_groups.size(), llvm::cast(raw_block_id_y)); llvm::Value* guarding_cond = b_.CreateICmpEQ(raw_block_id_y, b_.getInt32(i)); @@ -4888,11 +4897,11 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( ComputeReductionCodegenInfo(unnested_hlo, first_reduce); const KernelMappingScheme& mapping_scheme = reduction_info.GetKernelMappingScheme(); - // block_y_count is set to instr_groups.size(), so that each reduction group - // can be run in parallel by a different BlockIdy. + // block_y_count is set to instr_index_groups.size(), so that each reduction + // group can be run in parallel by a different BlockIdy. LaunchDimensions launch_dimensions( {/*x=*/mapping_scheme.GetNumberOfBlocks(), - /*y=*/static_cast(instr_groups.size()), + /*y=*/static_cast(instr_index_groups.size()), /*z=*/1}, {/*x=*/mapping_scheme.GetThreadsPerBlock(), /*y=*/1, /*z=*/1}); VLOG(3) << "Launch dimensions of " << unnested_hlo->name() diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 0744739a026..3928b01f38b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -405,13 +405,8 @@ class IrEmitterUnnested : public IrEmitter, // complicating the index calculation in the code generation of the reduce // instructions. In other words, a block_id_y is assigned to a group and so // different groups can be run in parallel. - // - // output_instructions: Output instructions in the computation: instruction - // itself if it's not a fusion, fusion root if fusion is not multi-output, and - // elements of the fusion multi-output tuple otherwise. Status EmitReductionFromOrToContiguousDimensions( - HloInstruction* unnested_hlo, - absl::Span output_instructions); + HloInstruction* unnested_hlo); // Computes the KernelMappingScheme for the reduce HLO and indicates whether // the reduction is a row reduction. For an un-fused reduce op, unnested_hlo @@ -555,12 +550,17 @@ class IrEmitterUnnested : public IrEmitter, // // Calculates and stores the temporary reduction value in the corresponding // alloca. - void EmitTileElementForReduction( - HloInstruction* unnested_hlo, const Shape& reduction_operand_shape, - absl::Span output_instructions, - const llvm_ir::IrArray::Index& index, - const ReductionCodegenInfo& reduction_info, - absl::Span reducers, int64 x_iter_num); + // + // `instr_index_group` indicates a set of reductions this call needs to emit, + // each i points to the ith output of unnested_hlo. Notice that if + // unnested_hlo is not a multi-output fusion, instr_index_group is always {0}. + void EmitTileElementForReduction(HloInstruction* unnested_hlo, + const Shape& reduction_operand_shape, + absl::Span instr_index_group, + const llvm_ir::IrArray::Index& index, + const ReductionCodegenInfo& reduction_info, + absl::Span reducers, + int64 x_iter_num); // Prepares for the code generation for a tile block of a reduction kernel. // @@ -577,13 +577,13 @@ class IrEmitterUnnested : public IrEmitter, llvm::Type* index_ty, HloInstruction* unnested_hlo, const ReductionCodegenInfo& reduction_info, absl::Span reduce_instructions, - absl::Span reduction_output_shape_indices, + absl::Span reduction_output_indices, absl::Span reducers, const TilingKernelInfo& tiling_kernel_info); - // Emits code for reductions in the output_instructions. + // Emits code for reductions in the instr_index_group. void EmitIRForReduction(HloInstruction* unnested_hlo, - absl::Span output_instructions, + absl::Span instr_index_group, ReductionCodegenInfo* reduction_info, const Shape& input_shape);