diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index fc07578a73f..62428f0f3e8 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2456,9 +2456,10 @@ void IrEmitterUnnested::EmitTileElementForReduction( // Emits a kernel for the hlo instruction using the given tiling scheme. void IrEmitterUnnested::EmitBlock(KernelCodegenInfo* kernel_info, - KernelSupportLibrary* ksl, - llvm::Type* index_ty, - TileGenerator emit_one_tile) { + KernelSupportLibrary* ksl, llvm::Value* y, + llvm::Value* x, + TileElementGenerator tile_generator) { + llvm::Type* index_ty = kernel_info->GetIndexType(); KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme(); absl::Span dims_in_tile = mapping_scheme->GetDimensionsInTiles(); absl::Span dims_in_block = @@ -2505,7 +2506,7 @@ void IrEmitterUnnested::EmitBlock(KernelCodegenInfo* kernel_info, mapping_scheme->GetDimensionsInElements(); // Emit the tile with a given tile_index, by calculating the tight bounds for - // each dimension of the tile and then calling emit_one_tile. + // each dimension of the tile and then calling tile_generator. auto emit_one_tile_for_tile_index = [&](const IrArray::Index& tile_index) { std::vector output_tile_bounds(3); for (int i = KernelMappingScheme::DimY; i < KernelMappingScheme::DimTot; @@ -2523,7 +2524,8 @@ void IrEmitterUnnested::EmitBlock(KernelCodegenInfo* kernel_info, IrArray::Index tile_origin = mapping_scheme->GetElementIndexForTileOrigin(tile_index); - emit_one_tile(tile_origin, output_tile_bounds); + tile_generator(y, x, tile_origin, "output", output_tile_bounds[1], + output_tile_bounds[2], ksl); }; const IrArray::Index starting_block = @@ -2574,6 +2576,7 @@ void IrEmitterUnnested::EmitKernel( ? b_.getInt64Ty() : GetIndexTypeForKernel(unnested_hlo, launch_dimensions.launch_bound(), &b_); + kernel_info->SetIndexType(index_ty); // Calculate the starting element coordinate within a tile for the current // thread, (y, x) from thread_id. @@ -2584,17 +2587,10 @@ void IrEmitterUnnested::EmitKernel( kernel_info->SetLaneId( mapping_scheme->GetNumberOfThreadsForDimensionX() == kWarpSize ? x : nullptr); - kernel_info->SetIndexType(index_ty); KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); block_prologue_generator(unnested_hlo, kernel_info); - EmitBlock(kernel_info, &ksl, index_ty, - [&](const IrArray::Index& output_tile_origin, - absl::Span output_tile_bounds) { - tile_element_generator(y, x, output_tile_origin, "output", - output_tile_bounds[1], - output_tile_bounds[2], &ksl); - }); + EmitBlock(kernel_info, &ksl, y, x, tile_element_generator); block_epilogue_generator(unnested_hlo, kernel_info); UpdateLaunchDimensions(launch_dimensions, kernel_thunk, ir_emitter_context_->llvm_module()); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 4fcc5dedb67..fbd3ad39d95 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -52,13 +52,6 @@ namespace gpu { class IrEmitterUnnested : public IrEmitter, private ThunkEmitter::EmissionContext { public: - // Parameter block_contains_multi_tiles indicates whether a tile block - // consists of multiple tiles or not. If the tile block contains only one - // tile, there is no need to use atomic operation to accumulate a local result - // to a global result to implement reduction. - using TileGenerator = - std::function output_tile_bounds)>; // KernelCodegenInfo records the common information to support the code // generation for a kernel to process tensor elements by blocks. A block of // tensor elements may contain one or multiple tiles. The code generators that @@ -251,7 +244,8 @@ class IrEmitterUnnested : public IrEmitter, BlockEpilogueGenerator block_epilogue_generator); void EmitBlock(KernelCodegenInfo* kernel_info, KernelSupportLibrary* ksl, - llvm::Type* index_ty, TileGenerator emit_one_tile); + llvm::Value* y, llvm::Value* x, + TileElementGenerator tile_generator); // Emits code to process a tensor element in a tile for the given kCopy HLO // that performs a 0-2-1 transpose.