From d5236e610b82215a2ace49d73c04ab4f20faa520 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Tue, 19 Mar 2019 09:20:19 -0700 Subject: [PATCH] [XLA:GPU] Enhancement to the implementation of reduction to vector. When the kept elements of a reduction are contiguous in the input tensor, we use the implementation of reduction to vector to emit high performance code. Previously, we require the reduction result have the same layout as the kept kept elements in the input tensor. This change removes such a constraint. Add test cases. PiperOrigin-RevId: 239204904 --- .../xla/service/gpu/ir_emission_utils.cc | 7 +-- .../xla/service/gpu/ir_emitter_unnested.cc | 39 ++++++++++-- .../gpu/tests/gpu_kernel_tiling_test.cc | 63 +++++++++++++++++++ 3 files changed, 99 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 6b9cbdd94b3..47b5b25e73a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -169,12 +169,7 @@ bool IsReductionToVector(const HloInstruction& reduce) { } } return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(), - dims_to_keep) && - ShapeUtil::Equal( - reduce.shape(), - ShapeUtil::FilterDimensions( - [&](int64 dim) { return absl::c_count(dims_to_keep, dim); }, - input->shape())); + dims_to_keep); } // This emits a device-side call to diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 07038607bdc..e6bd96e6e30 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2608,6 +2608,9 @@ class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo { AddressVector reduction_input_addresses_; InlinedVector reducers_; InlinedVector reduction_output_shape_indices_; + // The address of the memory that stores the linear index of the current + // output, assuming that the output doesn't change the layout of the kept + // elements in the reduction input. llvm::AllocaInst* current_output_linear_index_address_; llvm::AllocaInst* current_output_inbound_address_; bool is_row_reduction_; @@ -2806,21 +2809,49 @@ void IrEmitterUnnested::EmitEpilogueForReduction( llvm_ir::SetToFirstInsertPoint(if_output_inbound_data.true_block, &b_); } + HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion + ? unnested_hlo->fused_expression_root() + : unnested_hlo; + std::vector reduce_instructions; + absl::c_for_each(GetOutputInstructions(&reduce_or_tuple), + [&](const HloInstruction* instr) { + if (instr->opcode() == HloOpcode::kReduce) { + reduce_instructions.push_back(instr); + } + }); int num_partial_results = reduction_info->GetNumberOfPartialResults(); // Emit an atomic operation that accumulates the partial reduction to the // output element. For row reduction, this is only for lane 0 due to the // if-statement emitted above. for (int i = 0; i != num_reduces; ++i) { + const HloInstruction* reduce_hlo = reduce_instructions[i]; + Shape reduction_kept_element_shape = ShapeUtil::FilterDimensions( + [&](int64 dim) { + return !absl::c_linear_search(reduce_hlo->dimensions(), dim); + }, + reduce_hlo->operand(0)->shape()); for (int j = 0; j < num_partial_results; ++j) { + // A reduction is allowed to transpose its output. For example, suppose + // we are reducing the second dimension of f32[10,20,30]{3,2,1}. We are + // allowed to produce as output either f32[10,30]{1,0} (no transpose) or + // f32[10,30]{0,1} (transposing the two output dims). + // + // At this point in the function we have a "partial sum" of input elements + // (stored in partial_result_addresses), and we need to accumulate it into + // the correct output element. + // + // *reduction_info->GetCurrentOutputLinearIndexAddress() stores the linear + // index in the output into which we would need to accumulate *if the + // output layout matched the input layout*. This is why we use + // `reduction_kept_element_shape` rather than `unnested_hlo->shape()` when + // computing `element_index` below. IrArray::Index element_index( /*linear=*/Load( InBoundsGEP(reduction_info->GetCurrentOutputLinearIndexAddress(), {b_.getInt32(j)}), - "output_linear_addr"), - ShapeUtil::GetSubshape(unnested_hlo->shape(), - reduction_output_shape_indices[i]), - &b_); + "untransposed_output_linear_addr"), + reduction_kept_element_shape, &b_); llvm::Value* output_address = GetIrArray(*unnested_hlo, *unnested_hlo, reduction_output_shape_indices[i]) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index 869724db601..7f9c8202a4b 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -411,6 +411,69 @@ TEST_F(GpuKernelTilingTest, ColumnReductionMOFUnrolled) { // Check that the kernel runs correctly. EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); } + +TEST_F(GpuKernelTilingTest, ColumnReductionWithLayoutChangeTiled) { + const char *const kHloString = R"( + HloModule reduce_with_layout_change + reduction0 { + x0 = f32[] parameter(0) + y0 = f32[] parameter(1) + ROOT add0 = f32[] add(x0, y0) + } + + ENTRY kernel_entry { + arg0 = f32[4,32,32,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(0) + constant0 = f32[] constant(0) + ROOT reduce0 = f32[4,32,16,12,12]{4,3,2,1,0} reduce(arg0, constant0), + dimensions={1,6,7}, to_apply=reduction0 + })"; + + // Check that the kernel is tiled by looking for llvm.nvvm.atomic. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @reduce +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: } +)", + /*match_optimized_ir=*/true); + + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); +} + +TEST_F(GpuKernelTilingTest, RowReductionWithLayoutChangeTiled) { + const char *const kHloString = R"( + HloModule reduce_with_layout_change + reduction0 { + x0 = f32[] parameter(0) + y0 = f32[] parameter(1) + ROOT add0 = f32[] add(x0, y0) + } + + ENTRY kernel_entry { + arg0 = f32[8,6,64]{2,1,0} parameter(0) + constant0 = f32[] constant(0) + ROOT reduce0 = f32[8,6]{0,1} reduce(arg0, constant0), dimensions={2}, + to_apply=reduction0 + })"; + + // Check that the kernel is tiled by looking for llvm.nvvm.shfl.sync.down. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @reduce +; CHECK: call float @llvm.nvvm.shfl.sync.down.f32 +; CHECK: } +)", + /*match_optimized_ir=*/true); + + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); +} + } // namespace } // namespace gpu } // namespace xla