diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 5044aa770c1..37204af7b89 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1878,8 +1878,7 @@ bool MayPreventVectorization(const HloInstruction& hlo) { } // namespace Status IrEmitterUnnested::EmitTargetElementLoop( - const HloInstruction& hlo, - const llvm_ir::ElementGenerator& element_generator) { + const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter) { int unroll_factor = 1; if (!MayPreventVectorization(hlo)) { unroll_factor = ComputeMaxUnrollFactor(&hlo); @@ -1888,7 +1887,7 @@ Status IrEmitterUnnested::EmitTargetElementLoop( std::unique_ptr kernel_thunk = BuildKernelThunk( &hlo, /*implements_whole_instruction=*/true, unroll_factor); Status emit_status = - EmitTargetElementLoopInThunk(hlo, element_generator, kernel_thunk.get()); + EmitTargetElementLoopInThunk(hlo, body_emitter, kernel_thunk.get()); thunk_sequence_->emplace_back(std::move(kernel_thunk)); return emit_status; @@ -1912,8 +1911,8 @@ static llvm::Value* GetStartOffsetX(const KernelMappingScheme& mapping_scheme, void IrEmitterUnnested::EmitTile( const KernelMappingScheme& mapping_scheme, const IrArray::Index& tile_origin_index, const string& loop_name, - KernelSupportLibrary* ksl, llvm::Value* thread_id_y, - llvm::Value* thread_id_x, llvm::Value* tile_height, llvm::Value* tile_width, + KernelSupportLibrary* ksl, const ThreadIdInfo& thread_id_info, + llvm::Value* tile_height, llvm::Value* tile_width, const IrEmitterUnnested::EmitElementFunction& emit_elem_function) { llvm::Type* index_ty = tile_width->getType(); auto constant = [&](int64 val) { @@ -1924,8 +1923,8 @@ void IrEmitterUnnested::EmitTile( int64 tile_size_x = mapping_scheme.GetTileSizeX(); int64 x_num_steps = tile_size_x / num_threads_x; - llvm::Value* start_offset_x = - GetStartOffsetX(mapping_scheme, thread_id_x, index_ty, &b_); + llvm::Value* start_offset_x = GetStartOffsetX( + mapping_scheme, thread_id_info.thread_id_x, index_ty, &b_); // Using dilated mapping scheme, each thread steps with a stride of number // of threads. @@ -1962,22 +1961,25 @@ void IrEmitterUnnested::EmitTile( loop_name + "_y_in_tile", /*start=*/constant(0), /*end=*/ - ceil_of_ratio(b_.CreateSub(tile_height, thread_id_y), num_threads_y), + ceil_of_ratio(b_.CreateSub(tile_height, thread_id_info.thread_id_y), + num_threads_y), /*step=*/constant(1), [&](llvm::Value* y_indvar) { - llvm::Value* y_loc = - b_.CreateAdd(thread_id_y, b_.CreateMul(y_indvar, num_threads_y)); + llvm::Value* y_loc = b_.CreateAdd( + thread_id_info.thread_id_y, b_.CreateMul(y_indvar, num_threads_y)); for (int64 j = 0; j < x_num_steps; j++) { llvm::Value* x_loc = b_.CreateAdd(constant(j * step_x), start_offset_x, "x_loc"); IrArray::Index source_idx_x = source_idx.AddOffsetToDim(y_loc, kDimY, &b_) .AddOffsetToDim(constant(j * step_x), kDimX, &b_); + auto emit_element = [&] { + return emit_elem_function(source_idx_x, y_loc, x_loc, j); + }; if (!x_tile_fits) { ksl->If(loop_name + "_x_in_tile", - b_.CreateICmpULT(x_loc, tile_width), - [&] { emit_elem_function(source_idx_x, y_loc, x_loc, j); }); + b_.CreateICmpULT(x_loc, tile_width), emit_element); } else { - emit_elem_function(source_idx_x, y_loc, x_loc, j); + emit_element(); } } }); @@ -1989,14 +1991,11 @@ void IrEmitterUnnested::EmitTile( // index: The index for the first output element in the normalized tensor. The // normalized tensor is the resulting tensor after collapsing contiguous // dimensions that play the same role in the transpose. -// y_loc: The y coordinate within a tile. -// x_loc: The x coordinate within a tile. // mapping_scheme: Kernel mapping scheme specifying the tiling void IrEmitterUnnested::EmitTileElementForCopy( HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc, - llvm::Value* x_loc, int64 /*x_iter_num*/, - absl::Span param_shmem_buffers) { + llvm::Value* x_loc, absl::Span param_shmem_buffers) { // TODO(jlebar): Add AA metadata to this load. llvm::Instruction* load_from_shmem_buffer = Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x_loc, y_loc}), @@ -2027,13 +2026,10 @@ static IrArray::Index GetUnnormalizedIndex( // is the resulting tensor after collapsing contiguous dimensions that play // the same role in the transpose. // kernel_info: Other information to support the kernel code generation. -// y_loc: The y coordinate within a tile. -// x_loc: The x coordinate within a tile. void IrEmitterUnnested::EmitTileElementForFusion( HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc, - llvm::Value* x_loc, int64 /*x_iter_num*/, - absl::Span param_shmem_buffers) { + llvm::Value* x_loc, absl::Span param_shmem_buffers) { std::vector output_arrays = ConstructIrArrayForOutputs(*hlo); GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, GetNestedComputer()); @@ -2545,8 +2541,7 @@ IrEmitterUnnested::TilingKernelInfo IrEmitterUnnested::EmitTilingKernel( }(); auto emit_tile = [&](const IrArray::Index& tile) { - tile_element_generator(thread_id_info.thread_id_y, - thread_id_info.thread_id_x, tile, "output", + tile_element_generator(thread_id_info, tile, "output", output_tile_bounds[1], output_tile_bounds[2], &ksl); }; @@ -2667,16 +2662,16 @@ void IrEmitterUnnested::EmitHlo021Tile( llvm::Value* x_loc, int64 x_iter_num) { if (hlo->opcode() == HloOpcode::kCopy) { EmitTileElementForCopy(hlo, index, mapping_scheme, y_loc, x_loc, - x_iter_num, param_shmem_buffers); + param_shmem_buffers); } else { CHECK_EQ(hlo->opcode(), HloOpcode::kFusion); EmitTileElementForFusion(hlo, index, mapping_scheme, y_loc, x_loc, - x_iter_num, param_shmem_buffers); + param_shmem_buffers); } }; TileElementGenerator tile_generator = - [&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index, + [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index, const string& loop_name, llvm::Value* tile_height, llvm::Value* tile_width, KernelSupportLibrary* ksl) { // If shared memory transpose is needed, wait for all threads to reach @@ -2689,11 +2684,11 @@ void IrEmitterUnnested::EmitHlo021Tile( Permute({0, 2, 1}, index.dims()), index.GetType()); // Copy input parameter values to shared memory buffers: - // tile[y, x] = input[index] + // tile[thread_id_y, thread_id_x] = input[index] // Note that tile_width and tile_height are flipped here because we // are reading a transposed tile. - EmitTile(mapping_scheme, input_tile_origin, "input", ksl, y, x, - tile_width, tile_height, + EmitTile(mapping_scheme, input_tile_origin, "input", ksl, + thread_id_info, tile_width, tile_height, [&](const IrArray::Index& index, llvm::Value* y_loc, llvm::Value* x_loc, int64 /*x_iter_num*/) { for (int64 id : tiled_param_ids) { @@ -2717,8 +2712,8 @@ void IrEmitterUnnested::EmitHlo021Tile( EmitSyncThreads(); } - EmitTile(mapping_scheme, index, loop_name, ksl, y, x, tile_height, - tile_width, element_generator); + EmitTile(mapping_scheme, index, loop_name, ksl, thread_id_info, + tile_height, tile_width, element_generator); bool block_contains_multi_tiles = mapping_scheme.GetTileSizeZ() > 1; // If a tile block contains multiple tiles and shared memory buffers are @@ -3133,11 +3128,11 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( TilingKernelInfo tiling_kernel_info = EmitTilingKernel( mapping_scheme, index_ty, - [&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index, + [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index, const string& loop_name, llvm::Value* tile_height, llvm::Value* tile_width, KernelSupportLibrary* ksl) { EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl, - y, x, tile_height, tile_width, emit_reduction_tile); + thread_id_info, tile_height, tile_width, emit_reduction_tile); }); EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info, reduce_instructions, reduction_output_shape_indices, diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index d31e34cb9b7..3d6d0950db9 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -52,13 +52,25 @@ namespace gpu { class IrEmitterUnnested : public IrEmitter, private ThunkEmitter::EmissionContext { public: + struct ThreadIdInfo { + // Raw thread id. + llvm::Value* thread_id; + + // X-coordinate calculated from thread id: `thread_id % num_threads_x` + llvm::Value* thread_id_x; + + // Y-coordinate calculated from thread id: `thread_id / num_threads_x` + llvm::Value* thread_id_y; + + // Lane id: `thread_id % kWarpSize` + llvm::Value* lane_id; + }; + // A function object to generate code to process one element in a tile. // - // hlo: the instruction for which the code is generated for. // index: the index for the first output element of the current thread. // y_loc: The y coordinate within a tile. // x_loc: The x coordinate within a tile. - // kernel_info: Other information to support the kernel code generation. // x_iter_num: When a thread process N elements in the X dimension, x_iter_num // has a value of 0..N-1 to identify the element being process. using EmitElementFunction = std::function; @@ -291,26 +303,27 @@ class IrEmitterUnnested : public IrEmitter, void EmitTile( const KernelMappingScheme& mapping_scheme, const llvm_ir::IrArray::Index& tile_origin_index, const string& loop_name, - KernelSupportLibrary* ksl, llvm::Value* thread_id_y, - llvm::Value* thread_id_x, llvm::Value* tile_height, - llvm::Value* tile_width, + KernelSupportLibrary* ksl, const ThreadIdInfo& thread_id_info, + llvm::Value* tile_height, llvm::Value* tile_width, const IrEmitterUnnested::EmitElementFunction& emit_elem_function); // Emits code to process a tensor element in a tile for the given kCopy HLO // that performs a 0-2-1 transpose. + // y_loc: The y coordinate within a tile. + // x_loc: The x coordinate within a tile. void EmitTileElementForCopy( HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc, - llvm::Value* x_loc, int64 x_iter_num, - absl::Span param_shmem_buffers); + llvm::Value* x_loc, absl::Span param_shmem_buffers); // Emits code to process a tensor element in a tile for the given kLoop // fusion HLO containing parameters that are 0-2-1 transpose of its outputs. + // y_loc: The y coordinate within a tile. + // x_loc: The x coordinate within a tile. void EmitTileElementForFusion( HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc, - llvm::Value* x_loc, int64 x_iter_num, - absl::Span param_shmem_buffers); + llvm::Value* x_loc, absl::Span param_shmem_buffers); // Emits code to process a tensor element in a tile for the given input hlo // that is either a unnested kReduce or a kInput fusion. @@ -389,20 +402,6 @@ class IrEmitterUnnested : public IrEmitter, // Sets the return value range to [0, threads_per_block). llvm::Value* EmitThreadId(int64 threads_per_block, llvm::Type* index_ty); - struct ThreadIdInfo { - // Raw thread id. - llvm::Value* thread_id; - - // X-coordinate calculated from thread id: `thread_id % num_threads_x` - llvm::Value* thread_id_x; - - // Y-coordinate calculated from thread id: `thread_id / num_threads_x` - llvm::Value* thread_id_y; - - // Lane id: `thread_id % kWarpSize` - llvm::Value* lane_id; - }; - // Emits the LLVM values for thread_id, thread_id.x, thread_id.y and lane // id. // diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 5df25677332..5045d7b0c13 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -161,7 +161,7 @@ Status FusedIrEmitter::HandleParameter(const HloInstruction* parameter) { // address-space-based AA in LLVM, it wouldn't help us much here. return b_->CreateLoad( b_->CreateGEP(param_tile_buffer, {index.GetConstantWithIndexType(0), - tile_param_x_, tile_param_y_}), + thread_id_x_, thread_id_y_}), "tiled_buffer"); } } diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index c4e00f8889a..d13b0262180 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -60,13 +60,13 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault { FusedIrEmitter(GeneratorForOperandIrArrays operand_arrays_generator, ElementalIrEmitter* elemental_emitter, - llvm::Value* tile_param_x = nullptr, - llvm::Value* tile_param_y = nullptr, + llvm::Value* thread_id_x = nullptr, + llvm::Value* thread_id_y = nullptr, absl::Span param_shmem_buffers = {}) : operand_arrays_(), operand_arrays_generator_(std::move(operand_arrays_generator)), - tile_param_x_(tile_param_x), - tile_param_y_(tile_param_y), + thread_id_x_(thread_id_x), + thread_id_y_(thread_id_y), param_shmem_buffers_(param_shmem_buffers.begin(), param_shmem_buffers.end()), elemental_emitter_(elemental_emitter), @@ -121,10 +121,10 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault { GeneratorForOperandIrArrays operand_arrays_generator_; // The x coordinate within a tile. - llvm::Value* tile_param_x_; + llvm::Value* thread_id_x_; // The y coordinate within a tile. - llvm::Value* tile_param_y_; + llvm::Value* thread_id_y_; // Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr // if the parameter is not tiled.